├── .gitignore ├── DDPM ├── FIDScorer.py ├── InceptionV3.py ├── client.py ├── data_utils.py ├── model.py ├── precision_recall.py ├── run.sh ├── sample.py ├── server.py └── utils.py ├── DDPM_2 ├── DiffusionCondition.py ├── FIDScorer.py ├── InceptionV3.py ├── Scheduler.py ├── TrainCondition.py ├── client.py ├── config.py ├── data_utils.py ├── model.py ├── run.sh ├── server.py └── utils.py ├── DiT ├── README.md ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── LICENSE ├── README.md ├── U-ViT ├── LICENSE ├── README.md ├── UViT_ImageNet_demo.ipynb ├── client.py ├── common.py ├── configs │ ├── celeba64_uvit_small.py │ ├── cifar10_uvit_small.py │ ├── imagenet256_uvit_huge.py │ ├── imagenet256_uvit_large.py │ ├── imagenet512_uvit_huge.py │ ├── imagenet512_uvit_large.py │ ├── imagenet64_uvit_large.py │ ├── imagenet64_uvit_mid.py │ └── mscoco_uvit_small.py ├── dataset_utils.py ├── datasets.py ├── dpm_solver_pp.py ├── dpm_solver_pytorch.py ├── eval.py ├── eval_ldm.py ├── eval_ldm_discrete.py ├── eval_t2i_discrete.py ├── flower_utils.py ├── libs │ ├── __init__.py │ ├── autoencoder.py │ ├── clip.py │ ├── timm.py │ ├── uvit.py │ └── uvit_t2i.py ├── main.py ├── pyproject.toml ├── requirements.txt ├── run.sh ├── sample.png ├── sample_t2i_discrete.py ├── scripts │ ├── extract_empty_feature.py │ ├── extract_imagenet_feature.py │ ├── extract_mscoco_feature.py │ └── extract_test_prompt_feature.py ├── sde.py ├── skip_im.png ├── tools │ ├── fid_score.py │ └── inception.py ├── train.py ├── train_ldm.py ├── train_ldm_discrete.py ├── train_t2i_discrete.py ├── utils.py └── uvit.png ├── checkpoints └── README.md └── jobscript.sh /.gitignore: -------------------------------------------------------------------------------- 1 | /dataset/ 2 | /checkpoints/* 3 | !checkpoints/README.md 4 | # Tensorboard event files 5 | /runs/* 6 | !runs/README.md 7 | /DDPM/synthetic 8 | /DDPM/data 9 | /dataset 10 | /data 11 | /DDPM_2/runs 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | -------------------------------------------------------------------------------- /DDPM/FIDScorer.py: -------------------------------------------------------------------------------- 1 | from scipy import linalg 2 | 3 | from InceptionV3 import InceptionV3 4 | import numpy as np 5 | from torch.nn.functional import adaptive_avg_pool2d 6 | import torch 7 | 8 | 9 | class FIDScorer: 10 | def __init__(self): 11 | self.model = InceptionV3() 12 | 13 | def calculate_activation_statistics(self, images, device='cpu'): 14 | 15 | model = self.model.to(device) 16 | model.eval() 17 | 18 | activations = [] 19 | for (batch, _) in images: 20 | # if len(batch.shape) < 4: 21 | # batch = torch.unsqueeze(batch, 1) 22 | 23 | if batch.size(1) == 1: 24 | # greyscale 25 | batch = batch.repeat(1, 3, 1, 1) 26 | 27 | batch = batch.to(device) 28 | pred = model(batch)[0] 29 | 30 | # If model output is not scalar, apply global spatial average pooling. 31 | # This happens if you choose a dimensionality not equal 2048. 32 | if pred.size(2) != 1 or pred.size(3) != 1: 33 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 34 | 35 | act = pred.cpu().data.numpy().reshape(pred.size(0), -1) 36 | activations.append(act) 37 | 38 | act = np.concatenate(activations, axis=0) 39 | mu = np.mean(act, axis=0) 40 | sigma = np.cov(act, rowvar=False) 41 | return mu, sigma 42 | 43 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 44 | """Numpy implementation of the Frechet Distance. 45 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 46 | and X_2 ~ N(mu_2, C_2) is 47 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 48 | """ 49 | 50 | mu1 = np.atleast_1d(mu1) 51 | mu2 = np.atleast_1d(mu2) 52 | 53 | sigma1 = np.atleast_2d(sigma1) 54 | sigma2 = np.atleast_2d(sigma2) 55 | 56 | assert mu1.shape == mu2.shape, \ 57 | 'Training and test mean vectors have different lengths' 58 | assert sigma1.shape == sigma2.shape, \ 59 | 'Training and test covariances have different dimensions' 60 | 61 | diff = mu1 - mu2 62 | 63 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 64 | if not np.isfinite(covmean).all(): 65 | msg = ('fid calculation produces singular product; ' 66 | 'adding %s to diagonal of cov estimates') % eps 67 | print(msg) 68 | offset = np.eye(sigma1.shape[0]) * eps 69 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 70 | 71 | if np.iscomplexobj(covmean): 72 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 73 | m = np.max(np.abs(covmean.imag)) 74 | raise ValueError('Imaginary component {}'.format(m)) 75 | covmean = covmean.real 76 | 77 | tr_covmean = np.trace(covmean) 78 | 79 | return (diff.dot(diff) + np.trace(sigma1) + 80 | np.trace(sigma2) - 2 * tr_covmean) 81 | 82 | def calculate_fid(self, images_real, images_fake, device): 83 | mu_1, std_1 = self.calculate_activation_statistics(images_real, device=device) 84 | mu_2, std_2 = self.calculate_activation_statistics(images_fake, device=device) 85 | 86 | """get fretched distance""" 87 | fid_value = self.calculate_frechet_distance(mu_1, std_1, mu_2, std_2) 88 | return fid_value -------------------------------------------------------------------------------- /DDPM/InceptionV3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | 6 | 7 | class InceptionV3(nn.Module): 8 | """Pretrained InceptionV3 network returning feature maps""" 9 | 10 | # Index of default block of inception to return, 11 | # corresponds to output of final average pooling 12 | DEFAULT_BLOCK_INDEX = 3 13 | 14 | # Maps feature dimensionality to their output blocks indices 15 | BLOCK_INDEX_BY_DIM = { 16 | 64: 0, # First max pooling features 17 | 192: 1, # Second max pooling featurs 18 | 768: 2, # Pre-aux classifier features 19 | 2048: 3 # Final average pooling features 20 | } 21 | 22 | def __init__(self, 23 | output_blocks=[DEFAULT_BLOCK_INDEX], 24 | resize_input=True, 25 | normalize_input=True, 26 | requires_grad=False): 27 | 28 | super(InceptionV3, self).__init__() 29 | 30 | self.resize_input = resize_input 31 | self.normalize_input = normalize_input 32 | self.output_blocks = sorted(output_blocks) 33 | self.last_needed_block = max(output_blocks) 34 | 35 | assert self.last_needed_block <= 3, \ 36 | 'Last possible output block index is 3' 37 | 38 | self.blocks = nn.ModuleList() 39 | 40 | inception = models.inception_v3(weights='Inception_V3_Weights.IMAGENET1K_V1') 41 | 42 | # Block 0: input to maxpool1 43 | block0 = [ 44 | inception.Conv2d_1a_3x3, 45 | inception.Conv2d_2a_3x3, 46 | inception.Conv2d_2b_3x3, 47 | nn.MaxPool2d(kernel_size=3, stride=2) 48 | ] 49 | self.blocks.append(nn.Sequential(*block0)) 50 | 51 | # Block 1: maxpool1 to maxpool2 52 | if self.last_needed_block >= 1: 53 | block1 = [ 54 | inception.Conv2d_3b_1x1, 55 | inception.Conv2d_4a_3x3, 56 | nn.MaxPool2d(kernel_size=3, stride=2) 57 | ] 58 | self.blocks.append(nn.Sequential(*block1)) 59 | 60 | # Block 2: maxpool2 to aux classifier 61 | if self.last_needed_block >= 2: 62 | block2 = [ 63 | inception.Mixed_5b, 64 | inception.Mixed_5c, 65 | inception.Mixed_5d, 66 | inception.Mixed_6a, 67 | inception.Mixed_6b, 68 | inception.Mixed_6c, 69 | inception.Mixed_6d, 70 | inception.Mixed_6e, 71 | ] 72 | self.blocks.append(nn.Sequential(*block2)) 73 | 74 | # Block 3: aux classifier to final avgpool 75 | if self.last_needed_block >= 3: 76 | block3 = [ 77 | inception.Mixed_7a, 78 | inception.Mixed_7b, 79 | inception.Mixed_7c, 80 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 81 | ] 82 | self.blocks.append(nn.Sequential(*block3)) 83 | 84 | for param in self.parameters(): 85 | param.requires_grad = requires_grad 86 | 87 | def forward(self, inp): 88 | """Get Inception feature maps 89 | Parameters 90 | ---------- 91 | inp : torch.autograd.Variable 92 | Input tensor of shape Bx3xHxW. Values are expected to be in 93 | range (0, 1) 94 | Returns 95 | ------- 96 | List of torch.autograd.Variable, corresponding to the selected output 97 | block, sorted ascending by index 98 | """ 99 | outp = [] 100 | x = inp 101 | 102 | if self.resize_input: 103 | x = F.interpolate(x, 104 | size=(299, 299), 105 | mode='bilinear', 106 | align_corners=False) 107 | 108 | if self.normalize_input: 109 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 110 | 111 | for idx, block in enumerate(self.blocks): 112 | x = block(x) 113 | if idx in self.output_blocks: 114 | outp.append(x) 115 | 116 | if idx == self.last_needed_block: 117 | break 118 | 119 | return outp -------------------------------------------------------------------------------- /DDPM/client.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import timeit 3 | import warnings 4 | 5 | import torch 6 | import torchvision 7 | 8 | import flwr as fl 9 | from flwr.common import ( 10 | EvaluateIns, 11 | EvaluateRes, 12 | FitIns, 13 | FitRes, 14 | NDArrays, 15 | GetParametersRes, 16 | GetParametersIns, 17 | Status, 18 | Code, 19 | ) 20 | from model import Diffusion 21 | from copy import deepcopy 22 | from torchvision.datasets import ImageFolder 23 | from data_utils import load_data 24 | from utils import ema_update, eval_mode, train_mode, train, test 25 | 26 | warnings.filterwarnings("ignore", category=UserWarning) 27 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | class DiffusionClient(fl.client.Client): 31 | """Flower client implementing CIFAR-10 image classification using PyTorch.""" 32 | 33 | def __init__( 34 | self, 35 | cid: str, 36 | model: Diffusion, 37 | trainset: torchvision.datasets.folder.ImageFolder, 38 | testset: torchvision.datasets.folder.ImageFolder, 39 | ) -> None: 40 | self.cid = cid 41 | self.model = model 42 | self.model_ema = deepcopy(model) 43 | self.trainset = trainset 44 | self.testset = testset 45 | self.scaler = torch.cuda.amp.GradScaler() 46 | # Training settings 47 | self.seed = 0 48 | self.epoch = 0 49 | # Use a low discrepancy quasi-random sequence to sample uniformly distributed 50 | # timesteps. This considerably reduces the between-batch variance of the loss. 51 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 52 | self.ema_decay = 0.998 53 | # The number of timesteps to use when sampling 54 | self.steps = 500 55 | # The amount of noise to add each timestep when sampling 56 | self.eta = 1. 57 | 58 | 59 | def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: 60 | print(f"Client {self.cid}: get_parameters") 61 | 62 | weights: NDArrays = self.model_ema.get_weights() 63 | parameters = fl.common.ndarrays_to_parameters(weights) 64 | return GetParametersRes(status=Status(code=Code.OK, message="Success"), 65 | parameters=parameters) 66 | 67 | def fit(self, ins: FitIns) -> FitRes: 68 | print(f"Client {self.cid}: fit") 69 | 70 | weights: NDArrays = fl.common.parameters_to_ndarrays(ins.parameters) 71 | config = ins.config 72 | fit_begin = timeit.default_timer() 73 | 74 | # Get training config 75 | epochs = int(config["epochs"]) 76 | batch_size = int(config["batch_size"]) 77 | 78 | # Set model parameters 79 | self.model.set_weights(weights) 80 | self.model_ema.set_weights(weights) 81 | 82 | # Train model 83 | trainloader = torch.utils.data.DataLoader( 84 | self.trainset, batch_size=batch_size, shuffle=True 85 | ) 86 | train(self.model, self.model_ema, trainloader, epochs, self.epoch, self.rng, self.scaler, self.ema_decay, DEVICE) 87 | 88 | # Return the refined weights and the number of examples used for training 89 | weights_prime: NDArrays = self.model_ema.get_weights() 90 | params_prime = fl.common.ndarrays_to_parameters(weights_prime) 91 | num_examples_train = len(self.trainset) 92 | metrics = {"duration": timeit.default_timer() - fit_begin} 93 | return FitRes( 94 | parameters=params_prime, 95 | num_examples=num_examples_train, 96 | metrics=metrics, 97 | status=Status(code=Code.OK, message="Success"), 98 | ) 99 | 100 | def evaluate(self, ins: EvaluateIns) -> EvaluateRes: 101 | print(f"Client {self.cid}: evaluate") 102 | 103 | weights = fl.common.parameters_to_ndarrays(ins.parameters) 104 | 105 | # Use provided weights to update the local model 106 | self.model.set_weights(weights) 107 | 108 | # Evaluate the updated model on the local dataset 109 | testloader = torch.utils.data.DataLoader( 110 | self.testset, batch_size=100, shuffle=False 111 | ) 112 | # loss = (eval_mode(self.model_ema))(test(self.model_ema, testloader, device=DEVICE)) 113 | loss = test(self.model_ema, testloader, device=DEVICE) 114 | metrics = {} 115 | # Return the number of evaluation examples and the evaluation result (loss) 116 | return EvaluateRes( 117 | loss=loss, num_examples=len(self.testset), metrics=metrics, status=Status(code=Code.OK, message="Success") 118 | ) 119 | 120 | def main() -> None: 121 | """Load data, create and start CifarClient.""" 122 | parser = argparse.ArgumentParser(description="Flower") 123 | parser.add_argument( 124 | "--server-address", 125 | type=str, 126 | default="127.0.0.1:8080", 127 | help=f"gRPC server address (default: 127.0.0.1:8080)", 128 | ) 129 | parser.add_argument( 130 | "--cid", type=str, required=True, help="Client CID (no default)" 131 | ) 132 | parser.add_argument( 133 | "--dataset-path", type=str, help="Path to dataset (no default)" 134 | ) 135 | parser.add_argument( 136 | "--log_host", 137 | type=str, 138 | help="Logserver address (no default)", 139 | ) 140 | parser.add_argument( 141 | "--dataset", type=str, choices=["emnist","cinic10"], default="emnist" 142 | ) 143 | args = parser.parse_args() 144 | 145 | # Configure logger 146 | fl.common.logger.configure(f"client_{args.cid}", host=args.log_host) 147 | 148 | # Load model and data 149 | model = Diffusion(1).to(DEVICE) 150 | # model_ema = deepcopy(model) 151 | if args.dataset == "emnist": 152 | trainset, testset = load_data(args.dataset, args.cid) 153 | else: 154 | trainset, testset = load_data(args.dataset, args.cid, args.dataset_path) 155 | 156 | # Start client 157 | client = DiffusionClient(args.cid, model, trainset, testset) 158 | fl.client.start_client( 159 | server_address="127.0.0.1:8080", 160 | client=client) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() -------------------------------------------------------------------------------- /DDPM/data_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder, EMNIST 2 | from torchvision.transforms import Compose, Normalize, ToTensor, Resize 3 | from torch.utils.data import DataLoader, Subset 4 | import torchvision.transforms.functional as TF 5 | import numpy as np 6 | 7 | def balanced_split(dataset, num_splits, client_id): 8 | """ 9 | Splits training data into client datasets with balanced classes 10 | 11 | Args: 12 | dataset (torch.utils.data.Dataset): training data 13 | num_splits (int): number of client datasets to split into 14 | Returns: 15 | client_data (list): list of client datasets 16 | """ 17 | samples_per_class = len(dataset) // num_splits 18 | remainder = len(dataset) % num_splits 19 | num_classes = 10 20 | class_counts = [0] * num_classes # number of samples per class 21 | subset_indices = [[] for _ in range(num_splits)] # indices of samples per subset 22 | for i, (data, target) in enumerate(dataset): 23 | # Add sample to subset if number of samples per class is less than samples_per_class 24 | if class_counts[target] < samples_per_class: 25 | subset_indices[i % num_splits].append(i) 26 | class_counts[target] += 1 27 | elif remainder > 0: 28 | subset_indices[i % num_splits].append(i) 29 | class_counts[target] += 1 30 | remainder -= 1 31 | 32 | return Subset(dataset, subset_indices[int(client_id)]) 33 | 34 | def dirichlet_split(dataset, num_splits, client_id, beta=0.1): 35 | """ 36 | Splits training data into client datasets based Dirichlet distribution 37 | 38 | Args: 39 | dataset (torch.utils.data.Dataset): training data 40 | num_splits (int): number of client datasets to split into 41 | beta (float): concentration parameter of Dirichlet distribution 42 | Returns: 43 | client_data (list): list of client datasets 44 | """ 45 | # set seed for reproducibility 46 | np.random.seed(42) 47 | 48 | label_distributions = [] 49 | # Generate label distributions for each class using Dirichlet distribution 50 | for y in range(len(dataset.classes)): 51 | label_distributions.append(np.random.dirichlet(np.repeat(beta, num_splits))) 52 | 53 | labels = np.array(dataset.targets).astype(int) 54 | client_idx_map = {i: {} for i in range(num_splits)} 55 | client_size_map = {i: {} for i in range(num_splits)} 56 | 57 | for y in range(len(dataset.classes)): 58 | label_y_idx = np.where(labels == y)[0] 59 | label_y_size = len(label_y_idx) 60 | 61 | # Sample number of samples for each client from label distribution 62 | sample_size = (label_distributions[y] * label_y_size).astype(int) 63 | sample_size[num_splits - 1] += len(label_y_idx) - np.sum(sample_size) 64 | for i in range(num_splits): 65 | client_size_map[i][y] = sample_size[i] 66 | 67 | np.random.shuffle(label_y_idx) 68 | sample_interval = np.cumsum(sample_size) 69 | for i in range(num_splits): 70 | client_idx_map[i][y] = label_y_idx[(sample_interval[i - 1] if i > 0 else 0):sample_interval[i]] 71 | 72 | 73 | client_i_idx = np.concatenate(list(client_idx_map[int(client_id)].values())) 74 | np.random.shuffle(client_i_idx) 75 | return Subset(dataset, client_i_idx) 76 | 77 | def get_mean_std(dataset_id): 78 | """Get mean and std for normalization.""" 79 | if (dataset_id == "cifar10"): 80 | return [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616] 81 | elif (dataset_id == "cifar100"): 82 | return [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] 83 | elif (dataset_id == "cinic10"): 84 | return [0.47889522, 0.47227842, 0.43047404], [0.24205776, 0.23828046, 0.25874835] 85 | elif (dataset_id == "emnist"): 86 | return [0.5], [0.5] 87 | 88 | def load_data(dataset, client_id, path=None): 89 | """Load training and test set.""" 90 | mean, std = get_mean_std(dataset) 91 | 92 | if dataset == "cinic10": 93 | transform = Compose([ToTensor(), Normalize(mean, std)]) 94 | if client_id: 95 | trainset = ImageFolder(path + "/train", transform=transform) 96 | trainset = dirichlet_split(trainset, 5, client_id) 97 | # trainset = ImageFolder(path + "/train/client_" + str(client_id), transform=transform) 98 | else: 99 | trainset = None 100 | 101 | testset = ImageFolder(path + "/test", transform=transform) 102 | # reduce to 5000 images 103 | testset = balanced_split(testset, 18, 0) 104 | 105 | else: 106 | transform = Compose([lambda img: TF.rotate(img, -90), 107 | lambda img: TF.hflip(img), 108 | Resize(32), ToTensor(), Normalize(mean, std)]) 109 | if client_id: 110 | trainset = EMNIST(root='./data', train=True, download=True, transform=transform, split='digits') 111 | trainset = dirichlet_split(trainset, 5, client_id) 112 | else: 113 | trainset = None 114 | 115 | testset = EMNIST(root='./data', train=False, download=True, transform=transform, split='digits') 116 | # Reduce to 5000 images 117 | testset = balanced_split(testset, 8, 0) 118 | 119 | return trainset, testset -------------------------------------------------------------------------------- /DDPM/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import flwr as fl 5 | from collections import OrderedDict 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, main, skip=None): 9 | super().__init__() 10 | self.main = nn.Sequential(*main) 11 | self.skip = skip if skip else nn.Identity() 12 | 13 | def forward(self, input): 14 | return self.main(input) + self.skip(input) 15 | 16 | 17 | class ResConvBlock(ResidualBlock): 18 | def __init__(self, c_in, c_mid, c_out, dropout_last=True): 19 | skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False) 20 | super().__init__([ 21 | nn.Conv2d(c_in, c_mid, 3, padding=1), 22 | nn.Dropout2d(0.1, inplace=True), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(c_mid, c_out, 3, padding=1), 25 | nn.Dropout2d(0.1, inplace=True) if dropout_last else nn.Identity(), 26 | nn.ReLU(inplace=True), 27 | ], skip) 28 | 29 | 30 | class SkipBlock(nn.Module): 31 | def __init__(self, main, skip=None): 32 | super().__init__() 33 | self.main = nn.Sequential(*main) 34 | self.skip = skip if skip else nn.Identity() 35 | 36 | def forward(self, input): 37 | return torch.cat([self.main(input), self.skip(input)], dim=1) 38 | 39 | 40 | class FourierFeatures(nn.Module): 41 | def __init__(self, in_features, out_features, std=1.): 42 | super().__init__() 43 | assert out_features % 2 == 0 44 | self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std) 45 | 46 | def forward(self, input): 47 | f = 2 * math.pi * input @ self.weight.T 48 | return torch.cat([f.cos(), f.sin()], dim=-1) 49 | 50 | 51 | def expand_to_planes(input, shape): 52 | return input[..., None, None].repeat([1, 1, shape[2], shape[3]]) 53 | 54 | 55 | class Diffusion(nn.Module): 56 | def __init__(self, num_channels=3): 57 | super().__init__() 58 | c = 64 # The base channel count 59 | 60 | # The inputs to timestep_embed will approximately fall into the range 61 | # -10 to 10, so use std 0.2 for the Fourier Features. 62 | self.timestep_embed = FourierFeatures(1, 16, std=0.2) 63 | self.class_embed = nn.Embedding(10, 4) 64 | 65 | self.net = nn.Sequential( # 32x32 66 | ResConvBlock(num_channels + 16 + 4, c, c), 67 | ResConvBlock(c, c, c), 68 | SkipBlock([ 69 | nn.AvgPool2d(2), # 32x32 -> 16x16 70 | ResConvBlock(c, c * 2, c * 2), 71 | ResConvBlock(c * 2, c * 2, c * 2), 72 | SkipBlock([ 73 | nn.AvgPool2d(2), # 16x16 -> 8x8 74 | ResConvBlock(c * 2, c * 4, c * 4), 75 | ResConvBlock(c * 4, c * 4, c * 4), 76 | SkipBlock([ 77 | nn.AvgPool2d(2), # 8x8 -> 4x4 78 | ResConvBlock(c * 4, c * 8, c * 8), 79 | ResConvBlock(c * 8, c * 8, c * 8), 80 | ResConvBlock(c * 8, c * 8, c * 8), 81 | ResConvBlock(c * 8, c * 8, c * 4), 82 | nn.Upsample(scale_factor=2), 83 | ]), # 4x4 -> 8x8 84 | ResConvBlock(c * 8, c * 4, c * 4), 85 | ResConvBlock(c * 4, c * 4, c * 2), 86 | nn.Upsample(scale_factor=2), 87 | ]), # 8x8 -> 16x16 88 | ResConvBlock(c * 4, c * 2, c * 2), 89 | ResConvBlock(c * 2, c * 2, c), 90 | nn.Upsample(scale_factor=2), 91 | ]), # 16x16 -> 32x32 92 | ResConvBlock(c * 2, c, c), 93 | ResConvBlock(c, c, num_channels, dropout_last=False), 94 | ) 95 | 96 | def forward(self, input, log_snrs, cond): 97 | timestep_embed = expand_to_planes(self.timestep_embed(log_snrs[:, None]), input.shape) 98 | class_embed = expand_to_planes(self.class_embed(cond), input.shape) 99 | return self.net(torch.cat([input, class_embed, timestep_embed], dim=1)) 100 | 101 | def get_weights(self) -> fl.common.NDArray: 102 | """Get model weights as a list of NumPy ndarrays.""" 103 | return [val.cpu().numpy() for _, val in self.state_dict().items()] 104 | 105 | def set_weights(self, weights: fl.common.NDArray) -> None: 106 | """Set model weights from a list of NumPy ndarrays.""" 107 | state_dict = OrderedDict( 108 | {k: torch.tensor(v) for k, v in zip(self.state_dict().keys(), weights)} 109 | ) 110 | self.load_state_dict(state_dict, strict=True) 111 | 112 | def load_model(num_channels) -> Diffusion: 113 | """Load diffusion model.""" 114 | return Diffusion(num_channels) -------------------------------------------------------------------------------- /DDPM/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################ 4 | # Help # 5 | ############################################################ 6 | Help() 7 | { 8 | # Display Help 9 | echo "Flower Federated Learning parameters:" 10 | echo 11 | echo "Syntax: scriptTemplate [-g|h|v|V]" 12 | echo "options:" 13 | echo "c Number of clients." 14 | echo "r Number of rounds." 15 | echo "e Number of epochs." 16 | echo "d Training dataset path." 17 | echo "s Server Address." 18 | echo "h Help message." 19 | echo 20 | } 21 | 22 | ############################################################ 23 | ############################################################ 24 | # Main program # 25 | ############################################################ 26 | ############################################################ 27 | 28 | # Default values 29 | server_address="localhost:8080" 30 | num_clients=2 31 | data_path="C:\Users\ColinLaganier\Documents\UCL\Dissertation\Testing\data\cinic-10\federated\5" 32 | num_epochs=1 33 | dataset="emnist" 34 | 35 | # Get the options 36 | while getopts c:r:s:d:e:h: flag 37 | do 38 | case "${flag}" in 39 | c) num_clients=${OPTARG};; 40 | r) num_rounds=${OPTARG};; 41 | s) server_address=${OPTARG};; 42 | d) data_path=${OPTARG};; 43 | e) num_epochs=${OPTARG};; 44 | h) Help 45 | exit;; 46 | esac 47 | done 48 | 49 | set -e 50 | # --dataset-path $data_path 51 | python server.py --dataset $dataset --num-clients $num_clients --rounds $num_rounds --epochs $num_epochs --server-address $server_address& 52 | sleep 3 # Sleep for 3s to give the server enough time to start 53 | 54 | echo "Starting $num_clients clients." 55 | for ((i = 0; i < $num_clients; i++)) 56 | do 57 | echo "Starting client $i" 58 | python client.py \ 59 | --cid $i \ 60 | --server-address $server_address \ 61 | --dataset $dataset & 62 | # --dataset-path $data_path \ 63 | done 64 | echo "Started $num_clients clients." 65 | 66 | # Enable CTRL+C to stop all background processes 67 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM 68 | # Wait for all background processes to complete 69 | wait 70 | -------------------------------------------------------------------------------- /DDPM/sample.py: -------------------------------------------------------------------------------- 1 | from utils import sample 2 | from model import load_model 3 | import torch 4 | # from torchvision.utils import save_image 5 | from torchvision.transforms import Compose, Normalize, ToTensor, Resize 6 | from torchvision.transforms import functional as TF 7 | from torchvision import transforms 8 | import torchvision 9 | import random 10 | from FIDScorer import FIDScorer 11 | from torchvision.datasets import EMNIST 12 | from data_utils import balanced_split 13 | 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | #labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] 18 | labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 19 | 20 | counter = [0] * 10 21 | 22 | # CINIC-10 23 | # mean, std [0.47889522, 0.47227842, 0.43047404], [0.24205776, 0.23828046, 0.25874835] 24 | # transform = Compose([ToTensor(), Normalize(mean, std)]) 25 | # testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor()) 26 | #testset = torchvision.datasets.ImageFolder(root="../../Testing/data/cinic-10/federated/5/train/client_0", transform=transforms.ToTensor()) 27 | # if test set divide by 9 -- 90,000 28 | 29 | # EMNIST 30 | mean, std = [0.5], [0.5] 31 | transform = Compose([lambda img: TF.rotate(img, -90), lambda img: TF.hflip(img), Resize(32), ToTensor(), Normalize(mean, std)]) 32 | testset = EMNIST(root='./data', train=False, download=True, transform=transform, split='digits') 33 | testset = balanced_split(testset, 3, 0) 34 | 35 | model = load_model(1) 36 | checkpoint = torch.load("../checkpoints/20230825-164926/model_100.pth") 37 | model.load_state_dict(checkpoint) 38 | model.to("cuda:0") 39 | total_samples = 10000 40 | num_samples = 10000 41 | num_channels = 1 42 | #for i in range(1): 43 | noise = torch.randn(num_samples, num_channels, 32, 32).to(device) 44 | fakes_classes = torch.arange(10, device=device).repeat_interleave(num_samples // 10, 0) 45 | fakes = sample(model, noise, 500, 1., fakes_classes) 46 | 47 | for idx, fake in enumerate(fakes): 48 | cls = idx // (num_samples // 10) 49 | fake = TF.to_pil_image(fake.cpu().add(1).div(2).clamp(0, 1)).save("./data/synthetic/centralized/10K/{}/{}.png".format(labels[cls],counter[cls])) 50 | counter[cls] += 1 51 | 52 | 53 | # # Evaluate FID 54 | real_num = len(testset) 55 | subset = torch.utils.data.Subset(testset, random.sample(range(real_num), min(total_samples, real_num))) 56 | real_loader = torch.utils.data.DataLoader(subset, batch_size=100) 57 | fake_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(fakes, fakes_classes), batch_size=100) 58 | fid = FIDScorer().calculate_fid(real_loader, fake_loader, device=device) 59 | fake_saved = torchvision.datasets.ImageFolder(root="./data/synthetic/centralized/10K", transform=transforms.ToTensor()) 60 | fake_saved = torch.utils.data.DataLoader(fake_saved, batch_size=100) 61 | fid_2 = FIDScorer().calculate_fid(real_loader, fake_saved, device=device) 62 | print("FID: {}".format(fid)) 63 | print("FID_2: {}".format(fid_2)) 64 | -------------------------------------------------------------------------------- /DDPM/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Callable, Dict, List, Optional, Tuple, Union 3 | from flwr.common import (Parameters, Scalar) 4 | 5 | import os 6 | import time 7 | import torch 8 | import torchvision 9 | import flwr as fl 10 | import random 11 | from collections import OrderedDict 12 | from FIDScorer import FIDScorer 13 | from data_utils import load_data 14 | from utils import test, eval_mode, sample 15 | from model import load_model 16 | import numpy as np 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | 20 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | class SaveModelStrategy(fl.server.strategy.FedAvg): 23 | """Federated Averaging strategy with save model functionality.""" 24 | def aggregate_fit( 25 | self, 26 | server_round: int, 27 | results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], 28 | failures: List[Union[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes], BaseException]], 29 | ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: 30 | 31 | # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics 32 | aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures) 33 | 34 | 35 | # Save the model 36 | if aggregated_parameters is not None: 37 | model = load_model(1) 38 | aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters) 39 | params_dict = zip(model.state_dict().keys(), aggregated_ndarrays) 40 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 41 | model.load_state_dict(state_dict, strict=True) 42 | torch.save(model.state_dict(), checkpoint_path + f"/model.pth") 43 | # Save the model with round number 44 | if (server_round in [25, 50, 75, 100, 150, 200, 250]): 45 | torch.save(model.state_dict(), checkpoint_path + f"/model_{server_round}.pth") 46 | 47 | return aggregated_parameters, aggregated_metrics 48 | 49 | # class SaveProxStrategy(fl.server.strategy.FedProx): 50 | 51 | 52 | def main() -> None: 53 | """Start server and train five rounds.""" 54 | global num_epochs 55 | parser = argparse.ArgumentParser(description="Flower") 56 | parser.add_argument( 57 | "--server-address", 58 | type=str, 59 | default="0.0.0.0:8080", 60 | help=f"gRPC server address (default: 0.0.0.0:8080)", 61 | ) 62 | parser.add_argument( 63 | "--rounds", 64 | type=int, 65 | default=1, 66 | help="Number of rounds of federated learning (default: 1)", 67 | ) 68 | parser.add_argument( 69 | "--sample_fraction", 70 | type=float, 71 | default=1.0, 72 | help="Fraction of available clients used for fit/evaluate (default: 1.0)", 73 | ) 74 | parser.add_argument( 75 | "--min_sample_size", 76 | type=int, 77 | default=2, 78 | help="Minimum number of clients used for fit/evaluate (default: 2)", 79 | ) 80 | parser.add_argument( 81 | "--min_num_clients", 82 | type=int, 83 | default=2, 84 | help="Minimum number of available clients required for sampling (default: 2)", 85 | ) 86 | parser.add_argument( 87 | "--log_host", 88 | type=str, 89 | help="Logserver address (no default)", 90 | ) 91 | parser.add_argument( 92 | "--dataset-path", type=str, help="Path to dataset (no default)" 93 | ) 94 | parser.add_argument( 95 | "--num-clients", type=int, required=True, help="Number of clients (no default)" 96 | ) 97 | parser.add_argument( 98 | "--epochs", type=int, default=1, help="Number of epochs (default: 1)", 99 | ) 100 | parser.add_argument( 101 | "--dataset", type=str, choices=["emnist", "cinic10"], default="emnist" 102 | ) 103 | args = parser.parse_args() 104 | num_epochs = args.epochs 105 | 106 | # Load evaluation data 107 | if args.dataset == "emnist": 108 | _, testset = load_data(args.dataset, None) 109 | else: 110 | _, testset = load_data(args.dataset, None, args.dataset_path) 111 | # _, testset = load_data(args.dataset_path, 0) 112 | 113 | # Create strategy 114 | strategy = SaveModelStrategy( 115 | fraction_fit=args.sample_fraction, 116 | min_fit_clients=args.num_clients, 117 | min_available_clients=args.num_clients, 118 | evaluate_fn=get_evaluate_fn(testset), 119 | on_fit_config_fn=fit_config, 120 | ) 121 | 122 | 123 | # Configure logger and start server 124 | fl.common.logger.configure("server", host=args.log_host) 125 | fl.server.start_server( 126 | server_address=args.server_address, 127 | config=fl.server.ServerConfig(num_rounds=args.rounds), 128 | strategy=strategy) 129 | 130 | 131 | def fit_config(server_round: int) -> Dict[str, fl.common.Scalar]: 132 | """Return a configuration with static batch size and (local) epochs.""" 133 | config: Dict[str, fl.common.Scalar] = { 134 | "epoch_global": str(server_round), 135 | "epochs": str(num_epochs), 136 | "batch_size": str(128), 137 | } 138 | return config 139 | 140 | 141 | def get_evaluate_fn( 142 | testset: torchvision.datasets.folder.ImageFolder, 143 | ) -> Callable[[fl.common.NDArray], Optional[Tuple[float, float]]]: 144 | """Return an evaluation function for centralized evaluation.""" 145 | 146 | def evaluate(server_round, parameters: fl.common.NDArray, config) -> Optional[Tuple[float, float]]: 147 | """Use the entire CIFAR-10 test set for evaluation.""" 148 | model = load_model(1) 149 | model.set_weights(parameters) 150 | model.to(DEVICE) 151 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False) 152 | # loss = test(model, testloader, device=DEVICE) 153 | loss = 0 154 | real_num = len(testset) 155 | num_samples = 5000 156 | steps = 500 157 | eta = 1. 158 | metrics = {} 159 | 160 | if server_round % 10 == 0: 161 | with torch.no_grad(): 162 | # Generate fake images 163 | noise = torch.randn([num_samples, 1, 32, 32], device=DEVICE) 164 | fakes_classes = torch.arange(10, device=DEVICE).repeat_interleave(500, 0) 165 | fakes = sample(model, noise, steps, eta, fakes_classes) 166 | 167 | subset = torch.utils.data.Subset(testset, random.sample(range(real_num), min(num_samples, real_num))) 168 | real_loader = torch.utils.data.DataLoader(subset, batch_size=100) 169 | fake_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(fakes, fakes_classes), batch_size=100) 170 | fid = FIDScorer().calculate_fid(real_loader, fake_loader, device=DEVICE) 171 | 172 | logger.add_scalar("fid", fid, server_round) 173 | 174 | metrics = {"fid" : float(fid)} 175 | return loss, metrics 176 | 177 | return evaluate 178 | 179 | 180 | if __name__ == "__main__": 181 | global checkpoint_path, logger 182 | # Create checkpoint directory 183 | checkpoint_path = "../checkpoints/" + time.strftime("%Y%m%d-%H%M%S") 184 | os.makedirs(f"{checkpoint_path}", exist_ok=True) 185 | # Create tensorboard writer 186 | logger = SummaryWriter() 187 | 188 | main() -------------------------------------------------------------------------------- /DDPM/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import trange 3 | from contextlib import contextmanager 4 | from tqdm import tqdm 5 | 6 | 7 | def get_alphas_sigmas(log_snrs): 8 | """Returns the scaling factors for the clean image (alpha) and for the 9 | noise (sigma), given the log SNR for a timestep.""" 10 | return log_snrs.sigmoid().sqrt(), log_snrs.neg().sigmoid().sqrt() 11 | 12 | 13 | def get_ddpm_schedule(t): 14 | """Returns log SNRs for the noise schedule from the DDPM paper.""" 15 | return -torch.special.expm1(1e-4 + 10 * t**2).log() 16 | 17 | 18 | @torch.no_grad() 19 | def sample(model, x, steps, eta, classes): 20 | """Draws samples from a model given starting noise.""" 21 | ts = x.new_ones([x.shape[0]]) 22 | 23 | # Create the noise schedule 24 | t = torch.linspace(1, 0, steps + 1)[:-1] 25 | log_snrs = get_ddpm_schedule(t) 26 | alphas, sigmas = get_alphas_sigmas(log_snrs) 27 | 28 | # The sampling loop 29 | for i in trange(steps): 30 | 31 | # Get the model output (v, the predicted velocity) 32 | with torch.cuda.amp.autocast(): 33 | v = model(x, ts * log_snrs[i], classes).float() 34 | 35 | # Predict the noise and the denoised image 36 | pred = x * alphas[i] - v * sigmas[i] 37 | eps = x * sigmas[i] + v * alphas[i] 38 | 39 | # If we are not on the last timestep, compute the noisy image for the 40 | # next timestep. 41 | if i < steps - 1: 42 | # If eta > 0, adjust the scaling factor for the predicted noise 43 | # downward according to the amount of additional noise to add 44 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ 45 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() 46 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() 47 | 48 | # Recombine the predicted noise and predicted denoised image in the 49 | # correct proportions for the next step 50 | x = pred * alphas[i + 1] + eps * adjusted_sigma 51 | 52 | # Add the correct amount of fresh noise 53 | if eta: 54 | x += torch.randn_like(x) * ddim_sigma 55 | 56 | # If we are on the last timestep, output the denoised image 57 | return pred 58 | 59 | def eval_loss(model, rng, reals, classes, device): 60 | # Draw uniformly distributed continuous timesteps 61 | t = rng.draw(reals.shape[0])[:, 0].to(device) 62 | 63 | # Calculate the noise schedule parameters for those timesteps 64 | log_snrs = get_ddpm_schedule(t) 65 | alphas, sigmas = get_alphas_sigmas(log_snrs) 66 | weights = log_snrs.exp() / log_snrs.exp().add(1) 67 | 68 | # Combine the ground truth images and the noise 69 | alphas = alphas[:, None, None, None] 70 | sigmas = sigmas[:, None, None, None] 71 | noise = torch.randn_like(reals) 72 | noised_reals = reals * alphas + noise * sigmas 73 | targets = noise * alphas - reals * sigmas 74 | 75 | # Compute the model output and the loss. 76 | with torch.cuda.amp.autocast(): 77 | v = model(noised_reals, log_snrs, classes) 78 | return (v - targets).pow(2).mean([1, 2, 3]).mul(weights).mean() 79 | 80 | @contextmanager 81 | def train_mode(model, mode=True): 82 | """A context manager that places a model into training mode and restores 83 | the previous mode on exit.""" 84 | modes = [module.training for module in model.modules()] 85 | try: 86 | yield model.train(mode) 87 | finally: 88 | for i, module in enumerate(model.modules()): 89 | module.training = modes[i] 90 | 91 | 92 | def eval_mode(model): 93 | """A context manager that places a model into evaluation mode and restores 94 | the previous mode on exit.""" 95 | return train_mode(model, False) 96 | 97 | 98 | @torch.no_grad() 99 | def ema_update(model, averaged_model, decay): 100 | """Incorporates updated model parameters into an exponential moving averaged 101 | version of a model. It should be called after each optimizer step.""" 102 | model_params = dict(model.named_parameters()) 103 | averaged_params = dict(averaged_model.named_parameters()) 104 | assert model_params.keys() == averaged_params.keys() 105 | 106 | for name, param in model_params.items(): 107 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 108 | 109 | model_buffers = dict(model.named_buffers()) 110 | averaged_buffers = dict(averaged_model.named_buffers()) 111 | assert model_buffers.keys() == averaged_buffers.keys() 112 | 113 | for name, buf in model_buffers.items(): 114 | averaged_buffers[name].copy_(buf) 115 | 116 | 117 | 118 | def train(model, model_ema, trainloader, num_epoch, curr_epoch, rng, scaler, ema_decay, device): 119 | """Train the model on the training set.""" 120 | optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) 121 | for _ in range(num_epoch): 122 | curr_epoch += 1 123 | for reals, classes in tqdm(trainloader): 124 | optimizer.zero_grad() 125 | reals = reals.to(device) 126 | classes = classes.to(device) 127 | 128 | # Evaluate the loss 129 | loss = eval_loss(model, rng, reals, classes, device) 130 | 131 | # Do the optimizer step and EMA update 132 | scaler.scale(loss).backward() 133 | scaler.step(optimizer) 134 | ema_update(model, model_ema, 0.95 if curr_epoch < 20 else ema_decay) 135 | scaler.update() 136 | 137 | @torch.no_grad() 138 | @torch.random.fork_rng() 139 | # @eval_mode(model_ema) 140 | def test(model, testloader, device): 141 | """Validate the model on the test set.""" 142 | torch.manual_seed(42) 143 | eval_mode(model) 144 | rng = torch.quasirandom.SobolEngine(1, scramble=True) 145 | total_loss = 0 146 | count = 0 147 | for i, (reals, classes) in enumerate(tqdm(testloader)): 148 | reals = reals.to(device) 149 | classes = classes.to(device) 150 | 151 | loss = eval_loss(model, rng, reals, classes, device) 152 | 153 | total_loss += loss.item() * len(reals) 154 | count += len(reals) 155 | loss = total_loss / count 156 | train_mode(model) 157 | return loss 158 | 159 | -------------------------------------------------------------------------------- /DDPM_2/DiffusionCondition.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import numpy as np 7 | 8 | 9 | def extract(v, t, x_shape): 10 | """ 11 | Extract some coefficients at specified timesteps, then reshape to 12 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 13 | """ 14 | device = t.device 15 | out = torch.gather(v, index=t, dim=0).float().to(device) 16 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 17 | 18 | 19 | class GaussianDiffusionTrainer(nn.Module): 20 | def __init__(self, model, beta_1, beta_T, T): 21 | super().__init__() 22 | 23 | self.model = model 24 | self.T = T 25 | 26 | self.register_buffer( 27 | 'betas', torch.linspace(beta_1, beta_T, T).double()) 28 | alphas = 1. - self.betas 29 | alphas_bar = torch.cumprod(alphas, dim=0) 30 | 31 | # calculations for diffusion q(x_t | x_{t-1}) and others 32 | self.register_buffer( 33 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) 34 | self.register_buffer( 35 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) 36 | 37 | def forward(self, x_0, labels): 38 | """ 39 | Algorithm 1. 40 | """ 41 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) 42 | noise = torch.randn_like(x_0) 43 | x_t = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \ 44 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise 45 | loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none') 46 | return loss 47 | 48 | 49 | class GaussianDiffusionSampler(nn.Module): 50 | def __init__(self, model, beta_1, beta_T, T, w = 0.): 51 | super().__init__() 52 | 53 | self.model = model 54 | self.T = T 55 | ### In the classifier free guidence paper, w is the key to control the gudience. 56 | ### w = 0 and with label = 0 means no guidence. 57 | ### w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger. 58 | self.w = w 59 | 60 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double()) 61 | alphas = 1. - self.betas 62 | alphas_bar = torch.cumprod(alphas, dim=0) 63 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] 64 | self.register_buffer('coeff1', torch.sqrt(1. / alphas)) 65 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar)) 66 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) 67 | 68 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps): 69 | assert x_t.shape == eps.shape 70 | return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps 71 | 72 | def p_mean_variance(self, x_t, t, labels): 73 | # below: only log_variance is used in the KL computations 74 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]]) 75 | var = extract(var, t, x_t.shape) 76 | eps = self.model(x_t, t, labels) 77 | nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device)) 78 | eps = (1. + self.w) * eps - self.w * nonEps 79 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps) 80 | return xt_prev_mean, var 81 | 82 | def forward(self, x_T, labels): 83 | """ 84 | Algorithm 2. 85 | """ 86 | x_t = x_T 87 | for time_step in reversed(range(self.T)): 88 | print(time_step) 89 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step 90 | mean, var= self.p_mean_variance(x_t=x_t, t=t, labels=labels) 91 | if time_step > 0: 92 | noise = torch.randn_like(x_t) 93 | else: 94 | noise = 0 95 | x_t = mean + torch.sqrt(var) * noise 96 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor." 97 | x_0 = x_t 98 | return torch.clip(x_0, -1, 1) 99 | 100 | 101 | -------------------------------------------------------------------------------- /DDPM_2/FIDScorer.py: -------------------------------------------------------------------------------- 1 | from scipy import linalg 2 | 3 | from InceptionV3 import InceptionV3 4 | import numpy as np 5 | from torch.nn.functional import adaptive_avg_pool2d 6 | import torch 7 | 8 | 9 | class FIDScorer: 10 | def __init__(self): 11 | self.model = InceptionV3() 12 | 13 | def calculate_activation_statistics(self, images, device='cpu'): 14 | 15 | model = self.model.to(device) 16 | model.eval() 17 | 18 | activations = [] 19 | for (batch, _) in images: 20 | # if len(batch.shape) < 4: 21 | # batch = torch.unsqueeze(batch, 1) 22 | 23 | # if batch.size(1) == 1: 24 | # # greyscale 25 | # batch = batch.repeat(1, 3, 1, 1) 26 | 27 | batch = batch.to(device) 28 | pred = model(batch)[0] 29 | 30 | # If model output is not scalar, apply global spatial average pooling. 31 | # This happens if you choose a dimensionality not equal 2048. 32 | if pred.size(2) != 1 or pred.size(3) != 1: 33 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 34 | 35 | act = pred.cpu().data.numpy().reshape(pred.size(0), -1) 36 | activations.append(act) 37 | 38 | act = np.concatenate(activations, axis=0) 39 | mu = np.mean(act, axis=0) 40 | sigma = np.cov(act, rowvar=False) 41 | return mu, sigma 42 | 43 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 44 | """Numpy implementation of the Frechet Distance. 45 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 46 | and X_2 ~ N(mu_2, C_2) is 47 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 48 | """ 49 | 50 | mu1 = np.atleast_1d(mu1) 51 | mu2 = np.atleast_1d(mu2) 52 | 53 | sigma1 = np.atleast_2d(sigma1) 54 | sigma2 = np.atleast_2d(sigma2) 55 | 56 | assert mu1.shape == mu2.shape, \ 57 | 'Training and test mean vectors have different lengths' 58 | assert sigma1.shape == sigma2.shape, \ 59 | 'Training and test covariances have different dimensions' 60 | 61 | diff = mu1 - mu2 62 | 63 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 64 | if not np.isfinite(covmean).all(): 65 | msg = ('fid calculation produces singular product; ' 66 | 'adding %s to diagonal of cov estimates') % eps 67 | print(msg) 68 | offset = np.eye(sigma1.shape[0]) * eps 69 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 70 | 71 | if np.iscomplexobj(covmean): 72 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 73 | m = np.max(np.abs(covmean.imag)) 74 | raise ValueError('Imaginary component {}'.format(m)) 75 | covmean = covmean.real 76 | 77 | tr_covmean = np.trace(covmean) 78 | 79 | return (diff.dot(diff) + np.trace(sigma1) + 80 | np.trace(sigma2) - 2 * tr_covmean) 81 | 82 | def calculate_fid(self, images_real, images_fake, device): 83 | mu_1, std_1 = self.calculate_activation_statistics(images_real, device=device) 84 | mu_2, std_2 = self.calculate_activation_statistics(images_fake, device=device) 85 | 86 | """get fretched distance""" 87 | fid_value = self.calculate_frechet_distance(mu_1, std_1, mu_2, std_2) 88 | return fid_value -------------------------------------------------------------------------------- /DDPM_2/InceptionV3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | 6 | 7 | class InceptionV3(nn.Module): 8 | """Pretrained InceptionV3 network returning feature maps""" 9 | 10 | # Index of default block of inception to return, 11 | # corresponds to output of final average pooling 12 | DEFAULT_BLOCK_INDEX = 3 13 | 14 | # Maps feature dimensionality to their output blocks indices 15 | BLOCK_INDEX_BY_DIM = { 16 | 64: 0, # First max pooling features 17 | 192: 1, # Second max pooling featurs 18 | 768: 2, # Pre-aux classifier features 19 | 2048: 3 # Final average pooling features 20 | } 21 | 22 | def __init__(self, 23 | output_blocks=[DEFAULT_BLOCK_INDEX], 24 | resize_input=True, 25 | normalize_input=True, 26 | requires_grad=False): 27 | 28 | super(InceptionV3, self).__init__() 29 | 30 | self.resize_input = resize_input 31 | self.normalize_input = normalize_input 32 | self.output_blocks = sorted(output_blocks) 33 | self.last_needed_block = max(output_blocks) 34 | 35 | assert self.last_needed_block <= 3, \ 36 | 'Last possible output block index is 3' 37 | 38 | self.blocks = nn.ModuleList() 39 | 40 | inception = models.inception_v3(weights='Inception_V3_Weights.IMAGENET1K_V1') 41 | 42 | # Block 0: input to maxpool1 43 | block0 = [ 44 | inception.Conv2d_1a_3x3, 45 | inception.Conv2d_2a_3x3, 46 | inception.Conv2d_2b_3x3, 47 | nn.MaxPool2d(kernel_size=3, stride=2) 48 | ] 49 | self.blocks.append(nn.Sequential(*block0)) 50 | 51 | # Block 1: maxpool1 to maxpool2 52 | if self.last_needed_block >= 1: 53 | block1 = [ 54 | inception.Conv2d_3b_1x1, 55 | inception.Conv2d_4a_3x3, 56 | nn.MaxPool2d(kernel_size=3, stride=2) 57 | ] 58 | self.blocks.append(nn.Sequential(*block1)) 59 | 60 | # Block 2: maxpool2 to aux classifier 61 | if self.last_needed_block >= 2: 62 | block2 = [ 63 | inception.Mixed_5b, 64 | inception.Mixed_5c, 65 | inception.Mixed_5d, 66 | inception.Mixed_6a, 67 | inception.Mixed_6b, 68 | inception.Mixed_6c, 69 | inception.Mixed_6d, 70 | inception.Mixed_6e, 71 | ] 72 | self.blocks.append(nn.Sequential(*block2)) 73 | 74 | # Block 3: aux classifier to final avgpool 75 | if self.last_needed_block >= 3: 76 | block3 = [ 77 | inception.Mixed_7a, 78 | inception.Mixed_7b, 79 | inception.Mixed_7c, 80 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 81 | ] 82 | self.blocks.append(nn.Sequential(*block3)) 83 | 84 | for param in self.parameters(): 85 | param.requires_grad = requires_grad 86 | 87 | def forward(self, inp): 88 | """Get Inception feature maps 89 | Parameters 90 | ---------- 91 | inp : torch.autograd.Variable 92 | Input tensor of shape Bx3xHxW. Values are expected to be in 93 | range (0, 1) 94 | Returns 95 | ------- 96 | List of torch.autograd.Variable, corresponding to the selected output 97 | block, sorted ascending by index 98 | """ 99 | outp = [] 100 | x = inp 101 | 102 | if self.resize_input: 103 | x = F.interpolate(x, 104 | size=(299, 299), 105 | mode='bilinear', 106 | align_corners=False) 107 | 108 | if self.normalize_input: 109 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 110 | 111 | for idx, block in enumerate(self.blocks): 112 | x = block(x) 113 | if idx in self.output_blocks: 114 | outp.append(x) 115 | 116 | if idx == self.last_needed_block: 117 | break 118 | 119 | return outp -------------------------------------------------------------------------------- /DDPM_2/Scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | class GradualWarmupScheduler(_LRScheduler): 4 | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None): 5 | self.multiplier = multiplier 6 | self.total_epoch = warm_epoch 7 | self.after_scheduler = after_scheduler 8 | self.finished = False 9 | self.last_epoch = None 10 | self.base_lrs = None 11 | super().__init__(optimizer) 12 | 13 | def get_lr(self): 14 | if self.last_epoch > self.total_epoch: 15 | if self.after_scheduler: 16 | if not self.finished: 17 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 18 | self.finished = True 19 | return self.after_scheduler.get_lr() 20 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 21 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 22 | 23 | 24 | def step(self, epoch=None, metrics=None): 25 | if self.finished and self.after_scheduler: 26 | if epoch is None: 27 | self.after_scheduler.step(None) 28 | else: 29 | self.after_scheduler.step(epoch - self.total_epoch) 30 | else: 31 | return super(GradualWarmupScheduler, self).step(epoch) -------------------------------------------------------------------------------- /DDPM_2/TrainCondition.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | from typing import Dict 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from torchvision.datasets import CIFAR10 13 | from torchvision.utils import save_image 14 | 15 | from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer 16 | from DiffusionFreeGuidence.ModelCondition import UNet 17 | from Scheduler import GradualWarmupScheduler 18 | 19 | 20 | def train(modelConfig: Dict): 21 | device = torch.device(modelConfig["device"]) 22 | # dataset 23 | dataset = CIFAR10( 24 | root='./CIFAR10', train=True, download=True, 25 | transform=transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 28 | ])) 29 | dataloader = DataLoader( 30 | dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True) 31 | 32 | # model setup 33 | net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], 34 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device) 35 | if modelConfig["training_load_weight"] is not None: 36 | net_model.load_state_dict(torch.load(os.path.join( 37 | modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False) 38 | print("Model weight load down.") 39 | optimizer = torch.optim.AdamW( 40 | net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4) 41 | cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( 42 | optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1) 43 | warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"], 44 | warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler) 45 | trainer = GaussianDiffusionTrainer( 46 | net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device) 47 | 48 | # start training 49 | for e in range(modelConfig["epoch"]): 50 | with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader: 51 | for images, labels in tqdmDataLoader: 52 | # train 53 | b = images.shape[0] 54 | optimizer.zero_grad() 55 | x_0 = images.to(device) 56 | labels = labels.to(device) + 1 57 | if np.random.rand() < 0.1: 58 | labels = torch.zeros_like(labels).to(device) 59 | loss = trainer(x_0, labels).sum() / b ** 2. 60 | loss.backward() 61 | torch.nn.utils.clip_grad_norm_( 62 | net_model.parameters(), modelConfig["grad_clip"]) 63 | optimizer.step() 64 | tqdmDataLoader.set_postfix(ordered_dict={ 65 | "epoch": e, 66 | "loss: ": loss.item(), 67 | "img shape: ": x_0.shape, 68 | "LR": optimizer.state_dict()['param_groups'][0]["lr"] 69 | }) 70 | warmUpScheduler.step() 71 | torch.save(net_model.state_dict(), os.path.join( 72 | modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt")) 73 | 74 | 75 | def eval(modelConfig: Dict): 76 | device = torch.device(modelConfig["device"]) 77 | # load model and evaluate 78 | with torch.no_grad(): 79 | step = int(modelConfig["batch_size"] // 10) 80 | labelList = [] 81 | k = 0 82 | for i in range(1, modelConfig["batch_size"] + 1): 83 | labelList.append(torch.ones(size=[1]).long() * k) 84 | if i % step == 0: 85 | if k < 10 - 1: 86 | k += 1 87 | labels = torch.cat(labelList, dim=0).long().to(device) + 1 88 | print("labels: ", labels) 89 | model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], 90 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device) 91 | ckpt = torch.load(os.path.join( 92 | modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device) 93 | model.load_state_dict(ckpt) 94 | print("model load weight done.") 95 | model.eval() 96 | sampler = GaussianDiffusionSampler( 97 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device) 98 | # Sampled from standard normal distribution 99 | noisyImage = torch.randn( 100 | size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device) 101 | saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1) 102 | save_image(saveNoisy, os.path.join( 103 | modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"]) 104 | sampledImgs = sampler(noisyImage, labels) 105 | sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1] 106 | print(sampledImgs) 107 | save_image(sampledImgs, os.path.join( 108 | modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"]) -------------------------------------------------------------------------------- /DDPM_2/client.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import timeit 3 | import warnings 4 | from typing import Dict 5 | 6 | import torch 7 | import torchvision 8 | import torch.optim as optim 9 | 10 | import flwr as fl 11 | from flwr.common import ( 12 | EvaluateIns, 13 | EvaluateRes, 14 | FitIns, 15 | FitRes, 16 | NDArrays, 17 | GetParametersRes, 18 | GetParametersIns, 19 | Status, 20 | Code, 21 | ) 22 | 23 | from data_utils import load_data 24 | from utils import train 25 | from DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer 26 | from model import load_model 27 | from Scheduler import GradualWarmupScheduler 28 | from config import modelConfig 29 | 30 | warnings.filterwarnings("ignore", category=UserWarning) 31 | # DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | 33 | class DiffusionClient(fl.client.Client): 34 | """Flower client implementing CIFAR-10 image classification using PyTorch.""" 35 | 36 | def __init__( 37 | self, 38 | cid: str, 39 | device: int, 40 | trainset: torchvision.datasets.folder.ImageFolder, 41 | testset: torchvision.datasets.folder.ImageFolder, 42 | ) -> None: 43 | self.cid = cid 44 | self.device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu") 45 | 46 | # Initialize model 47 | self.model = load_model(modelConfig).to(self.device) 48 | 49 | # Set data 50 | self.trainset = trainset 51 | self.testset = testset 52 | 53 | # Training settings 54 | self.epoch = 0 55 | self.grad_clip = modelConfig["grad_clip"] 56 | 57 | self.optimizer = torch.optim.AdamW( 58 | self.model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4) 59 | 60 | self.cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( 61 | optimizer= self.optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1) 62 | 63 | self.warmUpScheduler = GradualWarmupScheduler(optimizer=self.optimizer, multiplier=modelConfig["multiplier"], 64 | warm_epoch=modelConfig["epoch"] // 10, after_scheduler=self.cosineScheduler) 65 | 66 | self.trainer = GaussianDiffusionTrainer( 67 | self.model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(self.device) 68 | 69 | # self.evaluater = GaussianDiffusionTrainer() 70 | 71 | 72 | def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: 73 | print(f"Client {self.cid}: get_parameters") 74 | 75 | weights: NDArrays = self.model.get_weights() 76 | parameters = fl.common.ndarrays_to_parameters(weights) 77 | return GetParametersRes(status=Status(code=Code.OK, message="Success"), 78 | parameters=parameters) 79 | 80 | def fit(self, ins: FitIns) -> FitRes: 81 | print(f"Client {self.cid}: fit") 82 | 83 | weights: NDArrays = fl.common.parameters_to_ndarrays(ins.parameters) 84 | config = ins.config 85 | fit_begin = timeit.default_timer() 86 | 87 | # Get training config 88 | epochs = int(config["epochs"]) 89 | batch_size = int(config["batch_size"]) 90 | 91 | # Set model parameters 92 | self.model.set_weights(weights) 93 | 94 | # Train model 95 | trainloader = torch.utils.data.DataLoader( 96 | self.trainset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True) 97 | 98 | train(self.model, self.trainer, self.optimizer, self.warmUpScheduler, self.grad_clip, trainloader, epochs, self.device) 99 | 100 | # Return the refined weights and the number of examples used for training 101 | weights_prime: NDArrays = self.model.get_weights() 102 | params_prime = fl.common.ndarrays_to_parameters(weights_prime) 103 | num_examples_train = len(self.trainset) 104 | metrics = {"duration": timeit.default_timer() - fit_begin} 105 | return FitRes( 106 | parameters=params_prime, 107 | num_examples=num_examples_train, 108 | metrics=metrics, 109 | status=Status(code=Code.OK, message="Success"), 110 | ) 111 | 112 | def evaluate(self, ins: EvaluateIns) -> EvaluateRes: 113 | print(f"Client {self.cid}: evaluate") 114 | 115 | weights = fl.common.parameters_to_ndarrays(ins.parameters) 116 | 117 | # Use provided weights to update the local model 118 | # self.model.set_weights(weights) 119 | 120 | # Evaluate the updated model on the local dataset 121 | # testloader = torch.utils.data.DataLoader( 122 | # self.testset, batch_size=100, shuffle=False 123 | # ) 124 | # loss = (eval_mode(self.model_ema))(test(self.model_ema, testloader, device=DEVICE)) 125 | loss = 0 126 | metrics = {} 127 | # Return the number of evaluation examples and the evaluation result (loss) 128 | return EvaluateRes( 129 | loss=loss, num_examples=len(self.testset), metrics=metrics, status=Status(code=Code.OK, message="Success") 130 | ) 131 | 132 | def main() -> None: 133 | """Load data, create and start CifarClient.""" 134 | parser = argparse.ArgumentParser(description="Flower") 135 | parser.add_argument( 136 | "--server_address", 137 | type=str, 138 | default="127.0.0.1:8080", 139 | help=f"gRPC server address (default: 127.0.0.1:8080)", 140 | ) 141 | parser.add_argument( 142 | "--cid", type=str, required=True, help="Client CID (no default)" 143 | ) 144 | parser.add_argument( 145 | "--dataset-path", type=str, required=True, help="Path to dataset (no default)" 146 | ) 147 | parser.add_argument( 148 | "--log_host", 149 | type=str, 150 | help="Logserver address (no default)", 151 | ) 152 | parser.add_argument( 153 | "--device", type=int, default=0, help="Device (default: 0)" 154 | ) 155 | args = parser.parse_args() 156 | 157 | # Configure logger 158 | fl.common.logger.configure(f"client_{args.cid}", host=args.log_host) 159 | 160 | # Load data 161 | trainset, testset = load_data(args.dataset_path, args.cid) 162 | 163 | # Start client 164 | client = DiffusionClient(args.cid, args.device, trainset, testset) 165 | fl.client.start_client( 166 | server_address="127.0.0.1:8080", 167 | client=client) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() -------------------------------------------------------------------------------- /DDPM_2/config.py: -------------------------------------------------------------------------------- 1 | 2 | modelConfig = { 3 | "state": "train", # or eval 4 | "epoch": 70, 5 | "batch_size": 80, 6 | "T": 500, 7 | "channel": 128, 8 | "channel_mult": [1, 2, 2, 2], 9 | "num_res_blocks": 2, 10 | "dropout": 0.15, 11 | "lr": 1e-4, 12 | "multiplier": 2.5, 13 | "beta_1": 1e-4, 14 | "beta_T": 0.028, 15 | "img_size": 32, 16 | "grad_clip": 1., 17 | "device": "cuda:0", 18 | "w": 1.8, 19 | "save_dir": "./CheckpointsCondition/", 20 | "training_load_weight": None, 21 | "test_load_weight": "ckpt_63_.pt", 22 | "sampled_dir": "./SampledImgs/", 23 | "sampledNoisyImgName": "NoisyGuidenceImgs.png", 24 | "sampledImgName": "SampledGuidenceImgs.png", 25 | "nrow": 8 26 | } -------------------------------------------------------------------------------- /DDPM_2/data_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | from torchvision.transforms import Compose, Normalize, ToTensor 3 | 4 | def get_mean_std(dataset_id): 5 | """Get mean and std for normalization.""" 6 | if (dataset_id == "cifar10"): 7 | return [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616] 8 | elif (dataset_id == "cifar100"): 9 | return [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] 10 | elif (dataset_id == "cinic10"): 11 | return [0.47889522, 0.47227842, 0.43047404], [0.24205776, 0.23828046, 0.25874835] 12 | 13 | def load_data(path, client_id): 14 | """Load training and test set.""" 15 | mean, std = get_mean_std("cinic10") 16 | transform = Compose([ToTensor(), Normalize(mean, std)]) 17 | trainset = ImageFolder(path + "/train/client_" + str(client_id), transform=transform) 18 | testset = ImageFolder(path + "/test", transform=transform) 19 | return trainset, testset -------------------------------------------------------------------------------- /DDPM_2/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from telnetlib import PRAGMA_HEARTBEAT 3 | import torch 4 | from torch import nn 5 | from torch.nn import init 6 | from torch.nn import functional as F 7 | import flwr as fl 8 | from collections import OrderedDict 9 | 10 | def drop_connect(x, drop_ratio): 11 | keep_ratio = 1.0 - drop_ratio 12 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 13 | mask.bernoulli_(p=keep_ratio) 14 | x.div_(keep_ratio) 15 | x.mul_(mask) 16 | return x 17 | 18 | class Swish(nn.Module): 19 | def forward(self, x): 20 | return x * torch.sigmoid(x) 21 | 22 | 23 | class TimeEmbedding(nn.Module): 24 | def __init__(self, T, d_model, dim): 25 | assert d_model % 2 == 0 26 | super().__init__() 27 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) 28 | emb = torch.exp(-emb) 29 | pos = torch.arange(T).float() 30 | emb = pos[:, None] * emb[None, :] 31 | assert list(emb.shape) == [T, d_model // 2] 32 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) 33 | assert list(emb.shape) == [T, d_model // 2, 2] 34 | emb = emb.view(T, d_model) 35 | 36 | self.timembedding = nn.Sequential( 37 | nn.Embedding.from_pretrained(emb, freeze=False), 38 | nn.Linear(d_model, dim), 39 | Swish(), 40 | nn.Linear(dim, dim), 41 | ) 42 | 43 | def forward(self, t): 44 | emb = self.timembedding(t) 45 | return emb 46 | 47 | 48 | class ConditionalEmbedding(nn.Module): 49 | def __init__(self, num_labels, d_model, dim): 50 | assert d_model % 2 == 0 51 | super().__init__() 52 | self.condEmbedding = nn.Sequential( 53 | nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0), 54 | nn.Linear(d_model, dim), 55 | Swish(), 56 | nn.Linear(dim, dim), 57 | ) 58 | 59 | def forward(self, t): 60 | emb = self.condEmbedding(t) 61 | return emb 62 | 63 | 64 | class DownSample(nn.Module): 65 | def __init__(self, in_ch): 66 | super().__init__() 67 | self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 68 | self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2) 69 | 70 | def forward(self, x, temb, cemb): 71 | x = self.c1(x) + self.c2(x) 72 | return x 73 | 74 | 75 | class UpSample(nn.Module): 76 | def __init__(self, in_ch): 77 | super().__init__() 78 | self.c = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 79 | self.t = nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1) 80 | 81 | def forward(self, x, temb, cemb): 82 | _, _, H, W = x.shape 83 | x = self.t(x) 84 | x = self.c(x) 85 | return x 86 | 87 | 88 | class AttnBlock(nn.Module): 89 | def __init__(self, in_ch): 90 | super().__init__() 91 | self.group_norm = nn.GroupNorm(32, in_ch) 92 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 93 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 94 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 95 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 96 | 97 | def forward(self, x): 98 | B, C, H, W = x.shape 99 | h = self.group_norm(x) 100 | q = self.proj_q(h) 101 | k = self.proj_k(h) 102 | v = self.proj_v(h) 103 | 104 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 105 | k = k.view(B, C, H * W) 106 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 107 | assert list(w.shape) == [B, H * W, H * W] 108 | w = F.softmax(w, dim=-1) 109 | 110 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 111 | h = torch.bmm(w, v) 112 | assert list(h.shape) == [B, H * W, C] 113 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 114 | h = self.proj(h) 115 | 116 | return x + h 117 | 118 | 119 | 120 | class ResBlock(nn.Module): 121 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=True): 122 | super().__init__() 123 | self.block1 = nn.Sequential( 124 | nn.GroupNorm(32, in_ch), 125 | Swish(), 126 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 127 | ) 128 | self.temb_proj = nn.Sequential( 129 | Swish(), 130 | nn.Linear(tdim, out_ch), 131 | ) 132 | self.cond_proj = nn.Sequential( 133 | Swish(), 134 | nn.Linear(tdim, out_ch), 135 | ) 136 | self.block2 = nn.Sequential( 137 | nn.GroupNorm(32, out_ch), 138 | Swish(), 139 | nn.Dropout(dropout), 140 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 141 | ) 142 | if in_ch != out_ch: 143 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 144 | else: 145 | self.shortcut = nn.Identity() 146 | if attn: 147 | self.attn = AttnBlock(out_ch) 148 | else: 149 | self.attn = nn.Identity() 150 | 151 | 152 | def forward(self, x, temb, labels): 153 | h = self.block1(x) 154 | h += self.temb_proj(temb)[:, :, None, None] 155 | h += self.cond_proj(labels)[:, :, None, None] 156 | h = self.block2(h) 157 | 158 | h = h + self.shortcut(x) 159 | h = self.attn(h) 160 | return h 161 | 162 | 163 | class UNet(nn.Module): 164 | def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout): 165 | super().__init__() 166 | tdim = ch * 4 167 | self.time_embedding = TimeEmbedding(T, ch, tdim) 168 | self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim) 169 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) 170 | self.downblocks = nn.ModuleList() 171 | chs = [ch] # record output channel when dowmsample for upsample 172 | now_ch = ch 173 | for i, mult in enumerate(ch_mult): 174 | out_ch = ch * mult 175 | for _ in range(num_res_blocks): 176 | self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout)) 177 | now_ch = out_ch 178 | chs.append(now_ch) 179 | if i != len(ch_mult) - 1: 180 | self.downblocks.append(DownSample(now_ch)) 181 | chs.append(now_ch) 182 | 183 | self.middleblocks = nn.ModuleList([ 184 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True), 185 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False), 186 | ]) 187 | 188 | self.upblocks = nn.ModuleList() 189 | for i, mult in reversed(list(enumerate(ch_mult))): 190 | out_ch = ch * mult 191 | for _ in range(num_res_blocks + 1): 192 | self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False)) 193 | now_ch = out_ch 194 | if i != 0: 195 | self.upblocks.append(UpSample(now_ch)) 196 | assert len(chs) == 0 197 | 198 | self.tail = nn.Sequential( 199 | nn.GroupNorm(32, now_ch), 200 | Swish(), 201 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) 202 | ) 203 | 204 | 205 | def forward(self, x, t, labels): 206 | # Timestep embedding 207 | temb = self.time_embedding(t) 208 | cemb = self.cond_embedding(labels) 209 | # Downsampling 210 | h = self.head(x) 211 | hs = [h] 212 | for layer in self.downblocks: 213 | h = layer(h, temb, cemb) 214 | hs.append(h) 215 | # Middle 216 | for layer in self.middleblocks: 217 | h = layer(h, temb, cemb) 218 | # Upsampling 219 | for layer in self.upblocks: 220 | if isinstance(layer, ResBlock): 221 | h = torch.cat([h, hs.pop()], dim=1) 222 | h = layer(h, temb, cemb) 223 | h = self.tail(h) 224 | 225 | assert len(hs) == 0 226 | return h 227 | 228 | def get_weights(self) -> fl.common.NDArray: 229 | """Get model weights as a list of NumPy ndarrays.""" 230 | return [val.cpu().numpy() for _, val in self.state_dict().items()] 231 | 232 | def set_weights(self, weights: fl.common.NDArray) -> None: 233 | """Set model weights from a list of NumPy ndarrays.""" 234 | state_dict = OrderedDict( 235 | {k: torch.tensor(v) for k, v in zip(self.state_dict().keys(), weights)} 236 | ) 237 | self.load_state_dict(state_dict, strict=True) 238 | 239 | def load_model(modelConfig) -> UNet: 240 | """Load diffusion model.""" 241 | return UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], 242 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]) -------------------------------------------------------------------------------- /DDPM_2/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################ 4 | # Help # 5 | ############################################################ 6 | Help() 7 | { 8 | # Display Help 9 | echo "Flower Federated Learning parameters:" 10 | echo 11 | echo "Syntax: scriptTemplate [-g|h|v|V]" 12 | echo "options:" 13 | echo "c Number of clients." 14 | echo "r Number of rounds." 15 | echo "e Number of epochs." 16 | echo "d Training dataset path." 17 | echo "s Server Address." 18 | echo "h Help message." 19 | echo 20 | } 21 | 22 | ############################################################ 23 | ############################################################ 24 | # Main program # 25 | ############################################################ 26 | ############################################################ 27 | 28 | # Default values 29 | server_address="[::]:8080" 30 | num_clients=2 31 | data_path="C:\Users\ColinLaganier\Documents\UCL\Dissertation\Testing\data\cinic-10\federated\5" 32 | num_epochs=1 33 | dev=0 34 | device_count=(0 0 0 0) 35 | 36 | # Get the options 37 | while getopts c:r:s:d:e:h: flag 38 | do 39 | case "${flag}" in 40 | c) num_clients=${OPTARG};; 41 | r) num_rounds=${OPTARG};; 42 | s) server_address=${OPTARG};; 43 | d) data_path=${OPTARG};; 44 | e) num_epochs=${OPTARG};; 45 | h) Help 46 | exit;; 47 | esac 48 | done 49 | 50 | set -e 51 | 52 | python server.py --dataset-path $data_path --num-clients $num_clients --rounds $num_rounds --epochs $num_epochs --device $device& 53 | sleep 3 # Sleep for 3s to give the server enough time to start 54 | # Increment device count 55 | ((device_count[$dev]++)) 56 | ((dev++)) 57 | 58 | echo "Starting $num_clients clients." 59 | for ((i = 0; i < $num_clients; i++)) 60 | do 61 | if [[ $((device_count[$dev])) == 2 ]] 62 | then 63 | ((dev++)) 64 | fi 65 | if [[ $dev == 4 ]] 66 | then 67 | dev=0 68 | fi 69 | echo "Starting client $i" 70 | python client.py \ 71 | --dataset-path $data_path \ 72 | --cid $i \ 73 | --device $device & 74 | # --server_address=$SERVER_ADDRESS & 75 | device_count = $((device_count + 1)) 76 | ((device_count[$dev]++)) 77 | ((dev++)) 78 | done 79 | echo "Started $num_clients clients." 80 | 81 | # Enable CTRL+C to stop all background processes 82 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM 83 | # Wait for all background processes to complete 84 | wait 85 | sh run.sh c /home/ec2-user/FedKDD/dataset/cinic-10/5 86 | sh run.sh -c 2 -r 1 -e 1 -d /home/ec2-user/FedKDD/dataset/cinic-10/5 -------------------------------------------------------------------------------- /DDPM_2/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Callable, Dict, List, Optional, Tuple, Union 3 | from flwr.common import (Parameters, Scalar) 4 | 5 | import os 6 | import time 7 | import torch 8 | import torchvision 9 | import flwr as fl 10 | import random 11 | from collections import OrderedDict 12 | import numpy as np 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from config import modelConfig 16 | from DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer 17 | from FIDScorer import FIDScorer 18 | from data_utils import load_data 19 | from model import load_model 20 | 21 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | class SaveModelStrategy(fl.server.strategy.FedAvg): 24 | """Federated Averaging strategy with save model functionality.""" 25 | def aggregate_fit( 26 | self, 27 | server_round: int, 28 | results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], 29 | failures: List[Union[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes], BaseException]], 30 | ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: 31 | 32 | # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics 33 | aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures) 34 | 35 | # Save the model 36 | if aggregated_parameters is not None: 37 | model = load_model(modelConfig) 38 | aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters) 39 | params_dict = zip(model.state_dict().keys(), aggregated_ndarrays) 40 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 41 | model.load_state_dict(state_dict, strict=True) 42 | torch.save(model.state_dict(), checkpoint_path + f"/model.pth") 43 | 44 | return aggregated_parameters, aggregated_metrics 45 | 46 | 47 | def main() -> None: 48 | """Start server and train five rounds.""" 49 | global num_epochs 50 | parser = argparse.ArgumentParser(description="Flower") 51 | parser.add_argument( 52 | "--server_address", 53 | type=str, 54 | default="0.0.0.0:8080", 55 | help=f"gRPC server address (default: 0.0.0.0:8080)", 56 | ) 57 | parser.add_argument( 58 | "--rounds", 59 | type=int, 60 | default=1, 61 | help="Number of rounds of federated learning (default: 1)", 62 | ) 63 | parser.add_argument( 64 | "--sample_fraction", 65 | type=float, 66 | default=1.0, 67 | help="Fraction of available clients used for fit/evaluate (default: 1.0)", 68 | ) 69 | parser.add_argument( 70 | "--min_sample_size", 71 | type=int, 72 | default=1, 73 | help="Minimum number of clients used for fit/evaluate (default: 2)", 74 | ) 75 | parser.add_argument( 76 | "--min_num_clients", 77 | type=int, 78 | default=1, 79 | help="Minimum number of available clients required for sampling (default: 2)", 80 | ) 81 | parser.add_argument( 82 | "--log_host", 83 | type=str, 84 | help="Logserver address (no default)", 85 | ) 86 | parser.add_argument( 87 | "--dataset-path", type=str, required=True, help="Path to dataset (no default)" 88 | ) 89 | parser.add_argument( 90 | "--num-clients", type=int, required=True, help="Number of clients (no default)" 91 | ) 92 | parser.add_argument( 93 | "--epochs", type=int, default=1, help="Number of epochs (default: 1)", 94 | ) 95 | parser.add_argument( 96 | "--device", type=int, default=0, help="Device (default: 0)" 97 | ) 98 | args = parser.parse_args() 99 | num_epochs = args.epochs 100 | device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") 101 | 102 | # Load evaluation data 103 | _, testset = load_data(args.dataset_path, 0) 104 | 105 | # Create strategy 106 | strategy = SaveModelStrategy( 107 | fraction_fit=args.sample_fraction, 108 | min_fit_clients=args.num_clients, 109 | min_available_clients=args.num_clients, 110 | evaluate_fn=get_evaluate_fn(testset, device), 111 | on_fit_config_fn=fit_config, 112 | ) 113 | 114 | # Configure logger and start server 115 | fl.common.logger.configure("server", host=args.log_host) 116 | fl.server.start_server( 117 | server_address=args.server_address, 118 | config=fl.server.ServerConfig(num_rounds=args.rounds), 119 | strategy=strategy) 120 | 121 | 122 | def fit_config(server_round: int) -> Dict[str, fl.common.Scalar]: 123 | """Return a configuration with static batch size and (local) epochs.""" 124 | config: Dict[str, fl.common.Scalar] = { 125 | "epoch_global": str(server_round), 126 | "epochs": str(num_epochs), 127 | "batch_size": str(80), 128 | } 129 | return config 130 | 131 | def get_evaluate_fn( 132 | testset: torchvision.datasets.folder.ImageFolder, 133 | device: torch.device, 134 | ) -> Callable[[fl.common.NDArray], Optional[Tuple[float, float]]]: 135 | """Return an evaluation function for centralized evaluation.""" 136 | 137 | def evaluate(server_round, weights: fl.common.NDArray, config) -> Optional[Tuple[float, float]]: 138 | """Use the entire CIFAR-10 test set for evaluation.""" 139 | # Load model and set weights 140 | model = load_model(modelConfig) 141 | model.set_weights(weights) 142 | model.to(device) 143 | model.eval() 144 | 145 | loss = 0 146 | real_num = len(testset) 147 | num_samples = 1000 148 | num_batches = 5 149 | batch_size = num_samples // num_batches 150 | if server_round % 5 == 0: 151 | with torch.no_grad(): 152 | # Generate fake labels 153 | 154 | # Store fake images generated 155 | fakes = [] 156 | fakes_classes = torch.arange(1,11).repeat_interleave(num_samples // 10, 0).to(device) 157 | 158 | for idx in range(num_batches): 159 | sampler = GaussianDiffusionSampler( 160 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device) 161 | # Sampled from standard normal distribution 162 | noise = torch.randn( 163 | size=[batch_size, 3, modelConfig["img_size"], modelConfig["img_size"]], device=device) 164 | fakes_batch = sampler(noise, fakes_classes[idx * batch_size : (idx + 1) * batch_size]) 165 | fakes_batch = fakes_batch * 0.5 + 0.5 # [0 ~ 1] 166 | fakes.append(fakes_batch) 167 | 168 | fakes = torch.cat(fakes, dim=0) 169 | print(fakes.shape) 170 | print(fakes_classes.shape) 171 | 172 | subset = torch.utils.data.Subset(testset, random.sample(range(real_num), min(num_samples, real_num))) 173 | real_loader = torch.utils.data.DataLoader(subset, batch_size=100) 174 | fake_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(fakes, fakes_classes), batch_size=100) 175 | fid = FIDScorer().calculate_fid(real_loader, fake_loader, device=device) 176 | logger.add_scalar("fid", fid, server_round) 177 | 178 | metrics = {"fid" : float(fid)} 179 | else: 180 | metrics = {} 181 | 182 | return loss, metrics 183 | 184 | return evaluate 185 | 186 | 187 | if __name__ == "__main__": 188 | global checkpoint_path, logger 189 | # Create checkpoint directory 190 | checkpoint_path = "../checkpoints/" + time.strftime("%Y%m%d-%H%M%S") 191 | os.makedirs(f"{checkpoint_path}", exist_ok=True) 192 | # Create tensorboard writer 193 | logger = SummaryWriter() 194 | 195 | main() -------------------------------------------------------------------------------- /DDPM_2/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import trange 3 | from contextlib import contextmanager 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | def train(model, trainer, optimizer, warmUpScheduler, grad_clip, trainloader, num_epoch, device): 8 | """Train the model on the training set.""" 9 | # optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) 10 | for epoch in range(num_epoch): 11 | with tqdm(trainloader, dynamic_ncols=True) as tqdmDataLoader: 12 | for images, labels in tqdmDataLoader: 13 | # train 14 | b = images.shape[0] 15 | optimizer.zero_grad() 16 | images, labels = images.to(device), labels.to(device) + 1 17 | if np.random.rand() < 0.1: 18 | labels = torch.zeros_like(labels).to(device) 19 | loss = trainer(images, labels).sum() / b ** 2. 20 | loss.backward() 21 | torch.nn.utils.clip_grad_norm_( 22 | model.parameters(), grad_clip) 23 | optimizer.step() 24 | tqdmDataLoader.set_postfix(ordered_dict={ 25 | "epoch": epoch, 26 | "loss: ": loss.item(), 27 | "img shape: ": images.shape, 28 | "LR": optimizer.state_dict()['param_groups'][0]["lr"] 29 | }) 30 | warmUpScheduler.step() 31 | # Save checkpoints 32 | # torch.save(net_model.state_dict(), os.path.join( 33 | # modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt")) 34 | 35 | # def eval(modelConfig: Dict, model): 36 | # device = torch.device(device) 37 | # model.eval() 38 | 39 | # with torch.no_grad(): 40 | # # Generate fake labels 41 | # labels = torch.arange(1,11).repeat_interleave(10, 0) 42 | 43 | # sampler = GaussianDiffusionSampler( 44 | # model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device) 45 | # # Sampled from standard normal distribution 46 | # noisyImage = torch.randn( 47 | # size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device) 48 | # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1) 49 | # # save_image(saveNoisy, os.path.join( 50 | # # modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"]) 51 | # sampledImgs = sampler(noisyImage, labels) 52 | # sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1] 53 | # print(sampledImgs) 54 | # # save_image(sampledImgs, os.path.join( 55 | # # modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"]) 56 | 57 | 58 | 59 | 60 | # loss = 0 61 | # real_num = len(testset) 62 | # num_samples = 1000 63 | # steps = 500 64 | # eta = 1. 65 | 66 | # # Generate fake images 67 | # noise = torch.randn([num_samples, 3, 32, 32], device=DEVICE) 68 | # fakes_classes = torch.arange(10, device=DEVICE).repeat_interleave(100, 0) 69 | # fakes = sample(model, noise, steps, eta, fakes_classes) 70 | 71 | # subset = torch.utils.data.Subset(testset, random.sample(range(real_num), min(num_samples, real_num))) 72 | # real_loader = torch.utils.data.DataLoader(subset, batch_size=100) 73 | # fake_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(fakes, fakes_classes), batch_size=100) 74 | # fid = FIDScorer().calculate_fid(real_loader, fake_loader, device=DEVICE) 75 | # logger.add_scalar("fid", fid, server_round) 76 | 77 | # metrics = {"fid" : float(fid)} 78 | # # return loss, metrics 79 | 80 | # with tqdm(testloader, dynamic_ncols=True) as tqdmDataLoader: 81 | # for images, labels in tqdmDataLoader: 82 | 83 | -------------------------------------------------------------------------------- /DiT/README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Model 2 | 3 | The proposed implementation relies on DiT (Scalable Diffusion Models with Transformers) to create a diffusion model for the clients to generate data. -------------------------------------------------------------------------------- /DiT/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /DiT/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /DiT/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /DiT/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Colin Laganier 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FederatedDiffusionModels 2 | 3 | Federated Learning of diffusion models using the Flower framework. Currently supports training of DDPM on CINIC-10 dataset using FedAvg strategy. 4 | 5 | TODO: 6 | - [ ] Add FedProx strategy and compare results 7 | - [ ] Add DiT model training 8 | 9 | -------------------------------------------------------------------------------- /U-ViT/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fan Bao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /U-ViT/client.py: -------------------------------------------------------------------------------- 1 | # Flower client, adapted from Pytorch quickstart example 2 | class FlowerClient(fl.client.NumPyClient): 3 | def __init__(self, cid: str, fed_dir_data: str): 4 | self.cid = cid 5 | self.fed_dir = Path(fed_dir_data) 6 | self.properties: Dict[str, Scalar] = {"tensor_type": "numpy.ndarray"} 7 | 8 | # Instantiate model 9 | self.net = Net() 10 | 11 | # Determine device 12 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | def get_parameters(self, config): 15 | return get_params(self.net) 16 | 17 | def fit(self, parameters, config): 18 | set_params(self.net, parameters) 19 | 20 | # Load data for this client and get trainloader 21 | num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"]) 22 | trainloader = get_dataloader( 23 | self.fed_dir, 24 | self.cid, 25 | is_train=True, 26 | batch_size=config["batch_size"], 27 | workers=num_workers, 28 | ) 29 | 30 | # Send model to device 31 | self.net.to(self.device) 32 | 33 | # Train 34 | train(self.net, trainloader, epochs=config["epochs"], device=self.device) 35 | 36 | # Return local model and statistics 37 | return get_params(self.net), len(trainloader.dataset), {} 38 | 39 | def evaluate(self, parameters, config): 40 | set_params(self.net, parameters) 41 | 42 | # Load data for this client and get trainloader 43 | num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"]) 44 | valloader = get_dataloader( 45 | self.fed_dir, self.cid, is_train=False, batch_size=50, workers=num_workers 46 | ) 47 | 48 | # Send model to device 49 | self.net.to(self.device) 50 | 51 | # Evaluate 52 | loss, accuracy = test(self.net, valloader, device=self.device) 53 | 54 | # Return statistics 55 | return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} -------------------------------------------------------------------------------- /U-ViT/configs/celeba64_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=500000, 17 | batch_size=128, 18 | mode='uncond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0002, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=512, 41 | depth=12, 42 | num_heads=8, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=-1, 47 | ) 48 | 49 | config.dataset = d( 50 | name='celeba', 51 | path='assets/datasets/celeba', 52 | resolution=64, 53 | ) 54 | 55 | config.sample = d( 56 | sample_steps=1000, 57 | n_samples=50000, 58 | mini_batch_size=500, 59 | algorithm='euler_maruyama_sde', 60 | path='' 61 | ) 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /U-ViT/configs/cifar10_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=500000, 17 | batch_size=128, 18 | mode='uncond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0002, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.999), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=2500 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=32, 39 | patch_size=2, 40 | embed_dim=512, 41 | depth=12, 42 | num_heads=8, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=-1, 47 | ) 48 | 49 | config.dataset = d( 50 | name='cifar10', 51 | path='assets/datasets/cifar10', 52 | random_flip=True, 53 | ) 54 | 55 | config.sample = d( 56 | sample_steps=1000, 57 | n_samples=50000, 58 | mini_batch_size=500, 59 | algorithm='euler_maruyama_sde', 60 | path='' 61 | ) 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet256_uvit_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet256_features', 59 | path='assets/datasets/imagenet256_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | sample_steps=50, 66 | n_samples=50000, 67 | mini_batch_size=50, # the decoder is large 68 | algorithm='dpm_solver', 69 | cfg=True, 70 | scale=0.4, 71 | path='' 72 | ) 73 | 74 | return config 75 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet256_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=300000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1024, 47 | depth=20, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True 54 | ) 55 | 56 | config.dataset = d( 57 | name='imagenet256_features', 58 | path='assets/datasets/imagenet256_features', 59 | cfg=True, 60 | p_uncond=0.15 61 | ) 62 | 63 | config.sample = d( 64 | sample_steps=50, 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.4, 70 | path='' 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet512_uvit_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=64, 44 | patch_size=4, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet512_features', 59 | path='assets/datasets/imagenet512_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | sample_steps=50, 66 | n_samples=50000, 67 | mini_batch_size=50, # the decoder is large 68 | algorithm='dpm_solver', 69 | cfg=True, 70 | scale=0.7, 71 | path='' 72 | ) 73 | 74 | return config 75 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet512_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=64, 44 | patch_size=4, 45 | in_chans=4, 46 | embed_dim=1024, 47 | depth=20, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True 54 | ) 55 | 56 | config.dataset = d( 57 | name='imagenet512_features', 58 | path='assets/datasets/imagenet512_features', 59 | cfg=True, 60 | p_uncond=0.15 61 | ) 62 | 63 | config.sample = d( 64 | sample_steps=50, 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.7, 70 | path='' 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet64_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=300000, 17 | batch_size=1024, 18 | mode='cond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0003, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=1024, 41 | depth=20, 42 | num_heads=16, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=1000, 47 | use_checkpoint=True 48 | ) 49 | 50 | config.dataset = d( 51 | name='imagenet', 52 | path='assets/datasets/ImageNet', 53 | resolution=64, 54 | ) 55 | 56 | config.sample = d( 57 | sample_steps=50, 58 | n_samples=50000, 59 | mini_batch_size=200, 60 | algorithm='dpm_solver', 61 | path='' 62 | ) 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet64_uvit_mid.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=300000, 17 | batch_size=1024, 18 | mode='cond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0003, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=768, 41 | depth=16, 42 | num_heads=12, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=1000, 47 | use_checkpoint=True 48 | ) 49 | 50 | config.dataset = d( 51 | name='imagenet', 52 | path='assets/datasets/ImageNet', 53 | resolution=64, 54 | ) 55 | 56 | config.sample = d( 57 | sample_steps=50, 58 | n_samples=50000, 59 | mini_batch_size=200, 60 | algorithm='dpm_solver', 61 | path='' 62 | ) 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /U-ViT/configs/mscoco_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.z_shape = (4, 32, 32) 14 | 15 | config.autoencoder = d( 16 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', 17 | scale_factor=0.23010 18 | ) 19 | 20 | config.train = d( 21 | n_steps=1000000, 22 | batch_size=256, 23 | log_interval=10, 24 | eval_interval=5000, 25 | save_interval=50000, 26 | ) 27 | 28 | config.optimizer = d( 29 | name='adamw', 30 | lr=0.0002, 31 | weight_decay=0.03, 32 | betas=(0.9, 0.9), 33 | ) 34 | 35 | config.lr_scheduler = d( 36 | name='customized', 37 | warmup_steps=5000 38 | ) 39 | 40 | config.nnet = d( 41 | name='uvit_t2i', 42 | img_size=32, 43 | in_chans=4, 44 | patch_size=2, 45 | embed_dim=512, 46 | depth=12, 47 | num_heads=8, 48 | mlp_ratio=4, 49 | qkv_bias=False, 50 | mlp_time_embed=False, 51 | clip_dim=768, 52 | num_clip_token=77 53 | ) 54 | 55 | config.dataset = d( 56 | name='mscoco256_features', 57 | path='assets/datasets/coco256_features', 58 | cfg=True, 59 | p_uncond=0.1 60 | ) 61 | 62 | config.sample = d( 63 | sample_steps=50, 64 | n_samples=30000, 65 | mini_batch_size=50, 66 | cfg=True, 67 | scale=1., 68 | path='' 69 | ) 70 | 71 | return config 72 | -------------------------------------------------------------------------------- /U-ViT/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import DataLoader 6 | import shutil 7 | from PIL import Image 8 | from torchvision.datasets import VisionDataset 9 | from typing import Callable, Optional, Tuple, Any 10 | from common import create_lda_partitions 11 | 12 | 13 | def get_dataset(path_to_data: Path, cid: str, partition: str): 14 | # generate path to cid's data 15 | path_to_data = path_to_data / cid / (partition + ".pt") 16 | 17 | return TorchVision_FL(path_to_data, transform=cifar10Transformation()) 18 | 19 | 20 | def get_dataloader( 21 | path_to_data: str, cid: str, is_train: bool, batch_size: int, workers: int 22 | ): 23 | """Generates trainset/valset object and returns appropiate dataloader.""" 24 | 25 | partition = "train" if is_train else "val" 26 | dataset = get_dataset(Path(path_to_data), cid, partition) 27 | 28 | # we use as number of workers all the cpu cores assigned to this actor 29 | kwargs = {"num_workers": workers, "pin_memory": True, "drop_last": False} 30 | return DataLoader(dataset, batch_size=batch_size, **kwargs) 31 | 32 | 33 | def get_random_id_splits(total: int, val_ratio: float, shuffle: bool = True): 34 | """Splits a list of length `total` into two following a 35 | (1-val_ratio):val_ratio partitioning. 36 | 37 | By default the indices are shuffled before creating the split and 38 | returning. 39 | """ 40 | 41 | if isinstance(total, int): 42 | indices = list(range(total)) 43 | else: 44 | indices = total 45 | 46 | split = int(np.floor(val_ratio * len(indices))) 47 | # print(f"Users left out for validation (ratio={val_ratio}) = {split} ") 48 | if shuffle: 49 | np.random.shuffle(indices) 50 | return indices[split:], indices[:split] 51 | 52 | 53 | def do_fl_partitioning(path_to_dataset, pool_size, alpha, num_classes, val_ratio=0.0): 54 | """Torchvision (e.g. CIFAR-10) datasets using LDA.""" 55 | 56 | images, labels = torch.load(path_to_dataset) 57 | idx = np.array(range(len(images))) 58 | dataset = [idx, labels] 59 | partitions, _ = create_lda_partitions( 60 | dataset, num_partitions=pool_size, concentration=alpha, accept_imbalanced=True 61 | ) 62 | 63 | # Show label distribution for first partition (purely informative) 64 | partition_zero = partitions[0][1] 65 | hist, _ = np.histogram(partition_zero, bins=list(range(num_classes + 1))) 66 | print( 67 | f"Class histogram for 0-th partition (alpha={alpha}, {num_classes} classes): {hist}" 68 | ) 69 | 70 | # now save partitioned dataset to disk 71 | # first delete dir containing splits (if exists), then create it 72 | splits_dir = path_to_dataset.parent / "federated" 73 | if splits_dir.exists(): 74 | shutil.rmtree(splits_dir) 75 | Path.mkdir(splits_dir, parents=True) 76 | 77 | for p in range(pool_size): 78 | labels = partitions[p][1] 79 | image_idx = partitions[p][0] 80 | imgs = images[image_idx] 81 | 82 | # create dir 83 | Path.mkdir(splits_dir / str(p)) 84 | 85 | if val_ratio > 0.0: 86 | # split data according to val_ratio 87 | train_idx, val_idx = get_random_id_splits(len(labels), val_ratio) 88 | val_imgs = imgs[val_idx] 89 | val_labels = labels[val_idx] 90 | 91 | with open(splits_dir / str(p) / "val.pt", "wb") as f: 92 | torch.save([val_imgs, val_labels], f) 93 | 94 | # remaining images for training 95 | imgs = imgs[train_idx] 96 | labels = labels[train_idx] 97 | 98 | with open(splits_dir / str(p) / "train.pt", "wb") as f: 99 | torch.save([imgs, labels], f) 100 | 101 | return splits_dir 102 | 103 | 104 | def cifar10Transformation(): 105 | return transforms.Compose( 106 | [ 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 109 | ] 110 | ) 111 | 112 | 113 | class TorchVision_FL(VisionDataset): 114 | """This is just a trimmed down version of torchvision.datasets.MNIST. 115 | 116 | Use this class by either passing a path to a torch file (.pt) 117 | containing (data, targets) or pass the data, targets directly 118 | instead. 119 | """ 120 | 121 | def __init__( 122 | self, 123 | path_to_data=None, 124 | data=None, 125 | targets=None, 126 | transform: Optional[Callable] = None, 127 | ) -> None: 128 | path = path_to_data.parent if path_to_data else None 129 | super(TorchVision_FL, self).__init__(path, transform=transform) 130 | self.transform = transform 131 | 132 | if path_to_data: 133 | # load data and targets (path_to_data points to an specific .pt file) 134 | self.data, self.targets = torch.load(path_to_data) 135 | else: 136 | self.data = data 137 | self.targets = targets 138 | 139 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 140 | img, target = self.data[index], int(self.targets[index]) 141 | 142 | # doing this so that it is consistent with all other datasets 143 | # to return a PIL Image 144 | if not isinstance(img, Image.Image): # if not PIL image 145 | if not isinstance(img, np.ndarray): # if torch tensor 146 | img = img.numpy() 147 | 148 | img = Image.fromarray(img) 149 | 150 | if self.transform is not None: 151 | img = self.transform(img) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return img, target 157 | 158 | def __len__(self) -> int: 159 | return len(self.data) 160 | 161 | 162 | def get_cifar_10(path_to_data="./data"): 163 | """Downloads CIFAR10 dataset and generates a unified training set (it will 164 | be partitioned later using the LDA partitioning mechanism.""" 165 | 166 | # download dataset and load train set 167 | train_set = datasets.CIFAR10(root=path_to_data, train=True, download=True) 168 | 169 | # fuse all data splits into a single "training.pt" 170 | data_loc = Path(path_to_data) / "cifar-10-batches-py" 171 | training_data = data_loc / "training.pt" 172 | print("Generating unified CIFAR dataset") 173 | torch.save([train_set.data, np.array(train_set.targets)], training_data) 174 | 175 | test_set = datasets.CIFAR10( 176 | root=path_to_data, train=False, transform=cifar10Transformation() 177 | ) 178 | 179 | # returns path where training data is and testset 180 | return training_data, test_set 181 | -------------------------------------------------------------------------------- /U-ViT/eval.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | 14 | 15 | def evaluate(config): 16 | if config.get('benchmark', False): 17 | torch.backends.cudnn.benchmark = True 18 | torch.backends.cudnn.deterministic = False 19 | 20 | mp.set_start_method('spawn') 21 | accelerator = accelerate.Accelerator() 22 | device = accelerator.device 23 | accelerate.utils.set_seed(config.seed, device_specific=True) 24 | logging.info(f'Process {accelerator.process_index} using device: {device}') 25 | 26 | config.mixed_precision = accelerator.mixed_precision 27 | config = ml_collections.FrozenConfigDict(config) 28 | if accelerator.is_main_process: 29 | utils.set_logger(log_level='info', fname=config.output_path) 30 | else: 31 | utils.set_logger(log_level='error') 32 | builtins.print = lambda *args: None 33 | 34 | dataset = get_dataset(**config.dataset) 35 | 36 | nnet = utils.get_nnet(**config.nnet) 37 | nnet = accelerator.prepare(nnet) 38 | logging.info(f'load nnet from {config.nnet_path}') 39 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 40 | nnet.eval() 41 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 42 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 43 | def cfg_nnet(x, timesteps, y): 44 | _cond = nnet(x, timesteps, y=y) 45 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 46 | return _cond + config.sample.scale * (_cond - _uncond) 47 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 48 | else: 49 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 50 | 51 | 52 | logging.info(config.sample) 53 | assert os.path.exists(dataset.fid_stat) 54 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 55 | 56 | def sample_fn(_n_samples): 57 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 58 | if config.train.mode == 'uncond': 59 | kwargs = dict() 60 | elif config.train.mode == 'cond': 61 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 62 | else: 63 | raise NotImplementedError 64 | 65 | if config.sample.algorithm == 'euler_maruyama_sde': 66 | rsde = sde.ReverseSDE(score_model) 67 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 68 | elif config.sample.algorithm == 'euler_maruyama_ode': 69 | rsde = sde.ODE(score_model) 70 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 71 | elif config.sample.algorithm == 'dpm_solver': 72 | noise_schedule = NoiseScheduleVP(schedule='linear') 73 | model_fn = model_wrapper( 74 | score_model.noise_pred, 75 | noise_schedule, 76 | time_input_type='0', 77 | model_kwargs=kwargs 78 | ) 79 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 80 | return dpm_solver.sample( 81 | x_init, 82 | steps=config.sample.sample_steps, 83 | eps=1e-4, 84 | adaptive_step_size=False, 85 | fast_version=True, 86 | ) 87 | else: 88 | raise NotImplementedError 89 | 90 | with tempfile.TemporaryDirectory() as temp_path: 91 | path = config.sample.path or temp_path 92 | if accelerator.is_main_process: 93 | os.makedirs(path, exist_ok=True) 94 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 95 | if accelerator.is_main_process: 96 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 97 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 98 | 99 | 100 | from absl import flags 101 | from absl import app 102 | from ml_collections import config_flags 103 | import os 104 | 105 | 106 | FLAGS = flags.FLAGS 107 | config_flags.DEFINE_config_file( 108 | "config", None, "Training configuration.", lock_config=False) 109 | flags.mark_flags_as_required(["config"]) 110 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 111 | flags.DEFINE_string("output_path", None, "The path to output log.") 112 | 113 | 114 | def main(argv): 115 | config = FLAGS.config 116 | config.nnet_path = FLAGS.nnet_path 117 | config.output_path = FLAGS.output_path 118 | evaluate(config) 119 | 120 | 121 | if __name__ == "__main__": 122 | app.run(main) 123 | -------------------------------------------------------------------------------- /U-ViT/eval_ldm.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import libs.autoencoder 14 | 15 | 16 | def evaluate(config): 17 | if config.get('benchmark', False): 18 | torch.backends.cudnn.benchmark = True 19 | torch.backends.cudnn.deterministic = False 20 | 21 | mp.set_start_method('spawn') 22 | accelerator = accelerate.Accelerator() 23 | device = accelerator.device 24 | accelerate.utils.set_seed(config.seed, device_specific=True) 25 | logging.info(f'Process {accelerator.process_index} using device: {device}') 26 | 27 | config.mixed_precision = accelerator.mixed_precision 28 | config = ml_collections.FrozenConfigDict(config) 29 | if accelerator.is_main_process: 30 | utils.set_logger(log_level='info', fname=config.output_path) 31 | else: 32 | utils.set_logger(log_level='error') 33 | builtins.print = lambda *args: None 34 | 35 | dataset = get_dataset(**config.dataset) 36 | 37 | nnet = utils.get_nnet(**config.nnet) 38 | nnet = accelerator.prepare(nnet) 39 | logging.info(f'load nnet from {config.nnet_path}') 40 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 41 | nnet.eval() 42 | 43 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 44 | autoencoder.to(device) 45 | 46 | @torch.cuda.amp.autocast() 47 | def encode(_batch): 48 | return autoencoder.encode(_batch) 49 | 50 | @torch.cuda.amp.autocast() 51 | def decode(_batch): 52 | return autoencoder.decode(_batch) 53 | 54 | def decode_large_batch(_batch): 55 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 56 | xs = [] 57 | pt = 0 58 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 59 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 60 | pt += _decode_mini_batch_size 61 | xs.append(x) 62 | xs = torch.concat(xs, dim=0) 63 | assert xs.size(0) == _batch.size(0) 64 | return xs 65 | 66 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 67 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 68 | def cfg_nnet(x, timesteps, y): 69 | _cond = nnet(x, timesteps, y=y) 70 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 71 | return _cond + config.sample.scale * (_cond - _uncond) 72 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 73 | else: 74 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 75 | 76 | logging.info(config.sample) 77 | assert os.path.exists(dataset.fid_stat) 78 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 79 | 80 | def sample_fn(_n_samples): 81 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 82 | if config.train.mode == 'uncond': 83 | kwargs = dict() 84 | elif config.train.mode == 'cond': 85 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 86 | else: 87 | raise NotImplementedError 88 | 89 | if config.sample.algorithm == 'euler_maruyama_sde': 90 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 91 | elif config.sample.algorithm == 'euler_maruyama_ode': 92 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 93 | elif config.sample.algorithm == 'dpm_solver': 94 | noise_schedule = NoiseScheduleVP(schedule='linear') 95 | model_fn = model_wrapper( 96 | score_model.noise_pred, 97 | noise_schedule, 98 | time_input_type='0', 99 | model_kwargs=kwargs 100 | ) 101 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 102 | _z = dpm_solver.sample( 103 | _z_init, 104 | steps=config.sample.sample_steps, 105 | eps=1e-4, 106 | adaptive_step_size=False, 107 | fast_version=True, 108 | ) 109 | else: 110 | raise NotImplementedError 111 | return decode_large_batch(_z) 112 | 113 | with tempfile.TemporaryDirectory() as temp_path: 114 | path = config.sample.path or temp_path 115 | if accelerator.is_main_process: 116 | os.makedirs(path, exist_ok=True) 117 | logging.info(f'Samples are saved in {path}') 118 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 119 | if accelerator.is_main_process: 120 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 121 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 122 | 123 | 124 | from absl import flags 125 | from absl import app 126 | from ml_collections import config_flags 127 | import os 128 | 129 | 130 | FLAGS = flags.FLAGS 131 | config_flags.DEFINE_config_file( 132 | "config", None, "Training configuration.", lock_config=False) 133 | flags.mark_flags_as_required(["config"]) 134 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 135 | flags.DEFINE_string("output_path", None, "The path to output log.") 136 | 137 | 138 | def main(argv): 139 | config = FLAGS.config 140 | config.nnet_path = FLAGS.nnet_path 141 | config.output_path = FLAGS.output_path 142 | evaluate(config) 143 | 144 | 145 | if __name__ == "__main__": 146 | app.run(main) 147 | -------------------------------------------------------------------------------- /U-ViT/eval_ldm_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | from datasets import get_dataset 8 | import tempfile 9 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 10 | from absl import logging 11 | import builtins 12 | import libs.autoencoder 13 | 14 | 15 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 16 | _betas = ( 17 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 18 | ) 19 | return _betas.numpy() 20 | 21 | 22 | def evaluate(config): 23 | if config.get('benchmark', False): 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.deterministic = False 26 | 27 | mp.set_start_method('spawn') 28 | accelerator = accelerate.Accelerator() 29 | device = accelerator.device 30 | accelerate.utils.set_seed(config.seed, device_specific=True) 31 | logging.info(f'Process {accelerator.process_index} using device: {device}') 32 | 33 | config.mixed_precision = accelerator.mixed_precision 34 | config = ml_collections.FrozenConfigDict(config) 35 | if accelerator.is_main_process: 36 | utils.set_logger(log_level='info', fname=config.output_path) 37 | else: 38 | utils.set_logger(log_level='error') 39 | builtins.print = lambda *args: None 40 | 41 | dataset = get_dataset(**config.dataset) 42 | 43 | nnet = utils.get_nnet(**config.nnet) 44 | nnet = accelerator.prepare(nnet) 45 | logging.info(f'load nnet from {config.nnet_path}') 46 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 47 | nnet.eval() 48 | 49 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 50 | autoencoder.to(device) 51 | 52 | @torch.cuda.amp.autocast() 53 | def encode(_batch): 54 | return autoencoder.encode(_batch) 55 | 56 | @torch.cuda.amp.autocast() 57 | def decode(_batch): 58 | return autoencoder.decode(_batch) 59 | 60 | def decode_large_batch(_batch): 61 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 62 | xs = [] 63 | pt = 0 64 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 65 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 66 | pt += _decode_mini_batch_size 67 | xs.append(x) 68 | xs = torch.concat(xs, dim=0) 69 | assert xs.size(0) == _batch.size(0) 70 | return xs 71 | 72 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 73 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 74 | def cfg_nnet(x, timesteps, y): 75 | _cond = nnet(x, timesteps, y=y) 76 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 77 | return _cond + config.sample.scale * (_cond - _uncond) 78 | else: 79 | def cfg_nnet(x, timesteps, y): 80 | _cond = nnet(x, timesteps, y=y) 81 | return _cond 82 | 83 | logging.info(config.sample) 84 | assert os.path.exists(dataset.fid_stat) 85 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 86 | 87 | _betas = stable_diffusion_beta_schedule() 88 | N = len(_betas) 89 | 90 | def sample_z(_n_samples, _sample_steps, **kwargs): 91 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 92 | 93 | if config.sample.algorithm == 'dpm_solver': 94 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 95 | 96 | def model_fn(x, t_continuous): 97 | t = t_continuous * N 98 | eps_pre = cfg_nnet(x, t, **kwargs) 99 | return eps_pre 100 | 101 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 102 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 103 | 104 | else: 105 | raise NotImplementedError 106 | 107 | return _z 108 | 109 | def sample_fn(_n_samples): 110 | if config.train.mode == 'uncond': 111 | kwargs = dict() 112 | elif config.train.mode == 'cond': 113 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 114 | else: 115 | raise NotImplementedError 116 | _z = sample_z(_n_samples, _sample_steps=config.sample.sample_steps, **kwargs) 117 | return decode_large_batch(_z) 118 | 119 | with tempfile.TemporaryDirectory() as temp_path: 120 | path = config.sample.path or temp_path 121 | if accelerator.is_main_process: 122 | os.makedirs(path, exist_ok=True) 123 | logging.info(f'Samples are saved in {path}') 124 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 125 | if accelerator.is_main_process: 126 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 127 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 128 | 129 | 130 | from absl import flags 131 | from absl import app 132 | from ml_collections import config_flags 133 | import os 134 | 135 | 136 | FLAGS = flags.FLAGS 137 | config_flags.DEFINE_config_file( 138 | "config", None, "Training configuration.", lock_config=False) 139 | flags.mark_flags_as_required(["config"]) 140 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 141 | flags.DEFINE_string("output_path", None, "The path to output log.") 142 | 143 | 144 | def main(argv): 145 | config = FLAGS.config 146 | config.nnet_path = FLAGS.nnet_path 147 | config.output_path = FLAGS.output_path 148 | evaluate(config) 149 | 150 | 151 | if __name__ == "__main__": 152 | app.run(main) 153 | -------------------------------------------------------------------------------- /U-ViT/eval_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | from torch.utils.data import DataLoader 7 | import utils 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import einops 14 | import libs.autoencoder 15 | 16 | 17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 18 | _betas = ( 19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 20 | ) 21 | return _betas.numpy() 22 | 23 | 24 | def evaluate(config): 25 | if config.get('benchmark', False): 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.deterministic = False 28 | 29 | mp.set_start_method('spawn') 30 | accelerator = accelerate.Accelerator() 31 | device = accelerator.device 32 | accelerate.utils.set_seed(config.seed, device_specific=True) 33 | logging.info(f'Process {accelerator.process_index} using device: {device}') 34 | 35 | config.mixed_precision = accelerator.mixed_precision 36 | config = ml_collections.FrozenConfigDict(config) 37 | if accelerator.is_main_process: 38 | utils.set_logger(log_level='info', fname=config.output_path) 39 | else: 40 | utils.set_logger(log_level='error') 41 | builtins.print = lambda *args: None 42 | 43 | dataset = get_dataset(**config.dataset) 44 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling 45 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, 46 | drop_last=True, num_workers=8, pin_memory=True, persistent_workers=True) 47 | 48 | nnet = utils.get_nnet(**config.nnet) 49 | nnet, test_dataset_loader = accelerator.prepare(nnet, test_dataset_loader) 50 | logging.info(f'load nnet from {config.nnet_path}') 51 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 52 | nnet.eval() 53 | 54 | def cfg_nnet(x, timesteps, context): 55 | _cond = nnet(x, timesteps, context=context) 56 | if config.sample.scale == 0: 57 | return _cond 58 | _empty_context = torch.tensor(dataset.empty_context, device=device) 59 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 60 | _uncond = nnet(x, timesteps, context=_empty_context) 61 | return _cond + config.sample.scale * (_cond - _uncond) 62 | 63 | autoencoder = libs.autoencoder.get_model(**config.autoencoder) 64 | autoencoder.to(device) 65 | 66 | @torch.cuda.amp.autocast() 67 | def encode(_batch): 68 | return autoencoder.encode(_batch) 69 | 70 | @torch.cuda.amp.autocast() 71 | def decode(_batch): 72 | return autoencoder.decode(_batch) 73 | 74 | def decode_large_batch(_batch): 75 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 76 | xs = [] 77 | pt = 0 78 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 79 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 80 | pt += _decode_mini_batch_size 81 | xs.append(x) 82 | xs = torch.concat(xs, dim=0) 83 | assert xs.size(0) == _batch.size(0) 84 | return xs 85 | 86 | def get_context_generator(): 87 | while True: 88 | for data in test_dataset_loader: 89 | _, _context = data 90 | yield _context 91 | 92 | context_generator = get_context_generator() 93 | 94 | _betas = stable_diffusion_beta_schedule() 95 | N = len(_betas) 96 | 97 | logging.info(config.sample) 98 | assert os.path.exists(dataset.fid_stat) 99 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}') 100 | 101 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs): 102 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 103 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 104 | 105 | def model_fn(x, t_continuous): 106 | t = t_continuous * N 107 | return cfg_nnet(x, t, **kwargs) 108 | 109 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 110 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 111 | return decode_large_batch(_z) 112 | 113 | def sample_fn(_n_samples): 114 | _context = next(context_generator) 115 | assert _context.size(0) == _n_samples 116 | return dpm_solver_sample(_n_samples, config.sample.sample_steps, context=_context) 117 | 118 | with tempfile.TemporaryDirectory() as temp_path: 119 | path = config.sample.path or temp_path 120 | if accelerator.is_main_process: 121 | os.makedirs(path, exist_ok=True) 122 | logging.info(f'Samples are saved in {path}') 123 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 124 | if accelerator.is_main_process: 125 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 126 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 127 | 128 | 129 | from absl import flags 130 | from absl import app 131 | from ml_collections import config_flags 132 | import os 133 | 134 | 135 | FLAGS = flags.FLAGS 136 | config_flags.DEFINE_config_file( 137 | "config", None, "Training configuration.", lock_config=False) 138 | flags.mark_flags_as_required(["config"]) 139 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 140 | flags.DEFINE_string("output_path", None, "The path to output log.") 141 | 142 | 143 | def main(argv): 144 | config = FLAGS.config 145 | config.nnet_path = FLAGS.nnet_path 146 | config.output_path = FLAGS.output_path 147 | evaluate(config) 148 | 149 | 150 | if __name__ == "__main__": 151 | app.run(main) 152 | -------------------------------------------------------------------------------- /U-ViT/flower_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz') 7 | # borrowed from Pytorch quickstart example 8 | class Net(nn.Module): 9 | def __init__(self) -> None: 10 | super(Net, self).__init__() 11 | self.conv1 = nn.Conv2d(3, 6, 5) 12 | self.pool = nn.MaxPool2d(2, 2) 13 | self.conv2 = nn.Conv2d(6, 16, 5) 14 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 15 | self.fc2 = nn.Linear(120, 84) 16 | self.fc3 = nn.Linear(84, 10) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | x = self.pool(F.relu(self.conv1(x))) 20 | x = self.pool(F.relu(self.conv2(x))) 21 | x = x.view(-1, 16 * 5 * 5) 22 | x = F.relu(self.fc1(x)) 23 | x = F.relu(self.fc2(x)) 24 | x = self.fc3(x) 25 | return x 26 | 27 | 28 | # borrowed from Pytorch quickstart example 29 | def train(net, trainloader, epochs, device: str): 30 | """Train the network on the training set.""" 31 | criterion = torch.nn.CrossEntropyLoss() 32 | optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) 33 | net.train() 34 | for _ in range(epochs): 35 | for images, labels in trainloader: 36 | images, labels = images.to(device), labels.to(device) 37 | optimizer.zero_grad() 38 | loss = criterion(net(images), labels) 39 | loss.backward() 40 | optimizer.step() 41 | 42 | 43 | # borrowed from Pytorch quickstart example 44 | def test(net, testloader, device: str): 45 | """Validate the network on the entire test set.""" 46 | criterion = torch.nn.CrossEntropyLoss() 47 | correct, loss = 0, 0.0 48 | net.eval() 49 | with torch.no_grad(): 50 | for data in testloader: 51 | images, labels = data[0].to(device), data[1].to(device) 52 | outputs = net(images) 53 | loss += criterion(outputs, labels).item() 54 | _, predicted = torch.max(outputs.data, 1) 55 | correct += (predicted == labels).sum().item() 56 | accuracy = correct / len(testloader.dataset) 57 | return loss, accuracy 58 | -------------------------------------------------------------------------------- /U-ViT/libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /U-ViT/libs/clip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class FrozenCLIPEmbedder(AbstractEncoder): 14 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 16 | super().__init__() 17 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 18 | self.transformer = CLIPTextModel.from_pretrained(version) 19 | self.device = device 20 | self.max_length = max_length 21 | self.freeze() 22 | 23 | def freeze(self): 24 | self.transformer = self.transformer.eval() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, text): 29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 31 | tokens = batch_encoding["input_ids"].to(self.device) 32 | outputs = self.transformer(input_ids=tokens) 33 | 34 | z = outputs.last_hidden_state 35 | return z 36 | 37 | def encode(self, text): 38 | return self(text) 39 | -------------------------------------------------------------------------------- /U-ViT/libs/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 141 | use_checkpoint=False, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.num_classes = num_classes 145 | self.in_chans = in_chans 146 | 147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 148 | num_patches = (img_size // patch_size) ** 2 149 | 150 | self.time_embed = nn.Sequential( 151 | nn.Linear(embed_dim, 4 * embed_dim), 152 | nn.SiLU(), 153 | nn.Linear(4 * embed_dim, embed_dim), 154 | ) if mlp_time_embed else nn.Identity() 155 | 156 | if self.num_classes > 0: 157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 158 | self.extras = 2 159 | else: 160 | self.extras = 1 161 | 162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 163 | 164 | self.in_blocks = nn.ModuleList([ 165 | Block( 166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 167 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 168 | for _ in range(depth // 2)]) 169 | 170 | self.mid_block = Block( 171 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 172 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 173 | 174 | self.out_blocks = nn.ModuleList([ 175 | Block( 176 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 177 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 178 | for _ in range(depth // 2)]) 179 | 180 | self.norm = norm_layer(embed_dim) 181 | self.patch_dim = patch_size ** 2 * in_chans 182 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 183 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 184 | 185 | trunc_normal_(self.pos_embed, std=.02) 186 | self.apply(self._init_weights) 187 | 188 | def _init_weights(self, m): 189 | if isinstance(m, nn.Linear): 190 | trunc_normal_(m.weight, std=.02) 191 | if isinstance(m, nn.Linear) and m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, nn.LayerNorm): 194 | nn.init.constant_(m.bias, 0) 195 | nn.init.constant_(m.weight, 1.0) 196 | 197 | @torch.jit.ignore 198 | def no_weight_decay(self): 199 | return {'pos_embed'} 200 | 201 | def forward(self, x, timesteps, y=None): 202 | x = self.patch_embed(x) 203 | B, L, D = x.shape 204 | 205 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 206 | time_token = time_token.unsqueeze(dim=1) 207 | x = torch.cat((time_token, x), dim=1) 208 | if y is not None: 209 | label_emb = self.label_emb(y) 210 | label_emb = label_emb.unsqueeze(dim=1) 211 | x = torch.cat((label_emb, x), dim=1) 212 | x = x + self.pos_embed 213 | 214 | skips = [] 215 | for blk in self.in_blocks: 216 | x = blk(x) 217 | skips.append(x) 218 | 219 | x = self.mid_block(x) 220 | 221 | for blk in self.out_blocks: 222 | x = blk(x, skips.pop()) 223 | 224 | x = self.norm(x) 225 | x = self.decoder_pred(x) 226 | assert x.size(1) == self.extras + L 227 | x = x[:, self.extras:, :] 228 | x = unpatchify(x, self.in_chans) 229 | x = self.final_layer(x) 230 | return x 231 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False, 141 | clip_dim=768, num_clip_token=77, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.in_chans = in_chans 145 | 146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = (img_size // patch_size) ** 2 148 | 149 | self.time_embed = nn.Sequential( 150 | nn.Linear(embed_dim, 4 * embed_dim), 151 | nn.SiLU(), 152 | nn.Linear(4 * embed_dim, embed_dim), 153 | ) if mlp_time_embed else nn.Identity() 154 | 155 | self.context_embed = nn.Linear(clip_dim, embed_dim) 156 | 157 | self.extras = 1 + num_clip_token 158 | 159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 160 | 161 | self.in_blocks = nn.ModuleList([ 162 | Block( 163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 165 | for _ in range(depth // 2)]) 166 | 167 | self.mid_block = Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 170 | 171 | self.out_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 175 | for _ in range(depth // 2)]) 176 | 177 | self.norm = norm_layer(embed_dim) 178 | self.patch_dim = patch_size ** 2 * in_chans 179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | @torch.jit.ignore 195 | def no_weight_decay(self): 196 | return {'pos_embed'} 197 | 198 | def forward(self, x, timesteps, context): 199 | x = self.patch_embed(x) 200 | B, L, D = x.shape 201 | 202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 203 | time_token = time_token.unsqueeze(dim=1) 204 | context_token = self.context_embed(context) 205 | x = torch.cat((time_token, context_token, x), dim=1) 206 | x = x + self.pos_embed 207 | 208 | skips = [] 209 | for blk in self.in_blocks: 210 | x = blk(x) 211 | skips.append(x) 212 | 213 | x = self.mid_block(x) 214 | 215 | for blk in self.out_blocks: 216 | x = blk(x, skips.pop()) 217 | 218 | x = self.norm(x) 219 | x = self.decoder_pred(x) 220 | assert x.size(1) == self.extras + L 221 | x = x[:, self.extras:, :] 222 | x = unpatchify(x, self.in_chans) 223 | x = self.final_layer(x) 224 | return x 225 | -------------------------------------------------------------------------------- /U-ViT/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import flwr as fl 3 | from flwr.common.typing import Scalar 4 | import ray 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | from collections import OrderedDict 9 | from pathlib import Path 10 | from typing import Dict, Callable, Optional, Tuple, List 11 | from dataset_utils import get_cifar_10, do_fl_partitioning, get_dataloader 12 | from utils import Net, train, test 13 | 14 | 15 | parser = argparse.ArgumentParser(description="Flower Simulation with PyTorch") 16 | 17 | parser.add_argument("--num_client_cpus", type=int, default=1) 18 | parser.add_argument("--num_rounds", type=int, default=5) 19 | 20 | 21 | # Flower client, adapted from Pytorch quickstart example 22 | class FlowerClient(fl.client.NumPyClient): 23 | def __init__(self, cid: str, fed_dir_data: str): 24 | self.cid = cid 25 | self.fed_dir = Path(fed_dir_data) 26 | self.properties: Dict[str, Scalar] = {"tensor_type": "numpy.ndarray"} 27 | 28 | # Instantiate model 29 | self.net = Net() 30 | 31 | # Determine device 32 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | 34 | def get_parameters(self, config): 35 | return get_params(self.net) 36 | 37 | def fit(self, parameters, config): 38 | set_params(self.net, parameters) 39 | 40 | # Load data for this client and get trainloader 41 | num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"]) 42 | trainloader = get_dataloader( 43 | self.fed_dir, 44 | self.cid, 45 | is_train=True, 46 | batch_size=config["batch_size"], 47 | workers=num_workers, 48 | ) 49 | 50 | # Send model to device 51 | self.net.to(self.device) 52 | 53 | # Train 54 | train(self.net, trainloader, epochs=config["epochs"], device=self.device) 55 | 56 | # Return local model and statistics 57 | return get_params(self.net), len(trainloader.dataset), {} 58 | 59 | def evaluate(self, parameters, config): 60 | set_params(self.net, parameters) 61 | 62 | # Load data for this client and get trainloader 63 | num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"]) 64 | valloader = get_dataloader( 65 | self.fed_dir, self.cid, is_train=False, batch_size=50, workers=num_workers 66 | ) 67 | 68 | # Send model to device 69 | self.net.to(self.device) 70 | 71 | # Evaluate 72 | loss, accuracy = test(self.net, valloader, device=self.device) 73 | 74 | # Return statistics 75 | return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} 76 | 77 | 78 | def fit_config(server_round: int) -> Dict[str, Scalar]: 79 | """Return a configuration with static batch size and (local) epochs.""" 80 | config = { 81 | "epochs": 5, # number of local epochs 82 | "batch_size": 64, 83 | } 84 | return config 85 | 86 | 87 | def get_params(model: torch.nn.ModuleList) -> List[np.ndarray]: 88 | """Get model weights as a list of NumPy ndarrays.""" 89 | return [val.cpu().numpy() for _, val in model.state_dict().items()] 90 | 91 | 92 | def set_params(model: torch.nn.ModuleList, params: List[np.ndarray]): 93 | """Set model weights from a list of NumPy ndarrays.""" 94 | params_dict = zip(model.state_dict().keys(), params) 95 | state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict}) 96 | model.load_state_dict(state_dict, strict=True) 97 | 98 | 99 | def get_evaluate_fn( 100 | testset: torchvision.datasets.CIFAR10, 101 | ) -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: 102 | """Return an evaluation function for centralized evaluation.""" 103 | 104 | def evaluate( 105 | server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] 106 | ) -> Optional[Tuple[float, float]]: 107 | """Use the entire CIFAR-10 test set for evaluation.""" 108 | 109 | # determine device 110 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 111 | 112 | model = Net() 113 | set_params(model, parameters) 114 | model.to(device) 115 | 116 | testloader = torch.utils.data.DataLoader(testset, batch_size=50) 117 | loss, accuracy = test(model, testloader, device=device) 118 | 119 | # return statistics 120 | return loss, {"accuracy": accuracy} 121 | 122 | return evaluate 123 | 124 | 125 | # Start simulation (a _default server_ will be created) 126 | # This example does: 127 | # 1. Downloads CIFAR-10 128 | # 2. Partitions the dataset into N splits, where N is the total number of 129 | # clients. We refere to this as `pool_size`. The partition can be IID or non-IID 130 | # 3. Starts a simulation where a % of clients are sample each round. 131 | # 4. After the M rounds end, the global model is evaluated on the entire testset. 132 | # Also, the global model is evaluated on the valset partition residing in each 133 | # client. This is useful to get a sense on how well the global model can generalise 134 | # to each client's data. 135 | if __name__ == "__main__": 136 | # parse input arguments 137 | args = parser.parse_args() 138 | 139 | pool_size = 100 # number of dataset partions (= number of total clients) 140 | client_resources = { 141 | "num_cpus": args.num_client_cpus 142 | } # each client will get allocated 1 CPUs 143 | 144 | # Download CIFAR-10 dataset 145 | train_path, testset = get_cifar_10() 146 | 147 | # partition dataset (use a large `alpha` to make it IID; 148 | # a small value (e.g. 1) will make it non-IID) 149 | # This will create a new directory called "federated": in the directory where 150 | # CIFAR-10 lives. Inside it, there will be N=pool_size sub-directories each with 151 | # its own train/set split. 152 | fed_dir = do_fl_partitioning( 153 | train_path, pool_size=pool_size, alpha=1000, num_classes=10, val_ratio=0.1 154 | ) 155 | 156 | # configure the strategy 157 | strategy = fl.server.strategy.FedAvg( 158 | fraction_fit=0.1, 159 | fraction_evaluate=0.1, 160 | min_fit_clients=10, 161 | min_evaluate_clients=10, 162 | min_available_clients=pool_size, # All clients should be available 163 | on_fit_config_fn=fit_config, 164 | evaluate_fn=get_evaluate_fn(testset), # centralised evaluation of global model 165 | ) 166 | 167 | def client_fn(cid: str): 168 | # create a single client instance 169 | return FlowerClient(cid, fed_dir) 170 | 171 | # (optional) specify Ray config 172 | ray_init_args = {"include_dashboard": False} 173 | 174 | # start simulation 175 | fl.simulation.start_simulation( 176 | client_fn=client_fn, 177 | num_clients=pool_size, 178 | client_resources=client_resources, 179 | config=fl.server.ServerConfig(num_rounds=args.num_rounds), 180 | strategy=strategy, 181 | ray_init_args=ray_init_args, 182 | ) 183 | -------------------------------------------------------------------------------- /U-ViT/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.4.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "simulation_pytorch" 7 | version = "0.1.0" 8 | description = "Federated Learning Simulation with Flower and PyTorch" 9 | authors = ["The Flower Authors "] 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.8,<3.11" 13 | flwr = {version = ">=1.0,<2.0", extras = ["simulation"]} 14 | torch = "1.13.1" 15 | torchvision = "0.14.1" 16 | -------------------------------------------------------------------------------- /U-ViT/requirements.txt: -------------------------------------------------------------------------------- 1 | flwr[simulation]>=1.0, <2.0 2 | torch==1.13.1 3 | torchvision==0.14.1 -------------------------------------------------------------------------------- /U-ViT/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ 4 | 5 | # Download the CIFAR-10 dataset 6 | python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" 7 | 8 | echo "Start simulation" 9 | python main.py & 10 | 11 | # Enable CTRL+C to stop all background processes 12 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM 13 | # Wait for all background processes to complete 14 | wait 15 | -------------------------------------------------------------------------------- /U-ViT/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colinlaganier/FederatedDiffusionModels/d86aac66c7ffdbc599fe3a581fac5b6c09247f1f/U-ViT/sample.png -------------------------------------------------------------------------------- /U-ViT/sample_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | import accelerate 5 | import utils 6 | from datasets import get_dataset 7 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 8 | from absl import logging 9 | import builtins 10 | import einops 11 | import libs.autoencoder 12 | import libs.clip 13 | from torchvision.utils import save_image 14 | 15 | 16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 17 | _betas = ( 18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 19 | ) 20 | return _betas.numpy() 21 | 22 | 23 | def evaluate(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | if accelerator.is_main_process: 37 | utils.set_logger(log_level='info') 38 | else: 39 | utils.set_logger(log_level='error') 40 | builtins.print = lambda *args: None 41 | 42 | dataset = get_dataset(**config.dataset) 43 | 44 | with open(config.input_path, 'r') as f: 45 | prompts = f.read().strip().split('\n') 46 | 47 | print(prompts) 48 | 49 | clip = libs.clip.FrozenCLIPEmbedder() 50 | clip.eval() 51 | clip.to(device) 52 | 53 | contexts = clip.encode(prompts) 54 | 55 | nnet = utils.get_nnet(**config.nnet) 56 | nnet = accelerator.prepare(nnet) 57 | logging.info(f'load nnet from {config.nnet_path}') 58 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 59 | nnet.eval() 60 | 61 | def cfg_nnet(x, timesteps, context): 62 | _cond = nnet(x, timesteps, context=context) 63 | if config.sample.scale == 0: 64 | return _cond 65 | _empty_context = torch.tensor(dataset.empty_context, device=device) 66 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 67 | _uncond = nnet(x, timesteps, context=_empty_context) 68 | return _cond + config.sample.scale * (_cond - _uncond) 69 | 70 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 71 | autoencoder.to(device) 72 | 73 | @torch.cuda.amp.autocast() 74 | def encode(_batch): 75 | return autoencoder.encode(_batch) 76 | 77 | @torch.cuda.amp.autocast() 78 | def decode(_batch): 79 | return autoencoder.decode(_batch) 80 | 81 | _betas = stable_diffusion_beta_schedule() 82 | N = len(_betas) 83 | 84 | logging.info(config.sample) 85 | logging.info(f'mixed_precision={config.mixed_precision}') 86 | logging.info(f'N={N}') 87 | 88 | z_init = torch.randn(contexts.size(0), *config.z_shape, device=device) 89 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 90 | 91 | def model_fn(x, t_continuous): 92 | t = t_continuous * N 93 | return cfg_nnet(x, t, context=contexts) 94 | 95 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 96 | z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) 97 | samples = dataset.unpreprocess(decode(z)) 98 | 99 | os.makedirs(config.output_path, exist_ok=True) 100 | for sample, prompt in zip(samples, prompts): 101 | save_image(sample, os.path.join(config.output_path, f"{prompt}.png")) 102 | 103 | 104 | 105 | from absl import flags 106 | from absl import app 107 | from ml_collections import config_flags 108 | import os 109 | 110 | 111 | FLAGS = flags.FLAGS 112 | config_flags.DEFINE_config_file( 113 | "config", None, "Training configuration.", lock_config=False) 114 | flags.mark_flags_as_required(["config"]) 115 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 116 | flags.DEFINE_string("output_path", None, "The path to output images.") 117 | flags.DEFINE_string("input_path", None, "The path to input texts.") 118 | 119 | 120 | def main(argv): 121 | config = FLAGS.config 122 | config.nnet_path = FLAGS.nnet_path 123 | config.output_path = FLAGS.output_path 124 | config.input_path = FLAGS.input_path 125 | evaluate(config) 126 | 127 | 128 | if __name__ == "__main__": 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /U-ViT/scripts/extract_empty_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | '', 14 | ] 15 | 16 | device = 'cuda' 17 | clip = libs.clip.FrozenCLIPEmbedder() 18 | clip.eval() 19 | clip.to(device) 20 | 21 | save_dir = f'assets/datasets/coco256_features' 22 | latent = clip.encode(prompts) 23 | print(latent.shape) 24 | c = latent[0].detach().cpu().numpy() 25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /U-ViT/scripts/extract_imagenet_feature.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | from datasets import ImageNet 5 | from torch.utils.data import DataLoader 6 | from libs.autoencoder import get_model 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 24 | model = nn.DataParallel(model) 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model.to(device) 27 | 28 | # features = [] 29 | # labels = [] 30 | 31 | idx = 0 32 | for batch in tqdm(train_dataset_loader): 33 | img, label = batch 34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 35 | img = img.to(device) 36 | moments = model(img, fn='encode_moments') 37 | moments = moments.detach().cpu().numpy() 38 | 39 | label = torch.cat([label, label], dim=0) 40 | label = label.detach().cpu().numpy() 41 | 42 | for moment, lb in zip(moments, label): 43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb)) 44 | idx += 1 45 | 46 | print(f'save {idx} files') 47 | 48 | # features = np.concatenate(features, axis=0) 49 | # labels = np.concatenate(labels, axis=0) 50 | # print(f'features.shape={features.shape}') 51 | # print(f'labels.shape={labels.shape}') 52 | # np.save(f'imagenet{resolution}_features.npy', features) 53 | # np.save(f'imagenet{resolution}_labels.npy', labels) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /U-ViT/scripts/extract_mscoco_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(resolution=256): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--split', default='train') 14 | args = parser.parse_args() 15 | print(args) 16 | 17 | 18 | if args.split == "train": 19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014', 20 | annFile='assets/datasets/coco/annotations/captions_train2014.json', 21 | size=resolution) 22 | save_dir = f'assets/datasets/coco{resolution}_features/train' 23 | elif args.split == "val": 24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014', 25 | annFile='assets/datasets/coco/annotations/captions_val2014.json', 26 | size=resolution) 27 | save_dir = f'assets/datasets/coco{resolution}_features/val' 28 | else: 29 | raise NotImplementedError("ERROR!") 30 | 31 | device = "cuda" 32 | os.makedirs(save_dir) 33 | 34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth') 35 | autoencoder.to(device) 36 | clip = libs.clip.FrozenCLIPEmbedder() 37 | clip.eval() 38 | clip.to(device) 39 | 40 | with torch.no_grad(): 41 | for idx, data in tqdm(enumerate(datas)): 42 | x, captions = data 43 | 44 | if len(x.shape) == 3: 45 | x = x[None, ...] 46 | x = torch.tensor(x, device=device) 47 | moments = autoencoder(x, fn='encode_moments').squeeze(0) 48 | moments = moments.detach().cpu().numpy() 49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments) 50 | 51 | latent = clip.encode(captions) 52 | for i in range(len(latent)): 53 | c = latent[i].detach().cpu().numpy() 54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /U-ViT/scripts/extract_test_prompt_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | 'A green train is coming down the tracks.', 14 | 'A group of skiers are preparing to ski down a mountain.', 15 | 'A small kitchen with a low ceiling.', 16 | 'A group of elephants walking in muddy water.', 17 | 'A living area with a television and a table.', 18 | 'A road with traffic lights, street lights and cars.', 19 | 'A bus driving in a city area with traffic signs.', 20 | 'A bus pulls over to the curb close to an intersection.', 21 | 'A group of people are walking and one is holding an umbrella.', 22 | 'A baseball player taking a swing at an incoming ball.', 23 | 'A city street line with brick buildings and trees.', 24 | 'A close up of a plate of broccoli and sauce.', 25 | ] 26 | 27 | device = 'cuda' 28 | clip = libs.clip.FrozenCLIPEmbedder() 29 | clip.eval() 30 | clip.to(device) 31 | 32 | save_dir = f'assets/datasets/coco256_features/run_vis' 33 | latent = clip.encode(prompts) 34 | for i in range(len(latent)): 35 | c = latent[i].detach().cpu().numpy() 36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /U-ViT/skip_im.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colinlaganier/FederatedDiffusionModels/d86aac66c7ffdbc599fe3a581fac5b6c09247f1f/U-ViT/skip_im.png -------------------------------------------------------------------------------- /U-ViT/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | from torchvision.utils import save_image 7 | from absl import logging 8 | 9 | 10 | def set_logger(log_level='info', fname=None): 11 | import logging as _logging 12 | handler = logging.get_absl_handler() 13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') 14 | handler.setFormatter(formatter) 15 | logging.set_verbosity(log_level) 16 | if fname is not None: 17 | handler = _logging.FileHandler(fname) 18 | handler.setFormatter(formatter) 19 | logging.get_absl_logger().addHandler(handler) 20 | 21 | 22 | def dct2str(dct): 23 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 24 | 25 | 26 | def get_nnet(name, **kwargs): 27 | if name == 'uvit': 28 | from libs.uvit import UViT 29 | return UViT(**kwargs) 30 | elif name == 'uvit_t2i': 31 | from libs.uvit_t2i import UViT 32 | return UViT(**kwargs) 33 | else: 34 | raise NotImplementedError(name) 35 | 36 | 37 | def set_seed(seed: int): 38 | if seed is not None: 39 | torch.manual_seed(seed) 40 | np.random.seed(seed) 41 | 42 | 43 | def get_optimizer(params, name, **kwargs): 44 | if name == 'adam': 45 | from torch.optim import Adam 46 | return Adam(params, **kwargs) 47 | elif name == 'adamw': 48 | from torch.optim import AdamW 49 | return AdamW(params, **kwargs) 50 | else: 51 | raise NotImplementedError(name) 52 | 53 | 54 | def customized_lr_scheduler(optimizer, warmup_steps=-1): 55 | from torch.optim.lr_scheduler import LambdaLR 56 | def fn(step): 57 | if warmup_steps > 0: 58 | return min(step / warmup_steps, 1) 59 | else: 60 | return 1 61 | return LambdaLR(optimizer, fn) 62 | 63 | 64 | def get_lr_scheduler(optimizer, name, **kwargs): 65 | if name == 'customized': 66 | return customized_lr_scheduler(optimizer, **kwargs) 67 | elif name == 'cosine': 68 | from torch.optim.lr_scheduler import CosineAnnealingLR 69 | return CosineAnnealingLR(optimizer, **kwargs) 70 | else: 71 | raise NotImplementedError(name) 72 | 73 | 74 | def ema(model_dest: nn.Module, model_src: nn.Module, rate): 75 | param_dict_src = dict(model_src.named_parameters()) 76 | for p_name, p_dest in model_dest.named_parameters(): 77 | p_src = param_dict_src[p_name] 78 | assert p_src is not p_dest 79 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) 80 | 81 | 82 | class TrainState(object): 83 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): 84 | self.optimizer = optimizer 85 | self.lr_scheduler = lr_scheduler 86 | self.step = step 87 | self.nnet = nnet 88 | self.nnet_ema = nnet_ema 89 | 90 | def ema_update(self, rate=0.9999): 91 | if self.nnet_ema is not None: 92 | ema(self.nnet_ema, self.nnet, rate) 93 | 94 | def save(self, path): 95 | os.makedirs(path, exist_ok=True) 96 | torch.save(self.step, os.path.join(path, 'step.pth')) 97 | for key, val in self.__dict__.items(): 98 | if key != 'step' and val is not None: 99 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) 100 | 101 | def load(self, path): 102 | logging.info(f'load from {path}') 103 | self.step = torch.load(os.path.join(path, 'step.pth')) 104 | for key, val in self.__dict__.items(): 105 | if key != 'step' and val is not None: 106 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) 107 | 108 | def resume(self, ckpt_root, step=None): 109 | if not os.path.exists(ckpt_root): 110 | return 111 | if step is None: 112 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) 113 | if not ckpts: 114 | return 115 | steps = map(lambda x: int(x.split(".")[0]), ckpts) 116 | step = max(steps) 117 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') 118 | logging.info(f'resume from {ckpt_path}') 119 | self.load(ckpt_path) 120 | 121 | def to(self, device): 122 | for key, val in self.__dict__.items(): 123 | if isinstance(val, nn.Module): 124 | val.to(device) 125 | 126 | 127 | def cnt_params(model): 128 | return sum(param.numel() for param in model.parameters()) 129 | 130 | 131 | def initialize_train_state(config, device): 132 | params = [] 133 | 134 | nnet = get_nnet(**config.nnet) 135 | params += nnet.parameters() 136 | nnet_ema = get_nnet(**config.nnet) 137 | nnet_ema.eval() 138 | logging.info(f'nnet has {cnt_params(nnet)} parameters') 139 | 140 | optimizer = get_optimizer(params, **config.optimizer) 141 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) 142 | 143 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, 144 | nnet=nnet, nnet_ema=nnet_ema) 145 | train_state.ema_update(0) 146 | train_state.to(device) 147 | return train_state 148 | 149 | 150 | def amortize(n_samples, batch_size): 151 | k = n_samples // batch_size 152 | r = n_samples % batch_size 153 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 154 | 155 | 156 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None): 157 | os.makedirs(path, exist_ok=True) 158 | idx = 0 159 | batch_size = mini_batch_size * accelerator.num_processes 160 | 161 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 162 | samples = unpreprocess_fn(sample_fn(mini_batch_size)) 163 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 164 | if accelerator.is_main_process: 165 | for sample in samples: 166 | save_image(sample, os.path.join(path, f"{idx}.png")) 167 | idx += 1 168 | 169 | 170 | def grad_norm(model): 171 | total_norm = 0. 172 | for p in model.parameters(): 173 | param_norm = p.grad.data.norm(2) 174 | total_norm += param_norm.item() ** 2 175 | total_norm = total_norm ** (1. / 2) 176 | return total_norm 177 | -------------------------------------------------------------------------------- /U-ViT/uvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colinlaganier/FederatedDiffusionModels/d86aac66c7ffdbc599fe3a581fac5b6c09247f1f/U-ViT/uvit.png -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Saved Model Checkpoints -------------------------------------------------------------------------------- /jobscript.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash -l 3 | 4 | # Batch script to run a GPU job under SGE. 5 | 6 | # Request a number of GPU cards, in this case 2 (the maximum) 7 | #$ -l gpu=1 8 | 9 | # Request ten minutes of wallclock time (format hours:minutes:seconds). 10 | #$ -l h_rt=24:00:0 11 | 12 | # Request 1 gigabyte of RAM (must be an integer followed by M, G, or T) 13 | #$ -l mem=32G 14 | 15 | # Request 15 gigabyte of TMPDIR space (default is 10 GB) 16 | #$ -l tmpfs=32G 17 | 18 | # Set the name of the job. 19 | #$ -N DDPM_Fed_10_100 20 | 21 | # Set the working directory to somewhere in your scratch space 22 | #$ -wd /home/ucabcuf/Scratch/FederatedDiffusionModels 23 | 24 | # Change into temporary directory to run work 25 | # cd $TMPDIR 26 | 27 | # load the cuda module (in case you are running a CUDA program) 28 | module purge 29 | module load default-modules 30 | module unload compilers mpi 31 | module load gcc-libs/4.9.2 32 | module load python/miniconda3/4.10.3 33 | 34 | # Activate conda environment 35 | source $UCL_CONDA_PATH/etc/profile.d/conda.sh 36 | conda activate FedKDD 37 | 38 | # Run the application 39 | nvidia-smi 40 | cd DDPM 41 | sh run.sh -c 5 -r 100 -e 1 --------------------------------------------------------------------------------