├── models ├── s ├── types_.py ├── base.py ├── __init__.py ├── cvae.py ├── beta_vae.py ├── logcosh_vae.py ├── dip_vae.py ├── vampvae.py ├── FullyVAE_.py ├── twostage_vae.py ├── miwae.py ├── swae.py ├── dfcvae.py ├── cat_vae.py ├── vq_vae.py └── fvae.py ├── Task_Split ├── s ├── sinkhorn.py └── MMD_Tool.py ├── datasets ├── d ├── __init__.py ├── Data Augmentation_.py ├── cifar10.py ├── lsun_bedroom.py ├── README.md ├── CIFAR100.py ├── MNIST32.py ├── CelebA_Load.py ├── cv2_ImageProcess.py ├── myCIFAR100.py └── MyCIFAR10.py ├── results └── d ├── NetworkModels ├── a ├── TeacherStudent_.py ├── Balance_TeacherStudent_.py ├── MMD_Lib.py ├── MyClassifier.py ├── DynamicTeacherStudent_.py ├── Balance_TeacherStudent_NoMPI_.py ├── TFCL_DynamicVAEModel_.py ├── TeacherEnsembleFramework_.py └── DynamicMixture256_.py ├── improved_diffusion ├── s ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── losses.py ├── image_datasets.py ├── respace.py ├── nn.py ├── resample.py └── train_util_all.py ├── GraphMemory_Structure_1.jpg ├── cifar10_GraphMemory2000.png ├── mnist_GraphMemory_WDistance2200.png ├── fashion_GraphMemory_WDistance2000.png └── README.md /models/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Task_Split/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/d: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /results/d: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /NetworkModels/a: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /improved_diffusion/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /GraphMemory_Structure_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dtuzi123/DCM/HEAD/GraphMemory_Structure_1.jpg -------------------------------------------------------------------------------- /cifar10_GraphMemory2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dtuzi123/DCM/HEAD/cifar10_GraphMemory2000.png -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /mnist_GraphMemory_WDistance2200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dtuzi123/DCM/HEAD/mnist_GraphMemory_WDistance2200.png -------------------------------------------------------------------------------- /fashion_GraphMemory_WDistance2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dtuzi123/DCM/HEAD/fashion_GraphMemory_WDistance2000.png -------------------------------------------------------------------------------- /models/types_.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Union, Any, TypeVar, Tuple 2 | # from torch import tensor as Tensor 3 | 4 | Tensor = TypeVar('torch.tensor') 5 | -------------------------------------------------------------------------------- /NetworkModels/TeacherStudent_.py: -------------------------------------------------------------------------------- 1 | from NetworkModels.Teacher_Model_ import Teacher 2 | from NetworkModels.VAE_Model_ import StudentModel 3 | import torch.nn as nn 4 | 5 | class TeacherStudent(nn.Module): 6 | 7 | def __init__(self,name,device,input_size): 8 | super(TeacherStudent, self).__init__() 9 | 10 | self.input_size = input_size 11 | self.teacher = Teacher(input_size) 12 | self.student = StudentModel(device,input_size) 13 | self.device = device 14 | 15 | def Train(self,Tepoch,Sepoch,data): 16 | self.teacher.Train_Self(Tepoch,data) 17 | self.student.Train_Self(Sepoch,data) 18 | 19 | def Train_StudentOnly(self,Tepoch,Sepoch,data): 20 | self.student.Train_Self(Sepoch,data) 21 | 22 | 23 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | from .types_ import * 2 | from torch import nn 3 | from abc import abstractmethod 4 | 5 | class BaseVAE(nn.Module): 6 | 7 | def __init__(self) -> None: 8 | super(BaseVAE, self).__init__() 9 | 10 | def encode(self, input: Tensor) -> List[Tensor]: 11 | raise NotImplementedError 12 | 13 | def decode(self, input: Tensor) -> Any: 14 | raise NotImplementedError 15 | 16 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: 17 | raise NotImplementedError 18 | 19 | def generate(self, x: Tensor, **kwargs) -> Tensor: 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def forward(self, *inputs: Tensor) -> Tensor: 24 | pass 25 | 26 | @abstractmethod 27 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor: 28 | pass 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /datasets/Data Augmentation_.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | 6 | 7 | class Data_Augmentation: 8 | def RandomTransfer(x1): 9 | 10 | inputSize = np.shape(x1)[2] 11 | 12 | transform1 = transforms.Compose([ 13 | transforms.ToTensor(), 14 | transforms.ToPILImage(), 15 | transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5) 16 | ]) 17 | 18 | transform2 = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.ToPILImage(), 21 | transforms.RandomCrop(inputSize, pad_if_needed=True), 22 | ]) 23 | 24 | transform3 = transforms.Compose([ 25 | transforms.Compose(), 26 | transforms.ToPILImage(), 27 | transforms.RandomRotation(60), 28 | ]) 29 | 30 | r = transform2(x1) 31 | return r 32 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import torchvision 5 | from tqdm.auto import tqdm 6 | 7 | CLASSES = ( 8 | "plane", 9 | "car", 10 | "bird", 11 | "cat", 12 | "deer", 13 | "dog", 14 | "frog", 15 | "horse", 16 | "ship", 17 | "truck", 18 | ) 19 | 20 | 21 | def main(): 22 | for split in ["train", "test"]: 23 | out_dir = f"cifar_{split}" 24 | if os.path.exists(out_dir): 25 | print(f"skipping split {split} since {out_dir} already exists.") 26 | continue 27 | 28 | print("downloading...") 29 | with tempfile.TemporaryDirectory() as tmp_dir: 30 | dataset = torchvision.datasets.CIFAR10( 31 | root=tmp_dir, train=split == "train", download=True 32 | ) 33 | 34 | print("dumping images...") 35 | os.mkdir(out_dir) 36 | for i in tqdm(range(len(dataset))): 37 | image, label = dataset[i] 38 | filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png") 39 | image.save(filename) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .vanilla_vae import * 3 | from .gamma_vae import * 4 | from .beta_vae import * 5 | from .wae_mmd import * 6 | from .cvae import * 7 | from .hvae import * 8 | from .vampvae import * 9 | from .iwae import * 10 | from .dfcvae import * 11 | from .mssim_vae import MSSIMVAE 12 | from .fvae import * 13 | from .cat_vae import * 14 | from .joint_vae import * 15 | from .info_vae import * 16 | # from .twostage_vae import * 17 | from .lvae import LVAE 18 | from .logcosh_vae import * 19 | from .swae import * 20 | from .miwae import * 21 | from .vq_vae import * 22 | from .betatc_vae import * 23 | from .dip_vae import * 24 | 25 | 26 | # Aliases 27 | VAE = VanillaVAE 28 | GaussianVAE = VanillaVAE 29 | CVAE = ConditionalVAE 30 | GumbelVAE = CategoricalVAE 31 | 32 | vae_models = {'HVAE':HVAE, 33 | 'LVAE':LVAE, 34 | 'IWAE':IWAE, 35 | 'SWAE':SWAE, 36 | 'MIWAE':MIWAE, 37 | 'VQVAE':VQVAE, 38 | 'DFCVAE':DFCVAE, 39 | 'DIPVAE':DIPVAE, 40 | 'BetaVAE':BetaVAE, 41 | 'InfoVAE':InfoVAE, 42 | 'WAE_MMD':WAE_MMD, 43 | 'VampVAE': VampVAE, 44 | 'GammaVAE':GammaVAE, 45 | 'MSSIMVAE':MSSIMVAE, 46 | 'JointVAE':JointVAE, 47 | 'BetaTCVAE':BetaTCVAE, 48 | 'FactorVAE':FactorVAE, 49 | 'LogCoshVAE':LogCoshVAE, 50 | 'VanillaVAE':VanillaVAE, 51 | 'ConditionalVAE':ConditionalVAE, 52 | 'CategoricalVAE':CategoricalVAE} 53 | -------------------------------------------------------------------------------- /NetworkModels/Balance_TeacherStudent_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 4 | from NetworkModels.VAE_Model_ import StudentModel,Balance_StudentModel 5 | import torch.nn as nn 6 | 7 | class Balance_TeacherStudent(nn.Module): 8 | 9 | def __init__(self,name,device,input_size): 10 | super(Balance_TeacherStudent, self).__init__() 11 | 12 | self.input_size = input_size 13 | self.teacher = Balance_Teacher(input_size) 14 | self.student = Balance_StudentModel(device,input_size) 15 | self.device = device 16 | 17 | def Train(self,Tepoch,Sepoch,data,generatedData): 18 | 19 | if np.shape(generatedData)[0] == 0: 20 | self.teacher.Train_Self(Tepoch,data) 21 | self.student.Train_Self(Sepoch,data) 22 | else: 23 | self.teacher.Train_Self_(Tepoch,data,generatedData) 24 | self.student.Train_Self_(Sepoch,data,generatedData) 25 | 26 | 27 | def Train_ByLoadData_Single(self,Tepoch,Sepoch,data): 28 | #self.teacher.Train_Self_ByDataLoad_Single(Tepoch, data) 29 | self.student.Train_Self_ByDataLoad_Single(Sepoch, data) 30 | 31 | 32 | def Train_ByLoadData(self,Tepoch,Sepoch,data,generatedData): 33 | 34 | if np.shape(generatedData)[0] == 0: 35 | self.teacher.Train_Self_ByDataLoad(Tepoch,data) 36 | self.student.Train_Self(Sepoch,data) 37 | else: 38 | self.teacher.Train_Self_ByDataLoad(Tepoch,data,generatedData) 39 | self.student.Train_Self_ByDataLoad(Sepoch,data,generatedData) 40 | 41 | 42 | -------------------------------------------------------------------------------- /datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /improved_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | import argparse 9 | 10 | import blobfile as bf 11 | #from mpi4py import MPI 12 | import torch as th 13 | import torch.distributed as dist 14 | 15 | # Change this to reflect your cluster layout. 16 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 17 | GPUS_PER_NODE = 8 18 | 19 | SETUP_RETRY_COUNT = 3 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | 49 | ''' 50 | if th.cuda.is_available(): 51 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 52 | return th.device("cpu") 53 | ''' 54 | if th.cuda.is_available(): 55 | return th.device("cuda:0") 56 | return th.device("cpu") 57 | 58 | def load_state_dict(path, **kwargs): 59 | """ 60 | Load a PyTorch file without redundant fetches across MPI ranks. 61 | """ 62 | if MPI.COMM_WORLD.Get_rank() == 0: 63 | with bf.BlobFile(path, "rb") as f: 64 | data = f.read() 65 | else: 66 | data = None 67 | data = MPI.COMM_WORLD.bcast(data) 68 | return th.load(io.BytesIO(data), **kwargs) 69 | 70 | 71 | def sync_params(params): 72 | """ 73 | Synchronize a sequence of Tensors across ranks from rank 0. 74 | """ 75 | for p in params: 76 | with th.no_grad(): 77 | dist.broadcast(p, 0) 78 | 79 | 80 | def _find_free_port(): 81 | try: 82 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 83 | s.bind(("", 0)) 84 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 85 | return s.getsockname()[1] 86 | finally: 87 | s.close() 88 | -------------------------------------------------------------------------------- /improved_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /NetworkModels/MMD_Lib.py: -------------------------------------------------------------------------------- 1 | # Compute MMD (maximum mean discrepancy) using numpy and scikit-learn. 2 | 3 | import numpy as np 4 | from sklearn import metrics 5 | 6 | 7 | def mmd_linear(X, Y): 8 | """MMD using linear kernel (i.e., k(x,y) = ) 9 | Note that this is not the original linear MMD, only the reformulated and faster version. 10 | The original version is: 11 | def mmd_linear(X, Y): 12 | XX = np.dot(X, X.T) 13 | YY = np.dot(Y, Y.T) 14 | XY = np.dot(X, Y.T) 15 | return XX.mean() + YY.mean() - 2 * XY.mean() 16 | 17 | Arguments: 18 | X {[n_sample1, dim]} -- [X matrix] 19 | Y {[n_sample2, dim]} -- [Y matrix] 20 | 21 | Returns: 22 | [scalar] -- [MMD value] 23 | """ 24 | delta = X.mean(0) - Y.mean(0) 25 | return delta.dot(delta.T) 26 | 27 | 28 | def mmd_rbf(X, Y, gamma=1.0): 29 | """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2)) 30 | 31 | Arguments: 32 | X {[n_sample1, dim]} -- [X matrix] 33 | Y {[n_sample2, dim]} -- [Y matrix] 34 | 35 | Keyword Arguments: 36 | gamma {float} -- [kernel parameter] (default: {1.0}) 37 | 38 | Returns: 39 | [scalar] -- [MMD value] 40 | """ 41 | XX = metrics.pairwise.rbf_kernel(X, X, gamma) 42 | YY = metrics.pairwise.rbf_kernel(Y, Y, gamma) 43 | XY = metrics.pairwise.rbf_kernel(X, Y, gamma) 44 | return XX.mean() + YY.mean() - 2 * XY.mean() 45 | 46 | 47 | def mmd_poly(X, Y, degree=2, gamma=1, coef0=0): 48 | """MMD using polynomial kernel (i.e., k(x,y) = (gamma + coef0)^degree) 49 | 50 | Arguments: 51 | X {[n_sample1, dim]} -- [X matrix] 52 | Y {[n_sample2, dim]} -- [Y matrix] 53 | 54 | Keyword Arguments: 55 | degree {int} -- [degree] (default: {2}) 56 | gamma {int} -- [gamma] (default: {1}) 57 | coef0 {int} -- [constant item] (default: {0}) 58 | 59 | Returns: 60 | [scalar] -- [MMD value] 61 | """ 62 | XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0) 63 | YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0) 64 | XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0) 65 | return XX.mean() + YY.mean() - 2 * XY.mean() 66 | 67 | 68 | ''' 69 | if __name__ == '__main__': 70 | a = np.arange(1, 10).reshape(3, 3) 71 | b = [[7, 6, 5], [4, 3, 2], [1, 1, 8], [0, 2, 5]] 72 | b = np.array(b) 73 | print(a) 74 | print(b) 75 | print(mmd_linear(a, b)) # 6.0 76 | print(mmd_rbf(a, b)) # 0.5822 77 | print(mmd_poly(a, b)) # 2436.5 78 | 79 | ''' -------------------------------------------------------------------------------- /NetworkModels/MyClassifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, inchannel, outchannel, stride=1): 8 | super(ResidualBlock, self).__init__() 9 | self.left = nn.Sequential( 10 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 11 | nn.BatchNorm2d(outchannel), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 14 | nn.BatchNorm2d(outchannel) 15 | ) 16 | self.shortcut = nn.Sequential() 17 | if stride != 1 or inchannel != outchannel: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(outchannel) 21 | ) 22 | 23 | def forward(self, x): 24 | out = self.left(x) 25 | out = out + self.shortcut(x) 26 | out = F.relu(out) 27 | 28 | return out 29 | 30 | 31 | class ResNet(nn.Module): 32 | def __init__(self, ResidualBlock, num_classes=10): 33 | super(ResNet, self).__init__() 34 | self.inchannel = 64 35 | self.conv1 = nn.Sequential( 36 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 37 | nn.BatchNorm2d(64), 38 | nn.ReLU() 39 | ) 40 | self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1) 41 | self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2) 42 | self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2) 43 | self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2) 44 | self.fc = nn.Linear(512, num_classes) 45 | 46 | def make_layer(self, block, channels, num_blocks, stride): 47 | strides = [stride] + [1] * (num_blocks - 1) 48 | layers = [] 49 | for stride in strides: 50 | layers.append(block(self.inchannel, channels, stride)) 51 | self.inchannel = channels 52 | return nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | out = self.conv1(x) 56 | out = self.layer1(out) 57 | out = self.layer2(out) 58 | out = self.layer3(out) 59 | out = self.layer4(out) 60 | out = F.avg_pool2d(out, 4) 61 | out = out.view(out.size(0), -1) 62 | out = self.fc(out) 63 | return out 64 | def ResNet18(): 65 | return ResNet(ResidualBlock) 66 | 67 | def ResNet18_100(a): 68 | return ResNet(ResidualBlock,a) -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet, LSUN bedrooms, and CIFAR-10 for use in this codebase. 4 | 5 | ## ImageNet-64 6 | 7 | To download unconditional ImageNet-64, go to [this page on image-net.org](http://www.image-net.org/small/download.php) and click on "Train (64x64)". Simply download the file and unzip it, and use the resulting directory as the data directory (the `--data_dir` argument for the training script). 8 | 9 | ## Class-conditional ImageNet 10 | 11 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 12 | 13 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 14 | 15 | ``` 16 | for file in *.tar; do tar xf "$file"; rm "$file"; done 17 | ``` 18 | 19 | This will extract and remove each tar file in turn. 20 | 21 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 22 | 23 | ## CIFAR-10 24 | 25 | For CIFAR-10, we created a script [cifar10.py](cifar10.py) that creates `cifar_train` and `cifar_test` directories. These directories contain files named like `truck_49997.png`, so that the class name is discernable to the data loader. 26 | 27 | The `cifar_train` and `cifar_test` directories can be passed directly to the training scripts via the `--data_dir` argument. 28 | 29 | ## LSUN bedroom 30 | 31 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 32 | 33 | ``` 34 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 35 | ``` 36 | 37 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 38 | -------------------------------------------------------------------------------- /datasets/CIFAR100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle as pickle 4 | import glob 5 | import matplotlib.pyplot as plt 6 | 7 | data_dir = "data" 8 | #data_dir_cifar10 = os.path.join(data_dir, "cifar-10-batches-py") 9 | data_dir_cifar100 = os.path.join(data_dir, "cifar-100-python") 10 | 11 | #class_names_cifar10 = np.load(os.path.join(data_dir_cifar10, "batches.meta")) 12 | class_names_cifar100 = np.load(os.path.join(data_dir_cifar100, "meta")) 13 | 14 | 15 | def one_hot(x, n): 16 | """ 17 | convert index representation to one-hot representation 18 | """ 19 | x = np.array(x) 20 | assert x.ndim == 1 21 | return np.eye(n)[x] 22 | 23 | 24 | def _load_batch_cifar10(filename, dtype='float64'): 25 | """ 26 | load a batch in the CIFAR-10 format 27 | """ 28 | path = os.path.join(data_dir_cifar10, filename) 29 | batch = np.load(path) 30 | data = batch['data'] / 255.0 # scale between [0, 1] 31 | labels = one_hot(batch['labels'], n=10) # convert labels to one-hot representation 32 | return data.astype(dtype), labels.astype(dtype) 33 | 34 | 35 | def _grayscale(a): 36 | print 37 | a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1) 38 | return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1) 39 | 40 | 41 | def cifar10(dtype='float64', grayscale=True): 42 | # train 43 | x_train = [] 44 | t_train = [] 45 | for k in xrange(5): 46 | x, t = _load_batch_cifar10("data_batch_%d" % (k + 1), dtype=dtype) 47 | x_train.append(x) 48 | t_train.append(t) 49 | 50 | x_train = np.concatenate(x_train, axis=0) 51 | t_train = np.concatenate(t_train, axis=0) 52 | 53 | # test 54 | x_test, t_test = _load_batch_cifar10("test_batch", dtype=dtype) 55 | 56 | if grayscale: 57 | x_train = _grayscale(x_train) 58 | x_test = _grayscale(x_test) 59 | 60 | return x_train, t_train, x_test, t_test 61 | 62 | 63 | def _load_batch_cifar100(filename, dtype='float64'): 64 | """ 65 | load a batch in the CIFAR-100 format 66 | """ 67 | path = os.path.join(data_dir_cifar100, filename) 68 | #batch = np.load(path,allow_pickle=True,fix_imports=True,encoding='latin1') 69 | batch = np.load(np.loadtxt,allow_pickle=True) 70 | data = batch['data'] / 255.0 71 | labels = one_hot(batch['fine_labels'], n=100) 72 | return data.astype(dtype), labels.astype(dtype) 73 | 74 | 75 | def cifar100(dtype='float64', grayscale=True): 76 | x_train, t_train = _load_batch_cifar100("train", dtype=dtype) 77 | x_test, t_test = _load_batch_cifar100("test", dtype=dtype) 78 | 79 | if grayscale: 80 | x_train = _grayscale(x_train) 81 | x_test = _grayscale(x_test) 82 | 83 | return x_train, t_train, x_test, t_test -------------------------------------------------------------------------------- /improved_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCM 2 | 3 | >📋 This is the implementation of Online Task-Free Continual Generative and Discriminative Learning via Dynamic Cluster Memory 4 | 5 | >📋 Accepted by CVPR 2024 6 | 7 | # Title : Online Task-Free Continual Generative and Discriminative Learning via Dynamic Cluster Memory 8 | 9 | # Paper link : https://openaccess.thecvf.com/content/CVPR2024/html/Ye_Online_Task-Free_Continual_Generative_and_Discriminative_Learning_via_Dynamic_Cluster_CVPR_2024_paper.html 10 | 11 | 12 | # Abstract 13 | 14 | Online Task-Free Continual Learning (OTFCL) aims to learn novel concepts from streaming data without accessing task information. Memory-based approaches have shown remarkable results in OTFCL, but most require accessing supervised signals to implement their sample selection mechanisms, limiting their applicability for unsupervised learning. In this study, we address this issue by proposing a novel memory management approach, namely the Dynamic Cluster Memory (DCM), which builds new memory clusters to capture distribution shifts over time without accessing any supervised signal. 15 | DCM introduces a novel memory expansion mechanism based on the knowledge discrepancy criterion, which evaluates the novelty of the incoming data as the signal for the memory expansion, ensuring a compact memory capacity. We also propose a new sample selection approach that automatically stores incoming data samples with similar semantic information in the same memory cluster, facilitating the knowledge diversity among memory clusters. Furthermore, a novel memory pruning approach is proposed to automatically remove overlapping memory clusters through a graph relation evaluation, ensuring a fixed memory capacity while maintaining the diversity among the samples stored in the memory. The proposed DCM is model-free, plug-and-play, and can be used in both supervised and unsupervised learning without modifications. Empirical results on OTFCL experiments show that the proposed DCM outperforms the state-of-the-art while memorizing fewer data samples. 16 | 17 | ![image](https://github.com/dtuzi123/DCM/blob/main/GraphMemory_Structure_1.jpg 18 | ) 19 | # Environment 20 | 21 | 1. Pytorch 1.12 22 | 2. Python 3.7 23 | 24 | Our code is based on the improved diffusion model ("https://github.com/openai/improved-diffusion") 25 | 26 | # Training and evaluation 27 | 28 | >📋 Python xxx.py, the model will be automatically trained and then report the results after the training. 29 | 30 | >📋 Different parameter settings of DCM would lead different results and we also provide different settings used in our experiments. 31 | 32 | 33 | # Visual results 34 | 35 | >📋 Split MNIST, Split Fashion and Split CIFAR10 36 | 37 | ![image](https://github.com/dtuzi123/DCM/blob/main/mnist_GraphMemory_WDistance2200.png) ![image](https://github.com/dtuzi123/DCM/blob/main/fashion_GraphMemory_WDistance2000.png) ![image](https://github.com/dtuzi123/DCM/blob/main/cifar10_GraphMemory2000.png) 38 | 39 | 40 | # BibTex 41 | >📋 If you use our code, please cite our paper as: 42 | >@inproceedings{ye2024online, 43 | title={Online Task-Free Continual Generative and Discriminative Learning via Dynamic Cluster Memory}, 44 | author={Ye, Fei and Bors, Adrian G}, 45 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 46 | pages={26202--26212}, 47 | year={2024} 48 | } 49 | -------------------------------------------------------------------------------- /NetworkModels/DynamicTeacherStudent_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 4 | from NetworkModels.VAE_Model_ import StudentModel,Balance_StudentModel 5 | import torch.nn as nn 6 | from datasets.Data_Loading import * 7 | from datasets.Fid_evaluation import * 8 | import torch as th 9 | 10 | class DynamicTeacherStudent(nn.Module): 11 | 12 | def __init__(self,name,device,input_size): 13 | super(DynamicTeacherStudent, self).__init__() 14 | 15 | self.expansion_threshold = 100 16 | self.selection_score = 0 17 | self.selectedTeacher = 0 18 | 19 | self.input_size = input_size 20 | self.teacher = Balance_Teacher(input_size) 21 | self.student = Balance_StudentModel(device,input_size) 22 | self.device = device 23 | 24 | self.teacherArr = [] 25 | self.teacherArr.append(self.teacher) 26 | self.selectedTeacher = self.teacher 27 | 28 | def CheckExpansion(self,newTask): 29 | maxCount = 500 30 | dataset = newTask[0:maxCount] 31 | 32 | arr = [] 33 | for i in range(np.shape(self.teacherArr)[0]): 34 | g1 = self.teacherArr.Give_Generation(maxCount) 35 | fid = calculate_fid_given_paths_Byimages(g1, dataset, 50, self.device, 2048) 36 | arr.append(fid) 37 | 38 | minScore = np.min(arr) 39 | minIndex = np.argmin(arr) 40 | 41 | if minScore > self.expansion_threshold:#Check the expansion 42 | #Perform the expansion 43 | newTeacher = Balance_Teacher(self.input_size) 44 | self.teacherArr.append(newTeacher) 45 | self.selectedTeacher = newTeacher 46 | self.selection_score = 0 47 | else: 48 | #Perform the component selection 49 | self.selectedTeacher = self.teacherArr[minIndex] 50 | self.selection_score = 1 51 | 52 | def Train(self,Tepoch,Sepoch,data,generatedData): 53 | 54 | #Check the expansion 55 | dataX = data 56 | 57 | if self.selection_score == 0: 58 | self.selectedTeacher.Train_Self(Tepoch,data) 59 | #train the student 60 | count = np.shape(self.dynamicTeacher.teacherArr)[0] 61 | batchSize = int(self.batch_size / count) 62 | totalCount = np.shape(dataX)[0] / batchSize 63 | 64 | for i in range(totalCount): 65 | arr2 = dataX[i * batchSize:(i + 1) * batchSize] 66 | for j in range(np.shape(self.dynamicTeacher.teacherArr)[0]): 67 | newa = self.dynamicTeacher.teacherArr[j].Give_Generation(batchSize) 68 | arr2 = th.cat([arr2, newa], dim=0) 69 | self.selectedTeacher.student.Train_One(arr2) 70 | else: 71 | self.selectedTeacher.Train_Self(Tepoch, data) 72 | # train the student 73 | count = np.shape(self.dynamicTeacher.teacherArr)[0] 74 | batchSize = int(self.batch_size / count) 75 | totalCount = np.shape(dataX)[0] / batchSize 76 | 77 | for i in range(totalCount): 78 | arr2 = dataX[i * batchSize:(i + 1) * batchSize] 79 | for j in range(np.shape(self.dynamicTeacher.teacherArr)[0]): 80 | newa = self.dynamicTeacher.teacherArr[j].Give_Generation(batchSize) 81 | arr2 = th.cat([arr2, newa], dim=0) 82 | self.selectedTeacher.student.Train_One(arr2) 83 | 84 | if np.shape(generatedData)[0] == 0: 85 | self.teacher.Train_Self(Tepoch,data) 86 | self.student.Train_Self(Sepoch,data) 87 | else: 88 | self.teacher.Train_Self_(Tepoch,data,generatedData) 89 | self.student.Train_Self_(Sepoch,data,generatedData) 90 | 91 | 92 | -------------------------------------------------------------------------------- /NetworkModels/Balance_TeacherStudent_NoMPI_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 4 | from NetworkModels.Teacher_Model_NoMPI_ import Balance_Teacher_NoMPI 5 | from NetworkModels.VAE_Model_ import Balance_StudentModel 6 | import torch.nn as nn 7 | 8 | class Balance_TeacherStudent_NoMPI(nn.Module): 9 | 10 | def __init__(self,name,device,input_size): 11 | super(Balance_TeacherStudent_NoMPI, self).__init__() 12 | 13 | self.input_size = input_size 14 | self.teacher = Balance_Teacher_NoMPI(input_size) 15 | self.student = Balance_StudentModel(device,input_size) 16 | self.device = device 17 | 18 | def Train_WithBeta_Cpu_ForStudent(self,Tepoch,Sepoch,data,generatedData,beta): 19 | 20 | if np.shape(generatedData)[0] == 0: 21 | #self.teacher.train_self_Single_Cpu(Tepoch,data) 22 | self.student.Train_Self_WithBeta_Single_Cpu(Sepoch,data,beta) 23 | else: 24 | #self.teacher.Train_Self_Cpu(Tepoch,data,generatedData) 25 | self.student.Train_Self_WithBeta_Cpu(Sepoch,data,generatedData,beta) 26 | 27 | 28 | def Train_WithBeta_Cpu(self,Tepoch,Sepoch,data,generatedData,beta): 29 | 30 | if np.shape(generatedData)[0] == 0: 31 | self.teacher.train_self_Single_Cpu(Tepoch,data) 32 | self.student.Train_Self_WithBeta_Single_Cpu(Sepoch,data,beta) 33 | else: 34 | self.teacher.Train_Self_Cpu(Tepoch,data,generatedData) 35 | self.student.Train_Self_WithBeta_Cpu(Sepoch,data,generatedData,beta) 36 | 37 | 38 | def Train_WithBeta_DatLoad(self,Tepoch,Sepoch,data,generatedData,beta): 39 | 40 | if np.shape(generatedData)[0] == 0: 41 | self.teacher.train_self_Single_Cpu(Tepoch,data) 42 | self.student.Train_Self_WithBeta_Single_Cpu(Sepoch,data,beta) 43 | else: 44 | self.teacher.Train_Self_Cpu(Tepoch,data,generatedData) 45 | self.student.Train_Self_WithBeta_Cpu(Sepoch,data,generatedData,beta) 46 | 47 | 48 | def Train_WithBeta_Cpu_2(self,Tepoch,Sepoch,data,generatedData,beta): 49 | 50 | if np.shape(generatedData)[0] == 0: 51 | self.teacher.train_self_Single_Cpu(Tepoch,data) 52 | self.student.Train_Self_WithBeta_Single_Cpu(Sepoch,data,beta) 53 | else: 54 | self.teacher.Train_Self_Cpu(Tepoch,data,generatedData) 55 | self.student.Train_Self_WithBeta_Cpu(Sepoch,data,generatedData,beta) 56 | 57 | 58 | def Train_WithBeta(self,Tepoch,Sepoch,data,generatedData,beta): 59 | 60 | if np.shape(generatedData)[0] == 0: 61 | self.teacher.Train_Self(Tepoch,data) 62 | self.student.Train_Self_WithBeta_Single(Sepoch,data,beta) 63 | else: 64 | self.teacher.Train_Self_(Tepoch,data,generatedData) 65 | self.student.Train_Self_WithBeta(Sepoch,data,generatedData,beta) 66 | 67 | def Train(self,Tepoch,Sepoch,data,generatedData): 68 | 69 | if np.shape(generatedData)[0] == 0: 70 | self.teacher.Train_Self(Tepoch,data) 71 | self.student.Train_Self(Sepoch,data) 72 | else: 73 | self.teacher.Train_Self_(Tepoch,data,generatedData) 74 | self.student.Train_Self_(Sepoch,data,generatedData) 75 | 76 | 77 | def Train_ByLoadData_Single(self,Tepoch,Sepoch,data): 78 | #self.teacher.Train_Self_ByDataLoad_Single(Tepoch, data) 79 | self.student.Train_Self_ByDataLoad_Single(Sepoch, data) 80 | 81 | 82 | def Train_ByLoadData(self,Tepoch,Sepoch,data,generatedData): 83 | 84 | if np.shape(generatedData)[0] == 0: 85 | self.teacher.Train_Self_ByDataLoad(Tepoch,data) 86 | self.student.Train_Self(Sepoch,data) 87 | else: 88 | self.teacher.Train_Self_ByDataLoad(Tepoch,data,generatedData) 89 | self.student.Train_Self_ByDataLoad(Sepoch,data,generatedData) 90 | 91 | 92 | -------------------------------------------------------------------------------- /NetworkModels/TFCL_DynamicVAEModel_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 4 | from NetworkModels.Teacher_Model_NoMPI_ import Balance_Teacher_NoMPI 5 | from NetworkModels.VAE_Model_ import Balance_StudentModel 6 | import torch.nn as nn 7 | from NetworkModels.TFCL_Teacher_ import * 8 | from improved_diffusion.train_util_balance_NoMPI_MultiGPU import * 9 | from NetworkModels.VAE_Model_ import * 10 | 11 | class TFCL_DynamicVAEModel(nn.Module): 12 | def __init__(self,name,device,input_size): 13 | super(TFCL_DynamicVAEModel, self).__init__() 14 | 15 | self.input_size = input_size 16 | self.currentComponent = TFCL_StudentModel(device,input_size) 17 | self.componentList = [] 18 | self.componentList.append(self.currentComponent) 19 | 20 | self.device = device 21 | self.trainingCount = 0 22 | self.trainingUpdate = 4 23 | self.GeneratingBatchSampleSize = 64 24 | self.batch_size = 64 25 | self.isTrainer = 0 26 | 27 | def AddNewComponent(self): 28 | a = TFCL_StudentModel(self.device,self.input_size) 29 | self.currentComponent = a 30 | self.componentList.append(a) 31 | 32 | def SelectComponent_BySample(self,singleData): 33 | arr = [] 34 | for i in range(np.shape(self.componentList)[0]): 35 | loss1 = self.componentList[i].ComputerLoss(singleData) 36 | loss1 = loss1['loss'] 37 | loss1 = loss1.cpu().detach().numpy() 38 | arr.append(loss1) 39 | minindex = np.argmin(arr) 40 | return minindex 41 | 42 | def Train_SelectedComponent(self,epoch,data,component): 43 | component.Train_Self(epoch,data) 44 | 45 | def GiveReconstructionBatch(self,batch): 46 | arr = [] 47 | for i in range(np.shape(batch)[0]): 48 | sample = batch[i] 49 | sample = torch.reshape(sample, (1, 3, self.input_size, self.input_size)) 50 | sample = torch.cat([sample, sample], 0) 51 | 52 | index = self.SelectComponent_BySample(sample) 53 | 54 | reco = self.componentList[index].Give_ReconstructionSingle(sample) 55 | reco = reco[0] 56 | reco = torch.reshape(reco, (1, 3, self.input_size, self.input_size)) 57 | if np.shape(arr)[0] == 0: 58 | arr = reco 59 | else: 60 | arr = torch.cat([arr,reco],0) 61 | return arr 62 | 63 | def GiveReconstructionFromOriginalImages(self,data): 64 | arr = [] 65 | count = int(np.shape(data)[0]/self.batch_size) 66 | for i in range(count): 67 | batch = data[i*self.batch_size:(i+1)*self.batch_size] 68 | batch = th.tensor(batch).cuda().to(device=self.device, dtype=th.float) 69 | reco = self.GiveReconstructionBatch(batch) 70 | if np.shape(arr)[0] == 0: 71 | arr = reco 72 | else: 73 | arr = th.cat([arr, reco], 0) 74 | return arr 75 | 76 | def GiveReconstruction(self,data): 77 | arr = [] 78 | 79 | count = int(np.shape(data)[0]/self.batch_size) 80 | for i in range(count): 81 | batch = data[i*self.batch_size:(i+1)*self.batch_size] 82 | reco = self.GiveReconstructionBatch(batch) 83 | if np.shape(arr)[0] == 0: 84 | arr = reco 85 | else: 86 | arr = torch.cat([arr,reco],0) 87 | return arr 88 | 89 | 90 | def GiveMixGeneration(self,num): 91 | count = int(num / self.batch_size) 92 | t2 = int(self.batch_size/np.shape(self.componentList)[0]) 93 | 94 | arr = [] 95 | for i in range(count): 96 | for j in range(np.shape(self.componentList)[0]): 97 | x1 = self.componentList[j].Generation(t2) 98 | if np.shape(arr)[0] == 0: 99 | arr = x1 100 | else: 101 | arr = torch.cat([arr,x1],0) 102 | 103 | return arr 104 | 105 | -------------------------------------------------------------------------------- /NetworkModels/TeacherEnsembleFramework_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from improved_diffusion import dist_util, logger 3 | 4 | #from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 5 | from NetworkModels.Teacher_Model_NoMPI_ import Balance_Teacher_NoMPI 6 | from NetworkModels.VAE_Model_ import Balance_StudentModel 7 | import torch.nn as nn 8 | from NetworkModels.TFCL_Teacher_ import * 9 | from improved_diffusion.train_util_balance_NoMPI_MultiGPU import * 10 | from NetworkModels.VAE_Model_ import * 11 | import random 12 | import torch.distributions as td 13 | from models.VAE256 import * 14 | from NetworkModels.DynamicDiffusionMixture_ import * 15 | from NetworkModels.MMD_Lib import * 16 | from models.VAE256 import * 17 | from NetworkModels.MyClassifier import * 18 | from NetworkModels.MemoryUnitFramework_ import * 19 | import numpy as np 20 | from datasets.Fid_evaluation import * 21 | 22 | class TeacherEnsembleFramework(MemoryUnitFramework): 23 | def __init__(self,name,device,input_size): 24 | super(MemoryUnitFramework, self).__init__(name,device,input_size) 25 | 26 | self.input_size = input_size 27 | self.device = device 28 | self.trainingCount = 0 29 | self.trainingUpdate = 4 30 | self.GeneratingBatchSampleSize = 64 31 | self.batchTrainStudent_size = 64 32 | self.isTrainer = 0 33 | self.teacherArray = [] 34 | self.autoencoderArr = [] 35 | 36 | self.isExpansion = True 37 | 38 | self.student = StudentModel(device,input_size)#Autoencoder(device, input_size) 39 | self.autoencoderArr.append(self.student) 40 | 41 | self.currentComponent = 0 42 | self.batch_size = 64 43 | 44 | self.resultMatrix = [] 45 | self.unitCount = 0 46 | 47 | self.memoryUnits = [] 48 | self.currentMemory = [] 49 | self.threshold = 0.02 50 | 51 | self.maxMemorySize = 2000 52 | self.diversityThreshold = 0.02 53 | self.expansionThreshold = 140 54 | self.memoryUnitSize = 64 55 | self.currentTrainingIndex = 0 56 | self.maxTrainingTime = 200 57 | self.currentTrainingTime = 0 58 | 59 | def Transfer_To_Numpy(self,sample): 60 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 61 | sample = sample.permute(0, 2, 3, 1) 62 | sample = sample.contiguous() 63 | mySamples = sample.unsqueeze(0).cuda().cpu() 64 | mySamples = np.array(mySamples) 65 | mySamples = mySamples[0] 66 | return mySamples 67 | 68 | def CheckModelExpansion_TaskKnown(self,newDataset): 69 | isExpansion = True 70 | index = 0 71 | 72 | genCount = 1000 73 | newDataset2 = newDataset[0:genCount] 74 | newDataset2 = torch.tensor(newDataset2).cuda().to(device=self.device, dtype=torch.float) 75 | 76 | arr = [] 77 | for i in range(np.shape(self.teacherArray)[0]): 78 | gan1 = self.teacherArray[i].GenerateImages(genCount) 79 | #gan1 = self.Transfer_To_Numpy(gan1) 80 | 81 | fid1 = calculate_fid_given_paths_Byimages(newDataset2, gan1, 50, self.device, 2048) 82 | arr.append(fid1) 83 | 84 | minscore = np.min(arr) 85 | if minscore > self.expansionThreshold: 86 | #Perform the expansion 87 | newComponent = self.Create_NewComponent() 88 | self.currentComponent = newComponent 89 | else: 90 | #Perform the expert selection 91 | index = np.argmin(arr) 92 | self.currentComponent = self.teacherArray[index] 93 | isExpansion = False 94 | 95 | return minscore,index,arr,isExpansion 96 | 97 | def TrainStudent_Numpy(self,epoch,memory): 98 | self.student.Train_Self_Single_Beta3_Numpy(epoch,memory) 99 | 100 | def TrainStudent_Balance_Numpy(self,epoch,generatedData,memory): 101 | self.student.Train_Self_Single_Beta3_Balance_Numpy(epoch,generatedData,memory) 102 | 103 | 104 | def TrainStudent(self, epoch, memory): 105 | # using the KD 106 | self.student.training_step2_WithBeta() 107 | self.student.Train_Self_Single_Beta3(epoch, memory) 108 | -------------------------------------------------------------------------------- /improved_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | #from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def load_data( 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 10 | ): 11 | """ 12 | For a dataset, create a generator over (images, kwargs) pairs. 13 | 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 15 | more keys, each of which map to a batched Tensor of their own. 16 | The kwargs dict can be used for class labels, in which case the key is "y" 17 | and the values are integer tensors of class labels. 18 | 19 | :param data_dir: a dataset directory. 20 | :param batch_size: the batch size of each returned pair. 21 | :param image_size: the size to which images are resized. 22 | :param class_cond: if True, include a "y" key in returned dicts for class 23 | label. If classes are not available and this is true, an 24 | exception will be raised. 25 | :param deterministic: if True, yield results in a deterministic order. 26 | """ 27 | if not data_dir: 28 | raise ValueError("unspecified data directory") 29 | all_files = _list_image_files_recursively(data_dir) 30 | classes = None 31 | if class_cond: 32 | # Assume classes are the first part of the filename, 33 | # before an underscore. 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 36 | classes = [sorted_classes[x] for x in class_names] 37 | dataset = ImageDataset( 38 | image_size, 39 | all_files, 40 | classes=classes, 41 | shard=MPI.COMM_WORLD.Get_rank(), 42 | num_shards=MPI.COMM_WORLD.Get_size(), 43 | ) 44 | if deterministic: 45 | loader = DataLoader( 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 47 | ) 48 | else: 49 | loader = DataLoader( 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 51 | ) 52 | while True: 53 | yield from loader 54 | 55 | 56 | def _list_image_files_recursively(data_dir): 57 | results = [] 58 | for entry in sorted(bf.listdir(data_dir)): 59 | full_path = bf.join(data_dir, entry) 60 | ext = entry.split(".")[-1] 61 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 62 | results.append(full_path) 63 | elif bf.isdir(full_path): 64 | results.extend(_list_image_files_recursively(full_path)) 65 | return results 66 | 67 | 68 | class ImageDataset(Dataset): 69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1): 70 | super().__init__() 71 | self.resolution = resolution 72 | self.local_images = image_paths[shard:][::num_shards] 73 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 74 | 75 | def __len__(self): 76 | return len(self.local_images) 77 | 78 | def __getitem__(self, idx): 79 | path = self.local_images[idx] 80 | with bf.BlobFile(path, "rb") as f: 81 | pil_image = Image.open(f) 82 | pil_image.load() 83 | 84 | # We are not on a new enough PIL to support the `reducing_gap` 85 | # argument, which uses BOX downsampling at powers of two first. 86 | # Thus, we do it by hand to improve downsample quality. 87 | while min(*pil_image.size) >= 2 * self.resolution: 88 | pil_image = pil_image.resize( 89 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 90 | ) 91 | 92 | scale = self.resolution / min(*pil_image.size) 93 | pil_image = pil_image.resize( 94 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 95 | ) 96 | 97 | arr = np.array(pil_image.convert("RGB")) 98 | crop_y = (arr.shape[0] - self.resolution) // 2 99 | crop_x = (arr.shape[1] - self.resolution) // 2 100 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 101 | arr = arr.astype(np.float32) / 127.5 - 1 102 | 103 | out_dict = {} 104 | if self.local_classes is not None: 105 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 106 | return np.transpose(arr, [2, 0, 1]), out_dict 107 | -------------------------------------------------------------------------------- /NetworkModels/DynamicMixture256_.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from improved_diffusion import dist_util, logger 4 | 5 | #from NetworkModels.Teacher_Model_ import Teacher,Balance_Teacher 6 | from NetworkModels.Teacher_Model_NoMPI_ import Balance_Teacher_NoMPI 7 | from NetworkModels.VAE_Model_ import Balance_StudentModel 8 | import torch.nn as nn 9 | from NetworkModels.TFCL_Teacher_ import * 10 | from improved_diffusion.train_util_balance_NoMPI_MultiGPU import * 11 | from NetworkModels.VAE_Model_ import * 12 | import random 13 | import torch.distributions as td 14 | from models.VAE256 import * 15 | from NetworkModels.DynamicDiffusionMixture_ import * 16 | from NetworkModels.MMD_Lib import * 17 | from models.VAE256 import * 18 | 19 | 20 | class DynamicMixture256(DynamicDiffusionMixture): 21 | def __init__(self,name,device,input_size,modelType,originalInputSize): 22 | super(DynamicDiffusionMixture, self).__init__() 23 | 24 | self.input_size = input_size 25 | self.device = device 26 | self.trainingCount = 0 27 | self.trainingUpdate = 4 28 | self.GeneratingBatchSampleSize = 64 29 | self.batchTrainStudent_size = 64 30 | self.isTrainer = 0 31 | self.teacherArray = [] 32 | self.autoencoderArr = [] 33 | self.memorybuffer = [] 34 | self.OriginalInputSize = originalInputSize 35 | self.currentTrainingTime = 0 36 | 37 | if modelType == "GAN": 38 | print("GAN") 39 | else: 40 | teacher = Autoencoder(device, input_size) 41 | teacher.OriginalInputSize = self.OriginalInputSize 42 | self.currentComponent = teacher 43 | self.teacherArray.append(teacher) 44 | self.student = Autoencoder(device, input_size) 45 | self.student.OriginalInputSize = self.OriginalInputSize 46 | #print(self.currentComponent.input_size) 47 | 48 | def Create_NewTeacher(self): 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | teacher = Autoencoder(device, self.input_size) 51 | self.currentComponent = teacher 52 | self.teacherArray.append(teacher) 53 | 54 | def Check_Expansion_Cpu(self): 55 | 56 | arr = [] 57 | for i in range(np.shape(self.teacherArr)[0]-1): 58 | print("build") 59 | memory = self.memorybuffer[0:1000] 60 | memory = memory 61 | memory = torch.FloatTensor(memory) 62 | 63 | generated = self.teacherArr[i].GiveGeneration_Cpu(np.shape(memory)[0]) 64 | #generated = generated.cpu() 65 | 66 | fid = calculate_fid_given_paths_Byimages(memory, generated, 50, self.device, 2048) 67 | #fid = fid.cpu().detach().numpy() 68 | arr.append(fid) 69 | 70 | minvalue = np.min(arr) 71 | print(minvalue) 72 | 73 | if minvalue > self.threshold: 74 | print("Build") 75 | self.Create_NewTeacher() 76 | 77 | 78 | def TrainTeacher(self,epoch,memoryBuffer): 79 | self.currentComponent.Train_Cpu_WithFiles(epoch, memoryBuffer) 80 | 81 | def TrainStudent_Cpu(self,epoch,memory): 82 | # using the KD 83 | self.student.Train_Self_Single_Beta3_Cpu(epoch,memory) 84 | 85 | 86 | def GiveGenerationByTeacher(self, num): 87 | count = np.shape(self.teacherArray)[0] 88 | t = int(num / 2) 89 | arr = [] 90 | for i in range(t): 91 | index = random.randint(1, count) - 1 92 | new1 = self.teacherArray[index].vae.sample_with_noise(2,self.device) 93 | if np.shape(arr)[0] == 0: 94 | arr = new1 95 | else: 96 | arr = torch.cat([arr, new1], 0) 97 | return arr 98 | 99 | 100 | def Give_GenerationFromTeacher_Cpu(self,num): 101 | 102 | with torch.no_grad(): 103 | count = np.shape(self.teacherArray)[0] 104 | t = int(num / 2) 105 | arr = [] 106 | for i in range(t): 107 | index = random.randint(1,count) - 1 108 | new1 = self.teacherArray[index].Give_GenerationsWithN(2) 109 | 110 | new1 = new1.unsqueeze(0).cuda().cpu() 111 | new1 = np.array(new1) 112 | new1 = new1[0] 113 | 114 | if np.shape(arr)[0] == 0: 115 | arr = new1 116 | else: 117 | #arr = torch.cat([arr,new1],0) 118 | arr = np.concatenate((arr,new1),0) 119 | return arr 120 | -------------------------------------------------------------------------------- /datasets/MNIST32.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from torch.utils.data import Dataset 6 | from torch.utils.data import DataLoader 7 | from datasets.Data_Loading import * 8 | 9 | class Generated_Dataset_Unsupervised(Dataset): 10 | def __init__(self ,filepath,datax,transform_train): 11 | #z = np.loadtxt(filepath ,dtype=np.float32 ,delimiter=',') 12 | 13 | self.transform = transform_train 14 | 15 | #mnist_train_x = mnist_train_x * 255.0 16 | #mnist_test = mnist_test * 255.0 17 | 18 | #mnist_train_x = np.transpose(mnist_train_x, (0, 3, 1, 2)) 19 | #mnist_test = np.transpose(mnist_test, (0, 3, 1, 2)) 20 | #mnist_train_x = mnist_train_x / 127.5 - 1 21 | #mnist_test = mnist_test / 127.5 - 1 22 | 23 | #self.x_data = torch.from_numpy(mnist_train_x) 24 | self.x_data = datax 25 | self.len = np.shape(self.x_data)[0] 26 | 27 | def GetSourceData(self): 28 | return self.x_data 29 | 30 | def SetData(self,datax): 31 | self.x_data = datax 32 | self.len = np.shape(self.x_data)[0] 33 | 34 | def GetLength(self): 35 | return self.len 36 | 37 | def __len__(self): 38 | return self.len 39 | def __getitem__(self, item): 40 | xx = self.transform(self.x_data[item]) 41 | xx = xx.type(torch.float) 42 | return xx 43 | 44 | 45 | class Generated_Dataset(Dataset): 46 | def __init__(self ,filepath,datax,datay,transform_train): 47 | #z = np.loadtxt(filepath ,dtype=np.float32 ,delimiter=',') 48 | 49 | self.transform = transform_train 50 | 51 | #mnist_train_x = mnist_train_x * 255.0 52 | #mnist_test = mnist_test * 255.0 53 | 54 | #mnist_train_x = np.transpose(mnist_train_x, (0, 3, 1, 2)) 55 | #mnist_test = np.transpose(mnist_test, (0, 3, 1, 2)) 56 | #mnist_train_x = mnist_train_x / 127.5 - 1 57 | #mnist_test = mnist_test / 127.5 - 1 58 | 59 | #self.x_data = torch.from_numpy(mnist_train_x) 60 | self.x_data = datax 61 | mnist_train_label = np.argmax(datay,1) 62 | self.y_data = torch.from_numpy(mnist_train_label) 63 | self.len = self.x_data.shape[0] 64 | 65 | def SetData(self,datax,datay): 66 | self.x_data = datax 67 | mnist_train_label = np.argmax(datay, 1) 68 | self.y_data = torch.from_numpy(mnist_train_label) 69 | self.len = self.x_data.shape[0] 70 | 71 | def __len__(self): 72 | return self.len 73 | def __getitem__(self, item): 74 | xx = self.transform(self.x_data[item]) 75 | xx = xx.type(torch.float) 76 | return xx,self.y_data[item] 77 | 78 | class MNIST32_Train(Dataset): 79 | def __init__(self ,filepath,transform_train): 80 | #z = np.loadtxt(filepath ,dtype=np.float32 ,delimiter=',') 81 | 82 | self.transform = transform_train 83 | mnist_train_x, mnist_train_label, mnist_test, mnist_label_test, x_train, y_train, x_test, y_test = GiveMNIST_SVHN() 84 | 85 | #mnist_train_x = mnist_train_x * 255.0 86 | #mnist_test = mnist_test * 255.0 87 | 88 | #mnist_train_x = np.transpose(mnist_train_x, (0, 3, 1, 2)) 89 | #mnist_test = np.transpose(mnist_test, (0, 3, 1, 2)) 90 | #mnist_train_x = mnist_train_x / 127.5 - 1 91 | #mnist_test = mnist_test / 127.5 - 1 92 | 93 | #self.x_data = torch.from_numpy(mnist_train_x) 94 | self.x_data = mnist_train_x 95 | mnist_train_label = np.argmax(mnist_train_label,1) 96 | self.y_data = torch.from_numpy(mnist_train_label) 97 | self.len = self.x_data.shape[0] 98 | 99 | def __len__(self): 100 | return self.len 101 | def __getitem__(self, item): 102 | xx = self.transform(self.x_data[item]) 103 | xx = xx.type(torch.float) 104 | return xx,self.y_data[item] 105 | 106 | 107 | class MNIST32_Test(Dataset): 108 | def __init__(self ,filepath,transform_train): 109 | #z = np.loadtxt(filepath ,dtype=np.float32 ,delimiter=',') 110 | 111 | self.transform = transform_train 112 | mnist_train_x, mnist_train_label, mnist_test, mnist_label_test, x_train, y_train, x_test, y_test = GiveMNIST_SVHN() 113 | 114 | #mnist_train_x = mnist_train_x * 255.0 115 | #mnist_test = mnist_test * 255.0 116 | 117 | #mnist_train_x = np.transpose(mnist_train_x, (0, 3, 1, 2)) 118 | #mnist_test = np.transpose(mnist_test, (0, 3, 1, 2)) 119 | #mnist_train_x = mnist_train_x / 127.5 - 1 120 | #mnist_test = mnist_test / 127.5 - 1 121 | 122 | #self.x_data = torch.from_numpy(mnist_test) 123 | self.x_data = mnist_test 124 | mnist_label_test = np.argmax(mnist_label_test,1) 125 | self.y_data = torch.from_numpy(mnist_label_test) 126 | self.len = self.x_data.shape[0] 127 | 128 | def __len__(self): 129 | return self.len 130 | def __getitem__(self, item): 131 | xx = self.transform(self.x_data[item]) 132 | xx = xx.type(torch.float) 133 | return xx ,self.y_data[item] -------------------------------------------------------------------------------- /datasets/CelebA_Load.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset #Dataset的包 2 | import os #路径需要这个 3 | import cv2 # 需要读取图片,最好用opencv-python,当然也可以用PIL只是我不顺手 4 | import random 5 | import numpy as np 6 | import os 7 | import gzip 8 | import cv2 9 | import scipy.io as sio 10 | from cv2_imageProcess import * 11 | import glob 12 | import imageio 13 | import scipy 14 | from PIL import Image 15 | import torchvision.datasets as datasets 16 | import torch.utils.data as data 17 | import torchvision.transforms as transforms 18 | import torchvision 19 | 20 | class CelebA_dataset(Dataset): #我定义的这个类 21 | def __init__(self, root_dir, transform=None): 22 | #下面需要使用的变量,在__init__定义好, 23 | self.transform = transform 24 | 25 | img_path = glob.glob('../img_celeba2/*.jpg') # 获取新文件夹下所有图片 26 | # img_path = glob('C:/CommonData/img_celeba2/*.jpg') # 获取新文件夹下所有图片 27 | data_files = img_path 28 | data_files = sorted(data_files) 29 | data_files = np.array(data_files) # for tl.iterate.minibatches 30 | celebaFiles = data_files 31 | maxSize = 50000 32 | maxTestingSize = 5000 33 | 34 | # maxSize = 100 35 | # maxTestingSize = 100 36 | 37 | # maxSize = 128 38 | # maxTestingSize = 50 39 | 40 | celebaTraining = celebaFiles[0:maxSize] 41 | self.img_path = celebaTraining #得到整体图片的路径(可取其中的一张一张的图像的名字) 42 | 43 | def __getitem__(self, idx): 44 | # 改写__getitem__(self,item)函数,最后得到图像,标签 45 | #获取具体的一幅图像的名字 46 | img_name = self.img_path[idx] 47 | #获取一幅图像的详细地址 48 | img_item_path = img_name 49 | #用opencv来读取图像 50 | 51 | ''' 52 | img = cv2.imread(img_item_path) 53 | #获取标签(这里简单写了aligned与original) 54 | label = self.label_dir 55 | return img, label 56 | ''' 57 | 58 | ''' 59 | celebaTrainSet = [GetImage_cv( 60 | sample_file, 61 | input_height=128, 62 | input_width=128, 63 | resize_height=64, 64 | resize_width=64, 65 | crop=True) 66 | for sample_file in img_item_path] 67 | ''' 68 | 69 | results = GetImage_cv( 70 | img_name, 71 | input_height=128, 72 | input_width=128, 73 | resize_height=64, 74 | resize_width=64, 75 | crop=True) 76 | 77 | results = np.transpose(results, (2, 0, 1)) 78 | #celebaTrainSet = np.array(celebaTrainSet) 79 | #celebaTrainSet = np.transpose(celebaTrainSet, (0, 3, 1, 2)) 80 | 81 | return results 82 | 83 | def __len__(self): 84 | #改写整体图像的大小 85 | return len(self.img_path) 86 | 87 | 88 | def file_name2(file_dir): 89 | t1 = [] 90 | for root, dirs, files in os.walk(file_dir): 91 | for a1 in dirs: 92 | b1 = "../rendered_chairs/" + a1 + "/renders/*.png" 93 | img_path = glob.glob(b1) 94 | t1.append(img_path) 95 | 96 | cc = [] 97 | 98 | for i in range(len(t1)): 99 | a1 = t1[i] 100 | for p1 in a1: 101 | cc.append(p1) 102 | return cc 103 | 104 | 105 | class Chair_dataset(Dataset): #我定义的这个类 106 | def __init__(self, root_dir, transform=None): 107 | #下面需要使用的变量,在__init__定义好, 108 | self.transform = transform 109 | 110 | file_dir = "../rendered_chairs/" 111 | files = file_name2(file_dir) 112 | data_files = files 113 | data_files = sorted(data_files) 114 | data_files = np.array(data_files) # for tl.iterate.minibatches 115 | chairFiles = data_files 116 | 117 | maxSize = 50000 118 | maxTestingSize = 5000 119 | 120 | # maxSize = 100 121 | # maxTestingSize = 100 122 | 123 | # maxSize = 128 124 | # maxTestingSize = 50 125 | 126 | chairTraining = chairFiles[0:maxSize] 127 | 128 | #img_path = glob.glob('../img_celeba2/*.jpg') # 获取新文件夹下所有图片 129 | 130 | self.img_path = chairTraining #得到整体图片的路径(可取其中的一张一张的图像的名字) 131 | 132 | def __getitem__(self, idx): 133 | # 改写__getitem__(self,item)函数,最后得到图像,标签 134 | #获取具体的一幅图像的名字 135 | img_name = self.img_path[idx] 136 | #获取一幅图像的详细地址 137 | img_item_path = img_name 138 | #用opencv来读取图像 139 | 140 | ''' 141 | img = cv2.imread(img_item_path) 142 | #获取标签(这里简单写了aligned与original) 143 | label = self.label_dir 144 | return img, label 145 | ''' 146 | 147 | ''' 148 | celebaTrainSet = [GetImage_cv( 149 | sample_file, 150 | input_height=128, 151 | input_width=128, 152 | resize_height=64, 153 | resize_width=64, 154 | crop=True) 155 | for sample_file in img_item_path] 156 | ''' 157 | 158 | results = get_image2(img_item_path, 300, is_crop=True, resize_w=image_size, is_grayscale=0) 159 | 160 | results = np.transpose(results, (2, 0, 1)) 161 | #celebaTrainSet = np.array(celebaTrainSet) 162 | #celebaTrainSet = np.transpose(celebaTrainSet, (0, 3, 1, 2)) 163 | 164 | return results 165 | 166 | def __len__(self): 167 | #改写整体图像的大小 168 | return len(self.img_path) 169 | -------------------------------------------------------------------------------- /improved_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel( 102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 103 | ) 104 | 105 | def _scale_timesteps(self, t): 106 | # Scaling is done by the wrapped model. 107 | return t 108 | 109 | 110 | class _WrappedModel: 111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 112 | self.model = model 113 | self.timestep_map = timestep_map 114 | self.rescale_timesteps = rescale_timesteps 115 | self.original_num_steps = original_num_steps 116 | 117 | def __call__(self, x, ts, **kwargs): 118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 119 | new_ts = map_tensor[ts] 120 | if self.rescale_timesteps: 121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 122 | return self.model(x, new_ts, **kwargs) 123 | -------------------------------------------------------------------------------- /datasets/cv2_ImageProcess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import skimage.io as io 4 | from skimage import io, transform 5 | 6 | def center_crop_cv(x, crop_h, crop_w, 7 | resize_h=64, resize_w=64): 8 | if crop_w is None: 9 | crop_w = crop_h 10 | 11 | h, w = x.shape[:2] 12 | 13 | j = int(round((h - crop_h) / 2.)) 14 | i = int(round((w - crop_w) / 2.)) 15 | 16 | tt = x[j:j + crop_h, i:i + crop_w] 17 | tt = np.array(tt) 18 | return cv2.resize(src=tt, dsize=(resize_h, resize_w)) 19 | 20 | def ims_cv(file,image): 21 | image = (image + 1) * 127.5 22 | #cv2.imwrite(file, image) 23 | io.imsave(file, image) 24 | 25 | def ims_cv_255(file,image): 26 | image = image * 255.0 27 | #cv2.imwrite(file, image) 28 | io.imsave(file, image) 29 | 30 | def center_crop2(x, crop_h, crop_w=None, resize_w=64): 31 | # crop the images to [crop_h,crop_w,3] then resize to [resize_h,resize_w,3] 32 | if crop_w is None: 33 | crop_w = crop_h # the width and height after cropped 34 | h, w = x.shape[:2] 35 | 36 | j = int(round((h - crop_h)/2.)) 37 | i = int(round((w - crop_w)/2.)) 38 | return cv2.resize(src=x[j:j+crop_h, i:i+crop_w], dsize=(resize_w, resize_w)) 39 | 40 | 41 | def GetImage_cv2_255(file, image_size, is_crop=True, resize_w=64, is_grayscale = False): 42 | im = cv2.imread(file,cv2.IMREAD_COLOR) 43 | 44 | if is_crop == True: 45 | im = center_crop2(im,image_size, resize_w=resize_w) 46 | else: 47 | im = cv2.resize(src=im, dsize=(resize_w, resize_w)) 48 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 49 | im = im / 255.0 50 | return im 51 | 52 | def GetImage_cv2(file, image_size, is_crop=True, resize_w=64, is_grayscale = False): 53 | #im = cv2.imread(file,cv2) 54 | im = io.imread(file) 55 | 56 | if is_crop == True: 57 | im = center_crop2(im,image_size, resize_w=resize_w) 58 | else: 59 | im = cv2.resize(src=im, dsize=(resize_w, resize_w)) 60 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 61 | im = im / 127.5 -1 62 | return im 63 | 64 | 65 | 66 | def GetImage_cv_01_Low(file,input_height=128, 67 | input_width=128, 68 | resize_height=64, 69 | resize_width=64, 70 | crop=True): 71 | im = io.imread(file,False) 72 | #im = cv2.imread(file,cv2.IMREAD_COLOR) 73 | 74 | if crop == True: 75 | im = center_crop_cv(im,input_height, input_width, resize_height,resize_width) 76 | else: 77 | im = cv2.resize(src=im, dsize=(resize_height, resize_width)) 78 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 79 | im = im / 127.5 -1 80 | return im 81 | 82 | 83 | def GetImage_cv_255_Low(file,input_height=128, 84 | input_width=128, 85 | resize_height=64, 86 | resize_width=64, 87 | crop=True): 88 | im = io.imread(file,False) 89 | #im = cv2.imread(file,cv2.IMREAD_COLOR) 90 | 91 | if crop == True: 92 | im = center_crop_cv(im,input_height, input_width, resize_height,resize_width) 93 | else: 94 | im = cv2.resize(src=im, dsize=(resize_height, resize_width)) 95 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 96 | im = im / 255.0 97 | return im 98 | 99 | import imageio 100 | 101 | def GetImage_cv_255_Low_Specific(file,input_height=128, 102 | input_width=128, 103 | resize_height=64, 104 | resize_width=64, 105 | crop=True): 106 | #im = io.imread(file,False) 107 | im = cv2.imread(file,cv2.IMREAD_GRAYSCALE) 108 | print(file) 109 | #im = imageio.imread(file) 110 | 111 | ''' 112 | if crop == True: 113 | im = center_crop_cv(im,input_height, input_width, resize_height,resize_width) 114 | else: 115 | im = cv2.resize(src=im, dsize=(resize_height, resize_width)) 116 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 117 | ''' 118 | im = im / 255.0 119 | return im 120 | 121 | 122 | def GetImage_cv_255(file,input_height=128, 123 | input_width=128, 124 | resize_height=64, 125 | resize_width=64, 126 | crop=True): 127 | im = io.imread(file) 128 | #im = cv2.imread(file,cv2.IMREAD_COLOR) 129 | 130 | if crop == True: 131 | im = center_crop_cv(im,input_height, input_width, resize_height,resize_width) 132 | else: 133 | im = cv2.resize(src=im, dsize=(resize_height, resize_width)) 134 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 135 | im = im / 255.0 136 | return im 137 | 138 | def GetImage_cv(file,input_height=128, 139 | input_width=128, 140 | resize_height=64, 141 | resize_width=64, 142 | crop=True): 143 | #im = cv2.imread(file) 144 | im = io.imread(file) 145 | 146 | if crop == True: 147 | im = center_crop_cv(im,input_height, input_width, resize_height,resize_width) 148 | else: 149 | im = cv2.resize(src=im, dsize=(resize_height, resize_width)) 150 | # im = cv2.resize(src=im, dsize=(64,64), interpolation=cv2.INTER_LINEAR) 151 | im = im / 127.5 -1 152 | 153 | #im = im[:, :, ::-1] 154 | return im 155 | -------------------------------------------------------------------------------- /improved_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /datasets/myCIFAR100.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Provides access to the CIFAR-10 dataset, including simple data augmentation. 6 | 7 | Author: Jan Schlüter 8 | """ 9 | import os 10 | import sys 11 | 12 | import numpy as np 13 | import pickle 14 | 15 | def download_dataset(path, source='https://www.cs.toronto.edu/~kriz/' 16 | 'cifar-100-python.tar.gz'): 17 | """ 18 | Downloads and extracts the dataset, if needed. 19 | """ 20 | files = ['train', 'test'] 21 | for fn in files: 22 | if not os.path.exists(os.path.join(path, "cifar-100-python", fn)): 23 | break # at least one file is missing 24 | else: 25 | return # dataset is already complete 26 | 27 | print("Downloading and extracting %s into %s..." % (source, path)) 28 | if sys.version_info[0] == 2: 29 | from urllib import urlopen 30 | else: 31 | from urllib.request import urlopen 32 | import tarfile 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | u = urlopen(source) 36 | with tarfile.open(fileobj=u, mode='r|gz') as f: 37 | f.extractall(path=path) 38 | u.close() 39 | 40 | def load_dataset(path): 41 | download_dataset(path) 42 | 43 | # training data 44 | data = pickle.load(open(os.path.join(path, "cifar-100-python", "train"), 'rb'), encoding='latin1') 45 | X_train = data['data'] 46 | y_train = np.asarray(data['fine_labels'], np.int8) 47 | 48 | # test data 49 | data = pickle.load(open(os.path.join(path, 'cifar-100-python', 'test'), 'rb'), encoding='latin1') 50 | X_test = data['data'] 51 | y_test = np.asarray(data['fine_labels'], np.int8) 52 | 53 | # reshape 54 | X_train = X_train.reshape(-1, 3, 32, 32) 55 | X_test = X_test.reshape(-1, 3, 32, 32) 56 | 57 | # normalize 58 | try: 59 | mean_std = np.load(os.path.join(path, 'cifar-100-mean_std.npz')) 60 | mean = mean_std['mean'] 61 | std = mean_std['std'] 62 | except IOError: 63 | mean = X_train.mean(axis=(0,2,3), keepdims=True).astype(np.float32) 64 | std = X_train.std(axis=(0,2,3), keepdims=True).astype(np.float32) 65 | np.savez(os.path.join(path, 'cifar-100-mean_std.npz'), 66 | mean=mean, std=std) 67 | X_train = (X_train - mean) / std 68 | X_test = (X_test - mean) / std 69 | 70 | return X_train, y_train, X_test, y_test 71 | 72 | 73 | def load_dataset2(path): 74 | download_dataset(path) 75 | 76 | # training data 77 | data = pickle.load(open(os.path.join(path, "cifar-100-python", "train"), 'rb'), encoding='latin1') 78 | X_train = data['data'] 79 | y_train = np.asarray(data['fine_labels'], np.int8) 80 | 81 | # test data 82 | data = pickle.load(open(os.path.join(path, 'cifar-100-python', 'test'), 'rb'), encoding='latin1') 83 | X_test = data['data'] 84 | y_test = np.asarray(data['fine_labels'], np.int8) 85 | 86 | # reshape 87 | X_train = X_train.reshape(-1, 3, 32, 32) 88 | X_test = X_test.reshape(-1, 3, 32, 32) 89 | 90 | ''' 91 | # normalize 92 | try: 93 | mean_std = np.load(os.path.join(path, 'cifar-100-mean_std.npz')) 94 | mean = mean_std['mean'] 95 | std = mean_std['std'] 96 | except IOError: 97 | mean = X_train.mean(axis=(0,2,3), keepdims=True).astype(np.float32) 98 | std = X_train.std(axis=(0,2,3), keepdims=True).astype(np.float32) 99 | np.savez(os.path.join(path, 'cifar-100-mean_std.npz'), 100 | mean=mean, std=std) 101 | X_train = (X_train - mean) / std 102 | X_test = (X_test - mean) / std 103 | ''' 104 | 105 | return X_train, y_train, X_test, y_test 106 | 107 | 108 | def iterate_minibatches(inputs, targets, batchsize, shuffle=False): 109 | """ 110 | Generates one epoch of batches of inputs and targets, optionally shuffled. 111 | """ 112 | assert len(inputs) == len(targets) 113 | if shuffle: 114 | indices = np.arange(len(inputs)) 115 | np.random.shuffle(indices) 116 | for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): 117 | if shuffle: 118 | excerpt = indices[start_idx:start_idx + batchsize] 119 | else: 120 | excerpt = slice(start_idx, start_idx + batchsize) 121 | yield inputs[excerpt], targets[excerpt] 122 | 123 | 124 | def augment_minibatches(minibatches, flip=0.5, trans=4): 125 | """ 126 | Randomly augments images by horizontal flipping with a probability of 127 | `flip` and random translation of up to `trans` pixels in both directions. 128 | """ 129 | for inputs, targets in minibatches: 130 | batchsize, c, h, w = inputs.shape 131 | if flip: 132 | coins = np.random.rand(batchsize) < flip 133 | inputs = [inp[:, :, ::-1] if coin else inp 134 | for inp, coin in zip(inputs, coins)] 135 | if not trans: 136 | inputs = np.asarray(inputs) 137 | outputs = inputs 138 | if trans: 139 | outputs = np.empty((batchsize, c, h, w), inputs[0].dtype) 140 | shifts = np.random.randint(-trans, trans, (batchsize, 2)) 141 | for outp, inp, (x, y) in zip(outputs, inputs, shifts): 142 | if x > 0: 143 | outp[:, :x] = 0 144 | outp = outp[:, x:] 145 | inp = inp[:, :-x] 146 | elif x < 0: 147 | outp[:, x:] = 0 148 | outp = outp[:, :x] 149 | inp = inp[:, -x:] 150 | if y > 0: 151 | outp[:, :, :y] = 0 152 | outp = outp[:, :, y:] 153 | inp = inp[:, :, :-y] 154 | elif y < 0: 155 | outp[:, :, y:] = 0 156 | outp = outp[:, :, :y] 157 | inp = inp[:, :, -y:] 158 | outp[:] = inp 159 | yield outputs, targets -------------------------------------------------------------------------------- /improved_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /datasets/MyCIFAR10.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | import pickle 7 | import random 8 | import numpy as np 9 | 10 | class_num = 10 11 | image_size = 32 12 | img_channels = 3 13 | 14 | 15 | # ========================================================== # 16 | # ├─ prepare_data() 17 | # ├─ download training data if not exist by download_data() 18 | # ├─ load data by load_data() 19 | # └─ shuffe and return data 20 | # ========================================================== # 21 | 22 | 23 | def download_data(): 24 | dirname = 'cifar-10-batches-py' 25 | origin = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 26 | fname = 'cifar-10-python.tar.gz' 27 | fpath = './data/' + dirname 28 | 29 | download = False 30 | if os.path.exists(fpath) or os.path.isfile(fname): 31 | download = False 32 | print("DataSet aready exist!") 33 | else: 34 | download = True 35 | if download: 36 | print('Downloading data from', origin) 37 | import urllib.request 38 | import tarfile 39 | 40 | def reporthook(count, block_size, total_size): 41 | global start_time 42 | if count == 0: 43 | start_time = time.time() 44 | return 45 | duration = time.time() - start_time 46 | progress_size = int(count * block_size) 47 | speed = int(progress_size / (1024 * duration)) 48 | percent = min(int(count * block_size * 100 / total_size), 100) 49 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 50 | (percent, progress_size / (1024 * 1024), speed, duration)) 51 | sys.stdout.flush() 52 | 53 | urllib.request.urlretrieve(origin, fname, reporthook) 54 | print('Download finished. Start extract!', origin) 55 | if (fname.endswith("tar.gz")): 56 | tar = tarfile.open(fname, "r:gz") 57 | tar.extractall() 58 | tar.close() 59 | elif (fname.endswith("tar")): 60 | tar = tarfile.open(fname, "r:") 61 | tar.extractall() 62 | tar.close() 63 | 64 | 65 | def unpickle(file): 66 | with open(file, 'rb') as fo: 67 | dict = pickle.load(fo, encoding='bytes') 68 | return dict 69 | 70 | 71 | def load_data_one(file): 72 | batch = unpickle(file) 73 | data = batch[b'data'] 74 | labels = batch[b'labels'] 75 | print("Loading %s : %d." % (file, len(data))) 76 | return data, labels 77 | 78 | 79 | def load_data(files, data_dir, label_count): 80 | global image_size, img_channels 81 | data, labels = load_data_one(data_dir + '/' + files[0]) 82 | for f in files[1:]: 83 | data_n, labels_n = load_data_one(data_dir + '/' + f) 84 | data = np.append(data, data_n, axis=0) 85 | labels = np.append(labels, labels_n, axis=0) 86 | labels = np.array([[float(i == label) for i in range(label_count)] for label in labels]) 87 | data = data.reshape([-1, img_channels, image_size, image_size]) 88 | data = data.transpose([0, 2, 3, 1]) 89 | return data, labels 90 | 91 | 92 | def prepare_data(): 93 | print("======Loading data======") 94 | download_data() 95 | data_dir = './data/cifar-10-batches-py' 96 | image_dim = image_size * image_size * img_channels 97 | meta = unpickle(data_dir + '/batches.meta') 98 | 99 | label_names = meta[b'label_names'] 100 | label_count = len(label_names) 101 | train_files = ['data_batch_%d' % d for d in range(1, 6)] 102 | train_data, train_labels = load_data(train_files, data_dir, label_count) 103 | test_data, test_labels = load_data(['test_batch'], data_dir, label_count) 104 | 105 | print("Train data:", np.shape(train_data), np.shape(train_labels)) 106 | print("Test data :", np.shape(test_data), np.shape(test_labels)) 107 | print("======Load finished======") 108 | 109 | print("======Shuffling data======") 110 | indices = np.random.permutation(len(train_data)) 111 | train_data = train_data[indices] 112 | train_labels = train_labels[indices] 113 | print("======Prepare Finished======") 114 | 115 | return train_data, train_labels, test_data, test_labels 116 | 117 | 118 | # ========================================================== # 119 | # ├─ _random_crop() 120 | # ├─ _random_flip_leftright() 121 | # ├─ data_augmentation() 122 | # └─ color_preprocessing() 123 | # ========================================================== # 124 | 125 | def _random_crop(batch, crop_shape, padding=None): 126 | oshape = np.shape(batch[0]) 127 | 128 | if padding: 129 | oshape = (oshape[0] + 2 * padding, oshape[1] + 2 * padding) 130 | new_batch = [] 131 | npad = ((padding, padding), (padding, padding), (0, 0)) 132 | for i in range(len(batch)): 133 | new_batch.append(batch[i]) 134 | if padding: 135 | new_batch[i] = np.lib.pad(batch[i], pad_width=npad, 136 | mode='constant', constant_values=0) 137 | nh = random.randint(0, oshape[0] - crop_shape[0]) 138 | nw = random.randint(0, oshape[1] - crop_shape[1]) 139 | new_batch[i] = new_batch[i][nh:nh + crop_shape[0], 140 | nw:nw + crop_shape[1]] 141 | return new_batch 142 | 143 | 144 | def _random_flip_leftright(batch): 145 | for i in range(len(batch)): 146 | if bool(random.getrandbits(1)): 147 | batch[i] = np.fliplr(batch[i]) 148 | return batch 149 | 150 | 151 | def color_preprocessing(x_train, x_test): 152 | x_train = x_train.astype('float32') 153 | x_test = x_test.astype('float32') 154 | x_train[:, :, :, 0] = (x_train[:, :, :, 0] - np.mean(x_train[:, :, :, 0])) / np.std(x_train[:, :, :, 0]) 155 | x_train[:, :, :, 1] = (x_train[:, :, :, 1] - np.mean(x_train[:, :, :, 1])) / np.std(x_train[:, :, :, 1]) 156 | x_train[:, :, :, 2] = (x_train[:, :, :, 2] - np.mean(x_train[:, :, :, 2])) / np.std(x_train[:, :, :, 2]) 157 | 158 | x_test[:, :, :, 0] = (x_test[:, :, :, 0] - np.mean(x_test[:, :, :, 0])) / np.std(x_test[:, :, :, 0]) 159 | x_test[:, :, :, 1] = (x_test[:, :, :, 1] - np.mean(x_test[:, :, :, 1])) / np.std(x_test[:, :, :, 1]) 160 | x_test[:, :, :, 2] = (x_test[:, :, :, 2] - np.mean(x_test[:, :, :, 2])) / np.std(x_test[:, :, :, 2]) 161 | 162 | return x_train, x_test 163 | 164 | 165 | def data_augmentation(batch): 166 | batch = _random_flip_leftright(batch) 167 | batch = _random_crop(batch, [32, 32], 4) 168 | return batch -------------------------------------------------------------------------------- /improved_diffusion/train_util_all.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | import cv2 12 | import os 13 | from skimage import io, transform 14 | from cv2_imageProcess import * 15 | 16 | from .nn import update_ema 17 | from .resample import LossAwareSampler, UniformSampler 18 | from .train_util import TrainLoop 19 | 20 | #new add 21 | from improved_diffusion import dist_util, logger 22 | 23 | 24 | # For ImageNet experiments, this was a good default value. 25 | # We found that the lg_loss_scale quickly climbed to 26 | # 20-21 within the first ~1K steps of training. 27 | INITIAL_LOG_LOSS_SCALE = 20.0 28 | 29 | class CelebATrainLoop(TrainLoop): 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | use_fp16=False, 44 | fp16_scale_growth=1e-3, 45 | schedule_sampler=None, 46 | weight_decay=0.0, 47 | lr_anneal_steps=0, 48 | ): 49 | self.model = model 50 | self.diffusion = diffusion 51 | self.data = data 52 | self.batch_size = batch_size 53 | self.microbatch = microbatch if microbatch > 0 else batch_size 54 | self.lr = lr 55 | self.ema_rate = ( 56 | [ema_rate] 57 | if isinstance(ema_rate, float) 58 | else [float(x) for x in ema_rate.split(",")] 59 | ) 60 | self.log_interval = log_interval 61 | self.save_interval = save_interval 62 | self.resume_checkpoint = resume_checkpoint 63 | self.use_fp16 = use_fp16 64 | self.fp16_scale_growth = fp16_scale_growth 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 66 | self.weight_decay = weight_decay 67 | self.lr_anneal_steps = lr_anneal_steps 68 | 69 | self.step = 0 70 | self.resume_step = 0 71 | self.global_batch = self.batch_size * dist.get_world_size() 72 | 73 | self.model_params = list(self.model.parameters()) 74 | self.master_params = self.model_params 75 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 76 | self.sync_cuda = th.cuda.is_available() 77 | 78 | self._load_and_sync_parameters() 79 | if self.use_fp16: 80 | self._setup_fp16() 81 | 82 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 83 | if self.resume_step: 84 | self._load_optimizer_state() 85 | # Model was resumed, either due to a restart or a checkpoint 86 | # being specified at the command line. 87 | self.ema_params = [ 88 | self._load_ema_parameters(rate) for rate in self.ema_rate 89 | ] 90 | else: 91 | self.ema_params = [ 92 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 93 | ] 94 | 95 | if th.cuda.is_available(): 96 | self.use_ddp = True 97 | self.ddp_model = DDP( 98 | self.model, 99 | device_ids=[dist_util.dev()], 100 | output_device=dist_util.dev(), 101 | broadcast_buffers=False, 102 | bucket_cap_mb=128, 103 | find_unused_parameters=False, 104 | ) 105 | else: 106 | if dist.get_world_size() > 1: 107 | logger.warn( 108 | "Distributed training requires CUDA. " 109 | "Gradients will not be synchronized properly!" 110 | ) 111 | self.use_ddp = False 112 | self.ddp_model = self.model 113 | 114 | 115 | def run_loop(self): 116 | while ( 117 | not self.lr_anneal_steps 118 | or self.step + self.resume_step < self.lr_anneal_steps 119 | ): 120 | batch, cond = next(self.data) 121 | self.run_step(batch, cond) 122 | if self.step % self.log_interval == 0: 123 | logger.dumpkvs() 124 | if self.step % self.save_interval == 0: 125 | self.save() 126 | # Run for a finite amount of time in integration tests. 127 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 128 | return 129 | self.step += 1 130 | 131 | if self.step % 500 == 0: 132 | #show generations 133 | self.Generate_Images(self.step) 134 | 135 | # Save the last checkpoint if it wasn't already saved. 136 | if (self.step - 1) % self.save_interval != 0: 137 | self.save() 138 | 139 | 140 | def Generate_Images(self,step): 141 | 142 | model_kwargs = None 143 | if False: 144 | NUM_CLASSES = 10 145 | classes = th.randint( 146 | low=0, high=NUM_CLASSES, size=(64,), device=dist_util.dev() 147 | ) 148 | model_kwargs["y"] = classes 149 | 150 | use_ddim = True 151 | sample_fn = ( 152 | self.diffusion.p_sample_loop if not use_ddim else self.diffusion.ddim_sample_loop 153 | ) 154 | 155 | sample = sample_fn( 156 | self.model, 157 | (10, 3, 64, 64), 158 | clip_denoised=True, 159 | model_kwargs=model_kwargs, 160 | ) 161 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 162 | sample = sample.permute(0, 2, 3, 1) 163 | sample = sample.contiguous() 164 | mySamples = sample.unsqueeze(0).cuda().cpu() 165 | mySamples = np.array(mySamples) 166 | 167 | mySamples = mySamples[0] 168 | 169 | out1 = merge2(mySamples[:10], [1, 10]) 170 | print(np.shape(out1)) 171 | 172 | path = "results/" 173 | #cv2.imwrite(os.path.join(path, 'waka.jpg'), mySamples) 174 | name = "CelebA_generated_" + str(step) + ".png" 175 | cv2.imwrite("/scratch/fy689/improved-diffusion-main/results/"+name, out1) 176 | cv2.waitKey(0) 177 | #io.imsave("/scratch/fy689/improved-diffusion-main/results/aa.png", mySamples[0]) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /models/cvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class ConditionalVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | num_classes: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | img_size:int = 64, 16 | **kwargs) -> None: 17 | super(ConditionalVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.img_size = img_size 21 | 22 | self.embed_class = nn.Linear(num_classes, img_size * img_size) 23 | self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1) 24 | 25 | modules = [] 26 | if hidden_dims is None: 27 | hidden_dims = [32, 64, 128, 256, 512] 28 | 29 | in_channels += 1 # To account for the extra label channel 30 | # Build Encoder 31 | for h_dim in hidden_dims: 32 | modules.append( 33 | nn.Sequential( 34 | nn.Conv2d(in_channels, out_channels=h_dim, 35 | kernel_size= 3, stride= 2, padding = 1), 36 | nn.BatchNorm2d(h_dim), 37 | nn.LeakyReLU()) 38 | ) 39 | in_channels = h_dim 40 | 41 | self.encoder = nn.Sequential(*modules) 42 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 43 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 44 | 45 | 46 | # Build Decoder 47 | modules = [] 48 | 49 | self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4) 50 | 51 | hidden_dims.reverse() 52 | 53 | for i in range(len(hidden_dims) - 1): 54 | modules.append( 55 | nn.Sequential( 56 | nn.ConvTranspose2d(hidden_dims[i], 57 | hidden_dims[i + 1], 58 | kernel_size=3, 59 | stride = 2, 60 | padding=1, 61 | output_padding=1), 62 | nn.BatchNorm2d(hidden_dims[i + 1]), 63 | nn.LeakyReLU()) 64 | ) 65 | 66 | 67 | 68 | self.decoder = nn.Sequential(*modules) 69 | 70 | self.final_layer = nn.Sequential( 71 | nn.ConvTranspose2d(hidden_dims[-1], 72 | hidden_dims[-1], 73 | kernel_size=3, 74 | stride=2, 75 | padding=1, 76 | output_padding=1), 77 | nn.BatchNorm2d(hidden_dims[-1]), 78 | nn.LeakyReLU(), 79 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 80 | kernel_size= 3, padding= 1), 81 | nn.Tanh()) 82 | 83 | def encode(self, input: Tensor) -> List[Tensor]: 84 | """ 85 | Encodes the input by passing through the encoder network 86 | and returns the latent codes. 87 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 88 | :return: (Tensor) List of latent codes 89 | """ 90 | result = self.encoder(input) 91 | result = torch.flatten(result, start_dim=1) 92 | 93 | # Split the result into mu and var components 94 | # of the latent Gaussian distribution 95 | mu = self.fc_mu(result) 96 | log_var = self.fc_var(result) 97 | 98 | return [mu, log_var] 99 | 100 | def decode(self, z: Tensor) -> Tensor: 101 | result = self.decoder_input(z) 102 | result = result.view(-1, 512, 2, 2) 103 | result = self.decoder(result) 104 | result = self.final_layer(result) 105 | return result 106 | 107 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 108 | """ 109 | Will a single z be enough ti compute the expectation 110 | for the loss?? 111 | :param mu: (Tensor) Mean of the latent Gaussian 112 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 113 | :return: 114 | """ 115 | std = torch.exp(0.5 * logvar) 116 | eps = torch.randn_like(std) 117 | return eps * std + mu 118 | 119 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 120 | y = kwargs['labels'].float() 121 | embedded_class = self.embed_class(y) 122 | embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1) 123 | embedded_input = self.embed_data(input) 124 | 125 | x = torch.cat([embedded_input, embedded_class], dim = 1) 126 | mu, log_var = self.encode(x) 127 | 128 | z = self.reparameterize(mu, log_var) 129 | 130 | z = torch.cat([z, y], dim = 1) 131 | return [self.decode(z), input, mu, log_var] 132 | 133 | def loss_function(self, 134 | *args, 135 | **kwargs) -> dict: 136 | recons = args[0] 137 | input = args[1] 138 | mu = args[2] 139 | log_var = args[3] 140 | 141 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 142 | recons_loss =F.mse_loss(recons, input) 143 | 144 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 145 | 146 | loss = recons_loss + kld_weight * kld_loss 147 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 148 | 149 | def sample(self, 150 | num_samples:int, 151 | current_device: int, 152 | **kwargs) -> Tensor: 153 | """ 154 | Samples from the latent space and return the corresponding 155 | image space map. 156 | :param num_samples: (Int) Number of samples 157 | :param current_device: (Int) Device to run the model 158 | :return: (Tensor) 159 | """ 160 | y = kwargs['labels'].float() 161 | z = torch.randn(num_samples, 162 | self.latent_dim) 163 | 164 | z = z.to(current_device) 165 | 166 | z = torch.cat([z, y], dim=1) 167 | samples = self.decode(z) 168 | return samples 169 | 170 | def generate(self, x: Tensor, **kwargs) -> Tensor: 171 | """ 172 | Given an input image x, returns the reconstructed image 173 | :param x: (Tensor) [B x C x H x W] 174 | :return: (Tensor) [B x C x H x W] 175 | """ 176 | 177 | return self.forward(x, **kwargs)[0] -------------------------------------------------------------------------------- /models/beta_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class BetaVAE(BaseVAE): 9 | 10 | num_iter = 0 # Global static variable to keep track of iterations 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | latent_dim: int, 15 | hidden_dims: List = None, 16 | beta: int = 4, 17 | gamma:float = 1000., 18 | max_capacity: int = 25, 19 | Capacity_max_iter: int = 1e5, 20 | loss_type:str = 'B', 21 | **kwargs) -> None: 22 | super(BetaVAE, self).__init__() 23 | 24 | self.latent_dim = latent_dim 25 | self.beta = beta 26 | self.gamma = gamma 27 | self.loss_type = loss_type 28 | self.C_max = torch.Tensor([max_capacity]) 29 | self.C_stop_iter = Capacity_max_iter 30 | 31 | modules = [] 32 | if hidden_dims is None: 33 | hidden_dims = [32, 64, 128, 256, 512] 34 | 35 | # Build Encoder 36 | for h_dim in hidden_dims: 37 | modules.append( 38 | nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels=h_dim, 40 | kernel_size= 3, stride= 2, padding = 1), 41 | nn.BatchNorm2d(h_dim), 42 | nn.LeakyReLU()) 43 | ) 44 | in_channels = h_dim 45 | 46 | self.encoder = nn.Sequential(*modules) 47 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 48 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 49 | 50 | 51 | # Build Decoder 52 | modules = [] 53 | 54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | 88 | def encode(self, input: Tensor) -> List[Tensor]: 89 | """ 90 | Encodes the input by passing through the encoder network 91 | and returns the latent codes. 92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 93 | :return: (Tensor) List of latent codes 94 | """ 95 | result = self.encoder(input) 96 | result = torch.flatten(result, start_dim=1) 97 | 98 | # Split the result into mu and var components 99 | # of the latent Gaussian distribution 100 | mu = self.fc_mu(result) 101 | log_var = self.fc_var(result) 102 | 103 | return [mu, log_var] 104 | 105 | def decode(self, z: Tensor) -> Tensor: 106 | result = self.decoder_input(z) 107 | result = result.view(-1, 512, 2, 2) 108 | result = self.decoder(result) 109 | result = self.final_layer(result) 110 | return result 111 | 112 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 113 | """ 114 | Will a single z be enough ti compute the expectation 115 | for the loss?? 116 | :param mu: (Tensor) Mean of the latent Gaussian 117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 118 | :return: 119 | """ 120 | std = torch.exp(0.5 * logvar) 121 | eps = torch.randn_like(std) 122 | return eps * std + mu 123 | 124 | def forward(self, input: Tensor, **kwargs) -> Tensor: 125 | mu, log_var = self.encode(input) 126 | z = self.reparameterize(mu, log_var) 127 | return [self.decode(z), input, mu, log_var] 128 | 129 | def loss_function(self, 130 | *args, 131 | **kwargs) -> dict: 132 | self.num_iter += 1 133 | recons = args[0] 134 | input = args[1] 135 | mu = args[2] 136 | log_var = args[3] 137 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 138 | 139 | recons_loss =F.mse_loss(recons, input) 140 | 141 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 142 | 143 | if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl 144 | loss = recons_loss + self.beta * kld_weight * kld_loss 145 | elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf 146 | self.C_max = self.C_max.to(input.device) 147 | C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0]) 148 | loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs() 149 | else: 150 | raise ValueError('Undefined loss type.') 151 | 152 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss} 153 | 154 | def sample(self, 155 | num_samples:int, 156 | current_device: int, **kwargs) -> Tensor: 157 | """ 158 | Samples from the latent space and return the corresponding 159 | image space map. 160 | :param num_samples: (Int) Number of samples 161 | :param current_device: (Int) Device to run the model 162 | :return: (Tensor) 163 | """ 164 | z = torch.randn(num_samples, 165 | self.latent_dim) 166 | 167 | z = z.to(current_device) 168 | 169 | samples = self.decode(z) 170 | return samples 171 | 172 | def generate(self, x: Tensor, **kwargs) -> Tensor: 173 | """ 174 | Given an input image x, returns the reconstructed image 175 | :param x: (Tensor) [B x C x H x W] 176 | :return: (Tensor) [B x C x H x W] 177 | """ 178 | 179 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/logcosh_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models import BaseVAE 4 | from torch import nn 5 | from .types_ import * 6 | 7 | 8 | class LogCoshVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | alpha: float = 100., 15 | beta: float = 10., 16 | **kwargs) -> None: 17 | super(LogCoshVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.alpha = alpha 21 | self.beta = beta 22 | 23 | modules = [] 24 | if hidden_dims is None: 25 | hidden_dims = [32, 64, 128, 256, 512] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | 46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 47 | 48 | hidden_dims.reverse() 49 | 50 | for i in range(len(hidden_dims) - 1): 51 | modules.append( 52 | nn.Sequential( 53 | nn.ConvTranspose2d(hidden_dims[i], 54 | hidden_dims[i + 1], 55 | kernel_size=3, 56 | stride = 2, 57 | padding=1, 58 | output_padding=1), 59 | nn.BatchNorm2d(hidden_dims[i + 1]), 60 | nn.LeakyReLU()) 61 | ) 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | def encode(self, input: Tensor) -> List[Tensor]: 79 | """ 80 | Encodes the input by passing through the encoder network 81 | and returns the latent codes. 82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 83 | :return: (Tensor) List of latent codes 84 | """ 85 | result = self.encoder(input) 86 | result = torch.flatten(result, start_dim=1) 87 | 88 | # Split the result into mu and var components 89 | # of the latent Gaussian distribution 90 | mu = self.fc_mu(result) 91 | log_var = self.fc_var(result) 92 | 93 | return [mu, log_var] 94 | 95 | def decode(self, z: Tensor) -> Tensor: 96 | """ 97 | Maps the given latent codes 98 | onto the image space. 99 | :param z: (Tensor) [B x D] 100 | :return: (Tensor) [B x C x H x W] 101 | """ 102 | result = self.decoder_input(z) 103 | result = result.view(-1, 512, 2, 2) 104 | result = self.decoder(result) 105 | result = self.final_layer(result) 106 | return result 107 | 108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 109 | """ 110 | Reparameterization trick to sample from N(mu, var) from 111 | N(0,1). 112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 114 | :return: (Tensor) [B x D] 115 | """ 116 | std = torch.exp(0.5 * logvar) 117 | eps = torch.randn_like(std) 118 | return eps * std + mu 119 | 120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 121 | mu, log_var = self.encode(input) 122 | z = self.reparameterize(mu, log_var) 123 | return [self.decode(z), input, mu, log_var] 124 | 125 | def loss_function(self, 126 | *args, 127 | **kwargs) -> dict: 128 | """ 129 | Computes the VAE loss function. 130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 131 | :param args: 132 | :param kwargs: 133 | :return: 134 | """ 135 | recons = args[0] 136 | input = args[1] 137 | mu = args[2] 138 | log_var = args[3] 139 | 140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 141 | t = recons - input 142 | # recons_loss = F.mse_loss(recons, input) 143 | # cosh = torch.cosh(self.alpha * t) 144 | # recons_loss = (1./self.alpha * torch.log(cosh)).mean() 145 | 146 | recons_loss = self.alpha * t + \ 147 | torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \ 148 | torch.log(torch.tensor(2.0)) 149 | # print(self.alpha* t.max(), self.alpha*t.min()) 150 | recons_loss = (1. / self.alpha) * recons_loss.mean() 151 | 152 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 153 | 154 | loss = recons_loss + self.beta * kld_weight * kld_loss 155 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 156 | 157 | def sample(self, 158 | num_samples:int, 159 | current_device: int, **kwargs) -> Tensor: 160 | """ 161 | Samples from the latent space and return the corresponding 162 | image space map. 163 | :param num_samples: (Int) Number of samples 164 | :param current_device: (Int) Device to run the model 165 | :return: (Tensor) 166 | """ 167 | z = torch.randn(num_samples, 168 | self.latent_dim) 169 | 170 | z = z.to(current_device) 171 | 172 | samples = self.decode(z) 173 | return samples 174 | 175 | def generate(self, x: Tensor, **kwargs) -> Tensor: 176 | """ 177 | Given an input image x, returns the reconstructed image 178 | :param x: (Tensor) [B x C x H x W] 179 | :return: (Tensor) [B x C x H x W] 180 | """ 181 | 182 | return self.forward(x)[0] -------------------------------------------------------------------------------- /Task_Split/sinkhorn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pykeops.torch as keops 4 | import torch 5 | 6 | import tqdm 7 | 8 | def sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2, 9 | w_x: Union[torch.Tensor, None] = None, 10 | w_y: Union[torch.Tensor, None] = None, 11 | eps: float = 1e-3, 12 | max_iters: int = 100, stop_thresh: float = 1e-5, 13 | verbose=False): 14 | """ 15 | Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds 16 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. 17 | Note that this algorithm can be backpropped through 18 | (though this may be slow if using many iterations). 19 | 20 | :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row) 21 | :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row) 22 | :param p: Which norm to use. Must be an integer greater than 0. 23 | :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 24 | :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 25 | :param eps: The reciprocal of the sinkhorn entropy regularization parameter. 26 | :param max_iters: The maximum number of Sinkhorn iterations to perform. 27 | :param stop_thresh: Stop if the maximum change in the parameters is below this amount 28 | :param verbose: Print iterations 29 | :return: a triple (d, corrs_x_to_y, corr_y_to_x) where: 30 | * d is the approximate p-wasserstein distance between point clouds x and y 31 | * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair) 32 | * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair) 33 | """ 34 | 35 | if not isinstance(p, int): 36 | raise TypeError(f"p must be an integer greater than 0, got {p}") 37 | if p <= 0: 38 | raise ValueError(f"p must be an integer greater than 0, got {p}") 39 | 40 | if eps <= 0: 41 | raise ValueError("Entropy regularization term eps must be > 0") 42 | 43 | if not isinstance(p, int): 44 | raise TypeError(f"max_iters must be an integer > 0, got {max_iters}") 45 | if max_iters <= 0: 46 | raise ValueError(f"max_iters must be an integer > 0, got {max_iters}") 47 | 48 | if not isinstance(stop_thresh, float): 49 | raise TypeError(f"stop_thresh must be a float, got {stop_thresh}") 50 | 51 | if len(x.shape) != 2: 52 | raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}") 53 | if len(y.shape) != 2: 54 | raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}") 55 | if x.shape[1] != y.shape[1]: 56 | raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], " 57 | f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}") 58 | 59 | if w_x is not None: 60 | if w_y is None: 61 | raise ValueError("If w_x is not None, w_y must also be not None") 62 | if len(w_x.shape) > 1: 63 | w_x = w_x.squeeze() 64 | if len(w_x.shape) != 1: 65 | raise ValueError(f"w_x must have shape [n,] or [n, 1] " 66 | f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}") 67 | if w_x.shape[0] != x.shape[0]: 68 | raise ValueError(f"w_x must match the shape of x in dimension 0 but got " 69 | f"x.shape = {x.shape} and w_x.shape = {w_x.shape}") 70 | if w_y is not None: 71 | if w_x is None: 72 | raise ValueError("If w_y is not None, w_x must also be not None") 73 | if len(w_y.shape) > 1: 74 | w_y = w_y.squeeze() 75 | if len(w_y.shape) != 1: 76 | raise ValueError(f"w_y must have shape [n,] or [n, 1] " 77 | f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}") 78 | if w_x.shape[0] != x.shape[0]: 79 | raise ValueError(f"w_y must match the shape of y in dimension 0 but got " 80 | f"y.shape = {y.shape} and w_y.shape = {w_y.shape}") 81 | 82 | 83 | # Distance matrix [n, m] 84 | x_i = keops.Vi(x) # [n, 1, d] 85 | y_j = keops.Vj(y) # [i, m, d] 86 | if p == 1: 87 | M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m] 88 | else: 89 | M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p) # [n, m] 90 | 91 | # Weights [n,] and [m,] 92 | if w_x is None and w_y is None: 93 | w_x = torch.ones(x.shape[0]).to(x) / x.shape[0] 94 | w_y = torch.ones(y.shape[0]).to(x) / y.shape[0] 95 | w_y *= (w_x.shape[0] / w_y.shape[0]) 96 | 97 | sum_w_x = w_x.sum().item() 98 | sum_w_y = w_y.sum().item() 99 | if abs(sum_w_x - sum_w_y) > 1e-5: 100 | raise ValueError(f"Weights w_x and w_y do not sum to the same value, " 101 | f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} " 102 | f"(absolute difference = {abs(sum_w_x - sum_w_y)}") 103 | 104 | log_a = torch.log(w_x) # [n] 105 | log_b = torch.log(w_y) # [m] 106 | 107 | # Initialize the iteration with the change of variable 108 | u = torch.zeros_like(w_x) 109 | v = eps * torch.log(w_y) 110 | 111 | u_i = keops.Vi(u.unsqueeze(-1)) 112 | v_j = keops.Vj(v.unsqueeze(-1)) 113 | 114 | if verbose: 115 | pbar = tqdm.trange(max_iters) 116 | else: 117 | pbar = range(max_iters) 118 | 119 | for _ in pbar: 120 | u_prev = u 121 | v_prev = v 122 | 123 | summand_u = (-M_ij + v_j) / eps 124 | u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze()) 125 | u_i = keops.Vi(u.unsqueeze(-1)) 126 | 127 | summand_v = (-M_ij + u_i) / eps 128 | v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze()) 129 | v_j = keops.Vj(v.unsqueeze(-1)) 130 | 131 | max_err_u = torch.max(torch.abs(u_prev-u)) 132 | max_err_v = torch.max(torch.abs(v_prev-v)) 133 | if verbose: 134 | pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()}) 135 | if max_err_u < stop_thresh and max_err_v < stop_thresh: 136 | break 137 | 138 | P_ij = ((-M_ij + u_i + v_j) / eps).exp() 139 | 140 | approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1) 141 | approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1) 142 | 143 | if u.shape[0] > v.shape[0]: 144 | distance = (P_ij * M_ij).sum(dim=1).sum() 145 | else: 146 | distance = (P_ij * M_ij).sum(dim=0).sum() 147 | return distance, approx_corr_1, approx_corr_2 148 | -------------------------------------------------------------------------------- /Task_Split/MMD_Tool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | import torch 6 | 7 | min_var_est = 1e-8 8 | 9 | 10 | # Consider linear time MMD with a linear kernel: 11 | # K(f(x), f(y)) = f(x)^Tf(y) 12 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 13 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 14 | # 15 | # f_of_X: batch_size * k 16 | # f_of_Y: batch_size * k 17 | def linear_mmd2(f_of_X, f_of_Y): 18 | loss = 0.0 19 | delta = f_of_X - f_of_Y 20 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 21 | return loss 22 | 23 | 24 | # Consider linear time MMD with a polynomial kernel: 25 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 26 | # f_of_X: batch_size * k 27 | # f_of_Y: batch_size * k 28 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 29 | K_XX = (alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c) 30 | K_XX_mean = torch.mean(K_XX.pow(d)) 31 | 32 | K_YY = (alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c) 33 | K_YY_mean = torch.mean(K_YY.pow(d)) 34 | 35 | K_XY = (alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c) 36 | K_XY_mean = torch.mean(K_XY.pow(d)) 37 | 38 | K_YX = (alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c) 39 | K_YX_mean = torch.mean(K_YX.pow(d)) 40 | 41 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 42 | 43 | 44 | def _mix_rbf_kernel(X, Y, sigma_list): 45 | assert(X.size(0) == Y.size(0)) 46 | m = X.size(0) 47 | 48 | Z = torch.cat((X, Y), 0) 49 | ZZT = torch.mm(Z, Z.t()) 50 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 51 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 52 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 53 | 54 | K = 0.0 55 | for sigma in sigma_list: 56 | gamma = 1.0 / (2 * sigma**2) 57 | K += torch.exp(-gamma * exponent) 58 | 59 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 60 | 61 | 62 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 63 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 64 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 65 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 66 | 67 | 68 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 69 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 70 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 71 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 72 | 73 | 74 | ################################################################################ 75 | # Helper functions to compute variances based on kernel matrices 76 | ################################################################################ 77 | 78 | 79 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 80 | m = K_XX.size(0) # assume X, Y are same shape 81 | 82 | # Get the various sums of kernels that we'll use 83 | # Kts drop the diagonal, but we don't need to compute them explicitly 84 | if const_diagonal is not False: 85 | diag_X = diag_Y = const_diagonal 86 | sum_diag_X = sum_diag_Y = m * const_diagonal 87 | else: 88 | diag_X = torch.diag(K_XX) # (m,) 89 | diag_Y = torch.diag(K_YY) # (m,) 90 | sum_diag_X = torch.sum(diag_X) 91 | sum_diag_Y = torch.sum(diag_Y) 92 | 93 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 94 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 95 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 96 | 97 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 98 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 99 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 100 | 101 | if biased: 102 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 103 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 104 | - 2.0 * K_XY_sum / (m * m)) 105 | else: 106 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 107 | + Kt_YY_sum / (m * (m - 1)) 108 | - 2.0 * K_XY_sum / (m * m)) 109 | 110 | return mmd2 111 | 112 | 113 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 114 | mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased) 115 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 116 | return loss, mmd2, var_est 117 | 118 | 119 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 120 | m = K_XX.size(0) # assume X, Y are same shape 121 | 122 | # Get the various sums of kernels that we'll use 123 | # Kts drop the diagonal, but we don't need to compute them explicitly 124 | if const_diagonal is not False: 125 | diag_X = diag_Y = const_diagonal 126 | sum_diag_X = sum_diag_Y = m * const_diagonal 127 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 128 | else: 129 | diag_X = torch.diag(K_XX) # (m,) 130 | diag_Y = torch.diag(K_YY) # (m,) 131 | sum_diag_X = torch.sum(diag_X) 132 | sum_diag_Y = torch.sum(diag_Y) 133 | sum_diag2_X = diag_X.dot(diag_X) 134 | sum_diag2_Y = diag_Y.dot(diag_Y) 135 | 136 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 137 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 138 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 139 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 140 | 141 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 142 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 143 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 144 | 145 | Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 146 | Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 147 | K_XY_2_sum = (K_XY ** 2).sum() # \| K_{XY} \|_F^2 148 | 149 | if biased: 150 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 151 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 152 | - 2.0 * K_XY_sum / (m * m)) 153 | else: 154 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 155 | + Kt_YY_sum / (m * (m - 1)) 156 | - 2.0 * K_XY_sum / (m * m)) 157 | 158 | var_est = ( 159 | 2.0 / (m**2 * (m - 1.0)**2) * (2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum) 160 | - (4.0*m - 6.0) / (m**3 * (m - 1.0)**3) * (Kt_XX_sum**2 + Kt_YY_sum**2) 161 | + 4.0*(m - 2.0) / (m**3 * (m - 1.0)**2) * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 162 | - 4.0*(m - 3.0) / (m**3 * (m - 1.0)**2) * (K_XY_2_sum) - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 163 | + 8.0 / (m**3 * (m - 1.0)) * ( 164 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 165 | - Kt_XX_sums.dot(K_XY_sums_1) 166 | - Kt_YY_sums.dot(K_XY_sums_0)) 167 | ) 168 | return mmd2, var_est -------------------------------------------------------------------------------- /models/dip_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class DIPVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | lambda_diag: float = 10., 15 | lambda_offdiag: float = 5., 16 | **kwargs) -> None: 17 | super(DIPVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.lambda_diag = lambda_diag 21 | self.lambda_offdiag = lambda_offdiag 22 | 23 | modules = [] 24 | if hidden_dims is None: 25 | hidden_dims = [32, 64, 128, 256, 512] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | 46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 47 | 48 | hidden_dims.reverse() 49 | 50 | for i in range(len(hidden_dims) - 1): 51 | modules.append( 52 | nn.Sequential( 53 | nn.ConvTranspose2d(hidden_dims[i], 54 | hidden_dims[i + 1], 55 | kernel_size=3, 56 | stride = 2, 57 | padding=1, 58 | output_padding=1), 59 | nn.BatchNorm2d(hidden_dims[i + 1]), 60 | nn.LeakyReLU()) 61 | ) 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | def encode(self, input: Tensor) -> List[Tensor]: 79 | """ 80 | Encodes the input by passing through the encoder network 81 | and returns the latent codes. 82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 83 | :return: (Tensor) List of latent codes 84 | """ 85 | result = self.encoder(input) 86 | result = torch.flatten(result, start_dim=1) 87 | 88 | # Split the result into mu and var components 89 | # of the latent Gaussian distribution 90 | mu = self.fc_mu(result) 91 | log_var = self.fc_var(result) 92 | 93 | return [mu, log_var] 94 | 95 | def decode(self, z: Tensor) -> Tensor: 96 | """ 97 | Maps the given latent codes 98 | onto the image space. 99 | :param z: (Tensor) [B x D] 100 | :return: (Tensor) [B x C x H x W] 101 | """ 102 | result = self.decoder_input(z) 103 | result = result.view(-1, 512, 2, 2) 104 | result = self.decoder(result) 105 | result = self.final_layer(result) 106 | return result 107 | 108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 109 | """ 110 | Reparameterization trick to sample from N(mu, var) from 111 | N(0,1). 112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 114 | :return: (Tensor) [B x D] 115 | """ 116 | std = torch.exp(0.5 * logvar) 117 | eps = torch.randn_like(std) 118 | return eps * std + mu 119 | 120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 121 | mu, log_var = self.encode(input) 122 | z = self.reparameterize(mu, log_var) 123 | return [self.decode(z), input, mu, log_var] 124 | 125 | def loss_function(self, 126 | *args, 127 | **kwargs) -> dict: 128 | """ 129 | Computes the VAE loss function. 130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 131 | :param args: 132 | :param kwargs: 133 | :return: 134 | """ 135 | recons = args[0] 136 | input = args[1] 137 | mu = args[2] 138 | log_var = args[3] 139 | 140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 141 | recons_loss =F.mse_loss(recons, input, reduction='sum') 142 | 143 | 144 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 145 | 146 | # DIP Loss 147 | centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D] 148 | cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D] 149 | 150 | # Add Variance for DIP Loss II 151 | cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D] 152 | # For DIp Loss I 153 | # cov_z = cov_mu 154 | 155 | cov_diag = torch.diag(cov_z) # [D] 156 | cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D] 157 | dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \ 158 | self.lambda_diag * torch.sum((cov_diag - 1) ** 2) 159 | 160 | loss = recons_loss + kld_weight * kld_loss + dip_loss 161 | return {'loss': loss, 162 | 'Reconstruction_Loss':recons_loss, 163 | 'KLD':-kld_loss, 164 | 'DIP_Loss':dip_loss} 165 | 166 | def sample(self, 167 | num_samples:int, 168 | current_device: int, **kwargs) -> Tensor: 169 | """ 170 | Samples from the latent space and return the corresponding 171 | image space map. 172 | :param num_samples: (Int) Number of samples 173 | :param current_device: (Int) Device to run the model 174 | :return: (Tensor) 175 | """ 176 | z = torch.randn(num_samples, 177 | self.latent_dim) 178 | 179 | z = z.to(current_device) 180 | 181 | samples = self.decode(z) 182 | return samples 183 | 184 | def generate(self, x: Tensor, **kwargs) -> Tensor: 185 | """ 186 | Given an input image x, returns the reconstructed image 187 | :param x: (Tensor) [B x C x H x W] 188 | :return: (Tensor) [B x C x H x W] 189 | """ 190 | 191 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/vampvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class VampVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | num_components: int = 50, 15 | **kwargs) -> None: 16 | super(VampVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | self.num_components = num_components 20 | 21 | modules = [] 22 | if hidden_dims is None: 23 | hidden_dims = [32, 64, 128, 256, 512] 24 | 25 | # Build Encoder 26 | for h_dim in hidden_dims: 27 | modules.append( 28 | nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels=h_dim, 30 | kernel_size= 3, stride= 2, padding = 1), 31 | nn.BatchNorm2d(h_dim), 32 | nn.LeakyReLU()) 33 | ) 34 | in_channels = h_dim 35 | 36 | self.encoder = nn.Sequential(*modules) 37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | 40 | 41 | # Build Decoder 42 | modules = [] 43 | 44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 45 | 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | 61 | 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | self.pseudo_input = torch.eye(self.num_components, requires_grad= False) 79 | self.embed_pseudo = nn.Sequential(nn.Linear(self.num_components, 12288), 80 | nn.Hardtanh(0.0, 1.0)) # 3x64x64 = 12288 81 | 82 | def encode(self, input: Tensor) -> List[Tensor]: 83 | """ 84 | Encodes the input by passing through the encoder network 85 | and returns the latent codes. 86 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 87 | :return: (Tensor) List of latent codes 88 | """ 89 | result = self.encoder(input) 90 | result = torch.flatten(result, start_dim=1) 91 | 92 | # Split the result into mu and var components 93 | # of the latent Gaussian distribution 94 | mu = self.fc_mu(result) 95 | log_var = self.fc_var(result) 96 | 97 | return [mu, log_var] 98 | 99 | def decode(self, z: Tensor) -> Tensor: 100 | result = self.decoder_input(z) 101 | result = result.view(-1, 512, 2, 2) 102 | result = self.decoder(result) 103 | result = self.final_layer(result) 104 | return result 105 | 106 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 107 | """ 108 | Will a single z be enough ti compute the expectation 109 | for the loss?? 110 | :param mu: (Tensor) Mean of the latent Gaussian 111 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 112 | :return: 113 | """ 114 | std = torch.exp(0.5 * logvar) 115 | eps = torch.randn_like(std) 116 | return eps * std + mu 117 | 118 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 119 | mu, log_var = self.encode(input) 120 | z = self.reparameterize(mu, log_var) 121 | return [self.decode(z), input, mu, log_var, z] 122 | 123 | def loss_function(self, 124 | *args, 125 | **kwargs) -> dict: 126 | recons = args[0] 127 | input = args[1] 128 | mu = args[2] 129 | log_var = args[3] 130 | z = args[4] 131 | 132 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 133 | recons_loss =F.mse_loss(recons, input) 134 | 135 | E_log_q_z = torch.mean(torch.sum(-0.5 * (log_var + (z - mu) ** 2)/ log_var.exp(), 136 | dim = 1), 137 | dim = 0) 138 | 139 | # Original Prior 140 | # E_log_p_z = torch.mean(torch.sum(-0.5 * (z ** 2), dim = 1), dim = 0) 141 | 142 | # Vamp Prior 143 | M, C, H, W = input.size() 144 | curr_device = input.device 145 | self.pseudo_input = self.pseudo_input.cuda(curr_device) 146 | x = self.embed_pseudo(self.pseudo_input) 147 | x = x.view(-1, C, H, W) 148 | prior_mu, prior_log_var = self.encode(x) 149 | 150 | z_expand = z.unsqueeze(1) 151 | prior_mu = prior_mu.unsqueeze(0) 152 | prior_log_var = prior_log_var.unsqueeze(0) 153 | 154 | E_log_p_z = torch.sum(-0.5 * 155 | (prior_log_var + (z_expand - prior_mu) ** 2)/ prior_log_var.exp(), 156 | dim = 2) - torch.log(torch.tensor(self.num_components).float()) 157 | 158 | # dim = 0) 159 | E_log_p_z = torch.logsumexp(E_log_p_z, dim = 1) 160 | E_log_p_z = torch.mean(E_log_p_z, dim = 0) 161 | 162 | # KLD = E_q log q - E_q log p 163 | kld_loss = -(E_log_p_z - E_log_q_z) 164 | # print(E_log_p_z, E_log_q_z) 165 | 166 | 167 | loss = recons_loss + kld_weight * kld_loss 168 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 169 | 170 | def sample(self, 171 | num_samples:int, 172 | current_device: int, **kwargs) -> Tensor: 173 | """ 174 | Samples from the latent space and return the corresponding 175 | image space map. 176 | :param num_samples: (Int) Number of samples 177 | :param current_device: (Int) Device to run the model 178 | :return: (Tensor) 179 | """ 180 | z = torch.randn(num_samples, 181 | self.latent_dim) 182 | 183 | z = z.cuda(current_device) 184 | 185 | samples = self.decode(z) 186 | return samples 187 | 188 | def generate(self, x: Tensor, **kwargs) -> Tensor: 189 | """ 190 | Given an input image x, returns the reconstructed image 191 | :param x: (Tensor) [B x C x H x W] 192 | :return: (Tensor) [B x C x H x W] 193 | """ 194 | 195 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/FullyVAE_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from models import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | import numpy as np 7 | 8 | 9 | class FullVAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | input_size: int, 15 | beta: int, 16 | hidden_dims: List = None, 17 | **kwargs) -> None: 18 | super(FullVAE, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.input_size = input_size 22 | self.beta = beta 23 | 24 | modules = [] 25 | self.label1 = nn.Sequential( 26 | nn.Linear(input_size*input_size*3, 2000), 27 | nn.LeakyReLU()) 28 | modules.append(self.label1) 29 | 30 | self.label2 = nn.Sequential( 31 | nn.Linear(2000, 1000), 32 | nn.LeakyReLU()) 33 | modules.append(self.label2) 34 | 35 | self.label3 = nn.Sequential( 36 | nn.Linear(1000, 800), 37 | nn.LeakyReLU()) 38 | modules.append(self.label3) 39 | 40 | self.label4 = nn.Sequential( 41 | nn.Linear(800, 500), 42 | nn.LeakyReLU()) 43 | modules.append(self.label4) 44 | self.encoder = nn.Sequential(*modules) 45 | 46 | self.fc_mu = nn.Linear(500, latent_dim) 47 | self.fc_var = nn.Linear(500, latent_dim) 48 | 49 | #decoder 50 | modules = [] 51 | self.decoder_label1 = nn.Sequential( 52 | nn.Linear(latent_dim, 500), 53 | nn.LeakyReLU()) 54 | modules.append(self.decoder_label1) 55 | 56 | self.decoder_label2 = nn.Sequential( 57 | nn.Linear(500, 800), 58 | nn.LeakyReLU()) 59 | modules.append(self.decoder_label2) 60 | 61 | self.decoder_label3 = nn.Sequential( 62 | nn.Linear(800, 1000), 63 | nn.LeakyReLU()) 64 | modules.append(self.decoder_label3) 65 | 66 | self.decoder_label4 = nn.Sequential( 67 | nn.Linear(1000, 2000), 68 | nn.LeakyReLU()) 69 | modules.append(self.decoder_label4) 70 | self.decoder = nn.Sequential(*modules) 71 | 72 | self.final_layer = nn.Sequential( 73 | nn.Linear(2000, input_size * input_size * 3), 74 | nn.Tanh()) 75 | 76 | def GiveFeatures(self,x): 77 | with torch.no_grad(): 78 | mu, log_var = self.encode(x) 79 | z = self.reparameterize(mu,log_var) 80 | return z 81 | 82 | def training_step(self,batch): 83 | results = self.forward(batch) 84 | 85 | train_loss = self.loss_function(*results) 86 | 87 | return train_loss['loss'] 88 | 89 | def encode(self, input: Tensor) -> List[Tensor]: 90 | """ 91 | Encodes the input by passing through the encoder network 92 | and returns the latent codes. 93 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 94 | :return: (Tensor) List of latent codes 95 | """ 96 | 97 | input = input.view(-1,self.input_size*self.input_size*3) 98 | result = self.encoder(input) 99 | result = torch.flatten(result, start_dim=1) 100 | 101 | # Split the result into mu and var components 102 | # of the latent Gaussian distribution 103 | mu = self.fc_mu(result) 104 | log_var = self.fc_var(result) 105 | 106 | return [mu, log_var] 107 | 108 | def decode(self, z: Tensor) -> Tensor: 109 | """ 110 | Maps the given latent codes 111 | onto the image space. 112 | :param z: (Tensor) [B x D] 113 | :return: (Tensor) [B x C x H x W] 114 | """ 115 | #result = self.decoder_input(z) 116 | #result = result.view(-1, 512, 2, 2) 117 | result = self.decoder(z) 118 | result = self.final_layer(result) 119 | result = result.view(-1,3,self.input_size,self.input_size) 120 | 121 | return result 122 | 123 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 124 | """ 125 | Reparameterization trick to sample from N(mu, var) from 126 | N(0,1). 127 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 128 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 129 | :return: (Tensor) [B x D] 130 | """ 131 | std = torch.exp(0.5 * logvar) 132 | eps = torch.randn_like(std) 133 | return eps * std + mu 134 | 135 | def Give_MeanAndVar(self,tensor): 136 | return self.encode(tensor) 137 | 138 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 139 | mu, log_var = self.encode(input) 140 | z = self.reparameterize(mu, log_var) 141 | return [self.decode(z), input, mu, log_var] 142 | 143 | def loss_function(self, 144 | *args, 145 | **kwargs) -> dict: 146 | """ 147 | Computes the VAE loss function. 148 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 149 | :param args: 150 | :param kwargs: 151 | :return: 152 | """ 153 | recons = args[0] 154 | input = args[1] 155 | mu = args[2] 156 | log_var = args[3] 157 | 158 | #kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 159 | kld_weight = 1.0 160 | #kld_weight = 1 161 | #recons_loss =F.mse_loss(recons, input) 162 | recons_loss = F.mse_loss(recons, input, size_average=False) / input.size(0) 163 | 164 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 165 | 166 | loss = recons_loss + kld_weight * kld_loss 167 | return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()} 168 | 169 | def sample(self, 170 | num_samples:int, 171 | current_device: int, **kwargs) -> Tensor: 172 | """ 173 | Samples from the latent space and return the corresponding 174 | image space map. 175 | :param num_samples: (Int) Number of samples 176 | :param current_device: (Int) Device to run the model 177 | :return: (Tensor) 178 | """ 179 | z = torch.randn(num_samples, 180 | self.latent_dim) 181 | 182 | z = z.to(current_device) 183 | 184 | samples = self.decode(z) 185 | return samples 186 | 187 | def Give_FullSampleByLatent(self,latentSample): 188 | 189 | with torch.no_grad(): 190 | arr = [] 191 | batchsize = 64 192 | count = int(np.shape(latentSample)[0] / batchsize) 193 | for i in range(count): 194 | latentBatch = latentSample[i*batchsize:(i+1)*batchsize] 195 | resultBatch = self.decode(latentBatch) 196 | if np.shape(arr)[0] == 0: 197 | arr = resultBatch 198 | else: 199 | arr = torch.cat([arr,resultBatch],0) 200 | return arr 201 | 202 | def generate(self, x: Tensor, **kwargs) -> Tensor: 203 | """ 204 | Given an input image x, returns the reconstructed image 205 | :param x: (Tensor) [B x C x H x W] 206 | :return: (Tensor) [B x C x H x W] 207 | """ 208 | 209 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/twostage_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class TwoStageVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | hidden_dims2: List = None, 15 | **kwargs) -> None: 16 | super(TwoStageVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | 20 | modules = [] 21 | if hidden_dims is None: 22 | hidden_dims = [32, 64, 128, 256, 512] 23 | 24 | if hidden_dims2 is None: 25 | hidden_dims2 = [1024, 1024] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | self.decoder = nn.Sequential(*modules) 61 | 62 | self.final_layer = nn.Sequential( 63 | nn.ConvTranspose2d(hidden_dims[-1], 64 | hidden_dims[-1], 65 | kernel_size=3, 66 | stride=2, 67 | padding=1, 68 | output_padding=1), 69 | nn.BatchNorm2d(hidden_dims[-1]), 70 | nn.LeakyReLU(), 71 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 72 | kernel_size= 3, padding= 1), 73 | nn.Tanh()) 74 | 75 | #---------------------- Second VAE ---------------------------# 76 | encoder2 = [] 77 | in_channels = self.latent_dim 78 | for h_dim in hidden_dims2: 79 | encoder2.append(nn.Sequential( 80 | nn.Linear(in_channels, h_dim), 81 | nn.BatchNorm1d(h_dim), 82 | nn.LeakyReLU())) 83 | in_channels = h_dim 84 | self.encoder2 = nn.Sequential(*encoder2) 85 | self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim) 86 | self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim) 87 | 88 | decoder2 = [] 89 | hidden_dims2.reverse() 90 | 91 | in_channels = self.latent_dim 92 | for h_dim in hidden_dims2: 93 | decoder2.append(nn.Sequential( 94 | nn.Linear(in_channels, h_dim), 95 | nn.BatchNorm1d(h_dim), 96 | nn.LeakyReLU())) 97 | in_channels = h_dim 98 | self.decoder2 = nn.Sequential(*decoder2) 99 | 100 | def encode(self, input: Tensor) -> List[Tensor]: 101 | """ 102 | Encodes the input by passing through the encoder network 103 | and returns the latent codes. 104 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 105 | :return: (Tensor) List of latent codes 106 | """ 107 | result = self.encoder(input) 108 | result = torch.flatten(result, start_dim=1) 109 | 110 | # Split the result into mu and var components 111 | # of the latent Gaussian distribution 112 | mu = self.fc_mu(result) 113 | log_var = self.fc_var(result) 114 | 115 | return [mu, log_var] 116 | 117 | def decode(self, z: Tensor) -> Tensor: 118 | """ 119 | Maps the given latent codes 120 | onto the image space. 121 | :param z: (Tensor) [B x D] 122 | :return: (Tensor) [B x C x H x W] 123 | """ 124 | result = self.decoder_input(z) 125 | result = result.view(-1, 512, 2, 2) 126 | result = self.decoder(result) 127 | result = self.final_layer(result) 128 | return result 129 | 130 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 131 | """ 132 | Reparameterization trick to sample from N(mu, var) from 133 | N(0,1). 134 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 135 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 136 | :return: (Tensor) [B x D] 137 | """ 138 | std = torch.exp(0.5 * logvar) 139 | eps = torch.randn_like(std) 140 | return eps * std + mu 141 | 142 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 143 | mu, log_var = self.encode(input) 144 | z = self.reparameterize(mu, log_var) 145 | 146 | return [self.decode(z), input, mu, log_var] 147 | 148 | def loss_function(self, 149 | *args, 150 | **kwargs) -> dict: 151 | """ 152 | Computes the VAE loss function. 153 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 154 | :param args: 155 | :param kwargs: 156 | :return: 157 | """ 158 | recons = args[0] 159 | input = args[1] 160 | mu = args[2] 161 | log_var = args[3] 162 | 163 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 164 | recons_loss =F.mse_loss(recons, input) 165 | 166 | 167 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 168 | 169 | loss = recons_loss + kld_weight * kld_loss 170 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 171 | 172 | def sample(self, 173 | num_samples:int, 174 | current_device: int, **kwargs) -> Tensor: 175 | """ 176 | Samples from the latent space and return the corresponding 177 | image space map. 178 | :param num_samples: (Int) Number of samples 179 | :param current_device: (Int) Device to run the model 180 | :return: (Tensor) 181 | """ 182 | z = torch.randn(num_samples, 183 | self.latent_dim) 184 | 185 | z = z.to(current_device) 186 | 187 | samples = self.decode(z) 188 | return samples 189 | 190 | def generate(self, x: Tensor, **kwargs) -> Tensor: 191 | """ 192 | Given an input image x, returns the reconstructed image 193 | :param x: (Tensor) [B x C x H x W] 194 | :return: (Tensor) [B x C x H x W] 195 | """ 196 | 197 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/miwae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | from torch.distributions import Normal 7 | 8 | 9 | class MIWAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | num_samples: int = 5, 16 | num_estimates: int = 5, 17 | **kwargs) -> None: 18 | super(MIWAE, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.num_samples = num_samples # K 22 | self.num_estimates = num_estimates # M 23 | 24 | modules = [] 25 | if hidden_dims is None: 26 | hidden_dims = [32, 64, 128, 256, 512] 27 | 28 | # Build Encoder 29 | for h_dim in hidden_dims: 30 | modules.append( 31 | nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels=h_dim, 33 | kernel_size= 3, stride= 2, padding = 1), 34 | nn.BatchNorm2d(h_dim), 35 | nn.LeakyReLU()) 36 | ) 37 | in_channels = h_dim 38 | 39 | self.encoder = nn.Sequential(*modules) 40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 42 | 43 | 44 | # Build Decoder 45 | modules = [] 46 | 47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 48 | 49 | hidden_dims.reverse() 50 | 51 | for i in range(len(hidden_dims) - 1): 52 | modules.append( 53 | nn.Sequential( 54 | nn.ConvTranspose2d(hidden_dims[i], 55 | hidden_dims[i + 1], 56 | kernel_size=3, 57 | stride = 2, 58 | padding=1, 59 | output_padding=1), 60 | nn.BatchNorm2d(hidden_dims[i + 1]), 61 | nn.LeakyReLU()) 62 | ) 63 | 64 | 65 | 66 | self.decoder = nn.Sequential(*modules) 67 | 68 | self.final_layer = nn.Sequential( 69 | nn.ConvTranspose2d(hidden_dims[-1], 70 | hidden_dims[-1], 71 | kernel_size=3, 72 | stride=2, 73 | padding=1, 74 | output_padding=1), 75 | nn.BatchNorm2d(hidden_dims[-1]), 76 | nn.LeakyReLU(), 77 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 78 | kernel_size= 3, padding= 1), 79 | nn.Tanh()) 80 | 81 | def encode(self, input: Tensor) -> List[Tensor]: 82 | """ 83 | Encodes the input by passing through the encoder network 84 | and returns the latent codes. 85 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 86 | :return: (Tensor) List of latent codes 87 | """ 88 | result = self.encoder(input) 89 | result = torch.flatten(result, start_dim=1) 90 | 91 | # Split the result into mu and var components 92 | # of the latent Gaussian distribution 93 | mu = self.fc_mu(result) 94 | log_var = self.fc_var(result) 95 | 96 | return [mu, log_var] 97 | 98 | def decode(self, z: Tensor) -> Tensor: 99 | """ 100 | Maps the given latent codes of S samples 101 | onto the image space. 102 | :param z: (Tensor) [B x S x D] 103 | :return: (Tensor) [B x S x C x H x W] 104 | """ 105 | B, M,S, D = z.size() 106 | z = z.contiguous().view(-1, self.latent_dim) #[BMS x D] 107 | result = self.decoder_input(z) 108 | result = result.view(-1, 512, 2, 2) 109 | result = self.decoder(result) 110 | result = self.final_layer(result) #[BMS x C x H x W ] 111 | result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W] 112 | return result 113 | 114 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 115 | """ 116 | :param mu: (Tensor) Mean of the latent Gaussian 117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 118 | :return: 119 | """ 120 | std = torch.exp(0.5 * logvar) 121 | eps = torch.randn_like(std) 122 | return eps * std + mu 123 | 124 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 125 | mu, log_var = self.encode(input) 126 | mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] 127 | log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] 128 | z = self.reparameterize(mu, log_var) # [B x M x S x D] 129 | eps = (z - mu) / log_var # Prior samples 130 | return [self.decode(z), input, mu, log_var, z, eps] 131 | 132 | def loss_function(self, 133 | *args, 134 | **kwargs) -> dict: 135 | """ 136 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 137 | :param args: 138 | :param kwargs: 139 | :return: 140 | """ 141 | recons = args[0] 142 | input = args[1] 143 | mu = args[2] 144 | log_var = args[3] 145 | z = args[4] 146 | eps = args[5] 147 | 148 | input = input.repeat(self.num_estimates, 149 | self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W] 150 | 151 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 152 | 153 | log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S] 154 | 155 | kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S] 156 | # Get importance weights 157 | log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data 158 | 159 | # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 160 | weight = F.softmax(log_weight, dim = -1) # [B x M x S] 161 | 162 | loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0) 163 | 164 | return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} 165 | 166 | def sample(self, 167 | num_samples:int, 168 | current_device: int, **kwargs) -> Tensor: 169 | """ 170 | Samples from the latent space and return the corresponding 171 | image space map. 172 | :param num_samples: (Int) Number of samples 173 | :param current_device: (Int) Device to run the model 174 | :return: (Tensor) 175 | """ 176 | z = torch.randn(num_samples, 1, 1, 177 | self.latent_dim) 178 | 179 | z = z.to(current_device) 180 | 181 | samples = self.decode(z).squeeze() 182 | return samples 183 | 184 | def generate(self, x: Tensor, **kwargs) -> Tensor: 185 | """ 186 | Given an input image x, returns the reconstructed image. 187 | Returns only the first reconstructed sample 188 | :param x: (Tensor) [B x C x H x W] 189 | :return: (Tensor) [B x C x H x W] 190 | """ 191 | 192 | return self.forward(x)[0][:, 0, 0, :] 193 | -------------------------------------------------------------------------------- /models/swae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch import distributions as dist 6 | from .types_ import * 7 | 8 | 9 | class SWAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | reg_weight: int = 100, 16 | wasserstein_deg: float= 2., 17 | num_projections: int = 50, 18 | projection_dist: str = 'normal', 19 | **kwargs) -> None: 20 | super(SWAE, self).__init__() 21 | 22 | self.latent_dim = latent_dim 23 | self.reg_weight = reg_weight 24 | self.p = wasserstein_deg 25 | self.num_projections = num_projections 26 | self.proj_dist = projection_dist 27 | 28 | modules = [] 29 | if hidden_dims is None: 30 | hidden_dims = [32, 64, 128, 256, 512] 31 | 32 | # Build Encoder 33 | for h_dim in hidden_dims: 34 | modules.append( 35 | nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels=h_dim, 37 | kernel_size= 3, stride= 2, padding = 1), 38 | nn.BatchNorm2d(h_dim), 39 | nn.LeakyReLU()) 40 | ) 41 | in_channels = h_dim 42 | 43 | self.encoder = nn.Sequential(*modules) 44 | self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim) 45 | 46 | 47 | # Build Decoder 48 | modules = [] 49 | 50 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 51 | 52 | hidden_dims.reverse() 53 | 54 | for i in range(len(hidden_dims) - 1): 55 | modules.append( 56 | nn.Sequential( 57 | nn.ConvTranspose2d(hidden_dims[i], 58 | hidden_dims[i + 1], 59 | kernel_size=3, 60 | stride = 2, 61 | padding=1, 62 | output_padding=1), 63 | nn.BatchNorm2d(hidden_dims[i + 1]), 64 | nn.LeakyReLU()) 65 | ) 66 | 67 | 68 | 69 | self.decoder = nn.Sequential(*modules) 70 | 71 | self.final_layer = nn.Sequential( 72 | nn.ConvTranspose2d(hidden_dims[-1], 73 | hidden_dims[-1], 74 | kernel_size=3, 75 | stride=2, 76 | padding=1, 77 | output_padding=1), 78 | nn.BatchNorm2d(hidden_dims[-1]), 79 | nn.LeakyReLU(), 80 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 81 | kernel_size= 3, padding= 1), 82 | nn.Tanh()) 83 | 84 | def encode(self, input: Tensor) -> Tensor: 85 | """ 86 | Encodes the input by passing through the encoder network 87 | and returns the latent codes. 88 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 89 | :return: (Tensor) List of latent codes 90 | """ 91 | result = self.encoder(input) 92 | result = torch.flatten(result, start_dim=1) 93 | 94 | # Split the result into mu and var components 95 | # of the latent Gaussian distribution 96 | z = self.fc_z(result) 97 | return z 98 | 99 | def decode(self, z: Tensor) -> Tensor: 100 | result = self.decoder_input(z) 101 | result = result.view(-1, 512, 2, 2) 102 | result = self.decoder(result) 103 | result = self.final_layer(result) 104 | return result 105 | 106 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 107 | z = self.encode(input) 108 | return [self.decode(z), input, z] 109 | 110 | def loss_function(self, 111 | *args, 112 | **kwargs) -> dict: 113 | recons = args[0] 114 | input = args[1] 115 | z = args[2] 116 | 117 | batch_size = input.size(0) 118 | bias_corr = batch_size * (batch_size - 1) 119 | reg_weight = self.reg_weight / bias_corr 120 | 121 | recons_loss_l2 = F.mse_loss(recons, input) 122 | recons_loss_l1 = F.l1_loss(recons, input) 123 | 124 | swd_loss = self.compute_swd(z, self.p, reg_weight) 125 | 126 | loss = recons_loss_l2 + recons_loss_l1 + swd_loss 127 | return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss} 128 | 129 | def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor: 130 | """ 131 | Returns random samples from latent distribution's (Gaussian) 132 | unit sphere for projecting the encoded samples and the 133 | distribution samples. 134 | 135 | :param latent_dim: (Int) Dimensionality of the latent space (D) 136 | :param num_samples: (Int) Number of samples required (S) 137 | :return: Random projections from the latent unit sphere 138 | """ 139 | if self.proj_dist == 'normal': 140 | rand_samples = torch.randn(num_samples, latent_dim) 141 | elif self.proj_dist == 'cauchy': 142 | rand_samples = dist.Cauchy(torch.tensor([0.0]), 143 | torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze() 144 | else: 145 | raise ValueError('Unknown projection distribution.') 146 | 147 | rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1) 148 | return rand_proj # [S x D] 149 | 150 | 151 | def compute_swd(self, 152 | z: Tensor, 153 | p: float, 154 | reg_weight: float) -> Tensor: 155 | """ 156 | Computes the Sliced Wasserstein Distance (SWD) - which consists of 157 | randomly projecting the encoded and prior vectors and computing 158 | their Wasserstein distance along those projections. 159 | 160 | :param z: Latent samples # [N x D] 161 | :param p: Value for the p^th Wasserstein distance 162 | :param reg_weight: 163 | :return: 164 | """ 165 | prior_z = torch.randn_like(z) # [N x D] 166 | device = z.device 167 | 168 | proj_matrix = self.get_random_projections(self.latent_dim, 169 | num_samples=self.num_projections).transpose(0,1).to(device) 170 | 171 | latent_projections = z.matmul(proj_matrix) # [N x S] 172 | prior_projections = prior_z.matmul(proj_matrix) # [N x S] 173 | 174 | # The Wasserstein distance is computed by sorting the two projections 175 | # across the batches and computing their element-wise l2 distance 176 | w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \ 177 | torch.sort(prior_projections.t(), dim=1)[0] 178 | w_dist = w_dist.pow(p) 179 | return reg_weight * w_dist.mean() 180 | 181 | def sample(self, 182 | num_samples:int, 183 | current_device: int, **kwargs) -> Tensor: 184 | """ 185 | Samples from the latent space and return the corresponding 186 | image space map. 187 | :param num_samples: (Int) Number of samples 188 | :param current_device: (Int) Device to run the model 189 | :return: (Tensor) 190 | """ 191 | z = torch.randn(num_samples, 192 | self.latent_dim) 193 | 194 | z = z.to(current_device) 195 | 196 | samples = self.decode(z) 197 | return samples 198 | 199 | def generate(self, x: Tensor, **kwargs) -> Tensor: 200 | """ 201 | Given an input image x, returns the reconstructed image 202 | :param x: (Tensor) [B x C x H x W] 203 | :return: (Tensor) [B x C x H x W] 204 | """ 205 | 206 | return self.forward(x)[0] 207 | -------------------------------------------------------------------------------- /models/dfcvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torchvision.models import vgg19_bn 5 | from torch.nn import functional as F 6 | from .types_ import * 7 | 8 | 9 | class DFCVAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | alpha:float = 1, 16 | beta:float = 0.5, 17 | **kwargs) -> None: 18 | super(DFCVAE, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.alpha = alpha 22 | self.beta = beta 23 | 24 | modules = [] 25 | if hidden_dims is None: 26 | hidden_dims = [32, 64, 128, 256, 512] 27 | 28 | # Build Encoder 29 | for h_dim in hidden_dims: 30 | modules.append( 31 | nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels=h_dim, 33 | kernel_size= 3, stride= 2, padding = 1), 34 | nn.BatchNorm2d(h_dim), 35 | nn.LeakyReLU()) 36 | ) 37 | in_channels = h_dim 38 | 39 | self.encoder = nn.Sequential(*modules) 40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 42 | 43 | 44 | # Build Decoder 45 | modules = [] 46 | 47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 48 | 49 | hidden_dims.reverse() 50 | 51 | for i in range(len(hidden_dims) - 1): 52 | modules.append( 53 | nn.Sequential( 54 | nn.ConvTranspose2d(hidden_dims[i], 55 | hidden_dims[i + 1], 56 | kernel_size=3, 57 | stride = 2, 58 | padding=1, 59 | output_padding=1), 60 | nn.BatchNorm2d(hidden_dims[i + 1]), 61 | nn.LeakyReLU()) 62 | ) 63 | 64 | 65 | 66 | self.decoder = nn.Sequential(*modules) 67 | 68 | self.final_layer = nn.Sequential( 69 | nn.ConvTranspose2d(hidden_dims[-1], 70 | hidden_dims[-1], 71 | kernel_size=3, 72 | stride=2, 73 | padding=1, 74 | output_padding=1), 75 | nn.BatchNorm2d(hidden_dims[-1]), 76 | nn.LeakyReLU(), 77 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 78 | kernel_size= 3, padding= 1), 79 | nn.Tanh()) 80 | 81 | self.feature_network = vgg19_bn(pretrained=True) 82 | 83 | # Freeze the pretrained feature network 84 | for param in self.feature_network.parameters(): 85 | param.requires_grad = False 86 | 87 | self.feature_network.eval() 88 | 89 | 90 | def encode(self, input: Tensor) -> List[Tensor]: 91 | """ 92 | Encodes the input by passing through the encoder network 93 | and returns the latent codes. 94 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 95 | :return: (Tensor) List of latent codes 96 | """ 97 | result = self.encoder(input) 98 | result = torch.flatten(result, start_dim=1) 99 | 100 | # Split the result into mu and var components 101 | # of the latent Gaussian distribution 102 | mu = self.fc_mu(result) 103 | log_var = self.fc_var(result) 104 | 105 | return [mu, log_var] 106 | 107 | def decode(self, z: Tensor) -> Tensor: 108 | """ 109 | Maps the given latent codes 110 | onto the image space. 111 | :param z: (Tensor) [B x D] 112 | :return: (Tensor) [B x C x H x W] 113 | """ 114 | result = self.decoder_input(z) 115 | result = result.view(-1, 512, 2, 2) 116 | result = self.decoder(result) 117 | result = self.final_layer(result) 118 | return result 119 | 120 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 121 | """ 122 | Reparameterization trick to sample from N(mu, var) from 123 | N(0,1). 124 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 125 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 126 | :return: (Tensor) [B x D] 127 | """ 128 | std = torch.exp(0.5 * logvar) 129 | eps = torch.randn_like(std) 130 | return eps * std + mu 131 | 132 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 133 | mu, log_var = self.encode(input) 134 | z = self.reparameterize(mu, log_var) 135 | recons = self.decode(z) 136 | 137 | recons_features = self.extract_features(recons) 138 | input_features = self.extract_features(input) 139 | 140 | return [recons, input, recons_features, input_features, mu, log_var] 141 | 142 | def extract_features(self, 143 | input: Tensor, 144 | feature_layers: List = None) -> List[Tensor]: 145 | """ 146 | Extracts the features from the pretrained model 147 | at the layers indicated by feature_layers. 148 | :param input: (Tensor) [B x C x H x W] 149 | :param feature_layers: List of string of IDs 150 | :return: List of the extracted features 151 | """ 152 | if feature_layers is None: 153 | feature_layers = ['14', '24', '34', '43'] 154 | features = [] 155 | result = input 156 | for (key, module) in self.feature_network.features._modules.items(): 157 | result = module(result) 158 | if(key in feature_layers): 159 | features.append(result) 160 | 161 | return features 162 | 163 | def loss_function(self, 164 | *args, 165 | **kwargs) -> dict: 166 | """ 167 | Computes the VAE loss function. 168 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 169 | :param args: 170 | :param kwargs: 171 | :return: 172 | """ 173 | recons = args[0] 174 | input = args[1] 175 | recons_features = args[2] 176 | input_features = args[3] 177 | mu = args[4] 178 | log_var = args[5] 179 | 180 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 181 | recons_loss =F.mse_loss(recons, input) 182 | 183 | feature_loss = 0.0 184 | for (r, i) in zip(recons_features, input_features): 185 | feature_loss += F.mse_loss(r, i) 186 | 187 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 188 | 189 | loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss 190 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 191 | 192 | def sample(self, 193 | num_samples:int, 194 | current_device: int, **kwargs) -> Tensor: 195 | """ 196 | Samples from the latent space and return the corresponding 197 | image space map. 198 | :param num_samples: (Int) Number of samples 199 | :param current_device: (Int) Device to run the model 200 | :return: (Tensor) 201 | """ 202 | z = torch.randn(num_samples, 203 | self.latent_dim) 204 | 205 | z = z.to(current_device) 206 | 207 | samples = self.decode(z) 208 | return samples 209 | 210 | def generate(self, x: Tensor, **kwargs) -> Tensor: 211 | """ 212 | Given an input image x, returns the reconstructed image 213 | :param x: (Tensor) [B x C x H x W] 214 | :return: (Tensor) [B x C x H x W] 215 | """ 216 | 217 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/cat_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models import BaseVAE 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from .types_ import * 7 | 8 | 9 | class CategoricalVAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | categorical_dim: int = 40, # Num classes 15 | hidden_dims: List = None, 16 | temperature: float = 0.5, 17 | anneal_rate: float = 3e-5, 18 | anneal_interval: int = 100, # every 100 batches 19 | alpha: float = 30., 20 | **kwargs) -> None: 21 | super(CategoricalVAE, self).__init__() 22 | 23 | self.latent_dim = latent_dim 24 | self.categorical_dim = categorical_dim 25 | self.temp = temperature 26 | self.min_temp = temperature 27 | self.anneal_rate = anneal_rate 28 | self.anneal_interval = anneal_interval 29 | self.alpha = alpha 30 | 31 | modules = [] 32 | if hidden_dims is None: 33 | hidden_dims = [32, 64, 128, 256, 512] 34 | 35 | # Build Encoder 36 | for h_dim in hidden_dims: 37 | modules.append( 38 | nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels=h_dim, 40 | kernel_size= 3, stride= 2, padding = 1), 41 | nn.BatchNorm2d(h_dim), 42 | nn.LeakyReLU()) 43 | ) 44 | in_channels = h_dim 45 | 46 | self.encoder = nn.Sequential(*modules) 47 | self.fc_z = nn.Linear(hidden_dims[-1]*4, 48 | self.latent_dim * self.categorical_dim) 49 | 50 | # Build Decoder 51 | modules = [] 52 | 53 | self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim 54 | , hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1))) 88 | 89 | def encode(self, input: Tensor) -> List[Tensor]: 90 | """ 91 | Encodes the input by passing through the encoder network 92 | and returns the latent codes. 93 | :param input: (Tensor) Input tensor to encoder [B x C x H x W] 94 | :return: (Tensor) Latent code [B x D x Q] 95 | """ 96 | result = self.encoder(input) 97 | result = torch.flatten(result, start_dim=1) 98 | 99 | # Split the result into mu and var components 100 | # of the latent Gaussian distribution 101 | z = self.fc_z(result) 102 | z = z.view(-1, self.latent_dim, self.categorical_dim) 103 | return [z] 104 | 105 | def decode(self, z: Tensor) -> Tensor: 106 | """ 107 | Maps the given latent codes 108 | onto the image space. 109 | :param z: (Tensor) [B x D x Q] 110 | :return: (Tensor) [B x C x H x W] 111 | """ 112 | result = self.decoder_input(z) 113 | result = result.view(-1, 512, 2, 2) 114 | result = self.decoder(result) 115 | result = self.final_layer(result) 116 | return result 117 | 118 | def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor: 119 | """ 120 | Gumbel-softmax trick to sample from Categorical Distribution 121 | :param z: (Tensor) Latent Codes [B x D x Q] 122 | :return: (Tensor) [B x D] 123 | """ 124 | # Sample from Gumbel 125 | u = torch.rand_like(z) 126 | g = - torch.log(- torch.log(u + eps) + eps) 127 | 128 | # Gumbel-Softmax sample 129 | s = F.softmax((z + g) / self.temp, dim=-1) 130 | s = s.view(-1, self.latent_dim * self.categorical_dim) 131 | return s 132 | 133 | 134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 135 | q = self.encode(input)[0] 136 | z = self.reparameterize(q) 137 | return [self.decode(z), input, q] 138 | 139 | def loss_function(self, 140 | *args, 141 | **kwargs) -> dict: 142 | """ 143 | Computes the VAE loss function. 144 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 145 | :param args: 146 | :param kwargs: 147 | :return: 148 | """ 149 | recons = args[0] 150 | input = args[1] 151 | q = args[2] 152 | 153 | q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities 154 | 155 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 156 | batch_idx = kwargs['batch_idx'] 157 | 158 | # Anneal the temperature at regular intervals 159 | if batch_idx % self.anneal_interval == 0 and self.training: 160 | self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), 161 | self.min_temp) 162 | 163 | recons_loss =F.mse_loss(recons, input, reduction='mean') 164 | 165 | # KL divergence between gumbel-softmax distribution 166 | eps = 1e-7 167 | 168 | # Entropy of the logits 169 | h1 = q_p * torch.log(q_p + eps) 170 | 171 | # Cross entropy with the categorical distribution 172 | h2 = q_p * np.log(1. / self.categorical_dim + eps) 173 | kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0) 174 | 175 | # kld_weight = 1.2 176 | loss = self.alpha * recons_loss + kld_weight * kld_loss 177 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 178 | 179 | def sample(self, 180 | num_samples:int, 181 | current_device: int, **kwargs) -> Tensor: 182 | """ 183 | Samples from the latent space and return the corresponding 184 | image space map. 185 | :param num_samples: (Int) Number of samples 186 | :param current_device: (Int) Device to run the model 187 | :return: (Tensor) 188 | """ 189 | # [S x D x Q] 190 | 191 | M = num_samples * self.latent_dim 192 | np_y = np.zeros((M, self.categorical_dim), dtype=np.float32) 193 | np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1 194 | np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim]) 195 | z = torch.from_numpy(np_y) 196 | 197 | # z = self.sampling_dist.sample((num_samples * self.latent_dim, )) 198 | z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device) 199 | samples = self.decode(z) 200 | return samples 201 | 202 | def generate(self, x: Tensor, **kwargs) -> Tensor: 203 | """ 204 | Given an input image x, returns the reconstructed image 205 | :param x: (Tensor) [B x C x H x W] 206 | :return: (Tensor) [B x C x H x W] 207 | """ 208 | 209 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/vq_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | class VectorQuantizer(nn.Module): 8 | """ 9 | Reference: 10 | [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py 11 | """ 12 | def __init__(self, 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | beta: float = 0.25): 16 | super(VectorQuantizer, self).__init__() 17 | self.K = num_embeddings 18 | self.D = embedding_dim 19 | self.beta = beta 20 | 21 | self.embedding = nn.Embedding(self.K, self.D) 22 | self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) 23 | 24 | def forward(self, latents: Tensor) -> Tensor: 25 | latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] 26 | latents_shape = latents.shape 27 | flat_latents = latents.view(-1, self.D) # [BHW x D] 28 | 29 | # Compute L2 distance between latents and embedding weights 30 | dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ 31 | torch.sum(self.embedding.weight ** 2, dim=1) - \ 32 | 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] 33 | 34 | # Get the encoding that has the min distance 35 | encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] 36 | 37 | # Convert to one-hot encodings 38 | device = latents.device 39 | encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) 40 | encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] 41 | 42 | # Quantize the latents 43 | quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] 44 | quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] 45 | 46 | # Compute the VQ Losses 47 | commitment_loss = F.mse_loss(quantized_latents.detach(), latents) 48 | embedding_loss = F.mse_loss(quantized_latents, latents.detach()) 49 | 50 | vq_loss = commitment_loss * self.beta + embedding_loss 51 | 52 | # Add the residue back to the latents 53 | quantized_latents = latents + (quantized_latents - latents).detach() 54 | 55 | return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] 56 | 57 | class ResidualLayer(nn.Module): 58 | 59 | def __init__(self, 60 | in_channels: int, 61 | out_channels: int): 62 | super(ResidualLayer, self).__init__() 63 | self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels, 64 | kernel_size=3, padding=1, bias=False), 65 | nn.ReLU(True), 66 | nn.Conv2d(out_channels, out_channels, 67 | kernel_size=1, bias=False)) 68 | 69 | def forward(self, input: Tensor) -> Tensor: 70 | return input + self.resblock(input) 71 | 72 | 73 | class VQVAE(BaseVAE): 74 | 75 | def __init__(self, 76 | in_channels: int, 77 | embedding_dim: int, 78 | num_embeddings: int, 79 | hidden_dims: List = None, 80 | beta: float = 0.25, 81 | img_size: int = 64, 82 | device: List = None, 83 | **kwargs) -> None: 84 | super(VQVAE, self).__init__() 85 | 86 | self.embedding_dim = embedding_dim 87 | self.num_embeddings = num_embeddings 88 | self.img_size = img_size 89 | self.beta = beta 90 | 91 | modules = [] 92 | if hidden_dims is None: 93 | hidden_dims = [128, 256] 94 | 95 | # Build Encoder 96 | for h_dim in hidden_dims: 97 | modules.append( 98 | nn.Sequential( 99 | nn.Conv2d(in_channels, out_channels=h_dim, 100 | kernel_size=4, stride=2, padding=1), 101 | nn.LeakyReLU()) 102 | ) 103 | in_channels = h_dim 104 | 105 | modules.append( 106 | nn.Sequential( 107 | nn.Conv2d(in_channels, in_channels, 108 | kernel_size=3, stride=1, padding=1), 109 | nn.LeakyReLU()) 110 | ) 111 | 112 | for _ in range(6): 113 | modules.append(ResidualLayer(in_channels, in_channels)) 114 | modules.append(nn.LeakyReLU()) 115 | 116 | modules.append( 117 | nn.Sequential( 118 | nn.Conv2d(in_channels, embedding_dim, 119 | kernel_size=1, stride=1), 120 | nn.LeakyReLU()) 121 | ) 122 | 123 | self.encoder = nn.Sequential(*modules) 124 | 125 | self.device = device 126 | self.encoder.to(device) 127 | 128 | self.vq_layer = VectorQuantizer(num_embeddings, 129 | embedding_dim, 130 | self.beta) 131 | self.vq_layer.to(device) 132 | 133 | # Build Decoder 134 | modules = [] 135 | modules.append( 136 | nn.Sequential( 137 | nn.Conv2d(embedding_dim, 138 | hidden_dims[-1], 139 | kernel_size=3, 140 | stride=1, 141 | padding=1), 142 | nn.LeakyReLU()) 143 | ) 144 | 145 | for _ in range(6): 146 | modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1])) 147 | 148 | modules.append(nn.LeakyReLU()) 149 | 150 | hidden_dims.reverse() 151 | 152 | for i in range(len(hidden_dims) - 1): 153 | modules.append( 154 | nn.Sequential( 155 | nn.ConvTranspose2d(hidden_dims[i], 156 | hidden_dims[i + 1], 157 | kernel_size=4, 158 | stride=2, 159 | padding=1), 160 | nn.LeakyReLU()) 161 | ) 162 | 163 | modules.append( 164 | nn.Sequential( 165 | nn.ConvTranspose2d(hidden_dims[-1], 166 | out_channels=3, 167 | kernel_size=4, 168 | stride=2, padding=1), 169 | nn.Tanh())) 170 | 171 | self.decoder = nn.Sequential(*modules) 172 | self.decoder.to(device) 173 | 174 | def encode(self, input: Tensor) -> List[Tensor]: 175 | """ 176 | Encodes the input by passing through the encoder network 177 | and returns the latent codes. 178 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 179 | :return: (Tensor) List of latent codes 180 | """ 181 | result = self.encoder(input) 182 | return [result] 183 | 184 | def decode(self, z: Tensor) -> Tensor: 185 | """ 186 | Maps the given latent codes 187 | onto the image space. 188 | :param z: (Tensor) [B x D x H x W] 189 | :return: (Tensor) [B x C x H x W] 190 | """ 191 | 192 | result = self.decoder(z) 193 | return result 194 | 195 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 196 | encoding = self.encode(input)[0] 197 | quantized_inputs, vq_loss = self.vq_layer(encoding) 198 | return [self.decode(quantized_inputs), input, vq_loss] 199 | 200 | def loss_function(self, 201 | *args, 202 | **kwargs) -> dict: 203 | """ 204 | :param args: 205 | :param kwargs: 206 | :return: 207 | """ 208 | recons = args[0] 209 | input = args[1] 210 | vq_loss = args[2] 211 | 212 | recons_loss = F.mse_loss(recons, input) 213 | 214 | loss = recons_loss + vq_loss 215 | return {'loss': loss, 216 | 'Reconstruction_Loss': recons_loss, 217 | 'VQ_Loss':vq_loss} 218 | 219 | def sample(self, 220 | num_samples: int, 221 | current_device: Union[int, str], **kwargs) -> Tensor: 222 | raise Warning('VQVAE sampler is not implemented.') 223 | 224 | def generate(self, x: Tensor, **kwargs) -> Tensor: 225 | """ 226 | Given an input image x, returns the reconstructed image 227 | :param x: (Tensor) [B x C x H x W] 228 | :return: (Tensor) [B x C x H x W] 229 | """ 230 | 231 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/fvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class FactorVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | gamma: float = 40., 15 | **kwargs) -> None: 16 | super(FactorVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | self.gamma = gamma 20 | 21 | modules = [] 22 | if hidden_dims is None: 23 | hidden_dims = [32, 64, 128, 256, 512] 24 | 25 | # Build Encoder 26 | for h_dim in hidden_dims: 27 | modules.append( 28 | nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels=h_dim, 30 | kernel_size= 3, stride= 2, padding = 1), 31 | nn.BatchNorm2d(h_dim), 32 | nn.LeakyReLU()) 33 | ) 34 | in_channels = h_dim 35 | 36 | self.encoder = nn.Sequential(*modules) 37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | 40 | 41 | # Build Decoder 42 | modules = [] 43 | 44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 45 | 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | 61 | 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | # Discriminator network for the Total Correlation (TC) loss 79 | self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000), 80 | nn.BatchNorm1d(1000), 81 | nn.LeakyReLU(0.2), 82 | nn.Linear(1000, 1000), 83 | nn.BatchNorm1d(1000), 84 | nn.LeakyReLU(0.2), 85 | nn.Linear(1000, 1000), 86 | nn.BatchNorm1d(1000), 87 | nn.LeakyReLU(0.2), 88 | nn.Linear(1000, 2)) 89 | self.D_z_reserve = None 90 | 91 | 92 | def encode(self, input: Tensor) -> List[Tensor]: 93 | """ 94 | Encodes the input by passing through the encoder network 95 | and returns the latent codes. 96 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 97 | :return: (Tensor) List of latent codes 98 | """ 99 | result = self.encoder(input) 100 | result = torch.flatten(result, start_dim=1) 101 | 102 | # Split the result into mu and var components 103 | # of the latent Gaussian distribution 104 | mu = self.fc_mu(result) 105 | log_var = self.fc_var(result) 106 | 107 | return [mu, log_var] 108 | 109 | def decode(self, z: Tensor) -> Tensor: 110 | """ 111 | Maps the given latent codes 112 | onto the image space. 113 | :param z: (Tensor) [B x D] 114 | :return: (Tensor) [B x C x H x W] 115 | """ 116 | result = self.decoder_input(z) 117 | result = result.view(-1, 512, 2, 2) 118 | result = self.decoder(result) 119 | result = self.final_layer(result) 120 | return result 121 | 122 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 123 | """ 124 | Reparameterization trick to sample from N(mu, var) from 125 | N(0,1). 126 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 127 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 128 | :return: (Tensor) [B x D] 129 | """ 130 | std = torch.exp(0.5 * logvar) 131 | eps = torch.randn_like(std) 132 | return eps * std + mu 133 | 134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 135 | mu, log_var = self.encode(input) 136 | z = self.reparameterize(mu, log_var) 137 | return [self.decode(z), input, mu, log_var, z] 138 | 139 | def permute_latent(self, z: Tensor) -> Tensor: 140 | """ 141 | Permutes each of the latent codes in the batch 142 | :param z: [B x D] 143 | :return: [B x D] 144 | """ 145 | B, D = z.size() 146 | 147 | # Returns a shuffled inds for each latent code in the batch 148 | inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)]) 149 | return z.view(-1)[inds].view(B, D) 150 | 151 | def loss_function(self, 152 | *args, 153 | **kwargs) -> dict: 154 | """ 155 | Computes the VAE loss function. 156 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 157 | :param args: 158 | :param kwargs: 159 | :return: 160 | """ 161 | recons = args[0] 162 | input = args[1] 163 | mu = args[2] 164 | log_var = args[3] 165 | z = args[4] 166 | 167 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 168 | optimizer_idx = kwargs['optimizer_idx'] 169 | 170 | # Update the VAE 171 | if optimizer_idx == 0: 172 | recons_loss =F.mse_loss(recons, input) 173 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 174 | 175 | self.D_z_reserve = self.discriminator(z) 176 | vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean() 177 | 178 | loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss 179 | 180 | # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}') 181 | return {'loss': loss, 182 | 'Reconstruction_Loss':recons_loss, 183 | 'KLD':-kld_loss, 184 | 'VAE_TC_Loss': vae_tc_loss} 185 | 186 | # Update the Discriminator 187 | elif optimizer_idx == 1: 188 | device = input.device 189 | true_labels = torch.ones(input.size(0), dtype= torch.long, 190 | requires_grad=False).to(device) 191 | false_labels = torch.zeros(input.size(0), dtype= torch.long, 192 | requires_grad=False).to(device) 193 | 194 | z = z.detach() # Detach so that VAE is not trained again 195 | z_perm = self.permute_latent(z) 196 | D_z_perm = self.discriminator(z_perm) 197 | D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) + 198 | F.cross_entropy(D_z_perm, true_labels)) 199 | # print(f'D_TC: {D_tc_loss}') 200 | return {'loss': D_tc_loss, 201 | 'D_TC_Loss':D_tc_loss} 202 | 203 | def sample(self, 204 | num_samples:int, 205 | current_device: int, **kwargs) -> Tensor: 206 | """ 207 | Samples from the latent space and return the corresponding 208 | image space map. 209 | :param num_samples: (Int) Number of samples 210 | :param current_device: (Int) Device to run the model 211 | :return: (Tensor) 212 | """ 213 | z = torch.randn(num_samples, 214 | self.latent_dim) 215 | 216 | z = z.to(current_device) 217 | 218 | samples = self.decode(z) 219 | return samples 220 | 221 | def generate(self, x: Tensor, **kwargs) -> Tensor: 222 | """ 223 | Given an input image x, returns the reconstructed image 224 | :param x: (Tensor) [B x C x H x W] 225 | :return: (Tensor) [B x C x H x W] 226 | """ 227 | 228 | return self.forward(x)[0] --------------------------------------------------------------------------------