├── .gitignore
├── LICENSE
├── README.md
├── assets
├── harmonization.jpg
├── matanyone_logo.png
├── pipeline.jpg
├── teaser.jpg
└── teaser_demo.gif
├── hugging_face
├── app.py
├── matanyone_wrapper.py
├── requirements.txt
└── tools
│ ├── __init__.py
│ ├── base_segmenter.py
│ ├── download_util.py
│ ├── interact_tools.py
│ ├── mask_painter.py
│ ├── misc.py
│ └── painter.py
├── inference_hf.py
├── inference_matanyone.py
├── inputs
├── mask
│ ├── test-sample0_1.png
│ ├── test-sample0_2.png
│ ├── test-sample1.png
│ ├── test-sample2.png
│ └── test-sample3.png
└── video
│ ├── test-sample0
│ ├── 0000.jpg
│ ├── 0001.jpg
│ ├── 0002.jpg
│ ├── 0003.jpg
│ ├── 0004.jpg
│ ├── 0005.jpg
│ ├── 0006.jpg
│ ├── 0007.jpg
│ ├── 0008.jpg
│ ├── 0009.jpg
│ ├── 0010.jpg
│ ├── 0011.jpg
│ ├── 0012.jpg
│ ├── 0013.jpg
│ ├── 0014.jpg
│ ├── 0015.jpg
│ ├── 0016.jpg
│ ├── 0017.jpg
│ ├── 0018.jpg
│ ├── 0019.jpg
│ ├── 0020.jpg
│ ├── 0021.jpg
│ ├── 0022.jpg
│ ├── 0023.jpg
│ ├── 0024.jpg
│ ├── 0025.jpg
│ ├── 0026.jpg
│ ├── 0027.jpg
│ ├── 0028.jpg
│ ├── 0029.jpg
│ ├── 0030.jpg
│ ├── 0031.jpg
│ ├── 0032.jpg
│ ├── 0033.jpg
│ ├── 0034.jpg
│ ├── 0035.jpg
│ ├── 0036.jpg
│ ├── 0037.jpg
│ ├── 0038.jpg
│ ├── 0039.jpg
│ ├── 0040.jpg
│ ├── 0041.jpg
│ ├── 0042.jpg
│ ├── 0043.jpg
│ ├── 0044.jpg
│ ├── 0045.jpg
│ ├── 0046.jpg
│ ├── 0047.jpg
│ ├── 0048.jpg
│ ├── 0049.jpg
│ ├── 0050.jpg
│ ├── 0051.jpg
│ ├── 0052.jpg
│ ├── 0053.jpg
│ ├── 0054.jpg
│ ├── 0055.jpg
│ ├── 0056.jpg
│ ├── 0057.jpg
│ ├── 0058.jpg
│ ├── 0059.jpg
│ ├── 0060.jpg
│ ├── 0061.jpg
│ ├── 0062.jpg
│ ├── 0063.jpg
│ ├── 0064.jpg
│ ├── 0065.jpg
│ ├── 0066.jpg
│ ├── 0067.jpg
│ ├── 0068.jpg
│ ├── 0069.jpg
│ ├── 0070.jpg
│ └── 0071.jpg
│ ├── test-sample1.mp4
│ ├── test-sample2.mp4
│ └── test-sample3.mp4
├── matanyone
├── __init__.py
├── config
│ ├── __init__.py
│ ├── eval_matanyone_config.yaml
│ ├── hydra
│ │ └── job_logging
│ │ │ ├── custom-no-rank.yaml
│ │ │ └── custom.yaml
│ └── model
│ │ └── base.yaml
├── inference
│ ├── __init__.py
│ ├── image_feature_store.py
│ ├── inference_core.py
│ ├── kv_memory_store.py
│ ├── memory_manager.py
│ ├── object_info.py
│ ├── object_manager.py
│ └── utils
│ │ ├── __init__.py
│ │ └── args_utils.py
├── model
│ ├── __init__.py
│ ├── aux_modules.py
│ ├── big_modules.py
│ ├── channel_attn.py
│ ├── group_modules.py
│ ├── matanyone.py
│ ├── modules.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── object_summarizer.py
│ │ ├── object_transformer.py
│ │ ├── positional_encoding.py
│ │ └── transformer_layers.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── memory_utils.py
│ │ ├── parameter_groups.py
│ │ └── resnet.py
└── utils
│ ├── __init__.py
│ ├── get_default_model.py
│ ├── inference_utils.py
│ └── tensor_utils.py
└── pyproject.toml
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .vscode/
3 | .DS_Store
4 | hugging_face/assets/
5 | results/
6 | test_sample/
7 | pretrained_models/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # S-Lab License 1.0
2 |
3 | Copyright 2023 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
36 |
37 |
38 | ---
39 | For inquiries permission for commercial use, please consult our team:
40 | Peiqing Yang (peiqingyang99@outlook.com),
41 | Dr. Shangchen Zhou (shangchenzhou@gmail.com),
42 | Prof. Chen Change Loy (ccloy@ntu.edu.sg)
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
Stable Video Matting with Consistent Memory Propagation
5 |
6 |
7 |
14 |
15 | 1S-Lab, Nanyang Technological University
16 | 2SenseTime Research, Singapore
17 |
18 |
19 |
20 |
37 |
38 |
MatAnyone is a practical human video matting framework supporting target assignment, with stable performance in both semantics of core regions and fine-grained boundary details.
39 |
40 |
41 |

42 |
43 |
44 | :movie_camera: For more visual results, go checkout our
project page
45 |
46 | ---
47 |
48 |
49 |
50 | ## 📮 Update
51 | - [2025.03] Release our evaluation benchmark - [YouTubeMatte](https://github.com/pq-yang/MatAnyone?tab=readme-ov-file#-evaluation-benchmark).
52 | - [2025.03] Integrate MatAnyone with Hugging Face 🤗
53 | - [2025.02] Release inference codes and gradio demo.
54 | - [2025.02] This repo is created.
55 |
56 | ## 🔎 Overview
57 | 
58 |
59 | ## 🔧 Installation
60 | 1. Clone Repo
61 | ```bash
62 | git clone https://github.com/pq-yang/MatAnyone
63 | cd MatAnyone
64 | ```
65 |
66 | 2. Create Conda Environment and Install Dependencies
67 | ```bash
68 | # create new conda env
69 | conda create -n matanyone python=3.8 -y
70 | conda activate matanyone
71 |
72 | # install python dependencies
73 | pip install -e .
74 | # [optional] install python dependencies for gradio demo
75 | pip3 install -r hugging_face/requirements.txt
76 | ```
77 |
78 | ## 🤗 Load from Hugging Face
79 | Alternatively, models can be directly loaded from [Hugging Face](https://huggingface.co/PeiqingYang/MatAnyone) to make inference.
80 |
81 | ```shell
82 | pip install -q git+https://github.com/pq-yang/MatAnyone
83 | ```
84 |
85 | To extract the foreground and the alpha video you can directly run the following lines. Please refer to [inference_hf.py](https://github.com/pq-yang/MatAnyone/blob/main/inference_hf.py) for more arguments.
86 | ```python
87 | from matanyone import InferenceCore
88 | processor = InferenceCore("PeiqingYang/MatAnyone")
89 |
90 | foreground_path, alpha_path = processor.process_video(
91 | input_path = "inputs/video/test-sample1.mp4",
92 | mask_path = "inputs/mask/test-sample1.png",
93 | output_path = "outputs"
94 | )
95 | ```
96 |
97 | ## 🔥 Inference
98 |
99 | ### Download Model
100 | Download our pretrained model from [MatAnyone v1.0.0](https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth) to the `pretrained_models` folder (pretrained model can also be automatically downloaded during the first inference).
101 |
102 | The directory structure will be arranged as:
103 | ```
104 | pretrained_models
105 | |- matanyone.pth
106 | ```
107 |
108 | ### Quick Test
109 | We provide some examples in the [`inputs`](./inputs) folder. **For each run, we take a video and its first-frame segmenatation mask as input.** The segmenation mask could be obtained from interactive segmentation models such as [SAM2 demo](https://huggingface.co/spaces/fffiloni/SAM2-Image-Predictor). For example, the directory structure can be arranged as:
110 | ```
111 | inputs
112 | |- video
113 | |- test-sample0 # folder containing all frames
114 | |- test-sample1.mp4 # .mp4, .mov, .avi
115 | |- mask
116 | |- test-sample0_1.png # mask for person 1
117 | |- test-sample0_2.png # mask for person 2
118 | |- test-sample1.png
119 | ```
120 | Run the following command to try it out:
121 |
122 | ```shell
123 | ## single target
124 | # short video; 720p
125 | python inference_matanyone.py -i inputs/video/test-sample1.mp4 -m inputs/mask/test-sample1.png
126 | # short video; 1080p
127 | python inference_matanyone.py -i inputs/video/test-sample2.mp4 -m inputs/mask/test-sample2.png
128 | # long video; 1080p
129 | python inference_matanyone.py -i inputs/video/test-sample3.mp4 -m inputs/mask/test-sample3.png
130 |
131 | ## multiple targets (control by mask)
132 | # obtain matte for target 1
133 | python inference_matanyone.py -i inputs/video/test-sample0 -m inputs/mask/test-sample0_1.png --suffix target1
134 | # obtain matte for target 2
135 | python inference_matanyone.py -i inputs/video/test-sample0 -m inputs/mask/test-sample0_2.png --suffix target2
136 | ```
137 | The results will be saved in the `results` folder, including the foreground output video and the alpha output video.
138 | - If you want to save the results as per-frame images, you can set `--save_image`.
139 | - If you want to set a limit for the maximum input resolution, you can set `--max_size`, and the video will be downsampled if min(w, h) exceeds. By default, we don't set the limit.
140 |
141 | ## 🎪 Interactive Demo
142 | To get rid of the preparation for first-frame segmentation mask, we prepare a gradio demo on [hugging face](https://huggingface.co/spaces/PeiqingYang/MatAnyone) and could also **launch locally**. Just drop your video/image, assign the target masks with a few clicks, and get the the matting results!
143 | ```shell
144 | cd hugging_face
145 |
146 | # install python dependencies
147 | pip3 install -r requirements.txt # FFmpeg required
148 |
149 | # launch the demo
150 | python app.py
151 | ```
152 |
153 | By launching, an interactive interface will appear as follow:
154 |
155 | 
156 |
157 |
158 | ## 📊 Evaluation Benchmark
159 |
160 | We provide a synthetic benchmark **[YouTubeMatte](https://drive.google.com/file/d/1IEH0RaimT_hSp38AWF6wuwNJzzNSHpJ4/view?usp=sharing)** to enlarge the commonly-used [VideoMatte240K-Test](https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/training.md#evaluation). A comparison between them is summarized in the table below.
161 |
162 | | Dataset | #Foregrounds | Source | Harmonized |
163 | | :------------------ | :----------: | :----------------: | :--------: |
164 | | VideoMatte240K-Test | 5 | Purchased Footage | ❌ |
165 | | **YouTubeMatte** | **32** | **YouTube Videos** | ✅ |
166 |
167 | It is noteworthy that we applied **harmonization** (using [Harmonizer](https://github.com/ZHKKKe/Harmonizer)) when compositing the foreground on a background. Such an operation effectively makes YouTubeMatte a more *challenging* benchmark that is closer to the *real* distribution. As shown in the figure below, while [RVM](https://github.com/PeterL1n/RobustVideoMatting) is confused by the harmonized frame, our method still yields robust performance.
168 |
169 | 
170 |
171 |
172 | ## 📑 Citation
173 |
174 | If you find our repo useful for your research, please consider citing our paper:
175 |
176 | ```bibtex
177 | @inProceedings{yang2025matanyone,
178 | title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
179 | author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
180 | booktitle = {CVPR},
181 | year = {2025}
182 | }
183 | ```
184 |
185 | ## 📝 License
186 |
187 | This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
188 |
189 | ## 👏 Acknowledgement
190 |
191 | This project is built upon [Cutie](https://github.com/hkchengrex/Cutie), with the interactive demo adapted from [ProPainter](https://github.com/sczhou/ProPainter), leveraging segmentation capabilities from [Segment Anything Model](https://github.com/facebookresearch/segment-anything) and [Segment Anything Model 2](https://github.com/facebookresearch/sam2). Thanks for their awesome works!
192 |
193 | ---
194 |
195 | This study is supported under the RIE2020 Industry Alignment Fund – Industry Collaboration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s).
196 |
197 | ## 📧 Contact
198 |
199 | If you have any questions, please feel free to reach us at `peiqingyang99@outlook.com`.
200 |
--------------------------------------------------------------------------------
/assets/harmonization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/harmonization.jpg
--------------------------------------------------------------------------------
/assets/matanyone_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/matanyone_logo.png
--------------------------------------------------------------------------------
/assets/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/pipeline.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/teaser.jpg
--------------------------------------------------------------------------------
/assets/teaser_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/teaser_demo.gif
--------------------------------------------------------------------------------
/hugging_face/matanyone_wrapper.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import torch
3 | from torchvision.transforms.functional import to_tensor
4 | import numpy as np
5 | import random
6 | import cv2
7 |
8 | def gen_dilate(alpha, min_kernel_size, max_kernel_size):
9 | kernel_size = random.randint(min_kernel_size, max_kernel_size)
10 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
11 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
12 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
13 | return dilate.astype(np.float32)
14 |
15 | def gen_erosion(alpha, min_kernel_size, max_kernel_size):
16 | kernel_size = random.randint(min_kernel_size, max_kernel_size)
17 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
18 | fg = np.array(np.equal(alpha, 255).astype(np.float32))
19 | erode = cv2.erode(fg, kernel, iterations=1)*255
20 | return erode.astype(np.float32)
21 |
22 | @torch.inference_mode()
23 | @torch.amp.autocast("cuda")
24 | def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
25 | """
26 | Args:
27 | frames_np: [(H,W,C)]*n, uint8
28 | mask: (H,W), uint8
29 | Outputs:
30 | com: [(H,W,C)]*n, uint8
31 | pha: [(H,W,C)]*n, uint8
32 | """
33 |
34 | # print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====')
35 | bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3))
36 | objects = [1]
37 |
38 | # [optional] erode & dilate on given seg mask
39 | if r_dilate > 0:
40 | mask = gen_dilate(mask, r_dilate, r_dilate)
41 | if r_erode > 0:
42 | mask = gen_erosion(mask, r_erode, r_erode)
43 |
44 | mask = torch.from_numpy(mask).cuda()
45 |
46 | frames_np = [frames_np[0]]* n_warmup + frames_np
47 |
48 | frames = []
49 | phas = []
50 | for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
51 | image = to_tensor(frame_single).cuda().float()
52 |
53 | if ti == 0:
54 | output_prob = processor.step(image, mask, objects=objects) # encode given mask
55 | output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
56 | else:
57 | if ti <= n_warmup:
58 | output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
59 | else:
60 | output_prob = processor.step(image)
61 |
62 | # convert output probabilities to an object mask
63 | mask = processor.output_prob_to_mask(output_prob)
64 |
65 | pha = mask.unsqueeze(2).cpu().numpy()
66 | com_np = frame_single / 255. * pha + bgr * (1 - pha)
67 |
68 | # DONOT save the warmup frames
69 | if ti > (n_warmup-1):
70 | frames.append((com_np*255).astype(np.uint8))
71 | phas.append((pha*255).astype(np.uint8))
72 |
73 | return frames, phas
--------------------------------------------------------------------------------
/hugging_face/requirements.txt:
--------------------------------------------------------------------------------
1 | progressbar2
2 | gdown >= 4.7.1
3 | gitpython >= 3.1
4 | git+https://github.com/cheind/py-thin-plate-spline
5 | hickle >= 5.0
6 | tensorboard >= 2.11
7 | numpy >= 1.21
8 | git+https://github.com/facebookresearch/segment-anything.git
9 | gradio==4.31.0
10 | fastapi==0.111.0
11 | pydantic==2.7.1
12 | opencv-python >= 4.8
13 | matplotlib
14 | pyyaml
15 | av >= 0.5.2
16 | openmim
17 | tqdm >= 4.66.1
18 | psutil
19 | ffmpeg-python
20 | cython
21 | Pillow >= 9.5
22 | scipy >= 1.7
23 | pycocotools >= 2.0.7
24 | einops >= 0.6
25 | hydra-core >= 1.3.2
26 | PySide6 >= 6.2.0
27 | charset-normalizer >= 3.1.0
28 | netifaces >= 0.11.0
29 | cchardet >= 2.1.7
30 | easydict
31 | requests
32 | pyqtdarktheme
33 | imageio == 2.25.0
34 | imageio[ffmpeg]
35 | ffmpeg-python
--------------------------------------------------------------------------------
/hugging_face/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/hugging_face/tools/__init__.py
--------------------------------------------------------------------------------
/hugging_face/tools/base_segmenter.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import cv2
4 | from PIL import Image, ImageDraw, ImageOps
5 | import numpy as np
6 | from typing import Union
7 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8 | import matplotlib.pyplot as plt
9 | import PIL
10 | from .mask_painter import mask_painter
11 |
12 |
13 | class BaseSegmenter:
14 | def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15 | """
16 | device: model device
17 | SAM_checkpoint: path of SAM checkpoint
18 | model_type: vit_b, vit_l, vit_h
19 | """
20 | print(f"Initializing BaseSegmenter to {device}")
21 | assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22 |
23 | self.device = device
24 | self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25 | self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
26 | self.model.to(device=self.device)
27 | self.predictor = SamPredictor(self.model)
28 | self.embedded = False
29 |
30 | @torch.no_grad()
31 | def set_image(self, image: np.ndarray):
32 | # PIL.open(image_path) 3channel: RGB
33 | # image embedding: avoid encode the same image multiple times
34 | self.orignal_image = image
35 | if self.embedded:
36 | print('repeat embedding, please reset_image.')
37 | return
38 | self.predictor.set_image(image)
39 | self.embedded = True
40 | return
41 |
42 | @torch.no_grad()
43 | def reset_image(self):
44 | # reset image embeding
45 | self.predictor.reset_image()
46 | self.embedded = False
47 |
48 | def predict(self, prompts, mode, multimask=True):
49 | """
50 | image: numpy array, h, w, 3
51 | prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
52 | prompts['point_coords']: numpy array [N,2]
53 | prompts['point_labels']: numpy array [1,N]
54 | prompts['mask_input']: numpy array [1,256,256]
55 | mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
56 | mask_outputs: True (return 3 masks), False (return 1 mask only)
57 | whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
58 | """
59 | assert self.embedded, 'prediction is called before set_image (feature embedding).'
60 | assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
61 |
62 | if mode == 'point':
63 | masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
64 | point_labels=prompts['point_labels'],
65 | multimask_output=multimask)
66 | elif mode == 'mask':
67 | masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
68 | multimask_output=multimask)
69 | elif mode == 'both': # both
70 | masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
71 | point_labels=prompts['point_labels'],
72 | mask_input=prompts['mask_input'],
73 | multimask_output=multimask)
74 | else:
75 | raise("Not implement now!")
76 | # masks (n, h, w), scores (n,), logits (n, 256, 256)
77 | return masks, scores, logits
78 |
79 |
80 | if __name__ == "__main__":
81 | # load and show an image
82 | image = cv2.imread('/hhd3/gaoshang/truck.jpg')
83 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
84 |
85 | # initialise BaseSegmenter
86 | SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
87 | model_type = 'vit_h'
88 | device = "cuda:4"
89 | base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
90 |
91 | # image embedding (once embedded, multiple prompts can be applied)
92 | base_segmenter.set_image(image)
93 |
94 | # examples
95 | # point only ------------------------
96 | mode = 'point'
97 | prompts = {
98 | 'point_coords': np.array([[500, 375], [1125, 625]]),
99 | 'point_labels': np.array([1, 1]),
100 | }
101 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
102 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
103 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
104 | cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
105 |
106 | # both ------------------------
107 | mode = 'both'
108 | mask_input = logits[np.argmax(scores), :, :]
109 | prompts = {'mask_input': mask_input [None, :, :]}
110 | prompts = {
111 | 'point_coords': np.array([[500, 375], [1125, 625]]),
112 | 'point_labels': np.array([1, 0]),
113 | 'mask_input': mask_input[None, :, :]
114 | }
115 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
116 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
117 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
118 | cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
119 |
120 | # mask only ------------------------
121 | mode = 'mask'
122 | mask_input = logits[np.argmax(scores), :, :]
123 |
124 | prompts = {'mask_input': mask_input[None, :, :]}
125 |
126 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
127 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
128 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
129 | cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
130 |
--------------------------------------------------------------------------------
/hugging_face/tools/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import requests
4 | from torch.hub import download_url_to_file, get_dir
5 | from tqdm import tqdm
6 | from urllib.parse import urlparse
7 |
8 | def sizeof_fmt(size, suffix='B'):
9 | """Get human readable file size.
10 |
11 | Args:
12 | size (int): File size.
13 | suffix (str): Suffix. Default: 'B'.
14 |
15 | Return:
16 | str: Formated file siz.
17 | """
18 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
19 | if abs(size) < 1024.0:
20 | return f'{size:3.1f} {unit}{suffix}'
21 | size /= 1024.0
22 | return f'{size:3.1f} Y{suffix}'
23 |
24 |
25 | def download_file_from_google_drive(file_id, save_path):
26 | """Download files from google drive.
27 | Ref:
28 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
29 | Args:
30 | file_id (str): File id.
31 | save_path (str): Save path.
32 | """
33 |
34 | session = requests.Session()
35 | URL = 'https://docs.google.com/uc?export=download'
36 | params = {'id': file_id}
37 |
38 | response = session.get(URL, params=params, stream=True)
39 | token = get_confirm_token(response)
40 | if token:
41 | params['confirm'] = token
42 | response = session.get(URL, params=params, stream=True)
43 |
44 | # get file size
45 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
46 | print(response_file_size)
47 | if 'Content-Range' in response_file_size.headers:
48 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
49 | else:
50 | file_size = None
51 |
52 | save_response_content(response, save_path, file_size)
53 |
54 |
55 | def get_confirm_token(response):
56 | for key, value in response.cookies.items():
57 | if key.startswith('download_warning'):
58 | return value
59 | return None
60 |
61 |
62 | def save_response_content(response, destination, file_size=None, chunk_size=32768):
63 | if file_size is not None:
64 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
65 |
66 | readable_file_size = sizeof_fmt(file_size)
67 | else:
68 | pbar = None
69 |
70 | with open(destination, 'wb') as f:
71 | downloaded_size = 0
72 | for chunk in response.iter_content(chunk_size):
73 | downloaded_size += chunk_size
74 | if pbar is not None:
75 | pbar.update(1)
76 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
77 | if chunk: # filter out keep-alive new chunks
78 | f.write(chunk)
79 | if pbar is not None:
80 | pbar.close()
81 |
82 |
83 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
84 | """Load file form http url, will download models if necessary.
85 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
86 | Args:
87 | url (str): URL to be downloaded.
88 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
89 | Default: None.
90 | progress (bool): Whether to show the download progress. Default: True.
91 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
92 | Returns:
93 | str: The path to the downloaded file.
94 | """
95 | if model_dir is None: # use the pytorch hub_dir
96 | hub_dir = get_dir()
97 | model_dir = os.path.join(hub_dir, 'checkpoints')
98 |
99 | os.makedirs(model_dir, exist_ok=True)
100 |
101 | parts = urlparse(url)
102 | filename = os.path.basename(parts.path)
103 | if file_name is not None:
104 | filename = file_name
105 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
106 | if not os.path.exists(cached_file):
107 | print(f'Downloading: "{url}" to {cached_file}\n')
108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109 | return cached_file
--------------------------------------------------------------------------------
/hugging_face/tools/interact_tools.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import cv2
4 | from PIL import Image, ImageDraw, ImageOps
5 | import numpy as np
6 | from typing import Union
7 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8 | import matplotlib.pyplot as plt
9 | import PIL
10 | from .mask_painter import mask_painter as mask_painter2
11 | from .base_segmenter import BaseSegmenter
12 | from .painter import mask_painter, point_painter
13 | import os
14 | import requests
15 | import sys
16 |
17 |
18 | mask_color = 3
19 | mask_alpha = 0.7
20 | contour_color = 1
21 | contour_width = 5
22 | point_color_ne = 8
23 | point_color_ps = 50
24 | point_alpha = 0.9
25 | point_radius = 15
26 | contour_color = 2
27 | contour_width = 5
28 |
29 |
30 | class SamControler():
31 | def __init__(self, SAM_checkpoint, model_type, device):
32 | '''
33 | initialize sam controler
34 | '''
35 | self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36 |
37 |
38 | # def seg_again(self, image: np.ndarray):
39 | # '''
40 | # it is used when interact in video
41 | # '''
42 | # self.sam_controler.reset_image()
43 | # self.sam_controler.set_image(image)
44 | # return
45 |
46 |
47 | def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
48 | '''
49 | it is used in first frame in video
50 | return: mask, logit, painted image(mask+point)
51 | '''
52 | # self.sam_controler.set_image(image)
53 | origal_image = self.sam_controler.orignal_image
54 | neg_flag = labels[-1]
55 | if neg_flag==1:
56 | #find neg
57 | prompts = {
58 | 'point_coords': points,
59 | 'point_labels': labels,
60 | }
61 | masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
62 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
63 | prompts = {
64 | 'point_coords': points,
65 | 'point_labels': labels,
66 | 'mask_input': logit[None, :, :]
67 | }
68 | masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
69 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
70 | else:
71 | #find positive
72 | prompts = {
73 | 'point_coords': points,
74 | 'point_labels': labels,
75 | }
76 | masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
77 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78 |
79 |
80 | assert len(points)==len(labels)
81 |
82 | painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
83 | painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
84 | painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
85 | painted_image = Image.fromarray(painted_image)
86 |
87 | return mask, logit, painted_image
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/hugging_face/tools/mask_painter.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import copy
6 | import time
7 |
8 |
9 | def colormap(rgb=True):
10 | color_list = np.array(
11 | [
12 | 0.000, 0.000, 0.000,
13 | 1.000, 1.000, 1.000,
14 | 1.000, 0.498, 0.313,
15 | 0.392, 0.581, 0.929,
16 | 0.000, 0.447, 0.741,
17 | 0.850, 0.325, 0.098,
18 | 0.929, 0.694, 0.125,
19 | 0.494, 0.184, 0.556,
20 | 0.466, 0.674, 0.188,
21 | 0.301, 0.745, 0.933,
22 | 0.635, 0.078, 0.184,
23 | 0.300, 0.300, 0.300,
24 | 0.600, 0.600, 0.600,
25 | 1.000, 0.000, 0.000,
26 | 1.000, 0.500, 0.000,
27 | 0.749, 0.749, 0.000,
28 | 0.000, 1.000, 0.000,
29 | 0.000, 0.000, 1.000,
30 | 0.667, 0.000, 1.000,
31 | 0.333, 0.333, 0.000,
32 | 0.333, 0.667, 0.000,
33 | 0.333, 1.000, 0.000,
34 | 0.667, 0.333, 0.000,
35 | 0.667, 0.667, 0.000,
36 | 0.667, 1.000, 0.000,
37 | 1.000, 0.333, 0.000,
38 | 1.000, 0.667, 0.000,
39 | 1.000, 1.000, 0.000,
40 | 0.000, 0.333, 0.500,
41 | 0.000, 0.667, 0.500,
42 | 0.000, 1.000, 0.500,
43 | 0.333, 0.000, 0.500,
44 | 0.333, 0.333, 0.500,
45 | 0.333, 0.667, 0.500,
46 | 0.333, 1.000, 0.500,
47 | 0.667, 0.000, 0.500,
48 | 0.667, 0.333, 0.500,
49 | 0.667, 0.667, 0.500,
50 | 0.667, 1.000, 0.500,
51 | 1.000, 0.000, 0.500,
52 | 1.000, 0.333, 0.500,
53 | 1.000, 0.667, 0.500,
54 | 1.000, 1.000, 0.500,
55 | 0.000, 0.333, 1.000,
56 | 0.000, 0.667, 1.000,
57 | 0.000, 1.000, 1.000,
58 | 0.333, 0.000, 1.000,
59 | 0.333, 0.333, 1.000,
60 | 0.333, 0.667, 1.000,
61 | 0.333, 1.000, 1.000,
62 | 0.667, 0.000, 1.000,
63 | 0.667, 0.333, 1.000,
64 | 0.667, 0.667, 1.000,
65 | 0.667, 1.000, 1.000,
66 | 1.000, 0.000, 1.000,
67 | 1.000, 0.333, 1.000,
68 | 1.000, 0.667, 1.000,
69 | 0.167, 0.000, 0.000,
70 | 0.333, 0.000, 0.000,
71 | 0.500, 0.000, 0.000,
72 | 0.667, 0.000, 0.000,
73 | 0.833, 0.000, 0.000,
74 | 1.000, 0.000, 0.000,
75 | 0.000, 0.167, 0.000,
76 | 0.000, 0.333, 0.000,
77 | 0.000, 0.500, 0.000,
78 | 0.000, 0.667, 0.000,
79 | 0.000, 0.833, 0.000,
80 | 0.000, 1.000, 0.000,
81 | 0.000, 0.000, 0.167,
82 | 0.000, 0.000, 0.333,
83 | 0.000, 0.000, 0.500,
84 | 0.000, 0.000, 0.667,
85 | 0.000, 0.000, 0.833,
86 | 0.000, 0.000, 1.000,
87 | 0.143, 0.143, 0.143,
88 | 0.286, 0.286, 0.286,
89 | 0.429, 0.429, 0.429,
90 | 0.571, 0.571, 0.571,
91 | 0.714, 0.714, 0.714,
92 | 0.857, 0.857, 0.857
93 | ]
94 | ).astype(np.float32)
95 | color_list = color_list.reshape((-1, 3)) * 255
96 | if not rgb:
97 | color_list = color_list[:, ::-1]
98 | return color_list
99 |
100 |
101 | color_list = colormap()
102 | color_list = color_list.astype('uint8').tolist()
103 |
104 |
105 | def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106 | background_color = np.array(background_color)
107 | contour_color = np.array(contour_color)
108 |
109 | # background_mask = 1 - background_mask
110 | # contour_mask = 1 - contour_mask
111 |
112 | for i in range(3):
113 | image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114 | + background_color[i] * (background_alpha-background_mask*background_alpha)
115 |
116 | image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117 | + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118 |
119 | return image.astype('uint8')
120 |
121 |
122 | def mask_generator_00(mask, background_radius, contour_radius):
123 | # no background width when '00'
124 | # distance map
125 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127 | dist_map = dist_transform_fore - dist_transform_back
128 | # ...:::!!!:::...
129 | contour_radius += 2
130 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131 | contour_mask = contour_mask / np.max(contour_mask)
132 | contour_mask[contour_mask>0.5] = 1.
133 |
134 | return mask, contour_mask
135 |
136 |
137 | def mask_generator_01(mask, background_radius, contour_radius):
138 | # no background width when '00'
139 | # distance map
140 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142 | dist_map = dist_transform_fore - dist_transform_back
143 | # ...:::!!!:::...
144 | contour_radius += 2
145 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146 | contour_mask = contour_mask / np.max(contour_mask)
147 | return mask, contour_mask
148 |
149 |
150 | def mask_generator_10(mask, background_radius, contour_radius):
151 | # distance map
152 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154 | dist_map = dist_transform_fore - dist_transform_back
155 | # .....:::::!!!!!
156 | background_mask = np.clip(dist_map, -background_radius, background_radius)
157 | background_mask = (background_mask - np.min(background_mask))
158 | background_mask = background_mask / np.max(background_mask)
159 | # ...:::!!!:::...
160 | contour_radius += 2
161 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162 | contour_mask = contour_mask / np.max(contour_mask)
163 | contour_mask[contour_mask>0.5] = 1.
164 | return background_mask, contour_mask
165 |
166 |
167 | def mask_generator_11(mask, background_radius, contour_radius):
168 | # distance map
169 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171 | dist_map = dist_transform_fore - dist_transform_back
172 | # .....:::::!!!!!
173 | background_mask = np.clip(dist_map, -background_radius, background_radius)
174 | background_mask = (background_mask - np.min(background_mask))
175 | background_mask = background_mask / np.max(background_mask)
176 | # ...:::!!!:::...
177 | contour_radius += 2
178 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179 | contour_mask = contour_mask / np.max(contour_mask)
180 | return background_mask, contour_mask
181 |
182 |
183 | def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184 | """
185 | Input:
186 | input_image: numpy array
187 | input_mask: numpy array
188 | background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189 | background_blur_radius: radius of background blur, must be odd number
190 | contour_width: width of mask contour, must be odd number
191 | contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192 | contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193 | mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194 |
195 | Output:
196 | painted_image: numpy array
197 | """
198 | assert input_image.shape[:2] == input_mask.shape, 'different shape'
199 | assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200 | assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201 |
202 | # downsample input image and mask
203 | width, height = input_image.shape[0], input_image.shape[1]
204 | res = 1024
205 | ratio = min(1.0 * res / max(width, height), 1.0)
206 | input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207 | input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208 |
209 | # 0: background, 1: foreground
210 | msk = np.clip(input_mask, 0, 1)
211 |
212 | # generate masks for background and contour pixels
213 | background_radius = (background_blur_radius - 1) // 2
214 | contour_radius = (contour_width - 1) // 2
215 | generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216 | background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217 |
218 | # paint
219 | painted_image = vis_add_mask\
220 | (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221 |
222 | return painted_image
223 |
224 |
225 | if __name__ == '__main__':
226 |
227 | background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228 | background_blur_radius = 31 # radius of background blur, must be odd number
229 | contour_width = 11 # contour width, must be odd number
230 | contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231 | contour_alpha = 1 # transparency of background, 0: no contour highlighted
232 |
233 | # load input image and mask
234 | input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235 | input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236 |
237 | # paint
238 | overall_time_1 = 0
239 | overall_time_2 = 0
240 | overall_time_3 = 0
241 | overall_time_4 = 0
242 | overall_time_5 = 0
243 |
244 | for i in range(50):
245 | t2 = time.time()
246 | painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247 | e2 = time.time()
248 |
249 | t3 = time.time()
250 | painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251 | e3 = time.time()
252 |
253 | t1 = time.time()
254 | painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255 | e1 = time.time()
256 |
257 | t4 = time.time()
258 | painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259 | e4 = time.time()
260 |
261 | t5 = time.time()
262 | painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263 | e5 = time.time()
264 |
265 | overall_time_1 += (e1 - t1)
266 | overall_time_2 += (e2 - t2)
267 | overall_time_3 += (e3 - t3)
268 | overall_time_4 += (e4 - t4)
269 | overall_time_5 += (e5 - t5)
270 |
271 | print(f'average time w gaussian: {overall_time_1/50}')
272 | print(f'average time w/o gaussian00: {overall_time_2/50}')
273 | print(f'average time w/o gaussian10: {overall_time_3/50}')
274 | print(f'average time w/o gaussian01: {overall_time_4/50}')
275 | print(f'average time w/o gaussian11: {overall_time_5/50}')
276 |
277 | # save
278 | painted_image_00 = Image.fromarray(painted_image_00)
279 | painted_image_00.save('./test_img/painter_output_image_00.png')
280 |
281 | painted_image_10 = Image.fromarray(painted_image_10)
282 | painted_image_10.save('./test_img/painter_output_image_10.png')
283 |
284 | painted_image_01 = Image.fromarray(painted_image_01)
285 | painted_image_01.save('./test_img/painter_output_image_01.png')
286 |
287 | painted_image_11 = Image.fromarray(painted_image_11)
288 | painted_image_11.save('./test_img/painter_output_image_11.png')
289 |
--------------------------------------------------------------------------------
/hugging_face/tools/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import logging
8 | import numpy as np
9 | from os import path as osp
10 |
11 | def constant_init(module, val, bias=0):
12 | if hasattr(module, 'weight') and module.weight is not None:
13 | nn.init.constant_(module.weight, val)
14 | if hasattr(module, 'bias') and module.bias is not None:
15 | nn.init.constant_(module.bias, bias)
16 |
17 | initialized_logger = {}
18 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19 | """Get the root logger.
20 | The logger will be initialized if it has not been initialized. By default a
21 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
22 | also be added.
23 | Args:
24 | logger_name (str): root logger name. Default: 'basicsr'.
25 | log_file (str | None): The log filename. If specified, a FileHandler
26 | will be added to the root logger.
27 | log_level (int): The root logger level. Note that only the process of
28 | rank 0 is affected, while other processes will set the level to
29 | "Error" and be silent most of the time.
30 | Returns:
31 | logging.Logger: The root logger.
32 | """
33 | logger = logging.getLogger(logger_name)
34 | # if the logger has been initialized, just return it
35 | if logger_name in initialized_logger:
36 | return logger
37 |
38 | format_str = '%(asctime)s %(levelname)s: %(message)s'
39 | stream_handler = logging.StreamHandler()
40 | stream_handler.setFormatter(logging.Formatter(format_str))
41 | logger.addHandler(stream_handler)
42 | logger.propagate = False
43 |
44 | if log_file is not None:
45 | logger.setLevel(log_level)
46 | # add file handler
47 | # file_handler = logging.FileHandler(log_file, 'w')
48 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49 | file_handler.setFormatter(logging.Formatter(format_str))
50 | file_handler.setLevel(log_level)
51 | logger.addHandler(file_handler)
52 | initialized_logger[logger_name] = True
53 | return logger
54 |
55 |
56 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57 | torch.__version__)[0][:3])] >= [1, 12, 0]
58 |
59 | def gpu_is_available():
60 | if IS_HIGH_VERSION:
61 | if torch.backends.mps.is_available():
62 | return True
63 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64 |
65 | def get_device(gpu_id=None):
66 | if gpu_id is None:
67 | gpu_str = ''
68 | elif isinstance(gpu_id, int):
69 | gpu_str = f':{gpu_id}'
70 | else:
71 | raise TypeError('Input should be int value.')
72 |
73 | if IS_HIGH_VERSION:
74 | if torch.backends.mps.is_available():
75 | return torch.device('mps'+gpu_str)
76 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77 |
78 |
79 | def set_random_seed(seed):
80 | """Set random seeds."""
81 | random.seed(seed)
82 | np.random.seed(seed)
83 | torch.manual_seed(seed)
84 | torch.cuda.manual_seed(seed)
85 | torch.cuda.manual_seed_all(seed)
86 |
87 |
88 | def get_time_str():
89 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90 |
91 |
92 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93 | """Scan a directory to find the interested files.
94 |
95 | Args:
96 | dir_path (str): Path of the directory.
97 | suffix (str | tuple(str), optional): File suffix that we are
98 | interested in. Default: None.
99 | recursive (bool, optional): If set to True, recursively scan the
100 | directory. Default: False.
101 | full_path (bool, optional): If set to True, include the dir_path.
102 | Default: False.
103 |
104 | Returns:
105 | A generator for all the interested files with relative pathes.
106 | """
107 |
108 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109 | raise TypeError('"suffix" must be a string or tuple of strings')
110 |
111 | root = dir_path
112 |
113 | def _scandir(dir_path, suffix, recursive):
114 | for entry in os.scandir(dir_path):
115 | if not entry.name.startswith('.') and entry.is_file():
116 | if full_path:
117 | return_path = entry.path
118 | else:
119 | return_path = osp.relpath(entry.path, root)
120 |
121 | if suffix is None:
122 | yield return_path
123 | elif return_path.endswith(suffix):
124 | yield return_path
125 | else:
126 | if recursive:
127 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128 | else:
129 | continue
130 |
131 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
--------------------------------------------------------------------------------
/hugging_face/tools/painter.py:
--------------------------------------------------------------------------------
1 | # paint masks, contours, or points on images, with specified colors
2 | import cv2
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import copy
7 | import time
8 |
9 |
10 | def colormap(rgb=True):
11 | color_list = np.array(
12 | [
13 | 0.000, 0.000, 0.000,
14 | 1.000, 1.000, 1.000,
15 | 1.000, 0.498, 0.313,
16 | 0.392, 0.581, 0.929,
17 | 0.000, 0.447, 0.741,
18 | 0.850, 0.325, 0.098,
19 | 0.929, 0.694, 0.125,
20 | 0.494, 0.184, 0.556,
21 | 0.466, 0.674, 0.188,
22 | 0.301, 0.745, 0.933,
23 | 0.635, 0.078, 0.184,
24 | 0.300, 0.300, 0.300,
25 | 0.600, 0.600, 0.600,
26 | 1.000, 0.000, 0.000,
27 | 1.000, 0.500, 0.000,
28 | 0.749, 0.749, 0.000,
29 | 0.000, 1.000, 0.000,
30 | 0.000, 0.000, 1.000,
31 | 0.667, 0.000, 1.000,
32 | 0.333, 0.333, 0.000,
33 | 0.333, 0.667, 0.000,
34 | 0.333, 1.000, 0.000,
35 | 0.667, 0.333, 0.000,
36 | 0.667, 0.667, 0.000,
37 | 0.667, 1.000, 0.000,
38 | 1.000, 0.333, 0.000,
39 | 1.000, 0.667, 0.000,
40 | 1.000, 1.000, 0.000,
41 | 0.000, 0.333, 0.500,
42 | 0.000, 0.667, 0.500,
43 | 0.000, 1.000, 0.500,
44 | 0.333, 0.000, 0.500,
45 | 0.333, 0.333, 0.500,
46 | 0.333, 0.667, 0.500,
47 | 0.333, 1.000, 0.500,
48 | 0.667, 0.000, 0.500,
49 | 0.667, 0.333, 0.500,
50 | 0.667, 0.667, 0.500,
51 | 0.667, 1.000, 0.500,
52 | 1.000, 0.000, 0.500,
53 | 1.000, 0.333, 0.500,
54 | 1.000, 0.667, 0.500,
55 | 1.000, 1.000, 0.500,
56 | 0.000, 0.333, 1.000,
57 | 0.000, 0.667, 1.000,
58 | 0.000, 1.000, 1.000,
59 | 0.333, 0.000, 1.000,
60 | 0.333, 0.333, 1.000,
61 | 0.333, 0.667, 1.000,
62 | 0.333, 1.000, 1.000,
63 | 0.667, 0.000, 1.000,
64 | 0.667, 0.333, 1.000,
65 | 0.667, 0.667, 1.000,
66 | 0.667, 1.000, 1.000,
67 | 1.000, 0.000, 1.000,
68 | 1.000, 0.333, 1.000,
69 | 1.000, 0.667, 1.000,
70 | 0.167, 0.000, 0.000,
71 | 0.333, 0.000, 0.000,
72 | 0.500, 0.000, 0.000,
73 | 0.667, 0.000, 0.000,
74 | 0.833, 0.000, 0.000,
75 | 1.000, 0.000, 0.000,
76 | 0.000, 0.167, 0.000,
77 | 0.000, 0.333, 0.000,
78 | 0.000, 0.500, 0.000,
79 | 0.000, 0.667, 0.000,
80 | 0.000, 0.833, 0.000,
81 | 0.000, 1.000, 0.000,
82 | 0.000, 0.000, 0.167,
83 | 0.000, 0.000, 0.333,
84 | 0.000, 0.000, 0.500,
85 | 0.000, 0.000, 0.667,
86 | 0.000, 0.000, 0.833,
87 | 0.000, 0.000, 1.000,
88 | 0.143, 0.143, 0.143,
89 | 0.286, 0.286, 0.286,
90 | 0.429, 0.429, 0.429,
91 | 0.571, 0.571, 0.571,
92 | 0.714, 0.714, 0.714,
93 | 0.857, 0.857, 0.857
94 | ]
95 | ).astype(np.float32)
96 | color_list = color_list.reshape((-1, 3)) * 255
97 | if not rgb:
98 | color_list = color_list[:, ::-1]
99 | return color_list
100 |
101 |
102 | color_list = colormap()
103 | color_list = color_list.astype('uint8').tolist()
104 |
105 |
106 | def vis_add_mask(image, mask, color, alpha):
107 | color = np.array(color_list[color])
108 | mask = mask > 0.5
109 | image[mask] = image[mask] * (1-alpha) + color * alpha
110 | return image.astype('uint8')
111 |
112 | def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113 | h, w = input_image.shape[:2]
114 | point_mask = np.zeros((h, w)).astype('uint8')
115 | for point in input_points:
116 | point_mask[point[1], point[0]] = 1
117 |
118 | kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119 | point_mask = cv2.dilate(point_mask, kernel)
120 |
121 | contour_radius = (contour_width - 1) // 2
122 | dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123 | dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124 | dist_map = dist_transform_fore - dist_transform_back
125 | # ...:::!!!:::...
126 | contour_radius += 2
127 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128 | contour_mask = contour_mask / np.max(contour_mask)
129 | contour_mask[contour_mask>0.5] = 1.
130 |
131 | # paint mask
132 | painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133 | # paint contour
134 | painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135 | return painted_image
136 |
137 | def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138 | assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139 | # 0: background, 1: foreground
140 | mask = np.clip(input_mask, 0, 1)
141 | contour_radius = (contour_width - 1) // 2
142 |
143 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145 | dist_map = dist_transform_fore - dist_transform_back
146 | # ...:::!!!:::...
147 | contour_radius += 2
148 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149 | contour_mask = contour_mask / np.max(contour_mask)
150 | contour_mask[contour_mask>0.5] = 1.
151 |
152 | # paint mask
153 | painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154 | # paint contour
155 | painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156 |
157 | return painted_image
158 |
159 | def background_remover(input_image, input_mask):
160 | """
161 | input_image: H, W, 3, np.array
162 | input_mask: H, W, np.array
163 |
164 | image_wo_background: PIL.Image
165 | """
166 | assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167 | # 0: background, 1: foreground
168 | mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169 | image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170 | image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171 |
172 | return image_wo_background
173 |
174 | if __name__ == '__main__':
175 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176 | input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177 |
178 | # example of mask painter
179 | mask_color = 3
180 | mask_alpha = 0.7
181 | contour_color = 1
182 | contour_width = 5
183 |
184 | # save
185 | painted_image = Image.fromarray(input_image)
186 | painted_image.save('images/original.png')
187 |
188 | painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189 | # save
190 | painted_image = Image.fromarray(input_image)
191 | painted_image.save('images/original1.png')
192 |
193 | # example of point painter
194 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195 | input_points = np.array([[500, 375], [70, 600]]) # x, y
196 | point_color = 5
197 | point_alpha = 0.9
198 | point_radius = 15
199 | contour_color = 2
200 | contour_width = 5
201 | painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202 | # save
203 | painted_image = Image.fromarray(painted_image_1)
204 | painted_image.save('images/point_painter_1.png')
205 |
206 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207 | painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208 | # save
209 | painted_image = Image.fromarray(painted_image_2)
210 | painted_image.save('images/point_painter_2.png')
211 |
212 | # example of background remover
213 | input_image = np.array(Image.open('images/original.png').convert('RGB'))
214 | image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215 | image_wo_background.save('images/image_wo_background.png')
216 |
--------------------------------------------------------------------------------
/inference_hf.py:
--------------------------------------------------------------------------------
1 | from matanyone import InferenceCore
2 |
3 |
4 | def main(
5 | input_path,
6 | mask_path,
7 | output_path,
8 | n_warmup=10,
9 | r_erode=10,
10 | r_dilate=10,
11 | suffix="",
12 | save_image=False,
13 | max_size=-1,
14 | ):
15 | processor = InferenceCore("PeiqingYang/MatAnyone")
16 | fgr, alpha = processor.process_video(
17 | input_path=input_path,
18 | mask_path=mask_path,
19 | output_path=output_path,
20 | n_warmup=n_warmup,
21 | r_erode=r_erode,
22 | r_dilate=r_dilate,
23 | suffix=suffix,
24 | save_image=save_image,
25 | max_size=max_size,
26 | )
27 | return fgr, alpha
--------------------------------------------------------------------------------
/inference_matanyone.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import tqdm
4 | import imageio
5 | import numpy as np
6 | from PIL import Image
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | from hugging_face.tools.download_util import load_file_from_url
12 | from matanyone.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos
13 |
14 | from matanyone.inference.inference_core import InferenceCore
15 | from matanyone.utils.get_default_model import get_matanyone_model
16 |
17 | import warnings
18 | warnings.filterwarnings("ignore")
19 |
20 | @torch.inference_mode()
21 | @torch.amp.autocast("cuda")
22 | def main(input_path, mask_path, output_path, ckpt_path, n_warmup=10, r_erode=10, r_dilate=10, suffix="", save_image=False, max_size=-1):
23 | # download ckpt for the first inference
24 | pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
25 | ckpt_path = load_file_from_url(pretrain_model_url, 'pretrained_models')
26 |
27 | # load MatAnyone model
28 | matanyone = get_matanyone_model(ckpt_path)
29 |
30 | # init inference processor
31 | processor = InferenceCore(matanyone, cfg=matanyone.cfg)
32 |
33 | # inference parameters
34 | r_erode = int(r_erode)
35 | r_dilate = int(r_dilate)
36 | n_warmup = int(n_warmup)
37 | max_size = int(max_size)
38 |
39 | # load input frames
40 | vframes, fps, length, video_name = read_frame_from_videos(input_path)
41 | repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1) # repeat the first frame for warmup
42 | vframes = torch.cat([repeated_frames, vframes], dim=0).float()
43 | length += n_warmup # update length
44 |
45 | # resize if needed
46 | if max_size > 0:
47 | h, w = vframes.shape[-2:]
48 | min_side = min(h, w)
49 | if min_side > max_size:
50 | new_h = int(h / min_side * max_size)
51 | new_w = int(w / min_side * max_size)
52 |
53 | vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area")
54 |
55 | # set output paths
56 | os.makedirs(output_path, exist_ok=True)
57 | if suffix != "":
58 | video_name = f'{video_name}_{suffix}'
59 | if save_image:
60 | os.makedirs(f'{output_path}/{video_name}', exist_ok=True)
61 | os.makedirs(f'{output_path}/{video_name}/pha', exist_ok=True)
62 | os.makedirs(f'{output_path}/{video_name}/fgr', exist_ok=True)
63 |
64 | # load the first-frame mask
65 | mask = Image.open(mask_path).convert('L')
66 | mask = np.array(mask)
67 |
68 | bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3)) # green screen to paste fgr
69 | objects = [1]
70 |
71 | # [optional] erode & dilate
72 | if r_dilate > 0:
73 | mask = gen_dilate(mask, r_dilate, r_dilate)
74 | if r_erode > 0:
75 | mask = gen_erosion(mask, r_erode, r_erode)
76 |
77 | mask = torch.from_numpy(mask).cuda()
78 |
79 | if max_size > 0: # resize needed
80 | mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest")
81 | mask = mask[0,0]
82 |
83 | # inference start
84 | phas = []
85 | fgrs = []
86 | for ti in tqdm.tqdm(range(length)):
87 | # load the image as RGB; normalization is done within the model
88 | image = vframes[ti]
89 |
90 | image_np = np.array(image.permute(1,2,0)) # for output visualize
91 | image = (image / 255.).cuda().float() # for network input
92 |
93 | if ti == 0:
94 | output_prob = processor.step(image, mask, objects=objects) # encode given mask
95 | output_prob = processor.step(image, first_frame_pred=True) # first frame for prediction
96 | else:
97 | if ti <= n_warmup:
98 | output_prob = processor.step(image, first_frame_pred=True) # reinit as the first frame for prediction
99 | else:
100 | output_prob = processor.step(image)
101 |
102 | # convert output probabilities to alpha matte
103 | mask = processor.output_prob_to_mask(output_prob)
104 |
105 | # visualize prediction
106 | pha = mask.unsqueeze(2).cpu().numpy()
107 | com_np = image_np / 255. * pha + bgr * (1 - pha)
108 |
109 | # DONOT save the warmup frame
110 | if ti > (n_warmup-1):
111 | com_np = (com_np*255).astype(np.uint8)
112 | pha = (pha*255).astype(np.uint8)
113 | fgrs.append(com_np)
114 | phas.append(pha)
115 | if save_image:
116 | cv2.imwrite(f'{output_path}/{video_name}/pha/{str(ti-n_warmup).zfill(5)}.png', pha)
117 | cv2.imwrite(f'{output_path}/{video_name}/fgr/{str(ti-n_warmup).zfill(5)}.png', com_np[...,[2,1,0]])
118 |
119 | phas = np.array(phas)
120 | fgrs = np.array(fgrs)
121 |
122 | imageio.mimwrite(f'{output_path}/{video_name}_fgr.mp4', fgrs, fps=fps, quality=7)
123 | imageio.mimwrite(f'{output_path}/{video_name}_pha.mp4', phas, fps=fps, quality=7)
124 |
125 | if __name__ == '__main__':
126 | import argparse
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument('-i', '--input_path', type=str, default="inputs/video/test-sample1.mp4", help='Path of the input video or frame folder.')
129 | parser.add_argument('-m', '--mask_path', type=str, default="inputs/mask/test-sample1.png", help='Path of the first-frame segmentation mask.')
130 | parser.add_argument('-o', '--output_path', type=str, default="results/", help='Output folder. Default: results')
131 | parser.add_argument('-c', '--ckpt_path', type=str, default="pretrained_models/matanyone.pth", help='Path of the MatAnyone model.')
132 | parser.add_argument('-w', '--warmup', type=str, default="10", help='Number of warmup iterations for the first frame alpha prediction.')
133 | parser.add_argument('-e', '--erode_kernel', type=str, default="10", help='Erosion kernel on the input mask.')
134 | parser.add_argument('-d', '--dilate_kernel', type=str, default="10", help='Dilation kernel on the input mask.')
135 | parser.add_argument('--suffix', type=str, default="", help='Suffix to specify different target when saving, e.g., target1.')
136 | parser.add_argument('--save_image', action='store_true', default=False, help='Save output frames. Default: False')
137 | parser.add_argument('--max_size', type=str, default="-1", help='When positive, the video will be downsampled if min(w, h) exceeds. Default: -1 (means no limit)')
138 |
139 |
140 | args = parser.parse_args()
141 |
142 | main(input_path=args.input_path, \
143 | mask_path=args.mask_path, \
144 | output_path=args.output_path, \
145 | ckpt_path=args.ckpt_path, \
146 | n_warmup=args.warmup, \
147 | r_erode=args.erode_kernel, \
148 | r_dilate=args.dilate_kernel, \
149 | suffix=args.suffix, \
150 | save_image=args.save_image, \
151 | max_size=args.max_size)
152 |
--------------------------------------------------------------------------------
/inputs/mask/test-sample0_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample0_1.png
--------------------------------------------------------------------------------
/inputs/mask/test-sample0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample0_2.png
--------------------------------------------------------------------------------
/inputs/mask/test-sample1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample1.png
--------------------------------------------------------------------------------
/inputs/mask/test-sample2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample2.png
--------------------------------------------------------------------------------
/inputs/mask/test-sample3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample3.png
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0000.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0001.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0002.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0003.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0004.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0005.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0006.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0007.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0008.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0009.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0010.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0010.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0011.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0011.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0012.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0012.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0013.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0013.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0014.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0014.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0015.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0015.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0016.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0016.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0017.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0017.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0018.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0018.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0019.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0019.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0020.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0020.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0021.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0021.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0022.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0022.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0023.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0023.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0024.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0025.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0026.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0026.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0027.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0027.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0028.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0028.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0029.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0029.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0030.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0030.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0031.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0031.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0032.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0032.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0033.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0033.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0034.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0034.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0035.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0035.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0036.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0036.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0037.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0037.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0038.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0038.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0039.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0039.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0040.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0040.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0041.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0041.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0042.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0042.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0043.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0043.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0044.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0044.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0045.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0045.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0046.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0046.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0047.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0047.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0048.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0049.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0049.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0050.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0050.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0051.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0051.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0052.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0052.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0053.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0053.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0054.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0054.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0055.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0055.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0056.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0056.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0057.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0057.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0058.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0058.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0059.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0059.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0060.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0060.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0061.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0061.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0062.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0062.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0063.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0063.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0064.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0064.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0065.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0065.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0066.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0066.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0067.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0067.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0068.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0068.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0069.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0069.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0070.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0070.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample0/0071.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0071.jpg
--------------------------------------------------------------------------------
/inputs/video/test-sample1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample1.mp4
--------------------------------------------------------------------------------
/inputs/video/test-sample2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample2.mp4
--------------------------------------------------------------------------------
/inputs/video/test-sample3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample3.mp4
--------------------------------------------------------------------------------
/matanyone/__init__.py:
--------------------------------------------------------------------------------
1 | from matanyone.inference.inference_core import InferenceCore
2 | from matanyone.model.matanyone import MatAnyone
3 |
--------------------------------------------------------------------------------
/matanyone/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/config/__init__.py
--------------------------------------------------------------------------------
/matanyone/config/eval_matanyone_config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model: base
4 | - override hydra/job_logging: custom-no-rank.yaml
5 |
6 | hydra:
7 | run:
8 | dir: ../output/${exp_id}/${dataset}
9 | output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10 |
11 | amp: False
12 | weights: pretrained_models/matanyone.pth # default (can be modified from outside)
13 | output_dir: null # defaults to run_dir; specify this to override
14 | flip_aug: False
15 |
16 |
17 | # maximum shortest side of the input; -1 means no resizing
18 | # With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
19 | # this parameter is added for the sole purpose for the GUI in the current codebase
20 | # InferenceCore will downsize the input and restore the output to the original size if needed
21 | # if you are using this code for some other project, you can also utilize this parameter
22 | max_internal_size: -1
23 |
24 | # these parameters, when set, override the dataset's default; useful for debugging
25 | save_all: True
26 | use_all_masks: False
27 | use_long_term: False
28 | mem_every: 5
29 |
30 | # only relevant when long_term is not enabled
31 | max_mem_frames: 5
32 |
33 | # only relevant when long_term is enabled
34 | long_term:
35 | count_usage: True
36 | max_mem_frames: 10
37 | min_mem_frames: 5
38 | num_prototypes: 128
39 | max_num_tokens: 10000
40 | buffer_tokens: 2000
41 |
42 | top_k: 30
43 | stagger_updates: 5
44 | chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
45 | save_scores: False
46 | save_aux: False
47 | visualize: False
48 |
--------------------------------------------------------------------------------
/matanyone/config/hydra/job_logging/custom-no-rank.yaml:
--------------------------------------------------------------------------------
1 | # python logging configuration for tasks
2 | version: 1
3 | formatters:
4 | simple:
5 | format: '[%(asctime)s][%(levelname)s] - %(message)s'
6 | datefmt: '%Y-%m-%d %H:%M:%S'
7 | handlers:
8 | console:
9 | class: logging.StreamHandler
10 | formatter: simple
11 | stream: ext://sys.stdout
12 | file:
13 | class: logging.FileHandler
14 | formatter: simple
15 | # absolute file path
16 | filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
17 | mode: w
18 | root:
19 | level: INFO
20 | handlers: [console, file]
21 |
22 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/matanyone/config/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | # python logging configuration for tasks
2 | version: 1
3 | formatters:
4 | simple:
5 | format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6 | datefmt: '%Y-%m-%d %H:%M:%S'
7 | handlers:
8 | console:
9 | class: logging.StreamHandler
10 | formatter: simple
11 | stream: ext://sys.stdout
12 | file:
13 | class: logging.FileHandler
14 | formatter: simple
15 | # absolute file path
16 | filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
17 | mode: w
18 | root:
19 | level: INFO
20 | handlers: [console, file]
21 |
22 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/matanyone/config/model/base.yaml:
--------------------------------------------------------------------------------
1 | pixel_mean: [0.485, 0.456, 0.406]
2 | pixel_std: [0.229, 0.224, 0.225]
3 |
4 | pixel_dim: 256
5 | key_dim: 64
6 | value_dim: 256
7 | sensory_dim: 256
8 | embed_dim: 256
9 |
10 | pixel_encoder:
11 | type: resnet50
12 | ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
13 |
14 | mask_encoder:
15 | type: resnet18
16 | final_dim: 256
17 |
18 | pixel_pe_scale: 32
19 | pixel_pe_temperature: 128
20 |
21 | object_transformer:
22 | embed_dim: ${model.embed_dim}
23 | ff_dim: 2048
24 | num_heads: 8
25 | num_blocks: 3
26 | num_queries: 16
27 | read_from_pixel:
28 | input_norm: False
29 | input_add_pe: False
30 | add_pe_to_qkv: [True, True, False]
31 | read_from_past:
32 | add_pe_to_qkv: [True, True, False]
33 | read_from_memory:
34 | add_pe_to_qkv: [True, True, False]
35 | read_from_query:
36 | add_pe_to_qkv: [True, True, False]
37 | output_norm: False
38 | query_self_attention:
39 | add_pe_to_qkv: [True, True, False]
40 | pixel_self_attention:
41 | add_pe_to_qkv: [True, True, False]
42 |
43 | object_summarizer:
44 | embed_dim: ${model.object_transformer.embed_dim}
45 | num_summaries: ${model.object_transformer.num_queries}
46 | add_pe: True
47 |
48 | aux_loss:
49 | sensory:
50 | enabled: True
51 | weight: 0.01
52 | query:
53 | enabled: True
54 | weight: 0.01
55 |
56 | mask_decoder:
57 | # first value must equal embed_dim
58 | up_dims: [256, 128, 128, 64, 16]
59 |
--------------------------------------------------------------------------------
/matanyone/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/inference/__init__.py
--------------------------------------------------------------------------------
/matanyone/inference/image_feature_store.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Iterable
3 | import torch
4 | from matanyone.model.matanyone import MatAnyone
5 |
6 |
7 | class ImageFeatureStore:
8 | """
9 | A cache for image features.
10 | These features might be reused at different parts of the inference pipeline.
11 | This class provide an interface for reusing these features.
12 | It is the user's responsibility to delete redundant features.
13 |
14 | Feature of a frame should be associated with a unique index -- typically the frame id.
15 | """
16 | def __init__(self, network: MatAnyone, no_warning: bool = False):
17 | self.network = network
18 | self._store = {}
19 | self.no_warning = no_warning
20 |
21 | def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
22 | ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
23 | key, shrinkage, selection = self.network.transform_key(ms_features[0])
24 | self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
25 |
26 | def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
27 | seq_length = images.shape[0]
28 | ms_features, pix_feat = self.network.encode_image(images, seq_length)
29 | key, shrinkage, selection = self.network.transform_key(ms_features[0])
30 | for index in range(seq_length):
31 | self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
32 |
33 | def get_features(self, index: int,
34 | image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
35 | if index not in self._store:
36 | self._encode_feature(index, image, last_feats)
37 |
38 | return self._store[index][:2]
39 |
40 | def get_key(self, index: int,
41 | image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
42 | if index not in self._store:
43 | self._encode_feature(index, image, last_feats)
44 |
45 | return self._store[index][2:]
46 |
47 | def delete(self, index: int) -> None:
48 | if index in self._store:
49 | del self._store[index]
50 |
51 | def __len__(self):
52 | return len(self._store)
53 |
54 | def __del__(self):
55 | if len(self._store) > 0 and not self.no_warning:
56 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
57 |
--------------------------------------------------------------------------------
/matanyone/inference/kv_memory_store.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Literal
2 | from collections import defaultdict
3 | import torch
4 |
5 |
6 | def _add_last_dim(dictionary, key, new_value, prepend=False):
7 | # append/prepend a new value to the last dimension of a tensor in a dictionary
8 | # if the key does not exist, put the new value in
9 | # append by default
10 | if key in dictionary:
11 | dictionary[key] = torch.cat([dictionary[key], new_value], -1)
12 | else:
13 | dictionary[key] = new_value
14 |
15 |
16 | class KeyValueMemoryStore:
17 | """
18 | Works for key/value pairs type storage
19 | e.g., working and long-term memory
20 | """
21 | def __init__(self, save_selection: bool = False, save_usage: bool = False):
22 | """
23 | We store keys and values of objects that first appear in the same frame in a bucket.
24 | Each bucket contains a set of object ids.
25 | Each bucket is associated with a single key tensor
26 | and a dictionary of value tensors indexed by object id.
27 |
28 | The keys and values are stored as the concatenation of a permanent part and a temporary part.
29 | """
30 | self.save_selection = save_selection
31 | self.save_usage = save_usage
32 |
33 | self.global_bucket_id = 0 # does not reduce even if buckets are removed
34 | self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
35 | self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
36 | self.v: Dict[int, torch.Tensor] = {} # indexed by object id
37 |
38 | # indexed by bucket id; the end point of permanent memory
39 | self.perm_end_pt: Dict[int, int] = defaultdict(int)
40 |
41 | # shrinkage and selection are just like the keys
42 | self.s = {}
43 | if self.save_selection:
44 | self.e = {} # does not contain the permanent memory part
45 |
46 | # usage
47 | if self.save_usage:
48 | self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
49 | self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
50 |
51 | def add(self,
52 | key: torch.Tensor,
53 | values: Dict[int, torch.Tensor],
54 | shrinkage: torch.Tensor,
55 | selection: torch.Tensor,
56 | supposed_bucket_id: int = -1,
57 | as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
58 | """
59 | key: (1/2)*C*N
60 | values: dict of values ((1/2)*C*N), object ids are used as keys
61 | shrinkage: (1/2)*1*N
62 | selection: (1/2)*C*N
63 |
64 | supposed_bucket_id: used to sync the bucket id between working and long-term memory
65 | if provided, the input should all be in a single bucket indexed by this id
66 | as_permanent: whether to store the input as permanent memory
67 | 'no': don't
68 | 'first': only store it as permanent memory if the bucket is empty
69 | 'all': always store it as permanent memory
70 | """
71 | bs = key.shape[0]
72 | ne = key.shape[-1]
73 | assert len(key.shape) == 3
74 | assert len(shrinkage.shape) == 3
75 | assert not self.save_selection or len(selection.shape) == 3
76 | assert as_permanent in ['no', 'first', 'all']
77 |
78 | # add the value and create new buckets if necessary
79 | if supposed_bucket_id >= 0:
80 | enabled_buckets = [supposed_bucket_id]
81 | bucket_exist = supposed_bucket_id in self.buckets
82 | for obj, value in values.items():
83 | if bucket_exist:
84 | assert obj in self.v
85 | assert obj in self.buckets[supposed_bucket_id]
86 | _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
87 | else:
88 | assert obj not in self.v
89 | self.v[obj] = value
90 | self.buckets[supposed_bucket_id] = list(values.keys())
91 | else:
92 | new_bucket_id = None
93 | enabled_buckets = set()
94 | for obj, value in values.items():
95 | assert len(value.shape) == 3
96 | if obj in self.v:
97 | _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
98 | bucket_used = [
99 | bucket_id for bucket_id, object_ids in self.buckets.items()
100 | if obj in object_ids
101 | ]
102 | assert len(bucket_used) == 1 # each object should only be in one bucket
103 | enabled_buckets.add(bucket_used[0])
104 | else:
105 | self.v[obj] = value
106 | if new_bucket_id is None:
107 | # create new bucket
108 | new_bucket_id = self.global_bucket_id
109 | self.global_bucket_id += 1
110 | self.buckets[new_bucket_id] = []
111 | # put the new object into the corresponding bucket
112 | self.buckets[new_bucket_id].append(obj)
113 | enabled_buckets.add(new_bucket_id)
114 |
115 | # increment the permanent size if necessary
116 | add_as_permanent = {} # indexed by bucket id
117 | for bucket_id in enabled_buckets:
118 | add_as_permanent[bucket_id] = False
119 | if as_permanent == 'all':
120 | self.perm_end_pt[bucket_id] += ne
121 | add_as_permanent[bucket_id] = True
122 | elif as_permanent == 'first':
123 | if self.perm_end_pt[bucket_id] == 0:
124 | self.perm_end_pt[bucket_id] = ne
125 | add_as_permanent[bucket_id] = True
126 |
127 | # create new counters for usage if necessary
128 | if self.save_usage and as_permanent != 'all':
129 | new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
130 | new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
131 |
132 | # add the key to every bucket
133 | for bucket_id in self.buckets:
134 | if bucket_id not in enabled_buckets:
135 | # if we are not adding new values to a bucket, we should skip it
136 | continue
137 |
138 | _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
139 | _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
140 | if not add_as_permanent[bucket_id]:
141 | if self.save_selection:
142 | _add_last_dim(self.e, bucket_id, selection)
143 | if self.save_usage:
144 | _add_last_dim(self.use_cnt, bucket_id, new_count)
145 | _add_last_dim(self.life_cnt, bucket_id, new_life)
146 |
147 | def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
148 | # increase all life count by 1
149 | # increase use of indexed elements
150 | if not self.save_usage:
151 | return
152 |
153 | usage = usage[:, self.perm_end_pt[bucket_id]:]
154 | if usage.shape[-1] == 0:
155 | # if there is no temporary memory, we don't need to update
156 | return
157 | self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
158 | self.life_cnt[bucket_id] += 1
159 |
160 | def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
161 | # keep only the temporary elements *outside* of this range (with some boundary conditions)
162 | # the permanent elements are ignored in this computation
163 | # i.e., concat (a[:start], a[end:])
164 | # bucket with size <= min_size are not modified
165 |
166 | assert start >= 0
167 | assert end <= 0
168 |
169 | object_ids = self.buckets[bucket_id]
170 | bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
171 | if bucket_num_elements <= min_size:
172 | return
173 |
174 | if end == 0:
175 | # negative 0 would not work as the end index!
176 | # effectively make the second part an empty slice
177 | end = self.k[bucket_id].shape[-1] + 1
178 |
179 | p_size = self.perm_end_pt[bucket_id]
180 | start = start + p_size
181 |
182 | k = self.k[bucket_id]
183 | s = self.s[bucket_id]
184 | if self.save_selection:
185 | e = self.e[bucket_id]
186 | if self.save_usage:
187 | use_cnt = self.use_cnt[bucket_id]
188 | life_cnt = self.life_cnt[bucket_id]
189 |
190 | self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
191 | self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
192 | if self.save_selection:
193 | self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
194 | if self.save_usage:
195 | self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
196 | self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
197 | -1)
198 | for obj_id in object_ids:
199 | v = self.v[obj_id]
200 | self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
201 |
202 | def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
203 | self.sieve_by_range(bucket_id, 0, -max_len, max_len)
204 |
205 | def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
206 | # for long-term memory only
207 | object_ids = self.buckets[bucket_id]
208 |
209 | assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
210 |
211 | # normalize with life duration
212 | usage = self.get_usage(bucket_id)
213 | bs = usage.shape[0]
214 |
215 | survivals = []
216 |
217 | for bi in range(bs):
218 | _, survived = torch.topk(usage[bi], k=max_size)
219 | survivals.append(survived.flatten())
220 | assert survived.shape[-1] == survivals[0].shape[-1]
221 |
222 | self.k[bucket_id] = torch.stack(
223 | [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
224 | self.s[bucket_id] = torch.stack(
225 | [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
226 |
227 | if self.save_selection:
228 | # Long-term memory does not store selection so this should not be needed
229 | self.e[bucket_id] = torch.stack(
230 | [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
231 | for obj_id in object_ids:
232 | self.v[obj_id] = torch.stack(
233 | [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
234 |
235 | self.use_cnt[bucket_id] = torch.stack(
236 | [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
237 | self.life_cnt[bucket_id] = torch.stack(
238 | [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
239 |
240 | def get_usage(self, bucket_id: int) -> torch.Tensor:
241 | # return normalized usage
242 | if not self.save_usage:
243 | raise RuntimeError('I did not count usage!')
244 | else:
245 | usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
246 | return usage
247 |
248 | def get_all_sliced(
249 | self, bucket_id: int, start: int, end: int
250 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
251 | # return k, sk, ek, value, normalized usage in order, sliced by start and end
252 | # this only queries the temporary memory
253 |
254 | assert start >= 0
255 | assert end <= 0
256 |
257 | p_size = self.perm_end_pt[bucket_id]
258 | start = start + p_size
259 |
260 | if end == 0:
261 | # negative 0 would not work as the end index!
262 | k = self.k[bucket_id][:, :, start:]
263 | sk = self.s[bucket_id][:, :, start:]
264 | ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
265 | value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
266 | usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
267 | else:
268 | k = self.k[bucket_id][:, :, start:end]
269 | sk = self.s[bucket_id][:, :, start:end]
270 | ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
271 | value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
272 | usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
273 |
274 | return k, sk, ek, value, usage
275 |
276 | def purge_except(self, obj_keep_idx: List[int]):
277 | # purge certain objects from the memory except the one listed
278 | obj_keep_idx = set(obj_keep_idx)
279 |
280 | # remove objects that are not in the keep list from the buckets
281 | buckets_to_remove = []
282 | for bucket_id, object_ids in self.buckets.items():
283 | self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
284 | if len(self.buckets[bucket_id]) == 0:
285 | buckets_to_remove.append(bucket_id)
286 |
287 | # remove object values that are not in the keep list
288 | self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
289 |
290 | # remove buckets that are empty
291 | for bucket_id in buckets_to_remove:
292 | del self.buckets[bucket_id]
293 | del self.k[bucket_id]
294 | del self.s[bucket_id]
295 | if self.save_selection:
296 | del self.e[bucket_id]
297 | if self.save_usage:
298 | del self.use_cnt[bucket_id]
299 | del self.life_cnt[bucket_id]
300 |
301 | def clear_non_permanent_memory(self):
302 | # clear all non-permanent memory
303 | for bucket_id in self.buckets:
304 | self.sieve_by_range(bucket_id, 0, 0, 0)
305 |
306 | def get_v_size(self, obj_id: int) -> int:
307 | return self.v[obj_id].shape[-1]
308 |
309 | def size(self, bucket_id: int) -> int:
310 | if bucket_id not in self.k:
311 | return 0
312 | else:
313 | return self.k[bucket_id].shape[-1]
314 |
315 | def perm_size(self, bucket_id: int) -> int:
316 | return self.perm_end_pt[bucket_id]
317 |
318 | def non_perm_size(self, bucket_id: int) -> int:
319 | return self.size(bucket_id) - self.perm_size(bucket_id)
320 |
321 | def engaged(self, bucket_id: Optional[int] = None) -> bool:
322 | if bucket_id is None:
323 | return len(self.buckets) > 0
324 | else:
325 | return bucket_id in self.buckets
326 |
327 | @property
328 | def num_objects(self) -> int:
329 | return len(self.v)
330 |
331 | @property
332 | def key(self) -> Dict[int, torch.Tensor]:
333 | return self.k
334 |
335 | @property
336 | def value(self) -> Dict[int, torch.Tensor]:
337 | return self.v
338 |
339 | @property
340 | def shrinkage(self) -> Dict[int, torch.Tensor]:
341 | return self.s
342 |
343 | @property
344 | def selection(self) -> Dict[int, torch.Tensor]:
345 | return self.e
346 |
347 | def __contains__(self, key):
348 | return key in self.v
349 |
--------------------------------------------------------------------------------
/matanyone/inference/object_info.py:
--------------------------------------------------------------------------------
1 | class ObjectInfo:
2 | """
3 | Store meta information for an object
4 | """
5 | def __init__(self, id: int):
6 | self.id = id
7 | self.poke_count = 0 # count number of detections missed
8 |
9 | def poke(self) -> None:
10 | self.poke_count += 1
11 |
12 | def unpoke(self) -> None:
13 | self.poke_count = 0
14 |
15 | def __hash__(self):
16 | return hash(self.id)
17 |
18 | def __eq__(self, other):
19 | if type(other) == int:
20 | return self.id == other
21 | return self.id == other.id
22 |
23 | def __repr__(self):
24 | return f'(ID: {self.id})'
25 |
--------------------------------------------------------------------------------
/matanyone/inference/object_manager.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Dict
2 |
3 | import torch
4 | from matanyone.inference.object_info import ObjectInfo
5 |
6 |
7 | class ObjectManager:
8 | """
9 | Object IDs are immutable. The same ID always represent the same object.
10 | Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
11 | Temporary IDs start from 1.
12 | """
13 |
14 | def __init__(self):
15 | self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
16 | self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
17 | self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
18 |
19 | self.all_historical_object_ids: List[int] = []
20 |
21 | def _recompute_obj_id_to_obj_mapping(self) -> None:
22 | self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
23 |
24 | def add_new_objects(
25 | self, objects: Union[List[ObjectInfo], ObjectInfo,
26 | List[int]]) -> (List[int], List[int]):
27 | if not isinstance(objects, list):
28 | objects = [objects]
29 |
30 | corresponding_tmp_ids = []
31 | corresponding_obj_ids = []
32 | for obj in objects:
33 | if isinstance(obj, int):
34 | obj = ObjectInfo(id=obj)
35 |
36 | if obj in self.obj_to_tmp_id:
37 | # old object
38 | corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
39 | corresponding_obj_ids.append(obj.id)
40 | else:
41 | # new object
42 | new_obj = ObjectInfo(id=obj.id)
43 |
44 | # new object
45 | new_tmp_id = len(self.obj_to_tmp_id) + 1
46 | self.obj_to_tmp_id[new_obj] = new_tmp_id
47 | self.tmp_id_to_obj[new_tmp_id] = new_obj
48 | self.all_historical_object_ids.append(new_obj.id)
49 | corresponding_tmp_ids.append(new_tmp_id)
50 | corresponding_obj_ids.append(new_obj.id)
51 |
52 | self._recompute_obj_id_to_obj_mapping()
53 | assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
54 | return corresponding_tmp_ids, corresponding_obj_ids
55 |
56 | def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
57 | # delete an object or a list of objects
58 | # re-sort the tmp ids
59 | if isinstance(obj_ids_to_remove, int):
60 | obj_ids_to_remove = [obj_ids_to_remove]
61 |
62 | new_tmp_id = 1
63 | total_num_id = len(self.obj_to_tmp_id)
64 |
65 | local_obj_to_tmp_id = {}
66 | local_tmp_to_obj_id = {}
67 |
68 | for tmp_iter in range(1, total_num_id + 1):
69 | obj = self.tmp_id_to_obj[tmp_iter]
70 | if obj.id not in obj_ids_to_remove:
71 | local_obj_to_tmp_id[obj] = new_tmp_id
72 | local_tmp_to_obj_id[new_tmp_id] = obj
73 | new_tmp_id += 1
74 |
75 | self.obj_to_tmp_id = local_obj_to_tmp_id
76 | self.tmp_id_to_obj = local_tmp_to_obj_id
77 | self._recompute_obj_id_to_obj_mapping()
78 |
79 | def purge_inactive_objects(self,
80 | max_missed_detection_count: int) -> (bool, List[int], List[int]):
81 | # remove tmp ids of objects that are removed
82 | obj_id_to_be_deleted = []
83 | tmp_id_to_be_deleted = []
84 | tmp_id_to_keep = []
85 | obj_id_to_keep = []
86 |
87 | for obj in self.obj_to_tmp_id:
88 | if obj.poke_count > max_missed_detection_count:
89 | obj_id_to_be_deleted.append(obj.id)
90 | tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
91 | else:
92 | tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
93 | obj_id_to_keep.append(obj.id)
94 |
95 | purge_activated = len(obj_id_to_be_deleted) > 0
96 | if purge_activated:
97 | self.delete_objects(obj_id_to_be_deleted)
98 | return purge_activated, tmp_id_to_keep, obj_id_to_keep
99 |
100 | def tmp_to_obj_cls(self, mask) -> torch.Tensor:
101 | # remap tmp id cls representation to the true object id representation
102 | new_mask = torch.zeros_like(mask)
103 | for tmp_id, obj in self.tmp_id_to_obj.items():
104 | new_mask[mask == tmp_id] = obj.id
105 | return new_mask
106 |
107 | def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
108 | # returns the mapping in a dict format for saving it with pickle
109 | return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
110 |
111 | def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
112 | # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
113 | output = []
114 | for _, obj in self.tmp_id_to_obj.items():
115 | if obj.id not in obj_dict:
116 | raise NotImplementedError
117 | output.append(obj_dict[obj.id])
118 | output = torch.stack(output, dim=dim)
119 | return output
120 |
121 | def make_one_hot(self, cls_mask) -> torch.Tensor:
122 | output = []
123 | for _, obj in self.tmp_id_to_obj.items():
124 | output.append(cls_mask == obj.id)
125 | if len(output) == 0:
126 | output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
127 | else:
128 | output = torch.stack(output, dim=0)
129 | return output
130 |
131 | @property
132 | def all_obj_ids(self) -> List[int]:
133 | return [k.id for k in self.obj_to_tmp_id]
134 |
135 | @property
136 | def num_obj(self) -> int:
137 | return len(self.obj_to_tmp_id)
138 |
139 | def has_all(self, objects: List[int]) -> bool:
140 | for obj in objects:
141 | if obj not in self.obj_to_tmp_id:
142 | return False
143 | return True
144 |
145 | def find_object_by_id(self, obj_id) -> ObjectInfo:
146 | return self.obj_id_to_obj[obj_id]
147 |
148 | def find_tmp_by_id(self, obj_id) -> int:
149 | return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
150 |
--------------------------------------------------------------------------------
/matanyone/inference/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/inference/utils/__init__.py
--------------------------------------------------------------------------------
/matanyone/inference/utils/args_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from omegaconf import DictConfig
3 |
4 | log = logging.getLogger()
5 |
6 |
7 | def get_dataset_cfg(cfg: DictConfig):
8 | dataset_name = cfg.dataset
9 | data_cfg = cfg.datasets[dataset_name]
10 |
11 | potential_overrides = [
12 | 'image_directory',
13 | 'mask_directory',
14 | 'json_directory',
15 | 'size',
16 | 'save_all',
17 | 'use_all_masks',
18 | 'use_long_term',
19 | 'mem_every',
20 | ]
21 |
22 | for override in potential_overrides:
23 | if cfg[override] is not None:
24 | log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
25 | data_cfg[override] = cfg[override]
26 | # escalte all potential overrides to the top-level config
27 | if override in data_cfg:
28 | cfg[override] = data_cfg[override]
29 |
30 | return data_cfg
31 |
--------------------------------------------------------------------------------
/matanyone/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/__init__.py
--------------------------------------------------------------------------------
/matanyone/model/aux_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | For computing auxiliary outputs for auxiliary losses
3 | """
4 | from typing import Dict
5 | from omegaconf import DictConfig
6 | import torch
7 | import torch.nn as nn
8 |
9 | from matanyone.model.group_modules import GConv2d
10 | from matanyone.utils.tensor_utils import aggregate
11 |
12 |
13 | class LinearPredictor(nn.Module):
14 | def __init__(self, x_dim: int, pix_dim: int):
15 | super().__init__()
16 | self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
17 |
18 | def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
19 | # pixel_feat: B*pix_dim*H*W
20 | # x: B*num_objects*x_dim*H*W
21 | num_objects = x.shape[1]
22 | x = self.projection(x)
23 |
24 | pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
25 | logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
26 | return logits
27 |
28 |
29 | class DirectPredictor(nn.Module):
30 | def __init__(self, x_dim: int):
31 | super().__init__()
32 | self.projection = GConv2d(x_dim, 1, kernel_size=1)
33 |
34 | def forward(self, x: torch.Tensor) -> torch.Tensor:
35 | # x: B*num_objects*x_dim*H*W
36 | logits = self.projection(x).squeeze(2)
37 | return logits
38 |
39 |
40 | class AuxComputer(nn.Module):
41 | def __init__(self, cfg: DictConfig):
42 | super().__init__()
43 |
44 | use_sensory_aux = cfg.model.aux_loss.sensory.enabled
45 | self.use_query_aux = cfg.model.aux_loss.query.enabled
46 | self.use_sensory_aux = use_sensory_aux
47 |
48 | sensory_dim = cfg.model.sensory_dim
49 | embed_dim = cfg.model.embed_dim
50 |
51 | if use_sensory_aux:
52 | self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
53 |
54 | def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
55 | prob = torch.sigmoid(logits)
56 | if selector is not None:
57 | prob = prob * selector
58 | logits = aggregate(prob, dim=1)
59 | return logits
60 |
61 | def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
62 | selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
63 | sensory = aux_input['sensory']
64 | q_logits = aux_input['q_logits']
65 |
66 | aux_output = {}
67 | aux_output['attn_mask'] = aux_input['attn_mask']
68 |
69 | if self.use_sensory_aux:
70 | # B*num_objects*H*W
71 | logits = self.sensory_aux(pix_feat, sensory)
72 | aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
73 | if self.use_query_aux:
74 | # B*num_objects*num_levels*H*W
75 | aux_output['q_logits'] = self._aggregate_with_selector(
76 | torch.stack(q_logits, dim=2),
77 | selector.unsqueeze(2) if selector is not None else None)
78 |
79 | return aux_output
80 |
81 | def compute_mask(self, aux_input: Dict[str, torch.Tensor],
82 | selector: torch.Tensor) -> Dict[str, torch.Tensor]:
83 | # sensory = aux_input['sensory']
84 | q_logits = aux_input['q_logits']
85 |
86 | aux_output = {}
87 |
88 | # B*num_objects*num_levels*H*W
89 | aux_output['q_logits'] = self._aggregate_with_selector(
90 | torch.stack(q_logits, dim=2),
91 | selector.unsqueeze(2) if selector is not None else None)
92 |
93 | return aux_output
--------------------------------------------------------------------------------
/matanyone/model/big_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | big_modules.py - This file stores higher-level network blocks.
3 |
4 | x - usually denotes features that are shared between objects.
5 | g - usually denotes features that are not shared between objects
6 | with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
7 |
8 | The trailing number of a variable usually denotes the stride
9 | """
10 |
11 | from typing import Iterable
12 | from omegaconf import DictConfig
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18 | from matanyone.model.utils import resnet
19 | from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
20 |
21 | class UncertPred(nn.Module):
22 | def __init__(self, model_cfg: DictConfig):
23 | super().__init__()
24 | self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
25 | self.bn1 = nn.BatchNorm2d(64)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
28 | self.bn2 = nn.BatchNorm2d(32)
29 | self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
30 |
31 | def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
32 | last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
33 | x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
34 | x = self.conv1x1_v2(x)
35 | x = self.bn1(x)
36 | x = self.relu(x)
37 | x = self.conv3x3(x)
38 | x = self.bn2(x)
39 | x = self.relu(x)
40 | x = self.conv3x3_out(x)
41 | return x
42 |
43 | # override the default train() to freeze BN statistics
44 | def train(self, mode=True):
45 | self.training = False
46 | for module in self.children():
47 | module.train(False)
48 | return self
49 |
50 | class PixelEncoder(nn.Module):
51 | def __init__(self, model_cfg: DictConfig):
52 | super().__init__()
53 |
54 | self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
55 | # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
56 | # else default to True
57 | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
58 | if self.is_resnet:
59 | if model_cfg.pixel_encoder.type == 'resnet18':
60 | network = resnet.resnet18(pretrained=is_pretrained_resnet)
61 | elif model_cfg.pixel_encoder.type == 'resnet50':
62 | network = resnet.resnet50(pretrained=is_pretrained_resnet)
63 | else:
64 | raise NotImplementedError
65 | self.conv1 = network.conv1
66 | self.bn1 = network.bn1
67 | self.relu = network.relu
68 | self.maxpool = network.maxpool
69 |
70 | self.res2 = network.layer1
71 | self.layer2 = network.layer2
72 | self.layer3 = network.layer3
73 | else:
74 | raise NotImplementedError
75 |
76 | def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
77 | f1 = x
78 | x = self.conv1(x)
79 | x = self.bn1(x)
80 | x = self.relu(x)
81 | f2 = x
82 | x = self.maxpool(x)
83 | f4 = self.res2(x)
84 | f8 = self.layer2(f4)
85 | f16 = self.layer3(f8)
86 |
87 | return f16, f8, f4, f2, f1
88 |
89 | # override the default train() to freeze BN statistics
90 | def train(self, mode=True):
91 | self.training = False
92 | for module in self.children():
93 | module.train(False)
94 | return self
95 |
96 |
97 | class KeyProjection(nn.Module):
98 | def __init__(self, model_cfg: DictConfig):
99 | super().__init__()
100 | in_dim = model_cfg.pixel_encoder.ms_dims[0]
101 | mid_dim = model_cfg.pixel_dim
102 | key_dim = model_cfg.key_dim
103 |
104 | self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
105 | self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
106 | # shrinkage
107 | self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
108 | # selection
109 | self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
110 |
111 | nn.init.orthogonal_(self.key_proj.weight.data)
112 | nn.init.zeros_(self.key_proj.bias.data)
113 |
114 | def forward(self, x: torch.Tensor, *, need_s: bool,
115 | need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
116 | x = self.pix_feat_proj(x)
117 | shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
118 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
119 |
120 | return self.key_proj(x), shrinkage, selection
121 |
122 |
123 | class MaskEncoder(nn.Module):
124 | def __init__(self, model_cfg: DictConfig, single_object=False):
125 | super().__init__()
126 | pixel_dim = model_cfg.pixel_dim
127 | value_dim = model_cfg.value_dim
128 | sensory_dim = model_cfg.sensory_dim
129 | final_dim = model_cfg.mask_encoder.final_dim
130 |
131 | self.single_object = single_object
132 | extra_dim = 1 if single_object else 2
133 |
134 | # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
135 | # else default to True
136 | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
137 | if model_cfg.mask_encoder.type == 'resnet18':
138 | network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
139 | elif model_cfg.mask_encoder.type == 'resnet50':
140 | network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
141 | else:
142 | raise NotImplementedError
143 | self.conv1 = network.conv1
144 | self.bn1 = network.bn1
145 | self.relu = network.relu
146 | self.maxpool = network.maxpool
147 |
148 | self.layer1 = network.layer1
149 | self.layer2 = network.layer2
150 | self.layer3 = network.layer3
151 |
152 | self.distributor = MainToGroupDistributor()
153 | self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
154 |
155 | self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
156 |
157 | def forward(self,
158 | image: torch.Tensor,
159 | pix_feat: torch.Tensor,
160 | sensory: torch.Tensor,
161 | masks: torch.Tensor,
162 | others: torch.Tensor,
163 | *,
164 | deep_update: bool = True,
165 | chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
166 | # ms_features are from the key encoder
167 | # we only use the first one (lowest resolution), following XMem
168 | if self.single_object:
169 | g = masks.unsqueeze(2)
170 | else:
171 | g = torch.stack([masks, others], dim=2)
172 |
173 | g = self.distributor(image, g)
174 |
175 | batch_size, num_objects = g.shape[:2]
176 | if chunk_size < 1 or chunk_size >= num_objects:
177 | chunk_size = num_objects
178 | fast_path = True
179 | new_sensory = sensory
180 | else:
181 | if deep_update:
182 | new_sensory = torch.empty_like(sensory)
183 | else:
184 | new_sensory = sensory
185 | fast_path = False
186 |
187 | # chunk-by-chunk inference
188 | all_g = []
189 | for i in range(0, num_objects, chunk_size):
190 | if fast_path:
191 | g_chunk = g
192 | else:
193 | g_chunk = g[:, i:i + chunk_size]
194 | actual_chunk_size = g_chunk.shape[1]
195 | g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
196 |
197 | g_chunk = self.conv1(g_chunk)
198 | g_chunk = self.bn1(g_chunk) # 1/2, 64
199 | g_chunk = self.maxpool(g_chunk) # 1/4, 64
200 | g_chunk = self.relu(g_chunk)
201 |
202 | g_chunk = self.layer1(g_chunk) # 1/4
203 | g_chunk = self.layer2(g_chunk) # 1/8
204 | g_chunk = self.layer3(g_chunk) # 1/16
205 |
206 | g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
207 | g_chunk = self.fuser(pix_feat, g_chunk)
208 | all_g.append(g_chunk)
209 | if deep_update:
210 | if fast_path:
211 | new_sensory = self.sensory_update(g_chunk, sensory)
212 | else:
213 | new_sensory[:, i:i + chunk_size] = self.sensory_update(
214 | g_chunk, sensory[:, i:i + chunk_size])
215 | g = torch.cat(all_g, dim=1)
216 |
217 | return g, new_sensory
218 |
219 | # override the default train() to freeze BN statistics
220 | def train(self, mode=True):
221 | self.training = False
222 | for module in self.children():
223 | module.train(False)
224 | return self
225 |
226 |
227 | class PixelFeatureFuser(nn.Module):
228 | def __init__(self, model_cfg: DictConfig, single_object=False):
229 | super().__init__()
230 | value_dim = model_cfg.value_dim
231 | sensory_dim = model_cfg.sensory_dim
232 | pixel_dim = model_cfg.pixel_dim
233 | embed_dim = model_cfg.embed_dim
234 | self.single_object = single_object
235 |
236 | self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
237 | if self.single_object:
238 | self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
239 | else:
240 | self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
241 |
242 | def forward(self,
243 | pix_feat: torch.Tensor,
244 | pixel_memory: torch.Tensor,
245 | sensory_memory: torch.Tensor,
246 | last_mask: torch.Tensor,
247 | last_others: torch.Tensor,
248 | *,
249 | chunk_size: int = -1) -> torch.Tensor:
250 | batch_size, num_objects = pixel_memory.shape[:2]
251 |
252 | if self.single_object:
253 | last_mask = last_mask.unsqueeze(2)
254 | else:
255 | last_mask = torch.stack([last_mask, last_others], dim=2)
256 |
257 | if chunk_size < 1:
258 | chunk_size = num_objects
259 |
260 | # chunk-by-chunk inference
261 | all_p16 = []
262 | for i in range(0, num_objects, chunk_size):
263 | sensory_readout = self.sensory_compress(
264 | torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
265 | p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
266 | p16 = self.fuser(pix_feat, p16)
267 | all_p16.append(p16)
268 | p16 = torch.cat(all_p16, dim=1)
269 |
270 | return p16
271 |
272 |
273 | class MaskDecoder(nn.Module):
274 | def __init__(self, model_cfg: DictConfig):
275 | super().__init__()
276 | embed_dim = model_cfg.embed_dim
277 | sensory_dim = model_cfg.sensory_dim
278 | ms_image_dims = model_cfg.pixel_encoder.ms_dims
279 | up_dims = model_cfg.mask_decoder.up_dims
280 |
281 | assert embed_dim == up_dims[0]
282 |
283 | self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
284 | sensory_dim)
285 |
286 | self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
287 | self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
288 | self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
289 | # newly add for alpha matte
290 | self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
291 | self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
292 |
293 | self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
294 | self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
295 |
296 | def forward(self,
297 | ms_image_feat: Iterable[torch.Tensor],
298 | memory_readout: torch.Tensor,
299 | sensory: torch.Tensor,
300 | *,
301 | chunk_size: int = -1,
302 | update_sensory: bool = True,
303 | seg_pass: bool = False,
304 | last_mask=None,
305 | sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
306 |
307 | batch_size, num_objects = memory_readout.shape[:2]
308 | f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
309 | if chunk_size < 1 or chunk_size >= num_objects:
310 | chunk_size = num_objects
311 | fast_path = True
312 | new_sensory = sensory
313 | else:
314 | if update_sensory:
315 | new_sensory = torch.empty_like(sensory)
316 | else:
317 | new_sensory = sensory
318 | fast_path = False
319 |
320 | # chunk-by-chunk inference
321 | all_logits = []
322 | for i in range(0, num_objects, chunk_size):
323 | if fast_path:
324 | p16 = memory_readout
325 | else:
326 | p16 = memory_readout[:, i:i + chunk_size]
327 | actual_chunk_size = p16.shape[1]
328 |
329 | p8 = self.up_16_8(p16, f8)
330 | p4 = self.up_8_4(p8, f4)
331 | p2 = self.up_4_2(p4, f2)
332 | p1 = self.up_2_1(p2, f1)
333 | with torch.amp.autocast("cuda",enabled=False):
334 | if seg_pass:
335 | if last_mask is not None:
336 | res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
337 | if sigmoid_residual:
338 | res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
339 | logits = last_mask + res
340 | else:
341 | logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
342 | else:
343 | if last_mask is not None:
344 | res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
345 | if sigmoid_residual:
346 | res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
347 | logits = last_mask + res
348 | else:
349 | logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
350 | ## SensoryUpdater_fullscale
351 | if update_sensory:
352 | p1 = torch.cat(
353 | [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
354 | if fast_path:
355 | new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
356 | else:
357 | new_sensory[:,
358 | i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
359 | sensory[:,
360 | i:i + chunk_size])
361 | all_logits.append(logits)
362 | logits = torch.cat(all_logits, dim=0)
363 | logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
364 |
365 | return new_sensory, logits
366 |
--------------------------------------------------------------------------------
/matanyone/model/channel_attn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class CAResBlock(nn.Module):
8 | def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
9 | super().__init__()
10 | self.residual = residual
11 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
12 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
13 |
14 | t = int((abs(math.log2(out_dim)) + 1) // 2)
15 | k = t if t % 2 else t + 1
16 | self.pool = nn.AdaptiveAvgPool2d(1)
17 | self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
18 |
19 | if self.residual:
20 | if in_dim == out_dim:
21 | self.downsample = nn.Identity()
22 | else:
23 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | r = x
27 | x = self.conv1(F.relu(x))
28 | x = self.conv2(F.relu(x))
29 |
30 | b, c = x.shape[:2]
31 | w = self.pool(x).view(b, 1, c)
32 | w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
33 |
34 | if self.residual:
35 | x = x * w + self.downsample(r)
36 | else:
37 | x = x * w
38 |
39 | return x
40 |
--------------------------------------------------------------------------------
/matanyone/model/group_modules.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from matanyone.model.channel_attn import CAResBlock
6 |
7 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8 | align_corners: bool) -> torch.Tensor:
9 | batch_size, num_objects = g.shape[:2]
10 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
11 | scale_factor=ratio,
12 | mode=mode,
13 | align_corners=align_corners)
14 | g = g.view(batch_size, num_objects, *g.shape[1:])
15 | return g
16 |
17 |
18 | def upsample_groups(g: torch.Tensor,
19 | ratio: float = 2,
20 | mode: str = 'bilinear',
21 | align_corners: bool = False) -> torch.Tensor:
22 | return interpolate_groups(g, ratio, mode, align_corners)
23 |
24 |
25 | def downsample_groups(g: torch.Tensor,
26 | ratio: float = 1 / 2,
27 | mode: str = 'area',
28 | align_corners: bool = None) -> torch.Tensor:
29 | return interpolate_groups(g, ratio, mode, align_corners)
30 |
31 |
32 | class GConv2d(nn.Conv2d):
33 | def forward(self, g: torch.Tensor) -> torch.Tensor:
34 | batch_size, num_objects = g.shape[:2]
35 | g = super().forward(g.flatten(start_dim=0, end_dim=1))
36 | return g.view(batch_size, num_objects, *g.shape[1:])
37 |
38 |
39 | class GroupResBlock(nn.Module):
40 | def __init__(self, in_dim: int, out_dim: int):
41 | super().__init__()
42 |
43 | if in_dim == out_dim:
44 | self.downsample = nn.Identity()
45 | else:
46 | self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
47 |
48 | self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
49 | self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
50 |
51 | def forward(self, g: torch.Tensor) -> torch.Tensor:
52 | out_g = self.conv1(F.relu(g))
53 | out_g = self.conv2(F.relu(out_g))
54 |
55 | g = self.downsample(g)
56 |
57 | return out_g + g
58 |
59 |
60 | class MainToGroupDistributor(nn.Module):
61 | def __init__(self,
62 | x_transform: Optional[nn.Module] = None,
63 | g_transform: Optional[nn.Module] = None,
64 | method: str = 'cat',
65 | reverse_order: bool = False):
66 | super().__init__()
67 |
68 | self.x_transform = x_transform
69 | self.g_transform = g_transform
70 | self.method = method
71 | self.reverse_order = reverse_order
72 |
73 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
74 | num_objects = g.shape[1]
75 |
76 | if self.x_transform is not None:
77 | x = self.x_transform(x)
78 |
79 | if self.g_transform is not None:
80 | g = self.g_transform(g)
81 |
82 | if not skip_expand:
83 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
84 | if self.method == 'cat':
85 | if self.reverse_order:
86 | g = torch.cat([g, x], 2)
87 | else:
88 | g = torch.cat([x, g], 2)
89 | elif self.method == 'add':
90 | g = x + g
91 | elif self.method == 'mulcat':
92 | g = torch.cat([x * g, g], dim=2)
93 | elif self.method == 'muladd':
94 | g = x * g + g
95 | else:
96 | raise NotImplementedError
97 |
98 | return g
99 |
100 |
101 | class GroupFeatureFusionBlock(nn.Module):
102 | def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
103 | super().__init__()
104 |
105 | x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
106 | g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
107 |
108 | self.distributor = MainToGroupDistributor(x_transform=x_transform,
109 | g_transform=g_transform,
110 | method='add')
111 | self.block1 = CAResBlock(out_dim, out_dim)
112 | self.block2 = CAResBlock(out_dim, out_dim)
113 |
114 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
115 | batch_size, num_objects = g.shape[:2]
116 |
117 | g = self.distributor(x, g)
118 |
119 | g = g.flatten(start_dim=0, end_dim=1)
120 |
121 | g = self.block1(g)
122 | g = self.block2(g)
123 |
124 | g = g.view(batch_size, num_objects, *g.shape[1:])
125 |
126 | return g
--------------------------------------------------------------------------------
/matanyone/model/matanyone.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Iterable, Tuple
2 | import logging
3 | from omegaconf import DictConfig
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from omegaconf import OmegaConf
8 | from huggingface_hub import PyTorchModelHubMixin
9 |
10 | from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
11 | from matanyone.model.aux_modules import AuxComputer
12 | from matanyone.model.utils.memory_utils import get_affinity, readout
13 | from matanyone.model.transformer.object_transformer import QueryTransformer
14 | from matanyone.model.transformer.object_summarizer import ObjectSummarizer
15 | from matanyone.utils.tensor_utils import aggregate
16 |
17 | log = logging.getLogger()
18 | class MatAnyone(nn.Module,
19 | PyTorchModelHubMixin,
20 | library_name="matanyone",
21 | repo_url="https://github.com/pq-yang/MatAnyone",
22 | coders={
23 | DictConfig: (
24 | lambda x: OmegaConf.to_container(x),
25 | lambda data: OmegaConf.create(data),
26 | )
27 | },
28 | ):
29 |
30 | def __init__(self, cfg: DictConfig, *, single_object=False):
31 | super().__init__()
32 | self.cfg = cfg
33 | model_cfg = cfg.model
34 | self.ms_dims = model_cfg.pixel_encoder.ms_dims
35 | self.key_dim = model_cfg.key_dim
36 | self.value_dim = model_cfg.value_dim
37 | self.sensory_dim = model_cfg.sensory_dim
38 | self.pixel_dim = model_cfg.pixel_dim
39 | self.embed_dim = model_cfg.embed_dim
40 | self.single_object = single_object
41 |
42 | log.info(f'Single object: {self.single_object}')
43 |
44 | self.pixel_encoder = PixelEncoder(model_cfg)
45 | self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
46 | self.key_proj = KeyProjection(model_cfg)
47 | self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
48 | self.mask_decoder = MaskDecoder(model_cfg)
49 | self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
50 | self.object_transformer = QueryTransformer(model_cfg)
51 | self.object_summarizer = ObjectSummarizer(model_cfg)
52 | self.aux_computer = AuxComputer(cfg)
53 | self.temp_sparity = UncertPred(model_cfg)
54 |
55 | self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
56 | self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
57 |
58 | def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
59 | # for each object, return the sum of masks of all other objects
60 | if self.single_object:
61 | return None
62 |
63 | num_objects = masks.shape[1]
64 | if num_objects >= 1:
65 | others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
66 | else:
67 | others = torch.zeros_like(masks)
68 | return others
69 |
70 | def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
71 | logits = self.temp_sparity(last_frame_feat=last_pix_feat,
72 | cur_frame_feat=cur_pix_feat,
73 | last_mask=last_mask,
74 | mem_val_diff=mem_val_diff)
75 |
76 | prob = torch.sigmoid(logits)
77 | mask = (prob > 0) + 0
78 |
79 | uncert_output = {"logits": logits,
80 | "prob": prob,
81 | "mask": mask}
82 |
83 | return uncert_output
84 |
85 | def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
86 | image = (image - self.pixel_mean) / self.pixel_std
87 | ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
88 | return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
89 |
90 | def encode_mask(
91 | self,
92 | image: torch.Tensor,
93 | ms_features: List[torch.Tensor],
94 | sensory: torch.Tensor,
95 | masks: torch.Tensor,
96 | *,
97 | deep_update: bool = True,
98 | chunk_size: int = -1,
99 | need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
100 | image = (image - self.pixel_mean) / self.pixel_std
101 | others = self._get_others(masks)
102 | mask_value, new_sensory = self.mask_encoder(image,
103 | ms_features,
104 | sensory,
105 | masks,
106 | others,
107 | deep_update=deep_update,
108 | chunk_size=chunk_size)
109 | object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
110 | return mask_value, new_sensory, object_summaries, object_logits
111 |
112 | def transform_key(self,
113 | final_pix_feat: torch.Tensor,
114 | *,
115 | need_sk: bool = True,
116 | need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
117 | key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
118 | return key, shrinkage, selection
119 |
120 | # Used in training only.
121 | # This step is replaced by MemoryManager in test time
122 | def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
123 | memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
124 | msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
125 | sensory: torch.Tensor, last_mask: torch.Tensor,
126 | selector: torch.Tensor, uncert_output=None, seg_pass=False,
127 | last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
128 | """
129 | query_key : B * CK * H * W
130 | query_selection : B * CK * H * W
131 | memory_key : B * CK * T * H * W
132 | memory_shrinkage: B * 1 * T * H * W
133 | msk_value : B * num_objects * CV * T * H * W
134 | obj_memory : B * num_objects * T * num_summaries * C
135 | pixel_feature : B * C * H * W
136 | """
137 | batch_size, num_objects = msk_value.shape[:2]
138 |
139 | uncert_mask = uncert_output["mask"] if uncert_output is not None else None
140 |
141 | # read using visual attention
142 | with torch.amp.autocast("cuda",enabled=False):
143 | affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
144 | query_selection.float(), uncert_mask=uncert_mask)
145 |
146 | msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
147 |
148 | # B * (num_objects*CV) * H * W
149 | pixel_readout = readout(affinity, msk_value, uncert_mask)
150 | pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
151 | *pixel_readout.shape[-2:])
152 |
153 | uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
154 | uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
155 | pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
156 |
157 | pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
158 |
159 |
160 | # read from query transformer
161 | mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
162 |
163 | aux_output = {
164 | 'sensory': sensory,
165 | 'q_logits': aux_features['logits'] if aux_features else None,
166 | 'attn_mask': aux_features['attn_mask'] if aux_features else None,
167 | }
168 |
169 | return mem_readout, aux_output, uncert_output
170 |
171 | def read_first_frame_memory(self, pixel_readout,
172 | obj_memory: torch.Tensor, pix_feat: torch.Tensor,
173 | sensory: torch.Tensor, last_mask: torch.Tensor,
174 | selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
175 | """
176 | query_key : B * CK * H * W
177 | query_selection : B * CK * H * W
178 | memory_key : B * CK * T * H * W
179 | memory_shrinkage: B * 1 * T * H * W
180 | msk_value : B * num_objects * CV * T * H * W
181 | obj_memory : B * num_objects * T * num_summaries * C
182 | pixel_feature : B * C * H * W
183 | """
184 |
185 | pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
186 |
187 | # read from query transformer
188 | mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
189 |
190 | aux_output = {
191 | 'sensory': sensory,
192 | 'q_logits': aux_features['logits'] if aux_features else None,
193 | 'attn_mask': aux_features['attn_mask'] if aux_features else None,
194 | }
195 |
196 | return mem_readout, aux_output
197 |
198 | def pixel_fusion(self,
199 | pix_feat: torch.Tensor,
200 | pixel: torch.Tensor,
201 | sensory: torch.Tensor,
202 | last_mask: torch.Tensor,
203 | *,
204 | chunk_size: int = -1) -> torch.Tensor:
205 | last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
206 | last_others = self._get_others(last_mask)
207 | fused = self.pixel_fuser(pix_feat,
208 | pixel,
209 | sensory,
210 | last_mask,
211 | last_others,
212 | chunk_size=chunk_size)
213 | return fused
214 |
215 | def readout_query(self,
216 | pixel_readout,
217 | obj_memory,
218 | *,
219 | selector=None,
220 | need_weights=False,
221 | seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
222 | return self.object_transformer(pixel_readout,
223 | obj_memory,
224 | selector=selector,
225 | need_weights=need_weights,
226 | seg_pass=seg_pass)
227 |
228 | def segment(self,
229 | ms_image_feat: List[torch.Tensor],
230 | memory_readout: torch.Tensor,
231 | sensory: torch.Tensor,
232 | *,
233 | selector: bool = None,
234 | chunk_size: int = -1,
235 | update_sensory: bool = True,
236 | seg_pass: bool = False,
237 | clamp_mat: bool = True,
238 | last_mask=None,
239 | sigmoid_residual=False,
240 | seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
241 | """
242 | multi_scale_features is from the key encoder for skip-connection
243 | memory_readout is from working/long-term memory
244 | sensory is the sensory memory
245 | last_mask is the mask from the last frame, supplementing sensory memory
246 | selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
247 | during training.
248 | """
249 | #### use mat head for seg data
250 | if seg_mat:
251 | assert seg_pass
252 | seg_pass = False
253 | ####
254 | sensory, logits = self.mask_decoder(ms_image_feat,
255 | memory_readout,
256 | sensory,
257 | chunk_size=chunk_size,
258 | update_sensory=update_sensory,
259 | seg_pass = seg_pass,
260 | last_mask=last_mask,
261 | sigmoid_residual=sigmoid_residual)
262 | if seg_pass:
263 | prob = torch.sigmoid(logits)
264 | if selector is not None:
265 | prob = prob * selector
266 |
267 | # Softmax over all objects[]
268 | logits = aggregate(prob, dim=1)
269 | prob = F.softmax(logits, dim=1)
270 | else:
271 | if clamp_mat:
272 | logits = logits.clamp(0.0, 1.0)
273 | logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
274 | prob = logits
275 |
276 | return sensory, logits, prob
277 |
278 | def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
279 | selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
280 | return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
281 |
282 | def forward(self, *args, **kwargs):
283 | raise NotImplementedError
284 |
285 | def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
286 | if not self.single_object:
287 | # Map single-object weight to multi-object weight (4->5 out channels in conv1)
288 | for k in list(src_dict.keys()):
289 | if k == 'mask_encoder.conv1.weight':
290 | if src_dict[k].shape[1] == 4:
291 | log.info(f'Converting {k} from single object to multiple objects.')
292 | pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
293 | if not init_as_zero_if_needed:
294 | nn.init.orthogonal_(pads)
295 | log.info(f'Randomly initialized padding for {k}.')
296 | else:
297 | log.info(f'Zero-initialized padding for {k}.')
298 | src_dict[k] = torch.cat([src_dict[k], pads], 1)
299 | elif k == 'pixel_fuser.sensory_compress.weight':
300 | if src_dict[k].shape[1] == self.sensory_dim + 1:
301 | log.info(f'Converting {k} from single object to multiple objects.')
302 | pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
303 | if not init_as_zero_if_needed:
304 | nn.init.orthogonal_(pads)
305 | log.info(f'Randomly initialized padding for {k}.')
306 | else:
307 | log.info(f'Zero-initialized padding for {k}.')
308 | src_dict[k] = torch.cat([src_dict[k], pads], 1)
309 | elif self.single_object:
310 | """
311 | If the model is multiple-object and we are training in single-object,
312 | we strip the last channel of conv1.
313 | This is not supposed to happen in standard training except when users are trying to
314 | finetune a trained model with single object datasets.
315 | """
316 | if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
317 | log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
318 | 'This is not supposed to happen in standard training.')
319 | src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
320 | src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
321 |
322 | for k in src_dict:
323 | if k not in self.state_dict():
324 | log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
325 | for k in self.state_dict():
326 | if k not in src_dict:
327 | log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
328 |
329 | self.load_state_dict(src_dict, strict=False)
330 |
331 | @property
332 | def device(self) -> torch.device:
333 | return self.pixel_mean.device
334 |
--------------------------------------------------------------------------------
/matanyone/model/modules.py:
--------------------------------------------------------------------------------
1 | from typing import List, Iterable
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
7 |
8 |
9 | class UpsampleBlock(nn.Module):
10 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
11 | super().__init__()
12 | self.out_conv = ResBlock(in_dim, out_dim)
13 | self.scale_factor = scale_factor
14 |
15 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
16 | g = F.interpolate(in_g,
17 | scale_factor=self.scale_factor,
18 | mode='bilinear')
19 | g = self.out_conv(g)
20 | g = g + skip_f
21 | return g
22 |
23 | class MaskUpsampleBlock(nn.Module):
24 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
25 | super().__init__()
26 | self.distributor = MainToGroupDistributor(method='add')
27 | self.out_conv = GroupResBlock(in_dim, out_dim)
28 | self.scale_factor = scale_factor
29 |
30 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
31 | g = upsample_groups(in_g, ratio=self.scale_factor)
32 | g = self.distributor(skip_f, g)
33 | g = self.out_conv(g)
34 | return g
35 |
36 |
37 | class DecoderFeatureProcessor(nn.Module):
38 | def __init__(self, decoder_dims: List[int], out_dims: List[int]):
39 | super().__init__()
40 | self.transforms = nn.ModuleList([
41 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
42 | ])
43 |
44 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
45 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
46 | return outputs
47 |
48 |
49 | # @torch.jit.script
50 | def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
51 | # h: batch_size * num_objects * hidden_dim * h * w
52 | # values: batch_size * num_objects * (hidden_dim*3) * h * w
53 | dim = values.shape[2] // 3
54 | forget_gate = torch.sigmoid(values[:, :, :dim])
55 | update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
56 | new_value = torch.tanh(values[:, :, dim * 2:])
57 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
58 | return new_h
59 |
60 |
61 | class SensoryUpdater_fullscale(nn.Module):
62 | # Used in the decoder, multi-scale feature + GRU
63 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
64 | super().__init__()
65 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
66 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
67 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
68 | self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
69 | self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
70 |
71 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
72 |
73 | nn.init.xavier_normal_(self.transform.weight)
74 |
75 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
76 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
77 | self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
78 | self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
79 | self.g1_conv(downsample_groups(g[4], ratio=1/16))
80 |
81 | with torch.amp.autocast("cuda",enabled=False):
82 | g = g.float()
83 | h = h.float()
84 | values = self.transform(torch.cat([g, h], dim=2))
85 | new_h = _recurrent_update(h, values)
86 |
87 | return new_h
88 |
89 | class SensoryUpdater(nn.Module):
90 | # Used in the decoder, multi-scale feature + GRU
91 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
92 | super().__init__()
93 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
94 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
95 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
96 |
97 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
98 |
99 | nn.init.xavier_normal_(self.transform.weight)
100 |
101 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
102 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
103 | self.g4_conv(downsample_groups(g[2], ratio=1/4))
104 |
105 | with torch.amp.autocast("cuda",enabled=False):
106 | g = g.float()
107 | h = h.float()
108 | values = self.transform(torch.cat([g, h], dim=2))
109 | new_h = _recurrent_update(h, values)
110 |
111 | return new_h
112 |
113 |
114 | class SensoryDeepUpdater(nn.Module):
115 | def __init__(self, f_dim: int, sensory_dim: int):
116 | super().__init__()
117 | self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
118 |
119 | nn.init.xavier_normal_(self.transform.weight)
120 |
121 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
122 | with torch.amp.autocast("cuda",enabled=False):
123 | g = g.float()
124 | h = h.float()
125 | values = self.transform(torch.cat([g, h], dim=2))
126 | new_h = _recurrent_update(h, values)
127 |
128 | return new_h
129 |
130 |
131 | class ResBlock(nn.Module):
132 | def __init__(self, in_dim: int, out_dim: int):
133 | super().__init__()
134 |
135 | if in_dim == out_dim:
136 | self.downsample = nn.Identity()
137 | else:
138 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
139 |
140 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
141 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
142 |
143 | def forward(self, g: torch.Tensor) -> torch.Tensor:
144 | out_g = self.conv1(F.relu(g))
145 | out_g = self.conv2(F.relu(out_g))
146 |
147 | g = self.downsample(g)
148 |
149 | return out_g + g
--------------------------------------------------------------------------------
/matanyone/model/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/transformer/__init__.py
--------------------------------------------------------------------------------
/matanyone/model/transformer/object_summarizer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from omegaconf import DictConfig
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from matanyone.model.transformer.positional_encoding import PositionalEncoding
8 |
9 |
10 | # @torch.jit.script
11 | def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
12 | logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
13 | # value: B*num_objects*H*W*value_dim
14 | # logits: B*num_objects*H*W*num_summaries
15 | # masks: B*num_objects*H*W*num_summaries: 1 if allowed
16 | weights = logits.sigmoid() * masks
17 | # B*num_objects*num_summaries*value_dim
18 | sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
19 | # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
20 | area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
21 |
22 | # B*num_objects*num_summaries*value_dim
23 | return sums, area
24 |
25 |
26 | class ObjectSummarizer(nn.Module):
27 | def __init__(self, model_cfg: DictConfig):
28 | super().__init__()
29 |
30 | this_cfg = model_cfg.object_summarizer
31 | self.value_dim = model_cfg.value_dim
32 | self.embed_dim = this_cfg.embed_dim
33 | self.num_summaries = this_cfg.num_summaries
34 | self.add_pe = this_cfg.add_pe
35 | self.pixel_pe_scale = model_cfg.pixel_pe_scale
36 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
37 |
38 | if self.add_pe:
39 | self.pos_enc = PositionalEncoding(self.embed_dim,
40 | scale=self.pixel_pe_scale,
41 | temperature=self.pixel_pe_temperature)
42 |
43 | self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
44 | self.feature_pred = nn.Sequential(
45 | nn.Linear(self.embed_dim, self.embed_dim),
46 | nn.ReLU(inplace=True),
47 | nn.Linear(self.embed_dim, self.embed_dim),
48 | )
49 | self.weights_pred = nn.Sequential(
50 | nn.Linear(self.embed_dim, self.embed_dim),
51 | nn.ReLU(inplace=True),
52 | nn.Linear(self.embed_dim, self.num_summaries),
53 | )
54 |
55 | def forward(self,
56 | masks: torch.Tensor,
57 | value: torch.Tensor,
58 | need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
59 | # masks: B*num_objects*(H0)*(W0)
60 | # value: B*num_objects*value_dim*H*W
61 | # -> B*num_objects*H*W*value_dim
62 | h, w = value.shape[-2:]
63 | masks = F.interpolate(masks, size=(h, w), mode='area')
64 | masks = masks.unsqueeze(-1)
65 | inv_masks = 1 - masks
66 | repeated_masks = torch.cat([
67 | masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
68 | inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
69 | ],
70 | dim=-1)
71 |
72 | value = value.permute(0, 1, 3, 4, 2)
73 | value = self.input_proj(value)
74 | if self.add_pe:
75 | pe = self.pos_enc(value)
76 | value = value + pe
77 |
78 | with torch.amp.autocast("cuda",enabled=False):
79 | value = value.float()
80 | feature = self.feature_pred(value)
81 | logits = self.weights_pred(value)
82 | sums, area = _weighted_pooling(repeated_masks, feature, logits)
83 |
84 | summaries = torch.cat([sums, area], dim=-1)
85 |
86 | if need_weights:
87 | return summaries, logits
88 | else:
89 | return summaries, None
--------------------------------------------------------------------------------
/matanyone/model/transformer/object_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from omegaconf import DictConfig
3 |
4 | import torch
5 | import torch.nn as nn
6 | from matanyone.model.group_modules import GConv2d
7 | from matanyone.utils.tensor_utils import aggregate
8 | from matanyone.model.transformer.positional_encoding import PositionalEncoding
9 | from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10 |
11 |
12 | class QueryTransformerBlock(nn.Module):
13 | def __init__(self, model_cfg: DictConfig):
14 | super().__init__()
15 |
16 | this_cfg = model_cfg.object_transformer
17 | self.embed_dim = this_cfg.embed_dim
18 | self.num_heads = this_cfg.num_heads
19 | self.num_queries = this_cfg.num_queries
20 | self.ff_dim = this_cfg.ff_dim
21 |
22 | self.read_from_pixel = CrossAttention(self.embed_dim,
23 | self.num_heads,
24 | add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
25 | self.self_attn = SelfAttention(self.embed_dim,
26 | self.num_heads,
27 | add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
28 | self.ffn = FFN(self.embed_dim, self.ff_dim)
29 | self.read_from_query = CrossAttention(self.embed_dim,
30 | self.num_heads,
31 | add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
32 | norm=this_cfg.read_from_query.output_norm)
33 | self.pixel_ffn = PixelFFN(self.embed_dim)
34 |
35 | def forward(
36 | self,
37 | x: torch.Tensor,
38 | pixel: torch.Tensor,
39 | query_pe: torch.Tensor,
40 | pixel_pe: torch.Tensor,
41 | attn_mask: torch.Tensor,
42 | need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
43 | # x: (bs*num_objects)*num_queries*embed_dim
44 | # pixel: bs*num_objects*C*H*W
45 | # query_pe: (bs*num_objects)*num_queries*embed_dim
46 | # pixel_pe: (bs*num_objects)*(H*W)*C
47 | # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
48 |
49 | # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
50 | pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
51 | x, q_weights = self.read_from_pixel(x,
52 | pixel_flat,
53 | query_pe,
54 | pixel_pe,
55 | attn_mask=attn_mask,
56 | need_weights=need_weights)
57 | x = self.self_attn(x, query_pe)
58 | x = self.ffn(x)
59 |
60 | pixel_flat, p_weights = self.read_from_query(pixel_flat,
61 | x,
62 | pixel_pe,
63 | query_pe,
64 | need_weights=need_weights)
65 | pixel = self.pixel_ffn(pixel, pixel_flat)
66 |
67 | if need_weights:
68 | bs, num_objects, _, h, w = pixel.shape
69 | q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
70 | p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
71 | self.num_queries, h, w)
72 |
73 | return x, pixel, q_weights, p_weights
74 |
75 |
76 | class QueryTransformer(nn.Module):
77 | def __init__(self, model_cfg: DictConfig):
78 | super().__init__()
79 |
80 | this_cfg = model_cfg.object_transformer
81 | self.value_dim = model_cfg.value_dim
82 | self.embed_dim = this_cfg.embed_dim
83 | self.num_heads = this_cfg.num_heads
84 | self.num_queries = this_cfg.num_queries
85 |
86 | # query initialization and embedding
87 | self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
88 | self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
89 |
90 | # projection from object summaries to query initialization and embedding
91 | self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
92 | self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
93 |
94 | self.pixel_pe_scale = model_cfg.pixel_pe_scale
95 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
96 | self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
97 | self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
98 | self.spatial_pe = PositionalEncoding(self.embed_dim,
99 | scale=self.pixel_pe_scale,
100 | temperature=self.pixel_pe_temperature,
101 | channel_last=False,
102 | transpose_output=True)
103 |
104 | # transformer blocks
105 | self.num_blocks = this_cfg.num_blocks
106 | self.blocks = nn.ModuleList(
107 | QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
108 | self.mask_pred = nn.ModuleList(
109 | nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
110 | for _ in range(self.num_blocks + 1))
111 |
112 | self.act = nn.ReLU(inplace=True)
113 |
114 | def forward(self,
115 | pixel: torch.Tensor,
116 | obj_summaries: torch.Tensor,
117 | selector: Optional[torch.Tensor] = None,
118 | need_weights: bool = False,
119 | seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
120 | # pixel: B*num_objects*embed_dim*H*W
121 | # obj_summaries: B*num_objects*T*num_queries*embed_dim
122 | T = obj_summaries.shape[2]
123 | bs, num_objects, _, H, W = pixel.shape
124 |
125 | # normalize object values
126 | # the last channel is the cumulative area of the object
127 | obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
128 | self.embed_dim + 1)
129 | # sum over time
130 | # during inference, T=1 as we already did streaming average in memory_manager
131 | obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
132 | obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
133 | obj_values = obj_sums / (obj_area + 1e-4)
134 | obj_init = self.summary_to_query_init(obj_values)
135 | obj_emb = self.summary_to_query_emb(obj_values)
136 |
137 | # positional embeddings for object queries
138 | query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
139 | query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
140 |
141 | # positional embeddings for pixel features
142 | pixel_init = self.pixel_init_proj(pixel)
143 | pixel_emb = self.pixel_emb_proj(pixel)
144 | pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
145 | pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
146 | pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
147 |
148 | pixel = pixel_init
149 |
150 | # run the transformer
151 | aux_features = {'logits': []}
152 |
153 | # first aux output
154 | aux_logits = self.mask_pred[0](pixel).squeeze(2)
155 | attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
156 | aux_features['logits'].append(aux_logits)
157 | for i in range(self.num_blocks):
158 | query, pixel, q_weights, p_weights = self.blocks[i](query,
159 | pixel,
160 | query_emb,
161 | pixel_pe,
162 | attn_mask,
163 | need_weights=need_weights)
164 |
165 | if self.training or i <= self.num_blocks - 1 or need_weights:
166 | aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
167 | attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
168 | aux_features['logits'].append(aux_logits)
169 |
170 | aux_features['q_weights'] = q_weights # last layer only
171 | aux_features['p_weights'] = p_weights # last layer only
172 |
173 | if self.training:
174 | # no need to save all heads
175 | aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
176 | self.num_queries, H, W)[:, :, 0]
177 |
178 | return pixel, aux_features
179 |
180 | def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
181 | # logits: batch_size*num_objects*H*W
182 | # selector: batch_size*num_objects*1*1
183 | # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
184 | # where True means the attention is blocked
185 |
186 | if selector is None:
187 | prob = logits.sigmoid()
188 | else:
189 | prob = logits.sigmoid() * selector
190 | logits = aggregate(prob, dim=1)
191 |
192 | is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
193 | foreground_mask = is_foreground.bool().flatten(start_dim=2)
194 | inv_foreground_mask = ~foreground_mask
195 | inv_background_mask = foreground_mask
196 |
197 | aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
198 | 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
199 | aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
200 | 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
201 |
202 | aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
203 |
204 | aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
205 |
206 | return aux_mask
--------------------------------------------------------------------------------
/matanyone/model/transformer/positional_encoding.py:
--------------------------------------------------------------------------------
1 | # Reference:
2 | # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
3 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
4 |
5 | import math
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
13 | """
14 | Gets a base embedding for one dimension with sin and cos intertwined
15 | """
16 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
17 | return torch.flatten(emb, -2, -1)
18 |
19 |
20 | class PositionalEncoding(nn.Module):
21 | def __init__(self,
22 | dim: int,
23 | scale: float = math.pi * 2,
24 | temperature: float = 10000,
25 | normalize: bool = True,
26 | channel_last: bool = True,
27 | transpose_output: bool = False):
28 | super().__init__()
29 | dim = int(np.ceil(dim / 4) * 2)
30 | self.dim = dim
31 | inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
32 | self.register_buffer("inv_freq", inv_freq)
33 | self.normalize = normalize
34 | self.scale = scale
35 | self.eps = 1e-6
36 | self.channel_last = channel_last
37 | self.transpose_output = transpose_output
38 |
39 | self.cached_penc = None # the cache is irrespective of the number of objects
40 |
41 | def forward(self, tensor: torch.Tensor) -> torch.Tensor:
42 | """
43 | :param tensor: A 4/5d tensor of size
44 | channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
45 | channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
46 | :return: positional encoding tensor that has the same shape as the input if the input is 4d
47 | if the input is 5d, the output is broadcastable along the k-dimension
48 | """
49 | if len(tensor.shape) != 4 and len(tensor.shape) != 5:
50 | raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
51 |
52 | if len(tensor.shape) == 5:
53 | # take a sample from the k dimension
54 | num_objects = tensor.shape[1]
55 | tensor = tensor[:, 0]
56 | else:
57 | num_objects = None
58 |
59 | if self.channel_last:
60 | batch_size, h, w, c = tensor.shape
61 | else:
62 | batch_size, c, h, w = tensor.shape
63 |
64 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
65 | if num_objects is None:
66 | return self.cached_penc
67 | else:
68 | return self.cached_penc.unsqueeze(1)
69 |
70 | self.cached_penc = None
71 |
72 | pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
73 | pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
74 | if self.normalize:
75 | pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
76 | pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
77 |
78 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
79 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
80 | emb_y = get_emb(sin_inp_y).unsqueeze(1)
81 | emb_x = get_emb(sin_inp_x)
82 |
83 | emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
84 | emb[:, :, :self.dim] = emb_x
85 | emb[:, :, self.dim:] = emb_y
86 |
87 | if not self.channel_last and self.transpose_output:
88 | # cancelled out
89 | pass
90 | elif (not self.channel_last) or (self.transpose_output):
91 | emb = emb.permute(2, 0, 1)
92 |
93 | self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
94 | if num_objects is None:
95 | return self.cached_penc
96 | else:
97 | return self.cached_penc.unsqueeze(1)
98 |
99 |
100 | if __name__ == '__main__':
101 | pe = PositionalEncoding(8).cuda()
102 | input = torch.ones((1, 8, 8, 8)).cuda()
103 | output = pe(input)
104 | # print(output)
105 | print(output[0, :, 0, 0])
106 | print(output[0, :, 0, 5])
107 | print(output[0, 0, :, 0])
108 | print(output[0, 0, 0, :])
109 |
--------------------------------------------------------------------------------
/matanyone/model/transformer/transformer_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from PyTorch nn.Transformer
2 |
3 | from typing import List, Callable
4 |
5 | import torch
6 | from torch import Tensor
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from matanyone.model.channel_attn import CAResBlock
10 |
11 |
12 | class SelfAttention(nn.Module):
13 | def __init__(self,
14 | dim: int,
15 | nhead: int,
16 | dropout: float = 0.0,
17 | batch_first: bool = True,
18 | add_pe_to_qkv: List[bool] = [True, True, False]):
19 | super().__init__()
20 | self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
21 | self.norm = nn.LayerNorm(dim)
22 | self.dropout = nn.Dropout(dropout)
23 | self.add_pe_to_qkv = add_pe_to_qkv
24 |
25 | def forward(self,
26 | x: torch.Tensor,
27 | pe: torch.Tensor,
28 | attn_mask: bool = None,
29 | key_padding_mask: bool = None) -> torch.Tensor:
30 | x = self.norm(x)
31 | if any(self.add_pe_to_qkv):
32 | x_with_pe = x + pe
33 | q = x_with_pe if self.add_pe_to_qkv[0] else x
34 | k = x_with_pe if self.add_pe_to_qkv[1] else x
35 | v = x_with_pe if self.add_pe_to_qkv[2] else x
36 | else:
37 | q = k = v = x
38 |
39 | r = x
40 | x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
41 | return r + self.dropout(x)
42 |
43 |
44 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
45 | class CrossAttention(nn.Module):
46 | def __init__(self,
47 | dim: int,
48 | nhead: int,
49 | dropout: float = 0.0,
50 | batch_first: bool = True,
51 | add_pe_to_qkv: List[bool] = [True, True, False],
52 | residual: bool = True,
53 | norm: bool = True):
54 | super().__init__()
55 | self.cross_attn = nn.MultiheadAttention(dim,
56 | nhead,
57 | dropout=dropout,
58 | batch_first=batch_first)
59 | if norm:
60 | self.norm = nn.LayerNorm(dim)
61 | else:
62 | self.norm = nn.Identity()
63 | self.dropout = nn.Dropout(dropout)
64 | self.add_pe_to_qkv = add_pe_to_qkv
65 | self.residual = residual
66 |
67 | def forward(self,
68 | x: torch.Tensor,
69 | mem: torch.Tensor,
70 | x_pe: torch.Tensor,
71 | mem_pe: torch.Tensor,
72 | attn_mask: bool = None,
73 | *,
74 | need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
75 | x = self.norm(x)
76 | if self.add_pe_to_qkv[0]:
77 | q = x + x_pe
78 | else:
79 | q = x
80 |
81 | if any(self.add_pe_to_qkv[1:]):
82 | mem_with_pe = mem + mem_pe
83 | k = mem_with_pe if self.add_pe_to_qkv[1] else mem
84 | v = mem_with_pe if self.add_pe_to_qkv[2] else mem
85 | else:
86 | k = v = mem
87 | r = x
88 | x, weights = self.cross_attn(q,
89 | k,
90 | v,
91 | attn_mask=attn_mask,
92 | need_weights=need_weights,
93 | average_attn_weights=False)
94 |
95 | if self.residual:
96 | return r + self.dropout(x), weights
97 | else:
98 | return self.dropout(x), weights
99 |
100 |
101 | class FFN(nn.Module):
102 | def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
103 | super().__init__()
104 | self.linear1 = nn.Linear(dim_in, dim_ff)
105 | self.linear2 = nn.Linear(dim_ff, dim_in)
106 | self.norm = nn.LayerNorm(dim_in)
107 |
108 | if isinstance(activation, str):
109 | self.activation = _get_activation_fn(activation)
110 | else:
111 | self.activation = activation
112 |
113 | def forward(self, x: torch.Tensor) -> torch.Tensor:
114 | r = x
115 | x = self.norm(x)
116 | x = self.linear2(self.activation(self.linear1(x)))
117 | x = r + x
118 | return x
119 |
120 |
121 | class PixelFFN(nn.Module):
122 | def __init__(self, dim: int):
123 | super().__init__()
124 | self.dim = dim
125 | self.conv = CAResBlock(dim, dim)
126 |
127 | def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
128 | # pixel: batch_size * num_objects * dim * H * W
129 | # pixel_flat: (batch_size*num_objects) * (H*W) * dim
130 | bs, num_objects, _, h, w = pixel.shape
131 | pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
132 | pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
133 |
134 | x = self.conv(pixel_flat)
135 | x = x.view(bs, num_objects, self.dim, h, w)
136 | return x
137 |
138 |
139 | class OutputFFN(nn.Module):
140 | def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
141 | super().__init__()
142 | self.linear1 = nn.Linear(dim_in, dim_out)
143 | self.linear2 = nn.Linear(dim_out, dim_out)
144 |
145 | if isinstance(activation, str):
146 | self.activation = _get_activation_fn(activation)
147 | else:
148 | self.activation = activation
149 |
150 | def forward(self, x: torch.Tensor) -> torch.Tensor:
151 | x = self.linear2(self.activation(self.linear1(x)))
152 | return x
153 |
154 |
155 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
156 | if activation == "relu":
157 | return F.relu
158 | elif activation == "gelu":
159 | return F.gelu
160 |
161 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
162 |
--------------------------------------------------------------------------------
/matanyone/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/utils/__init__.py
--------------------------------------------------------------------------------
/matanyone/model/utils/memory_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from typing import Optional, Union, Tuple
4 |
5 |
6 | # @torch.jit.script
7 | def get_similarity(mk: torch.Tensor,
8 | ms: torch.Tensor,
9 | qk: torch.Tensor,
10 | qe: torch.Tensor,
11 | add_batch_dim: bool = False,
12 | uncert_mask = None) -> torch.Tensor:
13 | # used for training/inference and memory reading/memory potentiation
14 | # mk: B x CK x [N] - Memory keys
15 | # ms: B x 1 x [N] - Memory shrinkage
16 | # qk: B x CK x [HW/P] - Query keys
17 | # qe: B x CK x [HW/P] - Query selection
18 | # Dimensions in [] are flattened
19 | # Return: B*N*HW
20 | if add_batch_dim:
21 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
22 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
23 |
24 | CK = mk.shape[1]
25 |
26 | mk = mk.flatten(start_dim=2)
27 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
28 | qk = qk.flatten(start_dim=2)
29 | qe = qe.flatten(start_dim=2) if qe is not None else None
30 |
31 | # query token selection based on temporal sparsity
32 | if uncert_mask is not None:
33 | uncert_mask = uncert_mask.flatten(start_dim=2)
34 | uncert_mask = uncert_mask.expand(-1, 64, -1)
35 | qk = qk * uncert_mask
36 | qe = qe * uncert_mask
37 |
38 | if qe is not None:
39 | # See XMem's appendix for derivation
40 | mk = mk.transpose(1, 2)
41 | a_sq = (mk.pow(2) @ qe)
42 | two_ab = 2 * (mk @ (qk * qe))
43 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
44 | similarity = (-a_sq + two_ab - b_sq)
45 | else:
46 | # similar to STCN if we don't have the selection term
47 | a_sq = mk.pow(2).sum(1).unsqueeze(2)
48 | two_ab = 2 * (mk.transpose(1, 2) @ qk)
49 | similarity = (-a_sq + two_ab)
50 |
51 | if ms is not None:
52 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW
53 | else:
54 | similarity = similarity / math.sqrt(CK) # B*N*HW
55 |
56 | return similarity
57 |
58 |
59 | def do_softmax(
60 | similarity: torch.Tensor,
61 | top_k: Optional[int] = None,
62 | inplace: bool = False,
63 | return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
64 | # normalize similarity with top-k softmax
65 | # similarity: B x N x [HW/P]
66 | # use inplace with care
67 | if top_k is not None:
68 | values, indices = torch.topk(similarity, k=top_k, dim=1)
69 |
70 | x_exp = values.exp_()
71 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
72 | if inplace:
73 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
74 | affinity = similarity
75 | else:
76 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
77 | else:
78 | maxes = torch.max(similarity, dim=1, keepdim=True)[0]
79 | x_exp = torch.exp(similarity - maxes)
80 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
81 | affinity = x_exp / x_exp_sum
82 | indices = None
83 |
84 | if return_usage:
85 | return affinity, affinity.sum(dim=2)
86 |
87 | return affinity
88 |
89 |
90 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
91 | qe: torch.Tensor, uncert_mask = None) -> torch.Tensor:
92 | # shorthand used in training with no top-k
93 | similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask)
94 | affinity = do_softmax(similarity)
95 | return affinity
96 |
97 | def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor:
98 | B, CV, T, H, W = mv.shape
99 |
100 | mo = mv.view(B, CV, T * H * W)
101 | mem = torch.bmm(mo, affinity)
102 | if uncert_mask is not None:
103 | uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1)
104 | mem = mem * uncert_mask
105 | mem = mem.view(B, CV, H, W)
106 |
107 | return mem
108 |
--------------------------------------------------------------------------------
/matanyone/model/utils/parameter_groups.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | log = logging.getLogger()
4 |
5 |
6 | def get_parameter_groups(model, stage_cfg, print_log=False):
7 | """
8 | Assign different weight decays and learning rates to different parameters.
9 | Returns a parameter group which can be passed to the optimizer.
10 | """
11 | weight_decay = stage_cfg.weight_decay
12 | embed_weight_decay = stage_cfg.embed_weight_decay
13 | backbone_lr_ratio = stage_cfg.backbone_lr_ratio
14 | base_lr = stage_cfg.learning_rate
15 |
16 | backbone_params = []
17 | embed_params = []
18 | other_params = []
19 |
20 | embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
21 | embedding_names = [e + '.weight' for e in embedding_names]
22 |
23 | # inspired by detectron2
24 | memo = set()
25 | for name, param in model.named_parameters():
26 | if not param.requires_grad:
27 | continue
28 | # Avoid duplicating parameters
29 | if param in memo:
30 | continue
31 | memo.add(param)
32 |
33 | if name.startswith('module'):
34 | name = name[7:]
35 |
36 | inserted = False
37 | if name.startswith('pixel_encoder.'):
38 | backbone_params.append(param)
39 | inserted = True
40 | if print_log:
41 | log.info(f'{name} counted as a backbone parameter.')
42 | else:
43 | for e in embedding_names:
44 | if name.endswith(e):
45 | embed_params.append(param)
46 | inserted = True
47 | if print_log:
48 | log.info(f'{name} counted as an embedding parameter.')
49 | break
50 |
51 | if not inserted:
52 | other_params.append(param)
53 |
54 | parameter_groups = [
55 | {
56 | 'params': backbone_params,
57 | 'lr': base_lr * backbone_lr_ratio,
58 | 'weight_decay': weight_decay
59 | },
60 | {
61 | 'params': embed_params,
62 | 'lr': base_lr,
63 | 'weight_decay': embed_weight_decay
64 | },
65 | {
66 | 'params': other_params,
67 | 'lr': base_lr,
68 | 'weight_decay': weight_decay
69 | },
70 | ]
71 |
72 | return parameter_groups
--------------------------------------------------------------------------------
/matanyone/model/utils/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | resnet.py - A modified ResNet structure
3 | We append extra channels to the first conv by some network surgery
4 | """
5 |
6 | from collections import OrderedDict
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils import model_zoo
12 |
13 |
14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15 | new_dict = OrderedDict()
16 |
17 | for k1, v1 in target.state_dict().items():
18 | if 'num_batches_tracked' not in k1:
19 | if k1 in source_state:
20 | tar_v = source_state[k1]
21 |
22 | if v1.shape != tar_v.shape:
23 | # Init the new segmentation channel with zeros
24 | # print(v1.shape, tar_v.shape)
25 | c, _, w, h = v1.shape
26 | pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
27 | nn.init.orthogonal_(pads)
28 | tar_v = torch.cat([tar_v, pads], 1)
29 |
30 | new_dict[k1] = tar_v
31 |
32 | target.load_state_dict(new_dict)
33 |
34 |
35 | model_urls = {
36 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
37 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
38 | }
39 |
40 |
41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42 | return nn.Conv2d(in_planes,
43 | out_planes,
44 | kernel_size=3,
45 | stride=stride,
46 | padding=dilation,
47 | dilation=dilation,
48 | bias=False)
49 |
50 |
51 | class BasicBlock(nn.Module):
52 | expansion = 1
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
55 | super(BasicBlock, self).__init__()
56 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
60 | self.bn2 = nn.BatchNorm2d(planes)
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | residual = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.relu(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 |
74 | if self.downsample is not None:
75 | residual = self.downsample(x)
76 |
77 | out += residual
78 | out = self.relu(out)
79 |
80 | return out
81 |
82 |
83 | class Bottleneck(nn.Module):
84 | expansion = 4
85 |
86 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
87 | super(Bottleneck, self).__init__()
88 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
89 | self.bn1 = nn.BatchNorm2d(planes)
90 | self.conv2 = nn.Conv2d(planes,
91 | planes,
92 | kernel_size=3,
93 | stride=stride,
94 | dilation=dilation,
95 | padding=dilation,
96 | bias=False)
97 | self.bn2 = nn.BatchNorm2d(planes)
98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
99 | self.bn3 = nn.BatchNorm2d(planes * 4)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.downsample = downsample
102 | self.stride = stride
103 |
104 | def forward(self, x):
105 | residual = x
106 |
107 | out = self.conv1(x)
108 | out = self.bn1(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv2(out)
112 | out = self.bn2(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | residual = self.downsample(x)
120 |
121 | out += residual
122 | out = self.relu(out)
123 |
124 | return out
125 |
126 |
127 | class ResNet(nn.Module):
128 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
129 | self.inplanes = 64
130 | super(ResNet, self).__init__()
131 | self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
132 | self.bn1 = nn.BatchNorm2d(64)
133 | self.relu = nn.ReLU(inplace=True)
134 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135 | self.layer1 = self._make_layer(block, 64, layers[0])
136 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
139 |
140 | for m in self.modules():
141 | if isinstance(m, nn.Conv2d):
142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
143 | m.weight.data.normal_(0, math.sqrt(2. / n))
144 | elif isinstance(m, nn.BatchNorm2d):
145 | m.weight.data.fill_(1)
146 | m.bias.data.zero_()
147 |
148 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
149 | downsample = None
150 | if stride != 1 or self.inplanes != planes * block.expansion:
151 | downsample = nn.Sequential(
152 | nn.Conv2d(self.inplanes,
153 | planes * block.expansion,
154 | kernel_size=1,
155 | stride=stride,
156 | bias=False),
157 | nn.BatchNorm2d(planes * block.expansion),
158 | )
159 |
160 | layers = [block(self.inplanes, planes, stride, downsample)]
161 | self.inplanes = planes * block.expansion
162 | for i in range(1, blocks):
163 | layers.append(block(self.inplanes, planes, dilation=dilation))
164 |
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def resnet18(pretrained=True, extra_dim=0):
169 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
170 | if pretrained:
171 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
172 | return model
173 |
174 |
175 | def resnet50(pretrained=True, extra_dim=0):
176 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
177 | if pretrained:
178 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
179 | return model
180 |
--------------------------------------------------------------------------------
/matanyone/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/utils/__init__.py
--------------------------------------------------------------------------------
/matanyone/utils/get_default_model.py:
--------------------------------------------------------------------------------
1 | """
2 | A helper function to get a default model for quick testing
3 | """
4 | from omegaconf import open_dict
5 | from hydra import compose, initialize
6 |
7 | import torch
8 | from matanyone.model.matanyone import MatAnyone
9 |
10 | def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
11 | initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
12 | cfg = compose(config_name="eval_matanyone_config")
13 |
14 | with open_dict(cfg):
15 | cfg['weights'] = ckpt_path
16 |
17 | # Load the network weights
18 | if device is not None:
19 | matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20 | model_weights = torch.load(cfg.weights, map_location=device)
21 | else: # if device is not specified, `.cuda()` by default
22 | matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
23 | model_weights = torch.load(cfg.weights)
24 |
25 | matanyone.load_weights(model_weights)
26 |
27 | return matanyone
28 |
--------------------------------------------------------------------------------
/matanyone/utils/inference_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 |
6 | import torch
7 | import torchvision
8 |
9 | IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG')
10 | VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.MP4', '.MOV', '.AVI')
11 |
12 | def read_frame_from_videos(frame_root):
13 | if frame_root.endswith(VIDEO_EXTENSIONS): # Video file path
14 | video_name = os.path.basename(frame_root)[:-4]
15 | frames, _, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', output_format='TCHW') # RGB
16 | fps = info['video_fps']
17 | else:
18 | video_name = os.path.basename(frame_root)
19 | frames = []
20 | fr_lst = sorted(os.listdir(frame_root))
21 | for fr in fr_lst:
22 | frame = cv2.imread(os.path.join(frame_root, fr))[...,[2,1,0]] # RGB, HWC
23 | frames.append(frame)
24 | fps = 24 # default
25 | frames = torch.Tensor(np.array(frames)).permute(0, 3, 1, 2).contiguous() # TCHW
26 |
27 | length = frames.shape[0]
28 |
29 | return frames, fps, length, video_name
30 |
31 | def get_video_paths(input_root):
32 | video_paths = []
33 | for root, _, files in os.walk(input_root):
34 | for file in files:
35 | if file.lower().endswith(VIDEO_EXTENSIONS):
36 | video_paths.append(os.path.join(root, file))
37 | return sorted(video_paths)
38 |
39 | def str_to_list(value):
40 | return list(map(int, value.split(',')))
41 |
42 | def gen_dilate(alpha, min_kernel_size, max_kernel_size):
43 | kernel_size = random.randint(min_kernel_size, max_kernel_size)
44 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
45 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
46 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
47 | return dilate.astype(np.float32)
48 |
49 | def gen_erosion(alpha, min_kernel_size, max_kernel_size):
50 | kernel_size = random.randint(min_kernel_size, max_kernel_size)
51 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
52 | fg = np.array(np.equal(alpha, 255).astype(np.float32))
53 | erode = cv2.erode(fg, kernel, iterations=1)*255
54 | return erode.astype(np.float32)
--------------------------------------------------------------------------------
/matanyone/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Iterable
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | # STM
7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
8 | h, w = in_img.shape[-2:]
9 |
10 | if h % d > 0:
11 | new_h = h + d - h % d
12 | else:
13 | new_h = h
14 | if w % d > 0:
15 | new_w = w + d - w % d
16 | else:
17 | new_w = w
18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
20 | pad_array = (int(lw), int(uw), int(lh), int(uh))
21 | out = F.pad(in_img, pad_array)
22 | return out, pad_array
23 |
24 |
25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
26 | if len(img.shape) == 4:
27 | if pad[2] + pad[3] > 0:
28 | img = img[:, :, pad[2]:-pad[3], :]
29 | if pad[0] + pad[1] > 0:
30 | img = img[:, :, :, pad[0]:-pad[1]]
31 | elif len(img.shape) == 3:
32 | if pad[2] + pad[3] > 0:
33 | img = img[:, pad[2]:-pad[3], :]
34 | if pad[0] + pad[1] > 0:
35 | img = img[:, :, pad[0]:-pad[1]]
36 | elif len(img.shape) == 5:
37 | if pad[2] + pad[3] > 0:
38 | img = img[:, :, :, pad[2]:-pad[3], :]
39 | if pad[0] + pad[1] > 0:
40 | img = img[:, :, :, :, pad[0]:-pad[1]]
41 | else:
42 | raise NotImplementedError
43 | return img
44 |
45 |
46 | # @torch.jit.script
47 | def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48 | with torch.amp.autocast("cuda",enabled=False):
49 | prob = prob.float()
50 | new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51 | dim).clamp(1e-7, 1 - 1e-7)
52 | logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf)
53 |
54 | return logits
55 |
56 |
57 | # @torch.jit.script
58 | def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
59 | # cls_gt: B*1*H*W
60 | B, _, H, W = cls_gt.shape
61 | one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62 | return one_hot
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.metadata]
6 | allow-direct-references = true
7 |
8 | [tool.yapf]
9 | based_on_style = "pep8"
10 | indent_width = 4
11 | column_limit = 100
12 |
13 | [project]
14 | name = "matanyone"
15 | version = "1.0.0"
16 | authors = [{ name = "Peiqing Yang", email = "peiqingyang99@outlook.com" }]
17 | description = ""
18 | readme = "README.md"
19 | requires-python = ">=3.8"
20 | classifiers = [
21 | "Programming Language :: Python :: 3",
22 | "Operating System :: OS Independent",
23 | ]
24 | dependencies = [
25 | 'cython',
26 | 'gitpython >= 3.1',
27 | 'thinplate@git+https://github.com/cheind/py-thin-plate-spline',
28 | 'hickle >= 5.0',
29 | 'tensorboard >= 2.11',
30 | 'numpy >= 1.21',
31 | 'Pillow >= 9.5',
32 | 'opencv-python >= 4.8',
33 | 'scipy >= 1.7',
34 | 'pycocotools >= 2.0.7',
35 | 'tqdm >= 4.66.1',
36 | 'gradio >= 3.34',
37 | 'gdown >= 4.7.1',
38 | 'einops >= 0.6',
39 | 'hydra-core >= 1.3.2',
40 | 'PySide6 >= 6.2.0',
41 | 'charset-normalizer >= 3.1.0',
42 | 'netifaces >= 0.11.0',
43 | 'cchardet >= 2.1.7',
44 | 'easydict',
45 | 'av >= 0.5.2',
46 | 'requests',
47 | 'pyqtdarktheme',
48 | 'imageio == 2.25.0',
49 | 'imageio[ffmpeg]',
50 | 'huggingface_hub',
51 | 'safetensors',
52 | ]
53 |
54 | [tool.hatch.build.targets.wheel]
55 | packages = ["matanyone"]
56 |
57 | [project.urls]
58 | "Homepage" = "https://github.com/pq-yang/MatAnyone"
59 | "Bug Tracker" = "https://github.com/pq-yang/MatAnyone/issues"
60 |
61 | [tool.setuptools]
62 | package-dir = {"" = "."}
63 |
--------------------------------------------------------------------------------