├── .gitignore ├── README.md ├── configs ├── kl.yaml └── mvtec.yaml ├── environment.yaml ├── rec_network ├── data │ ├── __init__.py │ ├── base.py │ ├── mvtec.py │ └── perlin.py ├── lr_scheduler.py ├── main.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── scripts ├── download_dataset.sh └── mvtec.py └── seg_network ├── data_loader.py ├── loss.py ├── model_unet.py ├── perlin.py ├── tensorboard_visualizer.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffAD 2 | [ICCV2023] Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model 3 | 4 | ``` 5 | @inproceedings{zhang2023unsupervised, 6 | title={Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model}, 7 | author={Zhang, Xinyi and Li, Naiqi and Li, Jiawei and Dai, Tao and Jiang, Yong and Xia, Shu-Tao}, 8 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 9 | pages={6782--6791}, 10 | year={2023} 11 | } 12 | ``` 13 | 14 | ## Method overview 15 | image 16 | 17 | ## Installation 18 | ``` 19 | conda env create -f environment.yaml 20 | conda activate DiffAD 21 | ``` 22 | 23 | ## Dataset 24 | Following DRAEM, we use the MVTec-AD and DTD dataset. You can run the download_dataset.sh script from the project directory to download the MVTec and the DTD datasets to the datasets folder in the project directory: 25 | ``` 26 | ./scripts/download_dataset.sh 27 | ``` 28 | 29 | ## Training 30 | ### Reconstruction sub-network 31 | The reconstrucion sub-network is based on the latent diffusion model. 32 | #### Training Auto-encoder 33 | ``` 34 | cd rec_network 35 | CUDA_VISIBLE_DEVICES= python main.py --base configs/kl.yaml -t --gpus 0, 36 | ``` 37 | #### Training LDMs 38 | ``` 39 | CUDA_VISIBLE_DEVICES= python main.py --base configs/mvtec.yaml -t --gpus 0 -max_epochs 4000, 40 | ``` 41 | 42 | ### Discriminative sub-network 43 | ``` 44 | cd seg_network 45 | CUDA_VISIBLE_DEVICES= python train.py --gpu_id 0 --lr 0.001 --bs 32 --epochs 700 --data_path ./datasets/mvtec/ --anomaly_source_path ./datasets/dtd/images/ --checkpoint_path ./checkpoints/obj_name --log_path ./logs/ 46 | ``` 47 | 48 | ## Evaluating 49 | ### Reconstrucion performance 50 | After training the reconstruction sub-network, you can test the reconstruction performance with the anomalous inputs: 51 | ``` 52 | python scripts/mvtec.py 53 | ``` 54 | For some samples with severe deformations, such as missing transistors, you can add some noise to the anomalous conditions to adjust the sampling. 55 | 56 | ### Anomaly segmentation 57 | ``` 58 | cd seg_network 59 | python test.py --gpu_id 0 --base_model_name "seg_network" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/obj_name/ 60 | ``` 61 | 62 | 63 | -------------------------------------------------------------------------------- /configs/kl.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: rec_network.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: rec_network.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.5 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 8 30 | num_workers: 0 31 | wrap: false 32 | train: 33 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset 34 | params: 35 | root_dir: './datasets/mvtec/grid/train/good' 36 | anomaly_source_path: './datasets/dtd/images' 37 | resize_shape: [256, 256] 38 | validation: 39 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset 40 | params: 41 | root_dir: './datasets/mvtec/grid/train/good' 42 | anomaly_source_path: './datasets/dtd/images' 43 | resize_shape: [ 256, 256 ] 44 | lightning: 45 | callbacks: 46 | image_logger: 47 | target: main.ImageLogger 48 | params: 49 | batch_frequency: 1000 50 | max_images: 8 51 | increase_log_steps: True 52 | 53 | trainer: 54 | benchmark: True 55 | accumulate_grad_batches: 2 56 | -------------------------------------------------------------------------------- /configs/mvtec.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: rec_network.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | #ckpt_path: modify the ckpt_path of rec_network when training the seg_network 6 | linear_start: 0.0015 7 | linear_end: 0.02 8 | num_timesteps_cond: 1 9 | log_every_t: 100 10 | timesteps: 1000 11 | first_stage_key: image 12 | cond_stage_key: augmented_image 13 | image_size: 32 14 | channels: 4 15 | concat_mode: true 16 | cond_stage_trainable: false 17 | conditioning_key: concat 18 | monitor: val/loss_simple_ema 19 | unet_config: 20 | target: rec_network.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 32 23 | in_channels: 8 24 | out_channels: 4 25 | model_channels: 256 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 4 35 | num_head_channels: 32 36 | 37 | first_stage_config: 38 | target: rec_network.models.autoencoder.AutoencoderKL 39 | params: 40 | embed_dim: 4 41 | monitor: "val/rec_loss" 42 | ckpt_path: "./VAE/bottle.ckpt" #TODO: modify the ckpt_path of VAE 43 | ddconfig: 44 | double_z: True 45 | z_channels: 4 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 51 | num_res_blocks: 2 52 | attn_resolutions: [ ] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_first_stage__ 57 | 58 | data: 59 | target: main.DataModuleFromConfig 60 | params: 61 | batch_size: 32 62 | num_workers: 0 63 | wrap: false 64 | train: 65 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset 66 | params: 67 | root_dir: './datasets/mvtec/bottle/train/good' #TODO: modify the path of training samples 68 | anomaly_source_path: './datasets/dtd/images' 69 | resize_shape: 70 | - 256 71 | - 256 72 | 73 | 74 | lightning: 75 | callbacks: 76 | metrics_over_trainsteps_checkpoint: 77 | target: pytorch_lightning.callbacks.ModelCheckpoint 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 1000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: DiffAD 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - transformers==4.3.1 25 | - imgaug 26 | -------------------------------------------------------------------------------- /rec_network/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/data/__init__.py -------------------------------------------------------------------------------- /rec_network/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /rec_network/data/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | import cv2 6 | import glob 7 | import imgaug.augmenters as iaa 8 | from perlin import rand_perlin_2d_np 9 | 10 | obj_list = ['capsule', 11 | 'bottle', 12 | 'carpet', 13 | 'leather', 14 | 'pill', 15 | 'transistor', 16 | 'tile', 17 | 'cable', 18 | 'zipper', 19 | 'toothbrush', 20 | 'metal_nut', 21 | 'hazelnut', 22 | 'screw', 23 | 'grid', 24 | 'wood'] 25 | 26 | class MVTecDRAEMTestDataset(Dataset): 27 | 28 | def __init__(self, root_dir, resize_shape=None): 29 | self.root_dir = root_dir 30 | self.images = sorted(glob.glob(root_dir+"/*/*.png")) 31 | 32 | # self.images = [] 33 | # for obj in obj_list: 34 | # self.images += sorted(glob.glob(root_dir + obj + "/test/*/*.png")) 35 | 36 | self.resize_shape=resize_shape 37 | 38 | def __len__(self): 39 | return len(self.images) 40 | 41 | def transform_image(self, image_path, mask_path): 42 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 43 | if mask_path is not None: 44 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 45 | else: 46 | mask = np.zeros((image.shape[0],image.shape[1])) 47 | if self.resize_shape != None: 48 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0])) 49 | mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0])) 50 | 51 | image = image / 255.0 52 | mask = mask / 255.0 53 | 54 | image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32) 55 | mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32) 56 | 57 | image = np.transpose(image, (2, 0, 1)) 58 | mask = np.transpose(mask, (2, 0, 1)) 59 | return image, mask 60 | 61 | def __getitem__(self, idx): 62 | if torch.is_tensor(idx): 63 | idx = idx.tolist() 64 | 65 | img_path = self.images[idx] 66 | dir_path, file_name = os.path.split(img_path) 67 | base_dir = os.path.basename(dir_path) 68 | if base_dir == 'good': 69 | image, mask = self.transform_image(img_path, None) 70 | has_anomaly = np.array([0], dtype=np.float32) 71 | else: 72 | mask_path = os.path.join(dir_path, '../../ground_truth/') 73 | mask_path = os.path.join(mask_path, base_dir) 74 | mask_file_name = file_name.split(".")[0]+"_mask.png" 75 | mask_path = os.path.join(mask_path, mask_file_name) 76 | image, mask = self.transform_image(img_path, mask_path) 77 | has_anomaly = np.array([1], dtype=np.float32) 78 | 79 | sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx} 80 | 81 | return sample 82 | 83 | 84 | 85 | class MVTecDRAEMTrainDataset(Dataset): 86 | 87 | def __init__(self, root_dir, anomaly_source_path, resize_shape=None): 88 | """ 89 | Args: 90 | root_dir (string): Directory with all the images. 91 | transform (callable, optional): Optional transform to be applied 92 | on a sample. 93 | """ 94 | 95 | self.root_dir = root_dir 96 | self.resize_shape=resize_shape 97 | 98 | # /data / zhangxinyi / dataset / mvtec / toothbrush / train / good 99 | # self.image_paths = [] 100 | # for obj in obj_list: 101 | # obj = 'cable' 102 | # self.image_paths += sorted(glob.glob(root_dir + obj + "/train/good/*.png")) 103 | # print(len(self.image_paths)) 104 | 105 | self.image_paths = sorted(glob.glob(root_dir + "/*.png")) 106 | 107 | self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path+"/*/*.jpg")) 108 | 109 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True), 110 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 111 | iaa.pillike.EnhanceSharpness(), 112 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 113 | iaa.Solarize(0.5, threshold=(32,128)), 114 | iaa.Posterize(), 115 | iaa.Invert(), 116 | iaa.pillike.Autocontrast(), 117 | iaa.pillike.Equalize(), 118 | iaa.Affine(rotate=(-45, 45)) 119 | ] 120 | 121 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 122 | 123 | 124 | def __len__(self): 125 | return len(self.image_paths) 126 | 127 | 128 | def randAugmenter(self): 129 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) 130 | aug = iaa.Sequential([self.augmenters[aug_ind[0]], 131 | self.augmenters[aug_ind[1]], 132 | self.augmenters[aug_ind[2]]] 133 | ) 134 | return aug 135 | 136 | def augment_image(self, image, anomaly_source_path): 137 | aug = self.randAugmenter() 138 | perlin_scale = 6 139 | min_perlin_scale = 0 140 | anomaly_source_img = cv2.imread(anomaly_source_path) 141 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize_shape[1], self.resize_shape[0])) 142 | 143 | anomaly_img_augmented = aug(image=anomaly_source_img) 144 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 145 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 146 | 147 | perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley)) 148 | perlin_noise = self.rot(image=perlin_noise) 149 | threshold = 0.5 150 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 151 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 152 | 153 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 154 | 155 | beta = torch.rand(1).numpy()[0] * 0.8 156 | 157 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * ( 158 | perlin_thr) 159 | 160 | no_anomaly = torch.rand(1).numpy()[0] 161 | if no_anomaly > 0.5: 162 | image = image.astype(np.float32) 163 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32) 164 | else: 165 | augmented_image = augmented_image.astype(np.float32) 166 | msk = (perlin_thr).astype(np.float32) 167 | augmented_image = msk * augmented_image + (1-msk)*image 168 | has_anomaly = 1.0 169 | if np.sum(msk) == 0: 170 | has_anomaly=0.0 171 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32) 172 | 173 | def transform_image(self, image_path, anomaly_source_path): 174 | image = cv2.imread(image_path) 175 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0])) 176 | 177 | do_aug_orig = torch.rand(1).numpy()[0] > 0.7 178 | if do_aug_orig: 179 | image = self.rot(image=image) 180 | 181 | image = np.array(image).reshape((image.shape[0], image.shape[1], image.shape[2])).astype(np.float32) / 255.0 182 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path) 183 | augmented_image = np.transpose(augmented_image, (2, 0, 1)) 184 | image = np.transpose(image, (2, 0, 1)) 185 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1)) 186 | return image, augmented_image, anomaly_mask, has_anomaly 187 | 188 | def __getitem__(self, idx): 189 | idx = torch.randint(0, len(self.image_paths), (1,)).item() 190 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item() 191 | image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(self.image_paths[idx], 192 | self.anomaly_source_paths[anomaly_source_idx]) 193 | sample = {'image': image, "anomaly_mask": anomaly_mask, 194 | 'augmented_image': augmented_image, 'has_anomaly': has_anomaly, 'idx': idx} 195 | return sample 196 | -------------------------------------------------------------------------------- /rec_network/data/perlin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def lerp_np(x,y,w): 6 | fin_out = (y-x)*w + x 7 | return fin_out 8 | 9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 10 | noise = np.zeros(shape) 11 | frequency = 1 12 | amplitude = 1 13 | for _ in range(octaves): 14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 15 | frequency *= 2 16 | amplitude *= persistence 17 | return noise 18 | 19 | 20 | def generate_perlin_noise_2d(shape, res): 21 | def f(t): 22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 23 | 24 | delta = (res[0] / shape[0], res[1] / shape[1]) 25 | d = (shape[0] // res[0], shape[1] // res[1]) 26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 27 | # Gradients 28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 29 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 34 | # Ramps 35 | n00 = np.sum(grid * g00, 2) 36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 39 | # Interpolation 40 | t = f(grid) 41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 44 | 45 | 46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 47 | delta = (res[0] / shape[0], res[1] / shape[1]) 48 | d = (shape[0] // res[0], shape[1] // res[1]) 49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 50 | 51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 54 | 55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 56 | dot = lambda grad, shift: ( 57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 59 | 60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 64 | t = fade(grid[:shape[0], :shape[1]]) 65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 66 | 67 | 68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 69 | delta = (res[0] / shape[0], res[1] / shape[1]) 70 | d = (shape[0] // res[0], shape[1] // res[1]) 71 | 72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 75 | 76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 77 | 0).repeat_interleave( 78 | d[1], 1) 79 | dot = lambda grad, shift: ( 80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 82 | 83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 84 | 85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 88 | t = fade(grid[:shape[0], :shape[1]]) 89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 90 | 91 | 92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 93 | noise = torch.zeros(shape) 94 | frequency = 1 95 | amplitude = 1 96 | for _ in range(octaves): 97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) 98 | frequency *= 2 99 | amplitude *= persistence 100 | return noise -------------------------------------------------------------------------------- /rec_network/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /rec_network/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | from PIL import Image 6 | import numpy as np 7 | 8 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 9 | 10 | from rec_network.modules.diffusionmodules.model import Encoder, Decoder 11 | from rec_network.modules.distributions.distributions import DiagonalGaussianDistribution 12 | 13 | from rec_network.util import instantiate_from_config 14 | 15 | 16 | class VQModel(pl.LightningModule): 17 | def __init__(self, 18 | ddconfig, 19 | lossconfig, 20 | n_embed, 21 | embed_dim, 22 | ckpt_path=None, 23 | ignore_keys=[], 24 | image_key="image", 25 | colorize_nlabels=None, 26 | monitor=None, 27 | batch_resize_range=None, 28 | scheduler_config=None, 29 | lr_g_factor=1.0, 30 | remap=None, 31 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 32 | use_ema=False 33 | ): 34 | super().__init__() 35 | self.embed_dim = embed_dim 36 | self.n_embed = n_embed 37 | self.image_key = image_key 38 | self.encoder = Encoder(**ddconfig) 39 | self.decoder = Decoder(**ddconfig) 40 | self.loss = instantiate_from_config(lossconfig) 41 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 42 | remap=remap, 43 | sane_index_shape=sane_index_shape) 44 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 45 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 46 | if colorize_nlabels is not None: 47 | assert type(colorize_nlabels)==int 48 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 49 | if monitor is not None: 50 | self.monitor = monitor 51 | self.batch_resize_range = batch_resize_range 52 | if self.batch_resize_range is not None: 53 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") 54 | 55 | self.use_ema = use_ema 56 | if self.use_ema: 57 | self.model_ema = LitEma(self) 58 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 59 | 60 | if ckpt_path is not None: 61 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 62 | self.scheduler_config = scheduler_config 63 | self.lr_g_factor = lr_g_factor 64 | 65 | @contextmanager 66 | def ema_scope(self, context=None): 67 | if self.use_ema: 68 | self.model_ema.store(self.parameters()) 69 | self.model_ema.copy_to(self) 70 | if context is not None: 71 | print(f"{context}: Switched to EMA weights") 72 | try: 73 | yield None 74 | finally: 75 | if self.use_ema: 76 | self.model_ema.restore(self.parameters()) 77 | if context is not None: 78 | print(f"{context}: Restored training weights") 79 | 80 | def init_from_ckpt(self, path, ignore_keys=list()): 81 | sd = torch.load(path, map_location="cpu")["state_dict"] 82 | keys = list(sd.keys()) 83 | for k in keys: 84 | for ik in ignore_keys: 85 | if k.startswith(ik): 86 | print("Deleting key {} from state_dict.".format(k)) 87 | del sd[k] 88 | missing, unexpected = self.load_state_dict(sd, strict=False) 89 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 90 | if len(missing) > 0: 91 | print(f"Missing Keys: {missing}") 92 | print(f"Unexpected Keys: {unexpected}") 93 | 94 | def on_train_batch_end(self, *args, **kwargs): 95 | if self.use_ema: 96 | self.model_ema(self) 97 | 98 | def encode(self, x): 99 | h = self.encoder(x) 100 | h = self.quant_conv(h) 101 | quant, emb_loss, info = self.quantize(h) 102 | return quant, emb_loss, info 103 | 104 | def encode_to_prequant(self, x): 105 | h = self.encoder(x) 106 | h = self.quant_conv(h) 107 | return h 108 | 109 | def decode(self, quant): 110 | quant = self.post_quant_conv(quant) 111 | dec = self.decoder(quant) 112 | return dec 113 | 114 | def decode_code(self, code_b): 115 | quant_b = self.quantize.embed_code(code_b) 116 | dec = self.decode(quant_b) 117 | return dec 118 | 119 | def forward(self, input, return_pred_indices=False): 120 | quant, diff, (_,_,ind) = self.encode(input) 121 | dec = self.decode(quant) 122 | if return_pred_indices: 123 | return dec, diff, ind 124 | return dec, diff 125 | 126 | def get_input(self, batch, k): 127 | x = batch[k] 128 | if len(x.shape) == 3: 129 | x = x[..., None] 130 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 131 | x = x.to(memory_format=torch.contiguous_format).float() 132 | if self.batch_resize_range is not None: 133 | lower_size = self.batch_resize_range[0] 134 | upper_size = self.batch_resize_range[1] 135 | if self.global_step <= 4: 136 | # do the first few batches with max size to avoid later oom 137 | new_resize = upper_size 138 | else: 139 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) 140 | if new_resize != x.shape[2]: 141 | x = F.interpolate(x, size=new_resize, mode="bicubic") 142 | x = x.detach() 143 | return x 144 | 145 | def training_step(self, batch, batch_idx, optimizer_idx): 146 | # https://github.com/pytorch/pytorch/issues/37142 147 | # try not to fool the heuristics 148 | x = self.get_input(batch, self.image_key) 149 | 150 | # outpath = './test/input.jpg' 151 | # sample = x.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 152 | # Image.fromarray(sample.astype(np.uint8)).save(outpath) 153 | 154 | xrec, qloss, ind = self(x, return_pred_indices=True) 155 | 156 | if optimizer_idx == 0: 157 | # autoencode 158 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 159 | last_layer=self.get_last_layer(), split="train" 160 | ) 161 | 162 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 163 | return aeloss 164 | 165 | if optimizer_idx == 1: 166 | # discriminator 167 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 168 | last_layer=self.get_last_layer(), split="train") 169 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 170 | return discloss 171 | 172 | def validation_step(self, batch, batch_idx): 173 | log_dict = self._validation_step(batch, batch_idx) 174 | with self.ema_scope(): 175 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") 176 | return log_dict 177 | 178 | def _validation_step(self, batch, batch_idx, suffix=""): 179 | x = self.get_input(batch, self.image_key) 180 | xrec, qloss, ind = self(x, return_pred_indices=True) 181 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, 182 | self.global_step, 183 | last_layer=self.get_last_layer(), 184 | split="val"+suffix 185 | ) 186 | 187 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, 188 | self.global_step, 189 | last_layer=self.get_last_layer(), 190 | split="val"+suffix 191 | ) 192 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] 193 | self.log(f"val{suffix}/rec_loss", rec_loss, 194 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 195 | self.log(f"val{suffix}/aeloss", aeloss, 196 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 197 | 198 | del log_dict_ae[f"val{suffix}/rec_loss"] 199 | self.log_dict(log_dict_ae) 200 | self.log_dict(log_dict_disc) 201 | return self.log_dict 202 | 203 | def configure_optimizers(self): 204 | lr_d = self.learning_rate 205 | lr_g = self.lr_g_factor*self.learning_rate 206 | print("lr_d", lr_d) 207 | print("lr_g", lr_g) 208 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 209 | list(self.decoder.parameters())+ 210 | list(self.quantize.parameters())+ 211 | list(self.quant_conv.parameters())+ 212 | list(self.post_quant_conv.parameters()), 213 | lr=lr_g, betas=(0.5, 0.9)) 214 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 215 | lr=lr_d, betas=(0.5, 0.9)) 216 | 217 | if self.scheduler_config is not None: 218 | scheduler = instantiate_from_config(self.scheduler_config) 219 | 220 | print("Setting up LambdaLR scheduler...") 221 | scheduler = [ 222 | { 223 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 224 | 'interval': 'step', 225 | 'frequency': 1 226 | }, 227 | { 228 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 229 | 'interval': 'step', 230 | 'frequency': 1 231 | }, 232 | ] 233 | return [opt_ae, opt_disc], scheduler 234 | return [opt_ae, opt_disc], [] 235 | 236 | def get_last_layer(self): 237 | return self.decoder.conv_out.weight 238 | 239 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): 240 | log = dict() 241 | x = self.get_input(batch, self.image_key) 242 | x = x.to(self.device) 243 | if only_inputs: 244 | log["inputs"] = x 245 | return log 246 | xrec, _ = self(x) 247 | if x.shape[1] > 3: 248 | # colorize with random projection 249 | assert xrec.shape[1] > 3 250 | x = self.to_rgb(x) 251 | xrec = self.to_rgb(xrec) 252 | log["inputs"] = x 253 | log["reconstructions"] = xrec 254 | if plot_ema: 255 | with self.ema_scope(): 256 | xrec_ema, _ = self(x) 257 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) 258 | log["reconstructions_ema"] = xrec_ema 259 | return log 260 | 261 | def to_rgb(self, x): 262 | assert self.image_key == "segmentation" 263 | if not hasattr(self, "colorize"): 264 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 265 | x = F.conv2d(x, weight=self.colorize) 266 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 267 | return x 268 | 269 | 270 | class VQModelInterface(VQModel): 271 | def __init__(self, embed_dim, *args, **kwargs): 272 | super().__init__(embed_dim=embed_dim, *args, **kwargs) 273 | self.embed_dim = embed_dim 274 | 275 | def encode(self, x): 276 | h = self.encoder(x) 277 | h = self.quant_conv(h) 278 | # print("!!!!encoding: ", h.shape) 279 | return h 280 | 281 | def decode(self, h, force_not_quantize=False): 282 | # also go through quantization layer 283 | if not force_not_quantize: 284 | quant, emb_loss, info = self.quantize(h) 285 | else: 286 | quant = h 287 | quant = self.post_quant_conv(quant) 288 | dec = self.decoder(quant) 289 | return dec 290 | 291 | 292 | class AutoencoderKL(pl.LightningModule): 293 | def __init__(self, 294 | ddconfig, 295 | lossconfig, 296 | embed_dim, 297 | ckpt_path=None, 298 | ignore_keys=[], 299 | image_key="image", 300 | colorize_nlabels=None, 301 | monitor=None, 302 | ): 303 | super().__init__() 304 | self.image_key = image_key 305 | self.encoder = Encoder(**ddconfig) 306 | self.decoder = Decoder(**ddconfig) 307 | self.loss = instantiate_from_config(lossconfig) 308 | assert ddconfig["double_z"] 309 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 310 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 311 | self.embed_dim = embed_dim 312 | if colorize_nlabels is not None: 313 | assert type(colorize_nlabels)==int 314 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 315 | if monitor is not None: 316 | self.monitor = monitor 317 | if ckpt_path is not None: 318 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 319 | 320 | def init_from_ckpt(self, path, ignore_keys=list()): 321 | sd = torch.load(path, map_location="cpu")["state_dict"] 322 | keys = list(sd.keys()) 323 | for k in keys: 324 | for ik in ignore_keys: 325 | if k.startswith(ik): 326 | print("Deleting key {} from state_dict.".format(k)) 327 | del sd[k] 328 | self.load_state_dict(sd, strict=False) 329 | print(f"Restored from {path}") 330 | 331 | def encode(self, x): 332 | h = self.encoder(x) 333 | moments = self.quant_conv(h) 334 | posterior = DiagonalGaussianDistribution(moments) 335 | return posterior 336 | 337 | def decode(self, z): 338 | z = self.post_quant_conv(z) 339 | dec = self.decoder(z) 340 | return dec 341 | 342 | def forward(self, input, sample_posterior=True): 343 | posterior = self.encode(input) 344 | if sample_posterior: 345 | z = posterior.sample() 346 | else: 347 | z = posterior.mode() 348 | dec = self.decode(z) 349 | return dec, posterior 350 | 351 | def get_input(self, batch, k): 352 | x = batch[k] 353 | if len(x.shape) == 3: 354 | x = x[..., None] 355 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 356 | return x 357 | 358 | def training_step(self, batch, batch_idx, optimizer_idx): 359 | inputs = self.get_input(batch, self.image_key) 360 | reconstructions, posterior = self(inputs) ## forward 361 | 362 | if optimizer_idx == 0: 363 | # train encoder+decoder+logvar 364 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 365 | last_layer=self.get_last_layer(), split="train") 366 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 367 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 368 | return aeloss 369 | 370 | if optimizer_idx == 1: 371 | # train the discriminator 372 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 373 | last_layer=self.get_last_layer(), split="train") 374 | 375 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 376 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 377 | return discloss 378 | 379 | def validation_step(self, batch, batch_idx): 380 | inputs = self.get_input(batch, self.image_key) 381 | reconstructions, posterior = self(inputs) 382 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 383 | last_layer=self.get_last_layer(), split="val") 384 | 385 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 386 | last_layer=self.get_last_layer(), split="val") 387 | 388 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 389 | self.log_dict(log_dict_ae) 390 | self.log_dict(log_dict_disc) 391 | return self.log_dict 392 | 393 | def configure_optimizers(self): 394 | lr = self.learning_rate 395 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 396 | list(self.decoder.parameters())+ 397 | list(self.quant_conv.parameters())+ 398 | list(self.post_quant_conv.parameters()), 399 | lr=lr, betas=(0.5, 0.9)) 400 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 401 | lr=lr, betas=(0.5, 0.9)) 402 | return [opt_ae, opt_disc], [] 403 | 404 | def get_last_layer(self): 405 | return self.decoder.conv_out.weight 406 | 407 | @torch.no_grad() 408 | def log_images(self, batch, only_inputs=False, **kwargs): 409 | log = dict() 410 | x = self.get_input(batch, self.image_key) 411 | x = x.to(self.device) 412 | if not only_inputs: 413 | xrec, posterior = self(x) 414 | if x.shape[1] > 3: 415 | # colorize with random projection 416 | assert xrec.shape[1] > 3 417 | x = self.to_rgb(x) 418 | xrec = self.to_rgb(xrec) 419 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 420 | log["reconstructions"] = xrec 421 | log["inputs"] = x 422 | return log 423 | 424 | def to_rgb(self, x): 425 | assert self.image_key == "segmentation" 426 | if not hasattr(self, "colorize"): 427 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 428 | x = F.conv2d(x, weight=self.colorize) 429 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 430 | return x 431 | 432 | 433 | class IdentityFirstStage(torch.nn.Module): 434 | def __init__(self, *args, vq_interface=False, **kwargs): 435 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 436 | super().__init__() 437 | 438 | def encode(self, x, *args, **kwargs): 439 | return x 440 | 441 | def decode(self, x, *args, **kwargs): 442 | return x 443 | 444 | def quantize(self, x, *args, **kwargs): 445 | if self.vq_interface: 446 | return x, None, [None, None, None] 447 | return x 448 | 449 | def forward(self, x, *args, **kwargs): 450 | return x 451 | -------------------------------------------------------------------------------- /rec_network/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/models/diffusion/__init__.py -------------------------------------------------------------------------------- /rec_network/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from rec_network.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from rec_network.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /rec_network/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from rec_network.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class DDIMSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | self.register_buffer('betas', to_torch(self.model.betas)) 32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 41 | 42 | # ddim sampling parameters 43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 44 | ddim_timesteps=self.ddim_timesteps, 45 | eta=ddim_eta,verbose=verbose) 46 | self.register_buffer('ddim_sigmas', ddim_sigmas) 47 | self.register_buffer('ddim_alphas', ddim_alphas) 48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 54 | 55 | @torch.no_grad() 56 | def sample(self, 57 | S, 58 | batch_size, 59 | shape, 60 | conditioning=None, 61 | callback=None, 62 | normals_sequence=None, 63 | img_callback=None, 64 | quantize_x0=False, 65 | eta=0., 66 | mask=None, 67 | x0=None, 68 | temperature=1., 69 | noise_dropout=0., 70 | score_corrector=None, 71 | corrector_kwargs=None, 72 | verbose=True, 73 | x_T=None, 74 | log_every_t=100, 75 | unconditional_guidance_scale=1., 76 | unconditional_conditioning=None, 77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 78 | **kwargs 79 | ): 80 | if conditioning is not None: 81 | if isinstance(conditioning, dict): 82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 83 | if cbs != batch_size: 84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 85 | else: 86 | if conditioning.shape[0] != batch_size: 87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 88 | 89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 90 | # sampling 91 | C, H, W = shape 92 | size = (batch_size, C, H, W) 93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 94 | 95 | samples, intermediates = self.ddim_sampling(conditioning, size, 96 | callback=callback, 97 | img_callback=img_callback, 98 | quantize_denoised=quantize_x0, 99 | mask=mask, x0=x0, 100 | ddim_use_original_steps=False, 101 | noise_dropout=noise_dropout, 102 | temperature=temperature, 103 | score_corrector=score_corrector, 104 | corrector_kwargs=corrector_kwargs, 105 | x_T=x_T, 106 | log_every_t=log_every_t, 107 | unconditional_guidance_scale=unconditional_guidance_scale, 108 | unconditional_conditioning=unconditional_conditioning, 109 | ) 110 | return samples, intermediates 111 | 112 | @torch.no_grad() 113 | def ddim_sampling(self, cond, shape, 114 | x_T=None, ddim_use_original_steps=False, 115 | callback=None, timesteps=None, quantize_denoised=False, 116 | mask=None, x0=None, img_callback=None, log_every_t=100, 117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 118 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 119 | device = self.model.betas.device 120 | b = shape[0] 121 | if x_T is None: 122 | img = torch.randn(shape, device=device) 123 | else: 124 | img = x_T 125 | 126 | if timesteps is None: 127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 128 | elif timesteps is not None and not ddim_use_original_steps: 129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 130 | timesteps = self.ddim_timesteps[:subset_end] 131 | 132 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 135 | print(f"Running DDIM Sampling with {total_steps} timesteps") 136 | 137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 138 | 139 | for i, step in enumerate(iterator): 140 | index = total_steps - i - 1 141 | ts = torch.full((b,), step, device=device, dtype=torch.long) 142 | 143 | if mask is not None: 144 | assert x0 is not None 145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 146 | img = img_orig * mask + (1. - mask) * img 147 | 148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 149 | quantize_denoised=quantize_denoised, temperature=temperature, 150 | noise_dropout=noise_dropout, score_corrector=score_corrector, 151 | corrector_kwargs=corrector_kwargs, 152 | unconditional_guidance_scale=unconditional_guidance_scale, 153 | unconditional_conditioning=unconditional_conditioning) 154 | img, pred_x0 = outs 155 | if callback: callback(i) 156 | if img_callback: img_callback(pred_x0, i) 157 | 158 | if index % log_every_t == 0 or index == total_steps - 1: 159 | intermediates['x_inter'].append(img) 160 | intermediates['pred_x0'].append(pred_x0) 161 | 162 | return img, intermediates 163 | 164 | @torch.no_grad() 165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 167 | unconditional_guidance_scale=1., unconditional_conditioning=None): 168 | b, *_, device = *x.shape, x.device 169 | 170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 171 | e_t = self.model.apply_model(x, t, c) 172 | else: 173 | x_in = torch.cat([x] * 2) 174 | t_in = torch.cat([t] * 2) 175 | c_in = torch.cat([unconditional_conditioning, c]) 176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 178 | 179 | if score_corrector is not None: 180 | assert self.model.parameterization == "eps" 181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 182 | 183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 187 | # select parameters corresponding to the currently considered timestep 188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 192 | 193 | # current prediction for x_0 194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 195 | if quantize_denoised: 196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 197 | # direction pointing to x_t 198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 200 | if noise_dropout > 0.: 201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 203 | return x_prev, pred_x0 204 | -------------------------------------------------------------------------------- /rec_network/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from rec_network.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class PLMSSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | if ddim_eta != 0: 26 | raise ValueError('ddim_eta must be 0 for PLMS') 27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 29 | alphas_cumprod = self.model.alphas_cumprod 30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 32 | 33 | self.register_buffer('betas', to_torch(self.model.betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 43 | 44 | # ddim sampling parameters 45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 46 | ddim_timesteps=self.ddim_timesteps, 47 | eta=ddim_eta,verbose=verbose) 48 | self.register_buffer('ddim_sigmas', ddim_sigmas) 49 | self.register_buffer('ddim_alphas', ddim_alphas) 50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 56 | 57 | @torch.no_grad() 58 | def sample(self, 59 | S, 60 | batch_size, 61 | shape, 62 | conditioning=None, 63 | callback=None, 64 | normals_sequence=None, 65 | img_callback=None, 66 | quantize_x0=False, 67 | eta=0., 68 | mask=None, 69 | x0=None, 70 | temperature=1., 71 | noise_dropout=0., 72 | score_corrector=None, 73 | corrector_kwargs=None, 74 | verbose=True, 75 | x_T=None, 76 | log_every_t=100, 77 | unconditional_guidance_scale=1., 78 | unconditional_conditioning=None, 79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 80 | **kwargs 81 | ): 82 | if conditioning is not None: 83 | if isinstance(conditioning, dict): 84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 85 | if cbs != batch_size: 86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 87 | else: 88 | if conditioning.shape[0] != batch_size: 89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 90 | 91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 92 | # sampling 93 | C, H, W = shape 94 | size = (batch_size, C, H, W) 95 | print(f'Data shape for PLMS sampling is {size}') 96 | 97 | samples, intermediates = self.plms_sampling(conditioning, size, 98 | callback=callback, 99 | img_callback=img_callback, 100 | quantize_denoised=quantize_x0, 101 | mask=mask, x0=x0, 102 | ddim_use_original_steps=False, 103 | noise_dropout=noise_dropout, 104 | temperature=temperature, 105 | score_corrector=score_corrector, 106 | corrector_kwargs=corrector_kwargs, 107 | x_T=x_T, 108 | log_every_t=log_every_t, 109 | unconditional_guidance_scale=unconditional_guidance_scale, 110 | unconditional_conditioning=unconditional_conditioning, 111 | ) 112 | return samples, intermediates 113 | 114 | @torch.no_grad() 115 | def plms_sampling(self, cond, shape, 116 | x_T=None, ddim_use_original_steps=False, 117 | callback=None, timesteps=None, quantize_denoised=False, 118 | mask=None, x0=None, img_callback=None, log_every_t=100, 119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 120 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 121 | device = self.model.betas.device 122 | b = shape[0] 123 | if x_T is None: 124 | img = torch.randn(shape, device=device) 125 | else: 126 | img = x_T 127 | 128 | if timesteps is None: 129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 130 | elif timesteps is not None and not ddim_use_original_steps: 131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 132 | timesteps = self.ddim_timesteps[:subset_end] 133 | 134 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 137 | print(f"Running PLMS Sampling with {total_steps} timesteps") 138 | 139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 140 | old_eps = [] 141 | 142 | for i, step in enumerate(iterator): 143 | index = total_steps - i - 1 144 | ts = torch.full((b,), step, device=device, dtype=torch.long) 145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 146 | 147 | if mask is not None: 148 | assert x0 is not None 149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 150 | img = img_orig * mask + (1. - mask) * img 151 | 152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 153 | quantize_denoised=quantize_denoised, temperature=temperature, 154 | noise_dropout=noise_dropout, score_corrector=score_corrector, 155 | corrector_kwargs=corrector_kwargs, 156 | unconditional_guidance_scale=unconditional_guidance_scale, 157 | unconditional_conditioning=unconditional_conditioning, 158 | old_eps=old_eps, t_next=ts_next) 159 | img, pred_x0, e_t = outs 160 | old_eps.append(e_t) 161 | if len(old_eps) >= 4: 162 | old_eps.pop(0) 163 | if callback: callback(i) 164 | if img_callback: img_callback(pred_x0, i) 165 | 166 | if index % log_every_t == 0 or index == total_steps - 1: 167 | intermediates['x_inter'].append(img) 168 | intermediates['pred_x0'].append(pred_x0) 169 | 170 | return img, intermediates 171 | 172 | @torch.no_grad() 173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 176 | b, *_, device = *x.shape, x.device 177 | 178 | def get_model_output(x, t): 179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 180 | e_t = self.model.apply_model(x, t, c) 181 | else: 182 | x_in = torch.cat([x] * 2) 183 | t_in = torch.cat([t] * 2) 184 | c_in = torch.cat([unconditional_conditioning, c]) 185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 187 | 188 | if score_corrector is not None: 189 | assert self.model.parameterization == "eps" 190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 191 | 192 | return e_t 193 | 194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 198 | 199 | def get_x_prev_and_pred_x0(e_t, index): 200 | # select parameters corresponding to the currently considered timestep 201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 205 | 206 | # current prediction for x_0 207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 208 | if quantize_denoised: 209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 210 | # direction pointing to x_t 211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 213 | if noise_dropout > 0.: 214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 216 | return x_prev, pred_x0 217 | 218 | e_t = get_model_output(x, t) 219 | if len(old_eps) == 0: 220 | # Pseudo Improved Euler (2nd order) 221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 222 | e_t_next = get_model_output(x_prev, t_next) 223 | e_t_prime = (e_t + e_t_next) / 2 224 | elif len(old_eps) == 1: 225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 227 | elif len(old_eps) == 2: 228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 230 | elif len(old_eps) >= 3: 231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 233 | 234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 235 | 236 | return x_prev, pred_x0, e_t 237 | -------------------------------------------------------------------------------- /rec_network/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from rec_network.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): #context_dim :y query_dim:z 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | # context_dim = query_dim 158 | 159 | self.scale = dim_head ** -0.5 160 | self.heads = heads 161 | 162 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 163 | # print("!!!!!!",inner_dim,context_dim) 164 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 165 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 166 | 167 | self.to_out = nn.Sequential( 168 | nn.Linear(inner_dim, query_dim), 169 | nn.Dropout(dropout) 170 | ) 171 | 172 | def forward(self, x, context=None, mask=None): 173 | h = self.heads 174 | 175 | q = self.to_q(x) 176 | context = default(context, x) 177 | # print("input",x.shape,context.shape) 178 | k = self.to_k(context) 179 | v = self.to_v(context) 180 | 181 | # print("!!!!!!!!!!",h,q.shape,k.shape,v.shape,context.shape) 182 | 183 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 184 | 185 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 186 | 187 | if exists(mask): 188 | mask = rearrange(mask, 'b ... -> b (...)') 189 | max_neg_value = -torch.finfo(sim.dtype).max 190 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 191 | sim.masked_fill_(~mask, max_neg_value) 192 | 193 | # attention, what we cannot get enough of 194 | attn = sim.softmax(dim=-1) 195 | 196 | out = einsum('b i j, b j d -> b i d', attn, v) 197 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 198 | return self.to_out(out) 199 | 200 | 201 | class BasicTransformerBlock(nn.Module): 202 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 203 | super().__init__() 204 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 205 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 206 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 207 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 208 | self.norm1 = nn.LayerNorm(dim) 209 | self.norm2 = nn.LayerNorm(dim) 210 | self.norm3 = nn.LayerNorm(dim) 211 | self.checkpoint = checkpoint 212 | 213 | def forward(self, x, context=None): 214 | 215 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 216 | 217 | def _forward(self, x, context=None): 218 | x = self.attn1(self.norm1(x)) + x 219 | x = self.attn2(self.norm2(x), context=context) + x 220 | x = self.ff(self.norm3(x)) + x 221 | return x 222 | 223 | 224 | class SpatialTransformer(nn.Module): 225 | """ 226 | Transformer block for image-like data. 227 | First, project the input (aka embedding) 228 | and reshape to b, t, d. 229 | Then apply standard transformer action. 230 | Finally, reshape to image 231 | """ 232 | def __init__(self, in_channels, n_heads, d_head, 233 | depth=1, dropout=0., context_dim=None): 234 | super().__init__() 235 | self.in_channels = in_channels 236 | inner_dim = n_heads * d_head 237 | self.norm = Normalize(in_channels) 238 | 239 | self.proj_in = nn.Conv2d(in_channels, 240 | inner_dim, 241 | kernel_size=1, 242 | stride=1, 243 | padding=0) 244 | 245 | self.transformer_blocks = nn.ModuleList( 246 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 247 | for d in range(depth)] 248 | ) 249 | 250 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 251 | in_channels, 252 | kernel_size=1, 253 | stride=1, 254 | padding=0)) 255 | 256 | def forward(self, x, context=None): 257 | # note: if no context is given, cross-attention defaults to self-attention 258 | b, c, h, w = x.shape 259 | # print("253 shape", x.shape, context.shape) 260 | x_in = x 261 | x = self.norm(x) 262 | x = self.proj_in(x) 263 | x = rearrange(x, 'b c h w -> b (h w) c') 264 | context = rearrange(context, 'b c h w -> b (h w) c') 265 | # print("258 shape", x.shape, context.shape) 266 | for block in self.transformer_blocks: 267 | x = block(x, context=context) 268 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 269 | x = self.proj_out(x) 270 | return x + x_in -------------------------------------------------------------------------------- /rec_network/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /rec_network/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from rec_network.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /rec_network/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/distributions/__init__.py -------------------------------------------------------------------------------- /rec_network/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /rec_network/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /rec_network/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/encoders/__init__.py -------------------------------------------------------------------------------- /rec_network/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from rec_network.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /rec_network/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from rec_network.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from rec_network.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /rec_network/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /rec_network/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from rec_network.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /rec_network/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /rec_network/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /rec_network/modules/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from functools import partial 6 | from inspect import isfunction 7 | from collections import namedtuple 8 | from einops import rearrange, repeat, reduce 9 | 10 | # constants 11 | 12 | DEFAULT_DIM_HEAD = 64 13 | 14 | Intermediates = namedtuple('Intermediates', [ 15 | 'pre_softmax_attn', 16 | 'post_softmax_attn' 17 | ]) 18 | 19 | LayerIntermediates = namedtuple('Intermediates', [ 20 | 'hiddens', 21 | 'attn_intermediates' 22 | ]) 23 | 24 | 25 | class AbsolutePositionalEmbedding(nn.Module): 26 | def __init__(self, dim, max_seq_len): 27 | super().__init__() 28 | self.emb = nn.Embedding(max_seq_len, dim) 29 | self.init_() 30 | 31 | def init_(self): 32 | nn.init.normal_(self.emb.weight, std=0.02) 33 | 34 | def forward(self, x): 35 | n = torch.arange(x.shape[1], device=x.device) 36 | return self.emb(n)[None, :, :] 37 | 38 | 39 | class FixedPositionalEmbedding(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 43 | self.register_buffer('inv_freq', inv_freq) 44 | 45 | def forward(self, x, seq_dim=1, offset=0): 46 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 47 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 48 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 49 | return emb[None, :, :] 50 | 51 | 52 | # helpers 53 | 54 | def exists(val): 55 | return val is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def always(val): 65 | def inner(*args, **kwargs): 66 | return val 67 | return inner 68 | 69 | 70 | def not_equals(val): 71 | def inner(x): 72 | return x != val 73 | return inner 74 | 75 | 76 | def equals(val): 77 | def inner(x): 78 | return x == val 79 | return inner 80 | 81 | 82 | def max_neg_value(tensor): 83 | return -torch.finfo(tensor.dtype).max 84 | 85 | 86 | # keyword argument helpers 87 | 88 | def pick_and_pop(keys, d): 89 | values = list(map(lambda key: d.pop(key), keys)) 90 | return dict(zip(keys, values)) 91 | 92 | 93 | def group_dict_by_key(cond, d): 94 | return_val = [dict(), dict()] 95 | for key in d.keys(): 96 | match = bool(cond(key)) 97 | ind = int(not match) 98 | return_val[ind][key] = d[key] 99 | return (*return_val,) 100 | 101 | 102 | def string_begins_with(prefix, str): 103 | return str.startswith(prefix) 104 | 105 | 106 | def group_by_key_prefix(prefix, d): 107 | return group_dict_by_key(partial(string_begins_with, prefix), d) 108 | 109 | 110 | def groupby_prefix_and_trim(prefix, d): 111 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 112 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 113 | return kwargs_without_prefix, kwargs 114 | 115 | 116 | # classes 117 | class Scale(nn.Module): 118 | def __init__(self, value, fn): 119 | super().__init__() 120 | self.value = value 121 | self.fn = fn 122 | 123 | def forward(self, x, **kwargs): 124 | x, *rest = self.fn(x, **kwargs) 125 | return (x * self.value, *rest) 126 | 127 | 128 | class Rezero(nn.Module): 129 | def __init__(self, fn): 130 | super().__init__() 131 | self.fn = fn 132 | self.g = nn.Parameter(torch.zeros(1)) 133 | 134 | def forward(self, x, **kwargs): 135 | x, *rest = self.fn(x, **kwargs) 136 | return (x * self.g, *rest) 137 | 138 | 139 | class ScaleNorm(nn.Module): 140 | def __init__(self, dim, eps=1e-5): 141 | super().__init__() 142 | self.scale = dim ** -0.5 143 | self.eps = eps 144 | self.g = nn.Parameter(torch.ones(1)) 145 | 146 | def forward(self, x): 147 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 148 | return x / norm.clamp(min=self.eps) * self.g 149 | 150 | 151 | class RMSNorm(nn.Module): 152 | def __init__(self, dim, eps=1e-8): 153 | super().__init__() 154 | self.scale = dim ** -0.5 155 | self.eps = eps 156 | self.g = nn.Parameter(torch.ones(dim)) 157 | 158 | def forward(self, x): 159 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 160 | return x / norm.clamp(min=self.eps) * self.g 161 | 162 | 163 | class Residual(nn.Module): 164 | def forward(self, x, residual): 165 | return x + residual 166 | 167 | 168 | class GRUGating(nn.Module): 169 | def __init__(self, dim): 170 | super().__init__() 171 | self.gru = nn.GRUCell(dim, dim) 172 | 173 | def forward(self, x, residual): 174 | gated_output = self.gru( 175 | rearrange(x, 'b n d -> (b n) d'), 176 | rearrange(residual, 'b n d -> (b n) d') 177 | ) 178 | 179 | return gated_output.reshape_as(x) 180 | 181 | 182 | # feedforward 183 | 184 | class GEGLU(nn.Module): 185 | def __init__(self, dim_in, dim_out): 186 | super().__init__() 187 | self.proj = nn.Linear(dim_in, dim_out * 2) 188 | 189 | def forward(self, x): 190 | x, gate = self.proj(x).chunk(2, dim=-1) 191 | return x * F.gelu(gate) 192 | 193 | 194 | class FeedForward(nn.Module): 195 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 196 | super().__init__() 197 | inner_dim = int(dim * mult) 198 | dim_out = default(dim_out, dim) 199 | project_in = nn.Sequential( 200 | nn.Linear(dim, inner_dim), 201 | nn.GELU() 202 | ) if not glu else GEGLU(dim, inner_dim) 203 | 204 | self.net = nn.Sequential( 205 | project_in, 206 | nn.Dropout(dropout), 207 | nn.Linear(inner_dim, dim_out) 208 | ) 209 | 210 | def forward(self, x): 211 | return self.net(x) 212 | 213 | 214 | # attention. 215 | class Attention(nn.Module): 216 | def __init__( 217 | self, 218 | dim, 219 | dim_head=DEFAULT_DIM_HEAD, 220 | heads=8, 221 | causal=False, 222 | mask=None, 223 | talking_heads=False, 224 | sparse_topk=None, 225 | use_entmax15=False, 226 | num_mem_kv=0, 227 | dropout=0., 228 | on_attn=False 229 | ): 230 | super().__init__() 231 | if use_entmax15: 232 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 233 | self.scale = dim_head ** -0.5 234 | self.heads = heads 235 | self.causal = causal 236 | self.mask = mask 237 | 238 | inner_dim = dim_head * heads 239 | 240 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 242 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 243 | self.dropout = nn.Dropout(dropout) 244 | 245 | # talking heads 246 | self.talking_heads = talking_heads 247 | if talking_heads: 248 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 250 | 251 | # explicit topk sparse attention 252 | self.sparse_topk = sparse_topk 253 | 254 | # entmax 255 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 256 | self.attn_fn = F.softmax 257 | 258 | # add memory key / values 259 | self.num_mem_kv = num_mem_kv 260 | if num_mem_kv > 0: 261 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 263 | 264 | # attention on attention 265 | self.attn_on_attn = on_attn 266 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 267 | 268 | def forward( 269 | self, 270 | x, 271 | context=None, 272 | mask=None, 273 | context_mask=None, 274 | rel_pos=None, 275 | sinusoidal_emb=None, 276 | prev_attn=None, 277 | mem=None 278 | ): 279 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 280 | kv_input = default(context, x) 281 | 282 | q_input = x 283 | k_input = kv_input 284 | v_input = kv_input 285 | 286 | if exists(mem): 287 | k_input = torch.cat((mem, k_input), dim=-2) 288 | v_input = torch.cat((mem, v_input), dim=-2) 289 | 290 | if exists(sinusoidal_emb): 291 | # in shortformer, the query would start at a position offset depending on the past cached memory 292 | offset = k_input.shape[-2] - q_input.shape[-2] 293 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 294 | k_input = k_input + sinusoidal_emb(k_input) 295 | 296 | q = self.to_q(q_input) 297 | k = self.to_k(k_input) 298 | v = self.to_v(v_input) 299 | 300 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 301 | 302 | input_mask = None 303 | if any(map(exists, (mask, context_mask))): 304 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 305 | k_mask = q_mask if not exists(context) else context_mask 306 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 307 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 308 | k_mask = rearrange(k_mask, 'b j -> b () () j') 309 | input_mask = q_mask * k_mask 310 | 311 | if self.num_mem_kv > 0: 312 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 313 | k = torch.cat((mem_k, k), dim=-2) 314 | v = torch.cat((mem_v, v), dim=-2) 315 | if exists(input_mask): 316 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 317 | 318 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 319 | mask_value = max_neg_value(dots) 320 | 321 | if exists(prev_attn): 322 | dots = dots + prev_attn 323 | 324 | pre_softmax_attn = dots 325 | 326 | if talking_heads: 327 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 328 | 329 | if exists(rel_pos): 330 | dots = rel_pos(dots) 331 | 332 | if exists(input_mask): 333 | dots.masked_fill_(~input_mask, mask_value) 334 | del input_mask 335 | 336 | if self.causal: 337 | i, j = dots.shape[-2:] 338 | r = torch.arange(i, device=device) 339 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 340 | mask = F.pad(mask, (j - i, 0), value=False) 341 | dots.masked_fill_(mask, mask_value) 342 | del mask 343 | 344 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 345 | top, _ = dots.topk(self.sparse_topk, dim=-1) 346 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 347 | mask = dots < vk 348 | dots.masked_fill_(mask, mask_value) 349 | del mask 350 | 351 | attn = self.attn_fn(dots, dim=-1) 352 | post_softmax_attn = attn 353 | 354 | attn = self.dropout(attn) 355 | 356 | if talking_heads: 357 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 358 | 359 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 360 | out = rearrange(out, 'b h n d -> b n (h d)') 361 | 362 | intermediates = Intermediates( 363 | pre_softmax_attn=pre_softmax_attn, 364 | post_softmax_attn=post_softmax_attn 365 | ) 366 | 367 | return self.to_out(out), intermediates 368 | 369 | 370 | class AttentionLayers(nn.Module): 371 | def __init__( 372 | self, 373 | dim, 374 | depth, 375 | heads=8, 376 | causal=False, 377 | cross_attend=False, 378 | only_cross=False, 379 | use_scalenorm=False, 380 | use_rmsnorm=False, 381 | use_rezero=False, 382 | rel_pos_num_buckets=32, 383 | rel_pos_max_distance=128, 384 | position_infused_attn=False, 385 | custom_layers=None, 386 | sandwich_coef=None, 387 | par_ratio=None, 388 | residual_attn=False, 389 | cross_residual_attn=False, 390 | macaron=False, 391 | pre_norm=True, 392 | gate_residual=False, 393 | **kwargs 394 | ): 395 | super().__init__() 396 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 397 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 398 | 399 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 400 | 401 | self.dim = dim 402 | self.depth = depth 403 | self.layers = nn.ModuleList([]) 404 | 405 | self.has_pos_emb = position_infused_attn 406 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 407 | self.rotary_pos_emb = always(None) 408 | 409 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 410 | self.rel_pos = None 411 | 412 | self.pre_norm = pre_norm 413 | 414 | self.residual_attn = residual_attn 415 | self.cross_residual_attn = cross_residual_attn 416 | 417 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 418 | norm_class = RMSNorm if use_rmsnorm else norm_class 419 | norm_fn = partial(norm_class, dim) 420 | 421 | norm_fn = nn.Identity if use_rezero else norm_fn 422 | branch_fn = Rezero if use_rezero else None 423 | 424 | if cross_attend and not only_cross: 425 | default_block = ('a', 'c', 'f') 426 | elif cross_attend and only_cross: 427 | default_block = ('c', 'f') 428 | else: 429 | default_block = ('a', 'f') 430 | 431 | if macaron: 432 | default_block = ('f',) + default_block 433 | 434 | if exists(custom_layers): 435 | layer_types = custom_layers 436 | elif exists(par_ratio): 437 | par_depth = depth * len(default_block) 438 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 439 | default_block = tuple(filter(not_equals('f'), default_block)) 440 | par_attn = par_depth // par_ratio 441 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 442 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 443 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 444 | par_block = default_block + ('f',) * (par_width - len(default_block)) 445 | par_head = par_block * par_attn 446 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 447 | elif exists(sandwich_coef): 448 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 449 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 450 | else: 451 | layer_types = default_block * depth 452 | 453 | self.layer_types = layer_types 454 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 455 | 456 | for layer_type in self.layer_types: 457 | if layer_type == 'a': 458 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 459 | elif layer_type == 'c': 460 | layer = Attention(dim, heads=heads, **attn_kwargs) 461 | elif layer_type == 'f': 462 | layer = FeedForward(dim, **ff_kwargs) 463 | layer = layer if not macaron else Scale(0.5, layer) 464 | else: 465 | raise Exception(f'invalid layer type {layer_type}') 466 | 467 | if isinstance(layer, Attention) and exists(branch_fn): 468 | layer = branch_fn(layer) 469 | 470 | if gate_residual: 471 | residual_fn = GRUGating(dim) 472 | else: 473 | residual_fn = Residual() 474 | 475 | self.layers.append(nn.ModuleList([ 476 | norm_fn(), 477 | layer, 478 | residual_fn 479 | ])) 480 | 481 | def forward( 482 | self, 483 | x, 484 | context=None, 485 | mask=None, 486 | context_mask=None, 487 | mems=None, 488 | return_hiddens=False 489 | ): 490 | hiddens = [] 491 | intermediates = [] 492 | prev_attn = None 493 | prev_cross_attn = None 494 | 495 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 496 | 497 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 498 | is_last = ind == (len(self.layers) - 1) 499 | 500 | if layer_type == 'a': 501 | hiddens.append(x) 502 | layer_mem = mems.pop(0) 503 | 504 | residual = x 505 | 506 | if self.pre_norm: 507 | x = norm(x) 508 | 509 | if layer_type == 'a': 510 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 511 | prev_attn=prev_attn, mem=layer_mem) 512 | elif layer_type == 'c': 513 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 514 | elif layer_type == 'f': 515 | out = block(x) 516 | 517 | x = residual_fn(out, residual) 518 | 519 | if layer_type in ('a', 'c'): 520 | intermediates.append(inter) 521 | 522 | if layer_type == 'a' and self.residual_attn: 523 | prev_attn = inter.pre_softmax_attn 524 | elif layer_type == 'c' and self.cross_residual_attn: 525 | prev_cross_attn = inter.pre_softmax_attn 526 | 527 | if not self.pre_norm and not is_last: 528 | x = norm(x) 529 | 530 | if return_hiddens: 531 | intermediates = LayerIntermediates( 532 | hiddens=hiddens, 533 | attn_intermediates=intermediates 534 | ) 535 | 536 | return x, intermediates 537 | 538 | return x 539 | 540 | 541 | class Encoder(AttentionLayers): 542 | def __init__(self, **kwargs): 543 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 544 | super().__init__(causal=False, **kwargs) 545 | 546 | 547 | 548 | class TransformerWrapper(nn.Module): 549 | def __init__( 550 | self, 551 | *, 552 | num_tokens, 553 | max_seq_len, 554 | attn_layers, 555 | emb_dim=None, 556 | max_mem_len=0., 557 | emb_dropout=0., 558 | num_memory_tokens=None, 559 | tie_embedding=False, 560 | use_pos_emb=True 561 | ): 562 | super().__init__() 563 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 564 | 565 | dim = attn_layers.dim 566 | emb_dim = default(emb_dim, dim) 567 | 568 | self.max_seq_len = max_seq_len 569 | self.max_mem_len = max_mem_len 570 | self.num_tokens = num_tokens 571 | 572 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 573 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 574 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 575 | self.emb_dropout = nn.Dropout(emb_dropout) 576 | 577 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 578 | self.attn_layers = attn_layers 579 | self.norm = nn.LayerNorm(dim) 580 | 581 | self.init_() 582 | 583 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 584 | 585 | # memory tokens (like [cls]) from Memory Transformers paper 586 | num_memory_tokens = default(num_memory_tokens, 0) 587 | self.num_memory_tokens = num_memory_tokens 588 | if num_memory_tokens > 0: 589 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 590 | 591 | # let funnel encoder know number of memory tokens, if specified 592 | if hasattr(attn_layers, 'num_memory_tokens'): 593 | attn_layers.num_memory_tokens = num_memory_tokens 594 | 595 | def init_(self): 596 | nn.init.normal_(self.token_emb.weight, std=0.02) 597 | 598 | def forward( 599 | self, 600 | x, 601 | return_embeddings=False, 602 | mask=None, 603 | return_mems=False, 604 | return_attn=False, 605 | mems=None, 606 | **kwargs 607 | ): 608 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 609 | x = self.token_emb(x) 610 | x += self.pos_emb(x) 611 | x = self.emb_dropout(x) 612 | 613 | x = self.project_emb(x) 614 | 615 | if num_mem > 0: 616 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 617 | x = torch.cat((mem, x), dim=1) 618 | 619 | # auto-handle masking after appending memory tokens 620 | if exists(mask): 621 | mask = F.pad(mask, (num_mem, 0), value=True) 622 | 623 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 624 | x = self.norm(x) 625 | 626 | mem, x = x[:, :num_mem], x[:, num_mem:] 627 | 628 | out = self.to_logits(x) if not return_embeddings else x 629 | 630 | if return_mems: 631 | hiddens = intermediates.hiddens 632 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 633 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 634 | return out, new_mems 635 | 636 | if return_attn: 637 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 638 | return out, attn_maps 639 | 640 | return out 641 | 642 | -------------------------------------------------------------------------------- /rec_network/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir datasets 4 | cd datasets 5 | # Download describable textures dataset 6 | wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz 7 | tar -xf dtd-r1.0.1.tar.gz 8 | rm dtd-r1.0.1.tar.gz 9 | 10 | mkdir mvtec 11 | cd mvtec 12 | # Download MVTec anomaly detection dataset 13 | wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz 14 | tar -xf mvtec_anomaly_detection.tar.xz 15 | rm mvtec_anomaly_detection.tar.xz 16 | 17 | -------------------------------------------------------------------------------- /scripts/mvtec.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import cv2 8 | from rec_network.main import instantiate_from_config 9 | from rec_network.models.diffusion.ddim import DDIMSampler 10 | from rec_network.data.mvtec import MVTecDRAEMTestDataset 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--outdir", 18 | type=str, 19 | default="./samples/bottle/", 20 | help="dir to write results to", 21 | ) 22 | parser.add_argument( 23 | "--steps", 24 | type=int, 25 | default=50, 26 | help="number of ddim sampling steps", 27 | ) 28 | opt = parser.parse_args() 29 | ddim_eta = 0.0 30 | 31 | mvtec_path = './datasets/mvtec/bottle' 32 | 33 | dataset = MVTecDRAEMTestDataset(mvtec_path + "/test/", resize_shape=[256, 256]) 34 | dataloader = DataLoader(dataset, batch_size=1, 35 | shuffle=False, num_workers=0) 36 | print(f"Found {len(dataloader)} inputs.") 37 | 38 | config = OmegaConf.load("../configs/mvtec.yaml") 39 | 40 | model = instantiate_from_config(config.model) 41 | model.load_state_dict(torch.load("./logs/2023-01-31T16-58-02_mvtec/checkpoints/last.ckpt")["state_dict"], 42 | strict=False) #TODO: modify the ckpt path 43 | 44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 45 | model = model.to(device) 46 | sampler = DDIMSampler(model) 47 | 48 | os.makedirs(opt.outdir, exist_ok=True) 49 | cnt = 0 50 | with torch.no_grad(): 51 | with model.ema_scope(): 52 | for i_batch, batch in enumerate(dataloader): 53 | 54 | c_outpath = os.path.join(opt.outdir, 'condition'+str(cnt)+'.jpg') 55 | outpath = os.path.join(opt.outdir, str(cnt)+'.jpg') 56 | # print(outpath) 57 | condition = batch["image"].cpu().numpy().transpose(0,2,3,1)[0]*255 58 | cv2.imwrite(c_outpath,condition) 59 | 60 | c = batch["image"].to(device) 61 | c = model.cond_stage_model.encode(c) 62 | c = c.mode() 63 | 64 | noise = torch.randn_like(c) 65 | t = torch.randint(400, 500, (c.shape[0],), device=device).long() 66 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise) 67 | 68 | shape = c.shape[1:] 69 | samples_ddim, _ = sampler.sample(S=opt.steps, 70 | conditioning=c, # or conditioning=c_noisy 71 | batch_size=c.shape[0], 72 | shape=shape, 73 | verbose=False) 74 | x_samples_ddim = model.decode_first_stage(samples_ddim) 75 | 76 | sample = x_samples_ddim.cpu().numpy().transpose(0,2,3,1)[0]*255 77 | cv2.imwrite(outpath, sample) 78 | cnt+=1 79 | -------------------------------------------------------------------------------- /seg_network/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | import cv2 6 | import glob 7 | import imgaug.augmenters as iaa 8 | from perlin import rand_perlin_2d_np 9 | 10 | class MVTecDRAEMTestDataset(Dataset): 11 | 12 | def __init__(self, root_dir, resize_shape=None): 13 | self.root_dir = root_dir 14 | self.images = sorted(glob.glob(root_dir+"/*/*.png")) 15 | self.resize_shape=resize_shape 16 | 17 | def __len__(self): 18 | return len(self.images) 19 | 20 | def transform_image(self, image_path, mask_path): 21 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 22 | if mask_path is not None: 23 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 24 | else: 25 | mask = np.zeros((image.shape[0],image.shape[1])) 26 | if self.resize_shape != None: 27 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0])) 28 | mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0])) 29 | 30 | image = image / 255.0 31 | mask = mask / 255.0 32 | 33 | image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32) 34 | mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32) 35 | 36 | image = np.transpose(image, (2, 0, 1)) 37 | mask = np.transpose(mask, (2, 0, 1)) 38 | return image, mask 39 | 40 | def __getitem__(self, idx): 41 | if torch.is_tensor(idx): 42 | idx = idx.tolist() 43 | 44 | img_path = self.images[idx] 45 | dir_path, file_name = os.path.split(img_path) 46 | base_dir = os.path.basename(dir_path) 47 | if base_dir == 'good': 48 | image, mask = self.transform_image(img_path, None) 49 | has_anomaly = np.array([0], dtype=np.float32) 50 | else: 51 | mask_path = os.path.join(dir_path, '../../ground_truth/') 52 | mask_path = os.path.join(mask_path, base_dir) 53 | mask_file_name = file_name.split(".")[0]+"_mask.png" 54 | mask_path = os.path.join(mask_path, mask_file_name) 55 | image, mask = self.transform_image(img_path, mask_path) 56 | has_anomaly = np.array([1], dtype=np.float32) 57 | 58 | sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx} 59 | 60 | return sample 61 | 62 | 63 | 64 | class MVTecDRAEMTrainDataset(Dataset): 65 | 66 | def __init__(self, root_dir, anomaly_source_path, resize_shape=None): 67 | """ 68 | Args: 69 | root_dir (string): Directory with all the images. 70 | transform (callable, optional): Optional transform to be applied 71 | on a sample. 72 | """ 73 | self.root_dir = root_dir 74 | self.resize_shape=resize_shape 75 | 76 | self.image_paths = sorted(glob.glob(root_dir+"/*.png")) 77 | 78 | self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path+"/*/*.jpg")) 79 | 80 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True), 81 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 82 | iaa.pillike.EnhanceSharpness(), 83 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 84 | iaa.Solarize(0.5, threshold=(32,128)), 85 | iaa.Posterize(), 86 | iaa.Invert(), 87 | iaa.pillike.Autocontrast(), 88 | iaa.pillike.Equalize(), 89 | iaa.Affine(rotate=(-45, 45)) 90 | ] 91 | 92 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 93 | 94 | 95 | def __len__(self): 96 | return len(self.image_paths) 97 | 98 | 99 | def randAugmenter(self): 100 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) 101 | aug = iaa.Sequential([self.augmenters[aug_ind[0]], 102 | self.augmenters[aug_ind[1]], 103 | self.augmenters[aug_ind[2]]] 104 | ) 105 | return aug 106 | 107 | def augment_image(self, image, anomaly_source_path): 108 | aug = self.randAugmenter() 109 | perlin_scale = 6 110 | min_perlin_scale = 0 111 | anomaly_source_img = cv2.imread(anomaly_source_path) 112 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize_shape[1], self.resize_shape[0])) 113 | 114 | anomaly_img_augmented = aug(image=anomaly_source_img) 115 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 116 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 117 | 118 | perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley)) 119 | perlin_noise = self.rot(image=perlin_noise) 120 | threshold = 0.5 121 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 122 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 123 | 124 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 125 | 126 | beta = torch.rand(1).numpy()[0] * 0.8 127 | 128 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * ( 129 | perlin_thr) 130 | 131 | no_anomaly = torch.rand(1).numpy()[0] 132 | if no_anomaly > 0.5: 133 | image = image.astype(np.float32) 134 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32) 135 | else: 136 | augmented_image = augmented_image.astype(np.float32) 137 | msk = (perlin_thr).astype(np.float32) 138 | augmented_image = msk * augmented_image + (1-msk)*image 139 | has_anomaly = 1.0 140 | if np.sum(msk) == 0: 141 | has_anomaly=0.0 142 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32) 143 | 144 | def transform_image(self, image_path, anomaly_source_path): 145 | image = cv2.imread(image_path) 146 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0])) 147 | 148 | do_aug_orig = torch.rand(1).numpy()[0] > 0.7 149 | if do_aug_orig: 150 | image = self.rot(image=image) 151 | 152 | image = np.array(image).reshape((image.shape[0], image.shape[1], image.shape[2])).astype(np.float32) / 255.0 153 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path) 154 | augmented_image = np.transpose(augmented_image, (2, 0, 1)) 155 | image = np.transpose(image, (2, 0, 1)) 156 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1)) 157 | return image, augmented_image, anomaly_mask, has_anomaly 158 | 159 | def __getitem__(self, idx): 160 | idx = torch.randint(0, len(self.image_paths), (1,)).item() 161 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item() 162 | image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(self.image_paths[idx], 163 | self.anomaly_source_paths[anomaly_source_idx]) 164 | sample = {'image': image, "anomaly_mask": anomaly_mask, 165 | 'augmented_image': augmented_image, 'has_anomaly': has_anomaly, 'idx': idx} 166 | 167 | return sample 168 | -------------------------------------------------------------------------------- /seg_network/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | class FocalLoss(nn.Module): 8 | """ 9 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 10 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 11 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 12 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 13 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 14 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 15 | focus on hard misclassified example 16 | :param smooth: (float,double) smooth value when cross entropy 17 | :param balance_index: (int) balance class index, should be specific when alpha is float 18 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 19 | """ 20 | 21 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 22 | super(FocalLoss, self).__init__() 23 | self.apply_nonlin = apply_nonlin 24 | self.alpha = alpha 25 | self.gamma = gamma 26 | self.balance_index = balance_index 27 | self.smooth = smooth 28 | self.size_average = size_average 29 | 30 | if self.smooth is not None: 31 | if self.smooth < 0 or self.smooth > 1.0: 32 | raise ValueError('smooth value should be in [0,1]') 33 | 34 | def forward(self, logit, target): 35 | if self.apply_nonlin is not None: 36 | logit = self.apply_nonlin(logit) 37 | num_class = logit.shape[1] 38 | 39 | if logit.dim() > 2: 40 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 41 | logit = logit.view(logit.size(0), logit.size(1), -1) 42 | logit = logit.permute(0, 2, 1).contiguous() 43 | logit = logit.view(-1, logit.size(-1)) 44 | target = torch.squeeze(target, 1) 45 | target = target.view(-1, 1) 46 | alpha = self.alpha 47 | 48 | if alpha is None: 49 | alpha = torch.ones(num_class, 1) 50 | elif isinstance(alpha, (list, np.ndarray)): 51 | assert len(alpha) == num_class 52 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 53 | alpha = alpha / alpha.sum() 54 | elif isinstance(alpha, float): 55 | alpha = torch.ones(num_class, 1) 56 | alpha = alpha * (1 - self.alpha) 57 | alpha[self.balance_index] = self.alpha 58 | 59 | else: 60 | raise TypeError('Not support alpha type') 61 | 62 | if alpha.device != logit.device: 63 | alpha = alpha.to(logit.device) 64 | 65 | idx = target.cpu().long() 66 | 67 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 68 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 69 | if one_hot_key.device != logit.device: 70 | one_hot_key = one_hot_key.to(logit.device) 71 | 72 | if self.smooth: 73 | one_hot_key = torch.clamp( 74 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 75 | pt = (one_hot_key * logit).sum(1) + self.smooth 76 | logpt = pt.log() 77 | 78 | gamma = self.gamma 79 | 80 | alpha = alpha[idx] 81 | alpha = torch.squeeze(alpha) 82 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 83 | 84 | if self.size_average: 85 | loss = loss.mean() 86 | return loss 87 | 88 | def gaussian(window_size, sigma): 89 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 90 | return gauss/gauss.sum() 91 | 92 | def create_window(window_size, channel=1): 93 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 94 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 95 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 96 | return window 97 | 98 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 99 | if val_range is None: 100 | if torch.max(img1) > 128: 101 | max_val = 255 102 | else: 103 | max_val = 1 104 | 105 | if torch.min(img1) < -0.5: 106 | min_val = -1 107 | else: 108 | min_val = 0 109 | l = max_val - min_val 110 | else: 111 | l = val_range 112 | 113 | padd = window_size//2 114 | (_, channel, height, width) = img1.size() 115 | if window is None: 116 | real_size = min(window_size, height, width) 117 | window = create_window(real_size, channel=channel).to(img1.device) 118 | 119 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 120 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 121 | 122 | mu1_sq = mu1.pow(2) 123 | mu2_sq = mu2.pow(2) 124 | mu1_mu2 = mu1 * mu2 125 | 126 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 127 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 128 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 129 | 130 | c1 = (0.01 * l) ** 2 131 | c2 = (0.03 * l) ** 2 132 | 133 | v1 = 2.0 * sigma12 + c2 134 | v2 = sigma1_sq + sigma2_sq + c2 135 | cs = torch.mean(v1 / v2) # contrast sensitivity 136 | 137 | ssim_map = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 138 | 139 | if size_average: 140 | ret = ssim_map.mean() 141 | else: 142 | ret = ssim_map.mean(1).mean(1).mean(1) 143 | 144 | if full: 145 | return ret, cs 146 | return ret, ssim_map 147 | 148 | 149 | class SSIM(torch.nn.Module): 150 | def __init__(self, window_size=11, size_average=True, val_range=None): 151 | super(SSIM, self).__init__() 152 | self.window_size = window_size 153 | self.size_average = size_average 154 | self.val_range = val_range 155 | 156 | # Assume 1 channel for SSIM 157 | self.channel = 1 158 | self.window = create_window(window_size).cuda() 159 | 160 | def forward(self, img1, img2): 161 | (_, channel, _, _) = img1.size() 162 | 163 | if channel == self.channel and self.window.dtype == img1.dtype: 164 | window = self.window 165 | else: 166 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 167 | self.window = window 168 | self.channel = channel 169 | 170 | s_score, ssim_map = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 171 | return 1.0 - s_score 172 | -------------------------------------------------------------------------------- /seg_network/model_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ReconstructiveSubNetwork(nn.Module): 6 | def __init__(self,in_channels=3, out_channels=3, base_width=128): 7 | super(ReconstructiveSubNetwork, self).__init__() 8 | self.encoder = EncoderReconstructive(in_channels, base_width) 9 | self.decoder = DecoderReconstructive(base_width, out_channels=out_channels) 10 | 11 | def forward(self, x): 12 | b5 = self.encoder(x) 13 | output = self.decoder(b5) 14 | return output 15 | 16 | class DiscriminativeSubNetwork(nn.Module): 17 | def __init__(self,in_channels=3, out_channels=3, base_channels=64, out_features=False): 18 | super(DiscriminativeSubNetwork, self).__init__() 19 | base_width = base_channels 20 | self.encoder_segment = EncoderDiscriminative(in_channels, base_width) 21 | self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels) 22 | #self.segment_act = torch.nn.Sigmoid() 23 | self.out_features = out_features 24 | def forward(self, x): 25 | b1,b2,b3,b4,b5,b6 = self.encoder_segment(x) 26 | output_segment = self.decoder_segment(b1,b2,b3,b4,b5,b6) 27 | if self.out_features: 28 | return output_segment, b2, b3, b4, b5, b6 29 | else: 30 | return output_segment 31 | 32 | class EncoderDiscriminative(nn.Module): 33 | def __init__(self, in_channels, base_width): 34 | super(EncoderDiscriminative, self).__init__() 35 | self.block1 = nn.Sequential( 36 | nn.Conv2d(in_channels,base_width, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(base_width), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(base_width), 41 | nn.ReLU(inplace=True)) 42 | self.mp1 = nn.Sequential(nn.MaxPool2d(2)) 43 | self.block2 = nn.Sequential( 44 | nn.Conv2d(base_width,base_width*2, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(base_width*2), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1), 48 | nn.BatchNorm2d(base_width*2), 49 | nn.ReLU(inplace=True)) 50 | self.mp2 = nn.Sequential(nn.MaxPool2d(2)) 51 | self.block3 = nn.Sequential( 52 | nn.Conv2d(base_width*2,base_width*4, kernel_size=3, padding=1), 53 | nn.BatchNorm2d(base_width*4), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1), 56 | nn.BatchNorm2d(base_width*4), 57 | nn.ReLU(inplace=True)) 58 | self.mp3 = nn.Sequential(nn.MaxPool2d(2)) 59 | self.block4 = nn.Sequential( 60 | nn.Conv2d(base_width*4,base_width*8, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(base_width*8), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 64 | nn.BatchNorm2d(base_width*8), 65 | nn.ReLU(inplace=True)) 66 | self.mp4 = nn.Sequential(nn.MaxPool2d(2)) 67 | self.block5 = nn.Sequential( 68 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1), 69 | nn.BatchNorm2d(base_width*8), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 72 | nn.BatchNorm2d(base_width*8), 73 | nn.ReLU(inplace=True)) 74 | 75 | self.mp5 = nn.Sequential(nn.MaxPool2d(2)) 76 | self.block6 = nn.Sequential( 77 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1), 78 | nn.BatchNorm2d(base_width*8), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 81 | nn.BatchNorm2d(base_width*8), 82 | nn.ReLU(inplace=True)) 83 | 84 | 85 | def forward(self, x): 86 | b1 = self.block1(x) 87 | mp1 = self.mp1(b1) 88 | b2 = self.block2(mp1) 89 | mp2 = self.mp3(b2) 90 | b3 = self.block3(mp2) 91 | mp3 = self.mp3(b3) 92 | b4 = self.block4(mp3) 93 | mp4 = self.mp4(b4) 94 | b5 = self.block5(mp4) 95 | mp5 = self.mp5(b5) 96 | b6 = self.block6(mp5) 97 | return b1,b2,b3,b4,b5,b6 98 | 99 | class DecoderDiscriminative(nn.Module): 100 | def __init__(self, base_width, out_channels=1): 101 | super(DecoderDiscriminative, self).__init__() 102 | 103 | self.up_b = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 104 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), 105 | nn.BatchNorm2d(base_width * 8), 106 | nn.ReLU(inplace=True)) 107 | self.db_b = nn.Sequential( 108 | nn.Conv2d(base_width*(8+8), base_width*8, kernel_size=3, padding=1), 109 | nn.BatchNorm2d(base_width*8), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), 112 | nn.BatchNorm2d(base_width * 8), 113 | nn.ReLU(inplace=True) 114 | ) 115 | 116 | 117 | self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 118 | nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), 119 | nn.BatchNorm2d(base_width * 4), 120 | nn.ReLU(inplace=True)) 121 | self.db1 = nn.Sequential( 122 | nn.Conv2d(base_width*(4+8), base_width*4, kernel_size=3, padding=1), 123 | nn.BatchNorm2d(base_width*4), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), 126 | nn.BatchNorm2d(base_width * 4), 127 | nn.ReLU(inplace=True) 128 | ) 129 | 130 | self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 131 | nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), 132 | nn.BatchNorm2d(base_width * 2), 133 | nn.ReLU(inplace=True)) 134 | self.db2 = nn.Sequential( 135 | nn.Conv2d(base_width*(2+4), base_width*2, kernel_size=3, padding=1), 136 | nn.BatchNorm2d(base_width*2), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), 139 | nn.BatchNorm2d(base_width * 2), 140 | nn.ReLU(inplace=True) 141 | ) 142 | 143 | self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 144 | nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), 145 | nn.BatchNorm2d(base_width), 146 | nn.ReLU(inplace=True)) 147 | self.db3 = nn.Sequential( 148 | nn.Conv2d(base_width*(2+1), base_width, kernel_size=3, padding=1), 149 | nn.BatchNorm2d(base_width), 150 | nn.ReLU(inplace=True), 151 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 152 | nn.BatchNorm2d(base_width), 153 | nn.ReLU(inplace=True) 154 | ) 155 | 156 | self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 157 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 158 | nn.BatchNorm2d(base_width), 159 | nn.ReLU(inplace=True)) 160 | self.db4 = nn.Sequential( 161 | nn.Conv2d(base_width*2, base_width, kernel_size=3, padding=1), 162 | nn.BatchNorm2d(base_width), 163 | nn.ReLU(inplace=True), 164 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 165 | nn.BatchNorm2d(base_width), 166 | nn.ReLU(inplace=True) 167 | ) 168 | 169 | 170 | 171 | self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) 172 | 173 | def forward(self, b1,b2,b3,b4,b5,b6): 174 | up_b = self.up_b(b6) 175 | cat_b = torch.cat((up_b,b5),dim=1) 176 | db_b = self.db_b(cat_b) 177 | 178 | up1 = self.up1(db_b) 179 | cat1 = torch.cat((up1,b4),dim=1) 180 | db1 = self.db1(cat1) 181 | 182 | up2 = self.up2(db1) 183 | cat2 = torch.cat((up2,b3),dim=1) 184 | db2 = self.db2(cat2) 185 | 186 | up3 = self.up3(db2) 187 | cat3 = torch.cat((up3,b2),dim=1) 188 | db3 = self.db3(cat3) 189 | 190 | up4 = self.up4(db3) 191 | cat4 = torch.cat((up4,b1),dim=1) 192 | db4 = self.db4(cat4) 193 | 194 | out = self.fin_out(db4) 195 | return out 196 | 197 | 198 | 199 | class EncoderReconstructive(nn.Module): 200 | def __init__(self, in_channels, base_width): 201 | super(EncoderReconstructive, self).__init__() 202 | self.block1 = nn.Sequential( 203 | nn.Conv2d(in_channels,base_width, kernel_size=3, padding=1), 204 | nn.BatchNorm2d(base_width), 205 | nn.ReLU(inplace=True), 206 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 207 | nn.BatchNorm2d(base_width), 208 | nn.ReLU(inplace=True)) 209 | self.mp1 = nn.Sequential(nn.MaxPool2d(2)) 210 | self.block2 = nn.Sequential( 211 | nn.Conv2d(base_width,base_width*2, kernel_size=3, padding=1), 212 | nn.BatchNorm2d(base_width*2), 213 | nn.ReLU(inplace=True), 214 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1), 215 | nn.BatchNorm2d(base_width*2), 216 | nn.ReLU(inplace=True)) 217 | self.mp2 = nn.Sequential(nn.MaxPool2d(2)) 218 | self.block3 = nn.Sequential( 219 | nn.Conv2d(base_width*2,base_width*4, kernel_size=3, padding=1), 220 | nn.BatchNorm2d(base_width*4), 221 | nn.ReLU(inplace=True), 222 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1), 223 | nn.BatchNorm2d(base_width*4), 224 | nn.ReLU(inplace=True)) 225 | self.mp3 = nn.Sequential(nn.MaxPool2d(2)) 226 | self.block4 = nn.Sequential( 227 | nn.Conv2d(base_width*4,base_width*8, kernel_size=3, padding=1), 228 | nn.BatchNorm2d(base_width*8), 229 | nn.ReLU(inplace=True), 230 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 231 | nn.BatchNorm2d(base_width*8), 232 | nn.ReLU(inplace=True)) 233 | self.mp4 = nn.Sequential(nn.MaxPool2d(2)) 234 | self.block5 = nn.Sequential( 235 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1), 236 | nn.BatchNorm2d(base_width*8), 237 | nn.ReLU(inplace=True), 238 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 239 | nn.BatchNorm2d(base_width*8), 240 | nn.ReLU(inplace=True)) 241 | 242 | 243 | def forward(self, x): 244 | b1 = self.block1(x) 245 | mp1 = self.mp1(b1) 246 | b2 = self.block2(mp1) 247 | mp2 = self.mp3(b2) 248 | b3 = self.block3(mp2) 249 | mp3 = self.mp3(b3) 250 | b4 = self.block4(mp3) 251 | mp4 = self.mp4(b4) 252 | b5 = self.block5(mp4) 253 | return b5 254 | 255 | 256 | class DecoderReconstructive(nn.Module): 257 | def __init__(self, base_width, out_channels=1): 258 | super(DecoderReconstructive, self).__init__() 259 | 260 | self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 261 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), 262 | nn.BatchNorm2d(base_width * 8), 263 | nn.ReLU(inplace=True)) 264 | self.db1 = nn.Sequential( 265 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1), 266 | nn.BatchNorm2d(base_width*8), 267 | nn.ReLU(inplace=True), 268 | nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), 269 | nn.BatchNorm2d(base_width * 4), 270 | nn.ReLU(inplace=True) 271 | ) 272 | 273 | self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 274 | nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), 275 | nn.BatchNorm2d(base_width * 4), 276 | nn.ReLU(inplace=True)) 277 | self.db2 = nn.Sequential( 278 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1), 279 | nn.BatchNorm2d(base_width*4), 280 | nn.ReLU(inplace=True), 281 | nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), 282 | nn.BatchNorm2d(base_width * 2), 283 | nn.ReLU(inplace=True) 284 | ) 285 | 286 | self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 287 | nn.Conv2d(base_width * 2, base_width*2, kernel_size=3, padding=1), 288 | nn.BatchNorm2d(base_width*2), 289 | nn.ReLU(inplace=True)) 290 | # cat with base*1 291 | self.db3 = nn.Sequential( 292 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1), 293 | nn.BatchNorm2d(base_width*2), 294 | nn.ReLU(inplace=True), 295 | nn.Conv2d(base_width*2, base_width*1, kernel_size=3, padding=1), 296 | nn.BatchNorm2d(base_width*1), 297 | nn.ReLU(inplace=True) 298 | ) 299 | 300 | self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 301 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 302 | nn.BatchNorm2d(base_width), 303 | nn.ReLU(inplace=True)) 304 | self.db4 = nn.Sequential( 305 | nn.Conv2d(base_width*1, base_width, kernel_size=3, padding=1), 306 | nn.BatchNorm2d(base_width), 307 | nn.ReLU(inplace=True), 308 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), 309 | nn.BatchNorm2d(base_width), 310 | nn.ReLU(inplace=True) 311 | ) 312 | 313 | self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) 314 | #self.fin_out = nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1) 315 | 316 | def forward(self, b5): 317 | up1 = self.up1(b5) 318 | db1 = self.db1(up1) 319 | 320 | up2 = self.up2(db1) 321 | db2 = self.db2(up2) 322 | 323 | up3 = self.up3(db2) 324 | db3 = self.db3(up3) 325 | 326 | up4 = self.up4(db3) 327 | db4 = self.db4(up4) 328 | 329 | out = self.fin_out(db4) 330 | return out -------------------------------------------------------------------------------- /seg_network/perlin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def lerp_np(x,y,w): 6 | fin_out = (y-x)*w + x 7 | return fin_out 8 | 9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 10 | noise = np.zeros(shape) 11 | frequency = 1 12 | amplitude = 1 13 | for _ in range(octaves): 14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 15 | frequency *= 2 16 | amplitude *= persistence 17 | return noise 18 | 19 | 20 | def generate_perlin_noise_2d(shape, res): 21 | def f(t): 22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 23 | 24 | delta = (res[0] / shape[0], res[1] / shape[1]) 25 | d = (shape[0] // res[0], shape[1] // res[1]) 26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 27 | # Gradients 28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 29 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 34 | # Ramps 35 | n00 = np.sum(grid * g00, 2) 36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 39 | # Interpolation 40 | t = f(grid) 41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 44 | 45 | 46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 47 | delta = (res[0] / shape[0], res[1] / shape[1]) 48 | d = (shape[0] // res[0], shape[1] // res[1]) 49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 50 | 51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 54 | 55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 56 | dot = lambda grad, shift: ( 57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 59 | 60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 64 | t = fade(grid[:shape[0], :shape[1]]) 65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 66 | 67 | 68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 69 | delta = (res[0] / shape[0], res[1] / shape[1]) 70 | d = (shape[0] // res[0], shape[1] // res[1]) 71 | 72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 75 | 76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 77 | 0).repeat_interleave( 78 | d[1], 1) 79 | dot = lambda grad, shift: ( 80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 82 | 83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 84 | 85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 88 | t = fade(grid[:shape[0], :shape[1]]) 89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 90 | 91 | 92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 93 | noise = torch.zeros(shape) 94 | frequency = 1 95 | amplitude = 1 96 | for _ in range(octaves): 97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) 98 | frequency *= 2 99 | amplitude *= persistence 100 | return noise -------------------------------------------------------------------------------- /seg_network/tensorboard_visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.tensorboard import SummaryWriter 4 | from torchvision import datasets, transforms 5 | import os 6 | # Writer will output to ./runs/ directory by default 7 | 8 | class TensorboardVisualizer(): 9 | 10 | def __init__(self,log_dir='./logs/'): 11 | if not os.path.exists(log_dir): 12 | os.makedirs(log_dir) 13 | self.writer = SummaryWriter(log_dir=log_dir) 14 | 15 | def visualize_image_batch(self,image_batch,n_iter,image_name='Image_batch'): 16 | grid = torchvision.utils.make_grid(image_batch) 17 | self.writer.add_image(image_name,grid,n_iter) 18 | 19 | def plot_loss(self, loss_val, n_iter, loss_name='loss'): 20 | self.writer.add_scalar(loss_name, loss_val, n_iter) 21 | 22 | -------------------------------------------------------------------------------- /seg_network/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.diffusion import Model 4 | from models.ema import EMAHelper 5 | from omegaconf import OmegaConf 6 | from PIL import Image 7 | from data_loader import MVTecDRAEMTestDataset, MVTecDRAEMTrainDataset 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | from sklearn.metrics import roc_auc_score, average_precision_score 11 | from model_unet import ReconstructiveSubNetwork, DiscriminativeSubNetwork 12 | import os 13 | import cv2 14 | from rec_network.util import instantiate_from_config 15 | from rec_network.models.diffusion.ddim import DDIMSampler 16 | 17 | 18 | def write_results_to_file(run_name, image_auc, pixel_auc, image_ap, pixel_ap): 19 | if not os.path.exists('./outputs/'): 20 | os.makedirs('./outputs/') 21 | 22 | fin_str = "img_auc," + run_name 23 | for i in image_auc: 24 | fin_str += "," + str(np.round(i, 3)) 25 | fin_str += "," + str(np.round(np.mean(image_auc), 3)) 26 | fin_str += "\n" 27 | fin_str += "pixel_auc," + run_name 28 | for i in pixel_auc: 29 | fin_str += "," + str(np.round(i, 3)) 30 | fin_str += "," + str(np.round(np.mean(pixel_auc), 3)) 31 | fin_str += "\n" 32 | fin_str += "img_ap," + run_name 33 | for i in image_ap: 34 | fin_str += "," + str(np.round(i, 3)) 35 | fin_str += "," + str(np.round(np.mean(image_ap), 3)) 36 | fin_str += "\n" 37 | fin_str += "pixel_ap," + run_name 38 | for i in pixel_ap: 39 | fin_str += "," + str(np.round(i, 3)) 40 | fin_str += "," + str(np.round(np.mean(pixel_ap), 3)) 41 | fin_str += "\n" 42 | fin_str += "--------------------------\n" 43 | 44 | with open("./outputs/results.txt", 'a+') as file: 45 | file.write(fin_str) 46 | 47 | 48 | def test(obj_name, mvtec_path, checkpoint_path, base_model_name): 49 | 50 | img_dim = 256 51 | run_name = base_model_name + "_" + obj_name + '_' 52 | 53 | config = OmegaConf.load("../configs/mvtec.yaml") 54 | 55 | model = instantiate_from_config(config.model) 56 | 57 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 58 | model = model.to(device) 59 | sampler = DDIMSampler(model) 60 | 61 | model_seg = DiscriminativeSubNetwork(in_channels=9, out_channels=2) 62 | model_seg.load_state_dict(torch.load(os.path.join(checkpoint_path, run_name+"_seg.pckl"), map_location='cuda:0')) 63 | model_seg.cuda() 64 | model_seg.eval() 65 | 66 | dataset = MVTecDRAEMTestDataset(mvtec_path + obj_name + "/test", resize_shape=[img_dim, img_dim]) 67 | dataloader = DataLoader(dataset, batch_size=1, 68 | shuffle=False, num_workers=0) 69 | 70 | total_pixel_scores = np.zeros((img_dim * img_dim * len(dataset))) 71 | total_gt_pixel_scores = np.zeros((img_dim * img_dim * len(dataset))) 72 | mask_cnt = 0 73 | 74 | anomaly_score_gt = [] 75 | anomaly_score_prediction = [] 76 | 77 | cnt_display = 0 78 | 79 | for i_batch, sample_batched in enumerate(dataloader): 80 | gray_batch = sample_batched["image"].cuda() 81 | 82 | is_normal = sample_batched["has_anomaly"].detach().numpy()[0, 0] 83 | anomaly_score_gt.append(is_normal) 84 | true_mask = sample_batched["mask"] 85 | true_mask_cv = true_mask.detach().numpy()[0, :, :, :].transpose((1, 2, 0)) 86 | 87 | c = model.cond_stage_model.encode(gray_batch) 88 | c = c.mode() 89 | noise = torch.randn_like(c) 90 | t = torch.randint(400, 500, (c.shape[0],), device=device).long() 91 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise) 92 | 93 | shape = c.shape[1:] 94 | samples_ddim, _ = sampler.sample(S=10, 95 | conditioning=c, # or conditioning=c_noisy 96 | batch_size=c.shape[0], 97 | shape=shape, 98 | verbose=False) 99 | 100 | gray_rec = model.decode_first_stage(samples_ddim) 101 | 102 | samples_ddim1 = 0.5 * samples_ddim + 0.5 * c 103 | gray_rec1 = model.decode_first_stage(samples_ddim1) 104 | 105 | joined_in = torch.cat((gray_rec.detach(), gray_rec1.detach(), gray_batch), dim=1) 106 | 107 | out_mask = model_seg(joined_in) 108 | out_mask_sm = torch.softmax(out_mask, dim=1) 109 | 110 | t_mask = out_mask_sm[:, 1:, :, :] 111 | 112 | outpath = os.path.join('./samples', obj_name, 'test', 'rec_images' + str(cnt_display) + '.jpg') 113 | sample = gray_rec.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 114 | cv2.imwrite(outpath, sample) 115 | 116 | outpath = os.path.join('./samples', obj_name, 'test', 'gt_images' + str(cnt_display) + '.jpg') 117 | sample = gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 118 | cv2.imwrite(outpath, sample) 119 | 120 | outpath = os.path.join('./samples', obj_name, 'test', 'out_masks' + str(cnt_display) + '.jpg') 121 | sample = t_mask[0].detach().cpu().numpy()[0] * 255 122 | cv2.imwrite(outpath, sample) 123 | 124 | outpath = os.path.join('./samples', obj_name, 'test', 'in_masks' + str(cnt_display) + '.jpg') 125 | sample = true_mask[0].detach().cpu().numpy()[0] * 255 126 | cv2.imwrite(outpath, sample) 127 | 128 | heatmap = t_mask[0].detach().cpu().numpy()[0] 129 | heatmap = heatmap / np.max(heatmap) 130 | heatmap = np.uint8(255 * heatmap) 131 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 132 | show = heatmap * 0.5 + gray_batch.detach().cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 133 | outpath = os.path.join('./samples', obj_name, 'test', 'heatmap' + str(cnt_display) + '.jpg') 134 | cv2.imwrite(outpath, show) 135 | cnt_display += 1 136 | 137 | out_mask_cv = out_mask_sm[0, 1, :, :].detach().cpu().numpy() 138 | out_mask_averaged = torch.nn.functional.avg_pool2d(out_mask_sm[:, 1:, :, :], 21, stride=1, 139 | padding=21 // 2).cpu().detach().numpy() 140 | image_score = np.max(out_mask_averaged) 141 | anomaly_score_prediction.append(image_score) 142 | 143 | flat_true_mask = true_mask_cv.flatten() 144 | flat_out_mask = out_mask_cv.flatten() 145 | total_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_out_mask 146 | total_gt_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_true_mask 147 | mask_cnt += 1 148 | 149 | anomaly_score_prediction = np.array(anomaly_score_prediction) 150 | anomaly_score_gt = np.array(anomaly_score_gt) 151 | auroc = roc_auc_score(anomaly_score_gt, anomaly_score_prediction) 152 | ap = average_precision_score(anomaly_score_gt, anomaly_score_prediction) 153 | 154 | total_gt_pixel_scores = total_gt_pixel_scores.astype(np.uint8) 155 | total_gt_pixel_scores = total_gt_pixel_scores[:img_dim * img_dim * mask_cnt] 156 | total_pixel_scores = total_pixel_scores[:img_dim * img_dim * mask_cnt] 157 | auroc_pixel = roc_auc_score(total_gt_pixel_scores, total_pixel_scores) 158 | ap_pixel = average_precision_score(total_gt_pixel_scores, total_pixel_scores) 159 | 160 | print(obj_name) 161 | print("AUC Image: " + str(auroc)) 162 | print("AP Image: " + str(ap)) 163 | print("AUC Pixel: " + str(auroc_pixel)) 164 | print("AP Pixel: " + str(ap_pixel)) 165 | 166 | write_results_to_file(run_name, auroc, auroc_pixel, ap, ap_pixel) 167 | 168 | 169 | if __name__ == "__main__": 170 | import argparse 171 | 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--gpu_id', action='store', type=int, required=True) 174 | parser.add_argument('--base_model_name', action='store', type=str, required=True) 175 | parser.add_argument('--data_path', action='store', type=str, required=True) 176 | parser.add_argument('--checkpoint_path', action='store', type=str, required=True) 177 | parser.add_argument('--ema', action='store_true') 178 | 179 | parser.add_argument("--logit_transform", default=False) 180 | parser.add_argument("--uniform_dequantization", default=False) 181 | parser.add_argument("--gaussian_dequantization", default=False) 182 | parser.add_argument("--random_flip", default=True) 183 | parser.add_argument("--rescaled", default=True) 184 | parser.add_argument("--sample_type", type=str, default="generalized", 185 | help="sampling approach (generalized or ddpm_noisy)") 186 | parser.add_argument("--skip_type", type=str, default="uniform", help="skip according to (uniform or quadratic)") 187 | 188 | args = parser.parse_args() 189 | 190 | obj_list = ['capsule', 191 | 'bottle', 192 | 'carpet', 193 | 'leather', 194 | 'pill', 195 | 'transistor', 196 | 'tile', 197 | 'cable', 198 | 'zipper', 199 | 'toothbrush', 200 | 'metal_nut', 201 | 'hazelnut', 202 | 'screw', 203 | 'grid', 204 | 'wood' 205 | ] 206 | obj_name = 'bottle' 207 | 208 | with torch.cuda.device(args.gpu_id): 209 | test(obj_name, args.data_path, args.checkpoint_path, args.base_model_name) 210 | -------------------------------------------------------------------------------- /seg_network/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.diffusion import Model 3 | from models.ema import EMAHelper 4 | from omegaconf import OmegaConf 5 | from data_loader import MVTecDRAEMTrainDataset 6 | from torch.utils.data import DataLoader 7 | from torch import optim 8 | from tensorboard_visualizer import TensorboardVisualizer 9 | from model_unet import DiscriminativeSubNetwork 10 | from loss import FocalLoss, SSIM 11 | from rec_network.util import instantiate_from_config 12 | from rec_network.models.diffusion.ddim import DDIMSampler 13 | import cv2 14 | import os 15 | 16 | 17 | def get_lr(optimizer): 18 | for param_group in optimizer.param_groups: 19 | return param_group['lr'] 20 | 21 | 22 | def weights_init(m): 23 | classname = m.__class__.__name__ 24 | if classname.find('Conv') != -1: 25 | m.weight.data.normal_(0.0, 0.02) 26 | elif classname.find('BatchNorm') != -1: 27 | m.weight.data.normal_(1.0, 0.02) 28 | m.bias.data.fill_(0) 29 | 30 | 31 | def train_on_device(obj_name, args): 32 | if not os.path.exists(args.checkpoint_path): 33 | os.makedirs(args.checkpoint_path) 34 | 35 | if not os.path.exists(args.log_path): 36 | os.makedirs(args.log_path) 37 | 38 | run_name = 'DRAEM_test_' + str(args.lr) + '_' + str(args.epochs) + '_bs' + str(args.bs) + "_" + obj_name + '_' 39 | 40 | visualizer = TensorboardVisualizer(log_dir=os.path.join(args.log_path, run_name + "/")) 41 | 42 | config = OmegaConf.load("../configs/mvtec.yaml") 43 | 44 | model = instantiate_from_config(config.model) 45 | 46 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 47 | model = model.to(device) 48 | sampler = DDIMSampler(model) 49 | 50 | model_seg = DiscriminativeSubNetwork(in_channels=9, out_channels=2) 51 | model_seg.cuda() 52 | model_seg.apply(weights_init) 53 | 54 | optimizer = torch.optim.Adam([ 55 | {"params": model_seg.parameters(), "lr": args.lr}]) 56 | 57 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [args.epochs * 0.8, args.epochs * 0.9], gamma=0.2, 58 | last_epoch=-1) 59 | 60 | loss_focal = FocalLoss() 61 | 62 | dataset = MVTecDRAEMTrainDataset(args.data_path + obj_name + "/train/good/", args.anomaly_source_path, 63 | resize_shape=[256, 256]) 64 | 65 | dataloader = DataLoader(dataset, batch_size=args.bs, 66 | shuffle=True, num_workers=0) 67 | 68 | n_iter = 0 69 | for epoch in range(args.epochs): 70 | lr = scheduler.get_last_lr()[0] 71 | print("Epoch: " + str(epoch) + " Learning rate: " + str(lr)) 72 | for i_batch, sample_batched in enumerate(dataloader): 73 | gray_batch = sample_batched["image"].cuda() 74 | aug_gray_batch = sample_batched["augmented_image"].cuda() 75 | anomaly_mask = sample_batched["anomaly_mask"].cuda() 76 | 77 | c = model.cond_stage_model.encode(aug_gray_batch) 78 | c = c.mode() 79 | 80 | shape = c.shape[1:] 81 | noise = torch.randn_like(c) 82 | t = torch.randint(400, 500, (c.shape[0],), device=device).long() 83 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise) 84 | samples_ddim, _ = sampler.sample(S=50, 85 | conditioning=c, # or conditioning=c_noisy 86 | batch_size=c.shape[0], 87 | shape=shape, 88 | verbose=False) 89 | gray_rec = model.decode_first_stage(samples_ddim) 90 | 91 | samples_ddim1 = 0.5 * samples_ddim + 0.5 * c 92 | gray_rec1 = model.decode_first_stage(samples_ddim1) 93 | 94 | joined_in = torch.cat((gray_rec, gray_rec1, aug_gray_batch), dim=1) 95 | 96 | out_mask = model_seg(joined_in) 97 | out_mask_sm = torch.softmax(out_mask, dim=1) 98 | 99 | segment_loss = loss_focal(out_mask_sm, anomaly_mask) 100 | loss = segment_loss 101 | 102 | optimizer.zero_grad() 103 | 104 | loss.backward() 105 | optimizer.step() 106 | 107 | if args.visualize and n_iter % 200 == 0: 108 | visualizer.plot_loss(segment_loss, n_iter, loss_name='segment_loss') 109 | if args.visualize and n_iter % 100 == 0: 110 | t_mask = out_mask_sm[:, 1:, :, :] 111 | 112 | outpath = os.path.join('./samples',obj_name,'train', 'batch_augmented' + str(n_iter) + '.jpg') 113 | sample = aug_gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 114 | cv2.imwrite(outpath, sample) 115 | 116 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_target' + str(n_iter) + '.jpg') 117 | sample = gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 118 | cv2.imwrite(outpath, sample) 119 | 120 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_out' + str(n_iter) + '.jpg') 121 | sample = gray_rec.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 122 | cv2.imwrite(outpath, sample) 123 | 124 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_inter' + str(n_iter) + '.jpg') 125 | sample = gray_rec1.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 126 | cv2.imwrite(outpath, sample) 127 | 128 | outpath = os.path.join('./samples',obj_name,'train', 'mask_target' + str(n_iter) + '.jpg') 129 | sample = anomaly_mask[0].detach().cpu().numpy()[0] * 255 130 | cv2.imwrite(outpath, sample) 131 | 132 | outpath = os.path.join('./samples',obj_name,'train', 'mask_out' + str(n_iter) + '.jpg') 133 | sample = t_mask[0].detach().cpu().numpy()[0] * 255 134 | cv2.imwrite(outpath, sample) 135 | 136 | n_iter += 1 137 | 138 | scheduler.step() 139 | 140 | torch.save(model_seg.state_dict(), os.path.join(args.checkpoint_path, run_name + "seg.pckl")) 141 | # if epoch % 100 == 0 : 142 | # torch.save(model_seg.state_dict(), 143 | # os.path.join(args.checkpoint_path, run_name + str(epoch) + "_seg.pckl")) 144 | 145 | 146 | if __name__ == "__main__": 147 | import argparse 148 | 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--obj_id', action='store', type=int, required=True) 151 | parser.add_argument('--bs', action='store', type=int, required=True) 152 | parser.add_argument('--lr', action='store', type=float, required=True) 153 | parser.add_argument('--epochs', action='store', type=int, required=True) 154 | parser.add_argument('--gpu_id', action='store', type=int, default=0, required=False) 155 | parser.add_argument('--data_path', action='store', type=str, required=True) 156 | parser.add_argument('--anomaly_source_path', action='store', type=str, required=True) 157 | parser.add_argument('--checkpoint_path', action='store', type=str, required=True) 158 | parser.add_argument('--log_path', action='store', type=str, required=True) 159 | parser.add_argument('--visualize', action='store_true') 160 | 161 | parser.add_argument("--type", type=str, default="simple") 162 | parser.add_argument("--in_channels", type=int, default=3) 163 | parser.add_argument("--out_ch", type=int, default=3) 164 | parser.add_argument("--ch", type=int, default=128) 165 | parser.add_argument("--ch_mult", type=list, default=[1, 1, 2, 2, 4, 4]) 166 | parser.add_argument("--num_res_blocks", type=int, default=2) 167 | parser.add_argument("--attn_resolutions", type=list, default=[16, ]) 168 | parser.add_argument("--dropout", type=float, default=0.0) 169 | parser.add_argument("--var_type", type=str, default="fixedsmall") 170 | parser.add_argument("--ema_rate", type=float, default=0.999) 171 | parser.add_argument('--ema', action='store_true') 172 | parser.add_argument('--resamp_with_conv', action='store_true') 173 | parser.add_argument("--num_diffusion_timesteps", type=int, default=1000) 174 | parser.add_argument('--ddim_log_path', action='store', type=str, default="../ddim-main/mvtec/logs") 175 | parser.add_argument("--timesteps", type=int, default=10, help="number of steps involved") 176 | parser.add_argument("--eta", type=float, default=0.0) 177 | 178 | parser.add_argument("--logit_transform", default=False) 179 | parser.add_argument("--uniform_dequantization", default=False) 180 | parser.add_argument("--gaussian_dequantization", default=False) 181 | parser.add_argument("--random_flip", default=True) 182 | parser.add_argument("--rescaled", default=True) 183 | parser.add_argument("--sample_type", type=str, default="generalized", 184 | help="sampling approach (generalized or ddpm_noisy)") 185 | parser.add_argument("--skip_type", type=str, default="uniform", help="skip according to (uniform or quadratic)") 186 | 187 | args = parser.parse_args() 188 | 189 | obj_name = 'bottle' 190 | # 'capsule' 191 | # 'carpet' 192 | # 'leather' 193 | # 'pill' 194 | # 'transistor' 195 | # 'tile' 196 | # 'cable' 197 | # 'zipper' 198 | # 'toothbrush' 199 | # 'metal_nut' 200 | # 'hazelnut' 201 | # 'screw' 202 | # 'grid' 203 | # 'wood' 204 | 205 | with torch.cuda.device(args.gpu_id): 206 | train_on_device(obj_name, args) 207 | 208 | --------------------------------------------------------------------------------