├── .gitignore
├── README.md
├── configs
├── kl.yaml
└── mvtec.yaml
├── environment.yaml
├── rec_network
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── mvtec.py
│ └── perlin.py
├── lr_scheduler.py
├── main.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── scripts
├── download_dataset.sh
└── mvtec.py
└── seg_network
├── data_loader.py
├── loss.py
├── model_unet.py
├── perlin.py
├── tensorboard_visualizer.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffAD
2 | [ICCV2023] Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model
3 |
4 | ```
5 | @inproceedings{zhang2023unsupervised,
6 | title={Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model},
7 | author={Zhang, Xinyi and Li, Naiqi and Li, Jiawei and Dai, Tao and Jiang, Yong and Xia, Shu-Tao},
8 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
9 | pages={6782--6791},
10 | year={2023}
11 | }
12 | ```
13 |
14 | ## Method overview
15 |
16 |
17 | ## Installation
18 | ```
19 | conda env create -f environment.yaml
20 | conda activate DiffAD
21 | ```
22 |
23 | ## Dataset
24 | Following DRAEM, we use the MVTec-AD and DTD dataset. You can run the download_dataset.sh script from the project directory to download the MVTec and the DTD datasets to the datasets folder in the project directory:
25 | ```
26 | ./scripts/download_dataset.sh
27 | ```
28 |
29 | ## Training
30 | ### Reconstruction sub-network
31 | The reconstrucion sub-network is based on the latent diffusion model.
32 | #### Training Auto-encoder
33 | ```
34 | cd rec_network
35 | CUDA_VISIBLE_DEVICES= python main.py --base configs/kl.yaml -t --gpus 0,
36 | ```
37 | #### Training LDMs
38 | ```
39 | CUDA_VISIBLE_DEVICES= python main.py --base configs/mvtec.yaml -t --gpus 0 -max_epochs 4000,
40 | ```
41 |
42 | ### Discriminative sub-network
43 | ```
44 | cd seg_network
45 | CUDA_VISIBLE_DEVICES= python train.py --gpu_id 0 --lr 0.001 --bs 32 --epochs 700 --data_path ./datasets/mvtec/ --anomaly_source_path ./datasets/dtd/images/ --checkpoint_path ./checkpoints/obj_name --log_path ./logs/
46 | ```
47 |
48 | ## Evaluating
49 | ### Reconstrucion performance
50 | After training the reconstruction sub-network, you can test the reconstruction performance with the anomalous inputs:
51 | ```
52 | python scripts/mvtec.py
53 | ```
54 | For some samples with severe deformations, such as missing transistors, you can add some noise to the anomalous conditions to adjust the sampling.
55 |
56 | ### Anomaly segmentation
57 | ```
58 | cd seg_network
59 | python test.py --gpu_id 0 --base_model_name "seg_network" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/obj_name/
60 | ```
61 |
62 |
63 |
--------------------------------------------------------------------------------
/configs/kl.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: rec_network.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: rec_network.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.5
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 8
30 | num_workers: 0
31 | wrap: false
32 | train:
33 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset
34 | params:
35 | root_dir: './datasets/mvtec/grid/train/good'
36 | anomaly_source_path: './datasets/dtd/images'
37 | resize_shape: [256, 256]
38 | validation:
39 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset
40 | params:
41 | root_dir: './datasets/mvtec/grid/train/good'
42 | anomaly_source_path: './datasets/dtd/images'
43 | resize_shape: [ 256, 256 ]
44 | lightning:
45 | callbacks:
46 | image_logger:
47 | target: main.ImageLogger
48 | params:
49 | batch_frequency: 1000
50 | max_images: 8
51 | increase_log_steps: True
52 |
53 | trainer:
54 | benchmark: True
55 | accumulate_grad_batches: 2
56 |
--------------------------------------------------------------------------------
/configs/mvtec.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: rec_network.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | #ckpt_path: modify the ckpt_path of rec_network when training the seg_network
6 | linear_start: 0.0015
7 | linear_end: 0.02
8 | num_timesteps_cond: 1
9 | log_every_t: 100
10 | timesteps: 1000
11 | first_stage_key: image
12 | cond_stage_key: augmented_image
13 | image_size: 32
14 | channels: 4
15 | concat_mode: true
16 | cond_stage_trainable: false
17 | conditioning_key: concat
18 | monitor: val/loss_simple_ema
19 | unet_config:
20 | target: rec_network.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 32
23 | in_channels: 8
24 | out_channels: 4
25 | model_channels: 256
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 4
35 | num_head_channels: 32
36 |
37 | first_stage_config:
38 | target: rec_network.models.autoencoder.AutoencoderKL
39 | params:
40 | embed_dim: 4
41 | monitor: "val/rec_loss"
42 | ckpt_path: "./VAE/bottle.ckpt" #TODO: modify the ckpt_path of VAE
43 | ddconfig:
44 | double_z: True
45 | z_channels: 4
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
51 | num_res_blocks: 2
52 | attn_resolutions: [ ]
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_first_stage__
57 |
58 | data:
59 | target: main.DataModuleFromConfig
60 | params:
61 | batch_size: 32
62 | num_workers: 0
63 | wrap: false
64 | train:
65 | target: rec_network.data.mvtec.MVTecDRAEMTrainDataset
66 | params:
67 | root_dir: './datasets/mvtec/bottle/train/good' #TODO: modify the path of training samples
68 | anomaly_source_path: './datasets/dtd/images'
69 | resize_shape:
70 | - 256
71 | - 256
72 |
73 |
74 | lightning:
75 | callbacks:
76 | metrics_over_trainsteps_checkpoint:
77 | target: pytorch_lightning.callbacks.ModelCheckpoint
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 1000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
87 |
88 |
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: DiffAD
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.4.2
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - torch-fidelity==0.3.0
24 | - transformers==4.3.1
25 | - imgaug
26 |
--------------------------------------------------------------------------------
/rec_network/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/data/__init__.py
--------------------------------------------------------------------------------
/rec_network/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/rec_network/data/mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | import cv2
6 | import glob
7 | import imgaug.augmenters as iaa
8 | from perlin import rand_perlin_2d_np
9 |
10 | obj_list = ['capsule',
11 | 'bottle',
12 | 'carpet',
13 | 'leather',
14 | 'pill',
15 | 'transistor',
16 | 'tile',
17 | 'cable',
18 | 'zipper',
19 | 'toothbrush',
20 | 'metal_nut',
21 | 'hazelnut',
22 | 'screw',
23 | 'grid',
24 | 'wood']
25 |
26 | class MVTecDRAEMTestDataset(Dataset):
27 |
28 | def __init__(self, root_dir, resize_shape=None):
29 | self.root_dir = root_dir
30 | self.images = sorted(glob.glob(root_dir+"/*/*.png"))
31 |
32 | # self.images = []
33 | # for obj in obj_list:
34 | # self.images += sorted(glob.glob(root_dir + obj + "/test/*/*.png"))
35 |
36 | self.resize_shape=resize_shape
37 |
38 | def __len__(self):
39 | return len(self.images)
40 |
41 | def transform_image(self, image_path, mask_path):
42 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)
43 | if mask_path is not None:
44 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
45 | else:
46 | mask = np.zeros((image.shape[0],image.shape[1]))
47 | if self.resize_shape != None:
48 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
49 | mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0]))
50 |
51 | image = image / 255.0
52 | mask = mask / 255.0
53 |
54 | image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32)
55 | mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32)
56 |
57 | image = np.transpose(image, (2, 0, 1))
58 | mask = np.transpose(mask, (2, 0, 1))
59 | return image, mask
60 |
61 | def __getitem__(self, idx):
62 | if torch.is_tensor(idx):
63 | idx = idx.tolist()
64 |
65 | img_path = self.images[idx]
66 | dir_path, file_name = os.path.split(img_path)
67 | base_dir = os.path.basename(dir_path)
68 | if base_dir == 'good':
69 | image, mask = self.transform_image(img_path, None)
70 | has_anomaly = np.array([0], dtype=np.float32)
71 | else:
72 | mask_path = os.path.join(dir_path, '../../ground_truth/')
73 | mask_path = os.path.join(mask_path, base_dir)
74 | mask_file_name = file_name.split(".")[0]+"_mask.png"
75 | mask_path = os.path.join(mask_path, mask_file_name)
76 | image, mask = self.transform_image(img_path, mask_path)
77 | has_anomaly = np.array([1], dtype=np.float32)
78 |
79 | sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx}
80 |
81 | return sample
82 |
83 |
84 |
85 | class MVTecDRAEMTrainDataset(Dataset):
86 |
87 | def __init__(self, root_dir, anomaly_source_path, resize_shape=None):
88 | """
89 | Args:
90 | root_dir (string): Directory with all the images.
91 | transform (callable, optional): Optional transform to be applied
92 | on a sample.
93 | """
94 |
95 | self.root_dir = root_dir
96 | self.resize_shape=resize_shape
97 |
98 | # /data / zhangxinyi / dataset / mvtec / toothbrush / train / good
99 | # self.image_paths = []
100 | # for obj in obj_list:
101 | # obj = 'cable'
102 | # self.image_paths += sorted(glob.glob(root_dir + obj + "/train/good/*.png"))
103 | # print(len(self.image_paths))
104 |
105 | self.image_paths = sorted(glob.glob(root_dir + "/*.png"))
106 |
107 | self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path+"/*/*.jpg"))
108 |
109 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
110 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
111 | iaa.pillike.EnhanceSharpness(),
112 | iaa.AddToHueAndSaturation((-50,50),per_channel=True),
113 | iaa.Solarize(0.5, threshold=(32,128)),
114 | iaa.Posterize(),
115 | iaa.Invert(),
116 | iaa.pillike.Autocontrast(),
117 | iaa.pillike.Equalize(),
118 | iaa.Affine(rotate=(-45, 45))
119 | ]
120 |
121 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
122 |
123 |
124 | def __len__(self):
125 | return len(self.image_paths)
126 |
127 |
128 | def randAugmenter(self):
129 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
130 | aug = iaa.Sequential([self.augmenters[aug_ind[0]],
131 | self.augmenters[aug_ind[1]],
132 | self.augmenters[aug_ind[2]]]
133 | )
134 | return aug
135 |
136 | def augment_image(self, image, anomaly_source_path):
137 | aug = self.randAugmenter()
138 | perlin_scale = 6
139 | min_perlin_scale = 0
140 | anomaly_source_img = cv2.imread(anomaly_source_path)
141 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize_shape[1], self.resize_shape[0]))
142 |
143 | anomaly_img_augmented = aug(image=anomaly_source_img)
144 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
145 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
146 |
147 | perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley))
148 | perlin_noise = self.rot(image=perlin_noise)
149 | threshold = 0.5
150 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
151 | perlin_thr = np.expand_dims(perlin_thr, axis=2)
152 |
153 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0
154 |
155 | beta = torch.rand(1).numpy()[0] * 0.8
156 |
157 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * (
158 | perlin_thr)
159 |
160 | no_anomaly = torch.rand(1).numpy()[0]
161 | if no_anomaly > 0.5:
162 | image = image.astype(np.float32)
163 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32)
164 | else:
165 | augmented_image = augmented_image.astype(np.float32)
166 | msk = (perlin_thr).astype(np.float32)
167 | augmented_image = msk * augmented_image + (1-msk)*image
168 | has_anomaly = 1.0
169 | if np.sum(msk) == 0:
170 | has_anomaly=0.0
171 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32)
172 |
173 | def transform_image(self, image_path, anomaly_source_path):
174 | image = cv2.imread(image_path)
175 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
176 |
177 | do_aug_orig = torch.rand(1).numpy()[0] > 0.7
178 | if do_aug_orig:
179 | image = self.rot(image=image)
180 |
181 | image = np.array(image).reshape((image.shape[0], image.shape[1], image.shape[2])).astype(np.float32) / 255.0
182 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path)
183 | augmented_image = np.transpose(augmented_image, (2, 0, 1))
184 | image = np.transpose(image, (2, 0, 1))
185 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1))
186 | return image, augmented_image, anomaly_mask, has_anomaly
187 |
188 | def __getitem__(self, idx):
189 | idx = torch.randint(0, len(self.image_paths), (1,)).item()
190 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
191 | image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(self.image_paths[idx],
192 | self.anomaly_source_paths[anomaly_source_idx])
193 | sample = {'image': image, "anomaly_mask": anomaly_mask,
194 | 'augmented_image': augmented_image, 'has_anomaly': has_anomaly, 'idx': idx}
195 | return sample
196 |
--------------------------------------------------------------------------------
/rec_network/data/perlin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | def lerp_np(x,y,w):
6 | fin_out = (y-x)*w + x
7 | return fin_out
8 |
9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5):
10 | noise = np.zeros(shape)
11 | frequency = 1
12 | amplitude = 1
13 | for _ in range(octaves):
14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1]))
15 | frequency *= 2
16 | amplitude *= persistence
17 | return noise
18 |
19 |
20 | def generate_perlin_noise_2d(shape, res):
21 | def f(t):
22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
23 |
24 | delta = (res[0] / shape[0], res[1] / shape[1])
25 | d = (shape[0] // res[0], shape[1] // res[1])
26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
27 | # Gradients
28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
29 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
34 | # Ramps
35 | n00 = np.sum(grid * g00, 2)
36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
39 | # Interpolation
40 | t = f(grid)
41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
44 |
45 |
46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
47 | delta = (res[0] / shape[0], res[1] / shape[1])
48 | d = (shape[0] // res[0], shape[1] // res[1])
49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
50 |
51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1)
54 |
55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1)
56 | dot = lambda grad, shift: (
57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
59 |
60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
64 | t = fade(grid[:shape[0], :shape[1]])
65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
66 |
67 |
68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
69 | delta = (res[0] / shape[0], res[1] / shape[1])
70 | d = (shape[0] // res[0], shape[1] // res[1])
71 |
72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
75 |
76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
77 | 0).repeat_interleave(
78 | d[1], 1)
79 | dot = lambda grad, shift: (
80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
82 |
83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
84 |
85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
88 | t = fade(grid[:shape[0], :shape[1]])
89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
90 |
91 |
92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
93 | noise = torch.zeros(shape)
94 | frequency = 1
95 | amplitude = 1
96 | for _ in range(octaves):
97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
98 | frequency *= 2
99 | amplitude *= persistence
100 | return noise
--------------------------------------------------------------------------------
/rec_network/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/rec_network/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 | from PIL import Image
6 | import numpy as np
7 |
8 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9 |
10 | from rec_network.modules.diffusionmodules.model import Encoder, Decoder
11 | from rec_network.modules.distributions.distributions import DiagonalGaussianDistribution
12 |
13 | from rec_network.util import instantiate_from_config
14 |
15 |
16 | class VQModel(pl.LightningModule):
17 | def __init__(self,
18 | ddconfig,
19 | lossconfig,
20 | n_embed,
21 | embed_dim,
22 | ckpt_path=None,
23 | ignore_keys=[],
24 | image_key="image",
25 | colorize_nlabels=None,
26 | monitor=None,
27 | batch_resize_range=None,
28 | scheduler_config=None,
29 | lr_g_factor=1.0,
30 | remap=None,
31 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
32 | use_ema=False
33 | ):
34 | super().__init__()
35 | self.embed_dim = embed_dim
36 | self.n_embed = n_embed
37 | self.image_key = image_key
38 | self.encoder = Encoder(**ddconfig)
39 | self.decoder = Decoder(**ddconfig)
40 | self.loss = instantiate_from_config(lossconfig)
41 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
42 | remap=remap,
43 | sane_index_shape=sane_index_shape)
44 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
45 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
46 | if colorize_nlabels is not None:
47 | assert type(colorize_nlabels)==int
48 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
49 | if monitor is not None:
50 | self.monitor = monitor
51 | self.batch_resize_range = batch_resize_range
52 | if self.batch_resize_range is not None:
53 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
54 |
55 | self.use_ema = use_ema
56 | if self.use_ema:
57 | self.model_ema = LitEma(self)
58 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
59 |
60 | if ckpt_path is not None:
61 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
62 | self.scheduler_config = scheduler_config
63 | self.lr_g_factor = lr_g_factor
64 |
65 | @contextmanager
66 | def ema_scope(self, context=None):
67 | if self.use_ema:
68 | self.model_ema.store(self.parameters())
69 | self.model_ema.copy_to(self)
70 | if context is not None:
71 | print(f"{context}: Switched to EMA weights")
72 | try:
73 | yield None
74 | finally:
75 | if self.use_ema:
76 | self.model_ema.restore(self.parameters())
77 | if context is not None:
78 | print(f"{context}: Restored training weights")
79 |
80 | def init_from_ckpt(self, path, ignore_keys=list()):
81 | sd = torch.load(path, map_location="cpu")["state_dict"]
82 | keys = list(sd.keys())
83 | for k in keys:
84 | for ik in ignore_keys:
85 | if k.startswith(ik):
86 | print("Deleting key {} from state_dict.".format(k))
87 | del sd[k]
88 | missing, unexpected = self.load_state_dict(sd, strict=False)
89 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
90 | if len(missing) > 0:
91 | print(f"Missing Keys: {missing}")
92 | print(f"Unexpected Keys: {unexpected}")
93 |
94 | def on_train_batch_end(self, *args, **kwargs):
95 | if self.use_ema:
96 | self.model_ema(self)
97 |
98 | def encode(self, x):
99 | h = self.encoder(x)
100 | h = self.quant_conv(h)
101 | quant, emb_loss, info = self.quantize(h)
102 | return quant, emb_loss, info
103 |
104 | def encode_to_prequant(self, x):
105 | h = self.encoder(x)
106 | h = self.quant_conv(h)
107 | return h
108 |
109 | def decode(self, quant):
110 | quant = self.post_quant_conv(quant)
111 | dec = self.decoder(quant)
112 | return dec
113 |
114 | def decode_code(self, code_b):
115 | quant_b = self.quantize.embed_code(code_b)
116 | dec = self.decode(quant_b)
117 | return dec
118 |
119 | def forward(self, input, return_pred_indices=False):
120 | quant, diff, (_,_,ind) = self.encode(input)
121 | dec = self.decode(quant)
122 | if return_pred_indices:
123 | return dec, diff, ind
124 | return dec, diff
125 |
126 | def get_input(self, batch, k):
127 | x = batch[k]
128 | if len(x.shape) == 3:
129 | x = x[..., None]
130 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
131 | x = x.to(memory_format=torch.contiguous_format).float()
132 | if self.batch_resize_range is not None:
133 | lower_size = self.batch_resize_range[0]
134 | upper_size = self.batch_resize_range[1]
135 | if self.global_step <= 4:
136 | # do the first few batches with max size to avoid later oom
137 | new_resize = upper_size
138 | else:
139 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
140 | if new_resize != x.shape[2]:
141 | x = F.interpolate(x, size=new_resize, mode="bicubic")
142 | x = x.detach()
143 | return x
144 |
145 | def training_step(self, batch, batch_idx, optimizer_idx):
146 | # https://github.com/pytorch/pytorch/issues/37142
147 | # try not to fool the heuristics
148 | x = self.get_input(batch, self.image_key)
149 |
150 | # outpath = './test/input.jpg'
151 | # sample = x.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
152 | # Image.fromarray(sample.astype(np.uint8)).save(outpath)
153 |
154 | xrec, qloss, ind = self(x, return_pred_indices=True)
155 |
156 | if optimizer_idx == 0:
157 | # autoencode
158 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
159 | last_layer=self.get_last_layer(), split="train"
160 | )
161 |
162 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
163 | return aeloss
164 |
165 | if optimizer_idx == 1:
166 | # discriminator
167 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
168 | last_layer=self.get_last_layer(), split="train")
169 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
170 | return discloss
171 |
172 | def validation_step(self, batch, batch_idx):
173 | log_dict = self._validation_step(batch, batch_idx)
174 | with self.ema_scope():
175 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
176 | return log_dict
177 |
178 | def _validation_step(self, batch, batch_idx, suffix=""):
179 | x = self.get_input(batch, self.image_key)
180 | xrec, qloss, ind = self(x, return_pred_indices=True)
181 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
182 | self.global_step,
183 | last_layer=self.get_last_layer(),
184 | split="val"+suffix
185 | )
186 |
187 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
188 | self.global_step,
189 | last_layer=self.get_last_layer(),
190 | split="val"+suffix
191 | )
192 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
193 | self.log(f"val{suffix}/rec_loss", rec_loss,
194 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
195 | self.log(f"val{suffix}/aeloss", aeloss,
196 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
197 |
198 | del log_dict_ae[f"val{suffix}/rec_loss"]
199 | self.log_dict(log_dict_ae)
200 | self.log_dict(log_dict_disc)
201 | return self.log_dict
202 |
203 | def configure_optimizers(self):
204 | lr_d = self.learning_rate
205 | lr_g = self.lr_g_factor*self.learning_rate
206 | print("lr_d", lr_d)
207 | print("lr_g", lr_g)
208 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
209 | list(self.decoder.parameters())+
210 | list(self.quantize.parameters())+
211 | list(self.quant_conv.parameters())+
212 | list(self.post_quant_conv.parameters()),
213 | lr=lr_g, betas=(0.5, 0.9))
214 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
215 | lr=lr_d, betas=(0.5, 0.9))
216 |
217 | if self.scheduler_config is not None:
218 | scheduler = instantiate_from_config(self.scheduler_config)
219 |
220 | print("Setting up LambdaLR scheduler...")
221 | scheduler = [
222 | {
223 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
224 | 'interval': 'step',
225 | 'frequency': 1
226 | },
227 | {
228 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
229 | 'interval': 'step',
230 | 'frequency': 1
231 | },
232 | ]
233 | return [opt_ae, opt_disc], scheduler
234 | return [opt_ae, opt_disc], []
235 |
236 | def get_last_layer(self):
237 | return self.decoder.conv_out.weight
238 |
239 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
240 | log = dict()
241 | x = self.get_input(batch, self.image_key)
242 | x = x.to(self.device)
243 | if only_inputs:
244 | log["inputs"] = x
245 | return log
246 | xrec, _ = self(x)
247 | if x.shape[1] > 3:
248 | # colorize with random projection
249 | assert xrec.shape[1] > 3
250 | x = self.to_rgb(x)
251 | xrec = self.to_rgb(xrec)
252 | log["inputs"] = x
253 | log["reconstructions"] = xrec
254 | if plot_ema:
255 | with self.ema_scope():
256 | xrec_ema, _ = self(x)
257 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
258 | log["reconstructions_ema"] = xrec_ema
259 | return log
260 |
261 | def to_rgb(self, x):
262 | assert self.image_key == "segmentation"
263 | if not hasattr(self, "colorize"):
264 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
265 | x = F.conv2d(x, weight=self.colorize)
266 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
267 | return x
268 |
269 |
270 | class VQModelInterface(VQModel):
271 | def __init__(self, embed_dim, *args, **kwargs):
272 | super().__init__(embed_dim=embed_dim, *args, **kwargs)
273 | self.embed_dim = embed_dim
274 |
275 | def encode(self, x):
276 | h = self.encoder(x)
277 | h = self.quant_conv(h)
278 | # print("!!!!encoding: ", h.shape)
279 | return h
280 |
281 | def decode(self, h, force_not_quantize=False):
282 | # also go through quantization layer
283 | if not force_not_quantize:
284 | quant, emb_loss, info = self.quantize(h)
285 | else:
286 | quant = h
287 | quant = self.post_quant_conv(quant)
288 | dec = self.decoder(quant)
289 | return dec
290 |
291 |
292 | class AutoencoderKL(pl.LightningModule):
293 | def __init__(self,
294 | ddconfig,
295 | lossconfig,
296 | embed_dim,
297 | ckpt_path=None,
298 | ignore_keys=[],
299 | image_key="image",
300 | colorize_nlabels=None,
301 | monitor=None,
302 | ):
303 | super().__init__()
304 | self.image_key = image_key
305 | self.encoder = Encoder(**ddconfig)
306 | self.decoder = Decoder(**ddconfig)
307 | self.loss = instantiate_from_config(lossconfig)
308 | assert ddconfig["double_z"]
309 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
310 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
311 | self.embed_dim = embed_dim
312 | if colorize_nlabels is not None:
313 | assert type(colorize_nlabels)==int
314 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
315 | if monitor is not None:
316 | self.monitor = monitor
317 | if ckpt_path is not None:
318 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
319 |
320 | def init_from_ckpt(self, path, ignore_keys=list()):
321 | sd = torch.load(path, map_location="cpu")["state_dict"]
322 | keys = list(sd.keys())
323 | for k in keys:
324 | for ik in ignore_keys:
325 | if k.startswith(ik):
326 | print("Deleting key {} from state_dict.".format(k))
327 | del sd[k]
328 | self.load_state_dict(sd, strict=False)
329 | print(f"Restored from {path}")
330 |
331 | def encode(self, x):
332 | h = self.encoder(x)
333 | moments = self.quant_conv(h)
334 | posterior = DiagonalGaussianDistribution(moments)
335 | return posterior
336 |
337 | def decode(self, z):
338 | z = self.post_quant_conv(z)
339 | dec = self.decoder(z)
340 | return dec
341 |
342 | def forward(self, input, sample_posterior=True):
343 | posterior = self.encode(input)
344 | if sample_posterior:
345 | z = posterior.sample()
346 | else:
347 | z = posterior.mode()
348 | dec = self.decode(z)
349 | return dec, posterior
350 |
351 | def get_input(self, batch, k):
352 | x = batch[k]
353 | if len(x.shape) == 3:
354 | x = x[..., None]
355 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
356 | return x
357 |
358 | def training_step(self, batch, batch_idx, optimizer_idx):
359 | inputs = self.get_input(batch, self.image_key)
360 | reconstructions, posterior = self(inputs) ## forward
361 |
362 | if optimizer_idx == 0:
363 | # train encoder+decoder+logvar
364 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
365 | last_layer=self.get_last_layer(), split="train")
366 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
367 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
368 | return aeloss
369 |
370 | if optimizer_idx == 1:
371 | # train the discriminator
372 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
373 | last_layer=self.get_last_layer(), split="train")
374 |
375 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
376 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
377 | return discloss
378 |
379 | def validation_step(self, batch, batch_idx):
380 | inputs = self.get_input(batch, self.image_key)
381 | reconstructions, posterior = self(inputs)
382 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
383 | last_layer=self.get_last_layer(), split="val")
384 |
385 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
386 | last_layer=self.get_last_layer(), split="val")
387 |
388 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
389 | self.log_dict(log_dict_ae)
390 | self.log_dict(log_dict_disc)
391 | return self.log_dict
392 |
393 | def configure_optimizers(self):
394 | lr = self.learning_rate
395 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
396 | list(self.decoder.parameters())+
397 | list(self.quant_conv.parameters())+
398 | list(self.post_quant_conv.parameters()),
399 | lr=lr, betas=(0.5, 0.9))
400 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
401 | lr=lr, betas=(0.5, 0.9))
402 | return [opt_ae, opt_disc], []
403 |
404 | def get_last_layer(self):
405 | return self.decoder.conv_out.weight
406 |
407 | @torch.no_grad()
408 | def log_images(self, batch, only_inputs=False, **kwargs):
409 | log = dict()
410 | x = self.get_input(batch, self.image_key)
411 | x = x.to(self.device)
412 | if not only_inputs:
413 | xrec, posterior = self(x)
414 | if x.shape[1] > 3:
415 | # colorize with random projection
416 | assert xrec.shape[1] > 3
417 | x = self.to_rgb(x)
418 | xrec = self.to_rgb(xrec)
419 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
420 | log["reconstructions"] = xrec
421 | log["inputs"] = x
422 | return log
423 |
424 | def to_rgb(self, x):
425 | assert self.image_key == "segmentation"
426 | if not hasattr(self, "colorize"):
427 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
428 | x = F.conv2d(x, weight=self.colorize)
429 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
430 | return x
431 |
432 |
433 | class IdentityFirstStage(torch.nn.Module):
434 | def __init__(self, *args, vq_interface=False, **kwargs):
435 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
436 | super().__init__()
437 |
438 | def encode(self, x, *args, **kwargs):
439 | return x
440 |
441 | def decode(self, x, *args, **kwargs):
442 | return x
443 |
444 | def quantize(self, x, *args, **kwargs):
445 | if self.vq_interface:
446 | return x, None, [None, None, None]
447 | return x
448 |
449 | def forward(self, x, *args, **kwargs):
450 | return x
451 |
--------------------------------------------------------------------------------
/rec_network/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/rec_network/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from rec_network.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from rec_network.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/rec_network/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from rec_network.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class DDIMSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | self.register_buffer('betas', to_torch(self.model.betas))
32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34 |
35 | # calculations for diffusion q(x_t | x_{t-1}) and others
36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41 |
42 | # ddim sampling parameters
43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44 | ddim_timesteps=self.ddim_timesteps,
45 | eta=ddim_eta,verbose=verbose)
46 | self.register_buffer('ddim_sigmas', ddim_sigmas)
47 | self.register_buffer('ddim_alphas', ddim_alphas)
48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54 |
55 | @torch.no_grad()
56 | def sample(self,
57 | S,
58 | batch_size,
59 | shape,
60 | conditioning=None,
61 | callback=None,
62 | normals_sequence=None,
63 | img_callback=None,
64 | quantize_x0=False,
65 | eta=0.,
66 | mask=None,
67 | x0=None,
68 | temperature=1.,
69 | noise_dropout=0.,
70 | score_corrector=None,
71 | corrector_kwargs=None,
72 | verbose=True,
73 | x_T=None,
74 | log_every_t=100,
75 | unconditional_guidance_scale=1.,
76 | unconditional_conditioning=None,
77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78 | **kwargs
79 | ):
80 | if conditioning is not None:
81 | if isinstance(conditioning, dict):
82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83 | if cbs != batch_size:
84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85 | else:
86 | if conditioning.shape[0] != batch_size:
87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88 |
89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90 | # sampling
91 | C, H, W = shape
92 | size = (batch_size, C, H, W)
93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94 |
95 | samples, intermediates = self.ddim_sampling(conditioning, size,
96 | callback=callback,
97 | img_callback=img_callback,
98 | quantize_denoised=quantize_x0,
99 | mask=mask, x0=x0,
100 | ddim_use_original_steps=False,
101 | noise_dropout=noise_dropout,
102 | temperature=temperature,
103 | score_corrector=score_corrector,
104 | corrector_kwargs=corrector_kwargs,
105 | x_T=x_T,
106 | log_every_t=log_every_t,
107 | unconditional_guidance_scale=unconditional_guidance_scale,
108 | unconditional_conditioning=unconditional_conditioning,
109 | )
110 | return samples, intermediates
111 |
112 | @torch.no_grad()
113 | def ddim_sampling(self, cond, shape,
114 | x_T=None, ddim_use_original_steps=False,
115 | callback=None, timesteps=None, quantize_denoised=False,
116 | mask=None, x0=None, img_callback=None, log_every_t=100,
117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
119 | device = self.model.betas.device
120 | b = shape[0]
121 | if x_T is None:
122 | img = torch.randn(shape, device=device)
123 | else:
124 | img = x_T
125 |
126 | if timesteps is None:
127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128 | elif timesteps is not None and not ddim_use_original_steps:
129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130 | timesteps = self.ddim_timesteps[:subset_end]
131 |
132 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135 | print(f"Running DDIM Sampling with {total_steps} timesteps")
136 |
137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138 |
139 | for i, step in enumerate(iterator):
140 | index = total_steps - i - 1
141 | ts = torch.full((b,), step, device=device, dtype=torch.long)
142 |
143 | if mask is not None:
144 | assert x0 is not None
145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146 | img = img_orig * mask + (1. - mask) * img
147 |
148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149 | quantize_denoised=quantize_denoised, temperature=temperature,
150 | noise_dropout=noise_dropout, score_corrector=score_corrector,
151 | corrector_kwargs=corrector_kwargs,
152 | unconditional_guidance_scale=unconditional_guidance_scale,
153 | unconditional_conditioning=unconditional_conditioning)
154 | img, pred_x0 = outs
155 | if callback: callback(i)
156 | if img_callback: img_callback(pred_x0, i)
157 |
158 | if index % log_every_t == 0 or index == total_steps - 1:
159 | intermediates['x_inter'].append(img)
160 | intermediates['pred_x0'].append(pred_x0)
161 |
162 | return img, intermediates
163 |
164 | @torch.no_grad()
165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167 | unconditional_guidance_scale=1., unconditional_conditioning=None):
168 | b, *_, device = *x.shape, x.device
169 |
170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171 | e_t = self.model.apply_model(x, t, c)
172 | else:
173 | x_in = torch.cat([x] * 2)
174 | t_in = torch.cat([t] * 2)
175 | c_in = torch.cat([unconditional_conditioning, c])
176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178 |
179 | if score_corrector is not None:
180 | assert self.model.parameterization == "eps"
181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182 |
183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187 | # select parameters corresponding to the currently considered timestep
188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192 |
193 | # current prediction for x_0
194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195 | if quantize_denoised:
196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197 | # direction pointing to x_t
198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200 | if noise_dropout > 0.:
201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203 | return x_prev, pred_x0
204 |
--------------------------------------------------------------------------------
/rec_network/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from rec_network.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class PLMSSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | if ddim_eta != 0:
26 | raise ValueError('ddim_eta must be 0 for PLMS')
27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29 | alphas_cumprod = self.model.alphas_cumprod
30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32 |
33 | self.register_buffer('betas', to_torch(self.model.betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43 |
44 | # ddim sampling parameters
45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46 | ddim_timesteps=self.ddim_timesteps,
47 | eta=ddim_eta,verbose=verbose)
48 | self.register_buffer('ddim_sigmas', ddim_sigmas)
49 | self.register_buffer('ddim_alphas', ddim_alphas)
50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56 |
57 | @torch.no_grad()
58 | def sample(self,
59 | S,
60 | batch_size,
61 | shape,
62 | conditioning=None,
63 | callback=None,
64 | normals_sequence=None,
65 | img_callback=None,
66 | quantize_x0=False,
67 | eta=0.,
68 | mask=None,
69 | x0=None,
70 | temperature=1.,
71 | noise_dropout=0.,
72 | score_corrector=None,
73 | corrector_kwargs=None,
74 | verbose=True,
75 | x_T=None,
76 | log_every_t=100,
77 | unconditional_guidance_scale=1.,
78 | unconditional_conditioning=None,
79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80 | **kwargs
81 | ):
82 | if conditioning is not None:
83 | if isinstance(conditioning, dict):
84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85 | if cbs != batch_size:
86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87 | else:
88 | if conditioning.shape[0] != batch_size:
89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90 |
91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92 | # sampling
93 | C, H, W = shape
94 | size = (batch_size, C, H, W)
95 | print(f'Data shape for PLMS sampling is {size}')
96 |
97 | samples, intermediates = self.plms_sampling(conditioning, size,
98 | callback=callback,
99 | img_callback=img_callback,
100 | quantize_denoised=quantize_x0,
101 | mask=mask, x0=x0,
102 | ddim_use_original_steps=False,
103 | noise_dropout=noise_dropout,
104 | temperature=temperature,
105 | score_corrector=score_corrector,
106 | corrector_kwargs=corrector_kwargs,
107 | x_T=x_T,
108 | log_every_t=log_every_t,
109 | unconditional_guidance_scale=unconditional_guidance_scale,
110 | unconditional_conditioning=unconditional_conditioning,
111 | )
112 | return samples, intermediates
113 |
114 | @torch.no_grad()
115 | def plms_sampling(self, cond, shape,
116 | x_T=None, ddim_use_original_steps=False,
117 | callback=None, timesteps=None, quantize_denoised=False,
118 | mask=None, x0=None, img_callback=None, log_every_t=100,
119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
121 | device = self.model.betas.device
122 | b = shape[0]
123 | if x_T is None:
124 | img = torch.randn(shape, device=device)
125 | else:
126 | img = x_T
127 |
128 | if timesteps is None:
129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130 | elif timesteps is not None and not ddim_use_original_steps:
131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132 | timesteps = self.ddim_timesteps[:subset_end]
133 |
134 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137 | print(f"Running PLMS Sampling with {total_steps} timesteps")
138 |
139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140 | old_eps = []
141 |
142 | for i, step in enumerate(iterator):
143 | index = total_steps - i - 1
144 | ts = torch.full((b,), step, device=device, dtype=torch.long)
145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146 |
147 | if mask is not None:
148 | assert x0 is not None
149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150 | img = img_orig * mask + (1. - mask) * img
151 |
152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153 | quantize_denoised=quantize_denoised, temperature=temperature,
154 | noise_dropout=noise_dropout, score_corrector=score_corrector,
155 | corrector_kwargs=corrector_kwargs,
156 | unconditional_guidance_scale=unconditional_guidance_scale,
157 | unconditional_conditioning=unconditional_conditioning,
158 | old_eps=old_eps, t_next=ts_next)
159 | img, pred_x0, e_t = outs
160 | old_eps.append(e_t)
161 | if len(old_eps) >= 4:
162 | old_eps.pop(0)
163 | if callback: callback(i)
164 | if img_callback: img_callback(pred_x0, i)
165 |
166 | if index % log_every_t == 0 or index == total_steps - 1:
167 | intermediates['x_inter'].append(img)
168 | intermediates['pred_x0'].append(pred_x0)
169 |
170 | return img, intermediates
171 |
172 | @torch.no_grad()
173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176 | b, *_, device = *x.shape, x.device
177 |
178 | def get_model_output(x, t):
179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180 | e_t = self.model.apply_model(x, t, c)
181 | else:
182 | x_in = torch.cat([x] * 2)
183 | t_in = torch.cat([t] * 2)
184 | c_in = torch.cat([unconditional_conditioning, c])
185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187 |
188 | if score_corrector is not None:
189 | assert self.model.parameterization == "eps"
190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191 |
192 | return e_t
193 |
194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198 |
199 | def get_x_prev_and_pred_x0(e_t, index):
200 | # select parameters corresponding to the currently considered timestep
201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205 |
206 | # current prediction for x_0
207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208 | if quantize_denoised:
209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210 | # direction pointing to x_t
211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213 | if noise_dropout > 0.:
214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216 | return x_prev, pred_x0
217 |
218 | e_t = get_model_output(x, t)
219 | if len(old_eps) == 0:
220 | # Pseudo Improved Euler (2nd order)
221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222 | e_t_next = get_model_output(x_prev, t_next)
223 | e_t_prime = (e_t + e_t_next) / 2
224 | elif len(old_eps) == 1:
225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
227 | elif len(old_eps) == 2:
228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230 | elif len(old_eps) >= 3:
231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233 |
234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235 |
236 | return x_prev, pred_x0, e_t
237 |
--------------------------------------------------------------------------------
/rec_network/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from rec_network.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): #context_dim :y query_dim:z
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 | # context_dim = query_dim
158 |
159 | self.scale = dim_head ** -0.5
160 | self.heads = heads
161 |
162 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163 | # print("!!!!!!",inner_dim,context_dim)
164 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
165 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
166 |
167 | self.to_out = nn.Sequential(
168 | nn.Linear(inner_dim, query_dim),
169 | nn.Dropout(dropout)
170 | )
171 |
172 | def forward(self, x, context=None, mask=None):
173 | h = self.heads
174 |
175 | q = self.to_q(x)
176 | context = default(context, x)
177 | # print("input",x.shape,context.shape)
178 | k = self.to_k(context)
179 | v = self.to_v(context)
180 |
181 | # print("!!!!!!!!!!",h,q.shape,k.shape,v.shape,context.shape)
182 |
183 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
184 |
185 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
186 |
187 | if exists(mask):
188 | mask = rearrange(mask, 'b ... -> b (...)')
189 | max_neg_value = -torch.finfo(sim.dtype).max
190 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
191 | sim.masked_fill_(~mask, max_neg_value)
192 |
193 | # attention, what we cannot get enough of
194 | attn = sim.softmax(dim=-1)
195 |
196 | out = einsum('b i j, b j d -> b i d', attn, v)
197 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
198 | return self.to_out(out)
199 |
200 |
201 | class BasicTransformerBlock(nn.Module):
202 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
203 | super().__init__()
204 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
205 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
206 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
207 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
208 | self.norm1 = nn.LayerNorm(dim)
209 | self.norm2 = nn.LayerNorm(dim)
210 | self.norm3 = nn.LayerNorm(dim)
211 | self.checkpoint = checkpoint
212 |
213 | def forward(self, x, context=None):
214 |
215 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
216 |
217 | def _forward(self, x, context=None):
218 | x = self.attn1(self.norm1(x)) + x
219 | x = self.attn2(self.norm2(x), context=context) + x
220 | x = self.ff(self.norm3(x)) + x
221 | return x
222 |
223 |
224 | class SpatialTransformer(nn.Module):
225 | """
226 | Transformer block for image-like data.
227 | First, project the input (aka embedding)
228 | and reshape to b, t, d.
229 | Then apply standard transformer action.
230 | Finally, reshape to image
231 | """
232 | def __init__(self, in_channels, n_heads, d_head,
233 | depth=1, dropout=0., context_dim=None):
234 | super().__init__()
235 | self.in_channels = in_channels
236 | inner_dim = n_heads * d_head
237 | self.norm = Normalize(in_channels)
238 |
239 | self.proj_in = nn.Conv2d(in_channels,
240 | inner_dim,
241 | kernel_size=1,
242 | stride=1,
243 | padding=0)
244 |
245 | self.transformer_blocks = nn.ModuleList(
246 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
247 | for d in range(depth)]
248 | )
249 |
250 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
251 | in_channels,
252 | kernel_size=1,
253 | stride=1,
254 | padding=0))
255 |
256 | def forward(self, x, context=None):
257 | # note: if no context is given, cross-attention defaults to self-attention
258 | b, c, h, w = x.shape
259 | # print("253 shape", x.shape, context.shape)
260 | x_in = x
261 | x = self.norm(x)
262 | x = self.proj_in(x)
263 | x = rearrange(x, 'b c h w -> b (h w) c')
264 | context = rearrange(context, 'b c h w -> b (h w) c')
265 | # print("258 shape", x.shape, context.shape)
266 | for block in self.transformer_blocks:
267 | x = block(x, context=context)
268 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
269 | x = self.proj_out(x)
270 | return x + x_in
--------------------------------------------------------------------------------
/rec_network/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/rec_network/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from rec_network.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/rec_network/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/rec_network/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/rec_network/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/rec_network/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/rec_network/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 |
8 |
9 | from rec_network.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 |
138 | class FrozenCLIPTextEmbedder(nn.Module):
139 | """
140 | Uses the CLIP transformer encoder for text.
141 | """
142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
143 | super().__init__()
144 | self.model, _ = clip.load(version, jit=False, device="cpu")
145 | self.device = device
146 | self.max_length = max_length
147 | self.n_repeat = n_repeat
148 | self.normalize = normalize
149 |
150 | def freeze(self):
151 | self.model = self.model.eval()
152 | for param in self.parameters():
153 | param.requires_grad = False
154 |
155 | def forward(self, text):
156 | tokens = clip.tokenize(text).to(self.device)
157 | z = self.model.encode_text(tokens)
158 | if self.normalize:
159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
160 | return z
161 |
162 | def encode(self, text):
163 | z = self(text)
164 | if z.ndim==2:
165 | z = z[:, None, :]
166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
167 | return z
168 |
169 |
170 | class FrozenClipImageEmbedder(nn.Module):
171 | """
172 | Uses the CLIP image encoder.
173 | """
174 | def __init__(
175 | self,
176 | model,
177 | jit=False,
178 | device='cuda' if torch.cuda.is_available() else 'cpu',
179 | antialias=False,
180 | ):
181 | super().__init__()
182 | self.model, _ = clip.load(name=model, device=device, jit=jit)
183 |
184 | self.antialias = antialias
185 |
186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
188 |
189 | def preprocess(self, x):
190 | # normalize to [0,1]
191 | x = kornia.geometry.resize(x, (224, 224),
192 | interpolation='bicubic',align_corners=True,
193 | antialias=self.antialias)
194 | x = (x + 1.) / 2.
195 | # renormalize according to clip
196 | x = kornia.enhance.normalize(x, self.mean, self.std)
197 | return x
198 |
199 | def forward(self, x):
200 | # x is assumed to be in range [-1,1]
201 | return self.model.encode_image(self.preprocess(x))
202 |
203 |
--------------------------------------------------------------------------------
/rec_network/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from rec_network.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from rec_network.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/rec_network/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhang-t/DiffAD/0c39d44b9270740014dcdad987545905ebba60d2/rec_network/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/rec_network/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from rec_network.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/rec_network/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/rec_network/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/rec_network/modules/x_transformer.py:
--------------------------------------------------------------------------------
1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | from functools import partial
6 | from inspect import isfunction
7 | from collections import namedtuple
8 | from einops import rearrange, repeat, reduce
9 |
10 | # constants
11 |
12 | DEFAULT_DIM_HEAD = 64
13 |
14 | Intermediates = namedtuple('Intermediates', [
15 | 'pre_softmax_attn',
16 | 'post_softmax_attn'
17 | ])
18 |
19 | LayerIntermediates = namedtuple('Intermediates', [
20 | 'hiddens',
21 | 'attn_intermediates'
22 | ])
23 |
24 |
25 | class AbsolutePositionalEmbedding(nn.Module):
26 | def __init__(self, dim, max_seq_len):
27 | super().__init__()
28 | self.emb = nn.Embedding(max_seq_len, dim)
29 | self.init_()
30 |
31 | def init_(self):
32 | nn.init.normal_(self.emb.weight, std=0.02)
33 |
34 | def forward(self, x):
35 | n = torch.arange(x.shape[1], device=x.device)
36 | return self.emb(n)[None, :, :]
37 |
38 |
39 | class FixedPositionalEmbedding(nn.Module):
40 | def __init__(self, dim):
41 | super().__init__()
42 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43 | self.register_buffer('inv_freq', inv_freq)
44 |
45 | def forward(self, x, seq_dim=1, offset=0):
46 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
47 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
48 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
49 | return emb[None, :, :]
50 |
51 |
52 | # helpers
53 |
54 | def exists(val):
55 | return val is not None
56 |
57 |
58 | def default(val, d):
59 | if exists(val):
60 | return val
61 | return d() if isfunction(d) else d
62 |
63 |
64 | def always(val):
65 | def inner(*args, **kwargs):
66 | return val
67 | return inner
68 |
69 |
70 | def not_equals(val):
71 | def inner(x):
72 | return x != val
73 | return inner
74 |
75 |
76 | def equals(val):
77 | def inner(x):
78 | return x == val
79 | return inner
80 |
81 |
82 | def max_neg_value(tensor):
83 | return -torch.finfo(tensor.dtype).max
84 |
85 |
86 | # keyword argument helpers
87 |
88 | def pick_and_pop(keys, d):
89 | values = list(map(lambda key: d.pop(key), keys))
90 | return dict(zip(keys, values))
91 |
92 |
93 | def group_dict_by_key(cond, d):
94 | return_val = [dict(), dict()]
95 | for key in d.keys():
96 | match = bool(cond(key))
97 | ind = int(not match)
98 | return_val[ind][key] = d[key]
99 | return (*return_val,)
100 |
101 |
102 | def string_begins_with(prefix, str):
103 | return str.startswith(prefix)
104 |
105 |
106 | def group_by_key_prefix(prefix, d):
107 | return group_dict_by_key(partial(string_begins_with, prefix), d)
108 |
109 |
110 | def groupby_prefix_and_trim(prefix, d):
111 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
112 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
113 | return kwargs_without_prefix, kwargs
114 |
115 |
116 | # classes
117 | class Scale(nn.Module):
118 | def __init__(self, value, fn):
119 | super().__init__()
120 | self.value = value
121 | self.fn = fn
122 |
123 | def forward(self, x, **kwargs):
124 | x, *rest = self.fn(x, **kwargs)
125 | return (x * self.value, *rest)
126 |
127 |
128 | class Rezero(nn.Module):
129 | def __init__(self, fn):
130 | super().__init__()
131 | self.fn = fn
132 | self.g = nn.Parameter(torch.zeros(1))
133 |
134 | def forward(self, x, **kwargs):
135 | x, *rest = self.fn(x, **kwargs)
136 | return (x * self.g, *rest)
137 |
138 |
139 | class ScaleNorm(nn.Module):
140 | def __init__(self, dim, eps=1e-5):
141 | super().__init__()
142 | self.scale = dim ** -0.5
143 | self.eps = eps
144 | self.g = nn.Parameter(torch.ones(1))
145 |
146 | def forward(self, x):
147 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
148 | return x / norm.clamp(min=self.eps) * self.g
149 |
150 |
151 | class RMSNorm(nn.Module):
152 | def __init__(self, dim, eps=1e-8):
153 | super().__init__()
154 | self.scale = dim ** -0.5
155 | self.eps = eps
156 | self.g = nn.Parameter(torch.ones(dim))
157 |
158 | def forward(self, x):
159 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
160 | return x / norm.clamp(min=self.eps) * self.g
161 |
162 |
163 | class Residual(nn.Module):
164 | def forward(self, x, residual):
165 | return x + residual
166 |
167 |
168 | class GRUGating(nn.Module):
169 | def __init__(self, dim):
170 | super().__init__()
171 | self.gru = nn.GRUCell(dim, dim)
172 |
173 | def forward(self, x, residual):
174 | gated_output = self.gru(
175 | rearrange(x, 'b n d -> (b n) d'),
176 | rearrange(residual, 'b n d -> (b n) d')
177 | )
178 |
179 | return gated_output.reshape_as(x)
180 |
181 |
182 | # feedforward
183 |
184 | class GEGLU(nn.Module):
185 | def __init__(self, dim_in, dim_out):
186 | super().__init__()
187 | self.proj = nn.Linear(dim_in, dim_out * 2)
188 |
189 | def forward(self, x):
190 | x, gate = self.proj(x).chunk(2, dim=-1)
191 | return x * F.gelu(gate)
192 |
193 |
194 | class FeedForward(nn.Module):
195 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
196 | super().__init__()
197 | inner_dim = int(dim * mult)
198 | dim_out = default(dim_out, dim)
199 | project_in = nn.Sequential(
200 | nn.Linear(dim, inner_dim),
201 | nn.GELU()
202 | ) if not glu else GEGLU(dim, inner_dim)
203 |
204 | self.net = nn.Sequential(
205 | project_in,
206 | nn.Dropout(dropout),
207 | nn.Linear(inner_dim, dim_out)
208 | )
209 |
210 | def forward(self, x):
211 | return self.net(x)
212 |
213 |
214 | # attention.
215 | class Attention(nn.Module):
216 | def __init__(
217 | self,
218 | dim,
219 | dim_head=DEFAULT_DIM_HEAD,
220 | heads=8,
221 | causal=False,
222 | mask=None,
223 | talking_heads=False,
224 | sparse_topk=None,
225 | use_entmax15=False,
226 | num_mem_kv=0,
227 | dropout=0.,
228 | on_attn=False
229 | ):
230 | super().__init__()
231 | if use_entmax15:
232 | raise NotImplementedError("Check out entmax activation instead of softmax activation!")
233 | self.scale = dim_head ** -0.5
234 | self.heads = heads
235 | self.causal = causal
236 | self.mask = mask
237 |
238 | inner_dim = dim_head * heads
239 |
240 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
241 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
242 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
243 | self.dropout = nn.Dropout(dropout)
244 |
245 | # talking heads
246 | self.talking_heads = talking_heads
247 | if talking_heads:
248 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
250 |
251 | # explicit topk sparse attention
252 | self.sparse_topk = sparse_topk
253 |
254 | # entmax
255 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax
256 | self.attn_fn = F.softmax
257 |
258 | # add memory key / values
259 | self.num_mem_kv = num_mem_kv
260 | if num_mem_kv > 0:
261 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
263 |
264 | # attention on attention
265 | self.attn_on_attn = on_attn
266 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
267 |
268 | def forward(
269 | self,
270 | x,
271 | context=None,
272 | mask=None,
273 | context_mask=None,
274 | rel_pos=None,
275 | sinusoidal_emb=None,
276 | prev_attn=None,
277 | mem=None
278 | ):
279 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
280 | kv_input = default(context, x)
281 |
282 | q_input = x
283 | k_input = kv_input
284 | v_input = kv_input
285 |
286 | if exists(mem):
287 | k_input = torch.cat((mem, k_input), dim=-2)
288 | v_input = torch.cat((mem, v_input), dim=-2)
289 |
290 | if exists(sinusoidal_emb):
291 | # in shortformer, the query would start at a position offset depending on the past cached memory
292 | offset = k_input.shape[-2] - q_input.shape[-2]
293 | q_input = q_input + sinusoidal_emb(q_input, offset=offset)
294 | k_input = k_input + sinusoidal_emb(k_input)
295 |
296 | q = self.to_q(q_input)
297 | k = self.to_k(k_input)
298 | v = self.to_v(v_input)
299 |
300 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
301 |
302 | input_mask = None
303 | if any(map(exists, (mask, context_mask))):
304 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
305 | k_mask = q_mask if not exists(context) else context_mask
306 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
307 | q_mask = rearrange(q_mask, 'b i -> b () i ()')
308 | k_mask = rearrange(k_mask, 'b j -> b () () j')
309 | input_mask = q_mask * k_mask
310 |
311 | if self.num_mem_kv > 0:
312 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
313 | k = torch.cat((mem_k, k), dim=-2)
314 | v = torch.cat((mem_v, v), dim=-2)
315 | if exists(input_mask):
316 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
317 |
318 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
319 | mask_value = max_neg_value(dots)
320 |
321 | if exists(prev_attn):
322 | dots = dots + prev_attn
323 |
324 | pre_softmax_attn = dots
325 |
326 | if talking_heads:
327 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
328 |
329 | if exists(rel_pos):
330 | dots = rel_pos(dots)
331 |
332 | if exists(input_mask):
333 | dots.masked_fill_(~input_mask, mask_value)
334 | del input_mask
335 |
336 | if self.causal:
337 | i, j = dots.shape[-2:]
338 | r = torch.arange(i, device=device)
339 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
340 | mask = F.pad(mask, (j - i, 0), value=False)
341 | dots.masked_fill_(mask, mask_value)
342 | del mask
343 |
344 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
345 | top, _ = dots.topk(self.sparse_topk, dim=-1)
346 | vk = top[..., -1].unsqueeze(-1).expand_as(dots)
347 | mask = dots < vk
348 | dots.masked_fill_(mask, mask_value)
349 | del mask
350 |
351 | attn = self.attn_fn(dots, dim=-1)
352 | post_softmax_attn = attn
353 |
354 | attn = self.dropout(attn)
355 |
356 | if talking_heads:
357 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
358 |
359 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
360 | out = rearrange(out, 'b h n d -> b n (h d)')
361 |
362 | intermediates = Intermediates(
363 | pre_softmax_attn=pre_softmax_attn,
364 | post_softmax_attn=post_softmax_attn
365 | )
366 |
367 | return self.to_out(out), intermediates
368 |
369 |
370 | class AttentionLayers(nn.Module):
371 | def __init__(
372 | self,
373 | dim,
374 | depth,
375 | heads=8,
376 | causal=False,
377 | cross_attend=False,
378 | only_cross=False,
379 | use_scalenorm=False,
380 | use_rmsnorm=False,
381 | use_rezero=False,
382 | rel_pos_num_buckets=32,
383 | rel_pos_max_distance=128,
384 | position_infused_attn=False,
385 | custom_layers=None,
386 | sandwich_coef=None,
387 | par_ratio=None,
388 | residual_attn=False,
389 | cross_residual_attn=False,
390 | macaron=False,
391 | pre_norm=True,
392 | gate_residual=False,
393 | **kwargs
394 | ):
395 | super().__init__()
396 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
397 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
398 |
399 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
400 |
401 | self.dim = dim
402 | self.depth = depth
403 | self.layers = nn.ModuleList([])
404 |
405 | self.has_pos_emb = position_infused_attn
406 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
407 | self.rotary_pos_emb = always(None)
408 |
409 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410 | self.rel_pos = None
411 |
412 | self.pre_norm = pre_norm
413 |
414 | self.residual_attn = residual_attn
415 | self.cross_residual_attn = cross_residual_attn
416 |
417 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
418 | norm_class = RMSNorm if use_rmsnorm else norm_class
419 | norm_fn = partial(norm_class, dim)
420 |
421 | norm_fn = nn.Identity if use_rezero else norm_fn
422 | branch_fn = Rezero if use_rezero else None
423 |
424 | if cross_attend and not only_cross:
425 | default_block = ('a', 'c', 'f')
426 | elif cross_attend and only_cross:
427 | default_block = ('c', 'f')
428 | else:
429 | default_block = ('a', 'f')
430 |
431 | if macaron:
432 | default_block = ('f',) + default_block
433 |
434 | if exists(custom_layers):
435 | layer_types = custom_layers
436 | elif exists(par_ratio):
437 | par_depth = depth * len(default_block)
438 | assert 1 < par_ratio <= par_depth, 'par ratio out of range'
439 | default_block = tuple(filter(not_equals('f'), default_block))
440 | par_attn = par_depth // par_ratio
441 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
442 | par_width = (depth_cut + depth_cut // par_attn) // par_attn
443 | assert len(default_block) <= par_width, 'default block is too large for par_ratio'
444 | par_block = default_block + ('f',) * (par_width - len(default_block))
445 | par_head = par_block * par_attn
446 | layer_types = par_head + ('f',) * (par_depth - len(par_head))
447 | elif exists(sandwich_coef):
448 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
449 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
450 | else:
451 | layer_types = default_block * depth
452 |
453 | self.layer_types = layer_types
454 | self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
455 |
456 | for layer_type in self.layer_types:
457 | if layer_type == 'a':
458 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
459 | elif layer_type == 'c':
460 | layer = Attention(dim, heads=heads, **attn_kwargs)
461 | elif layer_type == 'f':
462 | layer = FeedForward(dim, **ff_kwargs)
463 | layer = layer if not macaron else Scale(0.5, layer)
464 | else:
465 | raise Exception(f'invalid layer type {layer_type}')
466 |
467 | if isinstance(layer, Attention) and exists(branch_fn):
468 | layer = branch_fn(layer)
469 |
470 | if gate_residual:
471 | residual_fn = GRUGating(dim)
472 | else:
473 | residual_fn = Residual()
474 |
475 | self.layers.append(nn.ModuleList([
476 | norm_fn(),
477 | layer,
478 | residual_fn
479 | ]))
480 |
481 | def forward(
482 | self,
483 | x,
484 | context=None,
485 | mask=None,
486 | context_mask=None,
487 | mems=None,
488 | return_hiddens=False
489 | ):
490 | hiddens = []
491 | intermediates = []
492 | prev_attn = None
493 | prev_cross_attn = None
494 |
495 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
496 |
497 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
498 | is_last = ind == (len(self.layers) - 1)
499 |
500 | if layer_type == 'a':
501 | hiddens.append(x)
502 | layer_mem = mems.pop(0)
503 |
504 | residual = x
505 |
506 | if self.pre_norm:
507 | x = norm(x)
508 |
509 | if layer_type == 'a':
510 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
511 | prev_attn=prev_attn, mem=layer_mem)
512 | elif layer_type == 'c':
513 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
514 | elif layer_type == 'f':
515 | out = block(x)
516 |
517 | x = residual_fn(out, residual)
518 |
519 | if layer_type in ('a', 'c'):
520 | intermediates.append(inter)
521 |
522 | if layer_type == 'a' and self.residual_attn:
523 | prev_attn = inter.pre_softmax_attn
524 | elif layer_type == 'c' and self.cross_residual_attn:
525 | prev_cross_attn = inter.pre_softmax_attn
526 |
527 | if not self.pre_norm and not is_last:
528 | x = norm(x)
529 |
530 | if return_hiddens:
531 | intermediates = LayerIntermediates(
532 | hiddens=hiddens,
533 | attn_intermediates=intermediates
534 | )
535 |
536 | return x, intermediates
537 |
538 | return x
539 |
540 |
541 | class Encoder(AttentionLayers):
542 | def __init__(self, **kwargs):
543 | assert 'causal' not in kwargs, 'cannot set causality on encoder'
544 | super().__init__(causal=False, **kwargs)
545 |
546 |
547 |
548 | class TransformerWrapper(nn.Module):
549 | def __init__(
550 | self,
551 | *,
552 | num_tokens,
553 | max_seq_len,
554 | attn_layers,
555 | emb_dim=None,
556 | max_mem_len=0.,
557 | emb_dropout=0.,
558 | num_memory_tokens=None,
559 | tie_embedding=False,
560 | use_pos_emb=True
561 | ):
562 | super().__init__()
563 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
564 |
565 | dim = attn_layers.dim
566 | emb_dim = default(emb_dim, dim)
567 |
568 | self.max_seq_len = max_seq_len
569 | self.max_mem_len = max_mem_len
570 | self.num_tokens = num_tokens
571 |
572 | self.token_emb = nn.Embedding(num_tokens, emb_dim)
573 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
574 | use_pos_emb and not attn_layers.has_pos_emb) else always(0)
575 | self.emb_dropout = nn.Dropout(emb_dropout)
576 |
577 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
578 | self.attn_layers = attn_layers
579 | self.norm = nn.LayerNorm(dim)
580 |
581 | self.init_()
582 |
583 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
584 |
585 | # memory tokens (like [cls]) from Memory Transformers paper
586 | num_memory_tokens = default(num_memory_tokens, 0)
587 | self.num_memory_tokens = num_memory_tokens
588 | if num_memory_tokens > 0:
589 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
590 |
591 | # let funnel encoder know number of memory tokens, if specified
592 | if hasattr(attn_layers, 'num_memory_tokens'):
593 | attn_layers.num_memory_tokens = num_memory_tokens
594 |
595 | def init_(self):
596 | nn.init.normal_(self.token_emb.weight, std=0.02)
597 |
598 | def forward(
599 | self,
600 | x,
601 | return_embeddings=False,
602 | mask=None,
603 | return_mems=False,
604 | return_attn=False,
605 | mems=None,
606 | **kwargs
607 | ):
608 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
609 | x = self.token_emb(x)
610 | x += self.pos_emb(x)
611 | x = self.emb_dropout(x)
612 |
613 | x = self.project_emb(x)
614 |
615 | if num_mem > 0:
616 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
617 | x = torch.cat((mem, x), dim=1)
618 |
619 | # auto-handle masking after appending memory tokens
620 | if exists(mask):
621 | mask = F.pad(mask, (num_mem, 0), value=True)
622 |
623 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
624 | x = self.norm(x)
625 |
626 | mem, x = x[:, :num_mem], x[:, num_mem:]
627 |
628 | out = self.to_logits(x) if not return_embeddings else x
629 |
630 | if return_mems:
631 | hiddens = intermediates.hiddens
632 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
633 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
634 | return out, new_mems
635 |
636 | if return_attn:
637 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
638 | return out, attn_maps
639 |
640 | return out
641 |
642 |
--------------------------------------------------------------------------------
/rec_network/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/scripts/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | mkdir datasets
4 | cd datasets
5 | # Download describable textures dataset
6 | wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz
7 | tar -xf dtd-r1.0.1.tar.gz
8 | rm dtd-r1.0.1.tar.gz
9 |
10 | mkdir mvtec
11 | cd mvtec
12 | # Download MVTec anomaly detection dataset
13 | wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz
14 | tar -xf mvtec_anomaly_detection.tar.xz
15 | rm mvtec_anomaly_detection.tar.xz
16 |
17 |
--------------------------------------------------------------------------------
/scripts/mvtec.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | import cv2
8 | from rec_network.main import instantiate_from_config
9 | from rec_network.models.diffusion.ddim import DDIMSampler
10 | from rec_network.data.mvtec import MVTecDRAEMTestDataset
11 | from torch.utils.data import DataLoader
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "--outdir",
18 | type=str,
19 | default="./samples/bottle/",
20 | help="dir to write results to",
21 | )
22 | parser.add_argument(
23 | "--steps",
24 | type=int,
25 | default=50,
26 | help="number of ddim sampling steps",
27 | )
28 | opt = parser.parse_args()
29 | ddim_eta = 0.0
30 |
31 | mvtec_path = './datasets/mvtec/bottle'
32 |
33 | dataset = MVTecDRAEMTestDataset(mvtec_path + "/test/", resize_shape=[256, 256])
34 | dataloader = DataLoader(dataset, batch_size=1,
35 | shuffle=False, num_workers=0)
36 | print(f"Found {len(dataloader)} inputs.")
37 |
38 | config = OmegaConf.load("../configs/mvtec.yaml")
39 |
40 | model = instantiate_from_config(config.model)
41 | model.load_state_dict(torch.load("./logs/2023-01-31T16-58-02_mvtec/checkpoints/last.ckpt")["state_dict"],
42 | strict=False) #TODO: modify the ckpt path
43 |
44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
45 | model = model.to(device)
46 | sampler = DDIMSampler(model)
47 |
48 | os.makedirs(opt.outdir, exist_ok=True)
49 | cnt = 0
50 | with torch.no_grad():
51 | with model.ema_scope():
52 | for i_batch, batch in enumerate(dataloader):
53 |
54 | c_outpath = os.path.join(opt.outdir, 'condition'+str(cnt)+'.jpg')
55 | outpath = os.path.join(opt.outdir, str(cnt)+'.jpg')
56 | # print(outpath)
57 | condition = batch["image"].cpu().numpy().transpose(0,2,3,1)[0]*255
58 | cv2.imwrite(c_outpath,condition)
59 |
60 | c = batch["image"].to(device)
61 | c = model.cond_stage_model.encode(c)
62 | c = c.mode()
63 |
64 | noise = torch.randn_like(c)
65 | t = torch.randint(400, 500, (c.shape[0],), device=device).long()
66 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise)
67 |
68 | shape = c.shape[1:]
69 | samples_ddim, _ = sampler.sample(S=opt.steps,
70 | conditioning=c, # or conditioning=c_noisy
71 | batch_size=c.shape[0],
72 | shape=shape,
73 | verbose=False)
74 | x_samples_ddim = model.decode_first_stage(samples_ddim)
75 |
76 | sample = x_samples_ddim.cpu().numpy().transpose(0,2,3,1)[0]*255
77 | cv2.imwrite(outpath, sample)
78 | cnt+=1
79 |
--------------------------------------------------------------------------------
/seg_network/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | import cv2
6 | import glob
7 | import imgaug.augmenters as iaa
8 | from perlin import rand_perlin_2d_np
9 |
10 | class MVTecDRAEMTestDataset(Dataset):
11 |
12 | def __init__(self, root_dir, resize_shape=None):
13 | self.root_dir = root_dir
14 | self.images = sorted(glob.glob(root_dir+"/*/*.png"))
15 | self.resize_shape=resize_shape
16 |
17 | def __len__(self):
18 | return len(self.images)
19 |
20 | def transform_image(self, image_path, mask_path):
21 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)
22 | if mask_path is not None:
23 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
24 | else:
25 | mask = np.zeros((image.shape[0],image.shape[1]))
26 | if self.resize_shape != None:
27 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
28 | mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0]))
29 |
30 | image = image / 255.0
31 | mask = mask / 255.0
32 |
33 | image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32)
34 | mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32)
35 |
36 | image = np.transpose(image, (2, 0, 1))
37 | mask = np.transpose(mask, (2, 0, 1))
38 | return image, mask
39 |
40 | def __getitem__(self, idx):
41 | if torch.is_tensor(idx):
42 | idx = idx.tolist()
43 |
44 | img_path = self.images[idx]
45 | dir_path, file_name = os.path.split(img_path)
46 | base_dir = os.path.basename(dir_path)
47 | if base_dir == 'good':
48 | image, mask = self.transform_image(img_path, None)
49 | has_anomaly = np.array([0], dtype=np.float32)
50 | else:
51 | mask_path = os.path.join(dir_path, '../../ground_truth/')
52 | mask_path = os.path.join(mask_path, base_dir)
53 | mask_file_name = file_name.split(".")[0]+"_mask.png"
54 | mask_path = os.path.join(mask_path, mask_file_name)
55 | image, mask = self.transform_image(img_path, mask_path)
56 | has_anomaly = np.array([1], dtype=np.float32)
57 |
58 | sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx}
59 |
60 | return sample
61 |
62 |
63 |
64 | class MVTecDRAEMTrainDataset(Dataset):
65 |
66 | def __init__(self, root_dir, anomaly_source_path, resize_shape=None):
67 | """
68 | Args:
69 | root_dir (string): Directory with all the images.
70 | transform (callable, optional): Optional transform to be applied
71 | on a sample.
72 | """
73 | self.root_dir = root_dir
74 | self.resize_shape=resize_shape
75 |
76 | self.image_paths = sorted(glob.glob(root_dir+"/*.png"))
77 |
78 | self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path+"/*/*.jpg"))
79 |
80 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
81 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
82 | iaa.pillike.EnhanceSharpness(),
83 | iaa.AddToHueAndSaturation((-50,50),per_channel=True),
84 | iaa.Solarize(0.5, threshold=(32,128)),
85 | iaa.Posterize(),
86 | iaa.Invert(),
87 | iaa.pillike.Autocontrast(),
88 | iaa.pillike.Equalize(),
89 | iaa.Affine(rotate=(-45, 45))
90 | ]
91 |
92 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
93 |
94 |
95 | def __len__(self):
96 | return len(self.image_paths)
97 |
98 |
99 | def randAugmenter(self):
100 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
101 | aug = iaa.Sequential([self.augmenters[aug_ind[0]],
102 | self.augmenters[aug_ind[1]],
103 | self.augmenters[aug_ind[2]]]
104 | )
105 | return aug
106 |
107 | def augment_image(self, image, anomaly_source_path):
108 | aug = self.randAugmenter()
109 | perlin_scale = 6
110 | min_perlin_scale = 0
111 | anomaly_source_img = cv2.imread(anomaly_source_path)
112 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize_shape[1], self.resize_shape[0]))
113 |
114 | anomaly_img_augmented = aug(image=anomaly_source_img)
115 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
116 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
117 |
118 | perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley))
119 | perlin_noise = self.rot(image=perlin_noise)
120 | threshold = 0.5
121 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
122 | perlin_thr = np.expand_dims(perlin_thr, axis=2)
123 |
124 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0
125 |
126 | beta = torch.rand(1).numpy()[0] * 0.8
127 |
128 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * (
129 | perlin_thr)
130 |
131 | no_anomaly = torch.rand(1).numpy()[0]
132 | if no_anomaly > 0.5:
133 | image = image.astype(np.float32)
134 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32)
135 | else:
136 | augmented_image = augmented_image.astype(np.float32)
137 | msk = (perlin_thr).astype(np.float32)
138 | augmented_image = msk * augmented_image + (1-msk)*image
139 | has_anomaly = 1.0
140 | if np.sum(msk) == 0:
141 | has_anomaly=0.0
142 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32)
143 |
144 | def transform_image(self, image_path, anomaly_source_path):
145 | image = cv2.imread(image_path)
146 | image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
147 |
148 | do_aug_orig = torch.rand(1).numpy()[0] > 0.7
149 | if do_aug_orig:
150 | image = self.rot(image=image)
151 |
152 | image = np.array(image).reshape((image.shape[0], image.shape[1], image.shape[2])).astype(np.float32) / 255.0
153 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path)
154 | augmented_image = np.transpose(augmented_image, (2, 0, 1))
155 | image = np.transpose(image, (2, 0, 1))
156 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1))
157 | return image, augmented_image, anomaly_mask, has_anomaly
158 |
159 | def __getitem__(self, idx):
160 | idx = torch.randint(0, len(self.image_paths), (1,)).item()
161 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
162 | image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(self.image_paths[idx],
163 | self.anomaly_source_paths[anomaly_source_idx])
164 | sample = {'image': image, "anomaly_mask": anomaly_mask,
165 | 'augmented_image': augmented_image, 'has_anomaly': has_anomaly, 'idx': idx}
166 |
167 | return sample
168 |
--------------------------------------------------------------------------------
/seg_network/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from math import exp
6 |
7 | class FocalLoss(nn.Module):
8 | """
9 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
10 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
11 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
12 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
13 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
14 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
15 | focus on hard misclassified example
16 | :param smooth: (float,double) smooth value when cross entropy
17 | :param balance_index: (int) balance class index, should be specific when alpha is float
18 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
19 | """
20 |
21 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
22 | super(FocalLoss, self).__init__()
23 | self.apply_nonlin = apply_nonlin
24 | self.alpha = alpha
25 | self.gamma = gamma
26 | self.balance_index = balance_index
27 | self.smooth = smooth
28 | self.size_average = size_average
29 |
30 | if self.smooth is not None:
31 | if self.smooth < 0 or self.smooth > 1.0:
32 | raise ValueError('smooth value should be in [0,1]')
33 |
34 | def forward(self, logit, target):
35 | if self.apply_nonlin is not None:
36 | logit = self.apply_nonlin(logit)
37 | num_class = logit.shape[1]
38 |
39 | if logit.dim() > 2:
40 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
41 | logit = logit.view(logit.size(0), logit.size(1), -1)
42 | logit = logit.permute(0, 2, 1).contiguous()
43 | logit = logit.view(-1, logit.size(-1))
44 | target = torch.squeeze(target, 1)
45 | target = target.view(-1, 1)
46 | alpha = self.alpha
47 |
48 | if alpha is None:
49 | alpha = torch.ones(num_class, 1)
50 | elif isinstance(alpha, (list, np.ndarray)):
51 | assert len(alpha) == num_class
52 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
53 | alpha = alpha / alpha.sum()
54 | elif isinstance(alpha, float):
55 | alpha = torch.ones(num_class, 1)
56 | alpha = alpha * (1 - self.alpha)
57 | alpha[self.balance_index] = self.alpha
58 |
59 | else:
60 | raise TypeError('Not support alpha type')
61 |
62 | if alpha.device != logit.device:
63 | alpha = alpha.to(logit.device)
64 |
65 | idx = target.cpu().long()
66 |
67 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
68 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
69 | if one_hot_key.device != logit.device:
70 | one_hot_key = one_hot_key.to(logit.device)
71 |
72 | if self.smooth:
73 | one_hot_key = torch.clamp(
74 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
75 | pt = (one_hot_key * logit).sum(1) + self.smooth
76 | logpt = pt.log()
77 |
78 | gamma = self.gamma
79 |
80 | alpha = alpha[idx]
81 | alpha = torch.squeeze(alpha)
82 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
83 |
84 | if self.size_average:
85 | loss = loss.mean()
86 | return loss
87 |
88 | def gaussian(window_size, sigma):
89 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
90 | return gauss/gauss.sum()
91 |
92 | def create_window(window_size, channel=1):
93 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
94 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
95 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
96 | return window
97 |
98 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
99 | if val_range is None:
100 | if torch.max(img1) > 128:
101 | max_val = 255
102 | else:
103 | max_val = 1
104 |
105 | if torch.min(img1) < -0.5:
106 | min_val = -1
107 | else:
108 | min_val = 0
109 | l = max_val - min_val
110 | else:
111 | l = val_range
112 |
113 | padd = window_size//2
114 | (_, channel, height, width) = img1.size()
115 | if window is None:
116 | real_size = min(window_size, height, width)
117 | window = create_window(real_size, channel=channel).to(img1.device)
118 |
119 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
120 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
121 |
122 | mu1_sq = mu1.pow(2)
123 | mu2_sq = mu2.pow(2)
124 | mu1_mu2 = mu1 * mu2
125 |
126 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
127 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
128 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
129 |
130 | c1 = (0.01 * l) ** 2
131 | c2 = (0.03 * l) ** 2
132 |
133 | v1 = 2.0 * sigma12 + c2
134 | v2 = sigma1_sq + sigma2_sq + c2
135 | cs = torch.mean(v1 / v2) # contrast sensitivity
136 |
137 | ssim_map = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2)
138 |
139 | if size_average:
140 | ret = ssim_map.mean()
141 | else:
142 | ret = ssim_map.mean(1).mean(1).mean(1)
143 |
144 | if full:
145 | return ret, cs
146 | return ret, ssim_map
147 |
148 |
149 | class SSIM(torch.nn.Module):
150 | def __init__(self, window_size=11, size_average=True, val_range=None):
151 | super(SSIM, self).__init__()
152 | self.window_size = window_size
153 | self.size_average = size_average
154 | self.val_range = val_range
155 |
156 | # Assume 1 channel for SSIM
157 | self.channel = 1
158 | self.window = create_window(window_size).cuda()
159 |
160 | def forward(self, img1, img2):
161 | (_, channel, _, _) = img1.size()
162 |
163 | if channel == self.channel and self.window.dtype == img1.dtype:
164 | window = self.window
165 | else:
166 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
167 | self.window = window
168 | self.channel = channel
169 |
170 | s_score, ssim_map = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
171 | return 1.0 - s_score
172 |
--------------------------------------------------------------------------------
/seg_network/model_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ReconstructiveSubNetwork(nn.Module):
6 | def __init__(self,in_channels=3, out_channels=3, base_width=128):
7 | super(ReconstructiveSubNetwork, self).__init__()
8 | self.encoder = EncoderReconstructive(in_channels, base_width)
9 | self.decoder = DecoderReconstructive(base_width, out_channels=out_channels)
10 |
11 | def forward(self, x):
12 | b5 = self.encoder(x)
13 | output = self.decoder(b5)
14 | return output
15 |
16 | class DiscriminativeSubNetwork(nn.Module):
17 | def __init__(self,in_channels=3, out_channels=3, base_channels=64, out_features=False):
18 | super(DiscriminativeSubNetwork, self).__init__()
19 | base_width = base_channels
20 | self.encoder_segment = EncoderDiscriminative(in_channels, base_width)
21 | self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels)
22 | #self.segment_act = torch.nn.Sigmoid()
23 | self.out_features = out_features
24 | def forward(self, x):
25 | b1,b2,b3,b4,b5,b6 = self.encoder_segment(x)
26 | output_segment = self.decoder_segment(b1,b2,b3,b4,b5,b6)
27 | if self.out_features:
28 | return output_segment, b2, b3, b4, b5, b6
29 | else:
30 | return output_segment
31 |
32 | class EncoderDiscriminative(nn.Module):
33 | def __init__(self, in_channels, base_width):
34 | super(EncoderDiscriminative, self).__init__()
35 | self.block1 = nn.Sequential(
36 | nn.Conv2d(in_channels,base_width, kernel_size=3, padding=1),
37 | nn.BatchNorm2d(base_width),
38 | nn.ReLU(inplace=True),
39 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
40 | nn.BatchNorm2d(base_width),
41 | nn.ReLU(inplace=True))
42 | self.mp1 = nn.Sequential(nn.MaxPool2d(2))
43 | self.block2 = nn.Sequential(
44 | nn.Conv2d(base_width,base_width*2, kernel_size=3, padding=1),
45 | nn.BatchNorm2d(base_width*2),
46 | nn.ReLU(inplace=True),
47 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1),
48 | nn.BatchNorm2d(base_width*2),
49 | nn.ReLU(inplace=True))
50 | self.mp2 = nn.Sequential(nn.MaxPool2d(2))
51 | self.block3 = nn.Sequential(
52 | nn.Conv2d(base_width*2,base_width*4, kernel_size=3, padding=1),
53 | nn.BatchNorm2d(base_width*4),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1),
56 | nn.BatchNorm2d(base_width*4),
57 | nn.ReLU(inplace=True))
58 | self.mp3 = nn.Sequential(nn.MaxPool2d(2))
59 | self.block4 = nn.Sequential(
60 | nn.Conv2d(base_width*4,base_width*8, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(base_width*8),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
64 | nn.BatchNorm2d(base_width*8),
65 | nn.ReLU(inplace=True))
66 | self.mp4 = nn.Sequential(nn.MaxPool2d(2))
67 | self.block5 = nn.Sequential(
68 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1),
69 | nn.BatchNorm2d(base_width*8),
70 | nn.ReLU(inplace=True),
71 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
72 | nn.BatchNorm2d(base_width*8),
73 | nn.ReLU(inplace=True))
74 |
75 | self.mp5 = nn.Sequential(nn.MaxPool2d(2))
76 | self.block6 = nn.Sequential(
77 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1),
78 | nn.BatchNorm2d(base_width*8),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
81 | nn.BatchNorm2d(base_width*8),
82 | nn.ReLU(inplace=True))
83 |
84 |
85 | def forward(self, x):
86 | b1 = self.block1(x)
87 | mp1 = self.mp1(b1)
88 | b2 = self.block2(mp1)
89 | mp2 = self.mp3(b2)
90 | b3 = self.block3(mp2)
91 | mp3 = self.mp3(b3)
92 | b4 = self.block4(mp3)
93 | mp4 = self.mp4(b4)
94 | b5 = self.block5(mp4)
95 | mp5 = self.mp5(b5)
96 | b6 = self.block6(mp5)
97 | return b1,b2,b3,b4,b5,b6
98 |
99 | class DecoderDiscriminative(nn.Module):
100 | def __init__(self, base_width, out_channels=1):
101 | super(DecoderDiscriminative, self).__init__()
102 |
103 | self.up_b = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
104 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
105 | nn.BatchNorm2d(base_width * 8),
106 | nn.ReLU(inplace=True))
107 | self.db_b = nn.Sequential(
108 | nn.Conv2d(base_width*(8+8), base_width*8, kernel_size=3, padding=1),
109 | nn.BatchNorm2d(base_width*8),
110 | nn.ReLU(inplace=True),
111 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
112 | nn.BatchNorm2d(base_width * 8),
113 | nn.ReLU(inplace=True)
114 | )
115 |
116 |
117 | self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
118 | nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1),
119 | nn.BatchNorm2d(base_width * 4),
120 | nn.ReLU(inplace=True))
121 | self.db1 = nn.Sequential(
122 | nn.Conv2d(base_width*(4+8), base_width*4, kernel_size=3, padding=1),
123 | nn.BatchNorm2d(base_width*4),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1),
126 | nn.BatchNorm2d(base_width * 4),
127 | nn.ReLU(inplace=True)
128 | )
129 |
130 | self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
131 | nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1),
132 | nn.BatchNorm2d(base_width * 2),
133 | nn.ReLU(inplace=True))
134 | self.db2 = nn.Sequential(
135 | nn.Conv2d(base_width*(2+4), base_width*2, kernel_size=3, padding=1),
136 | nn.BatchNorm2d(base_width*2),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1),
139 | nn.BatchNorm2d(base_width * 2),
140 | nn.ReLU(inplace=True)
141 | )
142 |
143 | self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
144 | nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1),
145 | nn.BatchNorm2d(base_width),
146 | nn.ReLU(inplace=True))
147 | self.db3 = nn.Sequential(
148 | nn.Conv2d(base_width*(2+1), base_width, kernel_size=3, padding=1),
149 | nn.BatchNorm2d(base_width),
150 | nn.ReLU(inplace=True),
151 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
152 | nn.BatchNorm2d(base_width),
153 | nn.ReLU(inplace=True)
154 | )
155 |
156 | self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
157 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
158 | nn.BatchNorm2d(base_width),
159 | nn.ReLU(inplace=True))
160 | self.db4 = nn.Sequential(
161 | nn.Conv2d(base_width*2, base_width, kernel_size=3, padding=1),
162 | nn.BatchNorm2d(base_width),
163 | nn.ReLU(inplace=True),
164 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
165 | nn.BatchNorm2d(base_width),
166 | nn.ReLU(inplace=True)
167 | )
168 |
169 |
170 |
171 | self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1))
172 |
173 | def forward(self, b1,b2,b3,b4,b5,b6):
174 | up_b = self.up_b(b6)
175 | cat_b = torch.cat((up_b,b5),dim=1)
176 | db_b = self.db_b(cat_b)
177 |
178 | up1 = self.up1(db_b)
179 | cat1 = torch.cat((up1,b4),dim=1)
180 | db1 = self.db1(cat1)
181 |
182 | up2 = self.up2(db1)
183 | cat2 = torch.cat((up2,b3),dim=1)
184 | db2 = self.db2(cat2)
185 |
186 | up3 = self.up3(db2)
187 | cat3 = torch.cat((up3,b2),dim=1)
188 | db3 = self.db3(cat3)
189 |
190 | up4 = self.up4(db3)
191 | cat4 = torch.cat((up4,b1),dim=1)
192 | db4 = self.db4(cat4)
193 |
194 | out = self.fin_out(db4)
195 | return out
196 |
197 |
198 |
199 | class EncoderReconstructive(nn.Module):
200 | def __init__(self, in_channels, base_width):
201 | super(EncoderReconstructive, self).__init__()
202 | self.block1 = nn.Sequential(
203 | nn.Conv2d(in_channels,base_width, kernel_size=3, padding=1),
204 | nn.BatchNorm2d(base_width),
205 | nn.ReLU(inplace=True),
206 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
207 | nn.BatchNorm2d(base_width),
208 | nn.ReLU(inplace=True))
209 | self.mp1 = nn.Sequential(nn.MaxPool2d(2))
210 | self.block2 = nn.Sequential(
211 | nn.Conv2d(base_width,base_width*2, kernel_size=3, padding=1),
212 | nn.BatchNorm2d(base_width*2),
213 | nn.ReLU(inplace=True),
214 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1),
215 | nn.BatchNorm2d(base_width*2),
216 | nn.ReLU(inplace=True))
217 | self.mp2 = nn.Sequential(nn.MaxPool2d(2))
218 | self.block3 = nn.Sequential(
219 | nn.Conv2d(base_width*2,base_width*4, kernel_size=3, padding=1),
220 | nn.BatchNorm2d(base_width*4),
221 | nn.ReLU(inplace=True),
222 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1),
223 | nn.BatchNorm2d(base_width*4),
224 | nn.ReLU(inplace=True))
225 | self.mp3 = nn.Sequential(nn.MaxPool2d(2))
226 | self.block4 = nn.Sequential(
227 | nn.Conv2d(base_width*4,base_width*8, kernel_size=3, padding=1),
228 | nn.BatchNorm2d(base_width*8),
229 | nn.ReLU(inplace=True),
230 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
231 | nn.BatchNorm2d(base_width*8),
232 | nn.ReLU(inplace=True))
233 | self.mp4 = nn.Sequential(nn.MaxPool2d(2))
234 | self.block5 = nn.Sequential(
235 | nn.Conv2d(base_width*8,base_width*8, kernel_size=3, padding=1),
236 | nn.BatchNorm2d(base_width*8),
237 | nn.ReLU(inplace=True),
238 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
239 | nn.BatchNorm2d(base_width*8),
240 | nn.ReLU(inplace=True))
241 |
242 |
243 | def forward(self, x):
244 | b1 = self.block1(x)
245 | mp1 = self.mp1(b1)
246 | b2 = self.block2(mp1)
247 | mp2 = self.mp3(b2)
248 | b3 = self.block3(mp2)
249 | mp3 = self.mp3(b3)
250 | b4 = self.block4(mp3)
251 | mp4 = self.mp4(b4)
252 | b5 = self.block5(mp4)
253 | return b5
254 |
255 |
256 | class DecoderReconstructive(nn.Module):
257 | def __init__(self, base_width, out_channels=1):
258 | super(DecoderReconstructive, self).__init__()
259 |
260 | self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
261 | nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
262 | nn.BatchNorm2d(base_width * 8),
263 | nn.ReLU(inplace=True))
264 | self.db1 = nn.Sequential(
265 | nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
266 | nn.BatchNorm2d(base_width*8),
267 | nn.ReLU(inplace=True),
268 | nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1),
269 | nn.BatchNorm2d(base_width * 4),
270 | nn.ReLU(inplace=True)
271 | )
272 |
273 | self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
274 | nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1),
275 | nn.BatchNorm2d(base_width * 4),
276 | nn.ReLU(inplace=True))
277 | self.db2 = nn.Sequential(
278 | nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1),
279 | nn.BatchNorm2d(base_width*4),
280 | nn.ReLU(inplace=True),
281 | nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1),
282 | nn.BatchNorm2d(base_width * 2),
283 | nn.ReLU(inplace=True)
284 | )
285 |
286 | self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
287 | nn.Conv2d(base_width * 2, base_width*2, kernel_size=3, padding=1),
288 | nn.BatchNorm2d(base_width*2),
289 | nn.ReLU(inplace=True))
290 | # cat with base*1
291 | self.db3 = nn.Sequential(
292 | nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1),
293 | nn.BatchNorm2d(base_width*2),
294 | nn.ReLU(inplace=True),
295 | nn.Conv2d(base_width*2, base_width*1, kernel_size=3, padding=1),
296 | nn.BatchNorm2d(base_width*1),
297 | nn.ReLU(inplace=True)
298 | )
299 |
300 | self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
301 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
302 | nn.BatchNorm2d(base_width),
303 | nn.ReLU(inplace=True))
304 | self.db4 = nn.Sequential(
305 | nn.Conv2d(base_width*1, base_width, kernel_size=3, padding=1),
306 | nn.BatchNorm2d(base_width),
307 | nn.ReLU(inplace=True),
308 | nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
309 | nn.BatchNorm2d(base_width),
310 | nn.ReLU(inplace=True)
311 | )
312 |
313 | self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1))
314 | #self.fin_out = nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)
315 |
316 | def forward(self, b5):
317 | up1 = self.up1(b5)
318 | db1 = self.db1(up1)
319 |
320 | up2 = self.up2(db1)
321 | db2 = self.db2(up2)
322 |
323 | up3 = self.up3(db2)
324 | db3 = self.db3(up3)
325 |
326 | up4 = self.up4(db3)
327 | db4 = self.db4(up4)
328 |
329 | out = self.fin_out(db4)
330 | return out
--------------------------------------------------------------------------------
/seg_network/perlin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | def lerp_np(x,y,w):
6 | fin_out = (y-x)*w + x
7 | return fin_out
8 |
9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5):
10 | noise = np.zeros(shape)
11 | frequency = 1
12 | amplitude = 1
13 | for _ in range(octaves):
14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1]))
15 | frequency *= 2
16 | amplitude *= persistence
17 | return noise
18 |
19 |
20 | def generate_perlin_noise_2d(shape, res):
21 | def f(t):
22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
23 |
24 | delta = (res[0] / shape[0], res[1] / shape[1])
25 | d = (shape[0] // res[0], shape[1] // res[1])
26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
27 | # Gradients
28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
29 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
34 | # Ramps
35 | n00 = np.sum(grid * g00, 2)
36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
39 | # Interpolation
40 | t = f(grid)
41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
44 |
45 |
46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
47 | delta = (res[0] / shape[0], res[1] / shape[1])
48 | d = (shape[0] // res[0], shape[1] // res[1])
49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
50 |
51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1)
54 |
55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1)
56 | dot = lambda grad, shift: (
57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
59 |
60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
64 | t = fade(grid[:shape[0], :shape[1]])
65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
66 |
67 |
68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
69 | delta = (res[0] / shape[0], res[1] / shape[1])
70 | d = (shape[0] // res[0], shape[1] // res[1])
71 |
72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
75 |
76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
77 | 0).repeat_interleave(
78 | d[1], 1)
79 | dot = lambda grad, shift: (
80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
82 |
83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
84 |
85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
88 | t = fade(grid[:shape[0], :shape[1]])
89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
90 |
91 |
92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
93 | noise = torch.zeros(shape)
94 | frequency = 1
95 | amplitude = 1
96 | for _ in range(octaves):
97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
98 | frequency *= 2
99 | amplitude *= persistence
100 | return noise
--------------------------------------------------------------------------------
/seg_network/tensorboard_visualizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.tensorboard import SummaryWriter
4 | from torchvision import datasets, transforms
5 | import os
6 | # Writer will output to ./runs/ directory by default
7 |
8 | class TensorboardVisualizer():
9 |
10 | def __init__(self,log_dir='./logs/'):
11 | if not os.path.exists(log_dir):
12 | os.makedirs(log_dir)
13 | self.writer = SummaryWriter(log_dir=log_dir)
14 |
15 | def visualize_image_batch(self,image_batch,n_iter,image_name='Image_batch'):
16 | grid = torchvision.utils.make_grid(image_batch)
17 | self.writer.add_image(image_name,grid,n_iter)
18 |
19 | def plot_loss(self, loss_val, n_iter, loss_name='loss'):
20 | self.writer.add_scalar(loss_name, loss_val, n_iter)
21 |
22 |
--------------------------------------------------------------------------------
/seg_network/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from models.diffusion import Model
4 | from models.ema import EMAHelper
5 | from omegaconf import OmegaConf
6 | from PIL import Image
7 | from data_loader import MVTecDRAEMTestDataset, MVTecDRAEMTrainDataset
8 | from torch.utils.data import DataLoader
9 | import numpy as np
10 | from sklearn.metrics import roc_auc_score, average_precision_score
11 | from model_unet import ReconstructiveSubNetwork, DiscriminativeSubNetwork
12 | import os
13 | import cv2
14 | from rec_network.util import instantiate_from_config
15 | from rec_network.models.diffusion.ddim import DDIMSampler
16 |
17 |
18 | def write_results_to_file(run_name, image_auc, pixel_auc, image_ap, pixel_ap):
19 | if not os.path.exists('./outputs/'):
20 | os.makedirs('./outputs/')
21 |
22 | fin_str = "img_auc," + run_name
23 | for i in image_auc:
24 | fin_str += "," + str(np.round(i, 3))
25 | fin_str += "," + str(np.round(np.mean(image_auc), 3))
26 | fin_str += "\n"
27 | fin_str += "pixel_auc," + run_name
28 | for i in pixel_auc:
29 | fin_str += "," + str(np.round(i, 3))
30 | fin_str += "," + str(np.round(np.mean(pixel_auc), 3))
31 | fin_str += "\n"
32 | fin_str += "img_ap," + run_name
33 | for i in image_ap:
34 | fin_str += "," + str(np.round(i, 3))
35 | fin_str += "," + str(np.round(np.mean(image_ap), 3))
36 | fin_str += "\n"
37 | fin_str += "pixel_ap," + run_name
38 | for i in pixel_ap:
39 | fin_str += "," + str(np.round(i, 3))
40 | fin_str += "," + str(np.round(np.mean(pixel_ap), 3))
41 | fin_str += "\n"
42 | fin_str += "--------------------------\n"
43 |
44 | with open("./outputs/results.txt", 'a+') as file:
45 | file.write(fin_str)
46 |
47 |
48 | def test(obj_name, mvtec_path, checkpoint_path, base_model_name):
49 |
50 | img_dim = 256
51 | run_name = base_model_name + "_" + obj_name + '_'
52 |
53 | config = OmegaConf.load("../configs/mvtec.yaml")
54 |
55 | model = instantiate_from_config(config.model)
56 |
57 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
58 | model = model.to(device)
59 | sampler = DDIMSampler(model)
60 |
61 | model_seg = DiscriminativeSubNetwork(in_channels=9, out_channels=2)
62 | model_seg.load_state_dict(torch.load(os.path.join(checkpoint_path, run_name+"_seg.pckl"), map_location='cuda:0'))
63 | model_seg.cuda()
64 | model_seg.eval()
65 |
66 | dataset = MVTecDRAEMTestDataset(mvtec_path + obj_name + "/test", resize_shape=[img_dim, img_dim])
67 | dataloader = DataLoader(dataset, batch_size=1,
68 | shuffle=False, num_workers=0)
69 |
70 | total_pixel_scores = np.zeros((img_dim * img_dim * len(dataset)))
71 | total_gt_pixel_scores = np.zeros((img_dim * img_dim * len(dataset)))
72 | mask_cnt = 0
73 |
74 | anomaly_score_gt = []
75 | anomaly_score_prediction = []
76 |
77 | cnt_display = 0
78 |
79 | for i_batch, sample_batched in enumerate(dataloader):
80 | gray_batch = sample_batched["image"].cuda()
81 |
82 | is_normal = sample_batched["has_anomaly"].detach().numpy()[0, 0]
83 | anomaly_score_gt.append(is_normal)
84 | true_mask = sample_batched["mask"]
85 | true_mask_cv = true_mask.detach().numpy()[0, :, :, :].transpose((1, 2, 0))
86 |
87 | c = model.cond_stage_model.encode(gray_batch)
88 | c = c.mode()
89 | noise = torch.randn_like(c)
90 | t = torch.randint(400, 500, (c.shape[0],), device=device).long()
91 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise)
92 |
93 | shape = c.shape[1:]
94 | samples_ddim, _ = sampler.sample(S=10,
95 | conditioning=c, # or conditioning=c_noisy
96 | batch_size=c.shape[0],
97 | shape=shape,
98 | verbose=False)
99 |
100 | gray_rec = model.decode_first_stage(samples_ddim)
101 |
102 | samples_ddim1 = 0.5 * samples_ddim + 0.5 * c
103 | gray_rec1 = model.decode_first_stage(samples_ddim1)
104 |
105 | joined_in = torch.cat((gray_rec.detach(), gray_rec1.detach(), gray_batch), dim=1)
106 |
107 | out_mask = model_seg(joined_in)
108 | out_mask_sm = torch.softmax(out_mask, dim=1)
109 |
110 | t_mask = out_mask_sm[:, 1:, :, :]
111 |
112 | outpath = os.path.join('./samples', obj_name, 'test', 'rec_images' + str(cnt_display) + '.jpg')
113 | sample = gray_rec.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
114 | cv2.imwrite(outpath, sample)
115 |
116 | outpath = os.path.join('./samples', obj_name, 'test', 'gt_images' + str(cnt_display) + '.jpg')
117 | sample = gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
118 | cv2.imwrite(outpath, sample)
119 |
120 | outpath = os.path.join('./samples', obj_name, 'test', 'out_masks' + str(cnt_display) + '.jpg')
121 | sample = t_mask[0].detach().cpu().numpy()[0] * 255
122 | cv2.imwrite(outpath, sample)
123 |
124 | outpath = os.path.join('./samples', obj_name, 'test', 'in_masks' + str(cnt_display) + '.jpg')
125 | sample = true_mask[0].detach().cpu().numpy()[0] * 255
126 | cv2.imwrite(outpath, sample)
127 |
128 | heatmap = t_mask[0].detach().cpu().numpy()[0]
129 | heatmap = heatmap / np.max(heatmap)
130 | heatmap = np.uint8(255 * heatmap)
131 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
132 | show = heatmap * 0.5 + gray_batch.detach().cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
133 | outpath = os.path.join('./samples', obj_name, 'test', 'heatmap' + str(cnt_display) + '.jpg')
134 | cv2.imwrite(outpath, show)
135 | cnt_display += 1
136 |
137 | out_mask_cv = out_mask_sm[0, 1, :, :].detach().cpu().numpy()
138 | out_mask_averaged = torch.nn.functional.avg_pool2d(out_mask_sm[:, 1:, :, :], 21, stride=1,
139 | padding=21 // 2).cpu().detach().numpy()
140 | image_score = np.max(out_mask_averaged)
141 | anomaly_score_prediction.append(image_score)
142 |
143 | flat_true_mask = true_mask_cv.flatten()
144 | flat_out_mask = out_mask_cv.flatten()
145 | total_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_out_mask
146 | total_gt_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_true_mask
147 | mask_cnt += 1
148 |
149 | anomaly_score_prediction = np.array(anomaly_score_prediction)
150 | anomaly_score_gt = np.array(anomaly_score_gt)
151 | auroc = roc_auc_score(anomaly_score_gt, anomaly_score_prediction)
152 | ap = average_precision_score(anomaly_score_gt, anomaly_score_prediction)
153 |
154 | total_gt_pixel_scores = total_gt_pixel_scores.astype(np.uint8)
155 | total_gt_pixel_scores = total_gt_pixel_scores[:img_dim * img_dim * mask_cnt]
156 | total_pixel_scores = total_pixel_scores[:img_dim * img_dim * mask_cnt]
157 | auroc_pixel = roc_auc_score(total_gt_pixel_scores, total_pixel_scores)
158 | ap_pixel = average_precision_score(total_gt_pixel_scores, total_pixel_scores)
159 |
160 | print(obj_name)
161 | print("AUC Image: " + str(auroc))
162 | print("AP Image: " + str(ap))
163 | print("AUC Pixel: " + str(auroc_pixel))
164 | print("AP Pixel: " + str(ap_pixel))
165 |
166 | write_results_to_file(run_name, auroc, auroc_pixel, ap, ap_pixel)
167 |
168 |
169 | if __name__ == "__main__":
170 | import argparse
171 |
172 | parser = argparse.ArgumentParser()
173 | parser.add_argument('--gpu_id', action='store', type=int, required=True)
174 | parser.add_argument('--base_model_name', action='store', type=str, required=True)
175 | parser.add_argument('--data_path', action='store', type=str, required=True)
176 | parser.add_argument('--checkpoint_path', action='store', type=str, required=True)
177 | parser.add_argument('--ema', action='store_true')
178 |
179 | parser.add_argument("--logit_transform", default=False)
180 | parser.add_argument("--uniform_dequantization", default=False)
181 | parser.add_argument("--gaussian_dequantization", default=False)
182 | parser.add_argument("--random_flip", default=True)
183 | parser.add_argument("--rescaled", default=True)
184 | parser.add_argument("--sample_type", type=str, default="generalized",
185 | help="sampling approach (generalized or ddpm_noisy)")
186 | parser.add_argument("--skip_type", type=str, default="uniform", help="skip according to (uniform or quadratic)")
187 |
188 | args = parser.parse_args()
189 |
190 | obj_list = ['capsule',
191 | 'bottle',
192 | 'carpet',
193 | 'leather',
194 | 'pill',
195 | 'transistor',
196 | 'tile',
197 | 'cable',
198 | 'zipper',
199 | 'toothbrush',
200 | 'metal_nut',
201 | 'hazelnut',
202 | 'screw',
203 | 'grid',
204 | 'wood'
205 | ]
206 | obj_name = 'bottle'
207 |
208 | with torch.cuda.device(args.gpu_id):
209 | test(obj_name, args.data_path, args.checkpoint_path, args.base_model_name)
210 |
--------------------------------------------------------------------------------
/seg_network/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.diffusion import Model
3 | from models.ema import EMAHelper
4 | from omegaconf import OmegaConf
5 | from data_loader import MVTecDRAEMTrainDataset
6 | from torch.utils.data import DataLoader
7 | from torch import optim
8 | from tensorboard_visualizer import TensorboardVisualizer
9 | from model_unet import DiscriminativeSubNetwork
10 | from loss import FocalLoss, SSIM
11 | from rec_network.util import instantiate_from_config
12 | from rec_network.models.diffusion.ddim import DDIMSampler
13 | import cv2
14 | import os
15 |
16 |
17 | def get_lr(optimizer):
18 | for param_group in optimizer.param_groups:
19 | return param_group['lr']
20 |
21 |
22 | def weights_init(m):
23 | classname = m.__class__.__name__
24 | if classname.find('Conv') != -1:
25 | m.weight.data.normal_(0.0, 0.02)
26 | elif classname.find('BatchNorm') != -1:
27 | m.weight.data.normal_(1.0, 0.02)
28 | m.bias.data.fill_(0)
29 |
30 |
31 | def train_on_device(obj_name, args):
32 | if not os.path.exists(args.checkpoint_path):
33 | os.makedirs(args.checkpoint_path)
34 |
35 | if not os.path.exists(args.log_path):
36 | os.makedirs(args.log_path)
37 |
38 | run_name = 'DRAEM_test_' + str(args.lr) + '_' + str(args.epochs) + '_bs' + str(args.bs) + "_" + obj_name + '_'
39 |
40 | visualizer = TensorboardVisualizer(log_dir=os.path.join(args.log_path, run_name + "/"))
41 |
42 | config = OmegaConf.load("../configs/mvtec.yaml")
43 |
44 | model = instantiate_from_config(config.model)
45 |
46 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47 | model = model.to(device)
48 | sampler = DDIMSampler(model)
49 |
50 | model_seg = DiscriminativeSubNetwork(in_channels=9, out_channels=2)
51 | model_seg.cuda()
52 | model_seg.apply(weights_init)
53 |
54 | optimizer = torch.optim.Adam([
55 | {"params": model_seg.parameters(), "lr": args.lr}])
56 |
57 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [args.epochs * 0.8, args.epochs * 0.9], gamma=0.2,
58 | last_epoch=-1)
59 |
60 | loss_focal = FocalLoss()
61 |
62 | dataset = MVTecDRAEMTrainDataset(args.data_path + obj_name + "/train/good/", args.anomaly_source_path,
63 | resize_shape=[256, 256])
64 |
65 | dataloader = DataLoader(dataset, batch_size=args.bs,
66 | shuffle=True, num_workers=0)
67 |
68 | n_iter = 0
69 | for epoch in range(args.epochs):
70 | lr = scheduler.get_last_lr()[0]
71 | print("Epoch: " + str(epoch) + " Learning rate: " + str(lr))
72 | for i_batch, sample_batched in enumerate(dataloader):
73 | gray_batch = sample_batched["image"].cuda()
74 | aug_gray_batch = sample_batched["augmented_image"].cuda()
75 | anomaly_mask = sample_batched["anomaly_mask"].cuda()
76 |
77 | c = model.cond_stage_model.encode(aug_gray_batch)
78 | c = c.mode()
79 |
80 | shape = c.shape[1:]
81 | noise = torch.randn_like(c)
82 | t = torch.randint(400, 500, (c.shape[0],), device=device).long()
83 | c_noisy = model.q_sample(x_start=c, t=t, noise=noise)
84 | samples_ddim, _ = sampler.sample(S=50,
85 | conditioning=c, # or conditioning=c_noisy
86 | batch_size=c.shape[0],
87 | shape=shape,
88 | verbose=False)
89 | gray_rec = model.decode_first_stage(samples_ddim)
90 |
91 | samples_ddim1 = 0.5 * samples_ddim + 0.5 * c
92 | gray_rec1 = model.decode_first_stage(samples_ddim1)
93 |
94 | joined_in = torch.cat((gray_rec, gray_rec1, aug_gray_batch), dim=1)
95 |
96 | out_mask = model_seg(joined_in)
97 | out_mask_sm = torch.softmax(out_mask, dim=1)
98 |
99 | segment_loss = loss_focal(out_mask_sm, anomaly_mask)
100 | loss = segment_loss
101 |
102 | optimizer.zero_grad()
103 |
104 | loss.backward()
105 | optimizer.step()
106 |
107 | if args.visualize and n_iter % 200 == 0:
108 | visualizer.plot_loss(segment_loss, n_iter, loss_name='segment_loss')
109 | if args.visualize and n_iter % 100 == 0:
110 | t_mask = out_mask_sm[:, 1:, :, :]
111 |
112 | outpath = os.path.join('./samples',obj_name,'train', 'batch_augmented' + str(n_iter) + '.jpg')
113 | sample = aug_gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
114 | cv2.imwrite(outpath, sample)
115 |
116 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_target' + str(n_iter) + '.jpg')
117 | sample = gray_batch.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
118 | cv2.imwrite(outpath, sample)
119 |
120 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_out' + str(n_iter) + '.jpg')
121 | sample = gray_rec.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
122 | cv2.imwrite(outpath, sample)
123 |
124 | outpath = os.path.join('./samples',obj_name,'train', 'batch_recon_inter' + str(n_iter) + '.jpg')
125 | sample = gray_rec1.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
126 | cv2.imwrite(outpath, sample)
127 |
128 | outpath = os.path.join('./samples',obj_name,'train', 'mask_target' + str(n_iter) + '.jpg')
129 | sample = anomaly_mask[0].detach().cpu().numpy()[0] * 255
130 | cv2.imwrite(outpath, sample)
131 |
132 | outpath = os.path.join('./samples',obj_name,'train', 'mask_out' + str(n_iter) + '.jpg')
133 | sample = t_mask[0].detach().cpu().numpy()[0] * 255
134 | cv2.imwrite(outpath, sample)
135 |
136 | n_iter += 1
137 |
138 | scheduler.step()
139 |
140 | torch.save(model_seg.state_dict(), os.path.join(args.checkpoint_path, run_name + "seg.pckl"))
141 | # if epoch % 100 == 0 :
142 | # torch.save(model_seg.state_dict(),
143 | # os.path.join(args.checkpoint_path, run_name + str(epoch) + "_seg.pckl"))
144 |
145 |
146 | if __name__ == "__main__":
147 | import argparse
148 |
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument('--obj_id', action='store', type=int, required=True)
151 | parser.add_argument('--bs', action='store', type=int, required=True)
152 | parser.add_argument('--lr', action='store', type=float, required=True)
153 | parser.add_argument('--epochs', action='store', type=int, required=True)
154 | parser.add_argument('--gpu_id', action='store', type=int, default=0, required=False)
155 | parser.add_argument('--data_path', action='store', type=str, required=True)
156 | parser.add_argument('--anomaly_source_path', action='store', type=str, required=True)
157 | parser.add_argument('--checkpoint_path', action='store', type=str, required=True)
158 | parser.add_argument('--log_path', action='store', type=str, required=True)
159 | parser.add_argument('--visualize', action='store_true')
160 |
161 | parser.add_argument("--type", type=str, default="simple")
162 | parser.add_argument("--in_channels", type=int, default=3)
163 | parser.add_argument("--out_ch", type=int, default=3)
164 | parser.add_argument("--ch", type=int, default=128)
165 | parser.add_argument("--ch_mult", type=list, default=[1, 1, 2, 2, 4, 4])
166 | parser.add_argument("--num_res_blocks", type=int, default=2)
167 | parser.add_argument("--attn_resolutions", type=list, default=[16, ])
168 | parser.add_argument("--dropout", type=float, default=0.0)
169 | parser.add_argument("--var_type", type=str, default="fixedsmall")
170 | parser.add_argument("--ema_rate", type=float, default=0.999)
171 | parser.add_argument('--ema', action='store_true')
172 | parser.add_argument('--resamp_with_conv', action='store_true')
173 | parser.add_argument("--num_diffusion_timesteps", type=int, default=1000)
174 | parser.add_argument('--ddim_log_path', action='store', type=str, default="../ddim-main/mvtec/logs")
175 | parser.add_argument("--timesteps", type=int, default=10, help="number of steps involved")
176 | parser.add_argument("--eta", type=float, default=0.0)
177 |
178 | parser.add_argument("--logit_transform", default=False)
179 | parser.add_argument("--uniform_dequantization", default=False)
180 | parser.add_argument("--gaussian_dequantization", default=False)
181 | parser.add_argument("--random_flip", default=True)
182 | parser.add_argument("--rescaled", default=True)
183 | parser.add_argument("--sample_type", type=str, default="generalized",
184 | help="sampling approach (generalized or ddpm_noisy)")
185 | parser.add_argument("--skip_type", type=str, default="uniform", help="skip according to (uniform or quadratic)")
186 |
187 | args = parser.parse_args()
188 |
189 | obj_name = 'bottle'
190 | # 'capsule'
191 | # 'carpet'
192 | # 'leather'
193 | # 'pill'
194 | # 'transistor'
195 | # 'tile'
196 | # 'cable'
197 | # 'zipper'
198 | # 'toothbrush'
199 | # 'metal_nut'
200 | # 'hazelnut'
201 | # 'screw'
202 | # 'grid'
203 | # 'wood'
204 |
205 | with torch.cuda.device(args.gpu_id):
206 | train_on_device(obj_name, args)
207 |
208 |
--------------------------------------------------------------------------------