├── ldm ├── data │ ├── __init__.py │ ├── util.py │ ├── base.py │ └── simple.py ├── models │ └── diffusion │ │ ├── __init__.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ └── sampler.py │ │ └── sampling_util.py ├── modules │ ├── encoders │ │ └── __init__.py │ ├── midas │ │ ├── __init__.py │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── midas_net.py │ │ │ ├── dpt_depth.py │ │ │ └── midas_net_custom.py │ │ ├── utils.py │ │ └── api.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── upscaling.py │ ├── losses │ │ ├── __init__.py │ │ └── contperceptual.py │ ├── image_degradation │ │ ├── utils │ │ │ └── test.png │ │ └── __init__.py │ └── ema.py └── lr_scheduler.py ├── .gitignore ├── figures ├── main.png ├── sketch_to_art.jpg └── visualization.jpg ├── domainbed ├── optimizers.py ├── algorithms │ └── __init__.py ├── datasets │ ├── transforms.py │ └── __init__.py ├── lib │ ├── writers.py │ ├── fast_data_loader.py │ ├── wide_resnet.py │ ├── logger.py │ ├── swa_utils.py │ ├── query.py │ └── misc.py ├── models │ └── mixstyle.py ├── evaluator.py ├── hparams_registry.py ├── networks.py ├── swad.py └── command_launchers.py ├── bash ├── train_cls │ ├── pacs.sh │ ├── officehome.sh │ ├── vlcs.sh │ ├── pacs_interpolation.sh │ ├── officehome_interpolation.sh │ └── vlcs_interpolation.sh ├── train_dm │ ├── pacs.sh │ ├── vlcs.sh │ └── officehome.sh └── eval_cls │ ├── pacs.sh │ ├── vlcs.sh │ └── officehome.sh ├── environment.yaml ├── LICENCE ├── configs ├── OH │ ├── d012.yaml │ ├── d013.yaml │ ├── d023.yaml │ └── d123.yaml ├── PACS │ ├── d012.yaml │ ├── d013.yaml │ ├── d023.yaml │ └── d123.yaml └── VLCS │ ├── d012.yaml │ ├── d013.yaml │ ├── d023.yaml │ └── d123.yaml ├── eval_cls.py └── README.md /ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | save/ -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /figures/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mehrdad-Noori/FDS/HEAD/figures/main.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /figures/sketch_to_art.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mehrdad-Noori/FDS/HEAD/figures/sketch_to_art.jpg -------------------------------------------------------------------------------- /figures/visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mehrdad-Noori/FDS/HEAD/figures/visualization.jpg -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mehrdad-Noori/FDS/HEAD/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /domainbed/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_optimizer(name, params, **kwargs): 5 | name = name.lower() 6 | optimizers = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD, "adamw": torch.optim.AdamW} 7 | optim_cls = optimizers[name] 8 | 9 | return optim_cls(params, **kwargs) 10 | -------------------------------------------------------------------------------- /domainbed/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .algorithms import * 2 | 3 | 4 | def get_algorithm_class(algorithm_name): 5 | """Return the algorithm class with the given name.""" 6 | if algorithm_name not in globals(): 7 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 8 | return globals()[algorithm_name] 9 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /domainbed/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as T 2 | 3 | 4 | basic = T.Compose( 5 | [ 6 | T.Resize((224, 224)), 7 | T.ToTensor(), 8 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 9 | ] 10 | ) 11 | aug = T.Compose( 12 | [ 13 | T.RandomResizedCrop(224, scale=(0.7, 1.0)), 14 | T.RandomHorizontalFlip(), 15 | T.ColorJitter(0.3, 0.3, 0.3, 0.3), 16 | T.RandomGrayscale(p=0.1), 17 | T.ToTensor(), 18 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 19 | ] 20 | ) 21 | -------------------------------------------------------------------------------- /bash/train_cls/pacs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=PACS 6 | save_dir="./save/train_cls/${dataset}/" 7 | data_dir="./data" 8 | 9 | 10 | 11 | ### first seed 12 | python train_cls.py ${dataset}0 --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 13 | 14 | ### second seed 15 | python train_cls.py ${dataset}1 --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 16 | 17 | ### third seed 18 | python train_cls.py ${dataset}2 --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 19 | -------------------------------------------------------------------------------- /bash/train_cls/officehome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=OfficeHome 6 | save_dir="./save/train_cls/${dataset}/" 7 | data_dir="./data" 8 | 9 | 10 | 11 | ### first seed 12 | python train_cls.py ${dataset}0 --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 13 | 14 | ### second seed 15 | python train_cls.py ${dataset}1 --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 16 | 17 | ### third seed 18 | python train_cls.py ${dataset}2 --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir ${data_dir} --work_dir $save_dir 19 | -------------------------------------------------------------------------------- /bash/train_cls/vlcs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=VLCS 6 | save_dir="./save/train_cls/${dataset}/" 7 | data_dir="./data" 8 | 9 | 10 | 11 | ### first seed 12 | python train_cls.py ${dataset}0 --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 ${data_dir} --work_dir $save_dir 13 | 14 | ### second seed 15 | python train_cls.py ${dataset}1 --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 ${data_dir} --work_dir $save_dir 16 | 17 | ### third seed 18 | python train_cls.py ${dataset}2 --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 ${data_dir} --work_dir $save_dir 19 | -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | class AddMiDaS(object): 6 | def __init__(self, model_type): 7 | super().__init__() 8 | self.transform = load_midas_transform(model_type) 9 | 10 | def pt2np(self, x): 11 | x = ((x + 1.0) * .5).detach().cpu().numpy() 12 | return x 13 | 14 | def np2pt(self, x): 15 | x = torch.from_numpy(x) * 2 - 1. 16 | return x 17 | 18 | def __call__(self, sample): 19 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 20 | x = self.pt2np(sample['jpg']) 21 | x = self.transform({"image": x})["image"] 22 | sample['midas_in'] = x 23 | return sample -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /domainbed/lib/writers.py: -------------------------------------------------------------------------------- 1 | class Writer: 2 | def add_scalars(self, tag_scalar_dic, global_step): 3 | raise NotImplementedError() 4 | 5 | def add_scalars_with_prefix(self, tag_scalar_dic, global_step, prefix): 6 | tag_scalar_dic = {prefix + k: v for k, v in tag_scalar_dic.items()} 7 | self.add_scalars(tag_scalar_dic, global_step) 8 | 9 | 10 | class TBWriter(Writer): 11 | def __init__(self, dir_path): 12 | from tensorboardX import SummaryWriter 13 | 14 | self.writer = SummaryWriter(dir_path, flush_secs=30) 15 | 16 | def add_scalars(self, tag_scalar_dic, global_step): 17 | for tag, scalar in tag_scalar_dic.items(): 18 | self.writer.add_scalar(tag, scalar, global_step) 19 | 20 | 21 | def get_writer(dir_path): 22 | """ 23 | Args: 24 | dir_path: tb dir 25 | """ 26 | writer = TBWriter(dir_path) 27 | 28 | return writer 29 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fds 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.16.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 36 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 XXXX XXXX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from abc import abstractmethod 4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 5 | 6 | 7 | class Txt2ImgIterableBaseDataset(IterableDataset): 8 | ''' 9 | Define an interface to make the IterableDatasets for text2img data chainable 10 | ''' 11 | def __init__(self, num_records=0, valid_ids=None, size=256): 12 | super().__init__() 13 | self.num_records = num_records 14 | self.valid_ids = valid_ids 15 | self.sample_ids = valid_ids 16 | self.size = size 17 | 18 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 19 | 20 | def __len__(self): 21 | return self.num_records 22 | 23 | @abstractmethod 24 | def __iter__(self): 25 | pass 26 | 27 | 28 | class PRNGMixin(object): 29 | """ 30 | Adds a prng property which is a numpy RandomState which gets 31 | reinitialized whenever the pid changes to avoid synchronized sampling 32 | behavior when used in conjunction with multiprocessing. 33 | """ 34 | @property 35 | def prng(self): 36 | currentpid = os.getpid() 37 | if getattr(self, "_initpid", None) != currentpid: 38 | self._initpid = currentpid 39 | self._prng = np.random.RandomState() 40 | return self._prng 41 | -------------------------------------------------------------------------------- /bash/train_dm/pacs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=1 4 | nvidia-smi 5 | 6 | 7 | ### parameters 8 | dataset=PACS 9 | scale_lr=False 10 | no_test=True 11 | num_nodes=1 12 | check_val_every_n_epoch=5 13 | gpus="0,1,2,3" 14 | init_weights="init_weights/v1-5-pruned.ckpt" 15 | 16 | 17 | 18 | ### training for PACS dataset, 012 domain indexes as source domains (art, cartoon, photo) 19 | source_domains="012" 20 | config_dir="configs/${dataset}/d${source_domains}.yaml" 21 | logdir="save/dm/${dataset}/${source_domains}/" 22 | 23 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 24 | 25 | 26 | 27 | 28 | ### training for PACS dataset, 013 domain indexes as source domains (art, cartoon, sketch) 29 | source_domains="013" 30 | config_dir="configs/${dataset}/d${source_domains}.yaml" 31 | logdir="save/dm/${dataset}/${source_domains}/" 32 | 33 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 34 | 35 | 36 | 37 | 38 | ### training for PACS dataset, 023 domain indexes as source domains (art, photo, sketch) 39 | source_domains="023" 40 | config_dir="configs/${dataset}/d${source_domains}.yaml" 41 | logdir="save/dm/${dataset}/${source_domains}/" 42 | 43 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 44 | 45 | 46 | 47 | 48 | ### training for PACS dataset, 123 domain indexes as source domains (cartoon, photo, sketch) 49 | source_domains="123" 50 | config_dir="configs/${dataset}/d${source_domains}.yaml" 51 | logdir="save/dm/${dataset}/${source_domains}/" 52 | 53 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 54 | 55 | -------------------------------------------------------------------------------- /bash/train_dm/vlcs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=1 4 | nvidia-smi 5 | 6 | 7 | ### parameters 8 | dataset=VLCS 9 | scale_lr=False 10 | no_test=True 11 | num_nodes=1 12 | check_val_every_n_epoch=5 13 | gpus="0,1,2,3" 14 | init_weights="init_weights/v1-5-pruned.ckpt" 15 | 16 | 17 | 18 | ### training for VLCS dataset, 012 domain indexes as source domains (Caltech101, LabelMe, SUN09) 19 | source_domains="012" 20 | config_dir="configs/${dataset}/d${source_domains}.yaml" 21 | logdir="save/dm/${dataset}/${source_domains}/" 22 | 23 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 24 | 25 | 26 | 27 | 28 | ### training for VLCS dataset, 013 domain indexes as source domains (Caltech101, LabelMe, VOC2007) 29 | source_domains="013" 30 | config_dir="configs/${dataset}/d${source_domains}.yaml" 31 | logdir="save/dm/${dataset}/${source_domains}/" 32 | 33 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 34 | 35 | 36 | 37 | 38 | ### training for VLCS dataset, 023 domain indexes as source domains (Caltech101, SUN09, VOC2007) 39 | source_domains="023" 40 | config_dir="configs/${dataset}/d${source_domains}.yaml" 41 | logdir="save/dm/${dataset}/${source_domains}/" 42 | 43 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 44 | 45 | 46 | 47 | 48 | ### training for VLCS dataset, 123 domain indexes as source domains (LabelMe, SUN09, VOC2007) 49 | source_domains="123" 50 | config_dir="configs/${dataset}/d${source_domains}.yaml" 51 | logdir="save/dm/${dataset}/${source_domains}/" 52 | 53 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 54 | 55 | -------------------------------------------------------------------------------- /bash/train_dm/officehome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=1 4 | nvidia-smi 5 | 6 | 7 | ### parameters 8 | dataset=OfficeHome 9 | scale_lr=False 10 | no_test=True 11 | num_nodes=1 12 | check_val_every_n_epoch=5 13 | gpus="0,1,2,3" 14 | init_weights="init_weights/v1-5-pruned.ckpt" 15 | 16 | 17 | 18 | ### training for OfficeHome dataset, 012 domain indexes as source domains (Art, Clipart, Product) 19 | source_domains="012" 20 | config_dir="configs/${dataset}/d${source_domains}.yaml" 21 | logdir="save/dm/${dataset}/${source_domains}/" 22 | 23 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 24 | 25 | 26 | 27 | 28 | ### training for OfficeHome dataset, 013 domain indexes as source domains (Art, Clipart, Real World) 29 | source_domains="013" 30 | config_dir="configs/${dataset}/d${source_domains}.yaml" 31 | logdir="save/dm/${dataset}/${source_domains}/" 32 | 33 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 34 | 35 | 36 | 37 | 38 | ### training for OfficeHome dataset, 023 domain indexes as source domains (Art, Product, Real World) 39 | source_domains="023" 40 | config_dir="configs/${dataset}/d${source_domains}.yaml" 41 | logdir="save/dm/${dataset}/${source_domains}/" 42 | 43 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 44 | 45 | 46 | 47 | 48 | ### training for OfficeHome dataset, 123 domain indexes as source domains (Clipart, Product, Real World) 49 | source_domains="123" 50 | config_dir="configs/${dataset}/d${source_domains}.yaml" 51 | logdir="save/dm/${dataset}/${source_domains}/" 52 | 53 | python train_dm.py -t --base ${config_dir} --gpus ${gpus} --scale_lr ${scale_lr} --no-test ${no_test} --num_nodes ${num_nodes} --check_val_every_n_epoch ${check_val_every_n_epoch} --logdir ${logdir} --init_weights ${init_weights} 54 | 55 | -------------------------------------------------------------------------------- /domainbed/lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | 6 | class _InfiniteSampler(torch.utils.data.Sampler): 7 | """Wraps another Sampler to yield an infinite stream.""" 8 | 9 | def __init__(self, sampler): 10 | self.sampler = sampler 11 | 12 | def __iter__(self): 13 | while True: 14 | for batch in self.sampler: 15 | yield batch 16 | 17 | 18 | class InfiniteDataLoader: 19 | def __init__(self, dataset, weights, batch_size, num_workers): 20 | super().__init__() 21 | 22 | if weights: 23 | sampler = torch.utils.data.WeightedRandomSampler( 24 | weights, replacement=True, num_samples=batch_size 25 | ) 26 | else: 27 | sampler = torch.utils.data.RandomSampler(dataset, replacement=True) 28 | 29 | batch_sampler = torch.utils.data.BatchSampler( 30 | sampler, batch_size=batch_size, drop_last=True 31 | ) 32 | 33 | self._infinite_iterator = iter( 34 | torch.utils.data.DataLoader( 35 | dataset, 36 | num_workers=num_workers, 37 | batch_sampler=_InfiniteSampler(batch_sampler), 38 | ) 39 | ) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | 49 | class FastDataLoader: 50 | """ 51 | DataLoader wrapper with slightly improved speed by not respawning worker 52 | processes at every epoch. 53 | """ 54 | 55 | def __init__(self, dataset, batch_size, num_workers, shuffle=False): 56 | super().__init__() 57 | 58 | if shuffle: 59 | sampler = torch.utils.data.RandomSampler(dataset, replacement=False) 60 | else: 61 | sampler = torch.utils.data.SequentialSampler(dataset) 62 | 63 | batch_sampler = torch.utils.data.BatchSampler( 64 | sampler, 65 | batch_size=batch_size, 66 | drop_last=False, 67 | ) 68 | 69 | self._infinite_iterator = iter( 70 | torch.utils.data.DataLoader( 71 | dataset, 72 | num_workers=num_workers, 73 | batch_sampler=_InfiniteSampler(batch_sampler), 74 | ) 75 | ) 76 | 77 | self._length = len(batch_sampler) 78 | 79 | def __iter__(self): 80 | for _ in range(len(self)): 81 | yield next(self._infinite_iterator) 82 | 83 | def __len__(self): 84 | return self._length 85 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /domainbed/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 19 | 20 | 21 | def conv_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find("Conv") != -1: 24 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 25 | init.constant_(m.bias, 0) 26 | elif classname.find("BatchNorm") != -1: 27 | init.constant_(m.weight, 1) 28 | init.constant_(m.bias, 0) 29 | 30 | 31 | class wide_basic(nn.Module): 32 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 33 | super(wide_basic, self).__init__() 34 | self.bn1 = nn.BatchNorm2d(in_planes) 35 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 36 | self.dropout = nn.Dropout(p=dropout_rate) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 44 | ) 45 | 46 | def forward(self, x): 47 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 48 | out = self.conv2(F.relu(self.bn2(out))) 49 | out += self.shortcut(x) 50 | 51 | return out 52 | 53 | 54 | class Wide_ResNet(nn.Module): 55 | """Wide Resnet with the softmax layer chopped off""" 56 | 57 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 58 | super(Wide_ResNet, self).__init__() 59 | self.in_planes = 16 60 | 61 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" 62 | n = (depth - 4) / 6 63 | k = widen_factor 64 | 65 | # print('| Wide-Resnet %dx%d' % (depth, k)) 66 | nStages = [16, 16 * k, 32 * k, 64 * k] 67 | 68 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 69 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 70 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 71 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 72 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 73 | 74 | self.n_outputs = nStages[3] 75 | 76 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 77 | strides = [stride] + [1] * (int(num_blocks) - 1) 78 | layers = [] 79 | 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 82 | self.in_planes = planes 83 | 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = F.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | return out[:, :, 0, 0] 94 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /configs/OH/d012.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: OH 82 | root: ./data 83 | test_envs: [0, 1, 2] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Art, Bike" 91 | - "Clipart, Bike" 92 | - "Product, Bike" 93 | - "Real World, Bike" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/OH/d013.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: OH 82 | root: ./data 83 | test_envs: [0, 1, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Art, Bike" 91 | - "Clipart, Bike" 92 | - "Product, Bike" 93 | - "Real World, Bike" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/OH/d023.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: OH 82 | root: ./data 83 | test_envs: [0, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Art, Bike" 91 | - "Clipart, Bike" 92 | - "Product, Bike" 93 | - "Real World, Bike" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/OH/d123.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: OH 82 | root: ./data 83 | test_envs: [1, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Art, Bike" 91 | - "Clipart, Bike" 92 | - "Product, Bike" 93 | - "Real World, Bike" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/PACS/d012.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: PACS 82 | root: ./data 83 | test_envs: [0, 1, 2] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "art_painting, dog" 91 | - "cartoon, dog" 92 | - "photo, dog" 93 | - "sketch, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/PACS/d013.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: PACS 82 | root: ./data 83 | test_envs: [0, 1, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "art_painting, dog" 91 | - "cartoon, dog" 92 | - "photo, dog" 93 | - "sketch, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/PACS/d023.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: PACS 82 | root: ./data 83 | test_envs: [0, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "art_painting, dog" 91 | - "cartoon, dog" 92 | - "photo, dog" 93 | - "sketch, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/PACS/d123.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: PACS 82 | root: ./data 83 | test_envs: [1, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "art_painting, dog" 91 | - "cartoon, dog" 92 | - "photo, dog" 93 | - "sketch, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/VLCS/d012.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: VLCS 82 | root: ./data 83 | test_envs: [0, 1, 2] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Caltech101, dog" 91 | - "LabelMe, dog" 92 | - "SUN09, dog" 93 | - "VOC2007, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/VLCS/d013.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: VLCS 82 | root: ./data 83 | test_envs: [0, 1, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Caltech101, dog" 91 | - "LabelMe, dog" 92 | - "SUN09, dog" 93 | - "VOC2007, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/VLCS/d023.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: VLCS 82 | root: ./data 83 | test_envs: [0, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Caltech101, dog" 91 | - "LabelMe, dog" 92 | - "SUN09, dog" 93 | - "VOC2007, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /configs/VLCS/d123.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "txt" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | scale_factor: 0.18215 17 | 18 | scheduler_config: 19 | target: ldm.lr_scheduler.LambdaLinearScheduler 20 | params: 21 | warm_up_steps: [ 1 ] 22 | cycle_lengths: [ 10000000000000 ] 23 | f_start: [ 1.e-6 ] 24 | f_max: [ 1. ] 25 | f_min: [ 1. ] 26 | 27 | unet_config: 28 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_heads: 8 38 | use_spatial_transformer: True 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: True 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | 71 | 72 | data: 73 | target: ldm.util.DataModuleFromConfig 74 | params: 75 | batch_size: 24 76 | num_workers: 8 77 | num_val_workers: 0 78 | train: 79 | target: ldm.data.datasets.CombinedTextDomains 80 | params: 81 | dataset_name: VLCS 82 | root: ./data 83 | test_envs: [1, 2, 3] 84 | data_augmentation: False 85 | image_size: 256 86 | validation: 87 | target: ldm.data.simple.TextOnly 88 | params: 89 | captions: 90 | - "Caltech101, dog" 91 | - "LabelMe, dog" 92 | - "SUN09, dog" 93 | - "VOC2007, dog" 94 | output_size: 256 95 | n_gpus: 4 96 | 97 | 98 | lightning: 99 | find_unused_parameters: False 100 | 101 | modelcheckpoint: 102 | params: 103 | every_n_train_steps: 2000 104 | save_top_k: -1 105 | monitor: null 106 | 107 | callbacks: 108 | image_logger: 109 | target: ldm.util.ImageLogger 110 | params: 111 | batch_frequency: 2000 112 | max_images: 4 113 | increase_log_steps: False 114 | log_first_step: True 115 | log_all_val: True 116 | log_images_kwargs: 117 | use_ema_scope: True 118 | inpaint: False 119 | plot_progressive_rows: False 120 | plot_diffusion_rows: False 121 | N: 4 122 | unconditional_guidance_scale: 5.0 123 | unconditional_guidance_label: [""] 124 | 125 | trainer: 126 | benchmark: True 127 | num_sanity_val_steps: 0 128 | accumulate_grad_batches: 1 129 | precision: 16 130 | max_steps: 10000 -------------------------------------------------------------------------------- /domainbed/lib/logger.py: -------------------------------------------------------------------------------- 1 | """ Singleton Logger """ 2 | import sys 3 | import logging 4 | 5 | 6 | def levelize(levelname): 7 | """Convert levelname to level only if it is levelname""" 8 | if isinstance(levelname, str): 9 | return logging.getLevelName(levelname) 10 | else: 11 | return levelname # already level 12 | 13 | 14 | class ColorFormatter(logging.Formatter): 15 | color_dic = { 16 | "DEBUG": 37, # white 17 | "INFO": 36, # cyan 18 | "WARNING": 33, # yellow 19 | "ERROR": 31, # red 20 | "CRITICAL": 41, # white on red bg 21 | } 22 | 23 | def format(self, record): 24 | color = self.color_dic.get(record.levelname, 37) # default white 25 | record.levelname = "\033[{}m{}\033[0m".format(color, record.levelname) 26 | return logging.Formatter.format(self, record) 27 | 28 | 29 | class Logger(logging.Logger): 30 | NAME = "SingletonLogger" 31 | 32 | @classmethod 33 | def get(cls, file_path=None, level="INFO", colorize=True, track_code=False): 34 | logging.setLoggerClass(cls) 35 | logger = logging.getLogger(cls.NAME) 36 | logging.setLoggerClass(logging.Logger) # restore 37 | logger.setLevel(level) 38 | 39 | if logger.hasHandlers(): 40 | # If logger already got all handlers (# handlers == 2), use the logger. 41 | # else, re-set handlers. 42 | if len(logger.handlers) == 2: 43 | return logger 44 | 45 | logger.handlers.clear() 46 | 47 | log_format = "%(levelname)s %(asctime)s | %(message)s" 48 | # log_format = '%(asctime)s | %(message)s' 49 | if track_code: 50 | log_format = ( 51 | "%(levelname)s::%(asctime)s | [%(filename)s] [%(funcName)s:%(lineno)d] " 52 | "%(message)s" 53 | ) 54 | date_format = "%m/%d %H:%M:%S" 55 | if colorize: 56 | formatter = ColorFormatter(log_format, date_format) 57 | else: 58 | formatter = logging.Formatter(log_format, date_format) 59 | 60 | # standard output handler 61 | # NOTE as default, StreamHandler use stderr stream instead of stdout stream. 62 | # Use StreamHandler(sys.stdout) for stdout stream. 63 | stream_handler = logging.StreamHandler(sys.stdout) 64 | stream_handler.setFormatter(formatter) 65 | logger.addHandler(stream_handler) 66 | 67 | if file_path: 68 | # file output handler 69 | file_handler = logging.FileHandler(file_path) 70 | file_handler.setFormatter(formatter) 71 | logger.addHandler(file_handler) 72 | 73 | logger.propagate = False 74 | 75 | return logger 76 | 77 | def nofmt(self, msg, *args, level="INFO", **kwargs): 78 | level = levelize(level) 79 | formatters = self.remove_formats() 80 | super().log(level, msg, *args, **kwargs) 81 | self.set_formats(formatters) 82 | 83 | def remove_formats(self): 84 | """Remove all formats from logger""" 85 | formatters = [] 86 | for handler in self.handlers: 87 | formatters.append(handler.formatter) 88 | handler.setFormatter(logging.Formatter("%(message)s")) 89 | 90 | return formatters 91 | 92 | def set_formats(self, formatters): 93 | """Set formats to every handler of logger""" 94 | for handler, formatter in zip(self.handlers, formatters): 95 | handler.setFormatter(formatter) 96 | 97 | def set_file_handler(self, file_path): 98 | file_handler = logging.FileHandler(file_path) 99 | formatter = self.handlers[0].formatter 100 | file_handler.setFormatter(formatter) 101 | self.addHandler(file_handler) 102 | -------------------------------------------------------------------------------- /bash/eval_cls/pacs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### 012 domain indexes to get the generated data related to source domains (art, cartoon, photo) => indexes: "0" "1" "2" 5 | dataset="PACS" 6 | source_domains="012" 7 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 8 | 9 | #seed0 10 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 11 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 12 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 13 | 14 | #seed1 15 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 16 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 17 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 18 | 19 | #seed2 20 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 21 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 22 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 23 | 24 | 25 | 26 | 27 | 28 | 29 | ### 013 domain indexes to get the generated data related to source domains (art, cartoon, sketch) => indexes: "0" "1" "3" 30 | dataset="PACS" 31 | source_domains="013" 32 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 33 | 34 | #seed0 35 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 36 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 37 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 38 | 39 | #seed1 40 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 41 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 42 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 43 | 44 | #seed2 45 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 46 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 47 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | ### 023 domain indexes to get the generated data related to source domains (art, photo, sketch) => indexes: "0" "2" "3" 56 | dataset="PACS" 57 | source_domains="023" 58 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 59 | 60 | #seed0 61 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 62 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 63 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 64 | 65 | #seed1 66 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 67 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 68 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 69 | 70 | #seed2 71 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 72 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 73 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | ### 123 domain indexes to get the generated data related to source domains (cartoon, photo, sketch) => indexes: "1" "2" "3" 82 | dataset="PACS" 83 | source_domains="123" 84 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 85 | 86 | #seed0 87 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 88 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 89 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 90 | 91 | #seed1 92 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 93 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 94 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 95 | 96 | #seed2 97 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 98 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 99 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 100 | -------------------------------------------------------------------------------- /bash/eval_cls/vlcs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### 012 domain indexes to get the generated data related to source domains (Caltech101, LabelMe, SUN09) => indexes: "0" "1" "2" 5 | dataset="VLCS" 6 | source_domains="012" 7 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 8 | 9 | #seed0 10 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 11 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 12 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 13 | 14 | #seed1 15 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 16 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 17 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 18 | 19 | #seed2 20 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 21 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 22 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 23 | 24 | 25 | 26 | 27 | 28 | 29 | ### 013 domain indexes to get the generated data related to source domains (Caltech101, LabelMe, VOC2007) => indexes: "0" "1" "3" 30 | dataset="PACS" 31 | source_domains="013" 32 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 33 | 34 | #seed0 35 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 36 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 37 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 38 | 39 | #seed1 40 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 41 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 42 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 43 | 44 | #seed2 45 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 46 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 47 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | ### 023 domain indexes to get the generated data related to source domains (Caltech101, SUN09, VOC2007) => indexes: "0" "2" "3" 56 | dataset="PACS" 57 | source_domains="023" 58 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 59 | 60 | #seed0 61 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 62 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 63 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 64 | 65 | #seed1 66 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 67 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 68 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 69 | 70 | #seed2 71 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 72 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 73 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | ### 123 domain indexes to get the generated data related to source domains (LabelMe, SUN09, VOC2007) => indexes: "1" "2" "3" 82 | dataset="PACS" 83 | source_domains="123" 84 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 85 | 86 | #seed0 87 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 88 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 89 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 90 | 91 | #seed1 92 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 93 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 94 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 95 | 96 | #seed2 97 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 98 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 99 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 100 | -------------------------------------------------------------------------------- /bash/eval_cls/officehome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### 012 domain indexes to get the generated data related to source domains (Art, Clipart, Product) => indexes: "0" "1" "2" 5 | dataset="OfficeHome" 6 | source_domains="012" 7 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 8 | 9 | #seed0 10 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 11 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 12 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 13 | 14 | #seed1 15 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 16 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 17 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 18 | 19 | #seed2 20 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 21 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 22 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 23 | 24 | 25 | 26 | 27 | 28 | 29 | ### 013 domain indexes to get the generated data related to source domains (Art, Clipart, Real World) => indexes: "0" "1" "3" 30 | dataset="PACS" 31 | source_domains="013" 32 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 33 | 34 | #seed0 35 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 36 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 37 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 38 | 39 | #seed1 40 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 41 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 42 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 43 | 44 | #seed2 45 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 46 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 47 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | ### 023 domain indexes to get the generated data related to source domains (Art, Product, Real World) => indexes: "0" "2" "3" 56 | dataset="PACS" 57 | source_domains="023" 58 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 59 | 60 | #seed0 61 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 62 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 63 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 64 | 65 | #seed1 66 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 67 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 68 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 69 | 70 | #seed2 71 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 72 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 73 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | ### 123 domain indexes to get the generated data related to source domains (Clipart, Product, Real World) => indexes: "1" "2" "3" 82 | dataset="PACS" 83 | source_domains="123" 84 | generated_data_dir="save/dm/${dataset}/${source_domains}/generation" 85 | 86 | #seed0 87 | save_dir="save/eval/${dataset}/${source_domains}/seed0" 88 | ckpt_dir="path/to/the/pretrained/model/seed0/model.pth" 89 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 90 | 91 | #seed1 92 | save_dir="save/eval/${dataset}/${source_domains}/seed1" 93 | ckpt_dir="path/to/the/pretrained/model/seed1/model.pth" 94 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 95 | 96 | #seed2 97 | save_dir="save/eval/${dataset}/${source_domains}/seed2" 98 | ckpt_dir="path/to/the/pretrained/model/seed2/model.pth" 99 | python eval_cls.py --data_dir $generated_data_dir --save_dir $save_dir --ckpt_dir ${ckpt_dir} 100 | -------------------------------------------------------------------------------- /domainbed/models/mixstyle.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/mixstyle.py 3 | """ 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MixStyle(nn.Module): 10 | """MixStyle. 11 | Reference: 12 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 13 | """ 14 | 15 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 16 | """ 17 | Args: 18 | p (float): probability of using MixStyle. 19 | alpha (float): parameter of the Beta distribution. 20 | eps (float): scaling parameter to avoid numerical issues. 21 | """ 22 | super().__init__() 23 | self.p = p 24 | self.beta = torch.distributions.Beta(alpha, alpha) 25 | self.eps = eps 26 | self.alpha = alpha 27 | 28 | print("* MixStyle params") 29 | print(f"- p: {p}") 30 | print(f"- alpha: {alpha}") 31 | 32 | def __repr__(self): 33 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})" 34 | 35 | def forward(self, x): 36 | if not self.training: 37 | return x 38 | 39 | if random.random() > self.p: 40 | return x 41 | 42 | B = x.size(0) 43 | 44 | mu = x.mean(dim=[2, 3], keepdim=True) 45 | var = x.var(dim=[2, 3], keepdim=True) 46 | sig = (var + self.eps).sqrt() 47 | mu, sig = mu.detach(), sig.detach() 48 | x_normed = (x - mu) / sig 49 | 50 | lmda = self.beta.sample((B, 1, 1, 1)) 51 | lmda = lmda.to(x.device) 52 | 53 | perm = torch.randperm(B) 54 | mu2, sig2 = mu[perm], sig[perm] 55 | mu_mix = mu * lmda + mu2 * (1 - lmda) 56 | sig_mix = sig * lmda + sig2 * (1 - lmda) 57 | 58 | return x_normed * sig_mix + mu_mix 59 | 60 | 61 | class MixStyle2(nn.Module): 62 | """MixStyle (w/ domain prior). 63 | The input should contain two equal-sized mini-batches from two distinct domains. 64 | Reference: 65 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 66 | """ 67 | 68 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 69 | """ 70 | Args: 71 | p (float): probability of using MixStyle. 72 | alpha (float): parameter of the Beta distribution. 73 | eps (float): scaling parameter to avoid numerical issues. 74 | """ 75 | super().__init__() 76 | self.p = p 77 | self.beta = torch.distributions.Beta(alpha, alpha) 78 | self.eps = eps 79 | self.alpha = alpha 80 | 81 | print("* MixStyle params") 82 | print(f"- p: {p}") 83 | print(f"- alpha: {alpha}") 84 | 85 | def __repr__(self): 86 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})" 87 | 88 | def forward(self, x): 89 | """ 90 | For the input x, the first half comes from one domain, 91 | while the second half comes from the other domain. 92 | """ 93 | if not self.training: 94 | return x 95 | 96 | if random.random() > self.p: 97 | return x 98 | 99 | B = x.size(0) 100 | 101 | mu = x.mean(dim=[2, 3], keepdim=True) 102 | var = x.var(dim=[2, 3], keepdim=True) 103 | sig = (var + self.eps).sqrt() 104 | mu, sig = mu.detach(), sig.detach() 105 | x_normed = (x - mu) / sig 106 | 107 | lmda = self.beta.sample((B, 1, 1, 1)) 108 | lmda = lmda.to(x.device) 109 | 110 | perm = torch.arange(B - 1, -1, -1) # inverse index 111 | perm_b, perm_a = perm.chunk(2) 112 | perm_b = perm_b[torch.randperm(B // 2)] 113 | perm_a = perm_a[torch.randperm(B // 2)] 114 | perm = torch.cat([perm_b, perm_a], 0) 115 | 116 | mu2, sig2 = mu[perm], sig[perm] 117 | mu_mix = mu * lmda + mu2 * (1 - lmda) 118 | sig_mix = sig * lmda + sig2 * (1 - lmda) 119 | 120 | return x_normed * sig_mix + mu_mix 121 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /domainbed/evaluator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from domainbed.lib.fast_data_loader import FastDataLoader 6 | 7 | if torch.cuda.is_available(): 8 | device = "cuda" 9 | else: 10 | device = "cpu" 11 | 12 | 13 | def accuracy_from_loader(algorithm, loader, weights, debug=False): 14 | correct = 0 15 | total = 0 16 | losssum = 0.0 17 | weights_offset = 0 18 | 19 | algorithm.eval() 20 | 21 | for i, batch in enumerate(loader): 22 | x = batch["x"].to(device) 23 | y = batch["y"].to(device) 24 | 25 | with torch.no_grad(): 26 | logits = algorithm.predict(x) 27 | loss = F.cross_entropy(logits, y).item() 28 | 29 | B = len(x) 30 | losssum += loss * B 31 | 32 | if weights is None: 33 | batch_weights = torch.ones(len(x)) 34 | else: 35 | batch_weights = weights[weights_offset : weights_offset + len(x)] 36 | weights_offset += len(x) 37 | batch_weights = batch_weights.to(device) 38 | if logits.size(1) == 1: 39 | correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item() 40 | else: 41 | correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item() 42 | total += batch_weights.sum().item() 43 | 44 | if debug: 45 | break 46 | 47 | algorithm.train() 48 | 49 | acc = correct / total 50 | loss = losssum / total 51 | return acc, loss 52 | 53 | 54 | def accuracy(algorithm, loader_kwargs, weights, **kwargs): 55 | if isinstance(loader_kwargs, dict): 56 | loader = FastDataLoader(**loader_kwargs) 57 | elif isinstance(loader_kwargs, FastDataLoader): 58 | loader = loader_kwargs 59 | else: 60 | raise ValueError(loader_kwargs) 61 | return accuracy_from_loader(algorithm, loader, weights, **kwargs) 62 | 63 | 64 | class Evaluator: 65 | def __init__( 66 | self, test_envs, eval_meta, n_envs, logger, evalmode="fast", debug=False, target_env=None 67 | ): 68 | all_envs = list(range(n_envs)) 69 | train_envs = sorted(set(all_envs) - set(test_envs)) 70 | self.test_envs = test_envs 71 | self.train_envs = train_envs 72 | self.eval_meta = eval_meta 73 | self.n_envs = n_envs 74 | self.logger = logger 75 | self.evalmode = evalmode 76 | self.debug = debug 77 | 78 | if target_env is not None: 79 | self.set_target_env(target_env) 80 | 81 | def set_target_env(self, target_env): 82 | """When len(test_envs) == 2, you can specify target env for computing exact test acc.""" 83 | self.test_envs = [target_env] 84 | 85 | def evaluate(self, algorithm, ret_losses=False): 86 | n_train_envs = len(self.train_envs) 87 | n_test_envs = len(self.test_envs) 88 | assert n_test_envs == 1 89 | summaries = collections.defaultdict(float) 90 | # for key order 91 | summaries["test_in"] = 0.0 92 | summaries["test_out"] = 0.0 93 | summaries["train_in"] = 0.0 94 | summaries["train_out"] = 0.0 95 | accuracies = {} 96 | losses = {} 97 | 98 | # order: in_splits + out_splits. 99 | for name, loader_kwargs, weights in self.eval_meta: 100 | # env\d_[in|out] 101 | env_name, inout = name.split("_") 102 | env_num = int(env_name[3:]) 103 | 104 | skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs 105 | if skip_eval: 106 | continue 107 | 108 | is_test = env_num in self.test_envs 109 | acc, loss = accuracy(algorithm, loader_kwargs, weights, debug=self.debug) 110 | accuracies[name] = acc 111 | losses[name] = loss 112 | 113 | if env_num in self.train_envs: 114 | summaries["train_" + inout] += acc / n_train_envs 115 | if inout == "out": 116 | summaries["tr_" + inout + "loss"] += loss / n_train_envs 117 | elif is_test: 118 | summaries["test_" + inout] += acc / n_test_envs 119 | 120 | if ret_losses: 121 | return accuracies, summaries, losses 122 | else: 123 | return accuracies, summaries 124 | -------------------------------------------------------------------------------- /domainbed/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from domainbed.datasets import datasets 5 | from domainbed.lib import misc 6 | from domainbed.datasets import transforms as DBT 7 | 8 | 9 | def set_transfroms(dset, data_type, hparams, algorithm_class=None): 10 | """ 11 | Args: 12 | data_type: ['train', 'valid', 'test'] 13 | """ 14 | assert hparams["data_augmentation"] 15 | 16 | additional_data = False 17 | if data_type == "train": 18 | dset.transforms = {"x": DBT.aug} 19 | additional_data = True 20 | elif data_type == "valid": 21 | if hparams["val_augment"] is False: 22 | dset.transforms = {"x": DBT.basic} 23 | else: 24 | # Originally, DomainBed use same training augmentation policy to validation. 25 | # We turn off the augmentation for validation as default, 26 | # but left the option to reproducibility. 27 | dset.transforms = {"x": DBT.aug} 28 | elif data_type == "test": 29 | dset.transforms = {"x": DBT.basic} 30 | else: 31 | raise ValueError(data_type) 32 | 33 | if additional_data and algorithm_class is not None: 34 | for key, transform in algorithm_class.transforms.items(): 35 | dset.transforms[key] = transform 36 | 37 | 38 | def get_dataset(test_envs, args, hparams, algorithm_class=None): 39 | """Get dataset and split.""" 40 | if args.dataset=="PACSGenCSV" or args.dataset=="VLCSGenCSV" or args.dataset=="OfficeHomeGenCSV": 41 | dataset = vars(datasets)[args.dataset](args.data_dir, args.gen_data_dir, 42 | args.gen_csv_dir, args.gen_num_per_class, args.gen_max_entropy, 43 | args.gen_random_selection, args.gen_all_data, args.gen_only_correct) 44 | 45 | 46 | else: 47 | dataset = vars(datasets)[args.dataset](args.data_dir) 48 | 49 | in_splits = [] 50 | out_splits = [] 51 | for env_i, env in enumerate(dataset): 52 | # The split only depends on seed_hash (= trial_seed). 53 | # It means that the split is always identical only if use same trial_seed, 54 | # independent to run the code where, when, or how many times. 55 | out, in_ = split_dataset( 56 | env, 57 | int(len(env) * args.holdout_fraction), 58 | misc.seed_hash(args.trial_seed, env_i), 59 | ) 60 | if env_i in test_envs: 61 | in_type = "test" 62 | out_type = "test" 63 | else: 64 | in_type = "train" 65 | out_type = "valid" 66 | 67 | set_transfroms(in_, in_type, hparams, algorithm_class) 68 | set_transfroms(out, out_type, hparams, algorithm_class) 69 | 70 | if hparams["class_balanced"]: 71 | in_weights = misc.make_weights_for_balanced_classes(in_) 72 | out_weights = misc.make_weights_for_balanced_classes(out) 73 | else: 74 | in_weights, out_weights = None, None 75 | in_splits.append((in_, in_weights)) 76 | out_splits.append((out, out_weights)) 77 | 78 | print(f"Num datasets: {len(dataset)}") 79 | 80 | return dataset, in_splits, out_splits 81 | 82 | 83 | class _SplitDataset(torch.utils.data.Dataset): 84 | """Used by split_dataset""" 85 | 86 | def __init__(self, underlying_dataset, keys): 87 | super(_SplitDataset, self).__init__() 88 | self.underlying_dataset = underlying_dataset 89 | self.keys = keys 90 | self.transforms = {} 91 | 92 | self.direct_return = isinstance(underlying_dataset, _SplitDataset) 93 | 94 | def __getitem__(self, key): 95 | if self.direct_return: 96 | return self.underlying_dataset[self.keys[key]] 97 | 98 | x, y = self.underlying_dataset[self.keys[key]] 99 | ret = {"y": y} 100 | 101 | for key, transform in self.transforms.items(): 102 | ret[key] = transform(x) 103 | 104 | return ret 105 | 106 | def __len__(self): 107 | return len(self.keys) 108 | 109 | 110 | def split_dataset(dataset, n, seed=0): 111 | """ 112 | Return a pair of datasets corresponding to a random split of the given 113 | dataset, with n datapoints in the first dataset and the rest in the last, 114 | using the given random seed 115 | """ 116 | assert n <= len(dataset) 117 | keys = list(range(len(dataset))) 118 | np.random.RandomState(seed).shuffle(keys) 119 | keys_1 = keys[:n] 120 | keys_2 = keys[n:] 121 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 122 | -------------------------------------------------------------------------------- /domainbed/lib/swa_utils.py: -------------------------------------------------------------------------------- 1 | # Burrowed from https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py 2 | # modified for the DomainBed. 3 | import copy 4 | import torch 5 | from torch.nn import Module 6 | from copy import deepcopy 7 | 8 | 9 | class AveragedModel(Module): 10 | def __init__(self, model, device=None, avg_fn=None, rm_optimizer=False): 11 | super(AveragedModel, self).__init__() 12 | self.start_step = -1 13 | self.end_step = -1 14 | if isinstance(model, AveragedModel): 15 | # prevent nested averagedmodel 16 | model = model.module 17 | self.module = deepcopy(model) 18 | if rm_optimizer: 19 | for k, v in vars(self.module).items(): 20 | if isinstance(v, torch.optim.Optimizer): 21 | setattr(self.module, k, None) 22 | 23 | if device is not None: 24 | self.module = self.module.to(device) 25 | 26 | self.register_buffer("n_averaged", torch.tensor(0, dtype=torch.long, device=device)) 27 | 28 | if avg_fn is None: 29 | def avg_fn(averaged_model_parameter, model_parameter, num_averaged): 30 | return averaged_model_parameter + (model_parameter - averaged_model_parameter) / ( 31 | num_averaged + 1 32 | ) 33 | 34 | self.avg_fn = avg_fn 35 | 36 | def forward(self, *args, **kwargs): 37 | # return self.predict(*args, **kwargs) 38 | return self.module(*args, **kwargs) 39 | 40 | def predict(self, *args, **kwargs): 41 | return self.module(*args, **kwargs) 42 | 43 | @property 44 | def network(self): 45 | return self.module.network 46 | 47 | def update_parameters(self, model, step=None, start_step=None, end_step=None): 48 | """Update averaged model parameters 49 | 50 | Args: 51 | model: current model to update params 52 | step: current step. step is saved for log the averaged range 53 | start_step: set start_step only for first update 54 | end_step: set end_step 55 | """ 56 | if isinstance(model, AveragedModel): 57 | model = model.module 58 | for p_swa, p_model in zip(self.parameters(), model.parameters()): 59 | device = p_swa.device 60 | p_model_ = p_model.detach().to(device) 61 | if self.n_averaged == 0: 62 | p_swa.detach().copy_(p_model_) 63 | else: 64 | p_swa.detach().copy_( 65 | self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)) 66 | ) 67 | self.n_averaged += 1 68 | 69 | if step is not None: 70 | if start_step is None: 71 | start_step = step 72 | if end_step is None: 73 | end_step = step 74 | 75 | if start_step is not None: 76 | if self.n_averaged == 1: 77 | self.start_step = start_step 78 | 79 | if end_step is not None: 80 | self.end_step = end_step 81 | 82 | def clone(self): 83 | clone = copy.deepcopy(self.module) 84 | clone.optimizer = clone.new_optimizer(clone.network.parameters()) 85 | return clone 86 | 87 | 88 | def cvt_dbiterator_to_loader(dbiterator, n_iter): 89 | """Convert DB iterator to the loader""" 90 | for _ in range(n_iter): 91 | minibatches = [(x, y) for x, y in next(dbiterator)] 92 | all_x = torch.cat([x for x, y in minibatches]) 93 | all_y = torch.cat([y for x, y in minibatches]) 94 | 95 | yield all_x, all_y 96 | 97 | 98 | @torch.no_grad() 99 | def update_bn(iterator, model, n_steps, device="cuda"): 100 | momenta = {} 101 | for module in model.modules(): 102 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 103 | module.running_mean = torch.zeros_like(module.running_mean) 104 | module.running_var = torch.ones_like(module.running_var) 105 | momenta[module] = module.momentum 106 | 107 | if not momenta: 108 | return 109 | 110 | was_training = model.training 111 | model.train() 112 | for module in momenta.keys(): 113 | module.momentum = None 114 | module.num_batches_tracked *= 0 115 | 116 | for i in range(n_steps): 117 | # batches_dictlist: [{env0_data_key: tensor, env0_...}, env1_..., ...] 118 | batches_dictlist = next(iterator) 119 | x = torch.cat([dic["x"] for dic in batches_dictlist]) 120 | x = x.to(device) 121 | 122 | model(x) 123 | 124 | for bn_module in momenta.keys(): 125 | bn_module.momentum = momenta[bn_module] 126 | model.train(was_training) 127 | -------------------------------------------------------------------------------- /eval_cls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import csv 4 | import tqdm 5 | import torch 6 | import pickle 7 | import random 8 | import numpy as np 9 | 10 | from domainbed.datasets import transforms as DBT 11 | from domainbed.datasets.datasets import GeneralDGEvalDataset 12 | 13 | from domainbed import algorithms 14 | from domainbed import hparams_registry 15 | 16 | from torch.utils.data import ConcatDataset 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | from scipy.stats import entropy 20 | 21 | 22 | 23 | 24 | ### Argument parser 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="DomainBed Evaluation Script") 27 | parser.add_argument("--data_dir", type=str, required=True, help="Root directory for the dataset") 28 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save results") 29 | parser.add_argument("--algorithm", type=str, default='ERM', choices=['ERM', 'OtherAlgorithm'], help="Algorithm to use") 30 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size") 31 | parser.add_argument("--ckpt_dir", type=str, default=64, help="Path to a trained model") 32 | 33 | return parser.parse_args() 34 | 35 | 36 | args = parse_args() 37 | 38 | # Hyperparameters 39 | hparams = {"data_augmentation": False, 40 | "resnet18": False, 41 | "resnet_dropout": 0.0, 42 | } 43 | 44 | 45 | 46 | print("===="*10) 47 | print(f"data dir: {args.data_dir}") 48 | print(f"checkpoint_dir: {args.ckpt_dir}") 49 | 50 | save_dir = args.save_dir 51 | 52 | ### Create save root 53 | os.makedirs(save_dir, exist_ok=True) 54 | 55 | ### Data 56 | dataset = GeneralDGEvalDataset(args.data_dir, -1 , hparams) 57 | 58 | print(f"Num datasets: {len(dataset)}") 59 | # dataset = iter(dataset) 60 | 61 | combined_dataset = ConcatDataset(dataset) 62 | class_to_idx = combined_dataset.datasets[0].class_to_idx 63 | idx_to_class = {v: k for k, v in class_to_idx.items()} 64 | 65 | 66 | ### loading the model 67 | default_hparams = hparams_registry.default_hparams(args.algorithm, dataset) 68 | default_hparams.update(hparams) 69 | hparams = default_hparams 70 | 71 | 72 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 73 | model = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), 74 | hparams) 75 | checkpoint = torch.load(args.ckpt_dir, map_location='cpu') 76 | state_dict = checkpoint['model_dict'] 77 | new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 78 | msg = model.load_state_dict(new_state_dict, strict=False) 79 | print("+++ msg", msg) 80 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 81 | model.to(device) 82 | model.eval() 83 | 84 | 85 | # Create a DataLoader 86 | data_loader = DataLoader(combined_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True) 87 | 88 | # Initialize accuracy tracking variables 89 | total = 0 90 | correct = 0 91 | 92 | # Iterate over batches 93 | with torch.no_grad(): 94 | for batch in tqdm.tqdm(data_loader): 95 | x, y, paths = batch 96 | x = x.to(device) 97 | y = y.to(device) 98 | 99 | # Batch prediction 100 | outputs, embeddings = model.predict_emb(x) 101 | _, predicted = torch.max(outputs.data, 1) 102 | 103 | 104 | 105 | # Update correct and total counts for accuracy 106 | total += y.size(0) 107 | correct += (predicted == y).sum().item() 108 | 109 | 110 | 111 | # Process each item in the batch 112 | for idx, (path, pred, logit, label, emb) in enumerate(zip(paths, predicted, outputs.data.cpu(), y, embeddings)): 113 | correct_prediction = "Yes" if pred == label else "No" 114 | 115 | # Calculate entropy 116 | prob = torch.nn.functional.softmax(logit, dim=0).numpy() 117 | prediction_entropy = entropy(prob) 118 | 119 | # Save image information to CSV 120 | with open(os.path.join(save_dir, "image_predictions.csv"), 'a', newline='') as file: 121 | writer = csv.writer(file) 122 | writer.writerow([path, label.item(), pred.item(), correct_prediction, prediction_entropy, logit.tolist()]) 123 | 124 | 125 | 126 | # Calculate and print final accuracy 127 | accuracy = 100 * correct / total 128 | print(f'Final Accuracy: {accuracy}%') 129 | 130 | 131 | # Save the summary information to a TXT file 132 | summary_file = os.path.join(save_dir, "summary_info.txt") 133 | with open(summary_file, 'w') as file: 134 | file.write(f'checkpoint_dir: {args.ckpt_dir}\n') 135 | file.write(f'Total images: {total}\n') 136 | file.write(f'Accuracy: {accuracy}%\n') 137 | 138 | print(f"Saved summary information to {summary_file}") 139 | 140 | -------------------------------------------------------------------------------- /domainbed/hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import numpy as np 4 | 5 | 6 | def _hparams(algorithm, dataset, random_state): 7 | """ 8 | Global registry of hyperparams. Each entry is a (default, random) tuple. 9 | New algorithms / networks / etc. should add entries here. 10 | """ 11 | SMALL_IMAGES = ["Debug28", "RotatedMNIST", "ColoredMNIST"] 12 | 13 | hparams = {} 14 | 15 | hparams["data_augmentation"] = (True, True) 16 | hparams["val_augment"] = (False, False) # augmentation for in-domain validation set 17 | hparams["resnet18"] = (False, False) 18 | hparams["resnet_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5])) 19 | hparams["class_balanced"] = (False, False) 20 | hparams["optimizer"] = ("adam", "adam") 21 | 22 | hparams["freeze_bn"] = (True, True) 23 | hparams["pretrained"] = (True, True) # only for ResNet 24 | 25 | if dataset not in SMALL_IMAGES: 26 | hparams["lr"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 27 | if dataset == "DomainNet": 28 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5))) 29 | else: 30 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5.5))) 31 | if algorithm == "ARM": 32 | hparams["batch_size"] = (8, 8) 33 | else: 34 | hparams["lr"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 35 | hparams["batch_size"] = (64, int(2 ** random_state.uniform(3, 9))) 36 | 37 | if dataset in SMALL_IMAGES: 38 | hparams["weight_decay"] = (0.0, 0.0) 39 | else: 40 | hparams["weight_decay"] = (0.0, 10 ** random_state.uniform(-6, -2)) 41 | 42 | if algorithm in ["DANN", "CDANN"]: 43 | if dataset not in SMALL_IMAGES: 44 | hparams["lr_g"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 45 | hparams["lr_d"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 46 | else: 47 | hparams["lr_g"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 48 | hparams["lr_d"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 49 | 50 | if dataset in SMALL_IMAGES: 51 | hparams["weight_decay_g"] = (0.0, 0.0) 52 | else: 53 | hparams["weight_decay_g"] = (0.0, 10 ** random_state.uniform(-6, -2)) 54 | 55 | hparams["lambda"] = (1.0, 10 ** random_state.uniform(-2, 2)) 56 | hparams["weight_decay_d"] = (0.0, 10 ** random_state.uniform(-6, -2)) 57 | hparams["d_steps_per_g_step"] = (1, int(2 ** random_state.uniform(0, 3))) 58 | hparams["grad_penalty"] = (0.0, 10 ** random_state.uniform(-2, 1)) 59 | hparams["beta1"] = (0.5, random_state.choice([0.0, 0.5])) 60 | hparams["mlp_width"] = (256, int(2 ** random_state.uniform(6, 10))) 61 | hparams["mlp_depth"] = (3, int(random_state.choice([3, 4, 5]))) 62 | hparams["mlp_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5])) 63 | elif algorithm == "RSC": 64 | hparams["rsc_f_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5)) 65 | hparams["rsc_b_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5)) 66 | elif algorithm == "SagNet": 67 | hparams["sag_w_adv"] = (0.1, 10 ** random_state.uniform(-2, 1)) 68 | elif algorithm == "IRM": 69 | hparams["irm_lambda"] = (1e2, 10 ** random_state.uniform(-1, 5)) 70 | hparams["irm_penalty_anneal_iters"] = ( 71 | 500, 72 | int(10 ** random_state.uniform(0, 4)), 73 | ) 74 | elif algorithm in ["Mixup", "OrgMixup"]: 75 | hparams["mixup_alpha"] = (0.2, 10 ** random_state.uniform(-1, -1)) 76 | elif algorithm == "GroupDRO": 77 | hparams["groupdro_eta"] = (1e-2, 10 ** random_state.uniform(-3, -1)) 78 | elif algorithm in ("MMD", "CORAL"): 79 | hparams["mmd_gamma"] = (1.0, 10 ** random_state.uniform(-1, 1)) 80 | elif algorithm in ("MLDG", "SOMLDG"): 81 | hparams["mldg_beta"] = (1.0, 10 ** random_state.uniform(-1, 1)) 82 | elif algorithm == "MTL": 83 | hparams["mtl_ema"] = (0.99, random_state.choice([0.5, 0.9, 0.99, 1.0])) 84 | elif algorithm == "VREx": 85 | hparams["vrex_lambda"] = (1e1, 10 ** random_state.uniform(-1, 5)) 86 | hparams["vrex_penalty_anneal_iters"] = ( 87 | 500, 88 | int(10 ** random_state.uniform(0, 4)), 89 | ) 90 | elif algorithm == "SAM": 91 | hparams["rho"] = (0.05, random_state.choice([0.01, 0.02, 0.05, 0.1])) 92 | elif algorithm == "CutMix": 93 | hparams["beta"] = (1.0, 1.0) 94 | # cutmix_prob is set to 1.0 for ImageNet and 0.5 for CIFAR100 in the original paper. 95 | hparams["cutmix_prob"] = (1.0, 1.0) 96 | 97 | return hparams 98 | 99 | 100 | def default_hparams(algorithm, dataset): 101 | dummy_random_state = np.random.RandomState(0) 102 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, dummy_random_state).items()} 103 | 104 | 105 | def random_hparams(algorithm, dataset, seed): 106 | random_state = np.random.RandomState(seed) 107 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, random_state).items()} 108 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /domainbed/lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import inspect 6 | import json 7 | import types 8 | import warnings 9 | 10 | import numpy as np 11 | 12 | 13 | def make_selector_fn(selector): 14 | """ 15 | If selector is a function, return selector. 16 | Otherwise, return a function corresponding to the selector string. Examples 17 | of valid selector strings and the corresponding functions: 18 | x lambda obj: obj['x'] 19 | x.y lambda obj: obj['x']['y'] 20 | x,y lambda obj: (obj['x'], obj['y']) 21 | """ 22 | if isinstance(selector, str): 23 | if "," in selector: 24 | parts = selector.split(",") 25 | part_selectors = [make_selector_fn(part) for part in parts] 26 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 27 | elif "." in selector: 28 | parts = selector.split(".") 29 | part_selectors = [make_selector_fn(part) for part in parts] 30 | 31 | def f(obj): 32 | for sel in part_selectors: 33 | obj = sel(obj) 34 | return obj 35 | 36 | return f 37 | else: 38 | key = selector.strip() 39 | return lambda obj: obj[key] 40 | elif isinstance(selector, types.FunctionType): 41 | return selector 42 | else: 43 | raise TypeError 44 | 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({"_": obj}, sort_keys=True) 52 | 53 | 54 | class Q(object): 55 | def __init__(self, list_): 56 | super(Q, self).__init__() 57 | self._list = list_ 58 | 59 | def __len__(self): 60 | return len(self._list) 61 | 62 | def __getitem__(self, key): 63 | return self._list[key] 64 | 65 | def __eq__(self, other): 66 | if isinstance(other, self.__class__): 67 | return self._list == other._list 68 | else: 69 | return self._list == other 70 | 71 | def __str__(self): 72 | return str(self._list) 73 | 74 | def __repr__(self): 75 | return repr(self._list) 76 | 77 | def _append(self, item): 78 | """Unsafe, be careful you know what you're doing.""" 79 | self._list.append(item) 80 | 81 | def group(self, selector): 82 | """ 83 | Group elements by selector and return a list of (group, group_records) 84 | tuples. 85 | """ 86 | selector = make_selector_fn(selector) 87 | groups = {} 88 | for x in self._list: 89 | group = selector(x) 90 | group_key = hashable(group) 91 | if group_key not in groups: 92 | groups[group_key] = (group, Q([])) 93 | groups[group_key][1]._append(x) 94 | results = [groups[key] for key in sorted(groups.keys())] 95 | return Q(results) 96 | 97 | def group_map(self, selector, fn): 98 | """ 99 | Group elements by selector, apply fn to each group, and return a list 100 | of the results. 101 | """ 102 | return self.group(selector).map(fn) 103 | 104 | def map(self, fn): 105 | """ 106 | map self onto fn. If fn takes multiple args, tuple-unpacking 107 | is applied. 108 | """ 109 | if len(inspect.signature(fn).parameters) > 1: 110 | return Q([fn(*x) for x in self._list]) 111 | else: 112 | return Q([fn(x) for x in self._list]) 113 | 114 | def select(self, selector): 115 | selector = make_selector_fn(selector) 116 | return Q([selector(x) for x in self._list]) 117 | 118 | def min(self): 119 | return min(self._list) 120 | 121 | def max(self): 122 | return max(self._list) 123 | 124 | def sum(self): 125 | return sum(self._list) 126 | 127 | def len(self): 128 | return len(self._list) 129 | 130 | def mean(self): 131 | with warnings.catch_warnings(): 132 | warnings.simplefilter("ignore") 133 | return float(np.mean(self._list)) 134 | 135 | def std(self): 136 | with warnings.catch_warnings(): 137 | warnings.simplefilter("ignore") 138 | return float(np.std(self._list)) 139 | 140 | def mean_std(self): 141 | return (self.mean(), self.std()) 142 | 143 | def argmax(self, selector): 144 | selector = make_selector_fn(selector) 145 | return max(self._list, key=selector) 146 | 147 | def filter(self, fn): 148 | return Q([x for x in self._list if fn(x)]) 149 | 150 | def filter_equals(self, selector, value): 151 | """like [x for x in y if x.selector == value]""" 152 | selector = make_selector_fn(selector) 153 | return self.filter(lambda r: selector(r) == value) 154 | 155 | def filter_not_none(self): 156 | return self.filter(lambda r: r is not None) 157 | 158 | def filter_not_nan(self): 159 | return self.filter(lambda r: not np.isnan(r)) 160 | 161 | def flatten(self): 162 | return Q([y for x in self._list for y in x]) 163 | 164 | def unique(self): 165 | result = [] 166 | result_set = set() 167 | for x in self._list: 168 | hashable_x = hashable(x) 169 | if hashable_x not in result_set: 170 | result_set.add(hashable_x) 171 | result.append(x) 172 | return Q(result) 173 | 174 | def sorted(self, key=None, reverse=False): 175 | if key is None: 176 | key = lambda x: x 177 | 178 | def key2(x): 179 | x = key(x) 180 | if isinstance(x, (np.floating, float)) and np.isnan(x): 181 | return float("-inf") 182 | else: 183 | return x 184 | 185 | return Q(sorted(self._list, key=key2, reverse=reverse)) 186 | -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /domainbed/lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import os 9 | import shutil 10 | import errno 11 | from datetime import datetime 12 | from collections import Counter 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | def make_weights_for_balanced_classes(dataset): 20 | counts = Counter() 21 | classes = [] 22 | for _, y in dataset: 23 | y = int(y) 24 | counts[y] += 1 25 | classes.append(y) 26 | 27 | n_classes = len(counts) 28 | 29 | weight_per_class = {} 30 | for y in counts: 31 | weight_per_class[y] = 1 / (counts[y] * n_classes) 32 | 33 | weights = torch.zeros(len(dataset)) 34 | for i, y in enumerate(classes): 35 | weights[i] = weight_per_class[int(y)] 36 | 37 | return weights 38 | 39 | 40 | def seed_hash(*args): 41 | """ 42 | Derive an integer hash from all args, for use as a random seed. 43 | """ 44 | args_str = str(args) 45 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2 ** 31) 46 | 47 | 48 | def to_row(row, colwidth=10, latex=False): 49 | """Convert value list to row string""" 50 | if latex: 51 | sep = " & " 52 | end_ = "\\\\" 53 | else: 54 | sep = " " 55 | end_ = "" 56 | 57 | def format_val(x): 58 | if np.issubdtype(type(x), np.floating): 59 | x = "{:.6f}".format(x) 60 | return str(x).ljust(colwidth)[:colwidth] 61 | 62 | return sep.join([format_val(x) for x in row]) + " " + end_ 63 | 64 | 65 | def random_pairs_of_minibatches(minibatches): 66 | # n_tr_envs = len(minibatches) 67 | perm = torch.randperm(len(minibatches)).tolist() 68 | pairs = [] 69 | 70 | for i in range(len(minibatches)): 71 | # j = cyclic(i + 1) 72 | j = i + 1 if i < (len(minibatches) - 1) else 0 73 | 74 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 75 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 76 | 77 | min_n = min(len(xi), len(xj)) 78 | 79 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 80 | 81 | return pairs 82 | 83 | 84 | ########################################################### 85 | # Custom utils 86 | ########################################################### 87 | 88 | 89 | def index_conditional_iterate(skip_condition, iterable, index): 90 | for i, x in enumerate(iterable): 91 | if skip_condition(i): 92 | continue 93 | 94 | if index: 95 | yield i, x 96 | else: 97 | yield x 98 | 99 | 100 | class SplitIterator: 101 | def __init__(self, test_envs): 102 | self.test_envs = test_envs 103 | 104 | def train(self, iterable, index=False): 105 | return index_conditional_iterate(lambda idx: idx in self.test_envs, iterable, index) 106 | 107 | def test(self, iterable, index=False): 108 | return index_conditional_iterate(lambda idx: idx not in self.test_envs, iterable, index) 109 | 110 | 111 | class AverageMeter: 112 | """Computes and stores the average and current value""" 113 | 114 | def __init__(self): 115 | self.reset() 116 | 117 | def reset(self): 118 | """Reset all statistics""" 119 | self.val = 0 120 | self.avg = 0 121 | self.sum = 0 122 | self.count = 0 123 | 124 | def update(self, val, n=1): 125 | """Update statistics""" 126 | self.val = val 127 | self.sum += val * n 128 | self.count += n 129 | self.avg = self.sum / self.count 130 | 131 | def __repr__(self): 132 | return "{:.3f} (val={:.3f}, count={})".format(self.avg, self.val, self.count) 133 | 134 | 135 | class AverageMeters: 136 | def __init__(self, *keys): 137 | self.keys = keys 138 | for k in keys: 139 | setattr(self, k, AverageMeter()) 140 | 141 | def resets(self): 142 | for k in self.keys: 143 | getattr(self, k).reset() 144 | 145 | def updates(self, dic, n=1): 146 | for k, v in dic.items(): 147 | getattr(self, k).update(v, n) 148 | 149 | def __repr__(self): 150 | return " ".join(["{}: {}".format(k, str(getattr(self, k))) for k in self.keys]) 151 | 152 | def get_averages(self): 153 | dic = {k: getattr(self, k).avg for k in self.keys} 154 | return dic 155 | 156 | 157 | def timestamp(fmt="%y%m%d_%H-%M-%S"): 158 | return datetime.now().strftime(fmt) 159 | 160 | 161 | def makedirs(path): 162 | if not os.path.exists(path): 163 | try: 164 | os.makedirs(path) 165 | except OSError as exc: 166 | if exc.errno != errno.EEXIST: 167 | raise 168 | 169 | 170 | def rm(path): 171 | """remove dir recursively""" 172 | if os.path.isdir(path): 173 | shutil.rmtree(path, ignore_errors=True) 174 | elif os.path.exists(path): 175 | os.remove(path) 176 | 177 | 178 | def cp(src, dst): 179 | shutil.copy2(src, dst) 180 | 181 | 182 | def get_lr(optimizer): 183 | """Assume that the optimizer has single lr""" 184 | lr = optimizer.param_groups[0]["lr"] 185 | 186 | return lr 187 | 188 | 189 | @torch.no_grad() 190 | def hash_bn(module): 191 | summary = [] 192 | for m in module.modules(): 193 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 194 | w = m.weight.detach().mean().item() 195 | b = m.bias.detach().mean().item() 196 | rm = m.running_mean.detach().mean().item() 197 | rv = m.running_var.detach().mean().item() 198 | summary.append((w, b, rm, rv)) 199 | w, b, rm, rv = [np.mean(col) for col in zip(*summary)] 200 | 201 | return w, b, rm, rv 202 | 203 | 204 | def merge_dictlist(dictlist): 205 | """Merge list of dicts into dict of lists, by grouping same key.""" 206 | ret = {k: [] for k in dictlist[0].keys()} 207 | for dic in dictlist: 208 | for data_key, v in dic.items(): 209 | ret[data_key].append(v) 210 | return ret 211 | -------------------------------------------------------------------------------- /domainbed/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models 7 | 8 | from domainbed.lib import wide_resnet 9 | 10 | 11 | class Identity(nn.Module): 12 | """An identity layer""" 13 | 14 | def __init__(self): 15 | super(Identity, self).__init__() 16 | 17 | def forward(self, x): 18 | return x 19 | 20 | 21 | class SqueezeLastTwo(nn.Module): 22 | """ 23 | A module which squeezes the last two dimensions, 24 | ordinary squeeze can be a problem for batch size 1 25 | """ 26 | 27 | def __init__(self): 28 | super(SqueezeLastTwo, self).__init__() 29 | 30 | def forward(self, x): 31 | return x.view(x.shape[0], x.shape[1]) 32 | 33 | 34 | class MLP(nn.Module): 35 | """Just an MLP""" 36 | 37 | def __init__(self, n_inputs, n_outputs, hparams): 38 | super(MLP, self).__init__() 39 | self.input = nn.Linear(n_inputs, hparams["mlp_width"]) 40 | self.dropout = nn.Dropout(hparams["mlp_dropout"]) 41 | self.hiddens = nn.ModuleList( 42 | [ 43 | nn.Linear(hparams["mlp_width"], hparams["mlp_width"]) 44 | for _ in range(hparams["mlp_depth"] - 2) 45 | ] 46 | ) 47 | self.output = nn.Linear(hparams["mlp_width"], n_outputs) 48 | self.n_outputs = n_outputs 49 | 50 | def forward(self, x): 51 | x = self.input(x) 52 | x = self.dropout(x) 53 | x = F.relu(x) 54 | for hidden in self.hiddens: 55 | x = hidden(x) 56 | x = self.dropout(x) 57 | x = F.relu(x) 58 | x = self.output(x) 59 | return x 60 | 61 | 62 | class ResNet(torch.nn.Module): 63 | """ResNet with the softmax chopped off and the batchnorm frozen""" 64 | 65 | def __init__(self, input_shape, hparams, network=None): 66 | super(ResNet, self).__init__() 67 | if hparams["resnet18"]: 68 | if network is None: 69 | network = torchvision.models.resnet18(pretrained=hparams["pretrained"]) 70 | self.network = network 71 | self.n_outputs = 512 72 | else: 73 | if network is None: 74 | network = torchvision.models.resnet50(pretrained=hparams["pretrained"]) 75 | self.network = network 76 | self.n_outputs = 2048 77 | 78 | # adapt number of channels 79 | nc = input_shape[0] 80 | if nc != 3: 81 | tmp = self.network.conv1.weight.data.clone() 82 | 83 | self.network.conv1 = nn.Conv2d( 84 | nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 85 | ) 86 | 87 | for i in range(nc): 88 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] 89 | 90 | # save memory 91 | del self.network.fc 92 | self.network.fc = Identity() 93 | 94 | self.hparams = hparams 95 | self.dropout = nn.Dropout(hparams["resnet_dropout"]) 96 | self.freeze_bn() 97 | 98 | def forward(self, x): 99 | """Encode x into a feature vector of size n_outputs.""" 100 | return self.dropout(self.network(x)) 101 | 102 | def train(self, mode=True): 103 | """ 104 | Override the default train() to freeze the BN parameters 105 | """ 106 | super().train(mode) 107 | self.freeze_bn() 108 | 109 | def freeze_bn(self): 110 | if self.hparams["freeze_bn"] is False: 111 | return 112 | 113 | for m in self.network.modules(): 114 | if isinstance(m, nn.BatchNorm2d): 115 | m.eval() 116 | 117 | 118 | class MNIST_CNN(nn.Module): 119 | """ 120 | Hand-tuned architecture for MNIST. 121 | Weirdness I've noticed so far with this architecture: 122 | - adding a linear layer after the mean-pool in features hurts 123 | RotatedMNIST-100 generalization severely. 124 | """ 125 | 126 | n_outputs = 128 127 | 128 | def __init__(self, input_shape): 129 | super(MNIST_CNN, self).__init__() 130 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) 131 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 132 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) 133 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) 134 | 135 | self.bn0 = nn.GroupNorm(8, 64) 136 | self.bn1 = nn.GroupNorm(8, 128) 137 | self.bn2 = nn.GroupNorm(8, 128) 138 | self.bn3 = nn.GroupNorm(8, 128) 139 | 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | self.squeezeLastTwo = SqueezeLastTwo() 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = F.relu(x) 146 | x = self.bn0(x) 147 | 148 | x = self.conv2(x) 149 | x = F.relu(x) 150 | x = self.bn1(x) 151 | 152 | x = self.conv3(x) 153 | x = F.relu(x) 154 | x = self.bn2(x) 155 | 156 | x = self.conv4(x) 157 | x = F.relu(x) 158 | x = self.bn3(x) 159 | 160 | x = self.avgpool(x) 161 | x = self.squeezeLastTwo(x) 162 | return x 163 | 164 | 165 | class ContextNet(nn.Module): 166 | def __init__(self, input_shape): 167 | super(ContextNet, self).__init__() 168 | 169 | # Keep same dimensions 170 | padding = (5 - 1) // 2 171 | self.context_net = nn.Sequential( 172 | nn.Conv2d(input_shape[0], 64, 5, padding=padding), 173 | nn.BatchNorm2d(64), 174 | nn.ReLU(), 175 | nn.Conv2d(64, 64, 5, padding=padding), 176 | nn.BatchNorm2d(64), 177 | nn.ReLU(), 178 | nn.Conv2d(64, 1, 5, padding=padding), 179 | ) 180 | 181 | def forward(self, x): 182 | return self.context_net(x) 183 | 184 | 185 | def Featurizer(input_shape, hparams): 186 | """Auto-select an appropriate featurizer for the given input shape.""" 187 | if len(input_shape) == 1: 188 | return MLP(input_shape[0], 128, hparams) 189 | elif input_shape[1:3] == (28, 28): 190 | return MNIST_CNN(input_shape) 191 | elif input_shape[1:3] == (32, 32): 192 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.0) 193 | elif input_shape[1:3] == (224, 224): 194 | return ResNet(input_shape, hparams) 195 | else: 196 | raise NotImplementedError(f"Input shape {input_shape} is not supported") 197 | -------------------------------------------------------------------------------- /bash/train_cls/pacs_interpolation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=PACS 6 | save_dir="./save/train_cls_fds/${dataset}/" 7 | data_dir="./data" 8 | gen_num_per_class=570 # number of images/class of the interpolated (generated) data that will be selecte to used in training 9 | 10 | 11 | 12 | ########################## training the classifier using both original and interpolated data using the source domains: "art_painting cartoon photo" domains => indexes: "0" "1" "2" 13 | source_domains="012" 14 | test_env=3 15 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 16 | 17 | ### first seed 18 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 19 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 20 | 21 | 22 | ### second seed 23 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 24 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 25 | 26 | 27 | ### third seed 28 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 29 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 30 | 31 | 32 | 33 | ########################## training the classifier using both original and interpolated data using the source domains: "art_painting cartoon sketch" domains => indexes: "0" "1" "3" 34 | source_domains="013" 35 | test_env=2 36 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 37 | 38 | ### first seed 39 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 40 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 41 | 42 | 43 | ### second seed 44 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 45 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 46 | 47 | 48 | ### third seed 49 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 50 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 51 | 52 | 53 | 54 | ########################## training the classifier using both original and interpolated data using the source domains: "art_painting photo sketch" domains => indexes: "0" "2" "3" 55 | source_domains="023" 56 | test_env=1 57 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 58 | 59 | ### first seed 60 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 61 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 62 | 63 | 64 | ### second seed 65 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 66 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 67 | 68 | 69 | ### third seed 70 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 71 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 72 | 73 | 74 | 75 | ########################## training the classifier using both original and interpolated data using the source domains: "cartoon photo sketch" domains => indexes: "1" "2" "3" 76 | source_domains="123" 77 | test_env=0 78 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 79 | 80 | ### first seed 81 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 82 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 83 | 84 | 85 | ### second seed 86 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 87 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 88 | 89 | 90 | ### third seed 91 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 92 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 93 | 94 | -------------------------------------------------------------------------------- /bash/train_cls/officehome_interpolation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=OfficeHome 6 | save_dir="./save/train_cls_fds/${dataset}/" 7 | data_dir="./data" 8 | gen_num_per_class=70 # number of images/class of the interpolated (generated) data that will be selecte to used in training 9 | 10 | 11 | 12 | ########################## training the classifier using both original and interpolated data using the source domains: (Art, Clipart, Product) domains => indexes: "0" "1" "2" 13 | source_domains="012" 14 | test_env=3 15 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 16 | 17 | ### first seed 18 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 19 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 20 | 21 | 22 | ### second seed 23 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 24 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 25 | 26 | 27 | ### third seed 28 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 29 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 30 | 31 | 32 | 33 | ########################## training the classifier using both original and interpolated data using the source domains: (Art, Clipart, Real World) domains => indexes: "0" "1" "3" 34 | source_domains="013" 35 | test_env=2 36 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 37 | 38 | ### first seed 39 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 40 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 41 | 42 | 43 | ### second seed 44 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 45 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 46 | 47 | 48 | ### third seed 49 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 50 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 51 | 52 | 53 | 54 | ########################## training the classifier using both original and interpolated data using the source domains: (Art, Product, Real World) domains => indexes: "0" "2" "3" 55 | source_domains="023" 56 | test_env=1 57 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 58 | 59 | ### first seed 60 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 61 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 62 | 63 | 64 | ### second seed 65 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 66 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 67 | 68 | 69 | ### third seed 70 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 71 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 72 | 73 | 74 | 75 | ########################## training the classifier using both original and interpolated data using the source domains: (Clipart, Product, Real World) domains => indexes: "1" "2" "3" 76 | source_domains="123" 77 | test_env=0 78 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 79 | 80 | ### first seed 81 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 82 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 83 | 84 | 85 | ### second seed 86 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 87 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 88 | 89 | 90 | ### third seed 91 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 92 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 100 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 93 | 94 | -------------------------------------------------------------------------------- /domainbed/swad.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import deque 3 | import numpy as np 4 | from domainbed.lib import swa_utils 5 | 6 | 7 | class SWADBase: 8 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 9 | raise NotImplementedError() 10 | 11 | def get_final_model(self): 12 | raise NotImplementedError() 13 | 14 | 15 | class IIDMax(SWADBase): 16 | """SWAD start from iid max acc and select last by iid max swa acc""" 17 | 18 | def __init__(self, evaluator, **kwargs): 19 | self.iid_max_acc = 0.0 20 | self.swa_max_acc = 0.0 21 | self.avgmodel = None 22 | self.final_model = None 23 | self.evaluator = evaluator 24 | 25 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 26 | if self.iid_max_acc < val_acc: 27 | self.iid_max_acc = val_acc 28 | self.avgmodel = swa_utils.AveragedModel(segment_swa.module, rm_optimizer=True) 29 | self.avgmodel.start_step = segment_swa.start_step 30 | 31 | self.avgmodel.update_parameters(segment_swa.module) 32 | self.avgmodel.end_step = segment_swa.end_step 33 | 34 | # evaluate 35 | accuracies, summaries = self.evaluator.evaluate(self.avgmodel) 36 | results = {**summaries, **accuracies} 37 | prt_fn(results, self.avgmodel) 38 | 39 | swa_val_acc = results["train_out"] 40 | if swa_val_acc > self.swa_max_acc: 41 | self.swa_max_acc = swa_val_acc 42 | self.final_model = copy.deepcopy(self.avgmodel) 43 | 44 | def get_final_model(self): 45 | return self.final_model 46 | 47 | 48 | class LossValley(SWADBase): 49 | """IIDMax has a potential problem that bias to validation dataset. 50 | LossValley choose SWAD range by detecting loss valley. 51 | """ 52 | 53 | def __init__(self, evaluator, n_converge, n_tolerance, tolerance_ratio, **kwargs): 54 | """ 55 | Args: 56 | evaluator 57 | n_converge: converge detector window size. 58 | n_tolerance: loss min smoothing window size 59 | tolerance_ratio: decision ratio for dead loss valley 60 | """ 61 | self.evaluator = evaluator 62 | self.n_converge = n_converge 63 | self.n_tolerance = n_tolerance 64 | self.tolerance_ratio = tolerance_ratio 65 | 66 | self.converge_Q = deque(maxlen=n_converge) 67 | self.smooth_Q = deque(maxlen=n_tolerance) 68 | 69 | self.final_model = None 70 | 71 | self.converge_step = None 72 | self.dead_valley = False 73 | self.threshold = None 74 | 75 | def get_smooth_loss(self, idx): 76 | smooth_loss = min([model.end_loss for model in list(self.smooth_Q)[idx:]]) 77 | return smooth_loss 78 | 79 | @property 80 | def is_converged(self): 81 | return self.converge_step is not None 82 | 83 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 84 | if self.dead_valley: 85 | return 86 | 87 | frozen = copy.deepcopy(segment_swa.cpu()) 88 | frozen.end_loss = val_loss 89 | self.converge_Q.append(frozen) 90 | self.smooth_Q.append(frozen) 91 | 92 | if not self.is_converged: 93 | if len(self.converge_Q) < self.n_converge: 94 | return 95 | 96 | min_idx = np.argmin([model.end_loss for model in self.converge_Q]) 97 | untilmin_segment_swa = self.converge_Q[min_idx] # until-min segment swa. 98 | if min_idx == 0: 99 | self.converge_step = self.converge_Q[0].end_step 100 | self.final_model = swa_utils.AveragedModel(untilmin_segment_swa) 101 | 102 | th_base = np.mean([model.end_loss for model in self.converge_Q]) 103 | self.threshold = th_base * (1.0 + self.tolerance_ratio) 104 | 105 | if self.n_tolerance < self.n_converge: 106 | for i in range(self.n_converge - self.n_tolerance): 107 | model = self.converge_Q[1 + i] 108 | self.final_model.update_parameters( 109 | model, start_step=model.start_step, end_step=model.end_step 110 | ) 111 | elif self.n_tolerance > self.n_converge: 112 | converge_idx = self.n_tolerance - self.n_converge 113 | Q = list(self.smooth_Q)[: converge_idx + 1] 114 | start_idx = 0 115 | for i in reversed(range(len(Q))): 116 | model = Q[i] 117 | if model.end_loss > self.threshold: 118 | start_idx = i + 1 119 | break 120 | for model in Q[start_idx + 1 :]: 121 | self.final_model.update_parameters( 122 | model, start_step=model.start_step, end_step=model.end_step 123 | ) 124 | print( 125 | f"Model converged at step {self.converge_step}, " 126 | f"Start step = {self.final_model.start_step}; " 127 | f"Threshold = {self.threshold:.6f}, " 128 | ) 129 | return 130 | 131 | if self.smooth_Q[0].end_step < self.converge_step: 132 | return 133 | 134 | # converged -> loss valley 135 | min_vloss = self.get_smooth_loss(0) 136 | if min_vloss > self.threshold: 137 | self.dead_valley = True 138 | print(f"Valley is dead at step {self.final_model.end_step}") 139 | return 140 | 141 | model = self.smooth_Q[0] 142 | self.final_model.update_parameters( 143 | model, start_step=model.start_step, end_step=model.end_step 144 | ) 145 | 146 | def get_final_model(self): 147 | if not self.is_converged: 148 | self.evaluator.logger.error( 149 | "Requested final model, but model is not yet converged; return last model instead" 150 | ) 151 | return self.converge_Q[-1].cuda() 152 | 153 | if not self.dead_valley: 154 | self.smooth_Q.popleft() 155 | while self.smooth_Q: 156 | smooth_loss = self.get_smooth_loss(0) 157 | if smooth_loss > self.threshold: 158 | break 159 | segment_swa = self.smooth_Q.popleft() 160 | self.final_model.update_parameters(segment_swa, step=segment_swa.end_step) 161 | 162 | return self.final_model.cuda() 163 | -------------------------------------------------------------------------------- /bash/train_cls/vlcs_interpolation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ### parameters 5 | dataset=VLCS 6 | save_dir="./save/train_cls_fds/${dataset}/" 7 | data_dir="./data" 8 | gen_num_per_class=680 # number of images/class of the interpolated (generated) data that will be selecte to used in training 9 | 10 | 11 | 12 | ########################## training the classifier using both original and interpolated data using the source domains: (Caltech101, LabelMe, SUN09) domains => indexes: "0" "1" "2" 13 | source_domains="012" 14 | test_env=3 15 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 16 | 17 | ### first seed 18 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 19 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 20 | 21 | 22 | ### second seed 23 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 24 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 25 | 26 | 27 | ### third seed 28 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 29 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 30 | 31 | 32 | 33 | ########################## training the classifier using both original and interpolated data using the source domains: (Caltech101, LabelMe, VOC2007) domains => indexes: "0" "1" "3" 34 | source_domains="013" 35 | test_env=2 36 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 37 | 38 | ### first seed 39 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 40 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 41 | 42 | 43 | ### second seed 44 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 45 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 46 | 47 | 48 | ### third seed 49 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 50 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 51 | 52 | 53 | 54 | ########################## training the classifier using both original and interpolated data using the source domains: (Caltech101, SUN09, VOC2007) domains => indexes: "0" "2" "3" 55 | source_domains="023" 56 | test_env=1 57 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 58 | 59 | ### first seed 60 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 61 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 62 | 63 | 64 | ### second seed 65 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 66 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 67 | 68 | 69 | ### third seed 70 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 71 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 72 | 73 | 74 | 75 | ########################## training the classifier using both original and interpolated data using the source domains: (LabelMe, SUN09, VOC2007) domains => indexes: "1" "2" "3" 76 | source_domains="123" 77 | test_env=0 78 | gen_data_dir="save/dm/${dataset}/${source_domains}/generation" 79 | 80 | ### first seed 81 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed0/image_predictions.csv" 82 | python train_cls.py ${dataset}0_${source_domains} --dataset ${dataset} --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 83 | 84 | 85 | ### second seed 86 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed1/image_predictions.csv" 87 | python train_cls.py ${dataset}1_${source_domains} --dataset ${dataset} --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 88 | 89 | 90 | ### third seed 91 | gen_csv_dir="save/eval/${dataset}/${source_domains}/seed2/image_predictions.csv" 92 | python train_cls.py ${dataset}2_${source_domains} --dataset ${dataset} --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --test_envs ${test_env} --data_dir ${data_dir} --work_dir $save_dir --use_gen --gen_data_dir ${gen_data_dir} --gen_csv_dir ${gen_csv_dir} --gen_num_per_class ${gen_num_per_class} --gen_only_correct 93 | 94 | -------------------------------------------------------------------------------- /ldm/data/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | from omegaconf import DictConfig, ListConfig 4 | import torch 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import json 8 | from PIL import Image 9 | from torchvision import transforms 10 | from einops import rearrange 11 | from ldm.util import instantiate_from_config 12 | 13 | def make_multi_folder_data(paths, caption_files=None, **kwargs): 14 | """Make a concat dataset from multiple folders 15 | Don't suport captions yet 16 | 17 | If paths is a list, that's ok, if it's a Dict interpret it as: 18 | k=folder v=n_times to repeat that 19 | """ 20 | list_of_paths = [] 21 | if isinstance(paths, (Dict, DictConfig)): 22 | assert caption_files is None, \ 23 | "Caption files not yet supported for repeats" 24 | for folder_path, repeats in paths.items(): 25 | list_of_paths.extend([folder_path]*repeats) 26 | paths = list_of_paths 27 | 28 | if caption_files is not None: 29 | datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] 30 | else: 31 | datasets = [FolderData(p, **kwargs) for p in paths] 32 | return torch.utils.data.ConcatDataset(datasets) 33 | 34 | class FolderData(Dataset): 35 | def __init__(self, 36 | root_dir, 37 | caption_file=None, 38 | image_transforms=[], 39 | ext="jpg", 40 | default_caption="", 41 | postprocess=None, 42 | ) -> None: 43 | """Create a dataset from a folder of images. 44 | If you pass in a root directory it will be searched for images 45 | ending in ext (ext can be a list) 46 | """ 47 | self.root_dir = Path(root_dir) 48 | self.default_caption = default_caption 49 | if isinstance(postprocess, DictConfig): 50 | postprocess = instantiate_from_config(postprocess) 51 | self.postprocess = postprocess 52 | if caption_file is not None: 53 | with open(caption_file, "rt") as f: 54 | ext = Path(caption_file).suffix.lower() 55 | if ext == ".json": 56 | captions = json.load(f) 57 | elif ext == ".jsonl": 58 | lines = f.readlines() 59 | lines = [json.loads(x) for x in lines] 60 | captions = {x["file_name"]: x["text"].strip("\n") for x in lines} 61 | else: 62 | raise ValueError(f"Unrecognised format: {ext}") 63 | self.captions = captions 64 | else: 65 | self.captions = None 66 | 67 | if not isinstance(ext, (tuple, list, ListConfig)): 68 | ext = [ext] 69 | 70 | # Only used if there is no caption file 71 | self.paths = [] 72 | for e in ext: 73 | self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) 74 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 75 | image_transforms.extend([transforms.ToTensor(), 76 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 77 | image_transforms = transforms.Compose(image_transforms) 78 | self.tform = image_transforms 79 | 80 | # assert all(['full/' + str(x.name) in self.captions for x in self.paths]) 81 | 82 | def __len__(self): 83 | if self.captions is not None: 84 | return len(self.captions.keys()) 85 | else: 86 | return len(self.paths) 87 | 88 | def __getitem__(self, index): 89 | if self.captions is not None: 90 | chosen = list(self.captions.keys())[index] 91 | caption = self.captions.get(chosen, None) 92 | if caption is None: 93 | caption = self.default_caption 94 | im = Image.open(self.root_dir/chosen) 95 | else: 96 | im = Image.open(self.paths[index]) 97 | 98 | im = self.process_im(im) 99 | data = {"image": im} 100 | if self.captions is not None: 101 | data["txt"] = caption 102 | else: 103 | data["txt"] = self.default_caption 104 | 105 | if self.postprocess is not None: 106 | data = self.postprocess(data) 107 | return data 108 | 109 | def process_im(self, im): 110 | im = im.convert("RGB") 111 | return self.tform(im) 112 | 113 | def hf_dataset( 114 | name, 115 | image_transforms=[], 116 | image_column="image", 117 | text_column="text", 118 | split='train', 119 | image_key='image', 120 | caption_key='txt', 121 | ): 122 | """Make huggingface dataset with appropriate list of transforms applied 123 | """ 124 | ds = load_dataset(name, split=split) 125 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 126 | image_transforms.extend([transforms.ToTensor(), 127 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 128 | tform = transforms.Compose(image_transforms) 129 | 130 | assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" 131 | assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" 132 | 133 | def pre_process(examples): 134 | processed = {} 135 | processed[image_key] = [tform(im) for im in examples[image_column]] 136 | processed[caption_key] = examples[text_column] 137 | return processed 138 | 139 | ds.set_transform(pre_process) 140 | return ds 141 | 142 | class TextOnly(Dataset): 143 | def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): 144 | """Returns only captions with dummy images""" 145 | self.output_size = output_size 146 | self.image_key = image_key 147 | self.caption_key = caption_key 148 | if isinstance(captions, Path): 149 | self.captions = self._load_caption_file(captions) 150 | else: 151 | self.captions = captions 152 | 153 | if n_gpus > 1: 154 | # hack to make sure that all the captions appear on each gpu 155 | repeated = [n_gpus*[x] for x in self.captions] 156 | self.captions = [] 157 | [self.captions.extend(x) for x in repeated] 158 | 159 | def __len__(self): 160 | return len(self.captions) 161 | 162 | def __getitem__(self, index): 163 | dummy_im = torch.zeros(3, self.output_size, self.output_size) 164 | dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') 165 | return {self.image_key: dummy_im, self.caption_key: self.captions[index]} 166 | 167 | def _load_caption_file(self, filename): 168 | with open(filename, 'rt') as f: 169 | captions = f.readlines() 170 | return [x.strip('\n') for x in captions] 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FDS: Feedback-guided Domain Synthesis with Multi-Source Conditional Diffusion Models for Domain Generalization 2 | 3 | The official implementation of our paper "Feedback-guided Domain Synthesis with Multi-Source Conditional Diffusion Models for Domain Generalization". 4 | 5 | 6 | ## Mehtod 7 | 8 |
10 |
103 |
110 |