├── .gitignore ├── LICENSE.txt ├── assets └── Figures │ ├── Teaser.png │ ├── gradio.png │ └── tryon.png ├── cldm ├── cldm.py ├── ddim_hacked.py ├── hack.py ├── logger.py └── model.py ├── cog.yaml ├── configs ├── anydoor.yaml ├── datasets.yaml ├── demo.yaml └── inference.yaml ├── datasets ├── Preprocess │ ├── mvimagenet.txt │ └── uvo_process.py ├── base.py ├── data_utils.py ├── dreambooth.py ├── dresscode.py ├── fashiontryon.py ├── lvis.py ├── mose.py ├── mvimagenet.py ├── saliency_modular.py ├── sam.py ├── uvo.py ├── uvo_val.py ├── vipseg.py ├── vitonhd.py ├── ytb_vis.py └── ytb_vos.py ├── dinov2 ├── .github │ └── workflows │ │ └── lint.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── conda.yaml ├── dinov2 │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── vitb14_pretrain.yaml │ │ │ ├── vitg14_pretrain.yaml │ │ │ ├── vitl14_pretrain.yaml │ │ │ └── vits14_pretrain.yaml │ │ ├── ssl_default_config.yaml │ │ └── train │ │ │ ├── vitg14.yaml │ │ │ ├── vitl14.yaml │ │ │ └── vitl16_short.yaml │ ├── data │ │ ├── __init__.py │ │ ├── adapters.py │ │ ├── augmentations.py │ │ ├── collate.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── extended.py │ │ │ ├── image_net.py │ │ │ └── image_net_22k.py │ │ ├── loaders.py │ │ ├── masking.py │ │ ├── samplers.py │ │ └── transforms.py │ ├── distributed │ │ └── __init__.py │ ├── eval │ │ ├── __init__.py │ │ ├── knn.py │ │ ├── linear.py │ │ ├── log_regression.py │ │ ├── metrics.py │ │ ├── setup.py │ │ └── utils.py │ ├── fsdp │ │ └── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── dino_head.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── logging │ │ ├── __init__.py │ │ └── helpers.py │ ├── loss │ │ ├── __init__.py │ │ ├── dino_clstoken_loss.py │ │ ├── ibot_patch_loss.py │ │ └── koleo_loss.py │ ├── models │ │ ├── __init__.py │ │ └── vision_transformer.py │ ├── run │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── knn.py │ │ │ ├── linear.py │ │ │ └── log_regression.py │ │ ├── submit.py │ │ └── train │ │ │ └── train.py │ ├── train │ │ ├── __init__.py │ │ ├── ssl_meta_arch.py │ │ └── train.py │ └── utils │ │ ├── __init__.py │ │ ├── cluster.py │ │ ├── config.py │ │ ├── dtype.py │ │ ├── param_groups.py │ │ └── utils.py ├── hubconf.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts │ └── lint.sh ├── setup.cfg └── setup.py ├── environment.yaml ├── examples ├── Gradio │ ├── BG │ │ ├── 00.png │ │ ├── 01.png │ │ ├── 02.png │ │ ├── 03.png │ │ ├── 04.jpg │ │ ├── 04.png │ │ ├── 06.png │ │ ├── 07.png │ │ ├── 08.jpg │ │ ├── 13.jpg │ │ ├── 17.jpg │ │ └── 22.png │ └── FG │ │ ├── 00.jpg │ │ ├── 01.jpg │ │ ├── 04.jpg │ │ ├── 06.jpg │ │ ├── 07.png │ │ ├── 09.jpg │ │ ├── 18.png │ │ ├── 22.jpg │ │ ├── 25.png │ │ ├── 28.png │ │ ├── 33.png │ │ ├── 36.jpg │ │ ├── 39.jpg │ │ ├── 43.jpg │ │ ├── 44.jpg │ │ └── 50.jpg └── TestDreamBooth │ ├── BG │ ├── 000000047948_GT.png │ ├── 000000047948_mask.png │ ├── 000000309203_GT.png │ └── 000000309203_mask.png │ ├── FG │ ├── 00.png │ ├── 01.png │ ├── 02.png │ └── 03.png │ └── GEN │ └── gen_res.png ├── install.ps1 ├── install_cn.ps1 ├── iseg ├── coarse_mask_refine.pth └── coarse_mask_refine_util.py ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── predict.py ├── readme.md ├── requirements-windows.txt ├── requirements.txt ├── run_dataset_debug.py ├── run_gradio_demo.py ├── run_gui.ps1 ├── run_inference.py ├── run_train_anydoor.py ├── scripts ├── convert_weight.sh ├── inference.sh ├── train.sh └── util.py └── tool_add_control_sd21.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | **/.DS_Store 3 | training/ 4 | lightning_logs/ 5 | image_log/ 6 | 7 | #*.pth 8 | *.pt 9 | *.ckpt 10 | *.safetensors 11 | 12 | gradio_pose2image_private.py 13 | gradio_canny2image_private.py 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | path/ 146 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DAMO Vision Intelligence Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/Figures/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/assets/Figures/Teaser.png -------------------------------------------------------------------------------- /assets/Figures/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/assets/Figures/gradio.png -------------------------------------------------------------------------------- /assets/Figures/tryon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/assets/Figures/tryon.png -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | build: 4 | gpu: true 5 | system_packages: 6 | - "mesa-common-dev" 7 | python_version: "3.8.5" 8 | python_packages: 9 | - "albumentations==1.3.0" 10 | - "einops==0.3.0" 11 | - "fvcore==0.1.5.post20221221" 12 | - "gradio==3.39.0" 13 | - "numpy==1.23.1" 14 | - "omegaconf==2.1.1" 15 | - "open_clip_torch==2.17.1" 16 | - "opencv_python==4.7.0.72" 17 | - "opencv_python_headless==4.7.0.72" 18 | - "Pillow==9.4.0" 19 | - "pytorch_lightning==1.5.0" 20 | - "safetensors==0.2.7" 21 | - "scipy==1.9.1" 22 | - "setuptools==66.0.0" 23 | - "share==1.0.4" 24 | - "submitit==1.5.1" 25 | - "timm==0.6.12" 26 | - "torch==2.0.0" 27 | - "torchmetrics==0.6.0" 28 | - "tqdm==4.65.0" 29 | - "transformers==4.19.2" 30 | - "xformers==0.0.18" 31 | 32 | run: 33 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.3.1/pget" && chmod +x /usr/local/bin/pget 34 | 35 | # predict.py defines how predictions are run on your model 36 | predict: "predict.py:Predictor" -------------------------------------------------------------------------------- /configs/anydoor.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "ref" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: cldm.cldm.ControlNet 23 | params: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | hint_channels: 4 #3 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | unet_config: 40 | target: cldm.cldm.ControlledUnetModel 41 | params: 42 | use_checkpoint: True 43 | image_size: 32 # unused 44 | in_channels: 4 45 | out_channels: 4 46 | model_channels: 320 47 | attention_resolutions: [ 4, 2, 1 ] 48 | num_res_blocks: 2 49 | channel_mult: [ 1, 2, 4, 4 ] 50 | num_head_channels: 64 # need to fix for flash-attn 51 | use_spatial_transformer: True 52 | use_linear_in_transformer: True 53 | transformer_depth: 1 54 | context_dim: 1024 55 | legacy: False 56 | 57 | first_stage_config: 58 | target: ldm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | #attn_type: "vanilla-xformers" 64 | double_z: true 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | cond_stage_config: 82 | target: ldm.modules.encoders.modules.FrozenDinoV2Encoder 83 | weight: path/dinov2_vitg14_pretrain.pth 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/datasets.yaml: -------------------------------------------------------------------------------- 1 | Train: 2 | YoutubeVOS: 3 | image_dir: path/YTBVOS/train/JPEGImages/ 4 | anno: path/YTBVOS/train/Annotations 5 | meta: path/YTBVOS/train/meta.json 6 | 7 | YoutubeVIS: 8 | image_dir: path/youtubevis/train/JPEGImages/ 9 | anno: path/youtubevis/train/Annotations/ 10 | meta: path/youtubevis/train/meta.json 11 | 12 | VIPSeg: 13 | image_dir: path/VIPSeg/VIPSeg_720P/images/ 14 | anno: path/VIPSeg/VIPSeg_720P/panomasksRGB/ 15 | 16 | UVO: 17 | train: 18 | image_dir: path/UVO/uvo_frames_sparse 19 | video_json: path/UVO/UVO_sparse_train_video_with_interpolation.json 20 | image_json: path/UVO/UVO_sparse_train_video_with_interpolation_reorg.json 21 | val: 22 | image_dir: path/UVO/uvo_frames_sparse 23 | video_json: path/UVO/VideoSparseSet/UVO_sparse_val_video_with_interpolation.json 24 | image_json: path/UVO/VideoSparseSet/UVO_sparse_val_video_interpolation_reorg.json 25 | 26 | Mose: 27 | image_dir: path/MOSE/train/JPEGImages/ 28 | anno: path/MOSE/train/Annotations/ 29 | 30 | MVImageNet: 31 | txt: ./datasets/Preprocess/mvimagenet.txt 32 | image_dir: /mnt/workspace/xizhi/data/MVImgNet/ 33 | 34 | VitonHD: 35 | image_dir: path/TryOn/VitonHD/train/cloth/ 36 | 37 | Dresscode: 38 | image_dir: /mnt/workspace/xizhi/data/dresscode/DressCode/upper_body/label_maps/ 39 | 40 | FashionTryon: 41 | image_dir: path/TryOn/FashionTryOn/train 42 | 43 | Lvis: 44 | image_dir: path/COCO/train2017 45 | json_path: path/lvis_v1/lvis_v1_train.json 46 | 47 | SAM: 48 | sub1: path/SAM/0000 49 | sub2: path/SAM/0001 50 | sub3: path/SAM/0002 51 | sub4: path/SAM/0004 52 | 53 | Saliency: 54 | MSRA_root: path/Saliency/MSRA10K_Imgs_GT/ 55 | TR_root: path/Saliency/DUTS-TR/DUTS-TR-Image/ 56 | TE_root: path/Saliency/DUTS-TE/DUTS-TE-Image/ 57 | HFlickr_root: path/HFlickr/masks/ 58 | 59 | Test: 60 | DreamBooth: 61 | fg_dir: path/DreamBooth/AnyDoor_DreamBooth 62 | bg_dir: path/DreamBooth/v1_800 63 | 64 | VitonHDTest: 65 | image_dir: path/TryOn/VitonHD/test/cloth 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model: path/epoch=1-step=8687-pruned.ckpt 2 | config_file: configs/anydoor.yaml 3 | save_memory: False 4 | use_interactive_seg: True 5 | -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model: path/epoch=1-step=8687-pruned.ckpt 2 | config_file: configs/anydoor.yaml 3 | save_memory: False 4 | -------------------------------------------------------------------------------- /datasets/Preprocess/uvo_process.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import os 4 | from pycocotools import mask as mask_utils 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | json_path = 'path/UVO/UVO_sparse_train_video_with_interpolation.json' 9 | output_path = "path/UVO/UVO_sparse_train_video_with_interpolation_reorg.json" 10 | 11 | with open(json_path, 'r') as fcc_file: 12 | data = json.load(fcc_file) 13 | 14 | info = data['info'] 15 | videos = data['videos'] 16 | print(len(videos)) 17 | 18 | 19 | uvo_dict = {} 20 | for video in tqdm(videos): 21 | vid = video['id'] 22 | file_names = video['file_names'] 23 | uvo_dict[vid] = file_names 24 | 25 | 26 | with open(output_path,"w") as f: 27 | json.dump(uvo_dict,f) 28 | print('finish') 29 | 30 | -------------------------------------------------------------------------------- /datasets/dreambooth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | 11 | class DreamBoothDataset(BaseDataset): 12 | def __init__(self, fg_dir, bg_dir): 13 | self.bg_dir = bg_dir 14 | bg_data = os.listdir(self.bg_dir) 15 | self.bg_data = [i for i in bg_data if 'mask' in i] 16 | self.image_dir = fg_dir 17 | self.data = os.listdir(self.image_dir) 18 | self.size = (512,512) 19 | self.clip_size = (224,224) 20 | ''' 21 | Dynamic: 22 | 0: Static View, High Quality 23 | 1: Multi-view, Low Quality 24 | 2: Multi-view, High Quality 25 | ''' 26 | self.dynamic = 1 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | idx = np.random.randint(0, len(self.data)-1) 33 | item = self.get_sample(idx) 34 | return item 35 | 36 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 37 | pass_flag = True 38 | H,W = image.shape[0], image.shape[1] 39 | H,W = H * ratio, W * ratio 40 | y1,y2,x1,x2 = yyxx 41 | h,w = y2-y1,x2-x1 42 | if mode == 'max': 43 | if h > H and w > W: 44 | pass_flag = False 45 | elif mode == 'min': 46 | if h < H and w < W: 47 | pass_flag = False 48 | return pass_flag 49 | 50 | def get_alpha_mask(self, mask_path): 51 | image = cv2.imread( mask_path, cv2.IMREAD_UNCHANGED) 52 | mask = (image[:,:,-1] > 128).astype(np.uint8) 53 | return mask 54 | 55 | def get_sample(self, idx): 56 | dir_name = self.data[idx] 57 | dir_path = os.path.join(self.image_dir, dir_name) 58 | images = os.listdir(dir_path) 59 | image_name = [i for i in images if '.png' in i][0] 60 | image_path = os.path.join(dir_path, image_name) 61 | 62 | image = cv2.imread( image_path, cv2.IMREAD_UNCHANGED) 63 | mask = (image[:,:,-1] > 128).astype(np.uint8) 64 | image = image[:,:,:-1] 65 | 66 | image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB) 67 | ref_image = image 68 | ref_mask = mask 69 | ref_image, ref_mask = expand_image_mask(image, mask, ratio=1.4) 70 | bg_idx = np.random.randint(0, len(self.bg_data)-1) 71 | 72 | tar_mask_name = self.bg_data[bg_idx] 73 | tar_mask_path = os.path.join(self.bg_dir, tar_mask_name) 74 | tar_image_path = tar_mask_path.replace('_mask','_GT') 75 | 76 | tar_image = cv2.imread(tar_image_path).astype(np.uint8) 77 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 78 | tar_mask = (cv2.imread(tar_mask_path) > 128).astype(np.uint8)[:,:,0] 79 | 80 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 81 | sampled_time_steps = self.sample_timestep() 82 | item_with_collage['time_steps'] = sampled_time_steps 83 | return item_with_collage 84 | 85 | -------------------------------------------------------------------------------- /datasets/dresscode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | import albumentations as A 11 | 12 | class DresscodeDataset(BaseDataset): 13 | def __init__(self, image_dir): 14 | self.image_root = image_dir 15 | self.data = os.listdir(self.image_root) 16 | self.size = (512,512) 17 | self.clip_size = (224,224) 18 | self.dynamic = 2 19 | 20 | def __len__(self): 21 | return 20000 22 | 23 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 24 | pass_flag = True 25 | H,W = image.shape[0], image.shape[1] 26 | H,W = H * ratio, W * ratio 27 | y1,y2,x1,x2 = yyxx 28 | h,w = y2-y1,x2-x1 29 | if mode == 'max': 30 | if h > H and w > W: 31 | pass_flag = False 32 | elif mode == 'min': 33 | if h < H and w < W: 34 | pass_flag = False 35 | return pass_flag 36 | 37 | def get_sample(self, idx): 38 | tar_mask_path = os.path.join(self.image_root, self.data[idx]) 39 | tar_image_path = tar_mask_path.replace('label_maps/','images/').replace('_4.png','_0.jpg') 40 | ref_image_path = tar_mask_path.replace('label_maps/','images/').replace('_4.png','_1.jpg') 41 | 42 | # Read Image and Mask 43 | ref_image = cv2.imread(ref_image_path) 44 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 45 | 46 | tar_image = cv2.imread(tar_image_path) 47 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 48 | 49 | ref_mask = (ref_image < 240).astype(np.uint8)[:,:,0] 50 | 51 | 52 | tar_mask = Image.open(tar_mask_path ).convert('P') 53 | tar_mask= np.array(tar_mask) 54 | tar_mask = tar_mask == 4 55 | 56 | 57 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 1.0) 58 | sampled_time_steps = self.sample_timestep() 59 | item_with_collage['time_steps'] = sampled_time_steps 60 | return item_with_collage 61 | 62 | -------------------------------------------------------------------------------- /datasets/fashiontryon.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | import albumentations as A 11 | 12 | class FashionTryonDataset(BaseDataset): 13 | def __init__(self, image_dir): 14 | self.image_root = image_dir 15 | self.data =os.listdir(self.image_root) 16 | self.size = (512,512) 17 | self.clip_size = (224,224) 18 | self.dynamic = 2 19 | 20 | def __len__(self): 21 | return 5000 22 | 23 | def aug_data(self, image): 24 | transform = A.Compose([ 25 | A.RandomBrightnessContrast(p=0.5), 26 | ]) 27 | transformed = transform(image=image.astype(np.uint8)) 28 | transformed_image = transformed["image"] 29 | return transformed_image 30 | 31 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 32 | pass_flag = True 33 | H,W = image.shape[0], image.shape[1] 34 | H,W = H * ratio, W * ratio 35 | y1,y2,x1,x2 = yyxx 36 | h,w = y2-y1,x2-x1 37 | if mode == 'max': 38 | if h > H and w > W: 39 | pass_flag = False 40 | elif mode == 'min': 41 | if h < H and w < W: 42 | pass_flag = False 43 | return pass_flag 44 | 45 | def get_sample(self, idx): 46 | cloth_dir = os.path.join(self.image_root, self.data[idx]) 47 | ref_image_path = os.path.join(cloth_dir, 'target.jpg') 48 | 49 | ref_image = cv2.imread(ref_image_path) 50 | ref_image = cv2.cvtColor(ref_image.copy(), cv2.COLOR_BGR2RGB) 51 | 52 | ref_mask_path = os.path.join(cloth_dir,'mask.jpg') 53 | ref_mask = cv2.imread(ref_mask_path)[:,:,0] > 128 54 | 55 | target_dirs = [i for i in os.listdir(cloth_dir ) if '.jpg' not in i] 56 | target_dir_name = np.random.choice(target_dirs) 57 | 58 | target_image_path = os.path.join(cloth_dir, target_dir_name + '.jpg') 59 | target_image= cv2.imread(target_image_path) 60 | tar_image = cv2.cvtColor(target_image.copy(), cv2.COLOR_BGR2RGB) 61 | 62 | target_mask_path = os.path.join(cloth_dir, target_dir_name, 'segment.png') 63 | tar_mask= cv2.imread(target_mask_path)[:,:,0] 64 | target_mask = tar_mask == 7 65 | kernel = np.ones((3, 3), dtype=np.uint8) 66 | tar_mask = cv2.erode(target_mask.astype(np.uint8), kernel, iterations=3) 67 | 68 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 1.0) 69 | sampled_time_steps = self.sample_timestep() 70 | item_with_collage['time_steps'] = sampled_time_steps 71 | return item_with_collage 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /datasets/lvis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | from pycocotools import mask as mask_utils 11 | from lvis import LVIS 12 | 13 | class LvisDataset(BaseDataset): 14 | def __init__(self, image_dir, json_path): 15 | self.image_dir = image_dir 16 | self.json_path = json_path 17 | lvis_api = LVIS(json_path) 18 | img_ids = sorted(lvis_api.imgs.keys()) 19 | imgs = lvis_api.load_imgs(img_ids) 20 | anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] 21 | self.data = imgs 22 | self.annos = anns 23 | self.lvis_api = lvis_api 24 | self.size = (512,512) 25 | self.clip_size = (224,224) 26 | self.dynamic = 0 27 | 28 | def register_subset(self, path): 29 | data = os.listdir(path) 30 | data = [ os.path.join(path, i) for i in data if '.json' in i] 31 | self.data = self.data + data 32 | 33 | def get_sample(self, idx): 34 | # ==== get pairs ===== 35 | image_name = self.data[idx]['coco_url'].split('/')[-1] 36 | image_path = os.path.join(self.image_dir, image_name) 37 | image = cv2.imread(image_path) 38 | ref_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 39 | 40 | anno = self.annos[idx] 41 | obj_ids = [] 42 | for i in range(len(anno)): 43 | obj = anno[i] 44 | area = obj['area'] 45 | if area > 3600: 46 | obj_ids.append(i) 47 | assert len(anno) > 0 48 | obj_id = np.random.choice(obj_ids) 49 | anno = anno[obj_id] 50 | ref_mask = self.lvis_api.ann_to_mask(anno) 51 | 52 | tar_image, tar_mask = ref_image.copy(), ref_mask.copy() 53 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 54 | sampled_time_steps = self.sample_timestep() 55 | item_with_collage['time_steps'] = sampled_time_steps 56 | return item_with_collage 57 | 58 | def __len__(self): 59 | return 20000 60 | 61 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 62 | pass_flag = True 63 | H,W = image.shape[0], image.shape[1] 64 | H,W = H * ratio, W * ratio 65 | y1,y2,x1,x2 = yyxx 66 | h,w = y2-y1,x2-x1 67 | if mode == 'max': 68 | if h > H or w > W: 69 | pass_flag = False 70 | elif mode == 'min': 71 | if h < H or w < W: 72 | pass_flag = False 73 | return pass_flag 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /datasets/mose.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from PIL import Image 10 | from .base import BaseDataset 11 | 12 | class MoseDataset(BaseDataset): 13 | def __init__(self, image_dir, anno): 14 | self.image_root = image_dir 15 | self.anno_root = anno 16 | 17 | video_dirs = [] 18 | video_dirs = os.listdir(self.image_root) 19 | self.data = video_dirs 20 | self.size = (512,512) 21 | self.clip_size = (224,224) 22 | self.dynamic = 2 23 | 24 | def __len__(self): 25 | return 40000 26 | 27 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 28 | pass_flag = True 29 | H,W = image.shape[0], image.shape[1] 30 | H,W = H * ratio, W * ratio 31 | y1,y2,x1,x2 = yyxx 32 | h,w = y2-y1,x2-x1 33 | if mode == 'max': 34 | if h > H or w > W: 35 | pass_flag = False 36 | elif mode == 'min': 37 | if h < H or w < W: 38 | pass_flag = False 39 | return pass_flag 40 | 41 | def get_sample(self, idx): 42 | video_name = self.data[idx] 43 | video_path = os.path.join(self.image_root, video_name) 44 | frames = os.listdir(video_path) 45 | 46 | # Sampling frames 47 | min_interval = len(frames) // 10 48 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 49 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 50 | end_frame_index = min(end_frame_index, len(frames) - 1) 51 | 52 | # Get image path 53 | ref_image_name = frames[start_frame_index] 54 | tar_image_name = frames[end_frame_index] 55 | ref_image_path = os.path.join(self.image_root, video_name, ref_image_name) 56 | tar_image_path = os.path.join(self.image_root, video_name, tar_image_name) 57 | 58 | ref_mask_path = ref_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 59 | tar_mask_path = tar_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 60 | 61 | # Read Image and Mask 62 | ref_image = cv2.imread(ref_image_path) 63 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 64 | 65 | tar_image = cv2.imread(tar_image_path) 66 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 67 | 68 | ref_mask = Image.open(ref_mask_path ).convert('P') 69 | ref_mask= np.array(ref_mask) 70 | 71 | tar_mask = Image.open(tar_mask_path ).convert('P') 72 | tar_mask= np.array(tar_mask) 73 | 74 | ref_ids = np.unique(ref_mask) 75 | tar_ids = np.unique(tar_mask) 76 | 77 | common_ids = list(np.intersect1d(ref_ids, tar_ids)) 78 | common_ids = [ i for i in common_ids if i != 0 ] 79 | assert len(common_ids) > 0 80 | chosen_id = np.random.choice(common_ids) 81 | ref_mask = ref_mask == chosen_id 82 | tar_mask = tar_mask == chosen_id 83 | len_mask = len( self.check_connect( ref_mask.astype(np.uint8) ) ) 84 | assert len_mask == 1 85 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 86 | sampled_time_steps = self.sample_timestep() 87 | item_with_collage['time_steps'] = sampled_time_steps 88 | return item_with_collage 89 | 90 | def check_connect(self, mask): 91 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 92 | cnt_area = [cv2.contourArea(cnt) for cnt in contours] 93 | return cnt_area 94 | 95 | -------------------------------------------------------------------------------- /datasets/mvimagenet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | 11 | class MVImageNetDataset(BaseDataset): 12 | def __init__(self, txt, image_dir): 13 | with open(txt,"r") as f: 14 | data = f.read().split('\n')[:-1] 15 | self.image_dir = image_dir 16 | self.data = data 17 | self.size = (512,512) 18 | self.clip_size = (224,224) 19 | self.dynamic = 2 20 | 21 | def __len__(self): 22 | return 40000 23 | 24 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 25 | pass_flag = True 26 | H,W = image.shape[0], image.shape[1] 27 | H,W = H * ratio, W * ratio 28 | y1,y2,x1,x2 = yyxx 29 | h,w = y2-y1,x2-x1 30 | if mode == 'max': 31 | if h > H and w > W: 32 | pass_flag = False 33 | elif mode == 'min': 34 | if h < H and w < W: 35 | pass_flag = False 36 | return pass_flag 37 | 38 | def get_alpha_mask(self, mask_path): 39 | image = cv2.imread( mask_path, cv2.IMREAD_UNCHANGED) 40 | mask = (image[:,:,-1] > 128).astype(np.uint8) 41 | return mask 42 | 43 | def get_sample(self, idx): 44 | object_dir = self.data[idx].replace('MVDir/', self.image_dir) 45 | frames = os.listdir(object_dir) 46 | frames = [ i for i in frames if '.png' in i] 47 | 48 | # Sampling frames 49 | min_interval = len(frames) // 8 50 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 51 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 52 | end_frame_index = min(end_frame_index, len(frames) - 1) 53 | 54 | # Get image path 55 | ref_mask_name = frames[start_frame_index] 56 | tar_mask_name = frames[end_frame_index] 57 | 58 | ref_image_name = ref_mask_name.split('_')[0] + '.jpg' 59 | tar_image_name = tar_mask_name.split('_')[0] + '.jpg' 60 | 61 | ref_mask_path = os.path.join(object_dir, ref_mask_name) 62 | tar_mask_path = os.path.join(object_dir, tar_mask_name) 63 | ref_image_path = os.path.join(object_dir, ref_image_name) 64 | tar_image_path = os.path.join(object_dir, tar_image_name) 65 | 66 | # Read Image and Mask 67 | ref_image = cv2.imread(ref_image_path).astype(np.uint8) 68 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 69 | 70 | tar_image = cv2.imread(tar_image_path).astype(np.uint8) 71 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 72 | 73 | ref_mask = self.get_alpha_mask(ref_mask_path) 74 | tar_mask = self.get_alpha_mask(tar_mask_path) 75 | 76 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 77 | sampled_time_steps = self.sample_timestep() 78 | item_with_collage['time_steps'] = sampled_time_steps 79 | 80 | return item_with_collage 81 | 82 | -------------------------------------------------------------------------------- /datasets/saliency_modular.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | 11 | class SaliencyDataset(BaseDataset): 12 | def __init__(self, MSRA_root, TR_root, TE_root, HFlickr_root): 13 | image_mask_dict = {} 14 | 15 | # ====== MSRA-10k ====== 16 | file_lst = os.listdir(MSRA_root) 17 | image_lst = [MSRA_root+i for i in file_lst if '.jpg' in i] 18 | for i in image_lst: 19 | mask_path = i.replace('.jpg','.png') 20 | image_mask_dict[i] = mask_path 21 | 22 | # ===== DUT-TR ======== 23 | file_lst = os.listdir(TR_root) 24 | image_lst = [TR_root+i for i in file_lst if '.jpg' in i] 25 | for i in image_lst: 26 | mask_path = i.replace('.jpg','.png').replace('DUTS-TR-Image','DUTS-TR-Mask') 27 | image_mask_dict[i] = mask_path 28 | 29 | # ===== DUT-TE ======== 30 | file_lst = os.listdir(TE_root) 31 | image_lst = [TE_root+i for i in file_lst if '.jpg' in i] 32 | for i in image_lst: 33 | mask_path = i.replace('.jpg','.png').replace('DUTS-TE-Image','DUTS-TE-Mask') 34 | image_mask_dict[i] = mask_path 35 | 36 | # ===== HFlickr ======= 37 | file_lst = os.listdir(HFlickr_root) 38 | mask_list = [HFlickr_root+i for i in file_lst if '.png' in i] 39 | for i in file_lst: 40 | image_name = i.split('_')[0] +'.jpg' 41 | image_path = HFlickr_root.replace('masks', 'real_images') + image_name 42 | mask_path = HFlickr_root + i 43 | image_mask_dict[image_path] = mask_path 44 | 45 | self.image_mask_dict = image_mask_dict 46 | self.data = list(self.image_mask_dict.keys() ) 47 | self.size = (512,512) 48 | self.clip_size = (224,224) 49 | self.dynamic = 0 50 | 51 | def __len__(self): 52 | return 20000 53 | 54 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 55 | pass_flag = True 56 | H,W = image.shape[0], image.shape[1] 57 | H,W = H * ratio, W * ratio 58 | y1,y2,x1,x2 = yyxx 59 | h,w = y2-y1,x2-x1 60 | if mode == 'max': 61 | if h > H or w > W: 62 | pass_flag = False 63 | elif mode == 'min': 64 | if h < H or w < W: 65 | pass_flag = False 66 | return pass_flag 67 | 68 | def get_sample(self, idx): 69 | 70 | # ==== get pairs ===== 71 | image_path = self.data[idx] 72 | mask_path = self.image_mask_dict[image_path] 73 | 74 | instances_mask = cv2.imread(mask_path) 75 | if len(instances_mask.shape) == 3: 76 | instances_mask = instances_mask[:,:,0] 77 | instances_mask = (instances_mask > 128).astype(np.uint8) 78 | # ====================== 79 | ref_image = cv2.imread(image_path) 80 | ref_image = cv2.cvtColor(ref_image.copy(), cv2.COLOR_BGR2RGB) 81 | tar_image = ref_image 82 | 83 | ref_mask = instances_mask 84 | tar_mask = instances_mask 85 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 86 | sampled_time_steps = self.sample_timestep() 87 | item_with_collage['time_steps'] = sampled_time_steps 88 | return item_with_collage 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /datasets/sam.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | from pycocotools import mask as mask_utils 11 | 12 | class SAMDataset(BaseDataset): 13 | def __init__(self, sub1, sub2, sub3, sub4): 14 | image_mask_dict = {} 15 | self.data = [] 16 | self.register_subset(sub1) 17 | self.register_subset(sub2) 18 | self.register_subset(sub3) 19 | self.register_subset(sub4) 20 | self.size = (512,512) 21 | self.clip_size = (224,224) 22 | self.dynamic = 0 23 | 24 | def register_subset(self, path): 25 | data = os.listdir(path) 26 | data = [ os.path.join(path, i) for i in data if '.json' in i] 27 | self.data = self.data + data 28 | 29 | def get_sample(self, idx): 30 | # ==== get pairs ===== 31 | json_path = self.data[idx] 32 | image_path = json_path.replace('.json', '.jpg') 33 | 34 | with open(json_path, 'r') as json_file: 35 | data = json.load(json_file) 36 | annotation = data['annotations'] 37 | 38 | valid_ids = [] 39 | for i in range(len(annotation)): 40 | area = annotation[i]['area'] 41 | if area > 100 * 100 * 5: 42 | valid_ids.append(i) 43 | 44 | chosen_id = np.random.choice(valid_ids) 45 | mask = mask_utils.decode(annotation[chosen_id]["segmentation"] ) 46 | # ====================== 47 | 48 | image = cv2.imread(image_path) 49 | ref_image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB) 50 | tar_image = ref_image 51 | 52 | ref_mask = mask 53 | tar_mask = mask 54 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 55 | sampled_time_steps = self.sample_timestep() 56 | item_with_collage['time_steps'] = sampled_time_steps 57 | return item_with_collage 58 | 59 | def __len__(self): 60 | return 20000 61 | 62 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 63 | pass_flag = True 64 | H,W = image.shape[0], image.shape[1] 65 | H,W = H * ratio, W * ratio 66 | y1,y2,x1,x2 = yyxx 67 | h,w = y2-y1,x2-x1 68 | if mode == 'max': 69 | if h > H or w > W: 70 | pass_flag = False 71 | elif mode == 'min': 72 | if h < H or w < W: 73 | pass_flag = False 74 | return pass_flag 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /datasets/uvo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | from pycocotools import mask as mask_utils 11 | 12 | class UVODataset(BaseDataset): 13 | def __init__(self, image_dir, video_json, image_json): 14 | json_path = video_json 15 | with open(json_path, 'r') as fcc_file: 16 | data = json.load(fcc_file) 17 | 18 | image_json_path = image_json 19 | with open(image_json_path , 'r') as image_file: 20 | video_dict = json.load(image_file) 21 | 22 | self.image_root = image_dir 23 | self.data = data['annotations'] 24 | self.video_dict = video_dict 25 | self.size = (512,512) 26 | self.clip_size = (224,224) 27 | self.dynamic = 1 28 | 29 | def __len__(self): 30 | return 25000 31 | 32 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 33 | pass_flag = True 34 | H,W = image.shape[0], image.shape[1] 35 | H,W = H * ratio, W * ratio 36 | y1,y2,x1,x2 = yyxx 37 | h,w = y2-y1,x2-x1 38 | if mode == 'max': 39 | if h > H and w > W: 40 | pass_flag = False 41 | elif mode == 'min': 42 | if h < H and w < W: 43 | pass_flag = False 44 | return pass_flag 45 | 46 | def get_sample(self, idx): 47 | ins_anno = self.data[idx] 48 | video_id = str(ins_anno['video_id']) 49 | video_names = self.video_dict[video_id] 50 | masks = ins_anno['segmentations'] 51 | frames = video_names 52 | 53 | # Sampling frames 54 | min_interval = len(frames) // 10 55 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 56 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 57 | end_frame_index = min(end_frame_index, len(frames) - 1) 58 | 59 | # Get image path 60 | ref_image_name = frames[start_frame_index] 61 | tar_image_name = frames[end_frame_index] 62 | ref_image_path = os.path.join(self.image_root, ref_image_name) 63 | tar_image_path = os.path.join(self.image_root, tar_image_name) 64 | 65 | # Read Image and Mask 66 | ref_image = cv2.imread(ref_image_path) 67 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 68 | 69 | tar_image = cv2.imread(tar_image_path) 70 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 71 | 72 | ref_mask = mask_utils.decode(masks[start_frame_index]) 73 | tar_mask = mask_utils.decode(masks[end_frame_index]) 74 | 75 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 76 | sampled_time_steps = self.sample_timestep() 77 | item_with_collage['time_steps'] = sampled_time_steps 78 | return item_with_collage 79 | 80 | -------------------------------------------------------------------------------- /datasets/uvo_val.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | from pycocotools import mask as mask_utils 11 | 12 | class UVOValDataset(BaseDataset): 13 | def __init__(self, image_dir, video_json, image_json): 14 | json_path = video_json 15 | with open(json_path, 'r') as fcc_file: 16 | data = json.load(fcc_file) 17 | image_json_path = image_json 18 | with open(image_json_path , 'r') as image_file: 19 | video_dict = json.load(image_file) 20 | self.image_root = image_dir 21 | self.data = data['annotations'] 22 | self.video_dict = video_dict 23 | self.size = (512,512) 24 | self.clip_size = (224,224) 25 | self.dynamic = 1 26 | 27 | def __len__(self): 28 | return 8000 29 | 30 | def __getitem__(self, idx): 31 | while(1): 32 | idx = np.random.randint(0, len(self.data)-1) 33 | try: 34 | item = self.get_sample(idx) 35 | return item 36 | except: 37 | idx = np.random.randint(0, len(self.data)-1) 38 | 39 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 40 | pass_flag = True 41 | H,W = image.shape[0], image.shape[1] 42 | H,W = H * ratio, W * ratio 43 | y1,y2,x1,x2 = yyxx 44 | h,w = y2-y1,x2-x1 45 | if mode == 'max': 46 | if h > H and w > W: 47 | pass_flag = False 48 | elif mode == 'min': 49 | if h < H and w < W: 50 | pass_flag = False 51 | return pass_flag 52 | 53 | def get_sample(self, idx): 54 | ins_anno = self.data[idx] 55 | video_id = str(ins_anno['video_id']) 56 | 57 | video_names = self.video_dict[video_id] 58 | masks = ins_anno['segmentations'] 59 | frames = video_names 60 | 61 | # Sampling frames 62 | min_interval = len(frames) // 5 63 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 64 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 65 | end_frame_index = min(end_frame_index, len(frames) - 1) 66 | 67 | # Get image path 68 | ref_image_name = frames[start_frame_index] 69 | tar_image_name = frames[end_frame_index] 70 | ref_image_path = os.path.join(self.image_root, ref_image_name) 71 | tar_image_path = os.path.join(self.image_root, tar_image_name) 72 | 73 | # Read Image and Mask 74 | ref_image = cv2.imread(ref_image_path) 75 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 76 | 77 | tar_image = cv2.imread(tar_image_path) 78 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 79 | 80 | ref_mask = mask_utils.decode(masks[start_frame_index]) 81 | tar_mask = mask_utils.decode(masks[end_frame_index]) 82 | 83 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 84 | sampled_time_steps = self.sample_timestep() 85 | item_with_collage['time_steps'] = sampled_time_steps 86 | return item_with_collage 87 | 88 | -------------------------------------------------------------------------------- /datasets/vipseg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from panopticapi.utils import rgb2id 10 | from PIL import Image 11 | from .base import BaseDataset 12 | 13 | class VIPSegDataset(BaseDataset): 14 | def __init__(self, image_dir, anno): 15 | self.image_root = image_dir 16 | self.anno_root = anno 17 | video_dirs = [] 18 | video_dirs = os.listdir(self.image_root) 19 | self.data = video_dirs 20 | self.size = (512,512) 21 | self.clip_size = (224,224) 22 | self.dynamic = 1 23 | 24 | def __len__(self): 25 | return 30000 26 | 27 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 28 | pass_flag = True 29 | H,W = image.shape[0], image.shape[1] 30 | H,W = H * ratio, W * ratio 31 | y1,y2,x1,x2 = yyxx 32 | h,w = y2-y1,x2-x1 33 | if mode == 'max': 34 | if h > H or w > W: 35 | pass_flag = False 36 | elif mode == 'min': 37 | if h < H or w < W: 38 | pass_flag = False 39 | return pass_flag 40 | 41 | def get_sample(self, idx): 42 | video_name = self.data[idx] 43 | video_path = os.path.join(self.image_root, video_name) 44 | frames = os.listdir(video_path) 45 | 46 | # Sampling frames 47 | min_interval = len(frames) // 100 48 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 49 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 50 | end_frame_index = min(end_frame_index, len(frames) - 1) 51 | 52 | # Get image path 53 | ref_image_name = frames[start_frame_index] 54 | tar_image_name = frames[end_frame_index] 55 | ref_image_path = os.path.join(self.image_root, video_name, ref_image_name) 56 | tar_image_path = os.path.join(self.image_root, video_name, tar_image_name) 57 | 58 | ref_mask_path = ref_image_path.replace('images','panomasksRGB').replace('.jpg', '.png') 59 | tar_mask_path = tar_image_path.replace('images','panomasksRGB').replace('.jpg', '.png') 60 | 61 | # Read Image and Mask 62 | ref_image = cv2.imread(ref_image_path) 63 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 64 | 65 | tar_image = cv2.imread(tar_image_path) 66 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 67 | 68 | ref_mask = np.array(Image.open(ref_mask_path).convert('RGB')) 69 | ref_mask = rgb2id(ref_mask) 70 | 71 | tar_mask = np.array(Image.open(tar_mask_path).convert('RGB')) 72 | tar_mask = rgb2id(tar_mask) 73 | 74 | ref_ids = np.unique(ref_mask) 75 | tar_ids = np.unique(tar_mask) 76 | 77 | common_ids = list(np.intersect1d(ref_ids, tar_ids)) 78 | common_ids = [ i for i in common_ids if i != 0 ] 79 | 80 | chosen_id = np.random.choice(common_ids) 81 | ref_mask = ref_mask == chosen_id 82 | tar_mask = tar_mask == chosen_id 83 | 84 | len_mask = len( self.check_connect( ref_mask.astype(np.uint8) ) ) 85 | assert len_mask == 1 86 | 87 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 88 | sampled_time_steps = self.sample_timestep() 89 | item_with_collage['time_steps'] = sampled_time_steps 90 | return item_with_collage 91 | 92 | def check_connect(self, mask): 93 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 94 | cnt_area = [cv2.contourArea(cnt) for cnt in contours] 95 | return cnt_area 96 | 97 | -------------------------------------------------------------------------------- /datasets/vitonhd.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | import albumentations as A 11 | 12 | class VitonHDDataset(BaseDataset): 13 | def __init__(self, image_dir): 14 | self.image_root = image_dir 15 | self.data = os.listdir(self.image_root) 16 | self.size = (512,512) 17 | self.clip_size = (224,224) 18 | self.dynamic = 2 19 | 20 | def __len__(self): 21 | return 20000 22 | 23 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 24 | pass_flag = True 25 | H,W = image.shape[0], image.shape[1] 26 | H,W = H * ratio, W * ratio 27 | y1,y2,x1,x2 = yyxx 28 | h,w = y2-y1,x2-x1 29 | if mode == 'max': 30 | if h > H and w > W: 31 | pass_flag = False 32 | elif mode == 'min': 33 | if h < H and w < W: 34 | pass_flag = False 35 | return pass_flag 36 | 37 | def get_sample(self, idx): 38 | 39 | ref_image_path = os.path.join(self.image_root, self.data[idx]) 40 | tar_image_path = ref_image_path.replace('/cloth/', '/image/') 41 | ref_mask_path = ref_image_path.replace('/cloth/','/cloth-mask/') 42 | tar_mask_path = ref_image_path.replace('/cloth/', '/image-parse-v3/').replace('.jpg','.png') 43 | 44 | # Read Image and Mask 45 | ref_image = cv2.imread(ref_image_path) 46 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 47 | 48 | tar_image = cv2.imread(tar_image_path) 49 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 50 | 51 | ref_mask = (cv2.imread(ref_mask_path) > 128).astype(np.uint8)[:,:,0] 52 | 53 | tar_mask = Image.open(tar_mask_path ).convert('P') 54 | tar_mask= np.array(tar_mask) 55 | tar_mask = tar_mask == 5 56 | 57 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 1.0) 58 | sampled_time_steps = self.sample_timestep() 59 | item_with_collage['time_steps'] = sampled_time_steps 60 | return item_with_collage 61 | 62 | -------------------------------------------------------------------------------- /datasets/ytb_vis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | 11 | class YoutubeVISDataset(BaseDataset): 12 | def __init__(self, image_dir, anno, meta): 13 | self.image_root = image_dir 14 | self.anno_root = anno 15 | self.meta_file = meta 16 | 17 | video_dirs = [] 18 | with open(self.meta_file) as f: 19 | records = json.load(f) 20 | records = records["videos"] 21 | for video_id in records: 22 | video_dirs.append(video_id) 23 | 24 | self.records = records 25 | self.data = video_dirs 26 | self.size = (512,512) 27 | self.clip_size = (224,224) 28 | self.dynamic = 1 29 | 30 | def __len__(self): 31 | return 40000 32 | 33 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 34 | pass_flag = True 35 | H,W = image.shape[0], image.shape[1] 36 | H,W = H * ratio, W * ratio 37 | y1,y2,x1,x2 = yyxx 38 | h,w = y2-y1,x2-x1 39 | if mode == 'max': 40 | if h > H and w > W: 41 | pass_flag = False 42 | elif mode == 'min': 43 | if h < H and w < W: 44 | pass_flag = False 45 | return pass_flag 46 | 47 | def get_sample(self, idx): 48 | video_id = list(self.records.keys())[idx] 49 | objects_id = np.random.choice( list(self.records[video_id]["objects"].keys()) ) 50 | frames = self.records[video_id]["objects"][objects_id]["frames"] 51 | 52 | # Sampling frames 53 | min_interval = len(frames) // 10 54 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 55 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 56 | end_frame_index = min(end_frame_index, len(frames) - 1) 57 | 58 | # Get image path 59 | ref_image_name = frames[start_frame_index] 60 | tar_image_name = frames[end_frame_index] 61 | ref_image_path = os.path.join(self.image_root, video_id, ref_image_name) + '.jpg' 62 | tar_image_path = os.path.join(self.image_root, video_id, tar_image_name) + '.jpg' 63 | ref_mask_path = ref_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 64 | tar_mask_path = tar_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 65 | 66 | # Read Image and Mask 67 | ref_image = cv2.imread(ref_image_path) 68 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 69 | 70 | tar_image = cv2.imread(tar_image_path) 71 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 72 | 73 | ref_mask = Image.open(ref_mask_path ).convert('P') 74 | ref_mask= np.array(ref_mask) 75 | ref_mask = ref_mask == int(objects_id) 76 | 77 | tar_mask = Image.open(tar_mask_path ).convert('P') 78 | tar_mask= np.array(tar_mask) 79 | tar_mask = tar_mask == int(objects_id) 80 | 81 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 82 | sampled_time_steps = self.sample_timestep() 83 | item_with_collage['time_steps'] = sampled_time_steps 84 | return item_with_collage 85 | 86 | -------------------------------------------------------------------------------- /datasets/ytb_vos.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import cv2 8 | from .data_utils import * 9 | from .base import BaseDataset 10 | 11 | class YoutubeVOSDataset(BaseDataset): 12 | def __init__(self, image_dir, anno, meta): 13 | self.image_root = image_dir 14 | self.anno_root = anno 15 | self.meta_file = meta 16 | 17 | video_dirs = [] 18 | with open(self.meta_file) as f: 19 | records = json.load(f) 20 | records = records["videos"] 21 | for video_id in records: 22 | video_dirs.append(video_id) 23 | 24 | self.records = records 25 | self.data = video_dirs 26 | self.size = (512,512) 27 | self.clip_size = (224,224) 28 | self.dynamic = 1 29 | 30 | def __len__(self): 31 | return 40000 32 | 33 | def check_region_size(self, image, yyxx, ratio, mode = 'max'): 34 | pass_flag = True 35 | H,W = image.shape[0], image.shape[1] 36 | H,W = H * ratio, W * ratio 37 | y1,y2,x1,x2 = yyxx 38 | h,w = y2-y1,x2-x1 39 | if mode == 'max': 40 | if h > H and w > W: 41 | pass_flag = False 42 | elif mode == 'min': 43 | if h < H and w < W: 44 | pass_flag = False 45 | return pass_flag 46 | 47 | def get_sample(self, idx): 48 | video_id = list(self.records.keys())[idx] 49 | objects_id = np.random.choice( list(self.records[video_id]["objects"].keys()) ) 50 | frames = self.records[video_id]["objects"][objects_id]["frames"] 51 | 52 | # Sampling frames 53 | min_interval = len(frames) // 10 54 | start_frame_index = np.random.randint(low=0, high=len(frames) - min_interval) 55 | end_frame_index = start_frame_index + np.random.randint(min_interval, len(frames) - start_frame_index ) 56 | end_frame_index = min(end_frame_index, len(frames) - 1) 57 | 58 | # Get image path 59 | ref_image_name = frames[start_frame_index] 60 | tar_image_name = frames[end_frame_index] 61 | ref_image_path = os.path.join(self.image_root, video_id, ref_image_name) + '.jpg' 62 | tar_image_path = os.path.join(self.image_root, video_id, tar_image_name) + '.jpg' 63 | ref_mask_path = ref_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 64 | tar_mask_path = tar_image_path.replace('JPEGImages','Annotations').replace('.jpg', '.png') 65 | 66 | # Read Image and Mask 67 | ref_image = cv2.imread(ref_image_path) 68 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 69 | 70 | tar_image = cv2.imread(tar_image_path) 71 | tar_image = cv2.cvtColor(tar_image, cv2.COLOR_BGR2RGB) 72 | 73 | ref_mask = Image.open(ref_mask_path ).convert('P') 74 | ref_mask= np.array(ref_mask) 75 | ref_mask = ref_mask == int(objects_id) 76 | 77 | tar_mask = Image.open(tar_mask_path ).convert('P') 78 | tar_mask= np.array(tar_mask) 79 | tar_mask = tar_mask == int(objects_id) 80 | 81 | 82 | item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask) 83 | sampled_time_steps = self.sample_timestep() 84 | item_with_collage['time_steps'] = sampled_time_steps 85 | return item_with_collage 86 | 87 | 88 | -------------------------------------------------------------------------------- /dinov2/.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - master 10 | - 'gh/**' 11 | 12 | jobs: 13 | run-linters: 14 | name: Run linters 15 | runs-on: ubuntu-20.04 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: 3.9 24 | cache: 'pip' 25 | cache-dependency-path: '**/requirements*.txt' 26 | - name: Install Python (development) dependencies 27 | run: | 28 | pip install -r requirements-dev.txt 29 | - name: Run flake8 30 | run: | 31 | flake8 32 | - name: Run black 33 | if: always() 34 | run: | 35 | black --check dinov2 36 | - name: Run pylint 37 | if: always() 38 | run: | 39 | pylint --exit-zero dinov2 40 | -------------------------------------------------------------------------------- /dinov2/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | **/__pycache__/ 5 | 6 | **/.ipynb_checkpoints 7 | **/.ipynb_checkpoints/** 8 | 9 | **/notebooks 10 | 11 | *.swp 12 | 13 | .vscode/ 14 | -------------------------------------------------------------------------------- /dinov2/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /dinov2/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DINOv2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to DINOv2, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /dinov2/conda.yaml: -------------------------------------------------------------------------------- 1 | name: dinov2 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | - xformers 7 | - conda-forge 8 | dependencies: 9 | - python=3.9 10 | - pytorch::pytorch=2.0.0 11 | - pytorch::pytorch-cuda=11.7.0 12 | - pytorch::torchvision=0.15.0 13 | - omegaconf 14 | - torchmetrics=0.10.3 15 | - fvcore 16 | - iopath 17 | - xformers::xformers=0.0.18 18 | - pip 19 | - pip: 20 | - git+https://github.com/facebookincubator/submitit 21 | - --extra-index-url https://pypi.nvidia.com 22 | - cuml-cu11 23 | -------------------------------------------------------------------------------- /dinov2/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "0.0.1" 8 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pathlib 8 | 9 | from omegaconf import OmegaConf 10 | 11 | 12 | def load_config(config_name: str): 13 | config_filename = config_name + ".yaml" 14 | return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) 15 | 16 | 17 | dinov2_default_config = load_config("ssl_default_config") 18 | 19 | 20 | def load_and_merge_config(config_name: str): 21 | default_config = OmegaConf.create(dinov2_default_config) 22 | loaded_config = load_config(config_name) 23 | return OmegaConf.merge(default_config, loaded_config) 24 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitb14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_base 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitg14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_giant2 3 | patch_size: 14 4 | ffn_layer: swiglufused 5 | crops: 6 | global_crops_size: 518 # this is to set up the position embeddings properly 7 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitl14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vits14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_small 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/ssl_default_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHTS: '' 3 | compute_precision: 4 | grad_scaler: true 5 | teacher: 6 | backbone: 7 | sharding_strategy: SHARD_GRAD_OP 8 | mixed_precision: 9 | param_dtype: fp16 10 | reduce_dtype: fp16 11 | buffer_dtype: fp32 12 | dino_head: 13 | sharding_strategy: SHARD_GRAD_OP 14 | mixed_precision: 15 | param_dtype: fp16 16 | reduce_dtype: fp16 17 | buffer_dtype: fp32 18 | ibot_head: 19 | sharding_strategy: SHARD_GRAD_OP 20 | mixed_precision: 21 | param_dtype: fp16 22 | reduce_dtype: fp16 23 | buffer_dtype: fp32 24 | student: 25 | backbone: 26 | sharding_strategy: SHARD_GRAD_OP 27 | mixed_precision: 28 | param_dtype: fp16 29 | reduce_dtype: fp16 30 | buffer_dtype: fp32 31 | dino_head: 32 | sharding_strategy: SHARD_GRAD_OP 33 | mixed_precision: 34 | param_dtype: fp16 35 | reduce_dtype: fp32 36 | buffer_dtype: fp32 37 | ibot_head: 38 | sharding_strategy: SHARD_GRAD_OP 39 | mixed_precision: 40 | param_dtype: fp16 41 | reduce_dtype: fp32 42 | buffer_dtype: fp32 43 | dino: 44 | loss_weight: 1.0 45 | head_n_prototypes: 65536 46 | head_bottleneck_dim: 256 47 | head_nlayers: 3 48 | head_hidden_dim: 2048 49 | koleo_loss_weight: 0.1 50 | ibot: 51 | loss_weight: 1.0 52 | mask_sample_probability: 0.5 53 | mask_ratio_min_max: 54 | - 0.1 55 | - 0.5 56 | separate_head: false 57 | head_n_prototypes: 65536 58 | head_bottleneck_dim: 256 59 | head_nlayers: 3 60 | head_hidden_dim: 2048 61 | train: 62 | batch_size_per_gpu: 64 63 | dataset_path: ImageNet:split=TRAIN 64 | output_dir: . 65 | saveckp_freq: 20 66 | seed: 0 67 | num_workers: 10 68 | OFFICIAL_EPOCH_LENGTH: 1250 69 | cache_dataset: true 70 | centering: "centering" # or "sinkhorn_knopp" 71 | student: 72 | arch: vit_large 73 | patch_size: 16 74 | drop_path_rate: 0.3 75 | layerscale: 1.0e-05 76 | drop_path_uniform: true 77 | pretrained_weights: '' 78 | ffn_layer: "mlp" 79 | block_chunks: 0 80 | qkv_bias: true 81 | proj_bias: true 82 | ffn_bias: true 83 | teacher: 84 | momentum_teacher: 0.992 85 | final_momentum_teacher: 1 86 | warmup_teacher_temp: 0.04 87 | teacher_temp: 0.07 88 | warmup_teacher_temp_epochs: 30 89 | optim: 90 | epochs: 100 91 | weight_decay: 0.04 92 | weight_decay_end: 0.4 93 | base_lr: 0.004 # learning rate for a batch size of 1024 94 | lr: 0. # will be set after applying scaling rule 95 | warmup_epochs: 10 96 | min_lr: 1.0e-06 97 | clip_grad: 3.0 98 | freeze_last_layer_epochs: 1 99 | scaling_rule: sqrt_wrt_1024 100 | patch_embed_lr_mult: 0.2 101 | layerwise_decay: 0.9 102 | adamw_beta1: 0.9 103 | adamw_beta2: 0.999 104 | crops: 105 | global_crops_scale: 106 | - 0.32 107 | - 1.0 108 | local_crops_number: 8 109 | local_crops_scale: 110 | - 0.05 111 | - 0.32 112 | global_crops_size: 224 113 | local_crops_size: 96 114 | evaluation: 115 | eval_period_iterations: 12500 116 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitg14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 12 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_giant2 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitl14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 32 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_large 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitl16_short.yaml: -------------------------------------------------------------------------------- 1 | # this corresponds to the default config 2 | train: 3 | dataset_path: ImageNet:split=TRAIN 4 | batch_size_per_gpu: 64 5 | student: 6 | block_chunks: 4 7 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .adapters import DatasetWithEnumeratedTargets 8 | from .loaders import make_data_loader, make_dataset, SamplerType 9 | from .collate import collate_data_and_cast 10 | from .masking import MaskingGenerator 11 | from .augmentations import DataAugmentationDINO 12 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/adapters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Tuple 8 | 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class DatasetWithEnumeratedTargets(Dataset): 13 | def __init__(self, dataset): 14 | self._dataset = dataset 15 | 16 | def get_image_data(self, index: int) -> bytes: 17 | return self._dataset.get_image_data(index) 18 | 19 | def get_target(self, index: int) -> Tuple[Any, int]: 20 | target = self._dataset.get_target(index) 21 | return (index, target) 22 | 23 | def get_sample_decoder(self, index: int) -> Any: 24 | return self._dataset.get_sample_decoder(index) 25 | 26 | def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: 27 | image, target = self._dataset[index] 28 | target = index if target is None else target 29 | return image, (index, target) 30 | 31 | def __len__(self) -> int: 32 | return len(self._dataset) 33 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from torchvision import transforms 10 | 11 | from .transforms import ( 12 | GaussianBlur, 13 | make_normalize_transform, 14 | ) 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | class DataAugmentationDINO(object): 21 | def __init__( 22 | self, 23 | global_crops_scale, 24 | local_crops_scale, 25 | local_crops_number, 26 | global_crops_size=224, 27 | local_crops_size=96, 28 | ): 29 | self.global_crops_scale = global_crops_scale 30 | self.local_crops_scale = local_crops_scale 31 | self.local_crops_number = local_crops_number 32 | self.global_crops_size = global_crops_size 33 | self.local_crops_size = local_crops_size 34 | 35 | logger.info("###################################") 36 | logger.info("Using data augmentation parameters:") 37 | logger.info(f"global_crops_scale: {global_crops_scale}") 38 | logger.info(f"local_crops_scale: {local_crops_scale}") 39 | logger.info(f"local_crops_number: {local_crops_number}") 40 | logger.info(f"global_crops_size: {global_crops_size}") 41 | logger.info(f"local_crops_size: {local_crops_size}") 42 | logger.info("###################################") 43 | 44 | # random resized crop and flip 45 | self.geometric_augmentation_global = transforms.Compose( 46 | [ 47 | transforms.RandomResizedCrop( 48 | global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 49 | ), 50 | transforms.RandomHorizontalFlip(p=0.5), 51 | ] 52 | ) 53 | 54 | self.geometric_augmentation_local = transforms.Compose( 55 | [ 56 | transforms.RandomResizedCrop( 57 | local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 58 | ), 59 | transforms.RandomHorizontalFlip(p=0.5), 60 | ] 61 | ) 62 | 63 | # color distorsions / blurring 64 | color_jittering = transforms.Compose( 65 | [ 66 | transforms.RandomApply( 67 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 68 | p=0.8, 69 | ), 70 | transforms.RandomGrayscale(p=0.2), 71 | ] 72 | ) 73 | 74 | global_transfo1_extra = GaussianBlur(p=1.0) 75 | 76 | global_transfo2_extra = transforms.Compose( 77 | [ 78 | GaussianBlur(p=0.1), 79 | transforms.RandomSolarize(threshold=128, p=0.2), 80 | ] 81 | ) 82 | 83 | local_transfo_extra = GaussianBlur(p=0.5) 84 | 85 | # normalization 86 | self.normalize = transforms.Compose( 87 | [ 88 | transforms.ToTensor(), 89 | make_normalize_transform(), 90 | ] 91 | ) 92 | 93 | self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) 94 | self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) 95 | self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) 96 | 97 | def __call__(self, image): 98 | output = {} 99 | 100 | # global crops: 101 | im1_base = self.geometric_augmentation_global(image) 102 | global_crop_1 = self.global_transfo1(im1_base) 103 | 104 | im2_base = self.geometric_augmentation_global(image) 105 | global_crop_2 = self.global_transfo2(im2_base) 106 | 107 | output["global_crops"] = [global_crop_1, global_crop_2] 108 | 109 | # global crops for teacher: 110 | output["global_crops_teacher"] = [global_crop_1, global_crop_2] 111 | 112 | # local crops: 113 | local_crops = [ 114 | self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) 115 | ] 116 | output["local_crops"] = local_crops 117 | output["offsets"] = () 118 | 119 | return output 120 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import random 9 | 10 | 11 | def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): 12 | # dtype = torch.half # TODO: Remove 13 | 14 | n_global_crops = len(samples_list[0][0]["global_crops"]) 15 | n_local_crops = len(samples_list[0][0]["local_crops"]) 16 | 17 | collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) 18 | 19 | collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) 20 | 21 | B = len(collated_global_crops) 22 | N = n_tokens 23 | n_samples_masked = int(B * mask_probability) 24 | probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) 25 | upperbound = 0 26 | masks_list = [] 27 | for i in range(0, n_samples_masked): 28 | prob_min = probs[i] 29 | prob_max = probs[i + 1] 30 | masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) 31 | upperbound += int(N * prob_max) 32 | for i in range(n_samples_masked, B): 33 | masks_list.append(torch.BoolTensor(mask_generator(0))) 34 | 35 | random.shuffle(masks_list) 36 | 37 | collated_masks = torch.stack(masks_list).flatten(1) 38 | mask_indices_list = collated_masks.flatten().nonzero().flatten() 39 | 40 | masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] 41 | 42 | return { 43 | "collated_global_crops": collated_global_crops.to(dtype), 44 | "collated_local_crops": collated_local_crops.to(dtype), 45 | "collated_masks": collated_masks, 46 | "mask_indices_list": mask_indices_list, 47 | "masks_weight": masks_weight, 48 | "upperbound": upperbound, 49 | "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), 50 | } 51 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .image_net import ImageNet 8 | from .image_net_22k import ImageNet22k 9 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/decoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from io import BytesIO 8 | from typing import Any, Tuple 9 | 10 | from PIL import Image 11 | 12 | 13 | class Decoder: 14 | def decode(self) -> Any: 15 | raise NotImplementedError 16 | 17 | 18 | class ImageDataDecoder(Decoder): 19 | def __init__(self, image_data: bytes) -> None: 20 | self._image_data = image_data 21 | 22 | def decode(self) -> Image: 23 | f = BytesIO(self._image_data) 24 | return Image.open(f).convert(mode="RGB") 25 | 26 | 27 | class TargetDecoder(Decoder): 28 | def __init__(self, target: Any): 29 | self._target = target 30 | 31 | def decode(self) -> Any: 32 | return self._target 33 | 34 | 35 | class TupleDecoder(Decoder): 36 | def __init__(self, *decoders: Decoder): 37 | self._decoders: Tuple[Decoder, ...] = decoders 38 | 39 | def decode(self) -> Any: 40 | return (decoder.decode() for decoder in self._decoders) 41 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/extended.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Tuple 8 | 9 | from torchvision.datasets import VisionDataset 10 | 11 | from .decoders import Decoder, TargetDecoder, ImageDataDecoder, TupleDecoder 12 | 13 | 14 | class ExtendedVisionDataset(VisionDataset): 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) # type: ignore 17 | 18 | def get_image_data(self, index: int) -> bytes: 19 | raise NotImplementedError 20 | 21 | def get_target(self, index: int) -> Any: 22 | raise NotImplementedError 23 | 24 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 25 | try: 26 | image_data = self.get_image_data(index) 27 | image = ImageDataDecoder(image_data).decode() 28 | except Exception as e: 29 | raise RuntimeError(f"can not read image for sample {index}") from e 30 | target = self.get_target(index) 31 | target = TargetDecoder(target).decode() 32 | 33 | if self.transforms is not None: 34 | image, target = self.transforms(image, target) 35 | 36 | return image, target 37 | 38 | def get_sample_decoder(self, index: int) -> Decoder: 39 | image_data = self.get_image_data(index) 40 | target = self.get_target(index) 41 | return TupleDecoder( 42 | ImageDataDecoder(image_data), 43 | TargetDecoder(target), 44 | ) 45 | 46 | def __len__(self) -> int: 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/masking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import math 9 | import numpy as np 10 | 11 | 12 | class MaskingGenerator: 13 | def __init__( 14 | self, 15 | input_size, 16 | num_masking_patches=None, 17 | min_num_patches=4, 18 | max_num_patches=None, 19 | min_aspect=0.3, 20 | max_aspect=None, 21 | ): 22 | if not isinstance(input_size, tuple): 23 | input_size = (input_size,) * 2 24 | self.height, self.width = input_size 25 | 26 | self.num_patches = self.height * self.width 27 | self.num_masking_patches = num_masking_patches 28 | 29 | self.min_num_patches = min_num_patches 30 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 31 | 32 | max_aspect = max_aspect or 1 / min_aspect 33 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 34 | 35 | def __repr__(self): 36 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 37 | self.height, 38 | self.width, 39 | self.min_num_patches, 40 | self.max_num_patches, 41 | self.num_masking_patches, 42 | self.log_aspect_ratio[0], 43 | self.log_aspect_ratio[1], 44 | ) 45 | return repr_str 46 | 47 | def get_shape(self): 48 | return self.height, self.width 49 | 50 | def _mask(self, mask, max_mask_patches): 51 | delta = 0 52 | for _ in range(10): 53 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 54 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 55 | h = int(round(math.sqrt(target_area * aspect_ratio))) 56 | w = int(round(math.sqrt(target_area / aspect_ratio))) 57 | if w < self.width and h < self.height: 58 | top = random.randint(0, self.height - h) 59 | left = random.randint(0, self.width - w) 60 | 61 | num_masked = mask[top : top + h, left : left + w].sum() 62 | # Overlap 63 | if 0 < h * w - num_masked <= max_mask_patches: 64 | for i in range(top, top + h): 65 | for j in range(left, left + w): 66 | if mask[i, j] == 0: 67 | mask[i, j] = 1 68 | delta += 1 69 | 70 | if delta > 0: 71 | break 72 | return delta 73 | 74 | def __call__(self, num_masking_patches=0): 75 | mask = np.zeros(shape=self.get_shape(), dtype=bool) 76 | mask_count = 0 77 | while mask_count < num_masking_patches: 78 | max_mask_patches = num_masking_patches - mask_count 79 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 80 | 81 | delta = self._mask(mask, max_mask_patches) 82 | if delta == 0: 83 | break 84 | else: 85 | mask_count += delta 86 | 87 | return mask 88 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Sequence 8 | 9 | import torch 10 | from torchvision import transforms 11 | 12 | 13 | class GaussianBlur(transforms.RandomApply): 14 | """ 15 | Apply Gaussian Blur to the PIL image. 16 | """ 17 | 18 | def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): 19 | # NOTE: torchvision is applying 1 - probability to return the original image 20 | keep_p = 1 - p 21 | transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) 22 | super().__init__(transforms=[transform], p=keep_p) 23 | 24 | 25 | class MaybeToTensor(transforms.ToTensor): 26 | """ 27 | Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. 28 | """ 29 | 30 | def __call__(self, pic): 31 | """ 32 | Args: 33 | pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. 34 | Returns: 35 | Tensor: Converted image. 36 | """ 37 | if isinstance(pic, torch.Tensor): 38 | return pic 39 | return super().__call__(pic) 40 | 41 | 42 | # Use timm's names 43 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 44 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 45 | 46 | 47 | def make_normalize_transform( 48 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 49 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 50 | ) -> transforms.Normalize: 51 | return transforms.Normalize(mean=mean, std=std) 52 | 53 | 54 | # This roughly matches torchvision's preset for classification training: 55 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 56 | def make_classification_train_transform( 57 | *, 58 | crop_size: int = 224, 59 | interpolation=transforms.InterpolationMode.BICUBIC, 60 | hflip_prob: float = 0.5, 61 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 62 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 63 | ): 64 | transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 65 | if hflip_prob > 0.0: 66 | transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) 67 | transforms_list.extend( 68 | [ 69 | MaybeToTensor(), 70 | make_normalize_transform(mean=mean, std=std), 71 | ] 72 | ) 73 | return transforms.Compose(transforms_list) 74 | 75 | 76 | # This matches (roughly) torchvision's preset for classification evaluation: 77 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 78 | def make_classification_eval_transform( 79 | *, 80 | resize_size: int = 256, 81 | interpolation=transforms.InterpolationMode.BICUBIC, 82 | crop_size: int = 224, 83 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 84 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 85 | ) -> transforms.Compose: 86 | transforms_list = [ 87 | transforms.Resize(resize_size, interpolation=interpolation), 88 | transforms.CenterCrop(crop_size), 89 | MaybeToTensor(), 90 | make_normalize_transform(mean=mean, std=std), 91 | ] 92 | return transforms.Compose(transforms_list) 93 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import logging 9 | from typing import Any, Dict, Optional 10 | 11 | import torch 12 | from torch import Tensor 13 | from torchmetrics import Metric, MetricCollection 14 | from torchmetrics.classification import MulticlassAccuracy 15 | from torchmetrics.utilities.data import dim_zero_cat, select_topk 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | class MetricType(Enum): 22 | MEAN_ACCURACY = "mean_accuracy" 23 | MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" 24 | PER_CLASS_ACCURACY = "per_class_accuracy" 25 | IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" 26 | 27 | @property 28 | def accuracy_averaging(self): 29 | return getattr(AccuracyAveraging, self.name, None) 30 | 31 | def __str__(self): 32 | return self.value 33 | 34 | 35 | class AccuracyAveraging(Enum): 36 | MEAN_ACCURACY = "micro" 37 | MEAN_PER_CLASS_ACCURACY = "macro" 38 | PER_CLASS_ACCURACY = "none" 39 | 40 | def __str__(self): 41 | return self.value 42 | 43 | 44 | def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): 45 | if metric_type.accuracy_averaging is not None: 46 | return build_topk_accuracy_metric( 47 | average_type=metric_type.accuracy_averaging, 48 | num_classes=num_classes, 49 | ks=(1, 5) if ks is None else ks, 50 | ) 51 | elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: 52 | return build_topk_imagenet_real_accuracy_metric( 53 | num_classes=num_classes, 54 | ks=(1, 5) if ks is None else ks, 55 | ) 56 | 57 | raise ValueError(f"Unknown metric type {metric_type}") 58 | 59 | 60 | def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): 61 | metrics: Dict[str, Metric] = { 62 | f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks 63 | } 64 | return MetricCollection(metrics) 65 | 66 | 67 | def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): 68 | metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} 69 | return MetricCollection(metrics) 70 | 71 | 72 | class ImageNetReaLAccuracy(Metric): 73 | is_differentiable: bool = False 74 | higher_is_better: Optional[bool] = None 75 | full_state_update: bool = False 76 | 77 | def __init__( 78 | self, 79 | num_classes: int, 80 | top_k: int = 1, 81 | **kwargs: Any, 82 | ) -> None: 83 | super().__init__(**kwargs) 84 | self.num_classes = num_classes 85 | self.top_k = top_k 86 | self.add_state("tp", [], dist_reduce_fx="cat") 87 | 88 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 89 | # preds [B, D] 90 | # target [B, A] 91 | # preds_oh [B, D] with 0 and 1 92 | # select top K highest probabilities, use one hot representation 93 | preds_oh = select_topk(preds, self.top_k) 94 | # target_oh [B, D + 1] with 0 and 1 95 | target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) 96 | target = target.long() 97 | # for undefined targets (-1) use a fake value `num_classes` 98 | target[target == -1] = self.num_classes 99 | # fill targets, use one hot representation 100 | target_oh.scatter_(1, target, 1) 101 | # target_oh [B, D] (remove the fake target at index `num_classes`) 102 | target_oh = target_oh[:, :-1] 103 | # tp [B] with 0 and 1 104 | tp = (preds_oh * target_oh == 1).sum(dim=1) 105 | # at least one match between prediction and target 106 | tp.clip_(max=1) 107 | # ignore instances where no targets are defined 108 | mask = target_oh.sum(dim=1) > 0 109 | tp = tp[mask] 110 | self.tp.append(tp) # type: ignore 111 | 112 | def compute(self) -> Tensor: 113 | tp = dim_zero_cat(self.tp) # type: ignore 114 | return tp.float().mean() 115 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | from typing import Any, List, Optional, Tuple 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from dinov2.models import build_model_from_cfg 14 | from dinov2.utils.config import setup 15 | import dinov2.utils.utils as dinov2_utils 16 | 17 | 18 | def get_args_parser( 19 | description: Optional[str] = None, 20 | parents: Optional[List[argparse.ArgumentParser]] = [], 21 | add_help: bool = True, 22 | ): 23 | parser = argparse.ArgumentParser( 24 | description=description, 25 | parents=parents, 26 | add_help=add_help, 27 | ) 28 | parser.add_argument( 29 | "--config-file", 30 | type=str, 31 | help="Model configuration file", 32 | ) 33 | parser.add_argument( 34 | "--pretrained-weights", 35 | type=str, 36 | help="Pretrained model weights", 37 | ) 38 | parser.add_argument( 39 | "--output-dir", 40 | default="", 41 | type=str, 42 | help="Output directory to write results and logs", 43 | ) 44 | parser.add_argument( 45 | "--opts", 46 | help="Extra configuration options", 47 | default=[], 48 | nargs="+", 49 | ) 50 | return parser 51 | 52 | 53 | def get_autocast_dtype(config): 54 | teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype 55 | if teacher_dtype_str == "fp16": 56 | return torch.half 57 | elif teacher_dtype_str == "bf16": 58 | return torch.bfloat16 59 | else: 60 | return torch.float 61 | 62 | 63 | def build_model_for_eval(config, pretrained_weights): 64 | model, _ = build_model_from_cfg(config, only_teacher=True) 65 | dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") 66 | model.eval() 67 | model.cuda() 68 | return model 69 | 70 | 71 | def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: 72 | cudnn.benchmark = True 73 | config = setup(args) 74 | model = build_model_for_eval(config, args.pretrained_weights) 75 | autocast_dtype = get_autocast_dtype(config) 76 | return model, autocast_dtype 77 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | if attn_bias is not None: 77 | self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp 78 | else: 79 | self_att_op = None 80 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op) 81 | x = x.reshape([B, N, C]) 82 | 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | return x 86 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /dinov2/dinov2/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | import logging 9 | import os 10 | import sys 11 | from typing import Optional 12 | 13 | import dinov2.distributed as distributed 14 | from .helpers import MetricLogger, SmoothedValue 15 | 16 | 17 | # So that calling _configure_logger multiple times won't add many handlers 18 | @functools.lru_cache() 19 | def _configure_logger( 20 | name: Optional[str] = None, 21 | *, 22 | level: int = logging.DEBUG, 23 | output: Optional[str] = None, 24 | ): 25 | """ 26 | Configure a logger. 27 | 28 | Adapted from Detectron2. 29 | 30 | Args: 31 | name: The name of the logger to configure. 32 | level: The logging level to use. 33 | output: A file name or a directory to save log. If None, will not save log file. 34 | If ends with ".txt" or ".log", assumed to be a file name. 35 | Otherwise, logs will be saved to `output/log.txt`. 36 | 37 | Returns: 38 | The configured logger. 39 | """ 40 | 41 | logger = logging.getLogger(name) 42 | logger.setLevel(level) 43 | logger.propagate = False 44 | 45 | # Loosely match Google glog format: 46 | # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg 47 | # but use a shorter timestamp and include the logger name: 48 | # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg 49 | fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " 50 | fmt_message = "%(message)s" 51 | fmt = fmt_prefix + fmt_message 52 | datefmt = "%Y%m%d %H:%M:%S" 53 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 54 | 55 | # stdout logging for main worker only 56 | if distributed.is_main_process(): 57 | handler = logging.StreamHandler(stream=sys.stdout) 58 | handler.setLevel(logging.DEBUG) 59 | handler.setFormatter(formatter) 60 | logger.addHandler(handler) 61 | 62 | # file logging for all workers 63 | if output: 64 | if os.path.splitext(output)[-1] in (".txt", ".log"): 65 | filename = output 66 | else: 67 | filename = os.path.join(output, "logs", "log.txt") 68 | 69 | if not distributed.is_main_process(): 70 | global_rank = distributed.get_global_rank() 71 | filename = filename + ".rank{}".format(global_rank) 72 | 73 | os.makedirs(os.path.dirname(filename), exist_ok=True) 74 | 75 | handler = logging.StreamHandler(open(filename, "a")) 76 | handler.setLevel(logging.DEBUG) 77 | handler.setFormatter(formatter) 78 | logger.addHandler(handler) 79 | 80 | return logger 81 | 82 | 83 | def setup_logging( 84 | output: Optional[str] = None, 85 | *, 86 | name: Optional[str] = None, 87 | level: int = logging.DEBUG, 88 | capture_warnings: bool = True, 89 | ) -> None: 90 | """ 91 | Setup logging. 92 | 93 | Args: 94 | output: A file name or a directory to save log files. If None, log 95 | files will not be saved. If output ends with ".txt" or ".log", it 96 | is assumed to be a file name. 97 | Otherwise, logs will be saved to `output/log.txt`. 98 | name: The name of the logger to configure, by default the root logger. 99 | level: The logging level to use. 100 | capture_warnings: Whether warnings should be captured as logs. 101 | """ 102 | logging.captureWarnings(capture_warnings) 103 | _configure_logger(name, level=level, output=output) 104 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_clstoken_loss import DINOLoss 8 | from .ibot_patch_loss import iBOTPatchLoss 9 | from .koleo_loss import KoLeoLoss 10 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/dino_clstoken_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | class DINOLoss(nn.Module): 14 | def __init__( 15 | self, 16 | out_dim, 17 | student_temp=0.1, 18 | center_momentum=0.9, 19 | ): 20 | super().__init__() 21 | self.student_temp = student_temp 22 | self.center_momentum = center_momentum 23 | self.register_buffer("center", torch.zeros(1, out_dim)) 24 | self.updated = True 25 | self.reduce_handle = None 26 | self.len_teacher_output = None 27 | self.async_batch_center = None 28 | 29 | @torch.no_grad() 30 | def softmax_center_teacher(self, teacher_output, teacher_temp): 31 | self.apply_center_update() 32 | # teacher centering and sharpening 33 | return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) 34 | 35 | @torch.no_grad() 36 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): 37 | teacher_output = teacher_output.float() 38 | world_size = dist.get_world_size() if dist.is_initialized() else 1 39 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper 40 | B = Q.shape[1] * world_size # number of samples to assign 41 | K = Q.shape[0] # how many prototypes 42 | 43 | # make the matrix sums to 1 44 | sum_Q = torch.sum(Q) 45 | if dist.is_initialized(): 46 | dist.all_reduce(sum_Q) 47 | Q /= sum_Q 48 | 49 | for it in range(n_iterations): 50 | # normalize each row: total weight per prototype must be 1/K 51 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 52 | if dist.is_initialized(): 53 | dist.all_reduce(sum_of_rows) 54 | Q /= sum_of_rows 55 | Q /= K 56 | 57 | # normalize each column: total weight per sample must be 1/B 58 | Q /= torch.sum(Q, dim=0, keepdim=True) 59 | Q /= B 60 | 61 | Q *= B # the colomns must sum to 1 so that Q is an assignment 62 | return Q.t() 63 | 64 | def forward(self, student_output_list, teacher_out_softmaxed_centered_list): 65 | """ 66 | Cross-entropy between softmax outputs of the teacher and student networks. 67 | """ 68 | # TODO: Use cross_entropy_distribution here 69 | total_loss = 0 70 | for s in student_output_list: 71 | lsm = F.log_softmax(s / self.student_temp, dim=-1) 72 | for t in teacher_out_softmaxed_centered_list: 73 | loss = torch.sum(t * lsm, dim=-1) 74 | total_loss -= loss.mean() 75 | return total_loss 76 | 77 | @torch.no_grad() 78 | def update_center(self, teacher_output): 79 | self.reduce_center_update(teacher_output) 80 | 81 | @torch.no_grad() 82 | def reduce_center_update(self, teacher_output): 83 | self.updated = False 84 | self.len_teacher_output = len(teacher_output) 85 | self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 86 | if dist.is_initialized(): 87 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) 88 | 89 | @torch.no_grad() 90 | def apply_center_update(self): 91 | if self.updated is False: 92 | world_size = dist.get_world_size() if dist.is_initialized() else 1 93 | 94 | if self.reduce_handle is not None: 95 | self.reduce_handle.wait() 96 | _t = self.async_batch_center / (self.len_teacher_output * world_size) 97 | 98 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) 99 | 100 | self.updated = True 101 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/koleo_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | # import torch.distributed as dist 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class KoLeoLoss(nn.Module): 20 | """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.pdist = nn.PairwiseDistance(2, eps=1e-8) 25 | 26 | def pairwise_NNs_inner(self, x): 27 | """ 28 | Pairwise nearest neighbors for L2-normalized vectors. 29 | Uses Torch rather than Faiss to remain on GPU. 30 | """ 31 | # parwise dot products (= inverse distance) 32 | dots = torch.mm(x, x.t()) 33 | n = x.shape[0] 34 | dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 35 | # max inner prod -> min distance 36 | _, I = torch.max(dots, dim=1) # noqa: E741 37 | return I 38 | 39 | def forward(self, student_output, eps=1e-8): 40 | """ 41 | Args: 42 | student_output (BxD): backbone output of student 43 | """ 44 | with torch.cuda.amp.autocast(enabled=False): 45 | student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) 46 | I = self.pairwise_NNs_inner(student_output) # noqa: E741 47 | distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B 48 | loss = -torch.log(distances + eps).mean() 49 | return loss 50 | -------------------------------------------------------------------------------- /dinov2/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from . import vision_transformer as vits 10 | 11 | 12 | logger = logging.getLogger("dinov2") 13 | 14 | 15 | def build_model(args, only_teacher=False, img_size=224): 16 | args.arch = args.arch.removesuffix("_memeff") 17 | if "vit" in args.arch: 18 | vit_kwargs = dict( 19 | img_size=img_size, 20 | patch_size=args.patch_size, 21 | init_values=args.layerscale, 22 | ffn_layer=args.ffn_layer, 23 | block_chunks=args.block_chunks, 24 | qkv_bias=args.qkv_bias, 25 | proj_bias=args.proj_bias, 26 | ffn_bias=args.ffn_bias, 27 | ) 28 | teacher = vits.__dict__[args.arch](**vit_kwargs) 29 | if only_teacher: 30 | return teacher, teacher.embed_dim 31 | student = vits.__dict__[args.arch]( 32 | **vit_kwargs, 33 | drop_path_rate=args.drop_path_rate, 34 | drop_path_uniform=args.drop_path_uniform, 35 | ) 36 | embed_dim = student.embed_dim 37 | return student, teacher, embed_dim 38 | 39 | 40 | def build_model_from_cfg(cfg, only_teacher=False): 41 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 42 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.knn import get_args_parser as get_knn_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.knn import main as knn_main 25 | 26 | self._setup_args() 27 | knn_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 k-NN evaluation" 47 | knn_args_parser = get_knn_args_parser(add_help=False) 48 | parents = [knn_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:knn") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.linear import get_args_parser as get_linear_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.linear import main as linear_main 25 | 26 | self._setup_args() 27 | linear_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 linear evaluation" 47 | linear_args_parser = get_linear_args_parser(add_help=False) 48 | parents = [linear_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:linear") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/log_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.log_regression import main as log_regression_main 25 | 26 | self._setup_args() 27 | log_regression_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 logistic evaluation" 47 | log_regression_args_parser = get_log_regression_args_parser(add_help=False) 48 | parents = [log_regression_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:logreg") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import os 10 | from pathlib import Path 11 | from typing import List, Optional 12 | 13 | import submitit 14 | 15 | from dinov2.utils.cluster import ( 16 | get_slurm_executor_parameters, 17 | get_slurm_partition, 18 | get_user_checkpoint_path, 19 | ) 20 | 21 | 22 | logger = logging.getLogger("dinov2") 23 | 24 | 25 | def get_args_parser( 26 | description: Optional[str] = None, 27 | parents: Optional[List[argparse.ArgumentParser]] = [], 28 | add_help: bool = True, 29 | ) -> argparse.ArgumentParser: 30 | slurm_partition = get_slurm_partition() 31 | parser = argparse.ArgumentParser( 32 | description=description, 33 | parents=parents, 34 | add_help=add_help, 35 | ) 36 | parser.add_argument( 37 | "--ngpus", 38 | "--gpus", 39 | "--gpus-per-node", 40 | default=8, 41 | type=int, 42 | help="Number of GPUs to request on each node", 43 | ) 44 | parser.add_argument( 45 | "--nodes", 46 | "--nnodes", 47 | default=2, 48 | type=int, 49 | help="Number of nodes to request", 50 | ) 51 | parser.add_argument( 52 | "--timeout", 53 | default=2800, 54 | type=int, 55 | help="Duration of the job", 56 | ) 57 | parser.add_argument( 58 | "--partition", 59 | default=slurm_partition, 60 | type=str, 61 | help="Partition where to submit", 62 | ) 63 | parser.add_argument( 64 | "--use-volta32", 65 | action="store_true", 66 | help="Request V100-32GB GPUs", 67 | ) 68 | parser.add_argument( 69 | "--comment", 70 | default="", 71 | type=str, 72 | help="Comment to pass to scheduler, e.g. priority message", 73 | ) 74 | parser.add_argument( 75 | "--exclude", 76 | default="", 77 | type=str, 78 | help="Nodes to exclude", 79 | ) 80 | return parser 81 | 82 | 83 | def get_shared_folder() -> Path: 84 | user_checkpoint_path = get_user_checkpoint_path() 85 | if user_checkpoint_path is None: 86 | raise RuntimeError("Path to user checkpoint cannot be determined") 87 | path = user_checkpoint_path / "experiments" 88 | path.mkdir(exist_ok=True) 89 | return path 90 | 91 | 92 | def submit_jobs(task_class, args, name: str): 93 | if not args.output_dir: 94 | args.output_dir = str(get_shared_folder() / "%j") 95 | 96 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 97 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 98 | 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs["slurm_constraint"] = "volta32gb" 102 | if args.comment: 103 | kwargs["slurm_comment"] = args.comment 104 | if args.exclude: 105 | kwargs["slurm_exclude"] = args.exclude 106 | 107 | executor_params = get_slurm_executor_parameters( 108 | nodes=args.nodes, 109 | num_gpus_per_node=args.ngpus, 110 | timeout_min=args.timeout, # max is 60 * 72 111 | slurm_signal_delay_s=120, 112 | slurm_partition=args.partition, 113 | **kwargs, 114 | ) 115 | executor.update_parameters(name=name, **executor_params) 116 | 117 | task = task_class(args) 118 | job = executor.submit(task) 119 | 120 | logger.info(f"Submitted job_id: {job.job_id}") 121 | str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) 122 | logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") 123 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/train/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.logging import setup_logging 12 | from dinov2.train import get_args_parser as get_train_args_parser 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.train import main as train_main 25 | 26 | self._setup_args() 27 | train_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 training" 47 | train_args_parser = get_train_args_parser(add_help=False) 48 | parents = [train_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Trainer, args, name="dinov2:train") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .train import get_args_parser, main 8 | from .ssl_meta_arch import SSLMetaArch 9 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import os 9 | from pathlib import Path 10 | from typing import Any, Dict, Optional 11 | 12 | 13 | class ClusterType(Enum): 14 | AWS = "aws" 15 | FAIR = "fair" 16 | RSC = "rsc" 17 | 18 | 19 | def _guess_cluster_type() -> ClusterType: 20 | uname = os.uname() 21 | if uname.sysname == "Linux": 22 | if uname.release.endswith("-aws"): 23 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 24 | return ClusterType.AWS 25 | elif uname.nodename.startswith("rsc"): 26 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 27 | return ClusterType.RSC 28 | 29 | return ClusterType.FAIR 30 | 31 | 32 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 33 | if cluster_type is None: 34 | return _guess_cluster_type() 35 | 36 | return cluster_type 37 | 38 | 39 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 40 | cluster_type = get_cluster_type(cluster_type) 41 | if cluster_type is None: 42 | return None 43 | 44 | CHECKPOINT_DIRNAMES = { 45 | ClusterType.AWS: "checkpoints", 46 | ClusterType.FAIR: "checkpoint", 47 | ClusterType.RSC: "checkpoint/dino", 48 | } 49 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 50 | 51 | 52 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 53 | checkpoint_path = get_checkpoint_path(cluster_type) 54 | if checkpoint_path is None: 55 | return None 56 | 57 | username = os.environ.get("USER") 58 | assert username is not None 59 | return checkpoint_path / username 60 | 61 | 62 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 63 | cluster_type = get_cluster_type(cluster_type) 64 | if cluster_type is None: 65 | return None 66 | 67 | SLURM_PARTITIONS = { 68 | ClusterType.AWS: "learnlab", 69 | ClusterType.FAIR: "learnlab", 70 | ClusterType.RSC: "learn", 71 | } 72 | return SLURM_PARTITIONS[cluster_type] 73 | 74 | 75 | def get_slurm_executor_parameters( 76 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 77 | ) -> Dict[str, Any]: 78 | # create default parameters 79 | params = { 80 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 81 | "gpus_per_node": num_gpus_per_node, 82 | "tasks_per_node": num_gpus_per_node, # one task per GPU 83 | "cpus_per_task": 10, 84 | "nodes": nodes, 85 | "slurm_partition": get_slurm_partition(cluster_type), 86 | } 87 | # apply cluster-specific adjustments 88 | cluster_type = get_cluster_type(cluster_type) 89 | if cluster_type == ClusterType.AWS: 90 | params["cpus_per_task"] = 12 91 | del params["mem_gb"] 92 | elif cluster_type == ClusterType.RSC: 93 | params["cpus_per_task"] = 12 94 | # set additional parameters / apply overrides 95 | params.update(kwargs) 96 | return params 97 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import logging 9 | import os 10 | 11 | from omegaconf import OmegaConf 12 | 13 | import dinov2.distributed as distributed 14 | from dinov2.logging import setup_logging 15 | from dinov2.utils import utils 16 | from dinov2.configs import dinov2_default_config 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | def apply_scaling_rules_to_cfg(cfg): # to fix 23 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 24 | base_lr = cfg.optim.base_lr 25 | cfg.optim.lr = base_lr 26 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 27 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 28 | else: 29 | raise NotImplementedError 30 | return cfg 31 | 32 | 33 | def write_config(cfg, output_dir, name="config.yaml"): 34 | logger.info(OmegaConf.to_yaml(cfg)) 35 | saved_cfg_path = os.path.join(output_dir, name) 36 | with open(saved_cfg_path, "w") as f: 37 | OmegaConf.save(config=cfg, f=f) 38 | return saved_cfg_path 39 | 40 | 41 | def get_cfg_from_args(args): 42 | args.output_dir = os.path.abspath(args.output_dir) 43 | args.opts += [f"train.output_dir={args.output_dir}"] 44 | default_cfg = OmegaConf.create(dinov2_default_config) 45 | cfg = OmegaConf.load(args.config_file) 46 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 47 | return cfg 48 | 49 | 50 | def default_setup(args): 51 | distributed.enable(overwrite=True) 52 | seed = getattr(args, "seed", 0) 53 | rank = distributed.get_global_rank() 54 | 55 | global logger 56 | setup_logging(output=args.output_dir, level=logging.INFO) 57 | logger = logging.getLogger("dinov2") 58 | 59 | utils.fix_random_seeds(seed + rank) 60 | logger.info("git:\n {}\n".format(utils.get_sha())) 61 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 62 | 63 | 64 | def setup(args): 65 | """ 66 | Create configs and perform basic setups. 67 | """ 68 | cfg = get_cfg_from_args(args) 69 | os.makedirs(args.output_dir, exist_ok=True) 70 | default_setup(args) 71 | apply_scaling_rules_to_cfg(cfg) 72 | write_config(cfg, args.output_dir) 73 | return cfg 74 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | TypeSpec = Union[str, np.dtype, torch.dtype] 15 | 16 | 17 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 18 | np.dtype("bool"): torch.bool, 19 | np.dtype("uint8"): torch.uint8, 20 | np.dtype("int8"): torch.int8, 21 | np.dtype("int16"): torch.int16, 22 | np.dtype("int32"): torch.int32, 23 | np.dtype("int64"): torch.int64, 24 | np.dtype("float16"): torch.float16, 25 | np.dtype("float32"): torch.float32, 26 | np.dtype("float64"): torch.float64, 27 | np.dtype("complex64"): torch.complex64, 28 | np.dtype("complex128"): torch.complex128, 29 | } 30 | 31 | 32 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 33 | if isinstance(dtype, torch.dtype): 34 | return dtype 35 | if isinstance(dtype, str): 36 | dtype = np.dtype(dtype) 37 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 38 | return _NUMPY_TO_TORCH_DTYPE[dtype] 39 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import logging 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 15 | """ 16 | Calculate lr decay rate for different ViT blocks. 17 | Args: 18 | name (string): parameter name. 19 | lr_decay_rate (float): base lr decay rate. 20 | num_layers (int): number of ViT blocks. 21 | Returns: 22 | lr decay rate for the given parameter. 23 | """ 24 | layer_id = num_layers + 1 25 | if name.startswith("backbone") or force_is_backbone: 26 | if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name: 27 | layer_id = 0 28 | elif force_is_backbone and ( 29 | "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name 30 | ): 31 | layer_id = 0 32 | elif ".blocks." in name and ".residual." not in name: 33 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 34 | elif chunked_blocks and "blocks." in name and "residual." not in name: 35 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 36 | elif "blocks." in name and "residual." not in name: 37 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 38 | 39 | return lr_decay_rate ** (num_layers + 1 - layer_id) 40 | 41 | 42 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 43 | chunked_blocks = False 44 | if hasattr(model, "n_blocks"): 45 | logger.info("chunked fsdp") 46 | n_blocks = model.n_blocks 47 | chunked_blocks = model.chunked_blocks 48 | elif hasattr(model, "blocks"): 49 | logger.info("first code branch") 50 | n_blocks = len(model.blocks) 51 | elif hasattr(model, "backbone"): 52 | logger.info("second code branch") 53 | n_blocks = len(model.backbone.blocks) 54 | else: 55 | logger.info("else code branch") 56 | n_blocks = 0 57 | all_param_groups = [] 58 | 59 | for name, param in model.named_parameters(): 60 | name = name.replace("_fsdp_wrapped_module.", "") 61 | if not param.requires_grad: 62 | continue 63 | decay_rate = get_vit_lr_decay_rate( 64 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 65 | ) 66 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 67 | 68 | if "last_layer" in name: 69 | d.update({"is_last_layer": True}) 70 | 71 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 72 | d.update({"wd_multiplier": 0.0}) 73 | 74 | if "patch_embed" in name: 75 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 76 | 77 | all_param_groups.append(d) 78 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 79 | 80 | return all_param_groups 81 | 82 | 83 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 84 | fused_params_groups = defaultdict(lambda: {"params": []}) 85 | for d in all_params_groups: 86 | identifier = "" 87 | for k in keys: 88 | identifier += k + str(d[k]) + "_" 89 | 90 | for k in keys: 91 | fused_params_groups[identifier][k] = d[k] 92 | fused_params_groups[identifier]["params"].append(d["params"]) 93 | 94 | return fused_params_groups.values() 95 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | import subprocess 11 | from urllib.parse import urlparse 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 22 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 23 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 24 | else: 25 | state_dict = torch.load(pretrained_weights, map_location="cpu") 26 | if checkpoint_key is not None and checkpoint_key in state_dict: 27 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 28 | state_dict = state_dict[checkpoint_key] 29 | # remove `module.` prefix 30 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 31 | # remove `backbone.` prefix induced by multicrop wrapper 32 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 33 | msg = model.load_state_dict(state_dict, strict=False) 34 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 35 | 36 | 37 | def fix_random_seeds(seed=31): 38 | """ 39 | Fix random seeds. 40 | """ 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | 47 | def get_sha(): 48 | cwd = os.path.dirname(os.path.abspath(__file__)) 49 | 50 | def _run(command): 51 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 52 | 53 | sha = "N/A" 54 | diff = "clean" 55 | branch = "N/A" 56 | try: 57 | sha = _run(["git", "rev-parse", "HEAD"]) 58 | subprocess.check_output(["git", "diff"], cwd=cwd) 59 | diff = _run(["git", "diff-index", "HEAD"]) 60 | diff = "has uncommited changes" if diff else "clean" 61 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 62 | except Exception: 63 | pass 64 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 65 | return message 66 | 67 | 68 | class CosineScheduler(object): 69 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 70 | super().__init__() 71 | self.final_value = final_value 72 | self.total_iters = total_iters 73 | 74 | freeze_schedule = np.zeros((freeze_iters)) 75 | 76 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 77 | 78 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 79 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 80 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 81 | 82 | assert len(self.schedule) == self.total_iters 83 | 84 | def __getitem__(self, it): 85 | if it >= self.total_iters: 86 | return self.final_value 87 | else: 88 | return self.schedule[it] 89 | 90 | 91 | def has_batchnorms(model): 92 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 93 | for name, module in model.named_modules(): 94 | if isinstance(module, bn_types): 95 | return True 96 | return False 97 | -------------------------------------------------------------------------------- /dinov2/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.pylint.master] 5 | persistent = false 6 | score = false 7 | 8 | [tool.pylint.messages_control] 9 | disable = "all" 10 | enable = [ 11 | "miscellaneous", 12 | "similarities", 13 | ] 14 | 15 | [tool.pylint.similarities] 16 | ignore-comments = true 17 | ignore-docstrings = true 18 | ignore-imports = true 19 | min-similarity-lines = 8 20 | 21 | [tool.pylint.reports] 22 | reports = false 23 | 24 | [tool.pylint.miscellaneous] 25 | notes = [ 26 | "FIXME", 27 | "XXX", 28 | "TODO", 29 | ] 30 | -------------------------------------------------------------------------------- /dinov2/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==22.6.0 2 | flake8==5.0.4 3 | pylint==2.15.0 4 | -------------------------------------------------------------------------------- /dinov2/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==2.0.0 3 | torchvision==0.15.0 4 | omegaconf 5 | torchmetrics==0.10.3 6 | fvcore 7 | iopath 8 | xformers==0.0.18 9 | submitit 10 | --extra-index-url https://pypi.nvidia.com 11 | cuml-cu11 12 | -------------------------------------------------------------------------------- /dinov2/scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -n "$1" ]; then 4 | echo "linting \"$1\"" 5 | fi 6 | 7 | echo "running black" 8 | if [ -n "$1" ]; then 9 | black "$1" 10 | else 11 | black dinov2 12 | fi 13 | 14 | echo "running flake8" 15 | if [ -n "$1" ]; then 16 | flake8 "$1" 17 | else 18 | flake8 19 | fi 20 | 21 | echo "running pylint" 22 | if [ -n "$1" ]; then 23 | pylint "$1" 24 | else 25 | pylint dinov2 26 | fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /dinov2/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E501,W503 4 | per-file-ignores = 5 | __init__.py:F401 6 | -------------------------------------------------------------------------------- /dinov2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import re 9 | from typing import List, Tuple 10 | 11 | from setuptools import setup, find_packages 12 | 13 | 14 | NAME = "dinov2" 15 | DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method." 16 | 17 | URL = "https://github.com/facebookresearch/dinov2" 18 | AUTHOR = "FAIR" 19 | REQUIRES_PYTHON = ">=3.9.0" 20 | HERE = Path(__file__).parent 21 | 22 | 23 | try: 24 | with open(HERE / "README.md", encoding="utf-8") as f: 25 | long_description = "\n" + f.read() 26 | except FileNotFoundError: 27 | long_description = DESCRIPTION 28 | 29 | 30 | def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: 31 | requirements = [] 32 | extra_indices = [] 33 | with open(path) as f: 34 | for line in f.readlines(): 35 | line = line.rstrip("\r\n") 36 | if line.startswith("--extra-index-url "): 37 | extra_indices.append(line[18:]) 38 | continue 39 | requirements.append(line) 40 | return requirements, extra_indices 41 | 42 | 43 | def get_package_version() -> str: 44 | with open(HERE / "dinov2/__init__.py") as f: 45 | result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 46 | if result: 47 | return result.group(1) 48 | raise RuntimeError("Can't get package version") 49 | 50 | 51 | requirements, extra_indices = get_requirements() 52 | version = get_package_version() 53 | dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") 54 | 55 | 56 | setup( 57 | name=NAME, 58 | version=version, 59 | description=DESCRIPTION, 60 | long_description=long_description, 61 | long_description_content_type="text/markdown", 62 | author=AUTHOR, 63 | python_requires=REQUIRES_PYTHON, 64 | url=URL, 65 | packages=find_packages(), 66 | package_data={ 67 | "": ["*.yaml"], 68 | }, 69 | install_requires=requirements, 70 | dependency_links=extra_indices, 71 | extras_require={ 72 | "dev": dev_requirements, 73 | }, 74 | install_package_data=True, 75 | license="CC-BY-NC", 76 | license_files=("LICENSE",), 77 | classifiers=[ 78 | # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py 79 | "Development Status :: 3 - Alpha", 80 | "Intended Audience :: Developers", 81 | "Intended Audience :: Science/Research", 82 | "License :: Other/Proprietary License", 83 | "Programming Language :: Python :: 3.9", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | "Topic :: Software Development :: Libraries :: Python Modules", 86 | ], 87 | ) 88 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: anydoor 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.3=he6710b0_2 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1w=h7f8727e_0 15 | - pip=23.3.1=py38h06a4308_0 16 | - python=3.8.5=h7579374_1 17 | - readline=8.2=h5eee18b_0 18 | - sqlite=3.41.2=h5eee18b_0 19 | - tk=8.6.12=h1ccaba5_0 20 | - wheel=0.41.2=py38h06a4308_0 21 | - xz=5.4.5=h5eee18b_0 22 | - zlib=1.2.13=h5eee18b_0 23 | - pip: 24 | - absl-py==2.0.0 25 | - aiofiles==23.2.1 26 | - aiohttp==3.9.1 27 | - aiosignal==1.3.1 28 | - albumentations==1.3.0 29 | - altair==5.2.0 30 | - annotated-types==0.6.0 31 | - antlr4-python3-runtime==4.8 32 | - anyio==3.7.1 33 | - async-timeout==4.0.3 34 | - attrs==23.1.0 35 | - cachetools==5.3.2 36 | - certifi==2023.11.17 37 | - charset-normalizer==3.3.2 38 | - click==8.1.7 39 | - cloudpickle==3.0.0 40 | - cmake==3.28.0 41 | - contourpy==1.1.1 42 | - cycler==0.12.1 43 | - cython==3.0.6 44 | - einops==0.3.0 45 | - exceptiongroup==1.2.0 46 | - fastapi==0.105.0 47 | - ffmpy==0.3.1 48 | - filelock==3.13.1 49 | - fonttools==4.46.0 50 | - frozenlist==1.4.0 51 | - fsspec==2023.12.2 52 | - ftfy==6.1.3 53 | - future==0.18.3 54 | - fvcore==0.1.5.post20221221 55 | - google-auth==2.25.2 56 | - google-auth-oauthlib==1.0.0 57 | - gradio==3.39.0 58 | - gradio-client==0.7.2 59 | - grpcio==1.60.0 60 | - h11==0.14.0 61 | - httpcore==1.0.2 62 | - httpx==0.25.2 63 | - huggingface-hub==0.19.4 64 | - idna==3.6 65 | - imageio==2.33.1 66 | - importlib-metadata==7.0.0 67 | - importlib-resources==6.1.1 68 | - iopath==0.1.10 69 | - jinja2==3.1.2 70 | - joblib==1.3.2 71 | - jsonschema==4.20.0 72 | - jsonschema-specifications==2023.11.2 73 | - kiwisolver==1.4.5 74 | - lazy-loader==0.3 75 | - linkify-it-py==2.0.2 76 | - lit==17.0.6 77 | - lvis==0.5.3 78 | - markdown==3.5.1 79 | - markdown-it-py==2.2.0 80 | - markupsafe==2.1.3 81 | - matplotlib==3.7.4 82 | - mdit-py-plugins==0.3.3 83 | - mdurl==0.1.2 84 | - mpmath==1.3.0 85 | - multidict==6.0.4 86 | - mypy-extensions==1.0.0 87 | - networkx==3.1 88 | - numpy==1.23.1 89 | - nvidia-cublas-cu11==11.10.3.66 90 | - nvidia-cuda-cupti-cu11==11.7.101 91 | - nvidia-cuda-nvrtc-cu11==11.7.99 92 | - nvidia-cuda-runtime-cu11==11.7.99 93 | - nvidia-cudnn-cu11==8.5.0.96 94 | - nvidia-cufft-cu11==10.9.0.58 95 | - nvidia-curand-cu11==10.2.10.91 96 | - nvidia-cusolver-cu11==11.4.0.1 97 | - nvidia-cusparse-cu11==11.7.4.91 98 | - nvidia-nccl-cu11==2.14.3 99 | - nvidia-nvtx-cu11==11.7.91 100 | - oauthlib==3.2.2 101 | - omegaconf==2.1.1 102 | - open-clip-torch==2.17.1 103 | - opencv-contrib-python==4.3.0.36 104 | - opencv-python==4.7.0.72 105 | - opencv-python-headless==4.7.0.72 106 | - orjson==3.9.10 107 | - packaging==23.2 108 | - pandas==2.0.3 109 | - pillow==9.4.0 110 | - pkgutil-resolve-name==1.3.10 111 | - portalocker==2.8.2 112 | - protobuf==3.20.3 113 | - pyasn1==0.5.1 114 | - pyasn1-modules==0.3.0 115 | - pycocotools==2.0.7 116 | - pydantic==2.5.2 117 | - pydantic-core==2.14.5 118 | - pydeprecate==0.3.1 119 | - pydub==0.25.1 120 | - pyparsing==3.1.1 121 | - pyre-extensions==0.0.23 122 | - python-dateutil==2.8.2 123 | - python-multipart==0.0.6 124 | - pytorch-lightning==1.5.0 125 | - pytz==2023.3.post1 126 | - pywavelets==1.4.1 127 | - pyyaml==6.0.1 128 | - qudida==0.0.4 129 | - referencing==0.32.0 130 | - regex==2023.10.3 131 | - requests==2.31.0 132 | - requests-oauthlib==1.3.1 133 | - rpds-py==0.13.2 134 | - rsa==4.9 135 | - safetensors==0.2.7 136 | - scikit-image==0.21.0 137 | - scikit-learn==1.3.2 138 | - scipy==1.9.1 139 | - semantic-version==2.10.0 140 | - sentencepiece==0.1.99 141 | - setuptools==66.0.0 142 | - share==1.0.4 143 | - six==1.16.0 144 | - sniffio==1.3.0 145 | - starlette==0.27.0 146 | - submitit==1.5.1 147 | - sympy==1.12 148 | - tabulate==0.9.0 149 | - tensorboard==2.14.0 150 | - tensorboard-data-server==0.7.2 151 | - termcolor==2.4.0 152 | - threadpoolctl==3.2.0 153 | - tifffile==2023.7.10 154 | - timm==0.6.12 155 | - tokenizers==0.12.1 156 | - toolz==0.12.0 157 | - torch==2.0.0 158 | - torchmetrics==0.6.0 159 | - torchvision==0.15.1 160 | - tqdm==4.65.0 161 | - transformers==4.19.2 162 | - triton==2.0.0 163 | - typing-extensions==4.9.0 164 | - typing-inspect==0.9.0 165 | - tzdata==2023.3 166 | - uc-micro-py==1.0.2 167 | - urllib3==2.1.0 168 | - uvicorn==0.24.0.post1 169 | - wcwidth==0.2.12 170 | - websockets==11.0.3 171 | - werkzeug==3.0.1 172 | - xformers==0.0.18 173 | - yacs==0.1.8 174 | - yarl==1.9.4 175 | - zipp==3.17.0 176 | 177 | -------------------------------------------------------------------------------- /examples/Gradio/BG/00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/00.png -------------------------------------------------------------------------------- /examples/Gradio/BG/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/01.png -------------------------------------------------------------------------------- /examples/Gradio/BG/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/02.png -------------------------------------------------------------------------------- /examples/Gradio/BG/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/03.png -------------------------------------------------------------------------------- /examples/Gradio/BG/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/04.jpg -------------------------------------------------------------------------------- /examples/Gradio/BG/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/04.png -------------------------------------------------------------------------------- /examples/Gradio/BG/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/06.png -------------------------------------------------------------------------------- /examples/Gradio/BG/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/07.png -------------------------------------------------------------------------------- /examples/Gradio/BG/08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/08.jpg -------------------------------------------------------------------------------- /examples/Gradio/BG/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/13.jpg -------------------------------------------------------------------------------- /examples/Gradio/BG/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/17.jpg -------------------------------------------------------------------------------- /examples/Gradio/BG/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/BG/22.png -------------------------------------------------------------------------------- /examples/Gradio/FG/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/00.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/01.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/04.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/06.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/07.png -------------------------------------------------------------------------------- /examples/Gradio/FG/09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/09.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/18.png -------------------------------------------------------------------------------- /examples/Gradio/FG/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/22.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/25.png -------------------------------------------------------------------------------- /examples/Gradio/FG/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/28.png -------------------------------------------------------------------------------- /examples/Gradio/FG/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/33.png -------------------------------------------------------------------------------- /examples/Gradio/FG/36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/36.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/39.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/43.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/43.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/44.jpg -------------------------------------------------------------------------------- /examples/Gradio/FG/50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/Gradio/FG/50.jpg -------------------------------------------------------------------------------- /examples/TestDreamBooth/BG/000000047948_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/BG/000000047948_GT.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/BG/000000047948_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/BG/000000047948_mask.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/BG/000000309203_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/BG/000000309203_GT.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/BG/000000309203_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/BG/000000309203_mask.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/FG/00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/FG/00.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/FG/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/FG/01.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/FG/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/FG/02.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/FG/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/FG/03.png -------------------------------------------------------------------------------- /examples/TestDreamBooth/GEN/gen_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/examples/TestDreamBooth/GEN/gen_res.png -------------------------------------------------------------------------------- /install.ps1: -------------------------------------------------------------------------------- 1 | Set-Location $PSScriptRoot 2 | 3 | $Env:PIP_DISABLE_PIP_VERSION_CHECK = 1 4 | 5 | if (!(Test-Path -Path "venv")) { 6 | Write-Output "Creating venv for python..." 7 | python -m venv venv 8 | } 9 | .\venv\Scripts\activate 10 | 11 | python -m pip install pip==23.0.1 12 | 13 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 14 | 15 | pip install --no-deps xformers==0.0.22 16 | 17 | Write-Output "Installing deps..." 18 | 19 | pip install setuptools==66.0.0 20 | 21 | $SOURCEFILE="scripts/util.py" 22 | 23 | $TARGETFILE="venv/Lib/site-packages/setuptools/_distutils/util.py" 24 | 25 | Copy-Item -Path $SOURCEFILE -Destination $TARGETFILE -Force 26 | 27 | pip install share==1.0.4 28 | 29 | pip install -r requirements-windows.txt 30 | 31 | pip install git+https://github.com/cocodataset/panopticapi.git 32 | 33 | pip install pycocotools 34 | 35 | pip install lvis 36 | 37 | Write-Output "Checking models..." 38 | 39 | if (!(Test-Path -Path "path")) { 40 | Write-Output "Creating pretrained_models..." 41 | mkdir "path" 42 | } 43 | 44 | git lfs install 45 | git lfs clone https://huggingface.co/bdsqlsz/AnyDoor-Pruned ./path 46 | 47 | if (Test-Path -Path "path/.git/lfs") { 48 | Remove-Item -Path path/.git/lfs/* -Recurse -Force 49 | } 50 | 51 | Set-Location .\path 52 | 53 | Write-Output "Downloading dinoV2 models..." 54 | Write-Output "from https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth..." 55 | Write-Output "If it downloads slowly,you can copy link then close this window for downloading manually" 56 | 57 | wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth 58 | 59 | Write-Output "Install completed" 60 | Read-Host | Out-Null ; 61 | -------------------------------------------------------------------------------- /install_cn.ps1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/install_cn.ps1 -------------------------------------------------------------------------------- /iseg/coarse_mask_refine.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/iseg/coarse_mask_refine.pth -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/AnyDoor-for-windows/50640c4dd709b89607ffb336b006127a1bcdaf36/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /requirements-windows.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.0 2 | einops==0.3.0 3 | fvcore==0.1.5.post20221221 4 | gradio==3.39.0 5 | numpy==1.23.1 6 | omegaconf==2.1.1 7 | open_clip_torch==2.17.1 8 | opencv_contrib_python==4.7.0.72 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | Pillow==9.4.0 12 | pytorch_lightning==1.5.0 13 | safetensors==0.2.7 14 | scipy==1.9.1 15 | submitit==1.5.1 16 | timm==0.6.12 17 | #torch==2.0.0 18 | torchmetrics 19 | tqdm==4.65.0 20 | transformers==4.19.2 21 | #xformers==0.0.18 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.0 2 | einops==0.3.0 3 | fvcore==0.1.5.post20221221 4 | gradio==3.39.0 5 | numpy==1.23.1 6 | omegaconf==2.1.1 7 | open_clip_torch==2.17.1 8 | opencv_contrib_python==4.3.0.36 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | Pillow==9.4.0 12 | pytorch_lightning==1.5.0 13 | safetensors==0.2.7 14 | scipy==1.9.1 15 | setuptools==66.0.0 16 | share==1.0.4 17 | submitit==1.5.1 18 | timm==0.6.12 19 | torch==2.0.0 20 | torchmetrics==0.6.0 21 | tqdm==4.65.0 22 | transformers==4.19.2 23 | xformers==0.0.18 24 | -------------------------------------------------------------------------------- /run_dataset_debug.py: -------------------------------------------------------------------------------- 1 | from datasets.ytb_vos import YoutubeVOSDataset 2 | from datasets.ytb_vis import YoutubeVISDataset 3 | from datasets.saliency_modular import SaliencyDataset 4 | from datasets.vipseg import VIPSegDataset 5 | from datasets.mvimagenet import MVImageNetDataset 6 | from datasets.sam import SAMDataset 7 | from datasets.dreambooth import DreamBoothDataset 8 | from datasets.uvo import UVODataset 9 | from datasets.uvo_val import UVOValDataset 10 | from datasets.mose import MoseDataset 11 | from datasets.vitonhd import VitonHDDataset 12 | from datasets.fashiontryon import FashionTryonDataset 13 | from datasets.lvis import LvisDataset 14 | from torch.utils.data import ConcatDataset 15 | from torch.utils.data import DataLoader 16 | import numpy as np 17 | import cv2 18 | from omegaconf import OmegaConf 19 | 20 | # Datasets 21 | DConf = OmegaConf.load('./configs/datasets.yaml') 22 | dataset1 = YoutubeVOSDataset(**DConf.Train.YoutubeVOS) 23 | dataset2 = SaliencyDataset(**DConf.Train.Saliency) 24 | dataset3 = VIPSegDataset(**DConf.Train.VIPSeg) 25 | dataset4 = YoutubeVISDataset(**DConf.Train.YoutubeVIS) 26 | dataset5 = MVImageNetDataset(**DConf.Train.MVImageNet) 27 | dataset6 = SAMDataset(**DConf.Train.SAM) 28 | dataset7 = UVODataset(**DConf.Train.UVO.train) 29 | dataset8 = VitonHDDataset(**DConf.Train.VitonHD) 30 | dataset9 = UVOValDataset(**DConf.Train.UVO.val) 31 | dataset10 = MoseDataset(**DConf.Train.Mose) 32 | dataset11 = FashionTryonDataset(**DConf.Train.FashionTryon) 33 | dataset12 = LvisDataset(**DConf.Train.Lvis) 34 | 35 | dataset = dataset5 36 | 37 | 38 | def vis_sample(item): 39 | ref = item['ref']* 255 40 | tar = item['jpg'] * 127.5 + 127.5 41 | hint = item['hint'] * 127.5 + 127.5 42 | step = item['time_steps'] 43 | print(ref.shape, tar.shape, hint.shape, step.shape) 44 | 45 | ref = ref[0].numpy() 46 | tar = tar[0].numpy() 47 | hint_image = hint[0, :,:,:-1].numpy() 48 | hint_mask = hint[0, :,:,-1].numpy() 49 | hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1) 50 | ref = cv2.resize(ref.astype(np.uint8), (512,512)) 51 | vis = cv2.hconcat([ref.astype(np.float32), hint_image.astype(np.float32), hint_mask.astype(np.float32), tar.astype(np.float32) ]) 52 | cv2.imwrite('sample_vis.jpg',vis[:,:,::-1]) 53 | 54 | 55 | dataloader = DataLoader(dataset, num_workers=8, batch_size=4, shuffle=True) 56 | print('len dataloader: ', len(dataloader)) 57 | for data in dataloader: 58 | vis_sample(data) 59 | 60 | 61 | -------------------------------------------------------------------------------- /run_gui.ps1: -------------------------------------------------------------------------------- 1 | $configs="configs/stable-zero123.yaml" 2 | $image_path="./load/images/anya_front_rgba.png" 3 | $gpu=0 4 | 5 | Set-Location $PSScriptRoot 6 | .\venv\Scripts\activate 7 | 8 | $Env:HF_HOME = "./huggingface" 9 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 10 | #$Env:PYTHONPATH = $PSScriptRoot 11 | 12 | python run_gradio_demo.py -------------------------------------------------------------------------------- /run_train_anydoor.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader 3 | from datasets.ytb_vos import YoutubeVOSDataset 4 | from datasets.ytb_vis import YoutubeVISDataset 5 | from datasets.saliency_modular import SaliencyDataset 6 | from datasets.vipseg import VIPSegDataset 7 | from datasets.mvimagenet import MVImageNetDataset 8 | from datasets.sam import SAMDataset 9 | from datasets.uvo import UVODataset 10 | from datasets.uvo_val import UVOValDataset 11 | from datasets.mose import MoseDataset 12 | from datasets.vitonhd import VitonHDDataset 13 | from datasets.fashiontryon import FashionTryonDataset 14 | from datasets.lvis import LvisDataset 15 | from cldm.logger import ImageLogger 16 | from cldm.model import create_model, load_state_dict 17 | from torch.utils.data import ConcatDataset 18 | from cldm.hack import disable_verbosity, enable_sliced_attention 19 | from omegaconf import OmegaConf 20 | 21 | save_memory = False 22 | disable_verbosity() 23 | if save_memory: 24 | enable_sliced_attention() 25 | 26 | # Configs 27 | resume_path = 'path/to/weight' 28 | batch_size = 16 29 | logger_freq = 1000 30 | learning_rate = 1e-5 31 | sd_locked = False 32 | only_mid_control = False 33 | n_gpus = 2 34 | accumulate_grad_batches=1 35 | 36 | # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. 37 | model = create_model('./configs/anydoor.yaml').cpu() 38 | model.load_state_dict(load_state_dict(resume_path, location='cpu')) 39 | model.learning_rate = learning_rate 40 | model.sd_locked = sd_locked 41 | model.only_mid_control = only_mid_control 42 | 43 | # Datasets 44 | DConf = OmegaConf.load('./configs/datasets.yaml') 45 | dataset1 = YoutubeVOSDataset(**DConf.Train.YoutubeVOS) 46 | dataset2 = SaliencyDataset(**DConf.Train.Saliency) 47 | dataset3 = VIPSegDataset(**DConf.Train.VIPSeg) 48 | dataset4 = YoutubeVISDataset(**DConf.Train.YoutubeVIS) 49 | dataset5 = MVImageNetDataset(**DConf.Train.MVImageNet) 50 | dataset6 = SAMDataset(**DConf.Train.SAM) 51 | dataset7 = UVODataset(**DConf.Train.UVO.train) 52 | dataset8 = VitonHDDataset(**DConf.Train.VitonHD) 53 | dataset9 = UVOValDataset(**DConf.Train.UVO.val) 54 | dataset10 = MoseDataset(**DConf.Train.Mose) 55 | dataset11 = FashionTryonDataset(**DConf.Train.FashionTryon) 56 | dataset12 = LvisDataset(**DConf.Train.Lvis) 57 | 58 | image_data = [dataset2, dataset6, dataset12] 59 | video_data = [dataset1, dataset3, dataset4, dataset7, dataset9, dataset10 ] 60 | tryon_data = [dataset8, dataset11] 61 | threed_data = [dataset5] 62 | 63 | # The ratio of each dataset is adjusted by setting the __len__ 64 | dataset = ConcatDataset( image_data + video_data + tryon_data + threed_data + video_data + tryon_data + threed_data ) 65 | dataloader = DataLoader(dataset, num_workers=8, batch_size=batch_size, shuffle=True) 66 | logger = ImageLogger(batch_frequency=logger_freq) 67 | trainer = pl.Trainer(gpus=n_gpus, strategy="ddp", precision=16, accelerator="gpu", callbacks=[logger], progress_bar_refresh_rate=1, accumulate_grad_batches=accumulate_grad_batches) 68 | 69 | # Train! 70 | trainer.fit(model, dataloader) 71 | -------------------------------------------------------------------------------- /scripts/convert_weight.sh: -------------------------------------------------------------------------------- 1 | python tool_add_control_sd21.py path/v2-1_512-ema-pruned.ckpt path/control_sd21_ini.ckpt -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | unset WORLD_SIZE 2 | python run_inference.py -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | unset WORLD_SIZE 2 | python run_train_anydoor.py -------------------------------------------------------------------------------- /tool_add_control_sd21.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | assert len(sys.argv) == 3, 'Args are wrong.' 5 | 6 | input_path = sys.argv[1] 7 | output_path = sys.argv[2] 8 | 9 | assert os.path.exists(input_path), 'Input model does not exist.' 10 | assert not os.path.exists(output_path), 'Output filename already exists.' 11 | assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' 12 | 13 | import torch 14 | from share import * 15 | from cldm.model import create_model 16 | 17 | 18 | def get_node_name(name, parent_name): 19 | if len(name) <= len(parent_name): 20 | return False, '' 21 | p = name[:len(parent_name)] 22 | if p != parent_name: 23 | return False, '' 24 | return True, name[len(parent_name):] 25 | 26 | 27 | model = create_model(config_path='./models/anydoor.yaml') 28 | 29 | pretrained_weights = torch.load(input_path) 30 | if 'state_dict' in pretrained_weights: 31 | pretrained_weights = pretrained_weights['state_dict'] 32 | 33 | scratch_dict = model.state_dict() 34 | 35 | target_dict = {} 36 | for k in scratch_dict.keys(): 37 | 38 | is_control, name = get_node_name(k, 'control_') 39 | if 'control_model.input_blocks.0.0' in k: 40 | print('skipped key: ', k) 41 | continue 42 | 43 | if is_control: 44 | copy_k = 'model.diffusion_' + name 45 | else: 46 | copy_k = k 47 | if copy_k in pretrained_weights: 48 | target_dict[k] = pretrained_weights[copy_k].clone() 49 | else: 50 | target_dict[k] = scratch_dict[k].clone() 51 | print(f'These weights are newly added: {k}') 52 | 53 | model.load_state_dict(target_dict, strict=False) 54 | torch.save(model.state_dict(), output_path) 55 | print('Done.') 56 | --------------------------------------------------------------------------------