├── LICENSE ├── asserts ├── DiffFlow_fig1.png └── DiffFlow_tree.png ├── conf ├── config.yaml ├── data │ ├── base_img.yaml │ ├── cifar.yaml │ ├── mnist.yaml │ └── ps.yaml ├── hydra │ └── launcher │ │ └── joblib.yaml ├── model │ ├── cifar.yaml │ ├── cont_cifar.yaml │ ├── cont_mnist.yaml │ ├── cont_mnist_s.yaml │ ├── mnist.yaml │ └── ps.yaml ├── optimizer │ ├── basic.yaml │ ├── cifar.yaml │ ├── fix_drift.yaml │ ├── mnist.yaml │ └── ps.yaml ├── trainer │ ├── base.yaml │ ├── cifar.yaml │ ├── cifar_ddp.yaml │ ├── mnist.yaml │ └── ps.yaml └── wandb │ └── base.yaml ├── datasets ├── __init__.py ├── celeba.py ├── ffhq.py ├── image_dataset.py ├── img_tool.py ├── img_transform.py ├── points_dataset.py ├── sierpinski.jpg ├── sierpinski_hard.jpg ├── tree.png ├── utils.py └── vision.py ├── jam_.yaml ├── main.py ├── modules ├── __init__.py ├── optimizers.py ├── sde.py ├── sde_img_fns.py ├── sde_loss.py └── sde_ps_fns.py ├── networks ├── __init__.py ├── base_model.py ├── diff_flow.py ├── fouriermlp.py ├── official_unet.py └── unet.py ├── poetry.lock ├── pyproject.toml ├── readme.md ├── utils ├── __init__.py ├── ddp_trainer.py ├── diagnosis.py ├── img_viz.py ├── scalars.py ├── sdefunction.py └── trainer.py └── viz ├── __init__.py ├── img.py ├── lines.py └── ps.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Qinsheng Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /asserts/DiffFlow_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/asserts/DiffFlow_fig1.png -------------------------------------------------------------------------------- /asserts/DiffFlow_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/asserts/DiffFlow_tree.png -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # - hydra/launcher: joblib 3 | - data: ps 4 | - model: ${data} 5 | - trainer: ${data} 6 | - optimizer: ${data} 7 | - wandb: base 8 | 9 | 10 | is_dist: ${trainer.is_dist} 11 | cuda: true 12 | 13 | 14 | log: false 15 | name: unet 16 | -------------------------------------------------------------------------------- /conf/data/base_img.yaml: -------------------------------------------------------------------------------- 1 | dataloader: 2 | _target_: jamtorch.ddp.ddp_utils.ddp_loaders 3 | batch_size: ?? 4 | pin_memory: false 5 | num_workers: 0 6 | 7 | dataset: ?? 8 | image_size: ?? 9 | channel: ?? 10 | path: 11 | 12 | train_size: 13 | val_size: ?? 14 | eval_n_samples: 8 15 | 16 | random_flip: true 17 | logit_transform: false 18 | uniform_dequantization: false 19 | gaussian_dequantization: false 20 | rescaled: true 21 | image_mean: 22 | image_std: 23 | 24 | preprocess_fn: 25 | _target_: modules.sde_loss.img_preprocess 26 | 27 | fid: 28 | num_samples: 1000 29 | batch_size: 500 30 | -------------------------------------------------------------------------------- /conf/data/cifar.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_img 3 | 4 | dataloader: 5 | batch_size: 100 6 | 7 | dataset: CIFAR10 8 | image_size: 32 9 | channel: 3 10 | 11 | train_size: 12 | val_size: 1000 13 | -------------------------------------------------------------------------------- /conf/data/mnist.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_img 3 | 4 | dataloader: 5 | batch_size: 64 6 | pin_memory: true 7 | 8 | dataset: MNIST 9 | image_size: 32 10 | channel: 1 11 | 12 | train_size: 13 | val_size: 2000 14 | 15 | random_flip: false 16 | -------------------------------------------------------------------------------- /conf/data/ps.yaml: -------------------------------------------------------------------------------- 1 | dataset: "olympic" 2 | 3 | train_size: 500000 4 | val_size: 2000 5 | 6 | batch_size: 5000 7 | 8 | range: 1.0 9 | noise: 0.01 10 | std: 1.0 11 | 12 | iscenter: true 13 | 14 | density: false 15 | 16 | train_data: 17 | _target_: datasets.points_dataset.PointsDataSet 18 | data_name: ${data.dataset} 19 | num_sample: ${data.train_size} 20 | noise: ${data.noise} 21 | dim_range: ${data.range} 22 | iscenter: ${data.iscenter} 23 | 24 | 25 | val_data: 26 | _target_: datasets.points_dataset.PointsDataSet 27 | data_name: ${data.dataset} 28 | num_sample: ${data.val_size} 29 | noise: ${data.noise} 30 | dim_range: ${data.range} 31 | iscenter: ${data.iscenter} 32 | 33 | dataloader: 34 | _target_: jamtorch.ddp.ddp_utils.ddp_loaders 35 | batch_size: ${data.batch_size} 36 | pin_memory: true 37 | 38 | preprocess_fn: 39 | _target_: modules.sde_loss.point_preprocess 40 | -------------------------------------------------------------------------------- /conf/hydra/launcher/joblib.yaml: -------------------------------------------------------------------------------- 1 | n_jobs: 1 2 | -------------------------------------------------------------------------------- /conf/model/cifar.yaml: -------------------------------------------------------------------------------- 1 | name: sde 2 | quick: true 3 | enable_fid: true 4 | 5 | time_fn: 6 | _target_: utils.scalars.ExpTimer 7 | num_steps: 30 8 | t_start: 0.0001 9 | t_end: 0.1 10 | 11 | cond_fn: ${model.time_fn} 12 | 13 | diff_fn: 14 | _target_: utils.scalars.ExpTimer 15 | num_steps: ${model.time_fn.num_steps} 16 | t_start: 1.0 17 | t_end: ${model.diff_fn.t_start} 18 | 19 | d_in: 20 | - ${data.channel} 21 | - ${data.image_size} 22 | - ${data.image_size} 23 | 24 | drift: 25 | _target_: networks.unet.Unet 26 | dim: 128 27 | dim_mults: 28 | - 1 29 | - 2 30 | - 4 31 | - 8 32 | in_channel: ${data.channel} 33 | 34 | score: 35 | _target_: networks.unet.Unet 36 | dim: 64 37 | dim_mults: 38 | - 1 39 | - 2 40 | - 4 41 | - 8 42 | in_channel: ${data.channel} 43 | 44 | loss_fn: modules.sde_loss.loss_fn_wrapper 45 | trainer_register: modules.sde_img_fns.img_trainer_register 46 | -------------------------------------------------------------------------------- /conf/model/cont_cifar.yaml: -------------------------------------------------------------------------------- 1 | name: sde 2 | quick: true 3 | enable_fid: true 4 | 5 | N_iter: [5000, 15000, 20000, 25000, 30000] 6 | N_values: [5, 10, 20, 30, 50, 75] 7 | 8 | time_fn: 9 | _target_: utils.scalars.STimer 10 | num_steps: 5 11 | t_start: 0.001 12 | t_end: 0.05 13 | 14 | cond_fn: ${model.time_fn} 15 | 16 | diff_fn: 17 | _target_: utils.scalars.ExpTimer 18 | num_steps: ${model.time_fn.num_steps} 19 | t_start: 1.0 20 | t_end: ${model.diff_fn.t_start} 21 | 22 | d_in: 23 | - ${data.channel} 24 | - ${data.image_size} 25 | - ${data.image_size} 26 | 27 | drift: 28 | _target_: networks.official_unet.Model 29 | in_channels: ${data.channel} 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [1, 2, 2, 2] 33 | num_res_blocks: 2 34 | attn_resolutions: [16, ] 35 | dropout: 0.1 36 | resamp_with_conv: true 37 | resolution: 32 38 | 39 | score: 40 | _target_: networks.official_unet.Model 41 | in_channels: ${data.channel} 42 | out_ch: 3 43 | ch: 64 44 | ch_mult: [1, 2, 2, 2] 45 | num_res_blocks: 2 46 | attn_resolutions: [16, ] 47 | dropout: 0.1 48 | resamp_with_conv: true 49 | resolution: 32 50 | 51 | loss_fn: modules.sde_loss.cont_loss_fn_wrapper 52 | trainer_register: modules.sde_img_fns.img_trainer_register 53 | -------------------------------------------------------------------------------- /conf/model/cont_mnist.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mnist 3 | 4 | N_iter: [3000, 5000, 8000] 5 | N_values: [5, 10, 20, 30] 6 | 7 | time_fn: 8 | num_steps: 5 9 | 10 | loss_fn: modules.sde_loss.cont_loss_fn_wrapper -------------------------------------------------------------------------------- /conf/model/cont_mnist_s.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mnist 3 | 4 | N_iter: [6000, 15000, 25000] 5 | N_values: [5, 10, 20, 30] 6 | 7 | time_fn: 8 | _target_: utils.scalars.STimer 9 | num_steps: 5 10 | t_start: 0.001 11 | t_end: 0.05 12 | 13 | cond_fn: ${model.time_fn} 14 | 15 | diff_fn: 16 | _target_: utils.scalars.ExpTimer 17 | num_steps: ${model.time_fn.num_steps} 18 | t_start: 1.0 19 | t_end: ${model.diff_fn.t_start} 20 | 21 | loss_fn: modules.sde_loss.cont_loss_fn_wrapper -------------------------------------------------------------------------------- /conf/model/mnist.yaml: -------------------------------------------------------------------------------- 1 | name: sde 2 | quick: true 3 | enable_fid: false 4 | 5 | time_fn: 6 | _target_: utils.scalars.ExpTimer 7 | num_steps: 30 8 | t_start: 0.001 9 | t_end: 0.1 10 | 11 | cond_fn: ${model.time_fn} 12 | 13 | diff_fn: 14 | _target_: utils.scalars.ExpTimer 15 | num_steps: ${model.time_fn.num_steps} 16 | t_start: 1.0 17 | t_end: ${model.diff_fn.t_start} 18 | 19 | d_in: 20 | - 1 21 | - 32 22 | - 32 23 | 24 | drift: 25 | _target_: networks.unet.Unet 26 | dim: 64 27 | dim_mults: 28 | - 1 29 | - 2 30 | - 4 31 | - 8 32 | in_channel: ${data.channel} 33 | 34 | score: ${model.drift} 35 | 36 | loss_fn: modules.sde_loss.loss_fn_wrapper 37 | trainer_register: modules.sde_img_fns.img_trainer_register 38 | -------------------------------------------------------------------------------- /conf/model/ps.yaml: -------------------------------------------------------------------------------- 1 | name: sde 2 | quick: true 3 | 4 | time_fn: 5 | _target_: utils.scalars.ExpTimer 6 | num_steps: 30 7 | t_start: 0.001 8 | t_end: 0.05 9 | exp: 0.9 10 | 11 | cond_fn: ${model.time_fn} 12 | 13 | diff_fn: 14 | _target_: utils.scalars.ExpTimer 15 | num_steps: ${model.time_fn.num_steps} 16 | t_start: 0.2 17 | t_end: ${model.diff_fn.t_start} 18 | 19 | d_in: 20 | - 2 21 | score: 22 | _target_: networks.fouriermlp.FourierMLP 23 | data_shape: ${model.d_in} 24 | num_layers: 3 25 | channels: 128 26 | 27 | drift: ${model.score} 28 | 29 | 30 | loss_fn: modules.sde_loss.loss_fn_wrapper 31 | trainer_register: modules.sde_ps_fns.points_trainer_register 32 | -------------------------------------------------------------------------------- /conf/optimizer/basic.yaml: -------------------------------------------------------------------------------- 1 | fn: 2 | _target_: modules.optimizers.get_tune_optimizer 3 | drift: 2e-4 4 | score: 2e-4 5 | -------------------------------------------------------------------------------- /conf/optimizer/cifar.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - basic 3 | 4 | drift: 2e-4 5 | score: 2e-4 6 | -------------------------------------------------------------------------------- /conf/optimizer/fix_drift.yaml: -------------------------------------------------------------------------------- 1 | fn: 2 | _target_: model.sde_utils.fix_drift_optimizer 3 | score: 2e-4 4 | -------------------------------------------------------------------------------- /conf/optimizer/mnist.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - basic 3 | 4 | drift: 2e-4 5 | score: 2e-4 6 | -------------------------------------------------------------------------------- /conf/optimizer/ps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - basic 3 | 4 | drift: 1e-2 5 | score: 1e-2 6 | gamma: 0.95 7 | -------------------------------------------------------------------------------- /conf/trainer/base.yaml: -------------------------------------------------------------------------------- 1 | is_dist: false 2 | rank: 0 3 | gpu: 0 4 | world_size: 5 | 6 | resume: false 7 | ckpt: 8 | ckpt_dir: 9 | epochs: ?? 10 | clip_grad: 1.0 11 | enable_ema: true 12 | use_amp: true 13 | ema: 14 | _target_: jamtorch.io.EMA 15 | beta: 0.999 16 | num_warm: 10 17 | num_every: 1 18 | forget_resume: false 19 | 20 | eval_epoch: 1 21 | eval_iter: -1 22 | -------------------------------------------------------------------------------- /conf/trainer/cifar.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | epochs: 100 5 | eval_epoch: -1 6 | eval_iter: 150 7 | -------------------------------------------------------------------------------- /conf/trainer/cifar_ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | is_dist: true 5 | rank: 6 | gpu: 0 7 | cwd: 8 | world_size: 9 | dist: 10 | adjust_lr: false 11 | master_addr: 'localhost' 12 | master_port: '12355' 13 | mode: 'nccl' 14 | syncBN: false 15 | timeout: 16 | 17 | epochs: 100 18 | eval_epoch: -1 19 | eval_iter: 150 20 | ema_master: 21 | -------------------------------------------------------------------------------- /conf/trainer/mnist.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | epochs: 100 5 | eval_epoch: 1 6 | eval_iter: -1 7 | -------------------------------------------------------------------------------- /conf/trainer/ps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | epochs: 30 5 | -------------------------------------------------------------------------------- /conf/wandb/base.yaml: -------------------------------------------------------------------------------- 1 | log: ${log} 2 | project: "diff_flow" 3 | name: ${name} 4 | tags: 5 | notes: 6 | entity: "qinsheng" 7 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import datasets.image_dataset as img_dataset 2 | import datasets.points_dataset as ps_dataset 3 | 4 | from .img_transform import * 5 | 6 | 7 | def get_dataset(cfg): 8 | if cfg.dataset in ps_dataset.skd_func: 9 | return ps_dataset.get_ps_dataset(cfg) 10 | if cfg.dataset in ["MNIST", "CIFAR10"]: 11 | return img_dataset.get_img_dataset(cfg) 12 | 13 | raise RuntimeError("Not find dataset func") 14 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | import os 4 | 5 | import PIL 6 | import torch 7 | 8 | from .utils import check_integrity, download_file_from_google_drive 9 | from .vision import VisionDataset 10 | 11 | 12 | class CelebA(VisionDataset): 13 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 14 | 15 | Args: 16 | root (string): Root directory where images are downloaded to. 17 | split (string): One of {'train', 'valid', 'test'}. 18 | Accordingly dataset is selected. 19 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 20 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 21 | The targets represent: 22 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 23 | ``identity`` (int): label for each person (data points with the same identity are the same person) 24 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 25 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 26 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 27 | Defaults to ``attr``. 28 | transform (callable, optional): A function/transform that takes in an PIL image 29 | and returns a transformed version. E.g, ``transforms.ToTensor`` 30 | target_transform (callable, optional): A function/transform that takes in the 31 | target and transforms it. 32 | download (bool, optional): If true, downloads the dataset from the internet and 33 | puts it in root directory. If dataset is already downloaded, it is not 34 | downloaded again. 35 | """ 36 | 37 | base_folder = "celeba" 38 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 39 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 40 | # right now. 41 | file_list = [ 42 | # File ID MD5 Hash Filename 43 | ( 44 | "0B7EVK8r0v71pZjFTYXZWM3FlRnM", 45 | "00d2c5bc6d35e252742224ab0c1e8fcb", 46 | "img_align_celeba.zip", 47 | ), 48 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 49 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 50 | ( 51 | "0B7EVK8r0v71pblRyaVFSWGxPY0U", 52 | "75e246fa4810816ffd6ee81facbd244c", 53 | "list_attr_celeba.txt", 54 | ), 55 | ( 56 | "1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", 57 | "32bd1bd63d3c78cd57e08160ec5ed1e2", 58 | "identity_CelebA.txt", 59 | ), 60 | ( 61 | "0B7EVK8r0v71pbThiMVRxWXZ4dU0", 62 | "00566efa6fedff7a56946cd1c10f1c16", 63 | "list_bbox_celeba.txt", 64 | ), 65 | ( 66 | "0B7EVK8r0v71pd0FJY3Blby1HUTQ", 67 | "cc24ecafdb5b50baae59b03474781f8c", 68 | "list_landmarks_align_celeba.txt", 69 | ), 70 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 71 | ( 72 | "0B7EVK8r0v71pY0NSMzRuSXJEVkk", 73 | "d32c9cbf5e040fd4025c592c306e6668", 74 | "list_eval_partition.txt", 75 | ), 76 | ] 77 | 78 | def __init__( 79 | self, 80 | root, 81 | split="train", 82 | target_type="attr", 83 | transform=None, 84 | target_transform=None, 85 | download=False, 86 | ): 87 | import pandas 88 | 89 | super(CelebA, self).__init__(root) 90 | self.split = split 91 | if isinstance(target_type, list): 92 | self.target_type = target_type 93 | else: 94 | self.target_type = [target_type] 95 | self.transform = transform 96 | self.target_transform = target_transform 97 | 98 | if download: 99 | self.download() 100 | 101 | if not self._check_integrity(): 102 | raise RuntimeError( 103 | "Dataset not found or corrupted." 104 | + " You can use download=True to download it" 105 | ) 106 | 107 | self.transform = transform 108 | self.target_transform = target_transform 109 | 110 | if split.lower() == "train": 111 | split = 0 112 | elif split.lower() == "valid": 113 | split = 1 114 | elif split.lower() == "test": 115 | split = 2 116 | else: 117 | raise ValueError( 118 | 'Wrong split entered! Please use split="train" ' 119 | 'or split="valid" or split="test"' 120 | ) 121 | 122 | with open( 123 | os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r" 124 | ) as f: 125 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 126 | 127 | with open( 128 | os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r" 129 | ) as f: 130 | self.identity = pandas.read_csv( 131 | f, delim_whitespace=True, header=None, index_col=0 132 | ) 133 | 134 | with open( 135 | os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r" 136 | ) as f: 137 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 138 | 139 | with open( 140 | os.path.join( 141 | self.root, self.base_folder, "list_landmarks_align_celeba.txt" 142 | ), 143 | "r", 144 | ) as f: 145 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 146 | 147 | with open( 148 | os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r" 149 | ) as f: 150 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 151 | 152 | mask = splits[1] == split 153 | self.filename = splits[mask].index.values 154 | self.identity = torch.as_tensor(self.identity[mask].values) 155 | self.bbox = torch.as_tensor(self.bbox[mask].values) 156 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 157 | self.attr = torch.as_tensor(self.attr[mask].values) 158 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 159 | 160 | def _check_integrity(self): 161 | for (_, md5, filename) in self.file_list: 162 | fpath = os.path.join(self.root, self.base_folder, filename) 163 | _, ext = os.path.splitext(filename) 164 | # Allow original archive to be deleted (zip and 7z) 165 | # Only need the extracted images 166 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 167 | return False 168 | 169 | # Should check a hash of the images 170 | return os.path.isdir( 171 | os.path.join(self.root, self.base_folder, "img_align_celeba") 172 | ) 173 | 174 | def download(self): 175 | import zipfile 176 | 177 | if self._check_integrity(): 178 | print("Files already downloaded and verified") 179 | return 180 | 181 | for (file_id, md5, filename) in self.file_list: 182 | download_file_from_google_drive( 183 | file_id, os.path.join(self.root, self.base_folder), filename, md5 184 | ) 185 | 186 | with zipfile.ZipFile( 187 | os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r" 188 | ) as f: 189 | f.extractall(os.path.join(self.root, self.base_folder)) 190 | 191 | def __getitem__(self, index): 192 | X = PIL.Image.open( 193 | os.path.join( 194 | self.root, self.base_folder, "img_align_celeba", self.filename[index] 195 | ) 196 | ) 197 | 198 | target = [] 199 | for t in self.target_type: 200 | if t == "attr": 201 | target.append(self.attr[index, :]) 202 | elif t == "identity": 203 | target.append(self.identity[index, 0]) 204 | elif t == "bbox": 205 | target.append(self.bbox[index, :]) 206 | elif t == "landmarks": 207 | target.append(self.landmarks_align[index, :]) 208 | else: 209 | raise ValueError('Target type "{}" is not recognized.'.format(t)) 210 | target = tuple(target) if len(target) > 1 else target[0] 211 | 212 | if self.transform is not None: 213 | X = self.transform(X) 214 | 215 | if self.target_transform is not None: 216 | target = self.target_transform(target) 217 | 218 | return X, target 219 | 220 | def __len__(self): 221 | return len(self.attr) 222 | 223 | def extra_repr(self): 224 | lines = ["Target type: {target_type}", "Split: {split}"] 225 | return "\n".join(lines).format(**self.__dict__) 226 | -------------------------------------------------------------------------------- /datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from io import BytesIO 4 | 5 | import lmdb 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class FFHQ(Dataset): 11 | def __init__(self, path, transform, resolution=8): 12 | self.env = lmdb.open( 13 | path, 14 | max_readers=32, 15 | readonly=True, 16 | lock=False, 17 | readahead=False, 18 | meminit=False, 19 | ) 20 | 21 | if not self.env: 22 | raise IOError("Cannot open lmdb dataset", path) 23 | 24 | with self.env.begin(write=False) as txn: 25 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 26 | 27 | self.resolution = resolution 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return self.length 32 | 33 | def __getitem__(self, index): 34 | with self.env.begin(write=False) as txn: 35 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 36 | img_bytes = txn.get(key) 37 | 38 | buffer = BytesIO(img_bytes) 39 | img = Image.open(buffer) 40 | img = self.transform(img) 41 | target = 0 42 | 43 | return img, target 44 | -------------------------------------------------------------------------------- /datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | from jammy.utils.git import git_rootdir 6 | from torch.utils.data import Subset 7 | from torchvision.datasets import CIFAR10, LSUN, MNIST 8 | 9 | from datasets.celeba import CelebA 10 | from datasets.ffhq import FFHQ 11 | 12 | __all__ = ["get_img_dataset"] 13 | 14 | 15 | def mnist_dataset(data_path, img_size=28): 16 | train = MNIST( 17 | data_path, 18 | train=True, 19 | download=True, 20 | transform=transforms.Compose( 21 | [ 22 | transforms.Resize(img_size), 23 | transforms.ToTensor(), 24 | ] 25 | ), 26 | ) 27 | test = MNIST( 28 | data_path, 29 | train=False, 30 | download=True, 31 | transform=transforms.Compose( 32 | [ 33 | transforms.Resize(img_size), 34 | transforms.ToTensor(), 35 | ] 36 | ), 37 | ) 38 | return train, test 39 | 40 | 41 | def init_data_config(config): 42 | if "path" not in config: 43 | config.path = git_rootdir("data") 44 | 45 | 46 | def get_img_dataset(config): # pylint: disable=too-many-branches 47 | init_data_config(config) 48 | if config.dataset == "MNIST": 49 | return mnist_dataset(config.path, config.image_size) 50 | if config.random_flip is False: 51 | tran_transform = test_transform = transforms.Compose( 52 | [transforms.Resize(config.image_size), transforms.ToTensor()] 53 | ) 54 | else: 55 | tran_transform = transforms.Compose( 56 | [ 57 | transforms.Resize(config.image_size), 58 | transforms.RandomHorizontalFlip(p=0.5), 59 | transforms.ToTensor(), 60 | ] 61 | ) 62 | test_transform = transforms.Compose( 63 | [transforms.Resize(config.image_size), transforms.ToTensor()] 64 | ) 65 | 66 | if config.dataset == "CIFAR10": 67 | dataset = CIFAR10( 68 | os.path.join(config.path, "datasets", "cifar10"), 69 | train=True, 70 | download=True, 71 | transform=tran_transform, 72 | ) 73 | test_dataset = CIFAR10( 74 | os.path.join(config.path, "datasets", "cifar10_test"), 75 | train=False, 76 | download=True, 77 | transform=test_transform, 78 | ) 79 | 80 | elif config.dataset == "CELEBA": 81 | if config.random_flip: 82 | dataset = CelebA( 83 | root=os.path.join(config.path, "datasets", "celeba"), 84 | split="train", 85 | transform=transforms.Compose( 86 | [ 87 | transforms.CenterCrop(140), 88 | transforms.Resize(config.image_size), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | ] 92 | ), 93 | download=True, 94 | ) 95 | else: 96 | dataset = CelebA( 97 | root=os.path.join(config.path, "datasets", "celeba"), 98 | split="train", 99 | transform=transforms.Compose( 100 | [ 101 | transforms.CenterCrop(140), 102 | transforms.Resize(config.image_size), 103 | transforms.ToTensor(), 104 | ] 105 | ), 106 | download=True, 107 | ) 108 | 109 | test_dataset = CelebA( 110 | root=os.path.join(config.path, "datasets", "celeba_test"), 111 | split="test", 112 | transform=transforms.Compose( 113 | [ 114 | transforms.CenterCrop(140), 115 | transforms.Resize(config.image_size), 116 | transforms.ToTensor(), 117 | ] 118 | ), 119 | download=True, 120 | ) 121 | 122 | elif config.dataset == "LSUN": 123 | train_folder = "{}_train".format(config.category) 124 | val_folder = "{}_val".format(config.category) 125 | if config.random_flip: 126 | dataset = LSUN( 127 | root=os.path.join(config.path, "datasets", "lsun"), 128 | classes=[train_folder], 129 | transform=transforms.Compose( 130 | [ 131 | transforms.Resize(config.image_size), 132 | transforms.CenterCrop(config.image_size), 133 | transforms.RandomHorizontalFlip(p=0.5), 134 | transforms.ToTensor(), 135 | ] 136 | ), 137 | ) 138 | else: 139 | dataset = LSUN( 140 | root=os.path.join(config.path, "datasets", "lsun"), 141 | classes=[train_folder], 142 | transform=transforms.Compose( 143 | [ 144 | transforms.Resize(config.image_size), 145 | transforms.CenterCrop(config.image_size), 146 | transforms.ToTensor(), 147 | ] 148 | ), 149 | ) 150 | 151 | test_dataset = LSUN( 152 | root=os.path.join(config.path, "datasets", "lsun"), 153 | classes=[val_folder], 154 | transform=transforms.Compose( 155 | [ 156 | transforms.Resize(config.image_size), 157 | transforms.CenterCrop(config.image_size), 158 | transforms.ToTensor(), 159 | ] 160 | ), 161 | ) 162 | 163 | elif config.dataset == "FFHQ": 164 | if config.random_flip: 165 | dataset = FFHQ( 166 | path=os.path.join(config.path, "datasets", "FFHQ"), 167 | transform=transforms.Compose( 168 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 169 | ), 170 | resolution=config.image_size, 171 | ) 172 | else: 173 | dataset = FFHQ( 174 | path=os.path.join(config.path, "datasets", "FFHQ"), 175 | transform=transforms.ToTensor(), 176 | resolution=config.image_size, 177 | ) 178 | 179 | num_items = len(dataset) 180 | indices = list(range(num_items)) 181 | random_state = np.random.get_state() 182 | np.random.seed(2019) 183 | np.random.shuffle(indices) 184 | np.random.set_state(random_state) 185 | train_indices, test_indices = ( 186 | indices[: int(num_items * 0.9)], 187 | indices[int(num_items * 0.9) :], 188 | ) 189 | test_dataset = Subset(dataset, test_indices) 190 | dataset = Subset(dataset, train_indices) 191 | 192 | return dataset, test_dataset 193 | -------------------------------------------------------------------------------- /datasets/img_tool.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import numpy as np 3 | from jammy.image import imread 4 | 5 | 6 | def prepare_image( 7 | img_path, crop=None, embed=None, white_cutoff=225, gauss_sigma=5, background=0.0001 8 | ): 9 | """Transforms rgb image array into 2D-density and energy 10 | 11 | Parameters 12 | ---------- 13 | density : ndarray(width, height) 14 | Probability density 15 | 16 | energy : ndarray(width, height) 17 | Energy 18 | 19 | """ 20 | img = imread(img_path) 21 | 22 | # make one channel 23 | img = img.mean(axis=2) 24 | 25 | # make background white 26 | img = img.astype(np.float32) 27 | img[img > white_cutoff] = 255 28 | 29 | # normalize 30 | img /= img.max() 31 | 32 | if crop is not None: 33 | # crop 34 | img = img[crop[0] : crop[1], crop[2] : crop[3]] 35 | 36 | if embed is not None: 37 | tmp = np.ones((embed[0], embed[1]), dtype=np.float32) 38 | shift_x = (embed[0] - img.shape[0]) // 2 39 | shift_y = (embed[1] - img.shape[1]) // 2 40 | tmp[shift_x : img.shape[0] + shift_x, shift_y : img.shape[1] + shift_y] = img 41 | img = tmp 42 | 43 | # convolve with Gaussian 44 | from scipy.ndimage import gaussian_filter 45 | 46 | # TODO: need to tune the gauss_sigma smoothness of image 47 | img2 = gaussian_filter(img, sigma=gauss_sigma) 48 | 49 | # add background 50 | background1 = gaussian_filter(img, sigma=10) 51 | background2 = gaussian_filter(img, sigma=20) 52 | background3 = gaussian_filter(img, sigma=50) 53 | density = (1.0 - img2) + background * (background1 + background2 + background3) 54 | 55 | return density 56 | 57 | 58 | class ImageSampler(object): 59 | def __init__(self, img_density, mean=[350, 350], scale=[350, 350]): 60 | """Samples continuous coordinates from image density 61 | 62 | Parameters 63 | ---------- 64 | img_density : ndarray(width, height) 65 | Image probability density 66 | 67 | mean : (int, int) 68 | center pixel 69 | 70 | scale : (int, int) 71 | number of pixels to scale to 1.0 (in x and y direction) 72 | 73 | """ 74 | self.img_density = img_density 75 | Ix, Iy = np.meshgrid( 76 | np.arange(img_density.shape[1]), np.arange(img_density.shape[0]) 77 | ) 78 | self.idx = np.vstack([Ix.flatten(), Iy.flatten()]).T 79 | 80 | # draw samples from density 81 | density_normed = img_density.astype(np.float64) 82 | density_normed /= density_normed.sum() 83 | self.density_flat = density_normed.flatten() 84 | self.mean = np.array([mean]) 85 | self.scale = np.array([scale]) 86 | 87 | def sample(self, nsample): 88 | # draw random index 89 | i = np.random.choice(self.idx.shape[0], size=nsample, p=self.density_flat) 90 | ixy = self.idx[i, :] 91 | 92 | # simple dequantization, uniformally sample in the grid 93 | xy = ixy + np.random.rand(nsample, 2) - 0.5 94 | 95 | # normalize shape 96 | xy = (xy - self.mean) / self.scale 97 | 98 | return xy 99 | -------------------------------------------------------------------------------- /datasets/img_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ["logit_transform", "data_transform", "inverse_data_transform"] 4 | 5 | 6 | def logit_transform(image, lam=1e-6): 7 | image = lam + (1 - 2 * lam) * image 8 | return torch.log(image) - torch.log1p(-image) 9 | 10 | 11 | def data_transform(config, x): 12 | if config.uniform_dequantization: 13 | x = x / 256.0 * 255.0 + torch.rand_like(x) / 256.0 14 | if config.gaussian_dequantization: 15 | x = x + torch.randn_like(x) * 0.01 16 | 17 | if config.rescaled: 18 | x = 2 * x - 1.0 19 | elif config.logit_transform: 20 | x = logit_transform(x) 21 | 22 | if config.image_mean is not None and config.image_std is not None: 23 | return ( 24 | x - torch.FloatTensor(config.image_mean).to(x.device)[:, None, None] 25 | ) / torch.FloatTensor(config.image_std).to(x.device)[:, None, None] 26 | return x 27 | 28 | 29 | def inverse_data_transform(config, x): 30 | if config.image_mean is not None and config.image_std is not None: 31 | x = ( 32 | x * torch.FloatTensor(config.image_std).to(x.device)[:, None, None] 33 | + torch.FloatTensor(config.image_mean).to(x.device)[:, None, None] 34 | ) 35 | 36 | if config.logit_transform: 37 | x = torch.sigmoid(x) 38 | elif config.rescaled: 39 | x = (x + 1.0) / 2.0 40 | 41 | return torch.clamp(x, 0.0, 1.0) 42 | -------------------------------------------------------------------------------- /datasets/points_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import sklearn.datasets as skd 5 | from einops import rearrange 6 | from jammy import jam_instantiate 7 | from jamtorch.utils.meta import as_numpy 8 | from torch.utils.data import Dataset 9 | 10 | from .img_tool import ImageSampler, prepare_image 11 | 12 | # pylint: disable=global-statement, unused-argument, too-many-arguments 13 | DIM_LINSPACE = None 14 | G_MEAN = 0.0 15 | G_STD = 1.0 16 | G_SET_STD = 1.0 17 | 18 | 19 | def swissroll_generate_sample(N, noise=0.25): 20 | data = skd.make_swiss_roll(n_samples=N, noise=noise)[0] 21 | data = data.astype("float32")[:, [0, 2]] 22 | return data 23 | 24 | 25 | def moon_generate_sample(N, noise=0.25): 26 | data = skd.make_moons(n_samples=N, noise=noise)[0] 27 | data = data.astype("float32") 28 | return data 29 | 30 | 31 | def checkerboard_generate_sample(N, noise=0.25): 32 | x1 = np.random.rand(N) * 4 - 2 33 | x2_ = np.random.rand(N) - np.random.randint(0, 2, N) * 2 34 | x2 = x2_ + (np.floor(x1) % 2) 35 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 36 | 37 | 38 | def line_generate_sample(N, noise=0.25): 39 | assert noise <= 1.0 40 | cov = np.array([[1.0, 1 - noise], [1 - noise, 1.0]]) 41 | mean = np.array([0.0, 0.0]) 42 | return np.random.multivariate_normal(mean, cov, N) 43 | 44 | 45 | def circle_generate_sample(N, noise=0.25): 46 | angle = np.random.uniform(high=2 * np.pi, size=N) 47 | random_noise = np.random.normal(scale=np.sqrt(0.2), size=(N, 2)) 48 | pos = np.concatenate([np.cos(angle), np.sin(angle)]) 49 | pos = rearrange(pos, "(b c) -> c b", b=2) 50 | return pos + noise * random_noise 51 | 52 | 53 | def olympic_generate_sample(N, noise=0.25): 54 | w = 3.5 55 | h = 1.5 56 | centers = np.array([[-w, h], [0.0, h], [w, h], [-w * 0.6, -h], [w * 0.6, -h]]) 57 | pos = [ 58 | circle_generate_sample(N // 5, noise) + centers[i : i + 1] / 2 for i in range(5) 59 | ] 60 | return np.concatenate(pos) 61 | 62 | 63 | def four_generate_sample(N, noise=0.25): 64 | w = 3.5 65 | h = 1.5 66 | centers = np.array([[0.0, h], [w, h], [-w * 0.6, -h], [w * 0.6, -h]]) 67 | pos = [ 68 | circle_generate_sample(N // 4, noise) + centers[i : i + 1] / 2 for i in range(4) 69 | ] 70 | return np.concatenate(pos) 71 | 72 | 73 | def dog_sample(N, noise=0.25): 74 | density = prepare_image( 75 | Path(__file__).parent.joinpath("dog.jpg"), 76 | crop=(10, 710, 240, 940), 77 | white_cutoff=225, 78 | gauss_sigma=10 * noise, 79 | ) 80 | sampler = ImageSampler(density[::-1].copy(), mean=[350, 350], scale=[100, 100]) 81 | return sampler.sample(N) 82 | 83 | 84 | def tree_sample(N, noise=0.0): 85 | density = prepare_image( 86 | Path(__file__).parent.joinpath("tree.png"), 87 | # crop=(10, 710, 240, 940), 88 | white_cutoff=225, 89 | gauss_sigma=10 * noise, 90 | ) 91 | sampler = ImageSampler(density[::-1].copy(), mean=[275, 225], scale=[275, 225]) 92 | return sampler.sample(N) 93 | 94 | 95 | def sier_sample(N, noise=0.0): 96 | density = prepare_image( 97 | Path(__file__).parent.joinpath("sierpinski.jpg"), 98 | white_cutoff=225, 99 | gauss_sigma=10 * noise, 100 | ) 101 | sampler = ImageSampler(density[::-1].copy(), mean=[365, 365], scale=[365, 365]) 102 | return sampler.sample(N) 103 | 104 | 105 | def sier_hard_sample(N, noise=0.0): 106 | density = prepare_image( 107 | Path(__file__).parent.joinpath("sierpinski_hard.jpg"), 108 | white_cutoff=225, 109 | gauss_sigma=10 * noise, 110 | ) 111 | sampler = ImageSampler(density[::-1].copy(), mean=[365, 365], scale=[365, 365]) 112 | return sampler.sample(N) 113 | 114 | 115 | def word_sample(N, noise=0.25): 116 | density = prepare_image( 117 | Path(__file__).parent.joinpath("word.jpg"), 118 | white_cutoff=225, 119 | gauss_sigma=10 * noise, 120 | ) 121 | sampler = ImageSampler(density[::-1].copy(), mean=[350, 350], scale=[100, 100]) 122 | return sampler.sample(N) 123 | 124 | 125 | def smile_sample(N, noise=0.25): 126 | density = prepare_image( 127 | Path(__file__).parent.joinpath("smile.jpg"), 128 | embed=(1000, 1000), 129 | white_cutoff=225, 130 | gauss_sigma=10 * noise, 131 | ) 132 | sampler = ImageSampler(density[::-1].copy(), mean=[500, 225], scale=[200, 200]) 133 | return sampler.sample(N) 134 | 135 | 136 | def spirals_sample(N, noise=0.25): 137 | n = np.sqrt(np.random.rand(N // 2, 1)) * 540 * (2 * np.pi) / 360 138 | d1x = -np.cos(n) * n + np.random.rand(N // 2, 1) * 0.5 139 | d1y = np.sin(n) * n + np.random.rand(N // 2, 1) * 0.5 140 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 141 | x += np.random.randn(*x.shape) * 0.1 142 | return x 143 | 144 | 145 | def gaussian_sample(N, noise=0.25): 146 | scale = 4.0 147 | centers = [ 148 | (1, 0), 149 | (-1, 0), 150 | (0, 1), 151 | (0, -1), 152 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 153 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 154 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 155 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 156 | ] 157 | centers = [(scale * x, scale * y) for x, y in centers] 158 | 159 | dataset = [] 160 | for _ in range(N): 161 | point = np.random.randn(2) * 0.5 162 | idx = np.random.randint(8) 163 | center = centers[idx] 164 | point[0] += center[0] 165 | point[1] += center[1] 166 | dataset.append(point) 167 | dataset = np.array(dataset, dtype="float32") 168 | dataset /= 1.414 169 | return dataset 170 | 171 | 172 | skd_func = { 173 | "swissroll": swissroll_generate_sample, 174 | "checkerboard": checkerboard_generate_sample, 175 | "line": line_generate_sample, 176 | "dog": dog_sample, 177 | "word": word_sample, 178 | "smile": smile_sample, 179 | "circle": circle_generate_sample, 180 | "olympic": olympic_generate_sample, 181 | "moon": moon_generate_sample, 182 | "four": four_generate_sample, 183 | "8gaussian": gaussian_sample, 184 | "2spirals": spirals_sample, 185 | "tree": tree_sample, 186 | "sier": sier_sample, 187 | "sierpp": sier_hard_sample, 188 | } 189 | 190 | 191 | class PointsDataSet(Dataset): 192 | def __init__( 193 | self, data_name, num_sample, noise, dim_range=1.0, iscenter=True, shuffle=True 194 | ): 195 | self.name = data_name 196 | self.num_sample = num_sample 197 | self.noise = noise 198 | self.iscenter = iscenter 199 | global DIM_LINSPACE 200 | DIM_LINSPACE = np.linspace(-dim_range, dim_range, 200) 201 | self.dim_range = dim_range 202 | self.shuffle = shuffle 203 | self.data = self.generate_sample() 204 | 205 | def generate_sample(self): 206 | global skd_func 207 | data = skd_func[self.name](self.num_sample, self.noise) 208 | _max = np.max(np.abs(data)) 209 | data = (self.dim_range * 0.85 / _max) * data 210 | if self.shuffle: 211 | np.random.shuffle(data) 212 | return data 213 | 214 | def normalize(self, set_std): 215 | global G_MEAN, G_STD, G_SET_STD 216 | if self.iscenter: 217 | G_MEAN = np.mean(self.data, axis=0, keepdims=True) 218 | G_STD = np.std(self.data, axis=0, keepdims=True) 219 | G_SET_STD = set_std 220 | 221 | self.data = (self.data - G_MEAN) / G_STD * set_std 222 | 223 | def __len__(self): 224 | return self.num_sample 225 | 226 | def __getitem__(self, idx): 227 | return self.data[idx] 228 | 229 | 230 | def restore(data): 231 | global G_MEAN, G_STD, G_SET_STD 232 | data = as_numpy(data) 233 | sample = data / G_SET_STD * G_STD + G_MEAN 234 | return sample 235 | 236 | 237 | def get_ps_dataset(config): 238 | trainset = jam_instantiate(config.train_data) 239 | valset = jam_instantiate(config.val_data) 240 | trainset.normalize(config.std) 241 | valset.normalize(config.std) 242 | return trainset, valset 243 | -------------------------------------------------------------------------------- /datasets/sierpinski.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/datasets/sierpinski.jpg -------------------------------------------------------------------------------- /datasets/sierpinski_hard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/datasets/sierpinski_hard.jpg -------------------------------------------------------------------------------- /datasets/tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/datasets/tree.png -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | import errno 4 | import hashlib 5 | import os 6 | import os.path 7 | 8 | from torch.utils.model_zoo import tqdm 9 | 10 | 11 | def gen_bar_updater(): 12 | pbar = tqdm(total=None) 13 | 14 | def bar_update(count, block_size, total_size): 15 | if pbar.total is None and total_size: 16 | pbar.total = total_size 17 | progress_bytes = count * block_size 18 | pbar.update(progress_bytes - pbar.n) 19 | 20 | return bar_update 21 | 22 | 23 | def check_integrity(fpath, md5=None): 24 | if md5 is None: 25 | return True 26 | if not os.path.isfile(fpath): 27 | return False 28 | md5o = hashlib.md5() 29 | with open(fpath, "rb") as f: 30 | # read in 1MB chunks 31 | for chunk in iter(lambda: f.read(1024 * 1024), b""): 32 | md5o.update(chunk) 33 | md5c = md5o.hexdigest() 34 | if md5c != md5: 35 | return False 36 | return True 37 | 38 | 39 | def makedir_exist_ok(dirpath): 40 | """ 41 | Python2 support for os.makedirs(.., exist_ok=True) 42 | """ 43 | try: 44 | os.makedirs(dirpath) 45 | except OSError as e: 46 | if e.errno == errno.EEXIST: 47 | pass 48 | else: 49 | raise 50 | 51 | 52 | def download_url(url, root, filename=None, md5=None): 53 | """Download a file from a url and place it in root. 54 | 55 | Args: 56 | url (str): URL to download file from 57 | root (str): Directory to place downloaded file in 58 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 59 | md5 (str, optional): MD5 checksum of the download. If None, do not check 60 | """ 61 | from six.moves import urllib 62 | 63 | root = os.path.expanduser(root) 64 | if not filename: 65 | filename = os.path.basename(url) 66 | fpath = os.path.join(root, filename) 67 | 68 | makedir_exist_ok(root) 69 | 70 | # downloads file 71 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 72 | print("Using downloaded and verified file: " + fpath) 73 | else: 74 | try: 75 | print("Downloading " + url + " to " + fpath) 76 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 77 | except OSError: 78 | if url[:5] == "https": 79 | url = url.replace("https:", "http:") 80 | print( 81 | "Failed download. Trying https -> http instead." 82 | " Downloading " + url + " to " + fpath 83 | ) 84 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 85 | 86 | 87 | def list_dir(root, prefix=False): 88 | """List all directories at a given root 89 | 90 | Args: 91 | root (str): Path to directory whose folders need to be listed 92 | prefix (bool, optional): If true, prepends the path to each result, otherwise 93 | only returns the name of the directories found 94 | """ 95 | root = os.path.expanduser(root) 96 | directories = list( 97 | filter(lambda p: os.path.isdir(os.path.join(root, p)), os.listdir(root)) 98 | ) 99 | 100 | if prefix is True: 101 | directories = [os.path.join(root, d) for d in directories] 102 | 103 | return directories 104 | 105 | 106 | def list_files(root, suffix, prefix=False): 107 | """List all files ending with a suffix at a given root 108 | 109 | Args: 110 | root (str): Path to directory whose folders need to be listed 111 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 112 | It uses the Python "str.endswith" method and is passed directly 113 | prefix (bool, optional): If true, prepends the path to each result, otherwise 114 | only returns the name of the files found 115 | """ 116 | root = os.path.expanduser(root) 117 | files = list( 118 | filter( 119 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 120 | os.listdir(root), 121 | ) 122 | ) 123 | 124 | if prefix is True: 125 | files = [os.path.join(root, d) for d in files] 126 | 127 | return files 128 | 129 | 130 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 131 | """Download a Google Drive file from and place it in root. 132 | 133 | Args: 134 | file_id (str): id of file to be downloaded 135 | root (str): Directory to place downloaded file in 136 | filename (str, optional): Name to save the file under. If None, use the id of the file. 137 | md5 (str, optional): MD5 checksum of the download. If None, do not check 138 | """ 139 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 140 | import requests 141 | 142 | url = "https://docs.google.com/uc?export=download" 143 | 144 | root = os.path.expanduser(root) 145 | if not filename: 146 | filename = file_id 147 | fpath = os.path.join(root, filename) 148 | 149 | makedir_exist_ok(root) 150 | 151 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 152 | print("Using downloaded and verified file: " + fpath) 153 | else: 154 | session = requests.Session() 155 | 156 | response = session.get(url, params={"id": file_id}, stream=True) 157 | token = _get_confirm_token(response) 158 | 159 | if token: 160 | params = {"id": file_id, "confirm": token} 161 | response = session.get(url, params=params, stream=True) 162 | 163 | _save_response_content(response, fpath) 164 | 165 | 166 | def _get_confirm_token(response): 167 | for key, value in response.cookies.items(): 168 | if key.startswith("download_warning"): 169 | return value 170 | 171 | return None 172 | 173 | 174 | def _save_response_content(response, destination, chunk_size=32768): 175 | with open(destination, "wb") as f: 176 | pbar = tqdm(total=None) 177 | progress = 0 178 | for chunk in response.iter_content(chunk_size): 179 | if chunk: # filter out keep-alive new chunks 180 | f.write(chunk) 181 | progress += len(chunk) 182 | pbar.update(progress - pbar.n) 183 | pbar.close() 184 | -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | import os 4 | 5 | import torch 6 | import torch.utils.data as data 7 | 8 | 9 | class VisionDataset(data.Dataset): 10 | _repr_indent = 4 11 | 12 | def __init__(self, root): 13 | if isinstance(root, torch._six.string_classes): 14 | root = os.path.expanduser(root) 15 | self.root = root 16 | 17 | def __getitem__(self, index): 18 | raise NotImplementedError 19 | 20 | def __len__(self): 21 | raise NotImplementedError 22 | 23 | def __repr__(self): 24 | head = "Dataset " + self.__class__.__name__ 25 | body = ["Number of datapoints: {}".format(self.__len__())] 26 | if self.root is not None: 27 | body.append("Root location: {}".format(self.root)) 28 | body += self.extra_repr().splitlines() 29 | if hasattr(self, "transform") and self.transform is not None: 30 | body += self._format_transform_repr(self.transform, "Transforms: ") 31 | if hasattr(self, "target_transform") and self.target_transform is not None: 32 | body += self._format_transform_repr( 33 | self.target_transform, "Target transforms: " 34 | ) 35 | lines = [head] + [" " * self._repr_indent + line for line in body] 36 | return "\n".join(lines) 37 | 38 | def _format_transform_repr(self, transform, head): 39 | lines = transform.__repr__().splitlines() 40 | return ["{}{}".format(head, lines[0])] + [ 41 | "{}{}".format(" " * len(head), line) for line in lines[1:] 42 | ] 43 | 44 | def extra_repr(self): 45 | return "" 46 | -------------------------------------------------------------------------------- /jam_.yaml: -------------------------------------------------------------------------------- 1 | #conda: 2 | #env: base 3 | 4 | system: 5 | envs: 6 | JAM_RANDOM_SEED: 3 7 | JAM_DEBUG: true 8 | JAM_PROJ_PATH: TODO 9 | WANDB_API_KEY: TODO 10 | HYDRA_FULL_ERROR: 1 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import jamtorch.prototype as jampt 3 | import torch 4 | import torch.multiprocessing as mp 5 | from jammy import hydpath, jam_instantiate, link_hyd_run, load_class 6 | from jammy.logging import Wandb, get_logger 7 | from jamtorch.data import get_subset 8 | from jamtorch.ddp import ddp_utils 9 | from jamtorch.trainer import check_loss_error, trainer_save_cfg 10 | from omegaconf import OmegaConf 11 | 12 | from datasets import get_dataset 13 | from modules import import_fns 14 | 15 | 16 | def run(cfg): 17 | if ddp_utils.is_master(): 18 | Wandb.launch(cfg, cfg.log, True) 19 | get_logger( 20 | "jam_.log", 21 | clear=True, 22 | format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", 23 | level="DEBUG", 24 | ) 25 | jampt.set_gpu_mode(cfg.cuda, cfg.trainer.gpu) 26 | 27 | init_model, loss_fn_wrapper, trainer_register = import_fns(cfg.model) 28 | 29 | trainer_str = ( 30 | "utils.ddp_trainer.Trainer" if cfg.is_dist else "utils.trainer.Trainer" 31 | ) 32 | trainer = load_class(trainer_str)(cfg.trainer, loss_fn_wrapper(cfg)) 33 | model = init_model(cfg.model) 34 | optimizer = jam_instantiate(cfg.optimizer.fn, cfg.optimizer, model) 35 | trainer.set_model_optim(model, optimizer) 36 | trainer_register(trainer, cfg) 37 | check_loss_error(trainer) 38 | 39 | # data 40 | trainset, valset = get_dataset(cfg.data) 41 | trainset = get_subset(trainset, cfg.data.train_size) 42 | valset = get_subset(valset, cfg.data.val_size) 43 | train_loader, train_sampler, val_loader, val_sampler = jam_instantiate( 44 | cfg.data.dataloader, 45 | trainset, 46 | valset, 47 | rank=cfg.trainer.rank, 48 | world_size=cfg.trainer.world_size, 49 | ) 50 | if cfg.is_dist: 51 | trainer.set_sampler(train_sampler, val_sampler) 52 | trainer.set_dataloader(train_loader, val_loader) 53 | 54 | if ddp_utils.is_master(): 55 | trainer_save_cfg(trainer, cfg) 56 | trainer.set_monitor(cfg.log) 57 | trainer.save_ckpt() 58 | 59 | trainer.train() 60 | 61 | Wandb.finish() 62 | 63 | 64 | @ddp_utils.ddp_runner 65 | def mock_run(cfg): 66 | run(cfg) 67 | 68 | 69 | @hydra.main(config_path="conf", config_name="config.yaml") 70 | def main(cfg): 71 | OmegaConf.set_struct(cfg, False) 72 | link_hyd_run() 73 | cfg.data.path = hydpath("data") # address hyd relative path 74 | if cfg.is_dist: 75 | world_size = torch.cuda.device_count() 76 | ddp_utils.prepare_cfg(cfg) 77 | mp.spawn(mock_run, args=(world_size, None, cfg), nprocs=world_size, join=True) 78 | else: 79 | run(cfg) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() # pylint: disable=no-value-for-parameter 84 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | import jammy.utils.imp as imp 2 | 3 | 4 | def import_fns(cfg): 5 | model_fn = imp.load_class(f"modules.{cfg.name}.init_model") 6 | loss_fn = imp.load_class(cfg.loss_fn) 7 | register_fn = imp.load_class(cfg.trainer_register) 8 | return model_fn, loss_fn, register_fn 9 | -------------------------------------------------------------------------------- /modules/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jammy import jam_instantiate 3 | 4 | 5 | def get_optimizer(opt_cfg, model): 6 | return jam_instantiate(opt_cfg, model.parameters()) 7 | 8 | 9 | def get_tune_optimizer(opt_cfg, model): 10 | return torch.optim.Adam( 11 | [ 12 | {"params": model.drift.parameters(), "lr": opt_cfg.drift}, 13 | {"params": model.score.parameters(), "lr": opt_cfg.score}, 14 | ], 15 | lr=2e-4, 16 | ) 17 | 18 | 19 | def fix_drift_optimizer(opt_cfg, model): 20 | for param in model.drift.parameters(): 21 | param.requires_grad = False 22 | return torch.optim.Adam( 23 | [{"params": model.score.parameters(), "lr": opt_cfg.score}], lr=2e-4 24 | ) 25 | -------------------------------------------------------------------------------- /modules/sde.py: -------------------------------------------------------------------------------- 1 | from jammy import hyd_instantiate 2 | 3 | from networks.diff_flow import DiffFlow, QuickDiffFlow 4 | 5 | __all__ = ["init_model"] 6 | 7 | 8 | def _init_model(cfg): 9 | if "_target_" in cfg: 10 | return hyd_instantiate(cfg) 11 | raise RuntimeError 12 | 13 | 14 | def init_model(cfg): 15 | timestamps = hyd_instantiate(cfg.time_fn)() 16 | diffusion = hyd_instantiate(cfg.diff_fn)() 17 | condition = hyd_instantiate(cfg.cond_fn)() 18 | 19 | drift = _init_model(cfg.drift) 20 | score = _init_model(cfg.score) 21 | 22 | module = QuickDiffFlow if cfg.quick else DiffFlow 23 | return module(cfg.d_in, timestamps, diffusion, condition, drift, score) 24 | -------------------------------------------------------------------------------- /modules/sde_img_fns.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch_fidelity 4 | from jammy import io 5 | from jamtorch import ddp, get_logger, no_grad_func 6 | from jamtorch.data import get_batch, num_to_groups 7 | from jamtorch.ddp import ddp_utils 8 | 9 | from datasets import data_transform, inverse_data_transform 10 | from utils.diagnosis import ( 11 | backward_whole_process, 12 | backward_z2x, 13 | fb_whole_process, 14 | recon_x, 15 | ) 16 | from utils.scalars import scalar_helper 17 | from viz.img import ( 18 | check_unnormal_imgs, 19 | save_seperate_imgs, 20 | tensor2imgnd, 21 | viz_img_process, 22 | wandb_write_ndimg, 23 | ) 24 | 25 | logger = get_logger() 26 | 27 | 28 | def image_fidelity(img_dir): 29 | metric = torch_fidelity.calculate_metrics( 30 | input1=img_dir, 31 | input2="cifar10-train", 32 | cuda=True, 33 | isc=False, 34 | fid=True, 35 | kid=False, 36 | verbose=False, 37 | ) 38 | return metric["frechet_inception_distance"] 39 | 40 | 41 | def epoch_start_wrapper(cfg): 42 | def _epoch_start(trainer): 43 | from viz.lines import check_dflow_coef 44 | 45 | model = trainer.mmodel 46 | dataset = trainer.train_loader.dataset 47 | num_grid = cfg.data.eval_n_samples 48 | sample = get_batch(dataset, num_grid * num_grid) 49 | 50 | trainer.test_sample = data_transform(cfg.data, sample) 51 | trainer.test_noise = model.sample_noise(num_grid * num_grid) 52 | 53 | check_unnormal_imgs(cfg, trainer.test_sample, num_grid, 0, "GT-sample") 54 | check_unnormal_imgs(cfg, trainer.test_noise, num_grid, 0, "GT-noise") 55 | 56 | check_dflow_coef(model) 57 | 58 | return _epoch_start 59 | 60 | 61 | def viz_gt_process(model, x, gif_suffix, eval_n, reverse_transform_fn): 62 | forward_kv, backward_kv = fb_whole_process( 63 | model, x, model.timestamps, model.diffusion, model.condition, is_gt=True 64 | ) 65 | viz_img_process( 66 | forward_kv, 67 | f"f_{gif_suffix}", 68 | eval_n, 69 | ["data", "grad", "noise"], 70 | reverse_transform_fn, 71 | ) 72 | viz_img_process( 73 | backward_kv, 74 | f"b_{gif_suffix}", 75 | eval_n, 76 | ["data", "drift", "diff"], 77 | reverse_transform_fn, 78 | ) 79 | return forward_kv, backward_kv 80 | 81 | 82 | def viz_sample_process(model, z, gif_suffix, eval_n, reverse_transform_fn): 83 | backward_kv = backward_whole_process( 84 | model, z, model.timestamps, model.diffusion, model.condition 85 | ) 86 | viz_img_process( 87 | backward_kv, 88 | f"s_{gif_suffix}", 89 | eval_n, 90 | ["data", "drift", "diff"], 91 | reverse_transform_fn, 92 | ) 93 | b_imgnd = tensor2imgnd( 94 | reverse_transform_fn(backward_kv["data"][-1]), eval_n, eval_n 95 | ) 96 | return backward_kv, b_imgnd 97 | 98 | 99 | def epoch_after_wrapper(cfg): # pylint: disable=too-many-statements 100 | eval_n = cfg.data.eval_n_samples 101 | reverse_transform_fn = functools.partial(inverse_data_transform, cfg.data) 102 | 103 | ## prepare fid check 104 | n_gpu = ddp_utils.get_world_size() 105 | img_per_gpu = cfg.data.fid.num_samples // n_gpu 106 | sample_fid_path = "sample_fid_imgs" 107 | io.makedirs(sample_fid_path) 108 | gtimg_per_gpu = cfg.data.val_size // n_gpu 109 | sample_gt_path = "fb_fid_imgs" 110 | io.makedirs(sample_gt_path) 111 | check_fid_fn = image_fidelity 112 | 113 | @ddp.master_only 114 | @no_grad_func 115 | def check_gtsample_traj(trainer): 116 | n_epoch, n_iter = trainer.epoch_cnt, trainer.iter_cnt 117 | 118 | model = trainer.ema.model 119 | gt_sample = trainer.test_sample.to(trainer.device) 120 | forward_kv, backward_kv = viz_gt_process( 121 | model, gt_sample, f"{n_epoch}_{n_iter}", eval_n, reverse_transform_fn 122 | ) 123 | f_imgnd = tensor2imgnd( 124 | reverse_transform_fn(forward_kv["data"][-1]), eval_n, eval_n 125 | ) 126 | b_imgnd = tensor2imgnd( 127 | reverse_transform_fn(backward_kv["data"][-1]), eval_n, eval_n 128 | ) 129 | wandb_write_ndimg(f_imgnd, n_iter, "t_f") 130 | wandb_write_ndimg(b_imgnd, n_iter, "t_b") 131 | 132 | @ddp.master_only 133 | def check_sampling(trainer): 134 | logger.info(f"eval {trainer.iter_cnt}") 135 | n_epoch, n_iter = trainer.epoch_cnt, trainer.iter_cnt 136 | 137 | model = trainer.ema.model 138 | z = trainer.test_noise.to(trainer.device) 139 | _, sampling_img = viz_sample_process( 140 | model, z, f"{n_epoch}_{n_iter}", eval_n, reverse_transform_fn 141 | ) 142 | wandb_write_ndimg(sampling_img, n_iter, "sample") 143 | 144 | @no_grad_func 145 | def runtime_fid_sample(trainer): 146 | model = trainer.ema.model 147 | batch_size = cfg.data.fid.batch_size 148 | cnt = trainer.rank * img_per_gpu 149 | for _batch_size in num_to_groups(img_per_gpu, batch_size): 150 | s_n = model.sample_noise(_batch_size) 151 | sample = backward_z2x(model, s_n, *scalar_helper(model)) 152 | sample = inverse_data_transform(cfg.data, sample) 153 | save_seperate_imgs(sample.cpu(), sample_fid_path, cnt) 154 | cnt += _batch_size 155 | ddp_utils.barrier() 156 | 157 | @no_grad_func 158 | def runtime_fid_gt(trainer): 159 | model = trainer.ema.model 160 | cnt = trainer.rank * gtimg_per_gpu 161 | for data in trainer.val_loader: 162 | data = data[0].float().to(trainer.device) 163 | data_trans = data_transform(cfg.data, data) 164 | x = recon_x( 165 | model, data_trans, model.timestamps, model.diffusion, model.condition 166 | ) 167 | data_rec = inverse_data_transform(cfg.data, x) 168 | save_seperate_imgs(data_rec, sample_gt_path, cnt) 169 | cnt += len(data) 170 | ddp_utils.barrier() 171 | 172 | @ddp.master_only 173 | def runtime_check_fid(trainer): 174 | sample_fid = check_fid_fn(sample_fid_path) 175 | gt_fid = check_fid_fn(sample_gt_path) 176 | print(sample_fid, gt_fid) 177 | trainer.cmdviz.update("eval", {"sample_fid": sample_fid, "gt_fid": gt_fid}) 178 | 179 | @no_grad_func 180 | def epoch_after(trainer): 181 | ddp_utils.barrier() 182 | check_sampling(trainer) 183 | check_gtsample_traj(trainer) 184 | if cfg.model.enable_fid: 185 | runtime_fid_sample(trainer) 186 | runtime_fid_gt(trainer) 187 | runtime_check_fid(trainer) 188 | 189 | return epoch_after 190 | 191 | 192 | def img_trainer_register(trainer, cfg): 193 | trainer.register_event("epoch:start", epoch_start_wrapper(cfg)) 194 | trainer.register_event("val:start", epoch_after_wrapper(cfg)) 195 | -------------------------------------------------------------------------------- /modules/sde_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from jammy import hyd_instantiate 3 | 4 | from datasets import data_transform 5 | from utils.scalars import instantiate_scaler 6 | from viz.lines import check_dflow_coef 7 | 8 | # pylint: disable=unused-argument, unused-variable 9 | 10 | 11 | def img_preprocess(cfg, feed_dict, device): 12 | feed_dict = feed_dict[0].float().to(device) 13 | return data_transform(cfg.data, feed_dict) 14 | 15 | 16 | def point_preprocess(cfg, feed_dict, device): 17 | return feed_dict.float().to(device) 18 | 19 | 20 | def loss_fn_wrapper(cfg): 21 | preprocess_fn = hyd_instantiate(cfg.data.preprocess_fn, cfg) 22 | 23 | def loss_fn(trainer, feed_dict, is_train): 24 | model = trainer.mmodel 25 | feed_dict = preprocess_fn(feed_dict, trainer.device) 26 | z, logabsdet = model(feed_dict) 27 | 28 | norm_loss = -model.noise_log_prob(z.flatten(start_dim=1)).mean() 29 | det_loss = -logabsdet.mean() 30 | return ( 31 | norm_loss + det_loss, 32 | {}, 33 | { 34 | "norm": norm_loss, 35 | "det_loss": det_loss, 36 | "dim/norm": norm_loss / np.prod(z.shape[1:]), 37 | }, 38 | ) 39 | 40 | return loss_fn 41 | 42 | 43 | def cont_loss_fn_wrapper(cfg): 44 | preprocess_fn = hyd_instantiate(cfg.data.preprocess_fn, cfg) 45 | n_idx = 0 46 | # FIXME: FIX IN CONFIG 47 | n_iters = np.array(cfg.model.N_iter) 48 | n_values = np.array(cfg.model.N_values) 49 | timer, differ, conder = instantiate_scaler(cfg) 50 | from jamtorch import get_logger 51 | 52 | logger = get_logger() 53 | 54 | def update_scalar(iter_cnt): 55 | nonlocal n_idx, timer, differ, conder 56 | cur_idx = np.sum(iter_cnt > n_iters) 57 | if cur_idx > n_idx: 58 | new_num = int(n_values[cur_idx]) 59 | cfg.model.time_fn.num_steps = new_num 60 | cfg.model.diff_fn.num_steps = new_num 61 | cfg.model.cond_fn.num_steps = new_num 62 | n_idx = cur_idx 63 | timer, differ, conder = instantiate_scaler( 64 | cfg 65 | ) # pylint: disable=unused-variable 66 | logger.critical(f"\nIter{iter_cnt}: {n_idx} steps level: {new_num}") 67 | return True 68 | return False 69 | 70 | def loss_fn(trainer, feed_dict, is_train): 71 | model = trainer.mmodel 72 | if update_scalar(trainer.iter_cnt): 73 | model.timestamps = timer().to(trainer.device) 74 | model.diffusion = differ().to(trainer.device) 75 | model.condition = conder().to(trainer.device) 76 | model.delta_t = model.timestamps[1:] - model.timestamps[:-1] 77 | check_dflow_coef(model, prefix_caption=trainer.iter_cnt) 78 | feed_dict = preprocess_fn(feed_dict, trainer.device) 79 | 80 | cur_time = timer.rand() 81 | cur_diff = differ.index(cur_time).to(trainer.device) 82 | cur_cond = conder.index(cur_time).to(trainer.device) 83 | cur_time = cur_time.to(trainer.device) 84 | 85 | z, logabsdet = model.forward_cond(feed_dict, cur_time, cur_diff, cur_cond) 86 | 87 | norm_loss = -model.noise_log_prob(z.flatten(start_dim=1)).mean() 88 | det_loss = -logabsdet.mean() 89 | return ( 90 | norm_loss + det_loss, 91 | {}, 92 | { 93 | "norm": norm_loss, 94 | "det_loss": det_loss, 95 | "dim/norm": norm_loss / np.prod(z.shape[1:]), 96 | }, 97 | ) 98 | 99 | return loss_fn 100 | -------------------------------------------------------------------------------- /modules/sde_ps_fns.py: -------------------------------------------------------------------------------- 1 | import jammy.image as jimg 2 | import jamtorch.prototype as jampt 3 | import matplotlib.pyplot as plt 4 | from jamtorch.trainer import step_lr 5 | from jamtorch.utils import no_grad_func 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | import utils.diagnosis as dgns 9 | from utils.scalars import scalar_helper 10 | from viz.ps import seqSample2img, viz_sample 11 | 12 | # pylint: disable=unused-argument 13 | 14 | 15 | def epoch_start(trainer): 16 | dataset = trainer.train_loader.dataset 17 | trainer.test_sample = dataset.dataset.data[:50000] 18 | viz_sample(trainer.test_sample, "ground sample", "GT-sample.png") 19 | 20 | 21 | def epoch_after_wrapper(cfg): 22 | @no_grad_func 23 | def check_gtsample_traj(trainer): 24 | model = trainer.model 25 | gt_sample = jampt.from_numpy(trainer.test_sample) 26 | z = dgns.forward_x2z(model, gt_sample, *scalar_helper(model)) 27 | viz_sample( 28 | z, "test forward", f"forward_{trainer.epoch_cnt:02}.png", fix_lim=False 29 | ) 30 | x = dgns.backward_z2x(model, z, *scalar_helper(model)) 31 | viz_sample(x, "test backward", f"backward_{trainer.epoch_cnt:02}.png") 32 | 33 | @no_grad_func 34 | def epoch_after(trainer): 35 | sample = trainer.model.sample(50000) 36 | viz_sample(sample, "sde sample", f"epoch_{trainer.epoch_cnt:02}.png") 37 | check_gtsample_traj(trainer) 38 | 39 | return epoch_after 40 | 41 | 42 | def check_process(cfg): 43 | def _fn(trainer): 44 | model = trainer.model 45 | gt_sample = jampt.from_numpy(trainer.test_sample) 46 | noise = model.sample_noise(5000) 47 | imgs = [] 48 | 49 | # GT Forward 50 | f_process = dgns.forward_whole_process(model, gt_sample, *scalar_helper(model)) 51 | fig = seqSample2img(f_process["data"], 10) 52 | fig.suptitle("GT Forward") 53 | imgs.append(jimg.plt2pil(fig)) 54 | plt.close(fig) 55 | 56 | # GT Backward 57 | gt_noise = f_process["data"][-1] 58 | b_process = dgns.backward_whole_process(model, gt_noise, *scalar_helper(model)) 59 | fig = seqSample2img(b_process["data"][::-1], 10) 60 | fig.suptitle("GT Backward") 61 | imgs.append(jimg.plt2pil(fig)) 62 | plt.close(fig) 63 | 64 | # Noise Backward 65 | b_process = dgns.backward_whole_process(model, noise, *scalar_helper(model)) 66 | fig = seqSample2img(b_process["data"], 10) 67 | fig.suptitle("Noise Backward") 68 | imgs.append(jimg.plt2pil(fig)) 69 | plt.close(fig) 70 | 71 | # Deterministic Backward 72 | d_process = dgns.backward_deterministic_process( 73 | model, noise, *scalar_helper(model) 74 | ) 75 | fig = seqSample2img(d_process["data"], 10) 76 | fig.suptitle("Deterministic Backward") 77 | imgs.append(jimg.plt2pil(fig)) 78 | plt.close(fig) 79 | 80 | # LGV 81 | lgv_process = dgns.langevin_process( 82 | model, noise, -1, 1000, 0.05, model.condition, all_img=True 83 | ) 84 | fig = seqSample2img(lgv_process["data_mean"], 10) 85 | fig.suptitle("LGV process") 86 | imgs.append(jimg.plt2pil(fig)) 87 | plt.close(fig) 88 | 89 | fig = jimg.imgstack(imgs) 90 | jimg.savefig(fig, f"whole-{trainer.epoch_cnt:02}.png") 91 | plt.close(fig) 92 | 93 | return _fn 94 | 95 | 96 | def points_trainer_register(trainer, cfg): 97 | trainer.register_event("val:end", epoch_after_wrapper(cfg)) 98 | trainer.register_event("epoch:end", check_process(cfg)) 99 | trainer.register_event("epoch:start", epoch_start) 100 | 101 | scheduler = StepLR(trainer.optimizer, step_size=2, gamma=cfg.optimizer.gamma) 102 | trainer.lr_scheduler = scheduler 103 | trainer.register_event("epoch:after", step_lr) 104 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/networks/__init__.py -------------------------------------------------------------------------------- /networks/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from jamtorch.distributions import StandardNormal 4 | 5 | 6 | def batch_noise_square(noise): 7 | return torch.sum(noise.flatten(start_dim=1) ** 2, dim=1) 8 | 9 | 10 | # FIXME: cond_f, cond_b is scalar, could cause bugs 11 | 12 | # pylint: disable=too-many-arguments 13 | 14 | 15 | class BaseModel(torch.nn.Module): 16 | def __init__(self, data_shape, drift_net, score_net): 17 | super().__init__() 18 | self.data_shape = tuple(data_shape) 19 | self.drift = drift_net 20 | self.score = score_net 21 | self._distribution = StandardNormal([np.prod(data_shape)]) 22 | 23 | def forward_step(self, x, step_size, cond_f, cond_b, diff_f, diff_b): 24 | forward_noise = self._distribution.sample(x.shape[0]).view(x.shape) 25 | z = ( 26 | self.cal_next_nodiffusion(x, step_size, cond_f) 27 | + torch.sqrt(step_size) * diff_f * forward_noise 28 | ) 29 | backward_noise = self.cal_backnoise(x, z, step_size, cond_b, diff_b) 30 | delta_s = -0.5 * ( 31 | batch_noise_square(backward_noise) - batch_noise_square(forward_noise) 32 | ) 33 | return z, delta_s 34 | 35 | def cal_backnoise(self, x, z, step_size, cond_b, diff_b): 36 | f_backward = self.drift(z, cond_b) - diff_b ** 2 * self.score(z, cond_b) 37 | return (x - z + f_backward * step_size) / (diff_b * torch.sqrt(step_size)) 38 | 39 | def cal_forwardnoise(self, x, z, step_size, cond_f, diff_f): 40 | f_backward = self.drift(x, cond_f) 41 | return (z - x - f_backward * step_size) / (diff_f * torch.sqrt(step_size)) 42 | 43 | def cal_next_nodiffusion(self, x, step_size, cond_f): 44 | return x + self.drift(x, cond_f) * step_size 45 | 46 | def cal_prev_nodiffusion(self, z, step_size, cond_b, diff_b): 47 | return ( 48 | z 49 | - (self.drift(z, cond_b) - diff_b ** 2 * self.score(z, cond_b)) * step_size 50 | ) 51 | 52 | def backward_step(self, z, step_size, cond_f, cond_b, diff_f, diff_b): 53 | backward_noise = self._distribution.sample(z.shape[0]).view(z.shape) 54 | x = ( 55 | self.cal_prev_nodiffusion(z, step_size, cond_b, diff_b) 56 | + torch.sqrt(step_size) * diff_b * backward_noise 57 | ) 58 | forward_noise = self.cal_forwardnoise(x, z, step_size, cond_f, diff_f) 59 | delta_s = -0.5 * ( 60 | batch_noise_square(forward_noise) - batch_noise_square(backward_noise) 61 | ) 62 | return x, delta_s 63 | 64 | def sample(self, num_samples, timestamps, diffusion, condition): 65 | z = self._distribution.sample(num_samples).view(-1, *self.data_shape) 66 | x, _ = self.backward(z, timestamps, diffusion, condition) 67 | return x 68 | 69 | def forward(self, x, timestamps, diffusion, condition): 70 | batch_size = x.shape[0] 71 | logabsdet = x.new_zeros(batch_size) 72 | delta_t = timestamps[1:] - timestamps[:-1] 73 | for i_th, cur_delta_t in enumerate(delta_t): 74 | x, new_det = self.forward_step( 75 | x, 76 | cur_delta_t, 77 | condition[i_th], 78 | condition[i_th + 1], 79 | diffusion[i_th], 80 | diffusion[i_th + 1], 81 | ) 82 | logabsdet += new_det 83 | return x, logabsdet 84 | 85 | def backward(self, z, timestamps, diffusion, condition): 86 | delta_t = timestamps[1:] - timestamps[:-1] 87 | logabsdet = z.new_zeros(z.shape[0]) 88 | for i_th, cur_delta_t in enumerate(torch.flip(delta_t, (0,))): 89 | z, new_det = self.backward_step( 90 | z, 91 | cur_delta_t, 92 | condition[-i_th - 2], 93 | condition[-i_th - 1], 94 | diffusion[-i_th - 2], 95 | diffusion[-i_th - 1], 96 | ) 97 | logabsdet += new_det 98 | return z, logabsdet 99 | 100 | def forward_list(self, x): 101 | rtn = [x] 102 | for i_th, cur_delta_t in enumerate(self.delta_t): 103 | x, _ = self.forward_step( 104 | x, 105 | cur_delta_t, 106 | self.condition[i_th], 107 | self.condition[i_th + 1], 108 | self.diffusion[i_th], 109 | self.diffusion[i_th + 1], 110 | ) 111 | rtn.append(x) 112 | return rtn 113 | 114 | def backward_list(self, z): 115 | rtn = [z] 116 | for i_th, cur_delta_t in enumerate(torch.flip(self.delta_t, (0,))): 117 | z, _ = self.backward_step( 118 | z, 119 | cur_delta_t, 120 | self.condition[-i_th - 2], 121 | self.condition[-i_th - 1], 122 | self.diffusion[-i_th - 2], 123 | self.diffusion[-i_th - 1], 124 | ) 125 | rtn.append(z) 126 | return rtn 127 | -------------------------------------------------------------------------------- /networks/diff_flow.py: -------------------------------------------------------------------------------- 1 | from utils.sdefunction import SdeF 2 | 3 | from .base_model import BaseModel 4 | 5 | __all__ = ["DiffFlow", "QuickDiffFlow"] 6 | 7 | # pylint: disable=too-many-arguments, arguments-differ 8 | class DiffFlow(BaseModel): 9 | def __init__( 10 | self, data_shape, timestamp, diffusion, condition, drift_net, score_net 11 | ): 12 | super().__init__(data_shape, drift_net, score_net) 13 | self.register_buffer("timestamps", timestamp) 14 | self.register_buffer("diffusion", diffusion) 15 | self.register_buffer("condition", condition) 16 | assert self.timestamps.shape == self.diffusion.shape 17 | self.register_buffer("delta_t", self.timestamps[1:] - self.timestamps[:-1]) 18 | 19 | def forward(self, x): 20 | return super().forward(x, self.timestamps, self.diffusion, self.condition) 21 | 22 | def backward(self, z): 23 | return super().backward(z, self.timestamps, self.diffusion, self.condition) 24 | 25 | def sample(self, n_samples): 26 | z = self._distribution.sample(n_samples).view(-1, *self.data_shape) 27 | x, _ = self.backward(z) 28 | return x 29 | 30 | def sample_noise(self, n_samples): 31 | return self._distribution.sample(n_samples).view(-1, *self.data_shape) 32 | 33 | def noise_log_prob(self, z): 34 | return self._distribution.log_prob(z) 35 | 36 | 37 | class QuickDiffFlow(DiffFlow): 38 | def forward(self, x): 39 | return SdeF.apply( 40 | x, 41 | self, 42 | self.timestamps, 43 | self.diffusion, 44 | self.condition, 45 | *tuple(self.parameters()) 46 | ) 47 | 48 | def forward_cond(self, x, timestamps, diffusion, condition): 49 | return SdeF.apply( 50 | x, self, timestamps, diffusion, condition, *tuple(self.parameters()) 51 | ) 52 | -------------------------------------------------------------------------------- /networks/fouriermlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | __all__ = ["FourierMLP"] 7 | 8 | 9 | class FourierMLP(nn.Module): 10 | def __init__(self, data_shape, num_layers=2, channels=128): 11 | super().__init__() 12 | self.data_shape = [data_shape] 13 | 14 | self.register_buffer( 15 | "timestep_coeff", torch.linspace(start=0.1, end=100, steps=channels)[None] 16 | ) 17 | self.timestep_phase = nn.Parameter(torch.randn(channels)[None]) 18 | self.input_embed = nn.Linear(int(np.prod(data_shape)), channels) 19 | self.timestep_embed = nn.Sequential( 20 | nn.Linear(2 * channels, channels), 21 | nn.GELU(), 22 | nn.Linear(channels, channels), 23 | ) 24 | self.layers = nn.Sequential( 25 | nn.GELU(), 26 | *[ 27 | nn.Sequential(nn.Linear(channels, channels), nn.GELU()) 28 | for _ in range(num_layers) 29 | ], 30 | nn.Linear(channels, int(np.prod(data_shape))), 31 | ) 32 | 33 | def forward(self, inputs, cond): 34 | sin_embed_cond = torch.sin( 35 | (self.timestep_coeff * cond.float()) + self.timestep_phase 36 | ) 37 | cos_embed_cond = torch.cos( 38 | (self.timestep_coeff * cond.float()) + self.timestep_phase 39 | ) 40 | embed_cond = self.timestep_embed( 41 | rearrange([sin_embed_cond, cos_embed_cond], "d b w -> b (d w)") 42 | ) 43 | embed_ins = self.input_embed(inputs.view(inputs.shape[0], -1)) 44 | out = self.layers(embed_ins + embed_cond) 45 | return out.view(inputs.shape) 46 | -------------------------------------------------------------------------------- /networks/official_unet.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def get_timestep_embedding(timesteps, embedding_dim): 9 | """ 10 | This matches the implementation in Denoising Diffusion Probabilistic Models: 11 | From Fairseq. 12 | Build sinusoidal embeddings. 13 | This matches the implementation in tensor2tensor, but differs slightly 14 | from the description in Section 3.5 of "Attention Is All You Need". 15 | """ 16 | assert len(timesteps.shape) == 1 17 | 18 | half_dim = embedding_dim // 2 19 | emb = math.log(10000) / (half_dim - 1) 20 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 21 | emb = emb.to(device=timesteps.device) 22 | emb = timesteps.float()[:, None] * emb[None, :] 23 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 24 | if embedding_dim % 2 == 1: # zero pad 25 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 26 | return emb 27 | 28 | 29 | def nonlinearity(x): 30 | # swish 31 | return x * torch.sigmoid(x) 32 | 33 | 34 | def Normalize(in_channels): 35 | return torch.nn.GroupNorm( 36 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 37 | ) 38 | 39 | 40 | class Upsample(nn.Module): 41 | def __init__(self, in_channels, with_conv): 42 | super().__init__() 43 | self.with_conv = with_conv 44 | if self.with_conv: 45 | self.conv = torch.nn.Conv2d( 46 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 47 | ) 48 | 49 | def forward(self, x): 50 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 51 | if self.with_conv: 52 | x = self.conv(x) 53 | return x 54 | 55 | 56 | class Downsample(nn.Module): 57 | def __init__(self, in_channels, with_conv): 58 | super().__init__() 59 | self.with_conv = with_conv 60 | if self.with_conv: 61 | # no asymmetric padding in torch conv, must do it ourselves 62 | self.conv = torch.nn.Conv2d( 63 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 64 | ) 65 | 66 | def forward(self, x): 67 | if self.with_conv: 68 | pad = (0, 1, 0, 1) 69 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 70 | x = self.conv(x) 71 | else: 72 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 73 | return x 74 | 75 | 76 | class ResnetBlock(nn.Module): 77 | # middle, in_d * 2 + (in_d*3*3+1)*out_d + (t_d+1)*out_d + out_d*2 + (out_d*3*3+1)*out_d + 78 | # (in_d * 1 * 1 + 1)* out_d 79 | def __init__( 80 | self, 81 | *, 82 | in_channels, 83 | out_channels=None, 84 | conv_shortcut=False, 85 | dropout, 86 | temb_channels=512 87 | ): 88 | super().__init__() 89 | self.in_channels = in_channels 90 | out_channels = in_channels if out_channels is None else out_channels 91 | self.out_channels = out_channels 92 | self.use_conv_shortcut = conv_shortcut 93 | 94 | self.norm1 = Normalize(in_channels) 95 | self.conv1 = torch.nn.Conv2d( 96 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 97 | ) 98 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 99 | self.norm2 = Normalize(out_channels) 100 | self.dropout = torch.nn.Dropout(dropout) 101 | self.conv2 = torch.nn.Conv2d( 102 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 103 | ) 104 | if self.in_channels != self.out_channels: 105 | if self.use_conv_shortcut: 106 | self.conv_shortcut = torch.nn.Conv2d( 107 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 108 | ) 109 | else: 110 | self.nin_shortcut = torch.nn.Conv2d( 111 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 112 | ) 113 | 114 | def forward(self, x, temb): 115 | h = x 116 | h = self.norm1(h) 117 | h = nonlinearity(h) 118 | h = self.conv1(h) 119 | 120 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 121 | 122 | h = self.norm2(h) 123 | h = nonlinearity(h) 124 | h = self.dropout(h) 125 | h = self.conv2(h) 126 | 127 | if self.in_channels != self.out_channels: 128 | if self.use_conv_shortcut: 129 | x = self.conv_shortcut(x) 130 | else: 131 | x = self.nin_shortcut(x) 132 | 133 | return x + h 134 | 135 | 136 | class AttnBlock(nn.Module): 137 | def __init__(self, in_channels): 138 | super().__init__() 139 | self.in_channels = in_channels 140 | 141 | self.norm = Normalize(in_channels) 142 | self.q = torch.nn.Conv2d( 143 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 144 | ) 145 | self.k = torch.nn.Conv2d( 146 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 147 | ) 148 | self.v = torch.nn.Conv2d( 149 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 150 | ) 151 | self.proj_out = torch.nn.Conv2d( 152 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 153 | ) 154 | 155 | def forward(self, x): 156 | h_ = x 157 | h_ = self.norm(h_) 158 | q = self.q(h_) 159 | k = self.k(h_) 160 | v = self.v(h_) 161 | 162 | # compute attention 163 | b, c, h, w = q.shape 164 | q = q.reshape(b, c, h * w) 165 | q = q.permute(0, 2, 1) # b,hw,c 166 | k = k.reshape(b, c, h * w) # b,c,hw 167 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 168 | w_ = w_ * (int(c) ** (-0.5)) 169 | w_ = torch.nn.functional.softmax(w_, dim=2) 170 | 171 | # attend to values 172 | v = v.reshape(b, c, h * w) 173 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 174 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 175 | h_ = torch.bmm(v, w_) 176 | h_ = h_.reshape(b, c, h, w) 177 | 178 | h_ = self.proj_out(h_) 179 | 180 | return x + h_ 181 | 182 | 183 | class Model(nn.Module): 184 | def __init__( 185 | self, 186 | ch, 187 | out_ch, 188 | ch_mult, 189 | num_res_blocks, 190 | attn_resolutions, 191 | dropout, 192 | in_channels, 193 | resolution, 194 | resamp_with_conv, 195 | ): 196 | super().__init__() 197 | ch_mult = tuple(ch_mult) 198 | 199 | self.ch = ch 200 | self.temb_ch = self.ch * 4 201 | self.num_resolutions = len(ch_mult) 202 | self.num_res_blocks = num_res_blocks 203 | self.resolution = resolution 204 | self.in_channels = in_channels 205 | 206 | # timestep embedding 207 | self.temb = nn.Module() 208 | self.temb.dense = nn.ModuleList( 209 | [ 210 | torch.nn.Linear(self.ch, self.temb_ch), 211 | torch.nn.Linear(self.temb_ch, self.temb_ch), 212 | ] 213 | ) 214 | 215 | # downsampling 216 | self.conv_in = torch.nn.Conv2d( 217 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 218 | ) 219 | 220 | curr_res = resolution 221 | in_ch_mult = (1,) + ch_mult 222 | self.down = nn.ModuleList() 223 | block_in = None 224 | for i_level in range(self.num_resolutions): # 4 225 | block = nn.ModuleList() 226 | attn = nn.ModuleList() 227 | block_in = ch * in_ch_mult[i_level] 228 | block_out = ch * ch_mult[i_level] 229 | for i_block in range(self.num_res_blocks): # 2 230 | block.append( 231 | ResnetBlock( 232 | in_channels=block_in, 233 | out_channels=block_out, 234 | temb_channels=self.temb_ch, 235 | dropout=dropout, 236 | ) 237 | ) 238 | block_in = block_out 239 | if curr_res in attn_resolutions: 240 | attn.append(AttnBlock(block_in)) # 241 | down = nn.Module() 242 | down.block = block 243 | down.attn = attn 244 | if i_level != self.num_resolutions - 1: 245 | down.downsample = Downsample(block_in, resamp_with_conv) 246 | curr_res = curr_res // 2 247 | self.down.append(down) 248 | 249 | self.mid = nn.Module() 250 | self.mid.block_1 = ResnetBlock( 251 | in_channels=block_in, 252 | out_channels=block_in, 253 | temb_channels=self.temb_ch, 254 | dropout=dropout, 255 | ) 256 | self.mid.attn_1 = AttnBlock(block_in) 257 | self.mid.block_2 = ResnetBlock( 258 | in_channels=block_in, 259 | out_channels=block_in, 260 | temb_channels=self.temb_ch, 261 | dropout=dropout, 262 | ) 263 | 264 | # upsampling 265 | self.up = nn.ModuleList() 266 | for i_level in reversed(range(self.num_resolutions)): 267 | block = nn.ModuleList() 268 | attn = nn.ModuleList() 269 | block_out = ch * ch_mult[i_level] 270 | skip_in = ch * ch_mult[i_level] 271 | for i_block in range(self.num_res_blocks + 1): 272 | if i_block == self.num_res_blocks: 273 | skip_in = ch * in_ch_mult[i_level] 274 | block.append( 275 | ResnetBlock( 276 | in_channels=block_in + skip_in, 277 | out_channels=block_out, 278 | temb_channels=self.temb_ch, 279 | dropout=dropout, 280 | ) 281 | ) 282 | block_in = block_out 283 | if curr_res in attn_resolutions: 284 | attn.append(AttnBlock(block_in)) 285 | up = nn.Module() 286 | up.block = block 287 | up.attn = attn 288 | if i_level != 0: 289 | up.upsample = Upsample(block_in, resamp_with_conv) 290 | curr_res = curr_res * 2 291 | self.up.insert(0, up) # prepend to get consistent order 292 | 293 | # end 294 | self.norm_out = Normalize(block_in) 295 | self.conv_out = torch.nn.Conv2d( 296 | block_in, out_ch, kernel_size=3, stride=1, padding=1 297 | ) 298 | 299 | def forward(self, x, t): 300 | assert x.shape[2] == x.shape[3] == self.resolution 301 | 302 | # timestep embedding 303 | t = t.expand((x.shape[0],)) 304 | temb = get_timestep_embedding(t, self.ch) 305 | temb = self.temb.dense[0](temb) 306 | temb = nonlinearity(temb) 307 | temb = self.temb.dense[1](temb) 308 | 309 | # downsampling 310 | hs = [self.conv_in(x)] 311 | for i_level in range(self.num_resolutions): 312 | for i_block in range(self.num_res_blocks): 313 | h = self.down[i_level].block[i_block](hs[-1], temb) 314 | if len(self.down[i_level].attn) > 0: 315 | h = self.down[i_level].attn[i_block](h) 316 | hs.append(h) 317 | if i_level != self.num_resolutions - 1: 318 | hs.append(self.down[i_level].downsample(hs[-1])) 319 | 320 | # middle 321 | h = hs[-1] 322 | h = self.mid.block_1(h, temb) 323 | h = self.mid.attn_1(h) 324 | h = self.mid.block_2(h, temb) 325 | 326 | # upsampling 327 | for i_level in reversed(range(self.num_resolutions)): 328 | for i_block in range(self.num_res_blocks + 1): 329 | h = self.up[i_level].block[i_block]( 330 | torch.cat([h, hs.pop()], dim=1), temb 331 | ) 332 | if len(self.up[i_level].attn) > 0: 333 | h = self.up[i_level].attn[i_block](h) 334 | if i_level != 0: 335 | h = self.up[i_level].upsample(h) 336 | 337 | # end 338 | h = self.norm_out(h) 339 | h = nonlinearity(h) 340 | h = self.conv_out(h) 341 | return h 342 | -------------------------------------------------------------------------------- /networks/unet.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import math 3 | from inspect import isfunction 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from torch import nn 10 | 11 | # helpers functions 12 | 13 | __all__ = ["Unet"] 14 | 15 | 16 | def exists(x): 17 | return x is not None 18 | 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if isfunction(d) else d 24 | 25 | 26 | def num_to_groups(num, divisor): 27 | groups = num // divisor 28 | remainder = num % divisor 29 | arr = [divisor] * groups 30 | if remainder > 0: 31 | arr.append(remainder) 32 | return arr 33 | 34 | 35 | class Residual(nn.Module): 36 | def __init__(self, fn): 37 | super().__init__() 38 | self.fn = fn 39 | 40 | def forward(self, x, *args, **kwargs): 41 | return self.fn(x, *args, **kwargs) + x 42 | 43 | 44 | class SinusoidalPosEmb(nn.Module): 45 | def __init__(self, dim): 46 | super().__init__() 47 | self.dim = dim 48 | 49 | def forward(self, x): 50 | device = x.device 51 | half_dim = self.dim // 2 52 | emb = math.log(10000) / (half_dim - 1) 53 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 54 | emb = x[:, None] * emb[None, :] 55 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 56 | return emb 57 | 58 | 59 | class Mish(nn.Module): 60 | def forward(self, x): 61 | return x * torch.tanh(F.softplus(x)) 62 | 63 | 64 | class Upsample(nn.Module): 65 | def __init__(self, dim): 66 | super().__init__() 67 | self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1) 68 | 69 | def forward(self, x): 70 | return self.conv(x) 71 | 72 | 73 | class Downsample(nn.Module): 74 | def __init__(self, dim): 75 | super().__init__() 76 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 77 | 78 | def forward(self, x): 79 | return self.conv(x) 80 | 81 | 82 | class Rezero(nn.Module): 83 | def __init__(self, fn): 84 | super().__init__() 85 | self.fn = fn 86 | self.g = nn.Parameter(torch.zeros(1)) 87 | 88 | def forward(self, x): 89 | return self.fn(x) * self.g 90 | 91 | 92 | # building block modules 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, dim, dim_out, groups=8): 97 | super().__init__() 98 | self.block = nn.Sequential( 99 | nn.Conv2d(dim, dim_out, 3, padding=1), nn.GroupNorm(groups, dim_out), Mish() 100 | ) 101 | 102 | def forward(self, x): 103 | return self.block(x) 104 | 105 | 106 | class ResnetBlock(nn.Module): 107 | def __init__(self, dim, dim_out, *, time_emb_dim, groups=8): 108 | super().__init__() 109 | self.mlp = nn.Sequential(Mish(), nn.Linear(time_emb_dim, dim_out)) 110 | 111 | self.block1 = Block(dim, dim_out) 112 | self.block2 = Block(dim_out, dim_out) 113 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 114 | 115 | def forward(self, x, time_emb): 116 | h = self.block1(x) 117 | h += self.mlp(time_emb)[:, :, None, None] 118 | h = self.block2(h) 119 | return h + self.res_conv(x) 120 | 121 | 122 | class LinearAttention(nn.Module): 123 | def __init__(self, dim, heads=4, dim_head=32): 124 | super().__init__() 125 | self.heads = heads 126 | hidden_dim = dim_head * heads 127 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 128 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 129 | 130 | def forward(self, x): 131 | b, c, h, w = x.shape 132 | qkv = self.to_qkv(x) 133 | q, k, v = rearrange( 134 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 135 | ) 136 | k = k.softmax(dim=-1) 137 | context = torch.einsum("bhdn,bhen->bhde", k, v) 138 | out = torch.einsum("bhde,bhdn->bhen", context, q) 139 | out = rearrange( 140 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 141 | ) 142 | return self.to_out(out) 143 | 144 | 145 | # model 146 | 147 | 148 | # FIXME: Tune groups 149 | class Unet(nn.Module): 150 | def __init__( 151 | self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), groups=8, in_channel=3 152 | ): 153 | super().__init__() 154 | dims = [in_channel, *map(lambda m: dim * m, dim_mults)] 155 | in_out = list(zip(dims[:-1], dims[1:])) 156 | 157 | self.time_pos_emb = SinusoidalPosEmb(dim) 158 | self.mlp = nn.Sequential( 159 | nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim) 160 | ) 161 | 162 | self.downs = nn.ModuleList([]) 163 | self.ups = nn.ModuleList([]) 164 | num_resolutions = len(in_out) 165 | 166 | for ind, (dim_in, dim_out) in enumerate(in_out): 167 | is_last = ind >= (num_resolutions - 1) 168 | 169 | self.downs.append( 170 | nn.ModuleList( 171 | [ 172 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim), 173 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim), 174 | Residual(Rezero(LinearAttention(dim_out))), 175 | Downsample(dim_out) if not is_last else nn.Identity(), 176 | ] 177 | ) 178 | ) 179 | 180 | mid_dim = dims[-1] 181 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 182 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 183 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 184 | 185 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 186 | is_last = ind >= (num_resolutions - 1) 187 | 188 | self.ups.append( 189 | nn.ModuleList( 190 | [ 191 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), 192 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim), 193 | Residual(Rezero(LinearAttention(dim_in))), 194 | Upsample(dim_in) if not is_last else nn.Identity(), 195 | ] 196 | ) 197 | ) 198 | 199 | out_dim = default(out_dim, in_channel) 200 | self.final_conv = nn.Sequential(Block(dim, dim), nn.Conv2d(dim, out_dim, 1)) 201 | 202 | def forward(self, x, time): 203 | time = time.expand((x.shape[0],)) 204 | t = self.time_pos_emb(time) 205 | t = self.mlp(t) 206 | 207 | h = [] 208 | 209 | for resnet, resnet2, attn, downsample in self.downs: 210 | x = resnet(x, t) 211 | x = resnet2(x, t) 212 | x = attn(x) 213 | h.append(x) 214 | x = downsample(x) 215 | 216 | x = self.mid_block1(x, t) 217 | x = self.mid_attn(x) 218 | x = self.mid_block2(x, t) 219 | 220 | for resnet, resnet2, attn, upsample in self.ups: 221 | x = torch.cat((x, h.pop()), dim=1) 222 | x = resnet(x, t) 223 | x = resnet2(x, t) 224 | x = attn(x) 225 | x = upsample(x) 226 | 227 | return self.final_conv(x) 228 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "diffflow" 3 | version = "0.0.1" 4 | description = "" 5 | authors = ["qzhang419@gatech.edu"] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.8,<3.10" 9 | jammy = "0.0.2" 10 | tqdm = "^4.11.0" 11 | scikit-learn = "^0.24.2" 12 | scikit-image = "^0.18.0" 13 | attrs = "^20.3.0" 14 | torch-fidelity = "^0.3.0" 15 | 16 | [tool.poetry.dev-dependencies] 17 | pylint = "^2.8.3" 18 | pre-commit = "^2.13.0" 19 | isort = "^5.8.0" 20 | black = "^21.5b1" 21 | 22 | [build-system] 23 | requires = ["poetry-core>=1.0.0"] 24 | build-backend = "poetry.core.masonry.api" 25 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Diffusion Normalizing Flow (DiffFlow) 2 | 3 | ![DiffFlow](asserts/DiffFlow_fig1.png) 4 | ![DiffFlow](asserts/DiffFlow_tree.png) 5 | 6 | ## Reproduce 7 | 8 | ### setup environment 9 | 10 | The repo heavily depends on [jam](https://github.com/qsh-zh/jam), a personal toolbox developed by [Qsh.zh](https://github.com/qsh-zh). The API may change and check the [jammy](https://jammy.readthedocs.io/en/stable/index.html) version for running the repo. 11 | 12 | *pip* 13 | ```shell 14 | pip install . 15 | ``` 16 | 17 | *[poetry](https://python-poetry.org/)* 18 | ```shell 19 | curl -fsS -o /tmp/get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py 20 | python3 /tmp/get-poetry.py -y --no-modify-path 21 | export PATH=$HOME/.poetry/bin:$PATH 22 | poetry shell 23 | poetry install 24 | ``` 25 | 26 | ### Run 27 | 28 | ```shell 29 | python main.py trainer.epochs=100 data.dataset=tree 30 | ``` 31 | 32 | The repo supports viz results on [wandb](https://wandb.ai/site) 33 | ```shell 34 | python main.py trainer.epochs=100 data.dataset=tree log=true wandb.project=pub_diff wandb.name=tree 35 | ``` 36 | 37 | There are some [results](https://wandb.ai/qinsheng/pub_diff?workspace=user-qinsheng) reproduced by the repo. 38 | 39 | 40 | ## Reference 41 | 42 | ```tex 43 | @inproceedings{zhang2021diffusion, 44 | author = {Qinsheng Zhang and Yongxin Chen}, 45 | title = {Diffusion Normalizing Flow}, 46 | booktitle = {Advances in Neural Information Processing Systems}, 47 | year = {2021} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/utils/__init__.py -------------------------------------------------------------------------------- /utils/ddp_trainer.py: -------------------------------------------------------------------------------- 1 | from jammy.utils.meta import Singleton 2 | from jamtorch.ddp.ema_trainer import EMATrainer 3 | from jamtorch.trainer import LossException 4 | from jamtorch.trainer.trainer_monitor import TrainerMonitor 5 | from retry.api import retry_call 6 | 7 | 8 | class Trainer(EMATrainer, metaclass=Singleton): 9 | def __init__(self, cfg, loss_fn): 10 | super().__init__(cfg, loss_fn) 11 | self.trainer_monitor = None 12 | 13 | def monitor_update(self): 14 | if self.trainer_monitor: 15 | self.trainer_monitor.update( 16 | { 17 | **self.cur_monitor, # pylint: disable=access-member-before-definition 18 | "epoch": self.epoch_cnt, 19 | "iter": self.iter_cnt, 20 | } 21 | ) 22 | self.cur_monitor = dict() # pylint: disable=attribute-defined-outside-init 23 | 24 | def set_monitor(self, is_wandb, tblogger=False): 25 | """ 26 | docstring 27 | """ 28 | self.trainer_monitor = TrainerMonitor(is_wandb, tblogger) 29 | 30 | def _impl_load_ckpt(self, state): 31 | # do not overwrite the time coef 32 | state["model"]["timestamps"] = self.model.module.timestamps 33 | state["model"]["diffusion"] = self.model.module.diffusion 34 | state["model"]["condition"] = self.model.module.condition 35 | state["model"]["delta_t"] = self.model.module.delta_t 36 | super()._impl_load_ckpt(state) 37 | 38 | def train_step(self, feed_dict): 39 | retry_call( 40 | super().train_step, fargs=[feed_dict], tries=3, exceptions=LossException 41 | ) 42 | -------------------------------------------------------------------------------- /utils/diagnosis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jamtorch import as_cpu, no_grad_func 3 | 4 | # pylint: disable=too-many-arguments 5 | # pylint: disable=too-many-locals 6 | 7 | 8 | @no_grad_func 9 | def forward_whole_process(model, x, timestamps, diffusion, condition): 10 | x = x.clone() # avoid overwrite the origin data 11 | # TODO: dict can be slow 12 | rtn = {"data": [x.clone()], "grad": [], "noise": [], "step_size": []} 13 | delta_t = timestamps[1:] - timestamps[:-1] 14 | for i_th, cur_delta_t in enumerate(delta_t): 15 | cond_f, diff_f = condition[i_th], diffusion[i_th] 16 | grad = model.drift(x, cond_f) 17 | grad_step = grad * cur_delta_t 18 | noise = torch.sqrt(cur_delta_t) * diff_f * model.sample_noise(x.shape[0]) 19 | x += grad_step + noise 20 | rtn["data"].append(x.clone()) 21 | rtn["grad"].append(grad) 22 | rtn["noise"].append((noise / cur_delta_t)) 23 | rtn["step_size"].append(cur_delta_t) 24 | 25 | return rtn 26 | 27 | 28 | @no_grad_func 29 | def backward_whole_process( 30 | model, z, timestamps, diffusion, condition, drift_only=False, score_only=False 31 | ): 32 | rtn = { 33 | "data": [z.clone()], 34 | "grad": [], 35 | "noise": [], 36 | "drift": [], 37 | "diff": [], 38 | "score": [], 39 | "noise_step": [], 40 | "step_size": [], 41 | } 42 | z = z.clone() 43 | delta_t = timestamps[1:] - timestamps[:-1] 44 | for i_th, cur_delta_t in enumerate(torch.flip(delta_t, (0,))): 45 | cond_b, diff_b = condition[-i_th - 1], diffusion[-i_th - 1] 46 | drift = model.drift(z, cond_b) 47 | score = model.score(z, cond_b) 48 | diff = -(diff_b ** 2) * score 49 | if drift_only: 50 | grad = drift 51 | else: 52 | grad = drift + diff 53 | if score_only: 54 | grad = diff 55 | grad_step = grad * cur_delta_t 56 | noise = torch.sqrt(cur_delta_t) * diff_b * model.sample_noise(z.shape[0]) 57 | z_mean = z - grad_step 58 | z = z_mean + noise 59 | rtn["data"].append(z_mean.clone()) 60 | rtn["grad"].append(grad) 61 | rtn["drift"].append(drift) 62 | rtn["diff"].append(diff) 63 | rtn["score"].append(-score) 64 | rtn["noise"].append(noise / cur_delta_t) 65 | rtn["step_size"].append(cur_delta_t) 66 | return rtn 67 | 68 | 69 | def fb_whole_process(model, x, timestamps, diffusion, condition, is_gt=True): 70 | f_process = forward_whole_process(model, x, timestamps, diffusion, condition) 71 | z = f_process["data"][-1] 72 | if not is_gt: 73 | z = torch.randn_like(z) 74 | b_process = backward_whole_process(model, z, timestamps, diffusion, condition) 75 | return f_process, b_process 76 | 77 | 78 | # def fb_whole_process(model, x, timestamps, diffusion, condition, is_gt=True): 79 | # f_process = forward_whole_process(model, x, timestamps, diffusion, condition) 80 | # z = f_process["data"][-1] 81 | # if not is_gt: 82 | # z = torch.randn_like(z) 83 | # b_process = backward_whole_process(model, z, timestamps, diffusion, condition) 84 | 85 | # # convert data device and follow the same order 86 | # f_process = as_cpu(f_process) 87 | # b_process = as_cpu(b_process) 88 | # for _, item in b_process.items(): 89 | # item.reverse() 90 | 91 | # composite = { 92 | # "f_data": f_process["data"], 93 | # "b_data": b_process["data"], 94 | # "f_grad": f_process["grad"], 95 | # "b_grad": b_process["grad"], 96 | # "b_drift": b_process["drift"] 97 | # } 98 | # return composite 99 | 100 | 101 | def ema_whole_process(model, load_ema_fn, z, timestamps, diffusion, condition): 102 | """Only used in check ema in cmp_fb_process""" 103 | z = torch.randn_like(z) 104 | non_ema = backward_whole_process(model, z, timestamps, diffusion, condition) 105 | non_ema = as_cpu(non_ema) 106 | 107 | load_ema_fn(model) 108 | ema = backward_whole_process(model, z, timestamps, diffusion, condition) 109 | ema = as_cpu(ema) 110 | 111 | composite = { 112 | "ema_x": ema["data"], 113 | "non_ema_x": non_ema["data"], 114 | "ema_grad": ema["grad"], 115 | "non_ema_grad": non_ema["grad"], 116 | "ema_drift": ema["drift"], 117 | "non_ema_drift": non_ema["drift"], 118 | } 119 | return composite 120 | 121 | 122 | @no_grad_func 123 | def forward_data_process(model, x, timestamps, diffusion, condition): 124 | x = x.clone() # avoid overwrite the origin data 125 | rtn = [x.clone()] 126 | delta_t = timestamps[1:] - timestamps[:-1] 127 | for i_th, cur_delta_t in enumerate(delta_t): 128 | cond_f, diff_f = condition[i_th], diffusion[i_th] 129 | grad = model.drift(x, cond_f) 130 | grad_step = grad * cur_delta_t 131 | noise = torch.sqrt(cur_delta_t) * diff_f * model.sample_noise(x.shape[0]) 132 | x += grad_step + noise 133 | rtn.append(x.clone()) 134 | 135 | return rtn 136 | 137 | 138 | @no_grad_func 139 | def forward_x2z(model, x, timestamps, diffusion, condition): 140 | x = x.clone() # avoid overwrite the origin data 141 | delta_t = timestamps[1:] - timestamps[:-1] 142 | for i_th, cur_delta_t in enumerate(delta_t): 143 | cond_f, diff_f = condition[i_th], diffusion[i_th] 144 | grad = model.drift(x, cond_f) 145 | grad_step = grad * cur_delta_t 146 | noise = torch.sqrt(cur_delta_t) * diff_f * model.sample_noise(x.shape[0]) 147 | x += grad_step + noise 148 | return x 149 | 150 | 151 | @no_grad_func 152 | def backward_z2x(model, z, timestamps, diffusion, condition): 153 | # from noise to data 154 | z = z.clone() 155 | if len(timestamps) < 2: 156 | return z 157 | delta_t = timestamps[1:] - timestamps[:-1] 158 | for i_th, cur_delta_t in enumerate(torch.flip(delta_t, (0,))): 159 | cond_b, diff_b = condition[-i_th - 1], diffusion[-i_th - 1] 160 | drift = model.drift(z, cond_b) 161 | score = model.score(z, cond_b) 162 | diff = -(diff_b ** 2) * score 163 | grad = drift + diff 164 | grad_step = grad * cur_delta_t 165 | noise = torch.sqrt(cur_delta_t) * diff_b * model.sample_noise(z.shape[0]) 166 | z_mean = z - grad_step 167 | z = z_mean + noise 168 | return z_mean 169 | 170 | 171 | def recon_x(model, x, timestamps, diffusion, condition): 172 | z = forward_x2z(model, x, timestamps, diffusion, condition) 173 | return backward_z2x(model, z, timestamps, diffusion, condition) 174 | 175 | 176 | @no_grad_func 177 | def backward_deterministic_process(model, z, timestamps, diffusion, condition): 178 | rtn = { 179 | "data": [z.clone()], 180 | "grad": [], 181 | "grad_step": [], 182 | "step_size": [], 183 | } 184 | z = z.clone() 185 | delta_t = timestamps[1:] - timestamps[:-1] 186 | for i_th, cur_delta_t in enumerate(torch.flip(delta_t, (0,))): 187 | cond_b, diff_b = condition[-i_th - 1], diffusion[-i_th - 1] 188 | grad = model.drift(z, cond_b) - 0.5 * diff_b ** 2 * model.score(z, cond_b) 189 | grad_step = grad * cur_delta_t 190 | z += -grad_step 191 | rtn["data"].append(z.clone()) 192 | rtn["grad"].append(grad) 193 | rtn["grad_step"].append(grad_step) 194 | rtn["step_size"].append(cur_delta_t) 195 | return rtn 196 | 197 | 198 | @no_grad_func 199 | def langevin_process(model, z, idx, steps, snr, condition=None, all_img=False): 200 | condition = model.condition if condition is None else condition 201 | cond_b = condition[-idx - 1] 202 | if all_img: 203 | rtn = { 204 | "data": [z.clone()], 205 | "data_mean": [], 206 | "grad": [], 207 | "step_size": [], 208 | } 209 | z = z.clone() 210 | for _ in range(steps): 211 | noise = model.sample_noise(z.shape[0]) 212 | grad = model.score(z, cond_b) 213 | noise_norm = torch.mean(torch.norm(noise.flatten(start_dim=1), dim=1)) 214 | grad_norm = torch.mean(torch.norm(grad.flatten(start_dim=1), dim=1)) 215 | 216 | step_size = (snr * noise_norm / grad_norm) ** 2 * 2 217 | 218 | z_mean = z + grad * step_size 219 | 220 | z = z_mean + torch.sqrt(2 * step_size) * noise 221 | 222 | if all_img: 223 | rtn["data"].append(z.clone()) 224 | rtn["data_mean"].append(z_mean) 225 | rtn["grad"].append(grad) 226 | rtn["step_size"].append(step_size.item()) 227 | 228 | if all_img: 229 | return rtn 230 | return z_mean 231 | -------------------------------------------------------------------------------- /utils/img_viz.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/utils/img_viz.py -------------------------------------------------------------------------------- /utils/scalars.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | import torch 4 | from jammy import jam_instantiate 5 | from scipy.interpolate import interp1d 6 | 7 | # pylint: disable=attribute-defined-outside-init, too-few-public-methods 8 | 9 | 10 | def create_alpha_schedule(num_steps=100, t_start=0.0001, t_end=0.02): 11 | betas = np.linspace(t_start, t_end, num_steps) 12 | result = [1.0] 13 | alpha = 1.0 14 | for beta in betas: 15 | alpha *= 1 - beta 16 | result.append(alpha) 17 | return torch.FloatTensor(result) 18 | 19 | 20 | def timestamp_fn(num_steps=100, t_start=0.0001, t_end=0.02): 21 | betas = np.linspace(t_start, t_end, num_steps) 22 | low_tri = np.tril(np.ones((num_steps, num_steps))) 23 | m = np.concatenate((np.zeros(num_steps).reshape(1, -1), low_tri), axis=0) 24 | times = m @ betas.reshape(-1, 1) 25 | assert times.size == num_steps + 1 26 | return torch.FloatTensor(times).flatten() 27 | 28 | 29 | def diffusion_fn(num_steps=100, t_start=0.0001, t_end=0.02): 30 | betas = np.linspace(t_start, t_end, num_steps + 1) 31 | return torch.sqrt(torch.FloatTensor(betas)) 32 | 33 | 34 | def squareliner_fn(num_steps=100, t_start=0.0001, t_end=0.02): 35 | square = np.linspace(np.sqrt(t_start), np.sqrt(t_end), num_steps + 1) 36 | return torch.pow(torch.FloatTensor(square), 2) 37 | 38 | 39 | def linear_fn(num_steps=100, t_start=0.0001, t_end=0.02): 40 | square = np.linspace(t_start, t_end, num_steps + 1) 41 | return torch.FloatTensor(square) 42 | 43 | 44 | def exp_fn(num_steps=100, t_start=0.0001, t_end=0.02, exp=0.9): 45 | base = np.linspace(t_start ** exp, t_end ** exp, num_steps + 1) 46 | return torch.pow(torch.FloatTensor(base), 1.0 / exp) 47 | 48 | 49 | def g_square_fn(num_steps=100, t_start=0.0001, t_end=0.02, exp=0.9): 50 | avg = (t_end - t_start) / num_steps 51 | time = exp_fn(num_steps, t_start, t_end, exp) 52 | dt = time[1:] - time[:-1] 53 | g = (avg / dt) ** 1.5 54 | return torch.cat([g[-1:] * 1.2, g]) 55 | 56 | 57 | @attr.s 58 | class ExpTimer: 59 | num_steps = attr.ib(100) 60 | t_start = attr.ib(0.0001) 61 | t_end = attr.ib(0.02) 62 | exp = attr.ib(0.5) 63 | 64 | def __attrs_post_init__(self): 65 | self.base = torch.linspace( 66 | self.t_start ** self.exp, self.t_end ** self.exp, self.num_steps 67 | ) 68 | self.fix_x_slot = torch.linspace( 69 | self.t_start ** self.exp, self.t_end ** self.exp, self.num_steps + 1 70 | ) 71 | self.intervals = self.base[1:] - self.base[:-1] 72 | 73 | def __call__(self): 74 | value = torch.pow(self.fix_x_slot, 1.0 / self.exp) 75 | return self.deal_flip(value) 76 | 77 | def deal_flip(self, value): 78 | if self.exp > 1.0: 79 | value = self.t_start + self.t_end - value 80 | value = torch.flip(value, (0,)) 81 | return value 82 | 83 | def rand(self): 84 | ratio = torch.rand(self.num_steps - 1) 85 | mid_times = ratio * self.intervals + self.base[:-1] 86 | times = torch.cat([self.base[:1], mid_times, self.base[-1:]]).flatten() 87 | value = torch.pow(times, 1.0 / self.exp) 88 | return self.deal_flip(value) 89 | 90 | def index(self, time): 91 | if np.isclose(self.t_start, self.t_end): 92 | return torch.pow(self.base[-1], 1.0 / self.exp) * torch.ones_like(time) 93 | time = torch.clip(time, self.t_start, self.t_end) 94 | return time 95 | # base = time ** self.exp 96 | # ratio = (base - base[0]) / (base[-1] - base[0]) 97 | # times = ratio * (self.base[-1] - self.base[0]) + self.base[0] 98 | # return torch.pow(times, 1.0/ self.exp) 99 | 100 | 101 | @attr.s 102 | class SCurve: 103 | num_steps = attr.ib(100) 104 | t_start = attr.ib(0.0001) 105 | t_end = attr.ib(0.02) 106 | exp = attr.ib(0.5) 107 | 108 | def __attrs_post_init__(self): 109 | avg = (self.t_end - self.t_start) / self.num_steps 110 | self.ratio_x = [0.0, 0.2, 0.9, 1] 111 | # self.ratio_y = [0., 0.1, 0.1, 1] 112 | self.ratio_y = [0.0, avg, avg, 0.2] 113 | int_y = np.interp( 114 | np.linspace(0, 1, self.num_steps + 1), self.ratio_x, self.ratio_y 115 | ) 116 | # delta = int_y * avg / 0.1 117 | delta = int_y 118 | self._time = np.cumsum(delta) + self.t_start 119 | 120 | def __call__(self): 121 | return torch.from_numpy(self._time).float() 122 | 123 | def rand(self): 124 | midpoints = np.linspace(0, 1, self.num_steps) 125 | delta_t = midpoints[1:] - midpoints[:-1] 126 | ratio = np.random.rand(self.num_steps - 1) 127 | mid_timestamps = delta_t * ratio + midpoints[:-1] 128 | timestamps = np.concatenate([[0], mid_timestamps, [1]]).flatten() 129 | return torch.from_numpy( 130 | np.interp(timestamps, np.linspace(0, 1, self.num_steps + 1), self._time) 131 | ).float() 132 | 133 | def index(self, time): 134 | if np.isclose(self.t_start, self.t_end): 135 | return torch.ones_like(time) * self.t_start 136 | return time 137 | 138 | 139 | @attr.s 140 | class FTimer: 141 | num_steps = attr.ib(100) 142 | t_start = attr.ib(0.0001) 143 | t_end = attr.ib(0.02) 144 | exp = attr.ib(0.5) 145 | 146 | def __attrs_post_init__(self): 147 | first_p = int(self.num_steps * 0.9) 148 | self.t1 = ExpTimer(first_p, self.t_start, self.t_end, self.exp) 149 | self.t2 = ExpTimer(self.num_steps - first_p, self.t_end, 0.5, 1.0 / self.exp) 150 | 151 | def __call__(self): 152 | return torch.cat([self.t1()[:-1], self.t2()]) 153 | 154 | 155 | @attr.s 156 | class STimer: # pylint: disable= too-many-instance-attributes 157 | num_steps = attr.ib(50) 158 | t_start = attr.ib(0.0001) 159 | t_end = attr.ib(0.02) 160 | up = attr.ib(True) # pylint: disable= invalid-name 161 | 162 | def __attrs_post_init__(self): 163 | if self.up: 164 | x = [0, 0.1, 0.20, 0.4, 0.70, 0.9, 1.0] 165 | y = [1e-4, 0.05, 0.10, 0.6, 0.93, 0.98, 1.0] 166 | else: 167 | x = [0, 0.1, 0.30, 0.6, 0.93, 0.98, 1.0] 168 | y = [1e-4, 0.2, 0.35, 0.4, 0.70, 0.9, 1.0] 169 | self.interp1d_fn = interp1d(x, y, kind="cubic") 170 | fix_x_slot = np.linspace(1e-4, 1.0, self.num_steps + 1) 171 | dt = self.interp1d_fn(fix_x_slot) 172 | dt_sum = np.sum(dt) + self.t_start 173 | self.scale = self.t_end / dt_sum 174 | self.fix_t_slot = (np.cumsum(dt) + self.t_start) * self.scale 175 | self.time_fn = interp1d(fix_x_slot, self.fix_t_slot) 176 | self.fix_x_slot = fix_x_slot 177 | 178 | # for dealing with t 179 | random_x_slot = np.linspace(1e-4, 1.0, self.num_steps) 180 | self.random_t_slot = torch.from_numpy(self.time_fn(random_x_slot)).float() 181 | self.random_t_interval = self.random_t_slot[1:] - self.random_t_slot[:-1] 182 | 183 | def __call__(self): 184 | return torch.from_numpy(self.fix_t_slot).float() 185 | 186 | def rand(self): 187 | ratio = torch.rand(self.num_steps - 1) 188 | mid_times = ratio * self.random_t_interval + self.random_t_slot[:-1] 189 | return torch.cat( 190 | [self.random_t_slot[:1], mid_times, self.random_t_slot[-1:]] 191 | ).flatten() 192 | 193 | def index(self, time): 194 | if np.isclose(self.t_start, self.t_end): 195 | return self.t_start * torch.ones_like(time) 196 | time = torch.clip(time, self.t_start, self.t_end) 197 | return time 198 | 199 | 200 | def instantiate_scaler(cfg): 201 | timer = jam_instantiate(cfg.model.time_fn) 202 | differ = jam_instantiate(cfg.model.diff_fn) 203 | conder = jam_instantiate(cfg.model.cond_fn) 204 | return timer, differ, conder 205 | 206 | 207 | def scalar_helper(model): 208 | return model.timestamps, model.diffusion, model.condition 209 | -------------------------------------------------------------------------------- /utils/sdefunction.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import jamtorch.ddp.ddp_utils as ddp_utils 3 | import numpy as np 4 | import torch 5 | import torch.cuda.amp as amp 6 | from jamtorch.utils.meta import as_float 7 | 8 | # if ddp_utils.is_master(): 9 | 10 | # def trainer_stat(trainer, stat): 11 | # trainer.cur_monitor.update(stat) 12 | 13 | 14 | # else: 15 | 16 | # def trainer_stat(trainer, state): 17 | # pass 18 | 19 | 20 | # trainer = None 21 | 22 | 23 | class SdeF(torch.autograd.Function): 24 | @staticmethod 25 | @amp.custom_fwd 26 | def forward(ctx, x, model, timestamps, diffusion, condition, *model_parameter): 27 | shapes = [y0_.shape for y0_ in model_parameter] 28 | 29 | def _flatten(parameter): 30 | # flatten the gradient dict and parameter dict 31 | return torch.cat( 32 | [ 33 | param.flatten() if param is not None else x.new_zeros(shape.numel()) 34 | for param, shape in zip(parameter, shapes) 35 | ] 36 | ) 37 | 38 | def _unflatten(tensor, length): 39 | # return object like parameter groups 40 | tensor_list = [] 41 | total = 0 42 | for shape in shapes: 43 | next_total = total + shape.numel() 44 | # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails. 45 | tensor_list.append( 46 | tensor[..., total:next_total].view((*length, *shape)) 47 | ) 48 | total = next_total 49 | return tuple(tensor_list) 50 | 51 | history_x_state = x.new_zeros(len(timestamps) - 1, *x.shape) 52 | rtn_logabsdet = x.new_zeros(x.shape[0]) 53 | delta_t = timestamps[1:] - timestamps[:-1] 54 | new_x = x 55 | with torch.no_grad(): 56 | for i_th, cur_delta_t in enumerate(delta_t): 57 | history_x_state[i_th] = new_x 58 | new_x, new_logabsdet = model.forward_step( 59 | new_x, 60 | cur_delta_t, 61 | condition[i_th], 62 | condition[i_th + 1], 63 | diffusion[i_th], 64 | diffusion[i_th + 1], 65 | ) 66 | rtn_logabsdet += new_logabsdet 67 | ctx.model = model 68 | ctx._flatten = _flatten 69 | ctx._unflatten = _unflatten 70 | ctx.nparam = np.sum([shape.numel() for shape in shapes]) 71 | ctx.save_for_backward( 72 | history_x_state.clone(), new_x.clone(), timestamps, diffusion, condition 73 | ) 74 | return new_x, rtn_logabsdet 75 | 76 | @staticmethod 77 | @amp.custom_bwd 78 | def backward(ctx, dL_dz, dL_logabsdet): 79 | history_x_state, z, timestamps, diffusion, condition = ctx.saved_tensors 80 | dL_dparameter = dL_dz.new_zeros((1, ctx.nparam)) 81 | 82 | model, _flatten, _unflatten = ctx.model, ctx._flatten, ctx._unflatten 83 | model_parameter = tuple(model.parameters()) 84 | delta_t = timestamps[1:] - timestamps[:-1] 85 | b_noise = {} 86 | with torch.no_grad(): 87 | for bi_th, cur_delta_t in enumerate(torch.flip(delta_t, (0,))): 88 | bi_th += 1 89 | with torch.set_grad_enabled(True): 90 | x = history_x_state[-bi_th].requires_grad_(True) 91 | z = z.requires_grad_(True) 92 | noise_b = model.cal_backnoise( 93 | x, z, cur_delta_t, condition[-bi_th], diffusion[-bi_th] 94 | ) 95 | 96 | cur_delta_s = -0.5 * ( 97 | torch.sum(noise_b.flatten(start_dim=1) ** 2, dim=1) 98 | ) 99 | dl_dprev_state, dl_dnext_state, *dl_model_b = torch.autograd.grad( 100 | (cur_delta_s), 101 | (x, z) + model_parameter, 102 | grad_outputs=(dL_logabsdet), 103 | allow_unused=True, 104 | retain_graph=True, 105 | ) 106 | dl_dx, *dl_model_f = torch.autograd.grad( 107 | ( 108 | model.cal_next_nodiffusion( 109 | x, cur_delta_t, condition[-bi_th - 1] 110 | ) 111 | ), 112 | (x,) + model_parameter, 113 | grad_outputs=(dl_dnext_state + dL_dz), 114 | allow_unused=True, 115 | retain_graph=True, 116 | ) 117 | del x, z, dl_dnext_state 118 | b_noise[f"stat/{bi_th}"] = -1 * cur_delta_s.mean() 119 | z = history_x_state[-bi_th] 120 | dL_dz = dl_dx + dl_dprev_state 121 | dL_dparameter += _flatten(dl_model_b).unsqueeze(0) + _flatten( 122 | dl_model_f 123 | ).unsqueeze(0) 124 | 125 | # trainer_stat(trainer, as_float(b_noise)) 126 | 127 | return (dL_dz, None, None, None, None, *_unflatten(dL_dparameter, (1,))) 128 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | from jammy import Singleton 2 | from jammy.utils.retry import retry_call 3 | from jamtorch.trainer import LossException 4 | from jamtorch.trainer.ema_trainer import EMATrainer 5 | 6 | 7 | class Trainer(EMATrainer, metaclass=Singleton): 8 | def __init__(self, cfg, loss_fn): 9 | super().__init__(cfg, loss_fn) 10 | self.rank = 0 11 | self.is_master = True # for sync ddp_trainer 12 | 13 | def _impl_load_ckpt(self, state): 14 | # do not overwrite the time coef 15 | state["model"]["timestamps"] = self.model.timestamps 16 | state["model"]["diffusion"] = self.model.diffusion 17 | state["model"]["condition"] = self.model.condition 18 | state["model"]["delta_t"] = self.model.delta_t 19 | super()._impl_load_ckpt(state) 20 | 21 | def train_step(self, feed_dict): 22 | retry_call( 23 | super().train_step, fargs=[feed_dict], tries=3, exceptions=LossException 24 | ) 25 | -------------------------------------------------------------------------------- /viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsh-zh/DiffFlow/c45af9dad20bb63da46c0ed9209a6b168eea2430/viz/__init__.py -------------------------------------------------------------------------------- /viz/img.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import imageio 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from jammy import stmap 9 | from jammy.image import imwrite, nd2pil, plt2nd 10 | from jammy.logging import Wandb 11 | from torchvision.utils import make_grid, save_image 12 | 13 | from datasets import inverse_data_transform 14 | 15 | plt.switch_backend("agg") 16 | 17 | 18 | def tensor2imgnd(tensor, n_rows, n_cols): # pylint: disable=unused-argument 19 | grid = make_grid(tensor, n_rows) 20 | ndarr = ( 21 | grid.mul(255) 22 | .add_(0.5) 23 | .clamp_(0, 255) 24 | .permute(1, 2, 0) 25 | .to("cpu", torch.uint8) 26 | .numpy() 27 | ) 28 | return ndarr 29 | 30 | 31 | def kv_img2gif( 32 | kv_tensor_imgs, fname, img_row, img_col, keys 33 | ): # pylint: disable=too-many-locals 34 | save_imgs = [] 35 | input_imgs = {key: torch.stack(kv_tensor_imgs[key]) for key in keys} 36 | length = min([input_imgs[key].shape[0] for key in keys]) 37 | num_keys = len(keys) 38 | dpi = 128 39 | img_size = 400 40 | for i in range(length): 41 | fig, axs = plt.subplots( 42 | 1, num_keys, figsize=(num_keys * img_size / dpi, img_size / dpi), dpi=dpi 43 | ) 44 | for j in range(num_keys): 45 | cur_img = input_imgs[keys[j]][i] # j-th key, i-th image 46 | img2show = tensor2imgnd(cur_img, img_row, img_col) 47 | axs[j].imshow(img2show) 48 | axs[j].axes.xaxis.set_visible(False) 49 | axs[j].axes.yaxis.set_visible(False) 50 | 51 | fig.suptitle(f"t={i:03d} {' '.join(keys)}") 52 | save_imgs.append(np.asarray(plt2nd(fig))) 53 | plt.close() 54 | imageio.mimsave(f"{fname}.gif", save_imgs + ([save_imgs[-1]] * 5), fps=1) 55 | 56 | 57 | def viz_img_process(procss_kv, gif_file, num_grid, keys, reverse_transform_fn): 58 | imgs = stmap(reverse_transform_fn, procss_kv) 59 | kv_img2gif(imgs, gif_file, num_grid, num_grid, list(keys)) 60 | 61 | 62 | def wandb_write_ndimg(img, epoch_cnt, naming): 63 | if Wandb.IS_ACTIVE: 64 | wandb.log( 65 | {naming: wandb.Image(nd2pil(img), caption=f"{naming}_{epoch_cnt:05}.png")} 66 | ) 67 | imwrite(f"{naming}_{epoch_cnt:03}.png", img) 68 | 69 | 70 | def save_seperate_imgs(sample, sample_path, cnt): 71 | batch_size = len(sample) 72 | for i in range(batch_size): 73 | save_image(sample[i], osp.join(sample_path, f"{cnt:07d}.png")) 74 | cnt += 1 75 | 76 | 77 | def check_unnormal_imgs(cfg, x, num_grid, num_iter, fname): 78 | trans_x = inverse_data_transform(cfg.data, x) 79 | noise_nd = tensor2imgnd(trans_x, num_grid, num_grid) 80 | wandb_write_ndimg(noise_nd, num_iter, fname) 81 | -------------------------------------------------------------------------------- /viz/lines.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from jammy.logging import wandb_plt 4 | from jamtorch.utils import as_numpy 5 | 6 | 7 | @wandb_plt 8 | def draw_line(data, title, caption=None): 9 | fig, axs = plt.subplots(1, 1) 10 | axs.plot(data) 11 | fig.suptitle(title) 12 | fig.savefig(f"{title}.png") 13 | if caption is None: 14 | caption = title 15 | return fig, title 16 | 17 | 18 | def check_dflow_coef(model, prefix_caption=None): 19 | if hasattr(model, "module"): 20 | model = model.module 21 | for name in ["timestamps", "diffusion", "condition", "delta_t"]: 22 | if hasattr(model, name): 23 | if prefix_caption is None: 24 | caption = name 25 | else: 26 | caption = f"{name}_{prefix_caption}" 27 | draw_line(as_numpy(getattr(model, name)), name, caption) 28 | 29 | 30 | def plt_scalars(scalars, names): 31 | """viz scalars and names plot figure 32 | 33 | :param scalars: [description] 34 | :type scalars: List[Tensor,ndarray] 35 | :param names: names of scalars 36 | :type names: List[string] 37 | :return: plt,fig 38 | """ 39 | length = len(names) 40 | if isinstance(scalars[0], torch.Tensor): 41 | scalars = as_numpy(scalars) 42 | fig, axs = plt.subplots(1, length, figsize=(length * 7, 1 * 7)) 43 | for i_th, cur_data in enumerate(scalars): 44 | axs[i_th].plot(cur_data) 45 | axs[i_th].set_title(names[i_th]) 46 | return fig 47 | 48 | 49 | def plt_model_scalars(model): 50 | if isinstance(model, dict): 51 | timestamps = model["timestamps"].cpu().numpy() 52 | diffusion = model["diffusion"].cpu().numpy() 53 | condition = model["condition"].cpu().numpy() 54 | delta_t = timestamps[1:] - timestamps[:-1] 55 | scalars = [timestamps, diffusion, condition, delta_t] 56 | labels = ["timestamps", "diffusion", "condition", "delta_t"] 57 | return plt_scalars(scalars, labels) 58 | -------------------------------------------------------------------------------- /viz/ps.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from jammy.io import get_name 4 | from jammy.logging import wandb_plt 5 | 6 | import datasets.points_dataset as ps_dataset 7 | 8 | 9 | # pylint: disable=no-member 10 | def fix_ax_lim(ax): 11 | ax.set_xlim(ps_dataset.DIM_LINSPACE[0], ps_dataset.DIM_LINSPACE[-1]) 12 | ax.set_ylim(ps_dataset.DIM_LINSPACE[0], ps_dataset.DIM_LINSPACE[-1]) 13 | 14 | 15 | @wandb_plt 16 | def viz_sample(sample, title_name, fig_name, sample_num=50000, fix_lim=True): 17 | sample = ps_dataset.restore(sample) 18 | fig, ax = plt.subplots(1, 2, figsize=(14, 7)) 19 | ax[0].set_title(title_name) 20 | ax[0].hist2d( 21 | sample[:sample_num, 0], 22 | sample[:sample_num, 1], 23 | bins=ps_dataset.DIM_LINSPACE, 24 | cmap=plt.cm.jet, 25 | ) 26 | ax[0].set_facecolor(plt.cm.jet(0.0)) 27 | ax[1].plot( 28 | sample[:sample_num, 0], 29 | sample[:sample_num, 1], 30 | linewidth=0, 31 | marker=".", 32 | markersize=1, 33 | ) 34 | if fix_lim: 35 | fix_ax_lim(ax[1]) 36 | fig.suptitle(title_name) 37 | fig.savefig(fig_name) 38 | plt.axis("off") 39 | return fig, get_name(fig_name) 40 | 41 | 42 | # @wandb_fig 43 | # def check_density(density, title_name, fig_name): 44 | # global DIM_LINSPACE, G_MEAN, G_STD, G_SET_STD 45 | # sample = sample / G_SET_STD * G_STD + G_MEAN 46 | # fig, ax = plt.subplots(1, 1, figsize=(7, 7)) 47 | # ax.set_title(title_name) 48 | # x = DIM_LINSPACE 49 | # yy, xx = np.meshgrid(x, x) 50 | # ax.pcolor(xx, yy, density.reshape([yy.shape[0], yy.shape[1]])) 51 | # fig.suptitle(title_name) 52 | # fig.savefig(fig_name) 53 | # return fig, get_name(fig_name) 54 | 55 | 56 | # def plot_sample(sample, title_name, fig_name): 57 | # sample = ps_dataset.restore(sample) 58 | # fig, ax = plt.subplots(1, 1, figsize=(7, 7)) 59 | # ax.hist2d( 60 | # sample[:, 0], 61 | # sample[:, 1], 62 | # bins=ps_dataset.DIM_LINSPACE, 63 | # cmap=plt.cm.jet, 64 | # ) 65 | # ax.get_xaxis().set_ticks([]) 66 | # ax.get_yaxis().set_ticks([]) 67 | # fix_ax_lim(ax) 68 | # ax.set_facecolor(plt.cm.jet(0.0)) 69 | # plt.savefig(fig_name) 70 | # plt.close() 71 | 72 | 73 | # def plot_white_sample(sample, title_name, fig_name, sample_num=10000): 74 | # sample = ps_dataset.restore(sample) 75 | # fig, ax = plt.subplots(1, 1, figsize=(7, 7)) 76 | # ax.plot( 77 | # sample[:sample_num, 0], 78 | # sample[:sample_num, 1], 79 | # linewidth=0, 80 | # marker=".", 81 | # markersize=1, 82 | # alpha=0.5, 83 | # ) 84 | # ax.get_xaxis().set_ticks([]) 85 | # ax.get_yaxis().set_ticks([]) 86 | # fix_ax_lim(ax) 87 | # plt.savefig(fig_name) 88 | # plt.close() 89 | 90 | 91 | def seqSample2img(list_x, n): 92 | length = len(list_x) 93 | idxes = np.linspace(0, length - 1, n, dtype=int) 94 | with plt.style.context("img"): 95 | fig, axs = plt.subplots(1, n, figsize=(3 * n, 3)) 96 | for i_th, idx in enumerate(idxes): 97 | data = list_x[idx].cpu().numpy() 98 | axs[i_th].plot( 99 | data[:, 0], data[:, 1], linewidth=0, marker=".", markersize=1, alpha=0.5 100 | ) 101 | axs[i_th].set_xlim(-2, 2) 102 | axs[i_th].set_ylim(-2, 2) 103 | axs[i_th].set_title(idx) 104 | return fig 105 | --------------------------------------------------------------------------------