├── arts ├── model.jpg ├── mismatch.jpg └── multiple_tasks.jpg ├── constants.py ├── models ├── __init__.py ├── base.py ├── srgan.py ├── pix2pix.py └── glcic.py ├── requirements.txt ├── configs ├── srgan-celeba_64x64-gaussian-mr1.yaml ├── srgan-celeba_64x64-gaussian-mr2.yaml ├── srgan-celeba_64x64-laplace-mr1.yaml ├── srgan-celeba_64x64-laplace-mr2.yaml ├── srgan-celeba_64x64-laplace-pmr1.yaml ├── srgan-celeba_64x64-gaussian-pmr1.yaml ├── srgan-celeba_64x64-gaussian-pmr2.yaml ├── srgan-celeba_64x64-laplace-pmr2.yaml ├── glcic-celeba_128x128-gaussian-pmr2.yaml ├── pix2pix-cityscapes_256x256-gaussian-mr1.yaml ├── pix2pix-cityscapes_256x256-gaussian-mr2.yaml ├── pix2pix-cityscapes_256x256-gaussian-pmr1.yaml ├── pix2pix-cityscapes_256x256-gaussian-pmr2.yaml ├── pix2pix-cityscapes_256x256-laplace-mr1.yaml ├── pix2pix-cityscapes_256x256-laplace-mr2.yaml ├── pix2pix-cityscapes_256x256-laplace-pmr1.yaml ├── pix2pix-cityscapes_256x256-laplace-pmr2.yaml ├── pix2pix-maps_256x256-laplace-mr1.yaml ├── pix2pix-maps_256x256-laplace-mr2.yaml ├── pix2pix-maps_256x256-laplace-pmr1.yaml ├── pix2pix-maps_256x256-gaussian-mr1.yaml ├── pix2pix-maps_256x256-gaussian-mr2.yaml ├── pix2pix-maps_256x256-gaussian-pmr1.yaml ├── pix2pix-maps_256x256-laplace-pmr2.yaml └── pix2pix-maps_256x256-gaussian-pmr2.yaml ├── scripts ├── preprocess_celeba.py └── preprocess_pix2pix_data.py ├── utils.py ├── .gitignore ├── main.py ├── README.md ├── losses.py ├── LICENSE ├── train.py └── data.py /arts/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soochan-lee/MR-GAN/HEAD/arts/model.jpg -------------------------------------------------------------------------------- /arts/mismatch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soochan-lee/MR-GAN/HEAD/arts/mismatch.jpg -------------------------------------------------------------------------------- /arts/multiple_tasks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soochan-lee/MR-GAN/HEAD/arts/multiple_tasks.jpg -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | MODE_BASE = 'base' 2 | MODE_PRED = 'pred' 3 | MODE_MR = 'mr' 4 | 5 | 6 | MODES = ( 7 | MODE_BASE, 8 | MODE_PRED, 9 | MODE_MR, 10 | ) 11 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pix2pix import Pix2Pix 2 | from .glcic import GLCIC 3 | from .srgan import SRGAN 4 | 5 | 6 | MODELS = { 7 | Pix2Pix.name: Pix2Pix, 8 | GLCIC.name: GLCIC, 9 | SRGAN.name: SRGAN, 10 | } 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX==1.2 2 | torch==0.4.1 3 | torchvision==0.2.1 4 | pyyaml 5 | tensorflow 6 | ipdb 7 | git+https://github.com/lanpa/tensorboardX 8 | jupyter 9 | jupyterlab 10 | matplotlib 11 | tqdm 12 | opencv-python 13 | scipy 14 | scikit-image 15 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-gaussian-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: gaussian 19 | options: 20 | order: 1 21 | min_noise: 0.0001 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 0. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 20.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 24 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-gaussian-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: gaussian 19 | options: 20 | order: 2 21 | min_noise: 0.0001 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 0. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 20.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 24 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-laplace-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: laplace 19 | options: 20 | order: 1 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 0. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 30.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 24 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-laplace-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: laplace 19 | options: 20 | order: 2 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 0. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 30.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 24 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-laplace-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: laplace 19 | options: 20 | order: 2 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 2400. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 0.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 12 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-gaussian-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: gaussian 19 | options: 20 | order: 2 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 2400. 32 | mr_2nd_weight: 0. 33 | gan_weight: 1.0 34 | mle_weight: 0.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 12 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-gaussian-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: gaussian 19 | options: 20 | order: 2 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 2400. 32 | mr_2nd_weight: 2400. 33 | gan_weight: 1.0 34 | mle_weight: 0.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 12 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /configs/srgan-celeba_64x64-laplace-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: celeba 6 | root: ./data/celeba 7 | num_workers: 16 8 | size: 64 9 | 10 | 11 | model: 12 | name: srgan 13 | 14 | noise_type: gaussian 15 | noise_dim: [16, 16, 16, 16, 16, 16, 16] 16 | 17 | mle: 18 | type: laplace 19 | options: 20 | order: 2 21 | min_noise: 0.01 22 | gan: 23 | type: gan 24 | with_logits: true 25 | losses: 26 | base: 27 | gan_weight: 1.0 28 | recon_weight: 1000 29 | pred: 30 | mr: 31 | mr_1st_weight: 2400. 32 | mr_2nd_weight: 2400. 33 | gan_weight: 1.0 34 | mle_weight: 0.0 35 | 36 | 37 | batch_size: 32 38 | num_mr: 8 39 | num_mr_samples: 12 40 | 41 | d_updates_per_step: 1 42 | g_updates_per_step: 5 43 | 44 | 45 | # Optimizers 46 | d_optimizer: 47 | type: Adam 48 | options: 49 | lr: 0.0001 50 | betas: [0.5, 0.999] 51 | weight_decay: 0.0001 52 | amsgrad: True 53 | clip_grad: 54 | type: value 55 | options: 56 | clip_value: 0.5 57 | 58 | g_optimizer: 59 | type: Adam 60 | options: 61 | lr: 0.0001 62 | betas: [0.5, 0.999] 63 | weight_decay: 0.0001 64 | amsgrad: True 65 | clip_grad: 66 | type: value 67 | options: 68 | clip_value: 0.5 69 | 70 | p_optimizer: 71 | type: Adam 72 | options: 73 | lr: 0.0001 74 | betas: [0.5, 0.999] 75 | weight_decay: 0.0001 76 | amsgrad: True 77 | clip_grad: 78 | type: value 79 | options: 80 | clip_value: 0.5 81 | 82 | # Learning rate schedulers 83 | d_lr_scheduler: 84 | g_lr_scheduler: 85 | p_lr_scheduler: 86 | e_lr_scheduler: 87 | 88 | ckpt_step: 5000 89 | summary_step: 500 90 | 91 | log_dispersion_min: -6. 92 | log_dispersion_max: 0. 93 | 94 | summary: 95 | train_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 96 | val_samples: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] 97 | -------------------------------------------------------------------------------- /scripts/preprocess_celeba.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os.path 4 | import torch 5 | from torchvision.transforms import functional as F 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import random 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--src', required=True) 13 | parser.add_argument('--dst', required=True) 14 | parser.add_argument('--size', required=True) 15 | parser.add_argument('--val-ratio', default=0.1) 16 | 17 | 18 | def main(): 19 | args = parser.parse_args() 20 | size = int(args.size) 21 | print('size: {}'.format(size,)) 22 | 23 | src_dir = args.src 24 | dst_dir = args.dst 25 | 26 | items = [ 27 | os.path.join(src_dir, name) 28 | for name in os.listdir(src_dir) 29 | ] 30 | 31 | train_index = 0 32 | val_index = 0 33 | train_dir = os.path.join(dst_dir, 'train') 34 | val_dir = os.path.join(dst_dir, 'val') 35 | os.makedirs(train_dir, mode=0o755, exist_ok=False) 36 | os.makedirs(val_dir, mode=0o755, exist_ok=False) 37 | 38 | for item in tqdm(items): 39 | raw_image = Image.open(item) 40 | 41 | # Crop 42 | raw_image = F.to_tensor(raw_image) 43 | h = raw_image.size(1) 44 | w = raw_image.size(2) 45 | y_offset = (h - w) // 2 46 | square = raw_image[:, y_offset:y_offset + w, :] 47 | pil_square = F.to_pil_image(square) 48 | 49 | # Resize 50 | pil_hi = F.resize(pil_square, (size, size)) 51 | pil_lo = F.resize(pil_hi, (size // 4, size // 4)) 52 | 53 | # Pad & concatenate 54 | hi = F.to_tensor(pil_hi) 55 | lo = F.to_tensor(pil_lo) 56 | pad_size = hi.size(1) - lo.size(1) 57 | pad = torch.zeros(lo.size(0), pad_size, lo.size(2)) 58 | lo_padded = torch.cat([lo, pad], 1) 59 | result = torch.cat([hi, lo_padded], 2) 60 | 61 | # Save results 62 | result = F.to_pil_image(result) 63 | if random.random() > args.val_ratio: 64 | save_path = os.path.join(train_dir, '%06d.png' % train_index) 65 | train_index += 1 66 | else: 67 | save_path = os.path.join(val_dir, '%06d.png' % val_index) 68 | val_index += 1 69 | 70 | result.save(save_path) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from torch import nn 3 | import contextlib 4 | import time 5 | import itertools 6 | import yaml 7 | 8 | 9 | # ============= 10 | # PyTorch Utils 11 | # ============= 12 | 13 | class Lambda(nn.Module): 14 | def __init__(self, f=None): 15 | super().__init__() 16 | self.f = f if f is not None else (lambda x: x) 17 | 18 | def forward(self, *args, **kwargs): 19 | return self.f(*args, **kwargs) 20 | 21 | 22 | # ============ 23 | # Config Utils 24 | # ============ 25 | 26 | def load_config(path, as_namedtuple=True): 27 | config = yaml.load(open(path)) or {} 28 | return dict_to_namedtuple(config) if as_namedtuple else config 29 | 30 | 31 | def dict_to_namedtuple(d): 32 | if isinstance(d, dict): 33 | for k, v in d.items(): 34 | d[k] = dict_to_namedtuple(v) 35 | return namedtuple('d', d.keys())(**d) 36 | return d 37 | 38 | 39 | def namedtuple_to_dict(n): 40 | def _isnamedtuple(x): 41 | cls = type(x) 42 | bases = cls.__bases__ 43 | fields = getattr(cls, '_fields', None) 44 | 45 | if len(bases) != 1 or bases[0] != tuple: 46 | return False 47 | 48 | if not isinstance(fields, tuple): 49 | return False 50 | 51 | return all(type(name) == str for name in fields) 52 | 53 | d = dict(n._asdict()) 54 | for k, v in d.items(): 55 | if _isnamedtuple(v): 56 | d[k] = namedtuple_to_dict(v) 57 | return d 58 | 59 | 60 | # ============ 61 | # Python Utils 62 | # ============ 63 | 64 | @contextlib.contextmanager 65 | def time_logging_context(start_message, ending_message='done'): 66 | start = time.clock() 67 | print(start_message, end=' ', flush=True) 68 | yield 69 | took = time.clock() - start 70 | print(ending_message + ' (took {:.03f} seconds)'.format(took)) 71 | 72 | 73 | def ncycle(iterable, n): 74 | return itertools.chain.from_iterable(itertools.repeat(tuple(iterable), n)) 75 | 76 | 77 | def updated_nt(nt, path, value): 78 | try: 79 | attr_name, attr_names_left = path.split('.', 1) 80 | attr = getattr(nt, attr_name) 81 | return _updated_nt(nt, attr_name, updated_nt( 82 | attr, attr_names_left, value 83 | )) 84 | except ValueError: 85 | return _updated_nt(nt, path, value) 86 | 87 | 88 | def _updated_nt(nt, name, value): 89 | mapping = nt._asdict() 90 | mapping[name] = value 91 | return dict_to_namedtuple(mapping) 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/vim,linux,python 2 | 3 | ### Linux ### 4 | *~ 5 | 6 | # temporary files which can be created if a process still has a handle open of a deleted file 7 | .fuse_hidden* 8 | 9 | # KDE directory preferences 10 | .directory 11 | 12 | # Linux trash folder which might appear on any partition or disk 13 | .Trash-* 14 | 15 | # .nfs files are created when an open file is removed but is still being accessed 16 | .nfs* 17 | 18 | ### Python ### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | 124 | ### Python Patch ### 125 | .venv/ 126 | 127 | ### Vim ### 128 | # Swap 129 | [._]*.s[a-v][a-z] 130 | [._]*.sw[a-p] 131 | [._]s[a-rt-v][a-z] 132 | [._]ss[a-gi-z] 133 | [._]sw[a-p] 134 | 135 | # Session 136 | Session.vim 137 | 138 | # Temporary 139 | .netrwhist 140 | # Auto-generated tag files 141 | tags 142 | # Persistent undo 143 | [._]*.un~ 144 | 145 | 146 | # End of https://www.gitignore.io/api/vim,linux,python 147 | # 148 | data 149 | runs 150 | .idea 151 | __mnist 152 | __cifar10 153 | log 154 | generated_data 155 | samples 156 | samples-* 157 | -------------------------------------------------------------------------------- /configs/glcic-celeba_128x128-gaussian-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | 5 | data: 6 | name: celeba 7 | root: ./data/celeba/128x128 8 | category: null 9 | num_workers: 16 10 | size: 128 11 | random_crop: 128 12 | # random mask 13 | local_size: 64 14 | mask_size: [[48, 64], [48, 64]] 15 | # mask overlap 16 | overlap: 0 17 | overlap_weight: 0. 18 | scale: sigmoid 19 | norm: null 20 | real_jitter: 0.0 21 | 22 | 23 | model: 24 | name: glcic 25 | 26 | # Architecture 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | 30 | # Loss 31 | mle: 32 | type: gaussian 33 | options: 34 | order: 2 35 | min_noise: 0.01 36 | gan: 37 | type: gan 38 | with_logits: true 39 | losses: 40 | base: 41 | recon_weight: 100 42 | gan_weight: 1.0 43 | pred: 44 | mr: 45 | mr_1st_weight: 1000. 46 | mr_2nd_weight: 1000. 47 | gan_weight: 1.0 48 | mle_weight: 0.0 49 | 50 | 51 | g_pretrain_step: 0 52 | d_pretrain_step: 0 53 | 54 | batch_size: 16 55 | num_mr: 8 56 | num_mr_samples: 12 57 | 58 | d_updates_per_step: 1 59 | g_updates_per_step: 3 60 | 61 | 62 | # Optimizers 63 | d_optimizer: 64 | type: Adam 65 | options: 66 | lr: 0.0001 67 | betas: [0.5, 0.999] 68 | weight_decay: 0.0001 69 | amsgrad: True 70 | clip_grad: 71 | type: value 72 | options: 73 | clip_value: 0.5 74 | 75 | g_optimizer: 76 | type: Adam 77 | options: 78 | lr: 0.0001 79 | betas: [0.5, 0.999] 80 | weight_decay: 0.0001 81 | amsgrad: True 82 | clip_grad: 83 | type: value 84 | options: 85 | clip_value: 0.5 86 | 87 | p_optimizer: 88 | type: Adam 89 | options: 90 | lr: 0.0001 91 | betas: [0.5, 0.999] 92 | weight_decay: 0.0001 93 | amsgrad: True 94 | clip_grad: 95 | type: value 96 | options: 97 | clip_value: 0.5 98 | 99 | # Learning rate schedulers 100 | d_lr_scheduler: 101 | g_lr_scheduler: 102 | p_lr_scheduler: 103 | e_lr_scheduler: 104 | 105 | ckpt_step: 3000 106 | summary_step: 500 107 | 108 | log_dispersion_min: -6. 109 | log_dispersion_max: 0. 110 | 111 | summary: 112 | train_samples: [3825, 4957, 6457, 12365, 25817, 25942, 26465, 29138, 42834, 48282, 60148, 73585, 74907, 76942, 88628, 90125, 107117, 108936, 113474, 113614, 125387, 128753, 131210, 134668, 135283, 136666, 136918, 160845, 162794, 167819, 174655, 179805] 113 | val_samples: [1170, 1607, 1762, 3188, 3456, 4951, 6090, 6239, 6307, 7259, 7509, 7806, 7895, 8014, 9866, 10190, 12262, 12431, 13018, 14245, 15698, 17337, 17767, 17975, 18056, 18083, 18110, 18882, 19607, 19817, 19848, 19930] 114 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-gaussian-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: gaussian 35 | options: 36 | order: 1 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 0.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 10.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-gaussian-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: gaussian 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 0.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 10.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-gaussian-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: gaussian 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 10.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 0.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-gaussian-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: gaussian 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 10.0 47 | mr_2nd_weight: 10.0 48 | gan_weight: 1.0 49 | mle_weight: 0.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-laplace-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: laplace 35 | options: 36 | order: 1 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 0.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 10.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-laplace-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: laplace 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 0.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 10.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-laplace-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: laplace 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 1.0 47 | mr_2nd_weight: 0.0 48 | gan_weight: 1.0 49 | mle_weight: 0.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-cityscapes_256x256-laplace-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: cityscapes 6 | root: ./data/cityscapes/256x256 7 | num_workers: 32 8 | height: 256 9 | width: 256 10 | norm: 11 | mean: [0.3946, 0.2806, 0.3847] 12 | std: [0.2302, 0.1883, 0.2214] 13 | 14 | # to control training data size 15 | data_offset: null 16 | data_size: null 17 | 18 | model: 19 | name: pix2pix 20 | 21 | # Architecture 22 | in_channels: 3 23 | out_channels: 3 24 | num_downs: 8 # 1/256 25 | num_features: 64 26 | pred_features: 64 27 | noise_type: gaussian 28 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 29 | predictor: true 30 | norm: batch 31 | 32 | # Loss configurations 33 | mle: 34 | type: laplace 35 | options: 36 | order: 2 37 | gan: 38 | type: gan 39 | with_logits: true 40 | losses: 41 | base: 42 | recon_weight: 100. 43 | gan_weight: 1. 44 | pred: 45 | mr: 46 | mr_1st_weight: 1.0 47 | mr_2nd_weight: 1.0 48 | gan_weight: 1.0 49 | mle_weight: 0.0 50 | recon_weight: 0.0 # this converts optimize_g to BASE mode 51 | 52 | batch_size: 16 53 | num_mr: 8 54 | num_mr_samples: 10 55 | 56 | d_updates_per_step: 1 57 | g_updates_per_step: 1 58 | 59 | # Optimizers 60 | d_optimizer: 61 | type: Adam 62 | options: 63 | lr: 0.0001 64 | betas: [0.5, 0.999] 65 | weight_decay: 0.0001 66 | amsgrad: True 67 | clip_grad: 68 | type: value 69 | options: 70 | clip_value: 0.5 71 | 72 | g_optimizer: 73 | type: Adam 74 | options: 75 | lr: 0.0001 76 | betas: [0.5, 0.999] 77 | weight_decay: 0.0001 78 | amsgrad: True 79 | clip_grad: 80 | type: value 81 | options: 82 | clip_value: 0.5 83 | 84 | p_optimizer: 85 | type: Adam 86 | options: 87 | lr: 0.0001 88 | betas: [0.5, 0.999] 89 | weight_decay: 0.0001 90 | amsgrad: True 91 | clip_grad: 92 | type: value 93 | options: 94 | clip_value: 0.5 95 | 96 | e_optimizer: 97 | type: Adam 98 | options: 99 | lr: 0.0001 100 | betas: [0.5, 0.999] 101 | weight_decay: 0.0001 102 | amsgrad: True 103 | clip_grad: 104 | type: value 105 | options: 106 | clip_value: 0.5 107 | 108 | # Learning rate schedulers 109 | d_lr_scheduler: 110 | g_lr_scheduler: 111 | p_lr_scheduler: 112 | e_lr_scheduler: 113 | 114 | ckpt_step: 5000 115 | summary_step: 1000 116 | 117 | log_dispersion_min: -6. 118 | log_dispersion_max: 0. 119 | 120 | summary: 121 | train_samples: [466, 508, 788, 854, 1028, 1653, 1857, 2036, 2040, 2088, 2493, 2546, 2551, 2644, 2786, 2801] 122 | val_samples: [45, 114, 119, 160, 176, 206, 225, 367, 369, 398, 409, 422, 430, 431, 440, 458] 123 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-laplace-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: laplace 36 | options: 37 | order: 1 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 0.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 10.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-laplace-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: laplace 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 0.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 1.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-laplace-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: laplace 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 10.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 0.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-gaussian-mr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: gaussian 36 | options: 37 | order: 1 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 0.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 100.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-gaussian-mr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: gaussian 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 0.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 10.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-gaussian-pmr1.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: gaussian 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 300.0 48 | mr_2nd_weight: 0.0 49 | gan_weight: 1.0 50 | mle_weight: 0.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-laplace-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: laplace 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 10.0 48 | mr_2nd_weight: 10.0 49 | gan_weight: 1.0 50 | mle_weight: 0.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /configs/pix2pix-maps_256x256-gaussian-pmr2.yaml: -------------------------------------------------------------------------------- 1 | mode: null 2 | log_dir: null 3 | 4 | data: 5 | name: maps 6 | root: ./data/maps/512x512 7 | num_workers: 32 8 | height: 512 9 | width: 512 10 | train_height: 256 11 | train_width: 256 12 | resize: False 13 | random_flip: False 14 | random_rotate: False 15 | norm: 16 | mean: [0.8767, 0.8899, 0.8668] 17 | std: [0.0926, 0.0601, 0.1292] 18 | 19 | model: 20 | name: pix2pix 21 | 22 | # Architecture 23 | in_channels: 3 24 | out_channels: 3 25 | num_downs: 8 # 1/256 26 | num_features: 64 27 | pred_features: 64 28 | noise_type: gaussian 29 | noise_dim: [0, 0, 0, 0, 0, 0, 32, 32, 32] 30 | predictor: true 31 | norm: batch 32 | 33 | # Loss configurations 34 | mle: 35 | type: gaussian 36 | options: 37 | order: 2 38 | gan: 39 | type: gan 40 | with_logits: true 41 | losses: 42 | base: 43 | recon_weight: 100. 44 | gan_weight: 1. 45 | pred: 46 | mr: 47 | mr_1st_weight: 1000.0 48 | mr_2nd_weight: 1000.0 49 | gan_weight: 1.0 50 | mle_weight: 0.0 51 | recon_weight: 0.0 # this converts optimize_g to BASE mode 52 | 53 | batch_size: 16 54 | num_mr: 8 55 | num_mr_samples: 10 56 | 57 | d_updates_per_step: 1 58 | g_updates_per_step: 1 59 | 60 | # Optimizers 61 | d_optimizer: 62 | type: Adam 63 | options: 64 | lr: 0.0001 65 | betas: [0.5, 0.999] 66 | weight_decay: 0.0001 67 | amsgrad: True 68 | clip_grad: 69 | type: value 70 | options: 71 | clip_value: 0.5 72 | 73 | g_optimizer: 74 | type: Adam 75 | options: 76 | lr: 0.0001 77 | betas: [0.5, 0.999] 78 | weight_decay: 0.0001 79 | amsgrad: True 80 | clip_grad: 81 | type: value 82 | options: 83 | clip_value: 0.5 84 | 85 | p_optimizer: 86 | type: Adam 87 | options: 88 | lr: 0.0001 89 | betas: [0.5, 0.999] 90 | weight_decay: 0.0001 91 | amsgrad: True 92 | clip_grad: 93 | type: value 94 | options: 95 | clip_value: 0.5 96 | 97 | e_optimizer: 98 | type: Adam 99 | options: 100 | lr: 0.0001 101 | betas: [0.5, 0.999] 102 | weight_decay: 0.0001 103 | amsgrad: True 104 | clip_grad: 105 | type: value 106 | options: 107 | clip_value: 0.5 108 | 109 | # Learning rate schedulers 110 | d_lr_scheduler: 111 | g_lr_scheduler: 112 | p_lr_scheduler: 113 | e_lr_scheduler: 114 | 115 | ckpt_step: 5000 116 | summary_step: 1000 117 | 118 | log_dispersion_min: -6. 119 | log_dispersion_max: 0. 120 | 121 | summary: 122 | train_samples: [97, 140, 189, 251, 308, 319, 379, 418, 552, 733, 809, 879, 925, 954, 1028, 1080] 123 | val_samples: [99, 150, 452, 469, 511, 535, 581, 608, 614, 661, 722, 925, 931, 951, 998, 1051] 124 | -------------------------------------------------------------------------------- /scripts/preprocess_pix2pix_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import torch 4 | from torchvision.transforms import functional as F 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--src', required=True) 11 | parser.add_argument('--dst', required=True) 12 | parser.add_argument('--size', required=True) 13 | parser.add_argument('--random-flip', action='store_true') 14 | parser.add_argument('--random-rotate', action='store_true') 15 | 16 | 17 | def main(): 18 | args = parser.parse_args() 19 | size = int(args.size) 20 | print('size: {} | random flip: {} | random rotate: {}' 21 | .format(size, args.random_flip, args.random_rotate)) 22 | 23 | for mode in ['train', 'val']: 24 | src_dir = os.path.join(args.src, mode) 25 | dst_dir = os.path.join(args.dst, mode) 26 | 27 | items = [] 28 | for path, dir, files in os.walk(src_dir): 29 | for file in files: 30 | items.append(os.path.join(path, file)) 31 | 32 | index = 0 33 | print('Processing %s...' % mode) 34 | os.makedirs(dst_dir, mode=0o755, exist_ok=True) 35 | for item in tqdm(items): 36 | results = [] 37 | raw_image = Image.open(item) 38 | 39 | # Split 40 | raw_image = F.to_tensor(raw_image) 41 | image = raw_image[:, :, :raw_image.size(2)//2] 42 | label = raw_image[:, :, raw_image.size(2)//2:] 43 | pil_image = F.to_pil_image(image) 44 | pil_label = F.to_pil_image(label) 45 | 46 | # Resize 47 | pil_image = F.resize(pil_image, (size, size)) 48 | pil_label = F.resize(pil_label, (size, size)) 49 | results.append((pil_image, pil_label)) 50 | 51 | # Rotation (x4) 52 | if args.random_rotate and mode == 'train': 53 | for degree in [90, 180, 270]: 54 | aug_image = F.rotate(pil_image, degree) 55 | aug_label = F.rotate(pil_label, degree) 56 | results.append((aug_image, aug_label)) 57 | 58 | # Flip and rotate (x4) 59 | if args.random_flip and mode == 'train': 60 | pil_image = F.hflip(pil_image) 61 | pil_label = F.hflip(pil_label) 62 | results.append((pil_image, pil_label)) 63 | 64 | if args.random_rotate: 65 | for degree in [90, 180, 270]: 66 | aug_image = F.rotate(pil_image, degree) 67 | aug_label = F.rotate(pil_label, degree) 68 | results.append((aug_image, aug_label)) 69 | 70 | # Save results 71 | for image, label in results: 72 | merged = torch.cat([ 73 | F.to_tensor(image), F.to_tensor(label) 74 | ], 2) 75 | merged = F.to_pil_image(merged) 76 | merged.save(os.path.join(dst_dir, '%d.jpg' % index)) 77 | index += 1 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | from data import DATASETS 5 | from models import MODELS 6 | from train import train 7 | from utils import dict_to_namedtuple 8 | from constants import MODES, MODE_BASE, MODE_MR 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--log-dir') 12 | parser.add_argument('--config') 13 | parser.add_argument('--mode', choices=MODES, default=MODE_BASE) 14 | parser.add_argument('--pred-ckpt') 15 | parser.add_argument('--resume-ckpt') 16 | parser.add_argument('--options', '-o', default='') 17 | 18 | if __name__ == '__main__': 19 | args = parser.parse_args() 20 | 21 | assert args.mode in MODES, 'Unknown mode %s' % args.mode 22 | if args.mode == MODE_MR: 23 | if not (args.pred_ckpt or args.resume_ckpt): 24 | print('WARNING: Proxy MR-GAN requires ' 25 | 'checkpoint path of a predictor') 26 | 27 | # Load config 28 | config_path = args.config 29 | if args.resume_ckpt and not args.config: 30 | base_dir = os.path.dirname(os.path.dirname(args.resume_ckpt)) 31 | config_path = os.path.join(base_dir, 'config.yaml') 32 | config = yaml.load(open(config_path)) 33 | 34 | # Override options 35 | config['mode'] = args.mode 36 | for option in args.options.split('|'): 37 | if not option: 38 | continue 39 | address, value = option.split('=') 40 | keys = address.split('.') 41 | here = config 42 | for key in keys[:-1]: 43 | if key not in here: 44 | raise ValueError('{} is not defined in config file. ' 45 | 'Failed to override.'.format(address)) 46 | here = here[key] 47 | if keys[-1] not in here: 48 | raise ValueError('{} is not defined in config file. ' 49 | 'Failed to override.'.format(address)) 50 | 51 | here[keys[-1]] = yaml.load(value) 52 | 53 | # Set log directory 54 | config['log_dir'] = args.log_dir 55 | if not args.resume_ckpt and args.log_dir and os.path.exists(args.log_dir): 56 | print('WARNING: %s already exists' % args.log_dir) 57 | input('Press enter to continue') 58 | 59 | if args.resume_ckpt and not args.log_dir: 60 | config['log_dir'] = os.path.dirname( 61 | os.path.dirname(args.resume_ckpt) 62 | ) 63 | 64 | # Save config 65 | os.makedirs(config['log_dir'], mode=0o755, exist_ok=True) 66 | if not args.resume_ckpt or args.config: 67 | config_save_path = os.path.join(config['log_dir'], 'config.yaml') 68 | yaml.dump(config, open(config_save_path, 'w')) 69 | print('Config file saved to {}'.format(config_save_path)) 70 | 71 | config = dict_to_namedtuple(config) 72 | 73 | # Instantiate dataset 74 | dataset_factory = DATASETS[config.data.name] 75 | train_dataset, val_dataset = dataset_factory(config) 76 | 77 | model = MODELS[config.model.name](config) 78 | model.cuda() 79 | 80 | if args.resume_ckpt: 81 | print('Resuming checkpoint %s' % args.resume_ckpt) 82 | step = model.load(args.resume_ckpt) 83 | else: 84 | step = 0 85 | if args.pred_ckpt: 86 | print('Loading predictor from %s' % args.pred_ckpt) 87 | model.load_module(model.net_p, args.pred_ckpt) 88 | 89 | train(model, config, train_dataset, val_dataset, step) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MR-GAN 2 | 3 | [**Project**][1] | [**Paper**][2] 4 | 5 | Official PyTorch implementation of ICLR 2019 paper: *Harmonizing Maximum Likelihood with GANs for Multimodal Conditional Generation.* 6 | 7 | Conditional GAN models are often optimized by the joint use of the GAN loss and reconstruction loss. We show that this training recipe shared by almost all existing methods is problematic and has one critical side effect: lack of diversity in output samples. 8 | In order to accomplish both training stability and multimodal output generation, we propose novel training schemes with a new set of losses named *moment reconstruction losses* that simply replace the reconstruction loss. 9 | 10 | ![mismatch](./arts/mismatch.jpg) 11 | 12 | ![model](./arts/model.jpg) 13 | 14 | ![multiple tasks](./arts/multiple_tasks.jpg) 15 | 16 | 17 | ## Reference 18 | 19 | If you are willing to use this code or cite the paper, please refer the following: 20 | ```bibtex 21 | @inproceedings{ 22 | lee2019harmonizing, 23 | title={Harmonizing Maximum Likelihood with {GAN}s for Multimodal Conditional Generation}, 24 | author={Soochan Lee and Junsoo Ha and Gunhee Kim}, 25 | booktitle={International Conference on Learning Representations}, 26 | year={2019}, 27 | url={https://openreview.net/forum?id=HJxyAjRcFX}, 28 | } 29 | ``` 30 | 31 | ## Requirements 32 | * Python >= 3.6 33 | * CUDA >= 9.0 supported GPU with at least 10GB memory 34 | 35 | ## Installation 36 | ```shell 37 | $ pip install -r requirements.txt 38 | ``` 39 | 40 | ## Preprocessing 41 | ### Cityscapes 42 | We expect the original Cityscapes dataset to be located at `data/cityscapes/original`. Please refer to [Cityscapes Dataset](http://www.cityscapes-dataset.net/) and [mcordts/cityscapesScripts](https://github.com/mcordts/cityscapesScripts) for details. 43 | ```bash 44 | $ python ./scripts/preprocess_pix2pix_data.py \ 45 | --src data/cityscapes/original/leftImg8bit \ 46 | --dst data/cityscapes/256x256 \ 47 | --size 256 \ 48 | --random-flip 49 | ``` 50 | 51 | ### Maps 52 | We expect the original Maps dataset to be located at `data/maps/original`. We recommend you to use the [dataset downloading script](https://github.com/junyanz/CycleGAN/blob/master/datasets/download_dataset.sh) of [junyanz/CycleGAN](https://github.com/junyanz/CycleGAN). 53 | ```bash 54 | $ python ./scripts/preprocess_pix2pix_data.py \ 55 | --src data/maps/original \ 56 | --dst data/maps/512x512 \ 57 | --size 512 \ 58 | --random-flip \ 59 | --random-rotate 60 | ``` 61 | 62 | ### CelebA 63 | We expect the original CelebA dataset to be located at `data/celeba/original` with the directory structure of `data/celeba/original/train` and `data/celeba/original/val`. 64 | ```bash 65 | # For Super-Resolution 66 | $ python ./scripts/preprocess_celeba.py \ 67 | --src data/celeba/original \ 68 | --dst data/celeba/64x64 \ 69 | --size 64 70 | 71 | # For Inpainting 72 | $ python ./scripts/preprocess_celeba.py \ 73 | --src data/celeba/original \ 74 | --dst data/celeba/128x128 \ 75 | --size 128 76 | ``` 77 | 78 | 79 | 80 | ## Training 81 | 82 | ### MR-GAN 83 | ```bash 84 | $ python main.py --mode mr --config ./configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/mr 85 | ``` 86 | 87 | ### Proxy MR-GAN 88 | Train a predictor first and determine the checkpoint where the validation loss is minimized. 89 | ```bash 90 | $ python main.py --mode pred --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/predictor 91 | ``` 92 | Use the checkpoint as `--pred-ckpt` to train the generator. 93 | ```bash 94 | $ python main.py --mode mr --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/pmr --pred-ckpt ./logs/predictor/ckpt/{step}-p.pt 95 | ``` 96 | 97 | 98 | [1]: https://soochanlee.com/publications/mr-gan 99 | [2]: https://openreview.net/pdf?id=HJxyAjRcFX 100 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class L1Loss(nn.Module): 6 | name = 'l1' 7 | 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, x, y, normalizer=None): 12 | if not normalizer: 13 | return (x - y).abs().mean() 14 | else: 15 | return (x - y).abs().sum() / normalizer 16 | 17 | 18 | class MSELoss(nn.Module): 19 | name = 'mse' 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y, normalizer=None): 25 | if not normalizer: 26 | return ((x - y) ** 2).mean() 27 | else: 28 | return ((x - y) ** 2).sum() / normalizer 29 | 30 | 31 | class GANLoss(nn.Module): 32 | name = 'gan' 33 | need_lipschitz_d = False 34 | 35 | def __init__(self, real_label=1.0, fake_label=0.0, with_logits=False): 36 | super().__init__() 37 | self.register_buffer('real_label', torch.tensor(real_label)) 38 | self.register_buffer('fake_label', torch.tensor(fake_label)) 39 | if with_logits: 40 | self.loss = nn.BCEWithLogitsLoss() 41 | else: 42 | self.loss = nn.BCELoss() 43 | 44 | def forward(self, verdict, target_is_real): 45 | if target_is_real: 46 | target = self.real_label 47 | else: 48 | target = self.fake_label 49 | return self.loss(verdict, target.expand_as(verdict)) 50 | 51 | 52 | class WGANLoss(nn.Module): 53 | name = 'wgan' 54 | need_lipschitz_d = True 55 | 56 | def __init__(self, with_logits=True): 57 | super().__init__() 58 | 59 | def forward(self, verdict, target_is_real): 60 | return -verdict.mean() if target_is_real else verdict.mean() 61 | 62 | 63 | class LSGANLoss(nn.Module): 64 | name = 'lsgan' 65 | need_lipschitz_d = False 66 | 67 | def __init__(self, real_label=1.0, fake_label=0.0): 68 | super().__init__() 69 | self.register_buffer('real_label', torch.tensor(real_label)) 70 | self.register_buffer('fake_label', torch.tensor(fake_label)) 71 | self.loss = nn.MSELoss() 72 | 73 | def forward(self, verdict, target_is_real): 74 | if target_is_real: 75 | target = self.real_label 76 | else: 77 | target = self.fake_label 78 | return self.loss(verdict, target.expand_as(verdict)) 79 | 80 | 81 | class GaussianMLELoss(nn.Module): 82 | name = 'gaussian' 83 | 84 | def __init__(self, order=2, min_noise=0.): 85 | super().__init__() 86 | self.order = order 87 | self.min_noise = min_noise 88 | 89 | def forward(self, center, dispersion, y, 90 | log_dispersion=True, normalizer=None): 91 | squared = (center - y) ** 2 92 | 93 | if self.order == 1: 94 | return squared.mean() 95 | 96 | if log_dispersion: 97 | var = dispersion.exp() 98 | log_var = dispersion 99 | else: 100 | var = dispersion 101 | log_var = (dispersion + 1e-9).log() 102 | 103 | loss = ((squared + self.min_noise) / (var + 1e-9) + log_var) * 0.5 104 | 105 | if not normalizer: 106 | return loss.mean() 107 | else: 108 | return loss.sum() / normalizer 109 | 110 | 111 | class LaplaceMLELoss(nn.Module): 112 | name = 'laplace' 113 | 114 | def __init__(self, order=2, min_noise=0.): 115 | super().__init__() 116 | self.order = order 117 | self.min_noise = min_noise 118 | 119 | def forward(self, center, dispersion, y, 120 | log_dispersion=True, normalizer=None): 121 | deviation = (center - y).abs() 122 | 123 | if self.order == 1: 124 | return deviation.mean() 125 | 126 | if log_dispersion: 127 | mad = dispersion.exp() 128 | log_mad = dispersion 129 | else: 130 | mad = dispersion 131 | log_mad = (dispersion + 1e-9).log() 132 | 133 | loss = (deviation + self.min_noise) / (mad + 1e-9) + log_mad 134 | 135 | if not normalizer: 136 | return loss.mean() 137 | else: 138 | return loss.sum() / normalizer 139 | 140 | 141 | LOSSES = { 142 | GANLoss.name: GANLoss, 143 | WGANLoss.name: WGANLoss, 144 | LSGANLoss.name: LSGANLoss, 145 | GaussianMLELoss.name: GaussianMLELoss, 146 | LaplaceMLELoss.name: LaplaceMLELoss, 147 | } 148 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Soochan Lee, Junsoo Ha and Gunhee Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | -------------------------- LICENSE FOR PyTorch pix2pix ----------------------- 25 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 26 | All rights reserved. 27 | 28 | Redistribution and use in source and binary forms, with or without 29 | modification, are permitted provided that the following conditions are met: 30 | 31 | * Redistributions of source code must retain the above copyright notice, this 32 | list of conditions and the following disclaimer. 33 | 34 | * Redistributions in binary form must reproduce the above copyright notice, 35 | this list of conditions and the following disclaimer in the documentation 36 | and/or other materials provided with the distribution. 37 | 38 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 39 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 40 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 41 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 42 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 43 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 44 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 45 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 46 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 47 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 48 | 49 | 50 | --------------------------- LICENSE FOR pix2pix ------------------------------- 51 | BSD License 52 | 53 | For pix2pix software 54 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 55 | All rights reserved. 56 | 57 | Redistribution and use in source and binary forms, with or without 58 | modification, are permitted provided that the following conditions are met: 59 | 60 | * Redistributions of source code must retain the above copyright notice, this 61 | list of conditions and the following disclaimer. 62 | 63 | * Redistributions in binary form must reproduce the above copyright notice, 64 | this list of conditions and the following disclaimer in the documentation 65 | and/or other materials provided with the distribution. 66 | 67 | 68 | ----------------------------- LICENSE FOR DCGAN ------------------------------- 69 | BSD License 70 | 71 | For dcgan.torch software 72 | 73 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 74 | 75 | Redistribution and use in source and binary forms, with or without 76 | modification, are permitted provided that the following conditions are met: 77 | 78 | * Redistributions of source code must retain the above copyright notice, this 79 | list of conditions and the following disclaimer. 80 | 81 | * Redistributions in binary form must reproduce the above copyright notice, 82 | this list of conditions and the following disclaimer in the documentation 83 | and/or other materials provided with the distribution. 84 | 85 | Neither the name Facebook nor the names of its contributors may be used to 86 | endorse or promote products derived from this software without specific prior 87 | written permission. 88 | 89 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 90 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 91 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 92 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 93 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 94 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 95 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 96 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 97 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 98 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import yaml 3 | import time 4 | import torch 5 | 6 | from models.base import BaseModel 7 | from data import DataIterator, SubsetSequentialSampler, InfiniteRandomSampler 8 | from constants import MODE_PRED 9 | from utils import namedtuple_to_dict 10 | 11 | 12 | def train(model: BaseModel, config, train_dataset, val_dataset, step=0): 13 | train_iterator = DataIterator( 14 | train_dataset, 15 | batch_size=config.batch_size, 16 | num_workers=config.data.num_workers, 17 | sampler=InfiniteRandomSampler(train_dataset) 18 | ) 19 | 20 | # Prepare for summary 21 | writer = SummaryWriter(config.log_dir) 22 | config_str = yaml.dump(namedtuple_to_dict(config)) 23 | writer.add_text('config', config_str) 24 | train_sampler = SubsetSequentialSampler( 25 | train_dataset, config.summary.train_samples) 26 | val_sampler = SubsetSequentialSampler( 27 | val_dataset, config.summary.val_samples 28 | ) 29 | train_sample_iterator = DataIterator( 30 | train_dataset.for_summary(), sampler=train_sampler, num_workers=2) 31 | val_sample_iterator = DataIterator( 32 | val_dataset.for_summary(), sampler=val_sampler, num_workers=2) 33 | 34 | # Training loop 35 | start_time = time.time() 36 | start_step = step 37 | while True: 38 | step += 1 39 | save_summary = step % config.summary_step == 0 40 | d_summary, g_summary, p_summary = None, None, None 41 | if config.mode == MODE_PRED: 42 | if model.lr_sched_p is not None: 43 | model.lr_sched_p.step() 44 | x, y = next(train_iterator) 45 | p_summary = model.optimize_p( 46 | x, y, step=step, summarize=save_summary) 47 | 48 | else: 49 | if model.lr_sched_d is not None: 50 | model.lr_sched_d.step() 51 | 52 | x, y = next(train_iterator) 53 | summarize_d = save_summary and config.d_updates_per_step == 1 54 | d_summary = model.optimize_d( 55 | x, y, step=step, summarize=summarize_d) 56 | for i in range(config.d_updates_per_step - 1): 57 | x, y = next(train_iterator) 58 | summarize_d = save_summary and ( 59 | i == config.d_updates_per_step - 2) 60 | d_summary = model.optimize_d( 61 | x, y, step=step, summarize=summarize_d) 62 | 63 | if model.lr_sched_g is not None: 64 | model.lr_sched_g.step() 65 | 66 | summarize_g = save_summary and config.g_updates_per_step == 1 67 | g_summary = model.optimize_g( 68 | x, y, step=step, summarize=summarize_g) 69 | for i in range(config.g_updates_per_step - 1): 70 | x, y = next(train_iterator) 71 | summarize_g = save_summary and ( 72 | i == config.g_updates_per_step - 2) 73 | g_summary = model.optimize_g( 74 | x, y, step=step, summarize=summarize_g) 75 | 76 | # Print status 77 | elapsed_time = time.time() - start_time 78 | elapsed_step = step - start_step 79 | print( 80 | '\r[Step %d] %s' % ( 81 | step, time.strftime('%H:%M:%S', time.gmtime(elapsed_time))), 82 | end='') 83 | if elapsed_time > elapsed_step: 84 | print(' | %.2f s/it' % (elapsed_time / elapsed_step), end='') 85 | else: 86 | print(' | %.2f it/s' % (elapsed_step / elapsed_time), end='') 87 | 88 | if step % config.ckpt_step == 0: 89 | model.save(step) 90 | 91 | if save_summary: 92 | # Save summaries from optimization process 93 | for summary in [p_summary, d_summary, g_summary]: 94 | if summary is None: 95 | continue 96 | model.write_summary(writer, summary, step) 97 | 98 | # Summarize learning rates and gradients 99 | for component, optimizer in [ 100 | ('d', model.optim_d), ('g', model.optim_g), 101 | ('p', model.optim_p), 102 | ]: 103 | if optimizer is None: 104 | continue 105 | 106 | for i, group in enumerate(optimizer.param_groups): 107 | writer.add_scalar( 108 | 'lr/%s/%d' % (component, i), group['lr'], step) 109 | grads = [] 110 | for param in group['params']: 111 | if param.grad is not None: 112 | grads.append(param.grad.data.view([-1])) 113 | if grads: 114 | grads = torch.cat(grads, 0) 115 | writer.add_histogram( 116 | 'grad/%s/%d' % (component, i), grads, step) 117 | 118 | # Custom summaries 119 | model.summarize( 120 | writer, step, train_sample_iterator, val_sample_iterator) 121 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os.path 3 | import numpy as np 4 | import random 5 | import torch 6 | from torchvision.transforms import functional as F 7 | from torch.utils import data 8 | from PIL import Image 9 | from collections import Iterator, Iterable 10 | 11 | 12 | class InfiniteRandomIterator(Iterator): 13 | def __init__(self, data_source): 14 | self.data_source = data_source 15 | self._perms = [ 16 | torch.randperm(len(self.data_source)).tolist() 17 | for _ in range(30) 18 | ] 19 | self.iterator = iter(random.choice(self._perms)) 20 | 21 | def __next__(self): 22 | try: 23 | idx = next(self.iterator) 24 | except StopIteration: 25 | self.iterator = iter(random.choice(self._perms)) 26 | idx = next(self.iterator) 27 | 28 | return idx 29 | 30 | 31 | class InfiniteRandomSampler(data.Sampler): 32 | def __init__(self, data_source): 33 | super().__init__(data_source) 34 | self.data_source = data_source 35 | 36 | def __iter__(self): 37 | return InfiniteRandomIterator(self.data_source) 38 | 39 | def __len__(self): 40 | return len(self.data_source) 41 | 42 | 43 | class InfiniteSubsetIterator(Iterator): 44 | def __init__(self, indices): 45 | self.indices = indices 46 | self.iterator = iter(self.indices) 47 | 48 | def __next__(self): 49 | try: 50 | idx = next(self.iterator) 51 | except StopIteration: 52 | self.iterator = iter(self.indices) 53 | idx = next(self.iterator) 54 | 55 | return idx 56 | 57 | def __len__(self): 58 | return len(self.indices) 59 | 60 | 61 | class SubsetSequentialSampler(data.sampler.Sampler): 62 | def __init__(self, data_source, indices): 63 | super().__init__(data_source) 64 | if isinstance(indices, Iterable): 65 | self.indices = indices 66 | else: 67 | self.indices = np.random.choice( 68 | len(data_source), 69 | size=indices, 70 | replace=False 71 | ) 72 | 73 | def __iter__(self): 74 | return InfiniteSubsetIterator(self.indices) 75 | 76 | def __len__(self): 77 | return len(self.indices) 78 | 79 | 80 | class DataIterator(Iterator): 81 | def __init__(self, dataset: data.Dataset, **kwargs): 82 | self.data_loader = data.DataLoader(dataset, **kwargs) 83 | self.epoch = 0 84 | self.iterator = iter(self.data_loader) 85 | 86 | def __next__(self): 87 | batch = next(self.iterator) 88 | batch = self.data_loader.dataset.vector_preprocess(*batch) 89 | 90 | return batch 91 | 92 | def __len__(self): 93 | return len(self.data_loader) 94 | 95 | 96 | class BaseDataset(data.Dataset): 97 | def __init__(self): 98 | super().__init__() 99 | 100 | def __getitem__(self, index): 101 | raise NotImplementedError 102 | 103 | def __len__(self): 104 | raise NotImplementedError 105 | 106 | def vector_preprocess(self, x, y): 107 | return x, y 108 | 109 | def for_summary(self): 110 | return self 111 | 112 | 113 | # ======================= 114 | # Dataset Implementations 115 | # ======================= 116 | 117 | class Maps(BaseDataset): 118 | def __init__(self, config, mode='train'): 119 | super().__init__() 120 | assert mode in ['train', 'val'] 121 | self.config = config 122 | self.mode = mode 123 | self.resize = config.data.resize 124 | self.random_flip = config.data.random_flip 125 | self.random_rotate = config.data.random_rotate 126 | self.resize_shape = [config.data.height, config.data.width * 2] 127 | if config.data.norm: 128 | self.mean = torch.tensor( 129 | config.data.norm.mean, device='cuda' 130 | ).view(1, 3, 1, 1) 131 | self.std = torch.tensor( 132 | config.data.norm.std, device='cuda' 133 | ).view(1, 3, 1, 1) 134 | else: 135 | self.mean = None 136 | self.std = None 137 | 138 | # Load image paths 139 | image_dir = os.path.join(config.data.root, mode) 140 | self.items = [ 141 | os.path.join(image_dir, name) 142 | for name in os.listdir(image_dir) 143 | ] 144 | self.items = sorted(self.items) 145 | 146 | def __getitem__(self, index): 147 | raw_image = F.to_tensor(Image.open(self.items[index])) 148 | if self.resize: 149 | raw_image = F.resize(raw_image, self.resize_shape) 150 | 151 | w = self.config.data.width 152 | image = raw_image[:, :, :w] 153 | label = raw_image[:, :, w:] 154 | 155 | th = self.config.data.train_height 156 | tw = self.config.data.train_width 157 | if self.mode == 'train': 158 | max_offset_y = self.config.data.height - th 159 | max_offset_x = self.config.data.width - tw 160 | 161 | offset_y = torch.randint( 162 | low=0, high=max_offset_y, size=[], 163 | dtype=torch.long 164 | ) 165 | offset_x = torch.randint( 166 | low=0, high=max_offset_x, size=[], 167 | dtype=torch.long 168 | ) 169 | image = image[:, offset_y:offset_y + th, offset_x:offset_x + tw] 170 | label = label[:, offset_y:offset_y + th, offset_x:offset_x + tw] 171 | 172 | # Data augmentation 173 | if self.random_flip or self.random_rotate: 174 | image = F.to_pil_image(image) 175 | label = F.to_pil_image(label) 176 | 177 | # Random flip 178 | if self.random_flip and random.random() < 0.5: 179 | image = F.vflip(image) 180 | label = F.vflip(label) 181 | 182 | # Random rotation 183 | if self.random_rotate: 184 | degree = random.choice([0, 90, 180, 270]) 185 | if degree != 0: 186 | image = F.rotate(image, degree) 187 | label = F.rotate(label, degree) 188 | 189 | image = F.to_tensor(image) 190 | label = F.to_tensor(label) 191 | else: 192 | image = image[:, :th, :tw] 193 | label = label[:, :th, :tw] 194 | 195 | return label, image 196 | 197 | def __len__(self): 198 | return len(self.items) 199 | 200 | def vector_preprocess(self, x, y): 201 | if self.mean is not None: 202 | x = x.cuda().sub_(self.mean).div_(self.std) 203 | else: 204 | x = x.cuda().mul_(2.).sub_(1.) 205 | y = y.cuda().mul_(2.).sub_(1.) 206 | return x, y 207 | 208 | 209 | class Cityscapes(BaseDataset): 210 | def __init__(self, config, mode='train'): 211 | super().__init__() 212 | assert mode in ['train', 'val'] 213 | self.config = config 214 | self.mode = mode 215 | if config.data.norm: 216 | self.mean = torch.tensor( 217 | config.data.norm.mean, device='cuda' 218 | ).view(1, 3, 1, 1) 219 | self.std = torch.tensor( 220 | config.data.norm.std, device='cuda' 221 | ).view(1, 3, 1, 1) 222 | else: 223 | self.mean = None 224 | self.std = None 225 | 226 | # Load image paths 227 | image_dir = os.path.join(config.data.root, mode) 228 | self.items = sorted([ 229 | os.path.join(image_dir, name) 230 | for name in os.listdir(image_dir) 231 | ]) 232 | 233 | data_offset = config.data.data_offset or 0 234 | data_size = config.data.data_size 235 | 236 | if data_size is not None: 237 | self.items = self.items[data_offset: data_offset + data_size] 238 | else: 239 | self.items = self.items[data_offset:] 240 | 241 | def __getitem__(self, index): 242 | raw_image = F.to_tensor(Image.open(self.items[index])) 243 | w = self.config.data.width 244 | image = raw_image[:, :, :w] 245 | label = raw_image[:, :, w:] 246 | return label, image 247 | 248 | def __len__(self): 249 | return len(self.items) 250 | 251 | def vector_preprocess(self, x, y): 252 | if self.mean is not None: 253 | x = x.cuda().sub_(self.mean).div_(self.std) 254 | else: 255 | x = x.cuda().mul_(2.).sub_(1.) 256 | y = y.cuda().mul_(2.).sub_(1.) 257 | return x, y 258 | 259 | 260 | class Edges2Shoes(Cityscapes): 261 | def __getitem__(self, index): 262 | raw_image = F.to_tensor(Image.open(self.items[index])) 263 | w = self.config.data.width 264 | label = raw_image[:, :, :w] 265 | image = raw_image[:, :, w:] 266 | return label, image 267 | 268 | 269 | class CelebA(BaseDataset): 270 | mean_rgb = [130 / 255, 108 / 255, 96 / 255] 271 | 272 | def __init__(self, config, mode='train'): 273 | super().__init__() 274 | assert mode in ['train', 'val'] 275 | self.config = config 276 | self.mode = mode 277 | self.mean = None 278 | self.std = None 279 | self.size = getattr(config.data, 'size', None) 280 | self.local_size = getattr(config.data, 'local_size', None) 281 | self.mask_size = getattr(config.data, 'mask_size', None) 282 | self.summary = False 283 | 284 | # Load image paths 285 | image_dir = os.path.join(config.data.root, mode) 286 | self.items = [ 287 | os.path.join(image_dir, name) 288 | for name in os.listdir(image_dir) 289 | ] 290 | self.items = sorted(self.items) 291 | 292 | def __getitem__(self, index): 293 | raw_image = F.to_tensor(Image.open(self.items[index])) 294 | h = raw_image.size(1) 295 | w = raw_image.size(2) 296 | image = raw_image[:, :, :h] 297 | label = raw_image[:, :w - h, h:] 298 | 299 | return label, image 300 | 301 | def __len__(self): 302 | return len(self.items) 303 | 304 | def for_summary(self): 305 | clone = copy.deepcopy(self) 306 | clone.summary = True 307 | return clone 308 | 309 | def vector_preprocess(self, x, y): 310 | if self.config.model.name == 'glcic': 311 | y = y.cuda() 312 | x = y.clone() 313 | local_boxes = [] 314 | masks = [] 315 | for i in range(len(x)): 316 | if self.summary: 317 | seed = x[i] 318 | else: 319 | seed = None 320 | mask, mask_box, local_box = self._random_mask_in_local_box(seed) 321 | masks.append(mask) 322 | local_boxes.append(local_box) 323 | q1, p1, q2, p2 = mask_box 324 | x[i, 0, q1:q2, p1:p2] = self.mean_rgb[0] 325 | x[i, 1, q1:q2, p1:p2] = self.mean_rgb[1] 326 | x[i, 2, q1:q2, p1:p2] = self.mean_rgb[2] 327 | # annotate the image tensor 328 | local_boxes = torch.from_numpy(np.stack(local_boxes)).cuda() 329 | masks = torch.from_numpy(np.stack(masks)).float().cuda() 330 | x = torch.cat([x, masks], dim=1) 331 | x.local_boxes = local_boxes 332 | return x, y 333 | 334 | else: 335 | x = x.cuda().mul_(2.).sub_(1.) 336 | y = y.cuda().mul_(2.).sub_(1.) 337 | return x, y 338 | 339 | def _random_mask_in_local_box(self, seed=None): 340 | input_size = self.size 341 | local_size = self.local_size 342 | mh_range, mw_range = self.mask_size 343 | if seed is not None: 344 | np.random.seed(int(seed.sum() * 100)) 345 | 346 | # generate a random mask inside a local box 347 | max_offset = input_size - local_size 348 | y1, x1 = np.random.randint(0, max_offset + 1, 2) 349 | y2, x2 = np.array([y1, x1]) + local_size 350 | h = np.random.randint(mh_range[0], mh_range[1] + 1) 351 | w = np.random.randint(mw_range[0], mw_range[1] + 1) 352 | q1 = y1 + np.random.randint(0, local_size - h + 1) 353 | p1 = x1 + np.random.randint(0, local_size - w + 1) 354 | q2 = q1 + h 355 | p2 = p1 + w 356 | mask = np.zeros([1, input_size, input_size], dtype=np.float32) 357 | mask[:, q1:q2, p1:p2] = 1.0 358 | return mask, np.array([q1, p1, q2, p2]), np.array([y1, x1, y2, x2]) 359 | 360 | 361 | # ================= 362 | # Dataset Factories 363 | # ================= 364 | 365 | 366 | def CITYSCAPES(config): 367 | return Cityscapes(config, 'train'), Cityscapes(config, 'val') 368 | 369 | 370 | def MAPS(config): 371 | return Maps(config, 'train'), Maps(config, 'val') 372 | 373 | 374 | def CELEBA(config): 375 | return CelebA(config, 'train'), CelebA(config, 'val') 376 | 377 | 378 | DATASETS = { 379 | 'cityscapes': CITYSCAPES, 380 | 'maps': MAPS, 381 | 'celeba': CELEBA, 382 | } 383 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from data import DataIterator 5 | import os 6 | from tensorboardX import SummaryWriter 7 | from constants import MODES, MODE_BASE, MODE_PRED, MODE_MR 8 | from losses import LOSSES, LaplaceMLELoss, GaussianMLELoss, L1Loss, MSELoss 9 | 10 | 11 | class BaseModel(nn.Module): 12 | def __init__(self, config, 13 | net_d: nn.Module = None, net_g: nn.Module = None, 14 | net_p: nn.Module = None): 15 | super().__init__() 16 | self.config = config 17 | assert config.mode in MODES 18 | self.mode = config.mode 19 | self.log_dir = config.log_dir 20 | self.loss_config = getattr(config.model.losses, self.mode) 21 | if getattr(config.data, 'norm', None) is not None: 22 | self.input_mean = torch.tensor( 23 | config.data.norm.mean, device='cuda' 24 | ).view(1, 3, 1, 1) 25 | self.input_std = torch.tensor( 26 | config.data.norm.std, device='cuda' 27 | ).view(1, 3, 1, 1) 28 | else: 29 | self.input_mean = None 30 | self.input_std = None 31 | self.device = torch.device('cuda:0') if torch.cuda.is_available() \ 32 | else torch.device('cpu') 33 | self.net_d = net_d 34 | self.net_g = net_g 35 | self.net_p = net_p 36 | 37 | # GAN loss 38 | self.gan_loss = LOSSES[config.model.gan.type]( 39 | with_logits=config.model.gan.with_logits 40 | ) 41 | if self.net_d is not None and self.gan_loss.need_lipschitz_d: 42 | for m in [m for m in self.net_d.modules() if m._parameters]: 43 | nn.utils.spectral_norm(m) 44 | 45 | # MLE loss 46 | self.mle = config.model.mle.type 47 | if self.mle == 'gaussian': 48 | self.mle_loss = GaussianMLELoss( 49 | **config.model.mle.options._asdict()) 50 | elif self.mle == 'laplace': 51 | self.mle_loss = LaplaceMLELoss( 52 | **config.model.mle.options._asdict()) 53 | else: 54 | raise ValueError('Invalid MLE loss type: %s' % self.mle) 55 | 56 | # Other losses 57 | self._l1_loss = L1Loss() 58 | self._mse_loss = MSELoss() 59 | 60 | # Build optimizers 61 | self.optim_d, self.optim_g, self.optim_p = None, None, None 62 | if net_d is not None and self.mode != MODE_PRED \ 63 | and self.loss_config.gan_weight: 64 | self.optim_d = self._build_optimizer( 65 | config.d_optimizer, net_d.parameters()) 66 | if net_g is not None and self.mode != MODE_PRED: 67 | self.optim_g = self._build_optimizer( 68 | config.g_optimizer, net_g.parameters()) 69 | if net_p is not None and self.mode == MODE_PRED: 70 | self.optim_p = self._build_optimizer( 71 | config.p_optimizer, net_p.parameters()) 72 | 73 | # Build learning rate schedulers 74 | self.lr_sched_d, self.lr_sched_g, self.lr_sched_p = None, None, None 75 | if self.optim_d is not None and config.d_lr_scheduler is not None: 76 | self.lr_sched_d = self._build_lr_scheduler( 77 | config.d_lr_scheduler, self.optim_d) 78 | if self.optim_g is not None and config.g_lr_scheduler is not None: 79 | self.lr_sched_g = self._build_lr_scheduler( 80 | config.g_lr_scheduler, self.optim_g) 81 | if self.optim_p is not None and config.p_lr_scheduler is not None: 82 | self.lr_sched_p = self._build_lr_scheduler( 83 | config.p_lr_scheduler, self.optim_p) 84 | 85 | # Set train/eval 86 | if self.mode == MODE_BASE: 87 | net_g.train() 88 | net_d.train() 89 | elif self.mode == MODE_PRED: 90 | net_p.train() 91 | elif self.mode == MODE_MR: 92 | net_g.train() 93 | net_d.train() 94 | if net_p is not None: 95 | net_p.eval() 96 | else: 97 | raise NotImplementedError 98 | 99 | def forward(self, x): 100 | if self.mode == MODE_PRED: 101 | return self.net_p(x)[0] 102 | elif self.mode == MODE_MR: 103 | return self.net_g(x)[0] 104 | else: 105 | return self.net_g(x)[0] 106 | 107 | def optimize_d(self, x, y, step, summarize=False): 108 | raise NotImplementedError 109 | 110 | def optimize_g(self, x, y, step, summarize=False): 111 | raise NotImplementedError 112 | 113 | def optimize_p(self, x, y, step, summarize=False): 114 | raise NotImplementedError 115 | 116 | def summarize(self, writer: SummaryWriter, step, 117 | train_sample_iterator: DataIterator, 118 | val_sample_iterator: DataIterator): 119 | raise NotImplementedError 120 | 121 | @staticmethod 122 | def write_summary(writer: SummaryWriter, summary, step): 123 | for summary_type, summary_fn in [ 124 | ('scalar', writer.add_scalar), 125 | ('image', writer.add_image), 126 | ('histogram', writer.add_histogram), 127 | ('text', writer.add_text) 128 | ]: 129 | if summary_type not in summary: 130 | continue 131 | for name, value in summary[summary_type].items(): 132 | summary_fn(name, value, step) 133 | 134 | def save(self, step): 135 | os.makedirs(os.path.join(self.log_dir, 'ckpt'), exist_ok=True) 136 | if self.net_d is not None: 137 | save_path = os.path.join(self.log_dir, 'ckpt', '%d-d.pt' % step) 138 | torch.save({'state': self.net_d.state_dict(), 'step': step}, 139 | save_path) 140 | if self.net_g is not None: 141 | save_path = os.path.join(self.log_dir, 'ckpt', '%d-g.pt' % step) 142 | torch.save({'state': self.net_g.state_dict(), 'step': step}, 143 | save_path) 144 | if self.net_p is not None: 145 | save_path = os.path.join(self.log_dir, 'ckpt', '%d-p.pt' % step) 146 | torch.save({'state': self.net_p.state_dict(), 'step': step}, 147 | save_path) 148 | 149 | def load(self, path): 150 | step = 0 151 | if self.net_d is not None: 152 | step = self.load_module(self.net_d, path + '-d.pt') 153 | if self.net_g is not None: 154 | step = self.load_module(self.net_g, path + '-g.pt') 155 | if self.net_p is not None: 156 | step = self.load_module(self.net_p, path + '-p.pt') 157 | return step 158 | 159 | def _build_uncertainty_image(self, log_var): 160 | log_var = log_var.clone() 161 | log_var_min = self.config.log_dispersion_min 162 | log_var_gap = self.config.log_dispersion_max - log_var_min 163 | 164 | # Add legend 165 | h = log_var.size(2) 166 | l_size = min(h // 10, 10) 167 | for i in range(8): 168 | log_var[:, :, i * l_size:(i + 1) * l_size, 0:l_size] = i * -1. 169 | uncertainty = log_var - log_var_min 170 | uncertainty /= (log_var_gap * 0.5) 171 | uncertainty -= 1.0 172 | uncertainty = torch.clamp(uncertainty, min=-1., max=1.) 173 | return uncertainty.expand([-1, 3, -1, -1]) 174 | 175 | @staticmethod 176 | def _build_sigma_offset_images(decode_fn, mean, log_var, index): 177 | stddev = log_var.exp().sqrt() 178 | mean_noffset = mean.clone() 179 | mean_noffset[:, :stddev.size(1), ...] -= stddev 180 | mean_poffset = mean.clone() 181 | mean_poffset[:, :stddev.size(1), ...] += stddev 182 | noffset = decode_fn(mean_noffset, index) 183 | poffset = decode_fn(mean_poffset, index) 184 | return noffset.clamp(-1, 1), poffset.clamp(-1, 1) 185 | 186 | @staticmethod 187 | def load_module(module: nn.Module, path, strict=True): 188 | ckpt = torch.load(path) 189 | module.load_state_dict(ckpt['state'], strict=strict) 190 | return ckpt['step'] 191 | 192 | @staticmethod 193 | def _build_optimizer(optim_config, params) -> torch.optim.Optimizer: 194 | return getattr(torch.optim, optim_config.type)( 195 | params, **optim_config.options._asdict()) 196 | 197 | @staticmethod 198 | def _build_lr_scheduler(lr_config, optimizer): 199 | return getattr(torch.optim.lr_scheduler, lr_config.type)( 200 | optimizer, **lr_config.options._asdict()) 201 | 202 | @staticmethod 203 | def _clip_grad_value(optimizer, clip_value): 204 | for group in optimizer.param_groups: 205 | nn.utils.clip_grad_value_(group['params'], clip_value) 206 | 207 | @staticmethod 208 | def _clip_grad_norm(optimizer, max_norm, norm_type=2): 209 | for group in optimizer.param_groups: 210 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type) 211 | 212 | @staticmethod 213 | def clip_grad(optimizer, clip_grad_config=None): 214 | if clip_grad_config is None: 215 | return 216 | 217 | if clip_grad_config.type == 'value': 218 | BaseModel._clip_grad_value( 219 | optimizer, **clip_grad_config.options._asdict() 220 | ) 221 | elif clip_grad_config.type == 'norm': 222 | BaseModel._clip_grad_norm( 223 | optimizer, **clip_grad_config.options._asdict() 224 | ) 225 | else: 226 | raise ValueError('Invalid clip_grad type: {}' 227 | .format(clip_grad_config.type)) 228 | 229 | def undo_norm(self, x): 230 | if self.input_mean is None: 231 | # [-1, 1] -> [-1, 1] 232 | return x 233 | else: 234 | # N(0, 1) -> [-1, 1] 235 | return (x * self.input_std + self.input_mean) * 2. - 1. 236 | 237 | def sample_statistics(self, samples): 238 | """Calculate sample statistics 239 | 240 | Args: 241 | samples: 5D tensor of shape 242 | or 4D tensor of shape 243 | Returns: 244 | sample_1st: 4D tensor of shape 245 | sample_2nd: 4D tensor of shape 246 | """ 247 | if len(samples.size()) == 4: 248 | samples = samples.unsqueeze(0) 249 | if isinstance(self.mle_loss, GaussianMLELoss): 250 | num_mr_samples = samples.size(1) 251 | sample_1st = samples.mean(dim=1, keepdim=True) 252 | # Tensor.std has bug up to PyTorch 0.4.1 253 | sample_2nd = (samples - sample_1st) ** 2 254 | sample_2nd = sample_2nd.sum(dim=1, keepdim=True) 255 | sample_2nd /= num_mr_samples - 1 256 | 257 | # Laplace statistics 258 | elif isinstance(self.mle_loss, LaplaceMLELoss): 259 | sample_1st, _ = samples.median( 260 | dim=1, keepdim=True) 261 | sample_2nd = torch.abs(samples - sample_1st) 262 | sample_2nd = sample_2nd.mean(dim=1, keepdim=True) 263 | else: 264 | raise RuntimeError('Unknown type of MLE loss') 265 | 266 | return sample_1st.squeeze(1), sample_2nd.squeeze(1) 267 | 268 | def mle_target(self, y): 269 | """Process ground truth y to proper MLE target""" 270 | raise NotImplementedError 271 | 272 | def build_d_input(self, x, samples): 273 | """Build discriminator input""" 274 | raise NotImplementedError 275 | 276 | def accumulate_mr_grad(self, x, y, summarize=False): 277 | # Initialize summaries 278 | scalar = {} 279 | histogram = {} 280 | image = {} 281 | 282 | num_mr = self.config.num_mr 283 | num_mr_samples = self.config.num_mr_samples 284 | 285 | loss = 0. 286 | 287 | # Get predictive mean and variance 288 | if self.net_p is not None: 289 | with torch.no_grad(): 290 | pred_1st, pred_log_2nd = self.net_p(x[:num_mr]) 291 | pred_2nd = torch.exp(pred_log_2nd) 292 | 293 | if summarize: 294 | scalar['variance/pred'] = pred_2nd.detach().mean() 295 | histogram['variance_hist/pred'] = pred_2nd.detach().view(-1) 296 | else: 297 | pred_1st, pred_log_2nd, pred_2nd = None, None, None 298 | 299 | # Get samples 300 | samples, _ = self.net_g(x[:num_mr], num_samples=num_mr_samples) 301 | 302 | # GAN loss 303 | if self.loss_config.gan_weight > 0: 304 | d_input = self.build_d_input(x, samples) 305 | fake_v = self.net_d(d_input) 306 | gan_loss = self.gan_loss(fake_v, True) 307 | loss += self.loss_config.gan_weight * gan_loss 308 | 309 | if summarize: 310 | scalar['loss/g/gan'] = gan_loss.detach() 311 | 312 | # Get sample mean and variance 313 | samples = samples.view( 314 | num_mr, num_mr_samples, *list(samples.size()[1:])) 315 | sample_1st, sample_2nd = self.sample_statistics(samples) 316 | 317 | if summarize: 318 | for i in range(min(16, samples.size(0))): 319 | image['train_samples/%d' % i] = torch.cat( 320 | torch.unbind(samples[i, :5].detach()), 2 321 | ) 322 | scalar['variance/sample'] = sample_2nd.detach().mean() 323 | histogram['variance_hist/sample'] = sample_2nd.detach().view(-1) 324 | 325 | # Direct MLE without predictor 326 | if self.loss_config.mle_weight > 0: 327 | if self.name == 'glcic': 328 | masks = x[:num_mr, -1:, ...] 329 | sample_2nd = ( 330 | sample_2nd * masks + 331 | math.exp(self.config.log_dispersion_min) * (1. - masks) 332 | ) 333 | normalizer = x[:num_mr, -1:, ...].sum() 334 | else: 335 | normalizer = None 336 | if isinstance(self.mle_loss, GaussianMLELoss): 337 | mle_loss = self.mle_loss( 338 | sample_1st, sample_2nd, self.mle_target(y[:num_mr]), 339 | log_dispersion=False, normalizer=normalizer) 340 | elif isinstance(self.mle_loss, LaplaceMLELoss): 341 | sample_mean = samples.mean(1) 342 | with torch.no_grad(): 343 | deviation = self.mle_target(y[:num_mr]) - sample_1st 344 | mean_target = (deviation + sample_mean).detach() 345 | mle_loss = self.mle_loss( 346 | sample_mean, sample_2nd, mean_target, 347 | log_dispersion=False, normalizer=normalizer 348 | ) 349 | else: 350 | raise RuntimeError('Invalid MLE loss') 351 | 352 | loss += self.loss_config.mle_weight * mle_loss 353 | 354 | if summarize: 355 | scalar['loss/g/mle'] = mle_loss.detach() 356 | 357 | # Moment matching 358 | if self.loss_config.mr_1st_weight or self.loss_config.mr_2nd_weight: 359 | normalizer = ( 360 | x[:num_mr, -1:, ...].sum() if self.name == 'glcic' else None 361 | ) 362 | if isinstance(self.mle_loss, GaussianMLELoss): 363 | mr_1st_loss = self._mse_loss( 364 | sample_1st, pred_1st, normalizer=normalizer 365 | ) 366 | elif isinstance(self.mle_loss, LaplaceMLELoss): 367 | sample_mean = samples.mean(1) 368 | with torch.no_grad(): 369 | mean_target = (pred_1st - sample_1st + sample_mean).detach() 370 | mr_1st_loss = self._mse_loss( 371 | sample_mean, mean_target, normalizer=normalizer 372 | ) 373 | else: 374 | raise RuntimeError('Invalid MLE loss') 375 | mr_2nd_loss = self._mse_loss( 376 | sample_2nd, pred_2nd, normalizer=normalizer 377 | ) 378 | weighted_mr_loss = \ 379 | self.loss_config.mr_1st_weight * mr_1st_loss + \ 380 | self.loss_config.mr_2nd_weight * mr_2nd_loss 381 | loss += weighted_mr_loss 382 | 383 | if summarize: 384 | scalar['loss/g/mr_1st'] = mr_1st_loss.detach() 385 | scalar['loss/g/mr_2nd'] = mr_2nd_loss.detach() 386 | scalar['loss/g/mr'] = weighted_mr_loss.detach() 387 | 388 | if summarize: 389 | scalar['loss/g/total'] = loss.detach() 390 | 391 | loss.backward() 392 | 393 | return {'scalar': scalar, 'histogram': histogram, 'image': image} 394 | 395 | 396 | def generate_noise(noise_type, noise_dim, like): 397 | if noise_type == 'gaussian': 398 | z = torch.randn([ 399 | like.size(0), noise_dim, 400 | like.size(2), like.size(3) 401 | ], device='cuda') 402 | elif noise_type == 'uniform': 403 | z = torch.rand( 404 | [like.size(0), noise_dim, 405 | like.size(2), like.size(3)], 406 | device='cuda' 407 | ) 408 | elif noise_type == 'bernoulli': 409 | z = torch.bernoulli( 410 | torch.ones([ 411 | like.size(0), noise_dim, 412 | like.size(2), like.size(3) 413 | ], device='cuda') * 0.5 414 | ) 415 | elif noise_type == 'categorical': 416 | z = torch.zeros([like.size(0), noise_dim, 417 | like.size(2), like.size(3)], device='cuda') 418 | z_idx = torch.randint( 419 | low=0, high=noise_dim, 420 | size=[like.size(0), 1, like.size(2), like.size(3)], 421 | dtype=torch.long, device='cuda' 422 | ) 423 | z.scatter_(1, z_idx, 1) 424 | else: 425 | raise ValueError('Invalid noise type %s' % noise_type) 426 | 427 | return z 428 | -------------------------------------------------------------------------------- /models/srgan.py: -------------------------------------------------------------------------------- 1 | # This SRGAN implementation is based on https://github.com/zijundeng/SRGAN 2 | import math 3 | import torch.nn.functional as F 4 | import torch 5 | from torch import nn 6 | from tensorboardX import SummaryWriter 7 | from data import DataIterator 8 | from constants import MODE_BASE, MODE_PRED, MODE_MR 9 | from .base import BaseModel, generate_noise 10 | 11 | 12 | class ResidualBlock(nn.Module): 13 | def __init__(self, channels, noise_dim=0, noise_type=None): 14 | super(ResidualBlock, self).__init__() 15 | self.noise_dim = noise_dim 16 | self.noise_type = noise_type 17 | self.conv1 = nn.Conv2d( 18 | channels + noise_dim, channels, kernel_size=3, padding=1 19 | ) 20 | self.bn1 = nn.BatchNorm2d(channels) 21 | self.prelu = nn.PReLU() 22 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 23 | self.bn2 = nn.BatchNorm2d(channels) 24 | 25 | def forward(self, x): 26 | x_in = x if self.noise_dim <= 0 else torch.cat([ 27 | x, generate_noise( 28 | self.noise_type, 29 | self.noise_dim, 30 | like=x 31 | ) 32 | ], 1) 33 | residual = self.conv1(x_in) 34 | residual = self.bn1(residual) 35 | residual = self.prelu(residual) 36 | residual = self.conv2(residual) 37 | residual = self.bn2(residual) 38 | return x + residual 39 | 40 | 41 | class UpsampleBLock(nn.Module): 42 | def __init__(self, in_channels, up_scale, noise_dim=0, noise_type=None): 43 | super(UpsampleBLock, self).__init__() 44 | self.noise_dim = noise_dim 45 | self.noise_type = noise_type 46 | self.conv = nn.Conv2d( 47 | in_channels + noise_dim, in_channels * up_scale ** 2, 48 | kernel_size=3, padding=1 49 | ) 50 | self.pixel_shuffle = nn.PixelShuffle(up_scale) 51 | self.prelu = nn.PReLU() 52 | 53 | def forward(self, x): 54 | x = x if self.noise_dim <= 0 else torch.cat([ 55 | x, generate_noise( 56 | self.noise_type, 57 | self.noise_dim, 58 | like=x 59 | ) 60 | ], 1) 61 | x = self.conv(x) 62 | x = self.pixel_shuffle(x) 63 | x = self.prelu(x) 64 | return x 65 | 66 | 67 | class Generator(nn.Module): 68 | def __init__(self, config, as_predictor=False): 69 | self.config = config 70 | self.as_predictor = as_predictor 71 | self.noise_dim = nd = config.model.noise_dim 72 | self.noise_type = config.model.noise_type 73 | upsample_block_num = int(math.log(4, 2)) 74 | 75 | super(Generator, self).__init__() 76 | self.block1 = nn.Sequential( 77 | nn.Conv2d(3, 64, kernel_size=9, padding=4), 78 | nn.PReLU() 79 | ) 80 | self.block2 = ResidualBlock( 81 | 64, nd[6] if nd[6] > 0 and not as_predictor else 0, self.noise_type 82 | ) 83 | self.block3 = ResidualBlock( 84 | 64, nd[5] if nd[5] > 0 and not as_predictor else 0, self.noise_type 85 | ) 86 | self.block4 = ResidualBlock( 87 | 64, nd[4] if nd[4] > 0 and not as_predictor else 0, self.noise_type 88 | ) 89 | self.block5 = ResidualBlock( 90 | 64, nd[3] if nd[3] > 0 and not as_predictor else 0, self.noise_type 91 | ) 92 | self.block6 = ResidualBlock( 93 | 64, nd[2] if nd[2] > 0 and not as_predictor else 0, self.noise_type 94 | ) 95 | self.block7 = nn.Sequential( 96 | nn.Conv2d( 97 | 64 + nd[1] if nd[1] > 0 and not as_predictor else 64, 64, 98 | kernel_size=3, padding=1 99 | ), 100 | nn.PReLU() 101 | ) 102 | self.block8 = nn.Sequential(*[ 103 | UpsampleBLock( 104 | 64, 2, 105 | nd[0] if i == 0 and nd[0] > 0 and not as_predictor else 0, 106 | self.noise_type 107 | ) for i in range(upsample_block_num) 108 | ]) 109 | self.mean_conv = nn.Conv2d(64, 3, kernel_size=9, padding=4) 110 | self.dispersion_conv = nn.Conv2d(64, 3, kernel_size=9, padding=4) 111 | 112 | def forward(self, x, num_samples=1): 113 | if num_samples > 1: 114 | x = x\ 115 | .unsqueeze(1)\ 116 | .expand(-1, num_samples, -1, -1, -1)\ 117 | .contiguous()\ 118 | .view(x.size(0) * num_samples, x.size(1), x.size(2), x.size(3)) 119 | block1 = self.block1(x) 120 | block2 = self.block2(block1) 121 | block3 = self.block3(block2) 122 | block4 = self.block4(block3) 123 | block5 = self.block5(block4) 124 | block6 = self.block6(block5) 125 | block6 = block6 if \ 126 | self.noise_dim[1] <= 0 or self.as_predictor else torch.cat([ 127 | block6, generate_noise( 128 | self.noise_type, 129 | self.noise_dim[1], 130 | like=block6 131 | ) 132 | ], 1) 133 | block7 = self.block7(block6) 134 | block8 = self.block8(block1 + block7) 135 | mean = self.mean_conv(block8) 136 | mean = F.tanh(mean) 137 | 138 | if self.as_predictor: 139 | log_dispersion = self.dispersion_conv(block8) 140 | return mean, log_dispersion 141 | else: 142 | return mean, None 143 | 144 | 145 | class Discriminator(nn.Module): 146 | def __init__(self, config): 147 | super(Discriminator, self).__init__() 148 | self.config = config 149 | self.net = nn.Sequential( 150 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 151 | nn.LeakyReLU(0.2), 152 | 153 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 154 | nn.BatchNorm2d(64), 155 | nn.LeakyReLU(0.2), 156 | 157 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 158 | nn.BatchNorm2d(128), 159 | nn.LeakyReLU(0.2), 160 | 161 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), 162 | nn.BatchNorm2d(128), 163 | nn.LeakyReLU(0.2), 164 | 165 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 166 | nn.BatchNorm2d(256), 167 | nn.LeakyReLU(0.2), 168 | 169 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), 170 | nn.BatchNorm2d(256), 171 | nn.LeakyReLU(0.2), 172 | 173 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 174 | nn.BatchNorm2d(512), 175 | nn.LeakyReLU(0.2), 176 | 177 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 178 | nn.BatchNorm2d(512), 179 | nn.LeakyReLU(0.2), 180 | 181 | nn.AdaptiveAvgPool2d(1), 182 | nn.Conv2d(512, 1024, kernel_size=1), 183 | nn.LeakyReLU(0.2), 184 | nn.Conv2d(1024, 1, kernel_size=1) 185 | ) 186 | 187 | def forward(self, x): 188 | return self.net(x) 189 | 190 | 191 | class SRGAN(BaseModel): 192 | name = 'srgan' 193 | 194 | def __init__(self, config): 195 | super().__init__( 196 | config, 197 | Discriminator(config) 198 | if config.mode in (MODE_BASE, MODE_MR) else None, 199 | Generator(config) 200 | if config.mode in (MODE_BASE, MODE_MR) else None, 201 | Generator(config, as_predictor=True) 202 | if config.mode in (MODE_PRED, MODE_MR) else None, 203 | ) 204 | 205 | def optimize_d(self, x, y, step, summarize=False): 206 | assert self.mode in (MODE_BASE, MODE_MR) 207 | if self.loss_config.gan_weight <= 0: 208 | return {} 209 | image = {} 210 | # gan loss (BASELINE || MR) 211 | fake_y = self.net_g(x)[0] 212 | fake_v = self.net_d(fake_y) 213 | real_v = self.net_d(y) 214 | real_loss = self.gan_loss(real_v, True) 215 | fake_loss = self.gan_loss(fake_v, False) 216 | gan_loss = (real_loss + fake_loss) / 2 217 | # backprop & minimize 218 | self.optim_d.zero_grad() 219 | loss = gan_loss * self.loss_config.gan_weight 220 | loss.backward() 221 | self.clip_grad(self.optim_d, self.config.d_optimizer.clip_grad) 222 | self.optim_d.step() 223 | # image summaries 224 | if summarize: 225 | for i in range(y.size(0)): 226 | real_image_id = 'd_inputs/real/{}'.format(i) 227 | fake_image_id = 'd_inputs/fake/{}'.format(i) 228 | image[real_image_id] = (y[i] + 1) / 2 229 | image[fake_image_id] = (fake_y[i] + 1) / 2 230 | return { 231 | 'scalar': { 232 | 'loss/d/total': loss, 233 | 'loss/d/gan_real': real_loss, 234 | 'loss/d/gan_fake': fake_loss, 235 | 'loss/d/gan_total': gan_loss, 236 | }, 237 | 'image': image 238 | } 239 | 240 | def optimize_g(self, x, y, step, summarize=False): 241 | assert self.mode in (MODE_BASE, MODE_MR) 242 | # prepare some accumulators 243 | scalar = {'loss/g/total': 0.} 244 | histogram = {} 245 | image = {} 246 | loss = 0. 247 | self.optim_g.zero_grad() 248 | fake_y = self.net_g(x)[0] 249 | # GAN loss (BASELINE) 250 | if self.mode == MODE_BASE and self.loss_config.gan_weight > 0: 251 | fake_v = self.net_d(fake_y) 252 | gan_loss = self.gan_loss(fake_v, True) 253 | weighted_gan_loss = self.loss_config.gan_weight * gan_loss 254 | loss += weighted_gan_loss 255 | if summarize: 256 | scalar['loss/g/gan'] = gan_loss.detach() 257 | scalar['loss/g/total'] += weighted_gan_loss.detach() 258 | # MSE loss (BASELINE) 259 | if self.mode == MODE_BASE and self.loss_config.recon_weight > 0: 260 | mse_loss = self._mse_loss(fake_y, y) 261 | weighted_mse_loss = self.loss_config.recon_weight * mse_loss 262 | loss += weighted_mse_loss 263 | if summarize: 264 | scalar['loss/g/mse'] = mse_loss.detach() 265 | scalar['loss/g/total'] += weighted_mse_loss.detach() 266 | # backprop before accumulating mr gradients 267 | if isinstance(loss, torch.Tensor): 268 | loss.backward() 269 | # MR loss (MR) 270 | if self.mode == MODE_MR: 271 | mr_summaries = self.accumulate_mr_grad(x, y, summarize) 272 | mr_scalar = mr_summaries['scalar'] 273 | mr_histogram = mr_summaries['histogram'] 274 | mr_image = mr_summaries['image'] 275 | torch.cuda.empty_cache() 276 | if summarize: 277 | scalar['loss/g/total'] += mr_scalar['loss/g/total'] 278 | del mr_scalar['loss/g/total'] 279 | scalar.update(mr_scalar) 280 | histogram.update(mr_histogram) 281 | for i in range(min(16, self.config.num_mr)): 282 | image_id = 'train_samples/%d' % i 283 | if image_id in image: 284 | image[image_id] = torch.cat([ 285 | image[image_id], mr_image[image_id] 286 | ], 2) 287 | else: 288 | image[image_id] = mr_image[image_id] 289 | image[image_id] = (image[image_id] + 1) / 2 290 | # Optimize the network 291 | self.clip_grad(self.optim_g, self.config.g_optimizer.clip_grad) 292 | self.optim_g.step() 293 | return {'scalar': scalar, 'histogram': histogram, 'image': image} 294 | 295 | def optimize_p(self, x, y, step, summarize=False): 296 | # MLE loss (PRED) 297 | assert self.mode == MODE_PRED 298 | # MLE loss (PRED) 299 | mean, log_var = self.net_p(x) 300 | loss = self.mle_loss(mean, log_var, y) 301 | # backprop & minimize 302 | self.optim_p.zero_grad() 303 | loss.backward() 304 | self.clip_grad(self.optim_p, self.config.p_optimizer.clip_grad) 305 | self.optim_p.step() 306 | return { 307 | 'scalar': { 308 | 'loss/p/mle': loss 309 | } 310 | } 311 | 312 | def summarize(self, writer: SummaryWriter, step, 313 | train_sample_iterator: DataIterator, 314 | val_sample_iterator: DataIterator): 315 | with torch.no_grad(): 316 | for iter_type, iterator in ( 317 | ('train', train_sample_iterator), 318 | ('val', val_sample_iterator) 319 | ): 320 | mle_loss = 0. 321 | sample_2nds = [] 322 | pred_log_2nds = [] 323 | for i in range(len(iterator)): 324 | x, y = next(iterator) 325 | h, w = y.size()[2:4] 326 | if self.mode == MODE_BASE: 327 | self.net_g.eval() 328 | num_samples = 12 329 | g_x = x.expand(num_samples, -1, -1, -1) 330 | samples = self.net_g(g_x)[0] 331 | samples_reshaped = samples.view( 332 | 1, num_samples, *list(samples.size()[1:]) 333 | ) 334 | sample_1st, sample_2nd = self.sample_statistics( 335 | samples_reshaped 336 | ) 337 | zeros = torch.zeros_like(y[:, :3, ...]) 338 | self.net_g.train() 339 | collage = torch.cat([ 340 | torch.cat([ 341 | (F.interpolate(x[:, :3, ...], [h, w]) + 1) / 2, 342 | (y + 1) / 2, zeros, zeros, zeros, zeros 343 | ], dim=3), 344 | torch.cat([ 345 | (fy.unsqueeze(0) + 1) / 2 for fy in 346 | torch.unbind(samples[:num_samples // 2]) 347 | ], dim=3), 348 | torch.cat([ 349 | (fy.unsqueeze(0) + 1) / 2 for fy in 350 | torch.unbind(samples[num_samples // 2:]) 351 | ], dim=3) 352 | ], dim=2) 353 | sample_2nds.append(sample_2nd) 354 | writer.add_image( 355 | 'g/{}/{}'.format(iter_type, i), collage, step 356 | ) 357 | elif self.mode == MODE_MR: 358 | pred_1st, pred_log_2nd = self.net_p(x) 359 | self.net_g.eval() 360 | num_samples = 12 361 | g_x = x.expand(num_samples, -1, -1, -1) 362 | samples = self.net_g(g_x)[0] 363 | self.net_g.train() 364 | sample_1st, sample_2nd = self.sample_statistics(samples) 365 | sample_log_2nd = (sample_2nd + 1e-4).log() 366 | log_2nds = torch.cat([pred_log_2nd, sample_log_2nd], 3) 367 | uncertainty = self._build_uncertainty_image(log_2nds) 368 | collage = [ 369 | torch.cat([ 370 | (F.interpolate(x[:, :3, ...], [h, w]) + 1) / 2, 371 | (y + 1) / 2, 372 | (pred_1st + 1) / 2, 373 | (sample_1st + 1) / 2, 374 | (uncertainty + 1) / 2, 375 | ], dim=3), 376 | torch.cat([ 377 | (fy.unsqueeze(0) + 1) / 2 for fy in 378 | torch.unbind(samples[:num_samples // 2]) 379 | ], dim=3), 380 | torch.cat([ 381 | (fy.unsqueeze(0) + 1) / 2 for fy in 382 | torch.unbind(samples[num_samples // 2:]) 383 | ], dim=3) 384 | ] 385 | collage = torch.cat(collage, dim=2) 386 | writer.add_image( 387 | 'collage/{}/{}'.format(iter_type, i), collage, step 388 | ) 389 | elif self.mode == MODE_PRED: 390 | self.net_p.eval() 391 | pred_1st, pred_log_2nd = self.net_p(x) 392 | self.net_p.train() 393 | uncertainty = self._build_uncertainty_image( 394 | pred_log_2nd 395 | ) 396 | collage = torch.cat([ 397 | torch.cat([ 398 | (F.interpolate(x[:, :3, ...], [h, w]) + 1) / 2, 399 | (y + 1) / 2 400 | ], dim=3), 401 | torch.cat([ 402 | (pred_1st + 1) / 2, 403 | (uncertainty + 1) / 2, 404 | ], dim=3) 405 | ], dim=2) 406 | pred_log_2nds.append(pred_log_2nd) 407 | writer.add_image( 408 | 'p/{}/{}'.format(iter_type, i), collage, step 409 | ) 410 | 411 | # Validation loss 412 | if iter_type == 'val': 413 | mle_loss += self.mle_loss(pred_1st, pred_log_2nd, y) 414 | 415 | if self.mode == MODE_PRED and iter_type == 'val': 416 | pred_log_2nd = torch.stack(pred_log_2nds) 417 | pred_2nd = torch.exp(pred_log_2nd) 418 | writer.add_scalar( 419 | 'variance/pred', pred_2nd.mean(), step 420 | ) 421 | writer.add_histogram( 422 | 'variance_hist/pred', pred_2nd, step 423 | ) 424 | if self.mode == MODE_BASE and iter_type == 'val': 425 | sample_2nd = torch.stack(sample_2nds) 426 | writer.add_scalar( 427 | 'variance/sample', sample_2nd.mean(), step 428 | ) 429 | writer.add_histogram( 430 | 'variance_hist/sample', sample_2nd, step 431 | ) 432 | if self.mode == MODE_PRED and iter_type == 'val': 433 | writer.add_scalar( 434 | 'loss/p/mle_val', mle_loss / len(iterator), step 435 | ) 436 | 437 | def build_d_input(self, x, samples): 438 | return samples 439 | 440 | def mle_target(self, y): 441 | return y 442 | -------------------------------------------------------------------------------- /models/pix2pix.py: -------------------------------------------------------------------------------- 1 | # This Pix2Pix implementation is based on 2 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | from torch.optim import lr_scheduler 8 | import functools 9 | from tensorboardX import SummaryWriter 10 | from data import DataIterator 11 | 12 | from models.base import BaseModel, generate_noise 13 | from constants import MODE_BASE, MODE_PRED, MODE_MR 14 | 15 | 16 | class Pix2Pix(BaseModel): 17 | name = 'pix2pix' 18 | 19 | def __init__(self, config): 20 | # Define D & G 21 | if config.mode in [MODE_BASE, MODE_MR]: 22 | net_d = define_d( 23 | config.model.in_channels + config.model.out_channels, 24 | ndf=config.model.num_features, 25 | which_model_net_d='basic', 26 | use_sigmoid=False 27 | ) 28 | net_g = Pix2PixUnet(config.mode, config) 29 | else: 30 | net_d, net_g = None, None 31 | 32 | # Define P 33 | if config.mode in [MODE_PRED, MODE_MR] and config.model.predictor: 34 | net_p = Pix2PixUnet(MODE_PRED, config) 35 | else: 36 | net_p = None 37 | 38 | super().__init__(config, net_d, net_g, net_p) 39 | 40 | def optimize_p(self, x, y, step, summarize=False): 41 | self.optim_p.zero_grad() 42 | scalar_summary = {} 43 | 44 | # Calculate losses 45 | mean, log_var = self.net_p(x) 46 | loss = self.mle_loss(mean, log_var, y) 47 | loss.backward() 48 | 49 | self.clip_grad(self.optim_p, self.config.p_optimizer.clip_grad) 50 | self.optim_p.step() 51 | 52 | # Summarize 53 | if summarize: 54 | scalar_summary['loss/p/mle'] = loss.detach() 55 | summary = {'scalar': scalar_summary} 56 | return summary 57 | 58 | def optimize_d(self, x, y, step, summarize=False): 59 | torch.cuda.empty_cache() 60 | if self.loss_config.gan_weight == 0.: 61 | return {} 62 | 63 | self.optim_d.zero_grad() 64 | with torch.no_grad(): 65 | fake_y = self(x) 66 | fake_v = self.net_d(torch.cat([x, fake_y], dim=1)) 67 | real_v = self.net_d(torch.cat([x, y], dim=1)) 68 | real_loss = self.gan_loss(real_v, True) 69 | fake_loss = self.gan_loss(fake_v, False) 70 | loss = (real_loss + fake_loss) * 0.5 * self.loss_config.gan_weight 71 | 72 | loss.backward() 73 | self.clip_grad(self.optim_d, self.config.d_optimizer.clip_grad) 74 | self.optim_d.step() 75 | 76 | if summarize: 77 | scalar_summary = { 78 | 'loss/d/gan_real': real_loss, 79 | 'loss/d/gan_fake': fake_loss, 80 | } 81 | return {'scalar': scalar_summary} 82 | 83 | def mle_target(self, y): 84 | return y 85 | 86 | def build_d_input(self, x, samples): 87 | num_mr = self.config.num_mr 88 | num_samples = self.config.num_mr_samples 89 | 90 | x_dup = x[:num_mr].unsqueeze(1) 91 | x_dup = x_dup.expand(-1, num_samples, -1, -1, -1) 92 | x_dup = x_dup.contiguous().view( 93 | num_mr * num_samples, *list(x_dup.size()[2:])) 94 | return torch.cat([x_dup, samples], 1) 95 | 96 | def optimize_g(self, x, y, step, summarize=False): 97 | torch.cuda.empty_cache() 98 | scalar = {'loss/g/total': 0.} 99 | histogram = {} 100 | image = {} 101 | loss = 0. 102 | self.optim_g.zero_grad() 103 | 104 | if self.mode == MODE_BASE or self.loss_config.recon_weight > 0: 105 | fake_y, _ = self.net_g(x) 106 | 107 | # GAN loss 108 | if self.loss_config.gan_weight > 0: 109 | fake_v = self.net_d(torch.cat([x, fake_y], dim=1)) 110 | gan_loss = self.gan_loss(fake_v, True) 111 | weighted_gan_loss = self.loss_config.gan_weight * gan_loss 112 | loss += weighted_gan_loss 113 | if summarize: 114 | scalar['loss/g/gan'] = gan_loss.detach() 115 | scalar['loss/g/total'] += weighted_gan_loss.detach() 116 | 117 | # L1 loss 118 | if self.loss_config.recon_weight > 0: 119 | l1_loss = self._l1_loss(fake_y, y) 120 | weighted_l1_loss = self.loss_config.recon_weight * l1_loss 121 | loss += weighted_l1_loss 122 | if summarize: 123 | scalar['loss/g/l1'] = l1_loss.detach() 124 | scalar['loss/g/total'] += weighted_l1_loss.detach() 125 | 126 | # Back-prop 127 | loss.backward() 128 | 129 | # Moment matching loss 130 | if self.mode == MODE_MR and ( 131 | self.loss_config.mr_1st_weight > 0 132 | or self.loss_config.mr_2nd_weight > 0 133 | or self.loss_config.mle_weight > 0 134 | ): 135 | mr_summaries = self.accumulate_mr_grad(x, y, summarize) 136 | mr_scalar = mr_summaries['scalar'] 137 | mr_histogram = mr_summaries['histogram'] 138 | mr_image = mr_summaries['image'] 139 | torch.cuda.empty_cache() 140 | 141 | if summarize: 142 | scalar['loss/g/total'] += mr_scalar['loss/g/total'] 143 | del mr_scalar['loss/g/total'] 144 | scalar.update(mr_scalar) 145 | histogram.update(mr_histogram) 146 | for i in range(min(16, self.config.num_mr)): 147 | image_id = 'train_samples/%d' % i 148 | if image_id in image: 149 | image[image_id] = torch.cat([ 150 | image[image_id], mr_image[image_id] 151 | ], 2) 152 | else: 153 | image[image_id] = mr_image[image_id] 154 | image[image_id] = image[image_id] * 0.5 + 0.5 155 | 156 | # Optimize 157 | self.clip_grad(self.optim_g, self.config.g_optimizer.clip_grad) 158 | self.optim_g.step() 159 | 160 | return {'scalar': scalar, 'histogram': histogram, 'image': image} 161 | 162 | def summarize(self, writer: SummaryWriter, step, 163 | train_sample_iterator: DataIterator, 164 | val_sample_iterator: DataIterator): 165 | for iter_type, iterator in ( 166 | ('train', train_sample_iterator), ('val', val_sample_iterator)): 167 | mle_loss = 0. 168 | for i in range(len(iterator)): 169 | x, y = next(iterator) 170 | if self.mode == MODE_BASE: 171 | with torch.no_grad(): 172 | self.net_g.eval() 173 | fake_y = self.net_g(x)[0] 174 | self.net_g.train() 175 | x_plain = self.undo_norm(x) 176 | collage = torch.cat([x_plain, y, fake_y], dim=3) 177 | collage = 0.5 * collage + 0.5 178 | writer.add_image('g/%s/%d' % (iter_type, i), collage, step) 179 | 180 | elif self.mode == MODE_MR: 181 | with torch.no_grad(): 182 | pred_1st, pred_log_2nd = self.net_p(x) 183 | self.net_g.eval() 184 | num_samples = 12 185 | g_x = x.expand(num_samples, -1, -1, -1) 186 | samples = self.net_g(g_x)[0] 187 | self.net_g.train() 188 | 189 | # Build collage 190 | sample_1st, sample_2nd = self.sample_statistics(samples) 191 | log_2nds = torch.cat([ 192 | pred_log_2nd, (sample_2nd + 1e-10).log() 193 | ], 3) 194 | uncertainty = self._build_uncertainty_image(log_2nds) 195 | x_plain = self.undo_norm(x) 196 | collage = [ 197 | torch.cat([ 198 | x_plain, y, pred_1st, sample_1st, uncertainty, 199 | ], dim=3), 200 | torch.cat([ 201 | fy.unsqueeze(0) 202 | for fy in torch.unbind(samples[:num_samples // 2]) 203 | ], dim=3), 204 | torch.cat([ 205 | fy.unsqueeze(0) 206 | for fy in torch.unbind(samples[num_samples // 2:]) 207 | ], dim=3) 208 | ] 209 | collage = torch.cat(collage, dim=2) 210 | collage = collage * 0.5 + 0.5 211 | writer.add_image('collage/%s/%d' % (iter_type, i), collage, 212 | step) 213 | 214 | elif self.mode == MODE_PRED: 215 | with torch.no_grad(): 216 | self.net_p.eval() 217 | pred_1st, pred_log_2nd = self.net_p(x) 218 | self.net_p.train() 219 | 220 | uncertainty = self._build_uncertainty_image( 221 | pred_log_2nd) 222 | 223 | x_plain = self.undo_norm(x) 224 | collage = torch.cat([ 225 | torch.cat([x_plain, y], dim=3), 226 | torch.cat([pred_1st, uncertainty], dim=3) 227 | ], dim=2) 228 | collage = collage * 0.5 + 0.5 229 | writer.add_image('p/%s/%d' % (iter_type, i), collage, step) 230 | 231 | # Validation loss 232 | if iter_type == 'val': 233 | mle_loss += self.mle_loss(pred_1st, pred_log_2nd, y) 234 | 235 | if self.mode == MODE_PRED and iter_type == 'val': 236 | writer.add_scalar('loss/p/mle_val', mle_loss / len(iterator), 237 | step) 238 | 239 | 240 | def get_norm_layer(norm_type='instance'): 241 | if norm_type == 'batch': 242 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 243 | elif norm_type == 'instance': 244 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, 245 | track_running_stats=True) 246 | elif norm_type == 'none': 247 | norm_layer = None 248 | else: 249 | raise NotImplementedError( 250 | 'normalization layer [%s] is not found' % norm_type) 251 | return norm_layer 252 | 253 | 254 | def get_scheduler(optimizer, opt): 255 | if opt.lr_policy == 'lambda': 256 | def lambda_rule(epoch): 257 | lr_l = 1.0 - max( 258 | 0, epoch + 1 + opt.epoch_count - opt.niter) / float( 259 | opt.niter_decay + 1) 260 | return lr_l 261 | 262 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 263 | elif opt.lr_policy == 'step': 264 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, 265 | gamma=0.1) 266 | elif opt.lr_policy == 'plateau': 267 | scheduler = lr_scheduler.ReduceLROnPlateau( 268 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 269 | else: 270 | return NotImplementedError( 271 | 'learning rate policy [%s] is not implemented', opt.lr_policy) 272 | return scheduler 273 | 274 | 275 | def init_weights(net, init_type='normal', gain=0.02): 276 | def init_func(m): 277 | classname = m.__class__.__name__ 278 | if hasattr(m, 'weight') and ( 279 | classname.find('Conv') != -1 or classname.find('Linear') != -1): 280 | if init_type == 'normal': 281 | init.normal_(m.weight.data, 0.0, gain) 282 | elif init_type == 'xavier': 283 | init.xavier_normal_(m.weight.data, gain=gain) 284 | elif init_type == 'kaiming': 285 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 286 | elif init_type == 'orthogonal': 287 | init.orthogonal_(m.weight.data, gain=gain) 288 | else: 289 | raise NotImplementedError( 290 | 'initialization method [%s] is not implemented' % init_type) 291 | if hasattr(m, 'bias') and m.bias is not None: 292 | init.constant_(m.bias.data, 0.0) 293 | elif classname.find('BatchNorm2d') != -1: 294 | init.normal_(m.weight.data, 1.0, gain) 295 | init.constant_(m.bias.data, 0.0) 296 | 297 | print('initialize network with %s' % init_type) 298 | net.apply(init_func) 299 | 300 | 301 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=()): 302 | if len(gpu_ids) > 0: 303 | assert (torch.cuda.is_available()) 304 | net.to(gpu_ids[0]) 305 | net = torch.nn.DataParallel(net, gpu_ids) 306 | init_weights(net, init_type, gain=init_gain) 307 | return net 308 | 309 | 310 | def define_d(input_nc, ndf, which_model_net_d, 311 | n_layers_d=3, norm='batch', use_sigmoid=False, init_type='normal', 312 | init_gain=0.02, gpu_ids=()): 313 | norm_layer = get_norm_layer(norm_type=norm) 314 | 315 | if which_model_net_d == 'basic': 316 | net_d = NLayerDiscriminator( 317 | input_nc, ndf, n_layers=3, norm_layer=norm_layer, 318 | use_sigmoid=use_sigmoid) 319 | elif which_model_net_d == 'n_layers': 320 | net_d = NLayerDiscriminator( 321 | input_nc, ndf, n_layers_d, norm_layer=norm_layer, 322 | use_sigmoid=use_sigmoid) 323 | elif which_model_net_d == 'pixel': 324 | net_d = PixelDiscriminator( 325 | input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 326 | else: 327 | raise NotImplementedError( 328 | 'Discriminator model name [%s] is not recognized' % 329 | which_model_net_d) 330 | return init_net(net_d, init_type, init_gain, gpu_ids) 331 | 332 | 333 | class Pix2PixUnet(nn.Module): 334 | def __init__(self, mode, config): 335 | super().__init__() 336 | self.mode = mode 337 | 338 | # Config shortcuts 339 | self.in_ch = config.model.in_channels 340 | self.out_ch = config.model.out_channels 341 | self.num_downs = config.model.num_downs 342 | if self.mode == MODE_PRED: 343 | self.num_features = config.model.pred_features 344 | else: 345 | self.num_features = config.model.num_features 346 | self.noise_type = config.model.noise_type 347 | self.noise_dim = config.model.noise_dim 348 | 349 | # Setup normalization layer 350 | if config.model.norm == 'batch': 351 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 352 | elif config.model.norm == 'instance': 353 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) 354 | else: 355 | raise ValueError('Invalid norm layer type {}' 356 | .format(config.model.norm)) 357 | 358 | down_convs = [] 359 | up_convs = [] 360 | down_norms = [] 361 | up_norms = [] 362 | 363 | # Create down conv 364 | prev_ch = self.in_ch 365 | next_ch = self.num_features 366 | for i in range(self.num_downs): 367 | down_conv = nn.Conv2d(prev_ch, next_ch, kernel_size=4, 368 | stride=2, padding=1, bias=i == 0) 369 | down_convs.append(down_conv) 370 | if i != 0: 371 | down_norm = norm_layer(next_ch) 372 | down_norms.append(down_norm) 373 | else: 374 | down_norms.append(None) 375 | 376 | prev_ch = next_ch 377 | next_ch = min(2 * next_ch, self.num_features * 8) 378 | 379 | self.down_convs = nn.ModuleList(down_convs) 380 | self.down_norms = nn.ModuleList(down_norms) 381 | 382 | # Create up conv in reverse order 383 | prev_ch = self.out_ch 384 | next_ch = self.num_features 385 | self.mr_ch, self.up_ch = [], [] 386 | for i in range(self.num_downs): 387 | if self.mode == MODE_MR: 388 | noise_dim = self.noise_dim[i + 1] 389 | else: 390 | noise_dim = 0 391 | if i == self.num_downs - 1: 392 | ch = next_ch 393 | else: 394 | ch = 2 * next_ch # 2x for skip connection 395 | 396 | # Up convolution 397 | self.up_ch.append(prev_ch) 398 | up_conv = nn.ConvTranspose2d( 399 | ch + noise_dim, self.up_ch[-1], 400 | kernel_size=4, stride=2, padding=1, bias=i == 0 or False) 401 | up_convs.append(up_conv) 402 | if i != 0: 403 | up_norm = norm_layer(self.up_ch[-1]) 404 | up_norms.append(up_norm) 405 | else: 406 | up_norms.append(None) 407 | 408 | prev_ch = next_ch 409 | next_ch = min(2 * next_ch, self.num_features * 8) 410 | 411 | self.up_convs = nn.ModuleList(up_convs) 412 | self.up_norms = nn.ModuleList(up_norms) 413 | 414 | # Variance prediction 415 | if self.mode == MODE_PRED: 416 | self.dispersion_conv = nn.ConvTranspose2d( 417 | self.num_features * 2, self.out_ch, 418 | kernel_size=4, stride=2, padding=1 419 | ) 420 | else: 421 | self.dispersion_conv = None 422 | 423 | def forward(self, x, num_samples=1): 424 | # Down conv 425 | feat = [x] 426 | h = x 427 | for i, down_conv in enumerate(self.down_convs): 428 | f = down_conv(h) 429 | h = F.leaky_relu(f, 0.2) 430 | if self.down_norms[i] is not None: 431 | h = self.down_norms[i](h) 432 | feat.append(h) 433 | else: 434 | feat.append(h) 435 | 436 | # Duplicate num_samples times 437 | if num_samples > 1: 438 | feat = [ 439 | f.unsqueeze(1).expand( 440 | -1, num_samples, -1, -1, -1 441 | ).contiguous().view( 442 | f.size(0) * num_samples, f.size(1), f.size(2), f.size(3)) 443 | for f in feat 444 | ] 445 | h = h.unsqueeze(1).expand(-1, num_samples, -1, -1, -1).contiguous() 446 | h = h.view(h.size(0) * h.size(1), h.size(2), h.size(3), h.size(4)) 447 | 448 | # Up conv & MR conv 449 | dispersion = None 450 | for i in range(len(self.up_convs))[::-1]: 451 | # Skip connection 452 | if i < len(self.up_convs) - 1: 453 | h = torch.cat([h, feat[i + 1]], dim=1) 454 | 455 | # Predict dispersion 456 | if self.mode == MODE_PRED and i == 0: 457 | dispersion = self.dispersion_conv(h) 458 | 459 | # Mix noise 460 | if self.mode == MODE_MR and self.noise_dim[i + 1] > 0: 461 | z = generate_noise(self.noise_type, self.noise_dim[i + 1], 462 | like=h) 463 | h = torch.cat([h, z], dim=1) 464 | 465 | # Up convolution 466 | h = self.up_convs[i](h) 467 | h = torch.tanh(h) if i == 0 else F.relu(h) 468 | if self.up_norms[i] is not None: 469 | h = self.up_norms[i](h) 470 | 471 | return h, dispersion 472 | 473 | 474 | # Defines the PatchGAN discriminator with the specified arguments. 475 | class NLayerDiscriminator(nn.Module): 476 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 477 | use_sigmoid=False): 478 | super(NLayerDiscriminator, self).__init__() 479 | if type(norm_layer) == functools.partial: 480 | use_bias = norm_layer.func == nn.InstanceNorm2d 481 | else: 482 | use_bias = norm_layer == nn.InstanceNorm2d 483 | 484 | kw = 4 485 | padw = 1 486 | sequence = [ 487 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 488 | nn.LeakyReLU(0.2, True) 489 | ] 490 | 491 | nf_mult = 1 492 | for n in range(1, n_layers): 493 | nf_mult_prev = nf_mult 494 | nf_mult = min(2 ** n, 8) 495 | sequence += [ 496 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 497 | kernel_size=kw, stride=2, padding=padw, 498 | bias=use_bias), 499 | norm_layer(ndf * nf_mult), 500 | nn.LeakyReLU(0.2, True) 501 | ] 502 | 503 | nf_mult_prev = nf_mult 504 | nf_mult = min(2 ** n_layers, 8) 505 | sequence += [ 506 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 507 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 508 | norm_layer(ndf * nf_mult), 509 | nn.LeakyReLU(0.2, True) 510 | ] 511 | 512 | sequence += [ 513 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 514 | 515 | if use_sigmoid: 516 | sequence += [nn.Sigmoid()] 517 | 518 | self.model = nn.Sequential(*sequence) 519 | 520 | def forward(self, x): 521 | return self.model(x) 522 | 523 | 524 | class PixelDiscriminator(nn.Module): 525 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, 526 | use_sigmoid=False): 527 | super(PixelDiscriminator, self).__init__() 528 | if type(norm_layer) == functools.partial: 529 | use_bias = norm_layer.func == nn.InstanceNorm2d 530 | else: 531 | use_bias = norm_layer == nn.InstanceNorm2d 532 | 533 | self.net = [ 534 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 535 | nn.LeakyReLU(0.2, True), 536 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, 537 | bias=use_bias), 538 | norm_layer(ndf * 2), 539 | nn.LeakyReLU(0.2, True), 540 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, 541 | bias=use_bias)] 542 | 543 | if use_sigmoid: 544 | self.net.append(nn.Sigmoid()) 545 | 546 | self.net = nn.Sequential(*self.net) 547 | 548 | def forward(self, x): 549 | return self.net(x) 550 | -------------------------------------------------------------------------------- /models/glcic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from tensorboardX import SummaryWriter 5 | from constants import MODE_BASE, MODE_PRED, MODE_MR 6 | from data import DataIterator 7 | from .base import BaseModel, generate_noise 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, config, as_predictor=False): 12 | super().__init__() 13 | self.config = config 14 | self.as_predictor = as_predictor 15 | self.noise_dim = nd = config.model.noise_dim 16 | self.noise_type = config.model.noise_type 17 | 18 | # heads 19 | self.mean_conv = nn.Conv2d( 20 | 32 + nd[0] if nd[0] > 0 and not as_predictor else 32, 3, 21 | 3, 1, 1 22 | ) 23 | self.dispersion_conv = \ 24 | nn.Conv2d(32, 3, 3, 1, 1) if as_predictor else None 25 | 26 | # hidden layers 27 | self.layers = nn.ModuleList([ 28 | # conv1 29 | nn.Sequential( 30 | nn.Conv2d(4, 64, 5, 1, 2), 31 | nn.BatchNorm2d(64), 32 | nn.ReLU(inplace=True), 33 | ), 34 | # conv2 35 | nn.Sequential( 36 | nn.Conv2d(64, 128, 3, 2, 1), 37 | nn.BatchNorm2d(128), 38 | nn.ReLU(inplace=True), 39 | ), 40 | # conv3 41 | nn.Sequential( 42 | nn.Conv2d(128, 128, 3, 1, 1), 43 | nn.BatchNorm2d(128), 44 | nn.ReLU(inplace=True), 45 | ), 46 | # conv4 47 | nn.Sequential( 48 | nn.Conv2d(128, 256, 3, 2, 1), 49 | nn.BatchNorm2d(256), 50 | nn.ReLU(inplace=True), 51 | ), 52 | # conv5 53 | nn.Sequential( 54 | nn.Conv2d(256, 256, 3, 1, 1), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU(inplace=True), 57 | ), 58 | # conv6 59 | nn.Sequential( 60 | nn.Conv2d( 61 | 256 + nd[11] if nd[11] > 0 and not as_predictor else 256, 62 | 256, 3, 1, 1 63 | ), 64 | nn.BatchNorm2d(256), 65 | nn.ReLU(inplace=True), 66 | ), 67 | # dil1 68 | nn.Sequential( 69 | nn.Conv2d( 70 | 256 + nd[10] if nd[10] > 0 and not as_predictor else 256, 71 | 256, 3, 1, 2, dilation=2 72 | ), 73 | nn.BatchNorm2d(256), 74 | nn.ReLU(inplace=True), 75 | ), 76 | # dil2 77 | nn.Sequential( 78 | nn.Conv2d( 79 | 256 + nd[9] if nd[9] > 0 and not as_predictor else 256, 80 | 256, 3, 1, 4, dilation=4 81 | ), 82 | nn.BatchNorm2d(256), 83 | nn.ReLU(inplace=True), 84 | ), 85 | # dil3 86 | nn.Sequential( 87 | nn.Conv2d( 88 | 256 + nd[8] if nd[8] > 0 and not as_predictor else 256, 89 | 256, 3, 1, 8, dilation=8 90 | ), 91 | nn.BatchNorm2d(256), 92 | nn.ReLU(inplace=True), 93 | ), 94 | # dil4 95 | nn.Sequential( 96 | nn.Conv2d( 97 | 256 + nd[7] if nd[7] > 0 and not as_predictor else 256, 98 | 256, 3, 1, 16, dilation=16 99 | ), 100 | nn.BatchNorm2d(256), 101 | nn.ReLU(inplace=True), 102 | ), 103 | # conv7 104 | nn.Sequential( 105 | nn.Conv2d( 106 | 256 + nd[6] if nd[6] > 0 and not as_predictor else 256, 107 | 256, 3, 1, 1 108 | ), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(inplace=True), 111 | ), 112 | # conv8 113 | nn.Sequential( 114 | nn.Conv2d( 115 | 256 + nd[5] if nd[5] > 0 and not as_predictor else 256, 116 | 256, 3, 1, 1 117 | ), 118 | nn.BatchNorm2d(256), 119 | nn.ReLU(inplace=True), 120 | ), 121 | # deconv1 122 | nn.Sequential( 123 | nn.ConvTranspose2d( 124 | 256 + nd[4] if nd[4] > 0 and not as_predictor else 256, 125 | 128, 4, 2, 1 126 | ), 127 | nn.BatchNorm2d(128), 128 | nn.ReLU(inplace=True), 129 | ), 130 | # conv9 131 | nn.Sequential( 132 | nn.Conv2d( 133 | 128 + nd[3] if nd[3] > 0 and not as_predictor else 128, 134 | 128, 3, 1, 1 135 | ), 136 | nn.BatchNorm2d(128), 137 | nn.ReLU(inplace=True), 138 | ), 139 | # deconv2 140 | nn.Sequential( 141 | nn.ConvTranspose2d( 142 | 128 + nd[2] if nd[2] > 0 and not as_predictor else 128, 143 | 64, 4, 2, 1 144 | ), 145 | nn.BatchNorm2d(64), 146 | nn.ReLU(inplace=True), 147 | ), 148 | # conv10 149 | nn.Sequential( 150 | nn.Conv2d( 151 | 64 + nd[1] if nd[1] > 0 and not as_predictor else 64, 152 | 32, 3, 1, 1 153 | ), 154 | nn.BatchNorm2d(32), 155 | nn.ReLU(inplace=True), 156 | ), 157 | ]) 158 | 159 | def forward(self, x, num_samples=1): 160 | if num_samples > 1: 161 | x = x\ 162 | .unsqueeze(1)\ 163 | .expand(-1, num_samples, -1, -1, -1)\ 164 | .contiguous()\ 165 | .view(x.size(0) * num_samples, x.size(1), x.size(2), x.size(3)) 166 | 167 | images = x[:, :3, ...] 168 | masks = x[:, -1:, ...] 169 | net = x 170 | # run through the hidden layers 171 | for i, layer in enumerate(self.layers): 172 | reversed_index = len(self.layers) - i - 1 173 | net = layer(net) 174 | if not self.as_predictor and ( 175 | len(self.noise_dim) > reversed_index and 176 | self.noise_dim[reversed_index] > 0 177 | ): 178 | noise = generate_noise( 179 | self.noise_type, 180 | self.noise_dim[reversed_index], 181 | like=net 182 | ) 183 | net = torch.cat([net, noise], 1) 184 | 185 | mean = F.sigmoid(self.mean_conv(net)) 186 | mean = mean * masks + images * (1. - masks) 187 | 188 | if self.as_predictor: 189 | dispersion = self.dispersion_conv(net) 190 | dispersion = ( 191 | dispersion * masks + 192 | self.config.log_dispersion_min * (1. - masks) 193 | ) 194 | return mean, dispersion 195 | else: 196 | return mean, None 197 | 198 | 199 | class Discriminator(nn.Module): 200 | def __init__(self, config): 201 | super().__init__() 202 | self.global_d = GlobalDiscriminator(config) 203 | self.local_d = LocalDiscriminator(config) 204 | self.linear = nn.Linear(1024 * 2, 1) 205 | 206 | def forward(self, x_y): 207 | x, y = x_y 208 | local_boxes = x.local_boxes 209 | global_output = self.global_d(y) 210 | local_output = self.local_d(self._local(y, local_boxes)) 211 | return self.linear(torch.cat([global_output, local_output], 1)) 212 | 213 | def _local(self, images, local_boxes): 214 | local_images = [] 215 | for image, local_box in zip(images, local_boxes): 216 | y1, x1, y2, x2 = local_box 217 | local_images.append(image[:, y1:y2, x1:x2]) 218 | return torch.stack(local_images) 219 | 220 | 221 | class GlobalDiscriminator(nn.Module): 222 | def __init__(self, config): 223 | super().__init__() 224 | self.linear = nn.Linear( 225 | 512 * (config.data.random_crop // (2 ** 6)) ** 2, 226 | 1024 227 | ) 228 | self.layers = nn.Sequential( 229 | # conv1 230 | nn.Conv2d(3, 64, 5, 2, 2), 231 | nn.BatchNorm2d(64), 232 | nn.ReLU(inplace=True), 233 | # conv2 234 | nn.Conv2d(64, 128, 5, 2, 2), 235 | nn.BatchNorm2d(128), 236 | nn.ReLU(inplace=True), 237 | # conv3 238 | nn.Conv2d(128, 256, 5, 2, 2), 239 | nn.BatchNorm2d(256), 240 | nn.ReLU(inplace=True), 241 | # conv4 242 | nn.Conv2d(256, 512, 5, 2, 2), 243 | nn.BatchNorm2d(512), 244 | nn.ReLU(inplace=True), 245 | # conv5 246 | nn.Conv2d(512, 512, 5, 2, 2), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(inplace=True), 249 | # conv6 250 | nn.Conv2d(512, 512, 5, 2, 2), 251 | nn.BatchNorm2d(512), 252 | nn.ReLU(inplace=True), 253 | ) 254 | 255 | def forward(self, y): 256 | net = self.layers(y) 257 | net = net.view(net.size(0), -1) 258 | return self.linear(net) 259 | 260 | 261 | class LocalDiscriminator(nn.Module): 262 | def __init__(self, config): 263 | super().__init__() 264 | self.linear = nn.Linear( 265 | 512 * (config.data.local_size // (2 ** 5)) ** 2, 266 | 1024 267 | ) 268 | self.layers = nn.Sequential( 269 | # conv1 270 | nn.Conv2d(3, 64, 5, 2, 2), 271 | nn.BatchNorm2d(64), 272 | nn.ReLU(inplace=True), 273 | # conv2 274 | nn.Conv2d(64, 128, 5, 2, 2), 275 | nn.BatchNorm2d(128), 276 | nn.ReLU(inplace=True), 277 | # conv3 278 | nn.Conv2d(128, 256, 5, 2, 2), 279 | nn.BatchNorm2d(256), 280 | nn.ReLU(inplace=True), 281 | # conv4 282 | nn.Conv2d(256, 512, 5, 2, 2), 283 | nn.BatchNorm2d(512), 284 | nn.ReLU(inplace=True), 285 | # conv5 286 | nn.Conv2d(512, 512, 5, 2, 2), 287 | nn.BatchNorm2d(512), 288 | nn.ReLU(inplace=True), 289 | ) 290 | 291 | def forward(self, y): 292 | net = self.layers(y) 293 | net = net.view(net.size(0), -1) 294 | return self.linear(net) 295 | 296 | 297 | class GLCIC(BaseModel): 298 | name = 'glcic' 299 | 300 | def __init__(self, config): 301 | self.t_g = config.g_pretrain_step 302 | self.t_d = config.d_pretrain_step 303 | super().__init__( 304 | config, 305 | Discriminator(config) 306 | if config.mode in (MODE_BASE, MODE_MR) else None, 307 | Generator(config) 308 | if config.mode in (MODE_BASE, MODE_MR) else None, 309 | Generator(config, as_predictor=True) 310 | if config.mode in (MODE_PRED, MODE_MR) else None, 311 | ) 312 | 313 | def optimize_d(self, x, y, step, summarize=False): 314 | torch.cuda.empty_cache() 315 | image = {} 316 | assert self.mode in (MODE_BASE, MODE_MR) 317 | if step < self.t_g or self.loss_config.gan_weight <= 0: 318 | return {} 319 | # gan loss (BASELINE || MR) 320 | fake_y = self.net_g(x)[0] 321 | fake_v = self.net_d((x, fake_y)) 322 | 323 | # jitter real y 324 | if self.config.data.real_jitter > 0: 325 | mask = x[:, -1:, ...] 326 | jitter = torch.randn(y.size(0), y.size(1), 1, 1, device='cuda') 327 | real_y = y + mask * self.config.data.real_jitter * jitter 328 | real_y = torch.clamp(real_y, 0., 1.) 329 | else: 330 | real_y = y 331 | 332 | real_v = self.net_d((x, real_y)) 333 | real_loss = self.gan_loss(real_v, True) 334 | fake_loss = self.gan_loss(fake_v, False) 335 | gan_loss = (real_loss + fake_loss) / 2 336 | # backprop & minimize 337 | self.optim_d.zero_grad() 338 | loss = gan_loss * self.loss_config.gan_weight 339 | loss.backward() 340 | self.clip_grad(self.optim_d, self.config.d_optimizer.clip_grad) 341 | self.optim_d.step() 342 | # image summaries 343 | if summarize: 344 | for i in range(y.size(0)): 345 | real_image_id = 'd_inputs/real/{}'.format(i) 346 | fake_image_id = 'd_inputs/fake/{}'.format(i) 347 | image[real_image_id] = real_y[i] 348 | image[fake_image_id] = fake_y[i] 349 | return { 350 | 'scalar': { 351 | 'loss/d/total': loss, 352 | 'loss/d/gan_real': real_loss, 353 | 'loss/d/gan_fake': fake_loss, 354 | 'loss/d/gan_total': gan_loss, 355 | }, 356 | 'image': image 357 | } 358 | 359 | def optimize_g(self, x, y, step, summarize=False): 360 | torch.cuda.empty_cache() 361 | assert self.mode in (MODE_BASE, MODE_MR) 362 | # prepare some accumulators 363 | scalar = {'loss/g/total': 0.} 364 | histogram = {} 365 | image = {} 366 | loss = 0. 367 | self.optim_g.zero_grad() 368 | # GAN loss (BASELINE) 369 | fake_y = self.net_g(x)[0] 370 | if self.mode == MODE_BASE and (step > self.t_g + self.t_d): 371 | fake_v = self.net_d((x, fake_y)) 372 | gan_loss = self.gan_loss(fake_v, True) 373 | weighted_gan_loss = self.loss_config.gan_weight * gan_loss 374 | loss += weighted_gan_loss 375 | if summarize: 376 | scalar['loss/g/gan'] = gan_loss.detach() 377 | scalar['loss/g/total'] += weighted_gan_loss.detach() 378 | # MSE loss (BASELINE) 379 | if self.mode == MODE_BASE and self.loss_config.recon_weight > 0 and ( 380 | step > self.t_g + self.t_d or 381 | step < self.t_g 382 | ): 383 | mse_loss = self._mse_loss(fake_y, y) 384 | weighted_mse_loss = self.loss_config.recon_weight * mse_loss 385 | loss += weighted_mse_loss 386 | if summarize: 387 | scalar['loss/g/mse'] = mse_loss.detach() 388 | scalar['loss/g/total'] += weighted_mse_loss.detach() 389 | # backprop before accumulating mr gradients 390 | if isinstance(loss, torch.Tensor): 391 | loss.backward() 392 | 393 | # MR loss (MR) 394 | if self.mode == MODE_MR: 395 | mr_summaries = self.accumulate_mr_grad(x, y, summarize) 396 | mr_scalar = mr_summaries['scalar'] 397 | mr_histogram = mr_summaries['histogram'] 398 | mr_image = mr_summaries['image'] 399 | torch.cuda.empty_cache() 400 | if summarize: 401 | scalar['loss/g/total'] += mr_scalar['loss/g/total'] 402 | del mr_scalar['loss/g/total'] 403 | scalar.update(mr_scalar) 404 | histogram.update(mr_histogram) 405 | for i in range(min(16, self.config.num_mr)): 406 | image_id = 'train_samples/%d' % i 407 | if image_id in image: 408 | image[image_id] = torch.cat([ 409 | image[image_id], mr_image[image_id] 410 | ], 2) 411 | else: 412 | image[image_id] = mr_image[image_id] 413 | image[image_id] = image[image_id] 414 | # Optimize the network 415 | self.clip_grad(self.optim_g, self.config.g_optimizer.clip_grad) 416 | self.optim_g.step() 417 | return {'scalar': scalar, 'histogram': histogram, 'image': image} 418 | 419 | def optimize_p(self, x, y, step, summarize=False): 420 | assert self.mode == MODE_PRED 421 | # MLE loss (PRED) 422 | mean, log_var = self.net_p(x) 423 | loss = self.mle_loss(mean, log_var, y) 424 | # backprop & minimize 425 | self.optim_p.zero_grad() 426 | loss.backward() 427 | self.clip_grad(self.optim_p, self.config.p_optimizer.clip_grad) 428 | self.optim_p.step() 429 | return { 430 | 'scalar': { 431 | 'loss/p/mle': loss 432 | } 433 | } 434 | 435 | def summarize(self, writer: SummaryWriter, step, 436 | train_sample_iterator: DataIterator, 437 | val_sample_iterator: DataIterator): 438 | with torch.no_grad(): 439 | for iter_type, iterator in ( 440 | ('train', train_sample_iterator), 441 | ('val', val_sample_iterator) 442 | ): 443 | mle_loss = 0. 444 | sample_2nds = [] 445 | pred_log_2nds = [] 446 | for i in range(len(iterator)): 447 | x, y = next(iterator) 448 | if self.mode == MODE_BASE: 449 | self.net_g.eval() 450 | num_samples = 12 451 | g_x = x.expand(num_samples, -1, -1, -1) 452 | samples = self.net_g(g_x)[0] 453 | samples_reshaped = samples.view( 454 | 1, num_samples, *list(samples.size()[1:]) 455 | ) 456 | sample_1st, sample_2nd = self.sample_statistics( 457 | samples_reshaped 458 | ) 459 | zeros = torch.zeros_like(x[:, :3, ...]) 460 | self.net_g.train() 461 | collage = torch.cat([ 462 | torch.cat( 463 | [x[:, :3, ...], y, zeros, zeros, zeros, zeros], 464 | dim=3 465 | ), 466 | torch.cat([ 467 | fy.unsqueeze(0) for fy in 468 | torch.unbind(samples[:num_samples // 2]) 469 | ], dim=3), 470 | torch.cat([ 471 | fy.unsqueeze(0) for fy in 472 | torch.unbind(samples[num_samples // 2:]) 473 | ], dim=3) 474 | ], dim=2) 475 | sample_2nds.append(sample_2nd) 476 | writer.add_image( 477 | 'g/{}/{}'.format(iter_type, i), collage, step 478 | ) 479 | elif self.mode == MODE_MR: 480 | mask = x[:, -1:, ...] 481 | pred_1st, pred_log_2nd = self.net_p(x) 482 | self.net_g.eval() 483 | num_samples = 12 484 | g_x = x.expand(num_samples, -1, -1, -1) 485 | samples = self.net_g(g_x)[0] 486 | self.net_g.train() 487 | sample_1st, sample_2nd = self.sample_statistics(samples) 488 | sample_log_2nd = (sample_2nd + 1e-4).log() 489 | sample_log_2nd = ( 490 | sample_log_2nd * mask + 491 | self.config.log_dispersion_min * (1. - mask) 492 | ) 493 | log_2nds = torch.cat([pred_log_2nd, sample_log_2nd], 3) 494 | uncertainty = self._build_uncertainty_image(log_2nds) 495 | uncertainty = (uncertainty + 1.) * 0.5 496 | collage = [ 497 | torch.cat([ 498 | x[:, :3, ...], y, pred_1st, 499 | sample_1st, uncertainty, 500 | ], dim=3), 501 | torch.cat([ 502 | fy.unsqueeze(0) for fy in 503 | torch.unbind(samples[:num_samples // 2]) 504 | ], dim=3), 505 | torch.cat([ 506 | fy.unsqueeze(0) for fy in 507 | torch.unbind(samples[num_samples // 2:]) 508 | ], dim=3) 509 | ] 510 | collage = torch.cat(collage, dim=2) 511 | writer.add_image( 512 | 'collage/{}/{}'.format(iter_type, i), collage, step 513 | ) 514 | elif self.mode == MODE_PRED: 515 | self.net_p.eval() 516 | pred_1st, pred_log_2nd = self.net_p(x) 517 | self.net_p.train() 518 | uncertainty = self._build_uncertainty_image( 519 | pred_log_2nd 520 | ) 521 | uncertainty = (uncertainty + 1.) * 0.5 522 | collage = torch.cat([ 523 | torch.cat([x[:, :3, ...], y], dim=3), 524 | torch.cat([pred_1st, uncertainty], dim=3) 525 | ], dim=2) 526 | pred_log_2nds.append(pred_log_2nd) 527 | writer.add_image( 528 | 'p/{}/{}'.format(iter_type, i), collage, step 529 | ) 530 | 531 | # Validation loss 532 | if iter_type == 'val': 533 | mle_loss += self.mle_loss(pred_1st, pred_log_2nd, y) 534 | 535 | if self.mode == MODE_PRED and iter_type == 'val': 536 | pred_log_2nd = torch.stack(pred_log_2nds) 537 | pred_2nd = torch.exp(pred_log_2nd) 538 | writer.add_scalar( 539 | 'variance/pred', pred_2nd.mean(), step 540 | ) 541 | writer.add_histogram( 542 | 'variance_hist/pred', pred_2nd, step 543 | ) 544 | if self.mode == MODE_BASE and iter_type == 'val': 545 | sample_2nd = torch.stack(sample_2nds) 546 | writer.add_scalar( 547 | 'variance/sample', sample_2nd.mean(), step 548 | ) 549 | writer.add_histogram( 550 | 'variance_hist/sample', sample_2nd, step 551 | ) 552 | if self.mode == MODE_PRED and iter_type == 'val': 553 | writer.add_scalar( 554 | 'loss/p/mle_val', mle_loss / len(iterator), step 555 | ) 556 | 557 | def mle_target(self, y): 558 | return y 559 | 560 | def build_d_input(self, x, samples): 561 | num_mr = self.config.num_mr 562 | num_samples = self.config.num_mr_samples 563 | local_box_dup = x.local_boxes[:num_mr].unsqueeze(1) 564 | local_box_dup = local_box_dup.expand(-1, num_samples, -1) 565 | local_box_dup = local_box_dup.contiguous().view( 566 | num_mr * num_samples, local_box_dup.size(2) 567 | ) 568 | 569 | x_dup = x[:num_mr].unsqueeze(1).expand(-1, num_samples, -1, -1, -1) 570 | x_dup = x_dup.contiguous().view( 571 | num_mr * num_samples, *list(x_dup.size()[2:])) 572 | 573 | x_dup.local_boxes = local_box_dup 574 | 575 | return x_dup, samples 576 | --------------------------------------------------------------------------------