├── .gitignore
├── LICENSE
├── README.md
├── VAE
├── VAE_model.py
├── VAE_train.py
└── __init__.py
├── cell_sample.py
├── cell_train.py
├── celltypist_train.py
├── classifier_sample.py
├── classifier_train.py
├── exp_script
├── down_stream_analysis_muris.ipynb
├── script_description.md
├── script_diffusion_interpolation.ipynb
├── script_diffusion_multi-condi.ipynb
├── script_diffusion_umap.ipynb
├── script_random_forest.ipynb
└── script_static_eval.ipynb
├── guided_diffusion
├── __init__.py
├── cell_datasets_WOT.py
├── cell_datasets_loader.py
├── cell_datasets_lung.py
├── cell_datasets_muris.py
├── cell_datasets_pbmc.py
├── cell_datasets_sapiens.py
├── cell_model.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── logger.py
├── losses.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
└── train_util.py
├── model_archi.png
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__/
3 | VAE/cache
4 | output/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Erpai Luo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## scDiffusion: Conditional Generation of High-Quality Single-Cell Data Using Diffusion Model
2 | Welcome to the code base for scDiffusion, a model developed for the generation of scRNA-seq data. This model combines the power of latent diffusion model and pre-trained model. More details about this model: https://doi.org/10.1093/bioinformatics/btae518.
3 |
4 |
5 |
6 |

7 |
8 |
9 |
10 | # Environment
11 | ```
12 | pytorch 1.13.0
13 | numpy 1.23.4
14 | anndata 0.8.0
15 | scanpy 1.9.1
16 | scikit-learn 1.2.2
17 | blobfile 2.0.0
18 | pandas 1.5.1
19 | celltypist 1.3.0
20 | imbalanced-learn 0.11.0
21 | mpi4py 3.1.4
22 | ```
23 |
24 | # Train the scDiffusion model
25 |
26 | **Dataset:**
27 | The data used for training the model is formatted in h5ad. You can download the dataset that used in the paper in https://figshare.com/s/49b29cb24b27ec8b6d72. For other formats (or your data has already been pre-possed), modify the code in ./guided_diffusion/cell_datasets_loader.py. The load_data function in the cell_datasets_loader.py only support not pre-processed row count data.
28 |
29 | You can directly run the `train.sh` to complete all the training steps. Be aware to change the file path to your own.
30 |
31 | Below are the complete steps for the training process:
32 |
33 | - Step 1: Train the Autoencoder
34 | Run `VAE/VAE_train.py`: cd VAE. Set the parameters *data_dir* and *save_dir* to your local path, and set the *num_genes* parameter to match the gene number of your dataset. The pretrained weight of scimilarity can be found in https://zenodo.org/records/8286452, we used the annotation_model_v1 in this work. Set the *state_dict* to the path where you store your downloaded scimilarity checkpoint. You can also train the autoencoder from scratch, this might need larger interation steps (larger than 1.5e5 steps would be good).
35 |
36 | For example:
37 | `python VAE_train.py --data_dir '/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad' --num_genes 18996 --save_dir '../output/checkpoint/AE/my_VAE' --max_steps 200000`
38 |
39 | - Step 2: Train the diffusion backbone
40 | Run `cell_train.py`: First, set the parameters *vae_path* to the path of your trained Autoencoder. Next, set the *data_dir*, *model_name*(the folder to save the ckpt), and *save_dir*(the path to place the *model_name* folder). We trained the backbone for 6e5 steps.
41 |
42 | For example:
43 | `python cell_train.py --data_dir '/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad' --vae_path 'output/checkpoint/AE/my_VAE/model_seed=0_step=150000.pt' --model_name 'my_diffusion' --save_dir 'output/checkpoint/backbone' --lr_anneal_steps 800000`
44 |
45 | - Step 3: Train the classifier
46 | Run `classifier_train.py`: Again, set the parameters *vae_path* to the path of your trained Autoencoder. Set the *num_class* parameter to match the number of classes in your dataset. Then, set the *model_path* to the path you would like to save the ckpt and execute the file. We trained the classifier for 2e5 steps.
47 |
48 | For example:
49 | `python classifier_train.py --data_dir '/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad' --model_path "output/checkpoint/classifier/my_classifier" --iterations 400000 --vae_path 'output/checkpoint/AE/my_VAE/model_seed=0_step=150000.pt' --num_class=12`
50 |
51 | # Generate new sample
52 |
53 | **Unconditional generation:**
54 |
55 | Run `cell_sample.py`: set the *model_path* to match the trained backbone model's path and set the *sample_dir* to your local path. The *num_samples* is the number of cell to generate, and the *batch_size* is the number of cell generate in one diffusion reverse process.
56 |
57 | For example:
58 | `python cell_sample.py --model_path 'output/checkpoint/backbone/my_diffusion/model600000.pt' --sample_dir 'output/simulated_samples/muris' --num_samples 3000 --batch_size 1000`
59 |
60 | Running the file will generate new latent embeddings for the scRNA-seq data and save them in a .npz file. You can decode these latent embeddings and retrieve the complete gene expression data using `exp_script/script_diffusion_umap.ipynb` or `exp_script/script_static_eval.ipynb`.
61 |
62 | **Conditional generation:**
63 |
64 | Run `classifier_sample.py`: set the *model_path* and *classifier_path* to match the trained backbone model and the trained classifier, respectively. Also, set the *sample_dir* to your local path. The condition can be set in "main" (the param *cell_type* in the main() function refer to the cell_type you want to generate.). Running the file will generate new latent embeddings under the given conditions.
65 |
66 | For example:
67 | `python classifier_sample.py --model_path 'output/checkpoint/backbone/my_diffusion/model600000.pt' --classifier_path 'output/checkpoint/classifier/my_classifier/model200000.pt' --sample_dir 'output/simulated_samples/muris' --num_samples 3000 --batch_size 1000`
68 |
69 | You can decode these embeddings the same way as in unconditional generation.
70 |
71 | For multi-conditional generation and gradiante interpolation, refer to the comments in the main() function and create_argparser() function (see the comments with *** mark).
72 |
73 | **Experiments reproduce:**
74 |
75 | The scripts in the exp_script/ directory can be used to reproduce the results presented in the paper. You can refer the process in any of these scripts to rebuild the gene expression from latent space. The `exp_script/down_stream_analysis_muris.ipynb` can reproduce the marker genes result. The `exp_script/script_diffusion_umap_multi-condi.ipynb` can reproduce the result of two-conditonal generation. The `exp_script/script_diffusion_umap_trajectory.ipynb` can reproduce the result of Gradient Interpolation. The `exp_script/script_diffusion_umap.ipynb` can reproduce the UMAP shown in the paper. The `exp_script/script_static_eval.ipynb` can reproduce the statistical metrics mentioned in the paper.
76 |
--------------------------------------------------------------------------------
/VAE/VAE_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | import os
6 | import anndata as ad
7 | import scanpy as sc
8 | from typing import List
9 |
10 |
11 | class Encoder(nn.Module):
12 | """A class that encapsulates the encoder."""
13 | def __init__(
14 | self,
15 | n_genes: int,
16 | latent_dim: int = 128,
17 | hidden_dim: List[int] = [1024, 1024],
18 | dropout: float = 0.5,
19 | input_dropout: float = 0.4,
20 | residual: bool = False,
21 | ):
22 | """Constructor.
23 |
24 | Parameters
25 | ----------
26 | n_genes: int
27 | The number of genes in the gene space, representing the input dimensions.
28 | latent_dim: int, default: 128
29 | The latent space dimensions
30 | hidden_dim: List[int], default: [1024, 1024]
31 | A list of hidden layer dimensions, describing the number of layers and their dimensions.
32 | Hidden layers are constructed in the order of the list for the encoder and in reverse
33 | for the decoder.
34 | dropout: float, default: 0.5
35 | The dropout rate for hidden layers
36 | input_dropout: float, default: 0.4
37 | The dropout rate for the input layer
38 | residual: bool, default: False
39 | Use residual connections.
40 | """
41 | super().__init__()
42 | self.latent_dim = latent_dim
43 | self.network = nn.ModuleList()
44 | self.residual = residual
45 | if self.residual:
46 | assert len(set(hidden_dim)) == 1
47 | for i in range(len(hidden_dim)):
48 | if i == 0: # input layer
49 | self.network.append(
50 | nn.Sequential(
51 | nn.Dropout(p=input_dropout),
52 | nn.Linear(n_genes, hidden_dim[i]),
53 | nn.BatchNorm1d(hidden_dim[i]),
54 | nn.PReLU(),
55 | )
56 | )
57 | else: # hidden layers
58 | self.network.append(
59 | nn.Sequential(
60 | nn.Dropout(p=dropout),
61 | nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
62 | nn.BatchNorm1d(hidden_dim[i]),
63 | nn.PReLU(),
64 | )
65 | )
66 | # output layer
67 | self.network.append(nn.Linear(hidden_dim[-1], latent_dim))
68 |
69 | def forward(self, x) -> F.Tensor:
70 | for i, layer in enumerate(self.network):
71 | if self.residual and (0 < i < len(self.network) - 1):
72 | x = layer(x) + x
73 | else:
74 | x = layer(x)
75 | return F.normalize(x, p=2, dim=1)
76 |
77 | def save_state(self, filename: str):
78 | """Save state dictionary.
79 |
80 | Parameters
81 | ----------
82 | filename: str
83 | Filename to save the state dictionary.
84 | """
85 | torch.save({"state_dict": self.state_dict()}, filename)
86 |
87 | def load_state(self, filename: str, use_gpu: bool = False):
88 | """Load model state.
89 |
90 | Parameters
91 | ----------
92 | filename: str
93 | Filename containing the model state.
94 | use_gpu: bool
95 | Boolean indicating whether or not to use GPUs.
96 | """
97 | if not use_gpu:
98 | ckpt = torch.load(filename, map_location=torch.device("cpu"))
99 | else:
100 | ckpt = torch.load(filename)
101 | state_dict = ckpt['state_dict']
102 | first_layer_key = ['network.0.1.weight',
103 | 'network.0.1.bias',
104 | 'network.0.2.weight',
105 | 'network.0.2.bias',
106 | 'network.0.2.running_mean',
107 | 'network.0.2.running_var',
108 | 'network.0.2.num_batches_tracked',
109 | 'network.0.3.weight]',]
110 | for key in first_layer_key:
111 | if key in state_dict:
112 | del state_dict[key]
113 | self.load_state_dict(state_dict, strict=False)
114 |
115 |
116 | class Decoder(nn.Module):
117 | """A class that encapsulates the decoder."""
118 |
119 | def __init__(
120 | self,
121 | n_genes: int,
122 | latent_dim: int = 128,
123 | hidden_dim: List[int] = [1024, 1024],
124 | dropout: float = 0.5,
125 | residual: bool = False,
126 | ):
127 | """Constructor.
128 |
129 | Parameters
130 | ----------
131 | n_genes: int
132 | The number of genes in the gene space, representing the input dimensions.
133 | latent_dim: int, default: 128
134 | The latent space dimensions
135 | hidden_dim: List[int], default: [1024, 1024]
136 | A list of hidden layer dimensions, describing the number of layers and their dimensions.
137 | Hidden layers are constructed in the order of the list for the encoder and in reverse
138 | for the decoder.
139 | dropout: float, default: 0.5
140 | The dropout rate for hidden layers
141 | residual: bool, default: False
142 | Use residual connections.
143 | """
144 | super().__init__()
145 | self.latent_dim = latent_dim
146 | self.network = nn.ModuleList()
147 | self.residual = residual
148 | if self.residual:
149 | assert len(set(hidden_dim)) == 1
150 | for i in range(len(hidden_dim)):
151 | if i == 0: # first hidden layer
152 | self.network.append(
153 | nn.Sequential(
154 | nn.Linear(latent_dim, hidden_dim[i]),
155 | nn.BatchNorm1d(hidden_dim[i]),
156 | nn.PReLU(),
157 | )
158 | )
159 | else: # other hidden layers
160 | self.network.append(
161 | nn.Sequential(
162 | nn.Dropout(p=dropout),
163 | nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
164 | nn.BatchNorm1d(hidden_dim[i]),
165 | nn.PReLU(),
166 | )
167 | )
168 | # reconstruction layer
169 | self.network.append(nn.Linear(hidden_dim[-1], n_genes))
170 |
171 | def forward(self, x):
172 | for i, layer in enumerate(self.network):
173 | if self.residual and (0 < i < len(self.network) - 1):
174 | x = layer(x) + x
175 | else:
176 | x = layer(x)
177 | return x
178 |
179 | def save_state(self, filename: str):
180 | """Save state dictionary.
181 |
182 | Parameters
183 | ----------
184 | filename: str
185 | Filename to save the state dictionary.
186 | """
187 | torch.save({"state_dict": self.state_dict()}, filename)
188 |
189 | def load_state(self, filename: str, use_gpu: bool = False):
190 | """Load model state.
191 |
192 | Parameters
193 | ----------
194 | filename: str
195 | Filename containing the model state.
196 | use_gpu: bool
197 | Boolean indicating whether to use GPUs.
198 | """
199 | if not use_gpu:
200 | ckpt = torch.load(filename, map_location=torch.device("cpu"))
201 | else:
202 | ckpt = torch.load(filename)
203 | state_dict = ckpt['state_dict']
204 | last_layer_key = ['network.3.weight',
205 | 'network.3.bias',]
206 | for key in last_layer_key:
207 | if key in state_dict:
208 | del state_dict[key]
209 | self.load_state_dict(state_dict, strict=False)
210 | # self.load_state_dict(ckpt["state_dict"])
211 |
212 | class VAE(torch.nn.Module):
213 | """
214 | VAE base on compositional perturbation autoencoder (CPA)
215 | """
216 | def __init__(
217 | self,
218 | num_genes,
219 | device="cuda",
220 | seed=0,
221 | loss_ae="gauss",
222 | decoder_activation="linear",
223 | hidden_dim=128,
224 | ):
225 | super(VAE, self).__init__()
226 | # set generic attributes
227 | self.num_genes = num_genes
228 | self.device = device
229 | self.seed = seed
230 | self.loss_ae = loss_ae
231 | # early-stopping
232 | self.best_score = -1e3
233 | self.patience_trials = 0
234 |
235 | # set hyperparameters
236 | self.set_hparams_(hidden_dim)
237 |
238 | # set models
239 | self.hidden_dim = [1024,1024,1024]
240 | self.dropout = 0.0
241 | self.input_dropout = 0.0
242 | self.residual = False
243 | self.encoder = Encoder(
244 | self.num_genes,
245 | latent_dim=self.hparams["dim"],
246 | hidden_dim=self.hidden_dim,
247 | dropout=self.dropout,
248 | input_dropout=self.input_dropout,
249 | residual=self.residual,
250 | )
251 | self.decoder = Decoder(
252 | self.num_genes,
253 | latent_dim=self.hparams["dim"],
254 | hidden_dim=list(reversed(self.hidden_dim)),
255 | dropout=self.dropout,
256 | residual=self.residual,
257 | )
258 |
259 | # losses
260 | self.loss_autoencoder = nn.MSELoss(reduction='mean')
261 |
262 | self.iteration = 0
263 |
264 | self.to(self.device)
265 |
266 | # optimizers
267 | get_params = lambda model, cond: list(model.parameters()) if cond else []
268 | _parameters = (
269 | get_params(self.encoder, True)
270 | + get_params(self.decoder, True)
271 | )
272 | self.optimizer_autoencoder = torch.optim.AdamW(_parameters, lr=self.hparams["autoencoder_lr"], weight_decay=self.hparams["autoencoder_wd"],)
273 |
274 |
275 | def forward(self, genes, return_latent=False, return_decoded=False):
276 | """
277 | If return_latent=True, act as encoder only. If return_decoded, genes should
278 | be the latent representation and this act as decoder only.
279 | """
280 | if return_decoded:
281 | gene_reconstructions = self.decoder(genes)
282 | gene_reconstructions = nn.ReLU()(gene_reconstructions) # only relu when inference
283 | return gene_reconstructions
284 |
285 | latent_basal = self.encoder(genes)
286 | if return_latent:
287 | return latent_basal
288 |
289 | gene_reconstructions = self.decoder(latent_basal)
290 |
291 | return gene_reconstructions
292 |
293 |
294 |
295 | def set_hparams_(self, hidden_dim):
296 | """
297 | Set hyper-parameters to default values or values fixed by user.
298 | """
299 |
300 | self.hparams = {
301 | "dim": hidden_dim,
302 | "autoencoder_width": 5000,
303 | "autoencoder_depth": 3,
304 | "adversary_lr": 3e-4,
305 | "autoencoder_wd": 0.01,
306 | "autoencoder_lr": 5e-4,
307 | }
308 |
309 | return self.hparams
310 |
311 |
312 | def train_step(self, genes):
313 | """
314 | Train VAE.
315 | """
316 | genes = genes.to(self.device)
317 | gene_reconstructions = self.forward(genes)
318 |
319 | reconstruction_loss = self.loss_autoencoder(gene_reconstructions, genes)
320 |
321 | self.optimizer_autoencoder.zero_grad()
322 | reconstruction_loss.backward()
323 | self.optimizer_autoencoder.step()
324 |
325 | self.iteration += 1
326 |
327 | return {
328 | "loss_reconstruction": reconstruction_loss.item(),
329 | }
330 |
--------------------------------------------------------------------------------
/VAE/VAE_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from VAE_model import VAE
8 | import sys
9 | sys.path.append("..")
10 | # from guided_diffusion.cell_datasets import load_data
11 | # from guided_diffusion.cell_datasets_sapiens import load_data
12 | # from guided_diffusion.cell_datasets_WOT import load_data
13 | # from guided_diffusion.cell_datasets_muris import load_data
14 | from guided_diffusion.cell_datasets_loader import load_data
15 |
16 | torch.autograd.set_detect_anomaly(True)
17 | import random
18 |
19 | def seed_everything(seed):
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 | torch.backends.cudnn.deterministic = True
25 |
26 |
27 | def prepare_vae(args, state_dict=None):
28 | """
29 | Instantiates autoencoder and dataset to run an experiment.
30 | """
31 |
32 | device = "cuda" if torch.cuda.is_available() else "cpu"
33 |
34 | datasets = load_data(
35 | data_dir=args["data_dir"],
36 | batch_size=args["batch_size"],
37 | train_vae=True,
38 | )
39 |
40 | autoencoder = VAE(
41 | num_genes=args["num_genes"],
42 | device=device,
43 | seed=args["seed"],
44 | loss_ae=args["loss_ae"],
45 | hidden_dim=128,
46 | decoder_activation=args["decoder_activation"],
47 | )
48 | if state_dict is not None:
49 | print('loading pretrained model from: \n',state_dict)
50 | use_gpu = device == "cuda"
51 | autoencoder.encoder.load_state(state_dict["encoder"], use_gpu)
52 | autoencoder.decoder.load_state(state_dict["decoder"], use_gpu)
53 |
54 | return autoencoder, datasets
55 |
56 |
57 | def train_vae(args, return_model=False):
58 | """
59 | Trains a autoencoder
60 | """
61 | if args["state_dict"] is not None:
62 | filenames = {}
63 | checkpoint_path = {
64 | "encoder": os.path.join(
65 | args["state_dict"], filenames.get("model", "encoder.ckpt")
66 | ),
67 | "decoder": os.path.join(
68 | args["state_dict"], filenames.get("model", "decoder.ckpt")
69 | ),
70 | "gene_order": os.path.join(
71 | args["state_dict"], filenames.get("gene_order", "gene_order.tsv")
72 | ),
73 | }
74 | autoencoder, datasets = prepare_vae(args, checkpoint_path)
75 | else:
76 | autoencoder, datasets = prepare_vae(args)
77 |
78 | args["hparams"] = autoencoder.hparams
79 |
80 | start_time = time.time()
81 | for step in range(args["max_steps"]):
82 |
83 | genes, _ = next(datasets)
84 |
85 | minibatch_training_stats = autoencoder.train_step(genes)
86 |
87 | if step % 1000 == 0:
88 | for key, val in minibatch_training_stats.items():
89 | print('step ', step, 'loss ', val)
90 |
91 | ellapsed_minutes = (time.time() - start_time) / 60
92 |
93 | stop = ellapsed_minutes > args["max_minutes"] or (
94 | step == args["max_steps"] - 1
95 | )
96 |
97 | if ((step % args["checkpoint_freq"]) == 0 or stop):
98 |
99 | os.makedirs(args["save_dir"],exist_ok=True)
100 | torch.save(
101 | autoencoder.state_dict(),
102 | os.path.join(
103 | args["save_dir"],
104 | "model_seed={}_step={}.pt".format(args["seed"], step),
105 | ),
106 | )
107 |
108 | if stop:
109 | break
110 |
111 | if return_model:
112 | return autoencoder, datasets
113 |
114 |
115 | def parse_arguments():
116 | """
117 | Read arguments if this script is called from a terminal.
118 | """
119 |
120 | parser = argparse.ArgumentParser(description="Finetune Scimilarity")
121 | # dataset arguments
122 | parser.add_argument("--data_dir", type=str, default='/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad')
123 | parser.add_argument("--loss_ae", type=str, default="mse")
124 | parser.add_argument("--decoder_activation", type=str, default="ReLU")
125 |
126 | # AE arguments
127 | parser.add_argument("--local_rank", type=int, default=0)
128 | parser.add_argument("--split_seed", type=int, default=1234)
129 | parser.add_argument("--num_genes", type=int, default=18996)
130 | parser.add_argument("--seed", type=int, default=0)
131 | parser.add_argument("--hparams", type=str, default="")
132 |
133 | # training arguments
134 | parser.add_argument("--max_steps", type=int, default=200000)
135 | parser.add_argument("--max_minutes", type=int, default=3000)
136 | parser.add_argument("--checkpoint_freq", type=int, default=50000)
137 | parser.add_argument("--batch_size", type=int, default=128)
138 | parser.add_argument("--state_dict", type=str, default="/data1/lep/Workspace/guided-diffusion/scimilarity-main/models/annotation_model_v1") # if pretrain
139 | # parser.add_argument("--state_dict", type=str, default=None) # if not pretrain
140 |
141 | parser.add_argument("--save_dir", type=str, default='../output/ae_checkpoint/muris_AE')
142 | parser.add_argument("--sweep_seeds", type=int, default=200)
143 | return dict(vars(parser.parse_args()))
144 |
145 |
146 | if __name__ == "__main__":
147 | seed_everything(1234)
148 | train_vae(parse_arguments())
149 |
--------------------------------------------------------------------------------
/VAE/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EperLuo/scDiffusion/d34ef8e560b47159d4500cf4411a7a34e5a12a32/VAE/__init__.py
--------------------------------------------------------------------------------
/cell_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 | import argparse
6 |
7 | import numpy as np
8 | import torch as th
9 | import torch.distributed as dist
10 | import random
11 |
12 | from guided_diffusion import dist_util, logger
13 | from guided_diffusion.script_util import (
14 | NUM_CLASSES,
15 | model_and_diffusion_defaults,
16 | create_model_and_diffusion,
17 | add_dict_to_argparser,
18 | args_to_dict,
19 | )
20 |
21 |
22 | def save_data(all_cells, traj, data_dir):
23 | cell_gen = all_cells
24 | np.savez(data_dir, cell_gen=cell_gen)
25 | return
26 |
27 | def main():
28 | setup_seed(1234)
29 | args = create_argparser().parse_args()
30 |
31 | dist_util.setup_dist()
32 | logger.configure(dir='output/checkpoint/sample_logs')
33 |
34 | logger.log("creating model and diffusion...")
35 | model, diffusion = create_model_and_diffusion(
36 | **args_to_dict(args, model_and_diffusion_defaults().keys())
37 | )
38 | model.load_state_dict(
39 | dist_util.load_state_dict(args.model_path, map_location="cpu")
40 | )
41 | model.to(dist_util.dev())
42 | model.eval()
43 |
44 | logger.log("sampling...")
45 | all_cells = []
46 | while len(all_cells) * args.batch_size < args.num_samples:
47 | model_kwargs = {}
48 | sample_fn = (
49 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
50 | )
51 | sample, traj = sample_fn(
52 | model,
53 | (args.batch_size, args.input_dim),
54 | clip_denoised=args.clip_denoised,
55 | model_kwargs=model_kwargs,
56 | start_time=diffusion.betas.shape[0],
57 | )
58 |
59 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
60 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
61 | all_cells.extend([sample.cpu().numpy() for sample in gathered_samples])
62 | logger.log(f"created {len(all_cells) * args.batch_size} samples")
63 |
64 | arr = np.concatenate(all_cells, axis=0)
65 | save_data(arr, traj, args.sample_dir)
66 |
67 | dist.barrier()
68 | logger.log("sampling complete")
69 |
70 |
71 | def create_argparser():
72 | defaults = dict(
73 | clip_denoised=False,
74 | num_samples=12000,
75 | batch_size=3000,
76 | use_ddim=False,
77 | model_path="output/checkpoint/backbone/open_problem/model800000.pt",
78 | sample_dir="output/simulated_samples/open_problem"
79 | )
80 | defaults.update(model_and_diffusion_defaults())
81 | parser = argparse.ArgumentParser()
82 | add_dict_to_argparser(parser, defaults)
83 | return parser
84 |
85 | def setup_seed(seed):
86 | th.manual_seed(seed)
87 | th.cuda.manual_seed_all(seed)
88 | np.random.seed(seed)
89 | random.seed(seed)
90 | th.backends.cudnn.deterministic = True # 设置随机数种子
91 |
92 |
93 | if __name__ == "__main__":
94 | main()
95 |
--------------------------------------------------------------------------------
/cell_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 |
7 | from guided_diffusion import dist_util, logger
8 | from guided_diffusion.cell_datasets_loader import load_data
9 | from guided_diffusion.resample import create_named_schedule_sampler
10 | from guided_diffusion.script_util import (
11 | model_and_diffusion_defaults,
12 | create_model_and_diffusion,
13 | args_to_dict,
14 | add_dict_to_argparser,
15 | )
16 | from guided_diffusion.train_util import TrainLoop
17 |
18 | import torch
19 | import numpy as np
20 | import random
21 |
22 | def main():
23 | setup_seed(1234)
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure(dir='../output/logs/'+args.model_name) # log file
28 |
29 | logger.log("creating model and diffusion...")
30 | model, diffusion = create_model_and_diffusion(
31 | **args_to_dict(args, model_and_diffusion_defaults().keys())
32 | )
33 | model.to(dist_util.dev())
34 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
35 |
36 | logger.log("creating data loader...")
37 | data = load_data(
38 | data_dir=args.data_dir,
39 | batch_size=args.batch_size,
40 | vae_path=args.vae_path,
41 | train_vae=False,
42 | )
43 |
44 | logger.log("training...")
45 | TrainLoop(
46 | model=model,
47 | diffusion=diffusion,
48 | data=data,
49 | batch_size=args.batch_size,
50 | microbatch=args.microbatch,
51 | lr=args.lr,
52 | ema_rate=args.ema_rate,
53 | log_interval=args.log_interval,
54 | save_interval=args.save_interval,
55 | resume_checkpoint=args.resume_checkpoint,
56 | use_fp16=args.use_fp16,
57 | fp16_scale_growth=args.fp16_scale_growth,
58 | schedule_sampler=schedule_sampler,
59 | weight_decay=args.weight_decay,
60 | lr_anneal_steps=args.lr_anneal_steps,
61 | model_name=args.model_name,
62 | save_dir=args.save_dir
63 | ).run_loop()
64 |
65 |
66 | def create_argparser():
67 | defaults = dict(
68 | data_dir="/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad",
69 | schedule_sampler="uniform",
70 | lr=1e-4,
71 | weight_decay=0.0001,
72 | lr_anneal_steps=500000,
73 | batch_size=128,
74 | microbatch=-1, # -1 disables microbatches
75 | ema_rate="0.9999", # comma-separated list of EMA values
76 | log_interval=100,
77 | save_interval=200000,
78 | resume_checkpoint="",
79 | use_fp16=False,
80 | fp16_scale_growth=1e-3,
81 | vae_path = 'output/Autoencoder_checkpoint/muris_AE/model_seed=0_step=0.pt',
82 | model_name="muris_diffusion",
83 | save_dir='output/diffusion_checkpoint'
84 | )
85 | defaults.update(model_and_diffusion_defaults())
86 | parser = argparse.ArgumentParser()
87 | add_dict_to_argparser(parser, defaults)
88 | return parser
89 |
90 |
91 | def setup_seed(seed):
92 | torch.manual_seed(seed)
93 | torch.cuda.manual_seed_all(seed)
94 | np.random.seed(seed)
95 | random.seed(seed)
96 | torch.backends.cudnn.deterministic = True
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/celltypist_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import anndata as ad
3 | import scanpy as sc
4 | import celltypist
5 | from sklearn.model_selection import train_test_split
6 | from imblearn.over_sampling import RandomOverSampler
7 |
8 | def split_adata(adata, train_ratio=0.8, random_state=42):
9 | indexes = np.arange(adata.shape[0])
10 | train_indexes, test_indexes = train_test_split(indexes, train_size=train_ratio, random_state=random_state)
11 |
12 | train_adata = adata[train_indexes].copy()
13 | test_adata = adata[test_indexes].copy()
14 |
15 | return train_adata, test_adata
16 |
17 | adata = sc.read_h5ad('/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad')
18 | sc.pp.filter_genes(adata, min_cells=3)
19 | sc.pp.filter_cells(adata, min_genes=10)
20 | adata.var_names_make_unique()
21 | sc.pp.normalize_total(adata, target_sum=1e4)
22 | sc.pp.log1p(adata)
23 |
24 | # rebalance
25 | adata, test_adata = split_adata(adata, train_ratio=0.8, random_state=42)
26 | celltype = adata.obs['celltype'].values
27 | ros = RandomOverSampler(random_state=42)
28 | X_resampled, y_resampled = ros.fit_resample(adata.X, celltype)
29 | adata_resampled = ad.AnnData(X_resampled[:80000])
30 | adata_resampled.var_names = adata.var_names
31 | print(adata_resampled)
32 | # if you want to save the testset
33 | test_adata.write_h5ad('data/testset_muris_all.h5ad')
34 |
35 | new_model = celltypist.train(adata_resampled, labels = y_resampled[:80000], n_jobs=32)
36 |
37 | new_model.write('output/checkpoint/celltypist/my_celltypist.pkl')
38 |
--------------------------------------------------------------------------------
/classifier_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Like image_sample.py, but use a noisy image classifier to guide the sampling
3 | process towards more realistic images.
4 | """
5 |
6 | import argparse
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 | import torch.nn.functional as F
12 |
13 | from guided_diffusion import dist_util, logger
14 | from guided_diffusion.script_util import (
15 | NUM_CLASSES,
16 | model_and_diffusion_defaults,
17 | classifier_and_diffusion_defaults,
18 | create_model_and_diffusion,
19 | create_classifier,
20 | add_dict_to_argparser,
21 | args_to_dict,
22 | )
23 | import scanpy as sc
24 | import torch
25 | from VAE.VAE_model import VAE
26 |
27 | def load_VAE(ae_dir, num_gene):
28 | autoencoder = VAE(
29 | num_genes=num_gene,
30 | device='cuda',
31 | seed=0,
32 | hidden_dim=128,
33 | decoder_activation='ReLU',
34 | )
35 | autoencoder.load_state_dict(torch.load(ae_dir))
36 | return autoencoder
37 |
38 | def save_data(all_cells, traj, data_dir):
39 | cell_gen = all_cells
40 | np.savez(data_dir, cell_gen=cell_gen)
41 | return
42 |
43 | def main(cell_type=[0], multi=False, inter=False, weight=[10,10]):
44 | args = create_argparser(cell_type, weight).parse_args()
45 |
46 | dist_util.setup_dist()
47 | logger.configure()
48 |
49 | logger.log("creating model and diffusion...")
50 | model, diffusion = create_model_and_diffusion(
51 | **args_to_dict(args, model_and_diffusion_defaults().keys())
52 | )
53 | model.load_state_dict(
54 | dist_util.load_state_dict(args.model_path, map_location="cpu")
55 | )
56 | model.to(dist_util.dev())
57 | model.eval()
58 |
59 | logger.log("loading classifier...")
60 | if multi:
61 | args.num_class = args.num_class1 # how many classes in this condition
62 | classifier1 = create_classifier(**args_to_dict(args, (['num_class']+list(classifier_and_diffusion_defaults().keys()))[:3]))
63 | classifier1.load_state_dict(
64 | dist_util.load_state_dict(args.classifier_path1, map_location="cpu")
65 | )
66 | classifier1.to(dist_util.dev())
67 | classifier1.eval()
68 |
69 | args.num_class = args.num_class2 # how many classes in this condition
70 | classifier2 = create_classifier(**args_to_dict(args, (['num_class']+list(classifier_and_diffusion_defaults().keys()))[:3]))
71 | classifier2.load_state_dict(
72 | dist_util.load_state_dict(args.classifier_path2, map_location="cpu")
73 | )
74 | classifier2.to(dist_util.dev())
75 | classifier2.eval()
76 |
77 | else:
78 | classifier = create_classifier(**args_to_dict(args, (['num_class']+list(classifier_and_diffusion_defaults().keys()))[:3]))
79 | classifier.load_state_dict(
80 | dist_util.load_state_dict(args.classifier_path, map_location="cpu")
81 | )
82 | classifier.to(dist_util.dev())
83 | classifier.eval()
84 |
85 | '''
86 | control function for Gradient Interpolation Strategy
87 | '''
88 | def cond_fn_inter(x, t, y=None, init=None, diffusion=None):
89 | assert y is not None
90 | y1 = y[:,0]
91 | y2 = y[:,1]
92 | # xt = diffusion.q_sample(th.tensor(init,device=dist_util.dev()),t*th.ones(init.shape[0],device=dist_util.dev(),dtype=torch.long),)
93 | with th.enable_grad():
94 | x_in = x.detach().requires_grad_(True)
95 | logits = classifier(x_in, t)
96 | log_probs = F.log_softmax(logits, dim=-1)
97 | selected1 = log_probs[range(len(logits)), y1.view(-1)]
98 | selected2 = log_probs[range(len(logits)), y2.view(-1)]
99 |
100 | grad1 = th.autograd.grad(selected1.sum(), x_in, retain_graph=True)[0] * args.classifier_scale1
101 | grad2 = th.autograd.grad(selected2.sum(), x_in, retain_graph=True)[0] * args.classifier_scale2
102 |
103 | # l2_loss = ((x_in-xt)**2).mean()
104 | # grad3 = th.autograd.grad(-l2_loss, x_in, retain_graph=True)[0] * 100
105 |
106 | return grad1+grad2#+grad3
107 |
108 | '''
109 | control function for multi-conditional generation
110 | Two conditional generation here
111 | '''
112 | def cond_fn_multi(x, t, y=None):
113 | assert y is not None
114 | y1 = y[:,0]
115 | y2 = y[:,1]
116 | with th.enable_grad():
117 | x_in = x.detach().requires_grad_(True)
118 | logits1 = classifier1(x_in, t)
119 | log_probs1 = F.log_softmax(logits1, dim=-1)
120 | selected1 = log_probs1[range(len(logits1)), y1.view(-1)]
121 |
122 | logits2 = classifier2(x_in, t)
123 | log_probs2 = F.log_softmax(logits2, dim=-1)
124 | selected2 = log_probs2[range(len(logits2)), y2.view(-1)]
125 |
126 | grad1 = th.autograd.grad(selected1.sum(), x_in, retain_graph=True)[0] * args.classifier_scale1
127 | grad2 = th.autograd.grad(selected2.sum(), x_in, retain_graph=True)[0] * args.classifier_scale2
128 |
129 | return grad1+grad2
130 |
131 | '''
132 | control function for one conditional generation
133 | '''
134 | def cond_fn_ori(x, t, y=None):
135 | assert y is not None
136 | with th.enable_grad():
137 | x_in = x.detach().requires_grad_(True)
138 | logits = classifier(x_in, t)
139 | log_probs = F.log_softmax(logits, dim=-1)
140 | selected = log_probs[range(len(logits)), y.view(-1)]
141 | grad = th.autograd.grad(selected.sum(), x_in, retain_graph=True)[0] * args.classifier_scale
142 | return grad
143 |
144 | def model_fn(x, t, y=None, init=None, diffusion=None):
145 | assert y is not None
146 | if args.class_cond:
147 | return model(x, t, y if args.class_cond else None)
148 | else:
149 | return model(x, t)
150 |
151 | if inter:
152 | # input real cell expression data as initial noise
153 | ori_adata = sc.read_h5ad(args.init_cell_path)
154 | sc.pp.normalize_total(ori_adata, target_sum=1e4)
155 | sc.pp.log1p(ori_adata)
156 |
157 | logger.log("sampling...")
158 | all_cell = []
159 | sample_num = 0
160 | while sample_num < args.num_samples:
161 | model_kwargs = {}
162 |
163 | if not multi and not inter:
164 | classes = (cell_type[0])*th.ones((args.batch_size,), device=dist_util.dev(), dtype=th.long)
165 |
166 | if multi:
167 | classes1 = (cell_type[0])*th.ones((args.batch_size,), device=dist_util.dev(), dtype=th.long)
168 | classes2 = (cell_type[1])*th.ones((args.batch_size,), device=dist_util.dev(), dtype=th.long)
169 | # classes3 = ... if more conditions
170 | classes = th.stack((classes1,classes2), dim=1)
171 |
172 | if inter:
173 | classes1 = (cell_type[0])*th.ones((args.batch_size,), device=dist_util.dev(), dtype=th.long)
174 | classes2 = (cell_type[1])*th.ones((args.batch_size,), device=dist_util.dev(), dtype=th.long)
175 | classes = th.stack((classes1,classes2), dim=1)
176 |
177 | model_kwargs["y"] = classes
178 | sample_fn = (
179 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
180 | )
181 |
182 | if inter:
183 | celltype = ori_adata.obs['period'].cat.categories.tolist()[cell_type[0]]
184 | adata = ori_adata[ori_adata.obs['period']==celltype].copy()
185 |
186 | start_x = adata.X
187 | autoencoder = load_VAE(args.ae_dir, args.num_gene)
188 | start_x = autoencoder(torch.tensor(start_x,device=dist_util.dev()),return_latent=True).detach().cpu().numpy()
189 |
190 | n, m = start_x.shape
191 | if n >= args.batch_size:
192 | start_x = start_x[:args.batch_size, :]
193 | else:
194 | repeat_times = args.batch_size // n
195 | remainder = args.batch_size % n
196 | start_x = np.concatenate([start_x] * repeat_times + [start_x[:remainder, :]], axis=0)
197 |
198 | noise = diffusion.q_sample(th.tensor(start_x,device=dist_util.dev()),args.init_time*th.ones(start_x.shape[0],device=dist_util.dev(),dtype=torch.long),)
199 | model_kwargs["init"] = start_x
200 | model_kwargs["diffusion"] = diffusion
201 |
202 | if multi:
203 | sample, traj = sample_fn(
204 | model_fn,
205 | (args.batch_size, args.input_dim),
206 | clip_denoised=args.clip_denoised,
207 | model_kwargs=model_kwargs,
208 | cond_fn=cond_fn_multi,
209 | device=dist_util.dev(),
210 | noise = None,
211 | start_time=diffusion.betas.shape[0],
212 | start_guide_steps=args.start_guide_steps,
213 | )
214 | elif inter:
215 | sample, traj = sample_fn(
216 | model_fn,
217 | (args.batch_size, args.input_dim),
218 | clip_denoised=args.clip_denoised,
219 | model_kwargs=model_kwargs,
220 | cond_fn=cond_fn_inter,
221 | device=dist_util.dev(),
222 | noise = noise,
223 | start_time=diffusion.betas.shape[0],
224 | start_guide_steps=args.start_guide_steps,
225 | )
226 | else:
227 | sample, traj = sample_fn(
228 | model_fn,
229 | (args.batch_size, args.input_dim),
230 | clip_denoised=args.clip_denoised,
231 | model_kwargs=model_kwargs,
232 | cond_fn=cond_fn_ori,
233 | device=dist_util.dev(),
234 | noise = None,
235 | )
236 |
237 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
238 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
239 | if args.filter:
240 | for sample in gathered_samples:
241 | if multi:
242 | logits1 = classifier1(sample, torch.zeros((sample.shape[0]), device=sample.device))
243 | logits2 = classifier2(sample, torch.zeros((sample.shape[0]), device=sample.device))
244 | prob1 = F.softmax(logits1, dim=-1)
245 | prob2 = F.softmax(logits2, dim=-1)
246 | type1 = torch.argmax(prob1, 1)
247 | type2 = torch.argmax(prob2, 1)
248 | select_index = ((type1 == cell_type[0]) & (type2 == cell_type[1]))
249 | all_cell.extend([sample[select_index].cpu().numpy()])
250 | sample_num += select_index.sum().item()
251 | elif inter:
252 | logits = classifier(sample, torch.zeros((sample.shape[0]), device=sample.device))
253 | prob = F.softmax(logits, dim=-1)
254 | left = (prob[:,cell_type[0]] > weight[0]/10-0.15) & (prob[:,cell_type[0]] < weight[0]/10+0.15)
255 | right = (prob[:,cell_type[1]] > weight[1]/10-0.15) & (prob[:,cell_type[1]] < weight[1]/10+0.15)
256 | select_index = left & right
257 | all_cell.extend([sample[select_index].cpu().numpy()])
258 | sample_num += select_index.sum().item()
259 | else:
260 | logits = classifier(sample, torch.zeros((sample.shape[0]), device=sample.device))
261 | prob = F.softmax(logits, dim=-1)
262 | type = torch.argmax(prob, 1)
263 | select_index = (type == cell_type[0])
264 | all_cell.extend([sample[select_index].cpu().numpy()])
265 | sample_num += select_index.sum().item()
266 | logger.log(f"created {sample_num} samples")
267 | else:
268 | all_cell.extend([sample.cpu().numpy() for sample in gathered_samples])
269 | sample_num = len(all_cell) * args.batch_size
270 | logger.log(f"created {len(all_cell) * args.batch_size} samples")
271 |
272 | arr = np.concatenate(all_cell, axis=0)
273 | save_data(arr, traj, args.sample_dir+str(cell_type[0]))
274 |
275 | dist.barrier()
276 | logger.log("sampling complete")
277 |
278 |
279 | def create_argparser(celltype=[0], weight=[10,10]):
280 | defaults = dict(
281 | clip_denoised=True,
282 | num_samples=9000,
283 | batch_size=3000,
284 | use_ddim=False,
285 | class_cond=False,
286 |
287 | model_path="output/diffusion_checkpoint/muris_diffusion/model000000.pt",
288 |
289 | # ***if commen conditional generation & gradiante interpolation, use this path***
290 | classifier_path="output/classifier_checkpoint/classifier_muris/model000100.pt",
291 | # ***if multi-conditional, use this path. replace this to your own classifiers***
292 | classifier_path1="output/classifier_checkpoint/classifier_muris_ood_type/model200000.pt",
293 | classifier_path2="output/classifier_checkpoint/classifier_muris_ood_organ/model200000.pt",
294 | num_class1 = 2, # set this to the number of classes in your own dataset. this is the first condition (for example cell organ).
295 | num_class2 = 2, # this is the second condition (for example cell type).
296 |
297 | # ***if commen conditional generation, use this scale***
298 | classifier_scale=2,
299 | # ***in multi-conditional, use this scale. scale1 and scale2 are the weights of two classifiers***
300 | # ***in Gradient Interpolation, use this scale, too. scale1 and scale2 are the weights of two gradients***
301 | classifier_scale1=weight[0]*2/10,
302 | classifier_scale2=weight[1]*2/10,
303 |
304 | # ***if gradient interpolation, replace these base on your own situation***
305 | ae_dir='output/Autoencoder_checkpoint/WOT/model_seed=0_step=150000.pt',
306 | num_gene=19423,
307 | init_time = 600, # initial noised state if interpolation
308 | init_cell_path = 'data/WOT/filted_data.h5ad', #input initial noised cell state
309 |
310 | sample_dir=f"output/simulated_samples/muris",
311 | start_guide_steps = 500, # the time to use classifier guidance
312 | filter = False, # filter the simulated cells that are classified into other condition, might take long time
313 |
314 | )
315 | defaults.update(model_and_diffusion_defaults())
316 | defaults.update(classifier_and_diffusion_defaults())
317 | defaults['num_class']=12
318 | parser = argparse.ArgumentParser()
319 | add_dict_to_argparser(parser, defaults)
320 | return parser
321 |
322 |
323 | if __name__ == "__main__":
324 | # for conditional generation
325 | # main(cell_type=[2])
326 | for type in range(12):
327 | main(cell_type=[type])
328 |
329 | # ***for multi-condition, run***
330 | # muris ood
331 | # for i in [0,1]:
332 | # for j in [0,1]:
333 | # main(cell_type=[i,j],multi=True)
334 |
335 | # ***for Gradient Interpolation, run***
336 | # for i in range(0,11):
337 | # main(cell_type=[6,7], inter=True, weight=[10-i,i])
338 | # for i in range(18):
339 | # main(cell_type=[i,i+1], inter=True, weight=[5,5])
--------------------------------------------------------------------------------
/classifier_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a noised image classifier on ImageNet.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import blobfile as bf
9 | import torch as th
10 | import torch.distributed as dist
11 | import torch.nn.functional as F
12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
13 | from torch.optim import AdamW
14 |
15 | from guided_diffusion import dist_util, logger
16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer
17 | from guided_diffusion.cell_datasets_loader import load_data
18 | from guided_diffusion.resample import create_named_schedule_sampler
19 | from guided_diffusion.script_util import (
20 | add_dict_to_argparser,
21 | args_to_dict,
22 | classifier_and_diffusion_defaults,
23 | create_classifier_and_diffusion,
24 | )
25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict
26 | import torch
27 | import torch.nn as nn
28 | import numpy as np
29 |
30 | def main():
31 | args = create_argparser().parse_args()
32 |
33 | setup_seed(1234)
34 |
35 | dist_util.setup_dist()
36 | logger.configure()
37 |
38 | logger.log("creating model and diffusion...")
39 | model, diffusion = create_classifier_and_diffusion(
40 | **args_to_dict(args, classifier_and_diffusion_defaults().keys())
41 | )
42 | model.to(dist_util.dev())
43 | if args.noised:
44 | schedule_sampler = create_named_schedule_sampler(
45 | args.schedule_sampler, diffusion
46 | )
47 |
48 | resume_step = 0
49 | if args.resume_checkpoint:
50 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
51 | if dist.get_rank() == 0:
52 | logger.log(
53 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
54 | )
55 | model.load_state_dict(
56 | dist_util.load_state_dict(
57 | args.resume_checkpoint, map_location=dist_util.dev()
58 | )
59 | )
60 |
61 | # Needed for creating correct EMAs and fp16 parameters.
62 | dist_util.sync_params(model.parameters())
63 |
64 | mp_trainer = MixedPrecisionTrainer(
65 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
66 | )
67 |
68 | model = DDP(
69 | model,
70 | device_ids=[dist_util.dev()],
71 | output_device=dist_util.dev(),
72 | broadcast_buffers=False,
73 | bucket_cap_mb=128,
74 | find_unused_parameters=True,
75 | )
76 |
77 | logger.log("creating data loader...")
78 | data = load_data(
79 | data_dir=args.data_dir,
80 | batch_size=args.batch_size,
81 | vae_path=args.vae_path,
82 | hidden_dim=args.latent_dim,
83 | train_vae=False,
84 | )
85 | if args.val_data_dir:
86 | val_data = load_data(
87 | data_dir=args.val_data_dir,
88 | batch_size=args.batch_size,
89 | vae_path=args.vae_path,
90 | hidden_dim=args.latent_dim,
91 | train_vae=False,
92 | )
93 | else:
94 | val_data = None
95 |
96 | logger.log(f"creating optimizer...")
97 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
98 | if args.resume_checkpoint:
99 | opt_checkpoint = bf.join(
100 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
101 | )
102 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
103 | opt.load_state_dict(
104 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
105 | )
106 |
107 | logger.log("training classifier model...")
108 |
109 | def forward_backward_log(data_loader, prefix="train"):
110 | batch, extra = next(data_loader)
111 | labels = extra["y"].to(dist_util.dev())
112 |
113 | batch = batch.to(dist_util.dev())
114 | # Noisy cells
115 | if args.noised:
116 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev(), start_guide_time=args.start_guide_time)
117 | batch = diffusion.q_sample(batch, t)
118 | else:
119 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())
120 |
121 | for i, (sub_batch, sub_labels, sub_t) in enumerate(
122 | split_microbatches(args.microbatch, batch, labels, t)
123 | ):
124 | logits = model(sub_batch, sub_t)
125 | loss = F.cross_entropy(logits, sub_labels, reduction="none")
126 |
127 | losses = {}
128 | losses[f"{prefix}_loss"] = loss.detach()
129 | losses[f"{prefix}_acc@1"] = compute_top_k(
130 | logits, sub_labels, k=1, reduction="none"
131 | )
132 |
133 | log_loss_dict(diffusion, sub_t, losses)
134 | del losses
135 | loss = loss.mean()
136 | if loss.requires_grad:
137 | if i == 0:
138 | mp_trainer.zero_grad()
139 | mp_trainer.backward(loss * len(sub_batch) / len(batch))
140 |
141 | model_path = args.model_path
142 | for step in range(args.iterations - resume_step):
143 | logger.logkv("step", step + resume_step)
144 | logger.logkv(
145 | "samples",
146 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
147 | )
148 | if args.anneal_lr:
149 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
150 | forward_backward_log(data)
151 | mp_trainer.optimize(opt)
152 | if val_data is not None and not step % args.eval_interval:
153 | with th.no_grad():
154 | with model.no_sync():
155 | model.eval()
156 | forward_backward_log(val_data, prefix="val")
157 | model.train()
158 | if not step % args.log_interval:
159 | logger.dumpkvs()
160 | if (
161 | step
162 | and dist.get_rank() == 0
163 | and not (step + resume_step) % args.save_interval
164 | ):
165 | logger.log("saving model...")
166 | save_model(mp_trainer, opt, step + resume_step, model_path)
167 |
168 | if dist.get_rank() == 0:
169 | logger.log("saving model...")
170 | save_model(mp_trainer, opt, step + resume_step, model_path)
171 | dist.barrier()
172 |
173 |
174 | def set_annealed_lr(opt, base_lr, frac_done):
175 | lr = base_lr * (1 - frac_done)
176 | for param_group in opt.param_groups:
177 | param_group["lr"] = lr
178 |
179 |
180 | def save_model(mp_trainer, opt, step, model_path):
181 | if dist.get_rank() == 0:
182 | model_dir = model_path
183 | os.makedirs(model_dir,exist_ok=True)
184 | th.save(
185 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
186 | os.path.join(model_dir, f"model{step:06d}.pt"),
187 | )
188 | th.save(opt.state_dict(), os.path.join(model_dir, f"opt{step:06d}.pt"))
189 |
190 |
191 | def compute_top_k(logits, labels, k, reduction="mean"):
192 | _, top_ks = th.topk(logits, k, dim=-1)
193 | if reduction == "mean":
194 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
195 | elif reduction == "none":
196 | return (top_ks == labels[:, None]).float().sum(dim=-1)
197 |
198 |
199 | def split_microbatches(microbatch, *args):
200 | bs = len(args[0])
201 | if microbatch == -1 or microbatch >= bs:
202 | yield tuple(args)
203 | else:
204 | for i in range(0, bs, microbatch):
205 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args)
206 |
207 |
208 | def create_argparser():
209 | defaults = dict(
210 | data_dir="/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad",
211 | val_data_dir="",
212 | noised=True,
213 | iterations=500000,
214 | lr=3e-4,
215 | weight_decay=0.0,
216 | anneal_lr=False,
217 | batch_size=128,
218 | microbatch=-1,
219 | schedule_sampler="uniform",
220 | resume_checkpoint="",
221 | log_interval=100,
222 | eval_interval=100,
223 | save_interval=100000,
224 | vae_path='output/Autoencoder_checkpoint/muris_AE/model_seed=0_step=0.pt',
225 | latent_dim=128,
226 | model_path='output/classifier_checkpoint/classifier_muris',
227 | start_guide_time=500,
228 | num_class=12,
229 | )
230 | num_class = defaults['num_class']
231 | defaults.update(classifier_and_diffusion_defaults())
232 | defaults['num_class']= num_class
233 | parser = argparse.ArgumentParser()
234 | add_dict_to_argparser(parser, defaults)
235 | return parser
236 |
237 | def setup_seed(seed):
238 | torch.manual_seed(seed)
239 | torch.cuda.manual_seed_all(seed)
240 | np.random.seed(seed)
241 | torch.backends.cudnn.deterministic = True
242 |
243 | if __name__ == "__main__":
244 | main()
245 |
--------------------------------------------------------------------------------
/exp_script/script_description.md:
--------------------------------------------------------------------------------
1 | exp_script/down_stream_analysis_muris.ipynb
2 | QQ plot for generated data
3 |
4 | exp_script/script_diffusion_interpolation.ipynb
5 | generate data using gradient interpolation
6 |
7 | exp_script/script_diffusion_multi-condi.ipynb
8 | generate data with more than one condition
9 |
10 | exp_script/script_diffusion_umap.ipynb
11 | plot UMAP for the generated data
12 |
13 | exp_script/script_static_eval.ipynb
14 | statistical evaluation for the generated data (conditional and unconditional)
15 |
16 | exp_script/script_random_forest.ipynb
17 | use random forest to classify real and generated data
--------------------------------------------------------------------------------
/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_WOT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader, Dataset
3 |
4 | import scanpy as sc
5 | import torch
6 | import sys
7 | sys.path.append('..')
8 | from VAE.VAE_model import VAE
9 | from sklearn.preprocessing import LabelEncoder
10 |
11 | def load_VAE(vae_path, num_gene, hidden_dim):
12 | autoencoder = VAE(
13 | num_genes=num_gene,
14 | device='cuda',
15 | seed=0,
16 | loss_ae='mse',
17 | hidden_dim=hidden_dim,
18 | decoder_activation='ReLU',
19 | )
20 | autoencoder.load_state_dict(torch.load(vae_path))
21 | return autoencoder
22 |
23 | def load_data(
24 | *,
25 | data_dir,
26 | batch_size,
27 | vae_path=None,
28 | deterministic=False,
29 | train_vae=False,
30 | hidden_dim=128,
31 | ):
32 | """
33 | For a dataset, create a generator over (cells, kwargs) pairs.
34 |
35 | :param data_dir: a dataset directory.
36 | :param batch_size: the batch size of each returned pair.
37 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
38 | :param deterministic: if True, yield results in a deterministic order.
39 | :param train_vae: train the autoencoder or use the autoencoder.
40 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
41 | """
42 | if not data_dir:
43 | raise ValueError("unspecified data directory")
44 |
45 | adata = sc.read_h5ad(data_dir) # dataset already filter cells and genes
46 |
47 | sc.pp.normalize_total(adata, target_sum=1e4)
48 | sc.pp.log1p(adata)
49 |
50 | adata = adata[np.where(np.in1d(adata.obs['period'], ['D0','D0.5','D1','D1.5','D2','D2.5','D3','D4.5','D5','D5.5','D6','D6.5','D7','D7.5','D8']))[0]]
51 | print(adata)
52 |
53 | label_encoder = LabelEncoder()
54 | label_encoder.fit(adata.obs['period'])
55 | label_encoder.classes_= np.array(['D0','D0.5','D1','D1.5','D2','D2.5','D3','D4.5','D5','D5.5','D6','D6.5','D7','D7.5','D8'])
56 | classes = label_encoder.transform(adata.obs['period'])
57 | print(label_encoder.classes_)
58 |
59 | cell_data = adata.X
60 |
61 | # if not train autoencoder
62 | if not train_vae:
63 | num_gene = cell_data.shape[1]
64 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
65 | cell_data1 = autoencoder(torch.tensor(cell_data)[::2].cuda(),return_latent=True).cpu().detach().numpy()
66 | cell_data2 = autoencoder(torch.tensor(cell_data)[1::2].cuda(),return_latent=True).cpu().detach().numpy()
67 | cell_data = np.concatenate((cell_data1,cell_data2))
68 |
69 | classes = np.concatenate((classes[::2],classes[1::2]))
70 | print(cell_data.shape)
71 |
72 | dataset = CellDataset(
73 | cell_data,
74 | classes
75 | )
76 | if deterministic:
77 | loader = DataLoader(
78 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
79 | )
80 | else:
81 | loader = DataLoader(
82 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
83 | )
84 | while True:
85 | yield from loader
86 |
87 |
88 | class CellDataset(Dataset):
89 | def __init__(
90 | self,
91 | cell_data,
92 | class_name
93 | ):
94 | super().__init__()
95 | self.data = cell_data
96 | self.class_name = class_name
97 |
98 | def __len__(self):
99 | return self.data.shape[0]
100 |
101 | def __getitem__(self, idx):
102 | arr = self.data[idx]
103 | out_dict = {}
104 | if self.class_name is not None:
105 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
106 | return arr, out_dict
107 |
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_loader.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | import numpy as np
7 | from torch.utils.data import DataLoader, Dataset
8 |
9 | import scanpy as sc
10 | import torch
11 | import sys
12 | sys.path.append('..')
13 | from VAE.VAE_model import VAE
14 | from sklearn.preprocessing import LabelEncoder
15 |
16 | def stabilize(expression_matrix):
17 | ''' Use Anscombes approximation to variance stabilize Negative Binomial data
18 | See https://f1000research.com/posters/4-1041 for motivation.
19 | Assumes columns are samples, and rows are genes
20 | '''
21 | from scipy import optimize
22 | phi_hat, _ = optimize.curve_fit(lambda mu, phi: mu + phi * mu ** 2, expression_matrix.mean(1), expression_matrix.var(1))
23 |
24 | return np.log(expression_matrix + 1. / (2 * phi_hat[0]))
25 |
26 | def load_VAE(vae_path, num_gene, hidden_dim):
27 | autoencoder = VAE(
28 | num_genes=num_gene,
29 | device='cuda',
30 | seed=0,
31 | loss_ae='mse',
32 | hidden_dim=hidden_dim,
33 | decoder_activation='ReLU',
34 | )
35 | autoencoder.load_state_dict(torch.load(vae_path))
36 | return autoencoder
37 |
38 |
39 | def load_data(
40 | *,
41 | data_dir,
42 | batch_size,
43 | vae_path=None,
44 | deterministic=False,
45 | train_vae=False,
46 | hidden_dim=128,
47 | ):
48 | """
49 | For a dataset, create a generator over (cells, kwargs) pairs.
50 |
51 | :param data_dir: a dataset directory.
52 | :param batch_size: the batch size of each returned pair.
53 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
54 | :param deterministic: if True, yield results in a deterministic order.
55 | :param train_vae: train the autoencoder or use the autoencoder.
56 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
57 | """
58 | if not data_dir:
59 | raise ValueError("unspecified data directory")
60 |
61 | adata = sc.read_h5ad(data_dir)
62 |
63 | # preporcess the data. modify this part if use your own dataset. the gene expression must first norm1e4 then log1p
64 | sc.pp.filter_genes(adata, min_cells=3)
65 | sc.pp.filter_cells(adata, min_genes=10)
66 | adata.var_names_make_unique()
67 |
68 | # if generate ood data, left this as the ood data
69 | # selected_cells = (adata.obs['organ'] != 'mammary') | (adata.obs['celltype'] != 'B cell')
70 | # adata = adata[selected_cells, :]
71 |
72 | classes = adata.obs['celltype'].values
73 | label_encoder = LabelEncoder()
74 | labels = classes
75 | label_encoder.fit(labels)
76 | classes = label_encoder.transform(labels)
77 |
78 | sc.pp.normalize_total(adata, target_sum=1e4)
79 | sc.pp.log1p(adata)
80 |
81 | cell_data = adata.X.toarray()
82 |
83 | # turn the gene expression into latent space. use this if training the diffusion backbone.
84 | if not train_vae:
85 | num_gene = cell_data.shape[1]
86 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
87 | cell_data = autoencoder(torch.tensor(cell_data).cuda(),return_latent=True)
88 | cell_data = cell_data.cpu().detach().numpy()
89 |
90 | dataset = CellDataset(
91 | cell_data,
92 | classes
93 | )
94 | if deterministic:
95 | loader = DataLoader(
96 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
97 | )
98 | else:
99 | loader = DataLoader(
100 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
101 | )
102 | while True:
103 | yield from loader
104 |
105 |
106 | class CellDataset(Dataset):
107 | def __init__(
108 | self,
109 | cell_data,
110 | class_name
111 | ):
112 | super().__init__()
113 | self.data = cell_data
114 | self.class_name = class_name
115 |
116 | def __len__(self):
117 | return self.data.shape[0]
118 |
119 | def __getitem__(self, idx):
120 | arr = self.data[idx]
121 | out_dict = {}
122 | if self.class_name is not None:
123 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
124 | return arr, out_dict
125 |
126 |
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_lung.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader, Dataset
3 |
4 | import scanpy as sc
5 | import pandas as pd
6 | import torch
7 | import sys
8 | sys.path.append('..')
9 | from VAE.VAE_model import VAE
10 |
11 | from sklearn.preprocessing import LabelEncoder
12 |
13 | def load_VAE(vae_path, num_gene, hidden_dim):
14 | autoencoder = VAE(
15 | num_genes=num_gene,
16 | device='cuda',
17 | seed=0,
18 | loss_ae='mse',
19 | hidden_dim=hidden_dim,
20 | decoder_activation='ReLU',
21 | )
22 | autoencoder.load_state_dict(torch.load(vae_path))
23 | return autoencoder
24 |
25 | def load_data(
26 | *,
27 | data_dir,
28 | batch_size,
29 | vae_path=None,
30 | deterministic=False,
31 | train_vae=False,
32 | hidden_dim=128,
33 | ):
34 | """
35 | For a dataset, create a generator over (cells, kwargs) pairs.
36 |
37 | :param data_dir: a dataset directory.
38 | :param batch_size: the batch size of each returned pair.
39 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
40 | :param deterministic: if True, yield results in a deterministic order.
41 | :param train_vae: train the autoencoder or use the autoencoder.
42 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
43 | """
44 | if not data_dir:
45 | raise ValueError("unspecified data directory")
46 |
47 |
48 | adata = sc.read_h5ad(data_dir)
49 | sc.pp.filter_genes(adata, min_cells=3)
50 | sc.pp.filter_cells(adata, min_genes=10)
51 | adata.var_names_make_unique()
52 |
53 | sc.pp.normalize_total(adata, target_sum=1e4)
54 | sc.pp.log1p(adata)
55 |
56 | celltype = adata.obs['celltype']
57 | label_encoder = LabelEncoder()
58 | label_encoder.fit(celltype)
59 | classes = label_encoder.transform(celltype)
60 | print(label_encoder.classes_)
61 |
62 | cell_data = adata.X.toarray()
63 |
64 | # if not train autoencoder
65 | if not train_vae:
66 | num_gene = cell_data.shape[1]
67 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
68 | cell_data = autoencoder(torch.tensor(cell_data).cuda(),return_latent=True)
69 | cell_data = cell_data.cpu().detach().numpy()
70 |
71 | dataset = CellDataset(
72 | cell_data,
73 | classes
74 | )
75 | if deterministic:
76 | loader = DataLoader(
77 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
78 | )
79 | else:
80 | loader = DataLoader(
81 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
82 | )
83 | while True:
84 | yield from loader
85 |
86 | class CellDataset(Dataset):
87 | def __init__(
88 | self,
89 | cell_data,
90 | class_name
91 | ):
92 | super().__init__()
93 | self.data = cell_data
94 | self.class_name = class_name
95 |
96 | def __len__(self):
97 | return self.data.shape[0]
98 |
99 | def __getitem__(self, idx):
100 | arr = self.data[idx]
101 | out_dict = {}
102 | if self.class_name is not None:
103 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
104 | return arr, out_dict
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_muris.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | import numpy as np
7 | from torch.utils.data import DataLoader, Dataset
8 |
9 | import scanpy as sc
10 | import torch
11 | import sys
12 | sys.path.append('..')
13 | from VAE.VAE_model import VAE
14 | from sklearn.preprocessing import LabelEncoder
15 |
16 | def stabilize(expression_matrix):
17 | ''' Use Anscombes approximation to variance stabilize Negative Binomial data
18 | See https://f1000research.com/posters/4-1041 for motivation.
19 | Assumes columns are samples, and rows are genes
20 | '''
21 | from scipy import optimize
22 | phi_hat, _ = optimize.curve_fit(lambda mu, phi: mu + phi * mu ** 2, expression_matrix.mean(1), expression_matrix.var(1))
23 |
24 | return np.log(expression_matrix + 1. / (2 * phi_hat[0]))
25 |
26 | def load_VAE(vae_path, num_gene, hidden_dim):
27 | autoencoder = VAE(
28 | num_genes=num_gene,
29 | device='cuda',
30 | seed=0,
31 | loss_ae='mse',
32 | hidden_dim=hidden_dim,
33 | decoder_activation='ReLU',
34 | )
35 | autoencoder.load_state_dict(torch.load(vae_path))
36 | return autoencoder
37 |
38 |
39 | def load_data(
40 | *,
41 | data_dir,
42 | batch_size,
43 | vae_path=None,
44 | deterministic=False,
45 | train_vae=False,
46 | hidden_dim=128,
47 | ):
48 | """
49 | For a dataset, create a generator over (cells, kwargs) pairs.
50 |
51 | :param data_dir: a dataset directory.
52 | :param batch_size: the batch size of each returned pair.
53 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
54 | :param deterministic: if True, yield results in a deterministic order.
55 | :param train_vae: train the autoencoder or use the autoencoder.
56 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
57 | """
58 | if not data_dir:
59 | raise ValueError("unspecified data directory")
60 |
61 | adata = sc.read_h5ad(data_dir)
62 | sc.pp.filter_genes(adata, min_cells=3)
63 | sc.pp.filter_cells(adata, min_genes=10)
64 | adata.var_names_make_unique()
65 |
66 | # if generate ood data, left this as the ood data
67 | # selected_cells = (adata.obs['organ'] != 'mammary') | (adata.obs['celltype'] != 'B cell')
68 | # adata = adata[selected_cells, :]
69 |
70 | classes = adata.obs['celltype'].values
71 | label_encoder = LabelEncoder()
72 | labels = classes
73 | label_encoder.fit(labels)
74 | classes = label_encoder.transform(labels)
75 |
76 | sc.pp.normalize_total(adata, target_sum=1e4)
77 | sc.pp.log1p(adata)
78 |
79 | cell_data = adata.X.toarray()
80 |
81 | # if use vae
82 | if not train_vae:
83 | num_gene = cell_data.shape[1]
84 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
85 | cell_data = autoencoder(torch.tensor(cell_data).cuda(),return_latent=True)
86 | cell_data = cell_data.cpu().detach().numpy()
87 |
88 | dataset = CellDataset(
89 | cell_data,
90 | classes
91 | )
92 | if deterministic:
93 | loader = DataLoader(
94 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
95 | )
96 | else:
97 | loader = DataLoader(
98 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
99 | )
100 | while True:
101 | yield from loader
102 |
103 |
104 | class CellDataset(Dataset):
105 | def __init__(
106 | self,
107 | cell_data,
108 | class_name
109 | ):
110 | super().__init__()
111 | self.data = cell_data
112 | self.class_name = class_name
113 |
114 | def __len__(self):
115 | return self.data.shape[0]
116 |
117 | def __getitem__(self, idx):
118 | arr = self.data[idx]
119 | out_dict = {}
120 | if self.class_name is not None:
121 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
122 | return arr, out_dict
123 |
124 |
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_pbmc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader, Dataset
3 |
4 | import scanpy as sc
5 | import pandas as pd
6 | import torch
7 | import sys
8 | sys.path.append('..')
9 | from VAE.VAE_model import VAE
10 |
11 | from sklearn.preprocessing import LabelEncoder
12 |
13 | def load_VAE(vae_path, num_gene, hidden_dim):
14 | autoencoder = VAE(
15 | num_genes=num_gene,
16 | device='cuda',
17 | seed=0,
18 | loss_ae='mse',
19 | hidden_dim=hidden_dim,
20 | decoder_activation='ReLU',
21 | )
22 | autoencoder.load_state_dict(torch.load(vae_path))
23 | return autoencoder
24 |
25 | def load_data(
26 | *,
27 | data_dir,
28 | batch_size,
29 | vae_path=None,
30 | deterministic=False,
31 | train_vae=False,
32 | hidden_dim=128,
33 | ):
34 | """
35 | For a dataset, create a generator over (cells, kwargs) pairs.
36 |
37 | :param data_dir: a dataset directory.
38 | :param batch_size: the batch size of each returned pair.
39 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
40 | :param deterministic: if True, yield results in a deterministic order.
41 | :param train_vae: train the autoencoder or use the autoencoder.
42 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
43 | """
44 | if not data_dir:
45 | raise ValueError("unspecified data directory")
46 |
47 |
48 | adata = sc.read_10x_mtx(
49 | data_dir, # the directory with the `.mtx` file
50 | var_names='gene_symbols', # use gene symbols for the variable names (variables-axis index)
51 | cache=True) # write a cache file for faster subsequent reading
52 |
53 | adata.var_names_make_unique()
54 | sc.pp.filter_cells(adata, min_genes=10)
55 | sc.pp.filter_genes(adata, min_cells=3)
56 |
57 | sc.pp.normalize_total(adata, target_sum=1e4)
58 | sc.pp.log1p(adata)
59 |
60 | celltype = pd.read_csv('/data1/lep/Workspace/guided-diffusion/data/pbmc68k/analysis_csv/68k_pbmc_barcodes_annotation.tsv', sep='\t')['celltype'].values
61 |
62 | adata.obs['celltype'] = celltype
63 | label_encoder = LabelEncoder()
64 | label_encoder.fit(celltype)
65 | classes = label_encoder.transform(celltype)
66 |
67 | cell_data = adata.X.toarray()
68 |
69 | # if not train autoencoder
70 | if not train_vae:
71 | num_gene = cell_data.shape[1]
72 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
73 | cell_data = autoencoder(torch.tensor(cell_data).cuda(),return_latent=True)
74 | cell_data = cell_data.cpu().detach().numpy()
75 |
76 | dataset = CellDataset(
77 | cell_data,
78 | classes
79 | )
80 | if deterministic:
81 | loader = DataLoader(
82 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
83 | )
84 | else:
85 | loader = DataLoader(
86 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
87 | )
88 | while True:
89 | yield from loader
90 |
91 | class CellDataset(Dataset):
92 | def __init__(
93 | self,
94 | cell_data,
95 | class_name
96 | ):
97 | super().__init__()
98 | self.data = cell_data
99 | self.class_name = class_name
100 |
101 | def __len__(self):
102 | return self.data.shape[0]
103 |
104 | def __getitem__(self, idx):
105 | arr = self.data[idx]
106 | out_dict = {}
107 | if self.class_name is not None:
108 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
109 | return arr, out_dict
--------------------------------------------------------------------------------
/guided_diffusion/cell_datasets_sapiens.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader, Dataset
3 |
4 | import scanpy as sc
5 | import torch
6 | import sys
7 | sys.path.append('..')
8 | from VAE.VAE_model import VAE
9 | from sklearn.preprocessing import LabelEncoder
10 |
11 |
12 | def load_VAE(vae_path, num_gene, hidden_dim):
13 | autoencoder = VAE(
14 | num_genes=num_gene,
15 | device='cuda',
16 | seed=0,
17 | loss_ae='mse',
18 | hidden_dim=hidden_dim,
19 | decoder_activation='ReLU',
20 | )
21 | autoencoder.load_state_dict(torch.load(vae_path))
22 | return autoencoder
23 |
24 |
25 | def load_data(
26 | *,
27 | data_dir,
28 | batch_size,
29 | vae_path=None,
30 | deterministic=False,
31 | train_vae=False,
32 | hidden_dim=128,
33 | ):
34 | """
35 | For a dataset, create a generator over (cells, kwargs) pairs.
36 |
37 | :param data_dir: a dataset directory.
38 | :param batch_size: the batch size of each returned pair.
39 | :param vae_path: the path to save autoencoder / read autoencoder checkpoint.
40 | :param deterministic: if True, yield results in a deterministic order.
41 | :param train_vae: train the autoencoder or use the autoencoder.
42 | :param hidden_dim: the dimensions of latent space. If use pretrained weight, set 128
43 | """
44 | if not data_dir:
45 | raise ValueError("unspecified data directory")
46 |
47 | adata = sc.read_h5ad(data_dir)
48 | adata.var_names_make_unique() # has been process by the SCimilarity code base. No need to filter cells and genes
49 |
50 | # filter spleen macrophage cell
51 | selected_cells = (adata.obs['organ_tissue'] != 'Spleen') | (adata.obs['free_annotation'] != 'macrophage')
52 | adata = adata[selected_cells, :]
53 |
54 | # filter Thymus memory b cell
55 | selected_cells = (adata.obs['organ_tissue'] != 'Thymus') | (adata.obs['free_annotation'] != 'memory b cell')
56 | adata = adata[selected_cells, :]
57 |
58 | classes = adata.obs['organ_tissue'].values
59 | label_encoder = LabelEncoder()
60 | labels = classes
61 | label_encoder.fit(labels)
62 | classes = label_encoder.transform(labels)
63 | print(label_encoder.classes_)
64 |
65 | sc.pp.normalize_total(adata, target_sum=1e4)
66 | sc.pp.log1p(adata)
67 |
68 | cell_data = adata.X.toarray()
69 |
70 | # if not train autoencoder
71 | if not train_vae:
72 | num_gene = cell_data.shape[1]
73 | autoencoder = load_VAE(vae_path,num_gene,hidden_dim)
74 | cell_data = autoencoder(torch.tensor(cell_data).cuda(),return_latent=True)
75 | cell_data = cell_data.cpu().detach().numpy()
76 |
77 | dataset = CellDataset(
78 | cell_data,
79 | classes
80 | )
81 | if deterministic:
82 | loader = DataLoader(
83 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
84 | )
85 | else:
86 | loader = DataLoader(
87 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
88 | )
89 | while True:
90 | yield from loader
91 |
92 | class CellDataset(Dataset):
93 | def __init__(
94 | self,
95 | cell_data,
96 | class_name
97 | ):
98 | super().__init__()
99 | self.data = cell_data
100 | self.class_name = class_name
101 |
102 | def __len__(self):
103 | return self.data.shape[0]
104 |
105 | def __getitem__(self, idx):
106 | arr = self.data[idx]
107 | out_dict = {}
108 | if self.class_name is not None:
109 | out_dict["y"] = np.array(self.class_name[idx], dtype=np.int64)
110 | return arr, out_dict
111 |
--------------------------------------------------------------------------------
/guided_diffusion/cell_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .nn import (
5 | linear,
6 | timestep_embedding,
7 | )
8 |
9 | class TimeEmbedding(nn.Module):
10 | def __init__(self, hidden_dim):
11 | super(TimeEmbedding, self).__init__()
12 | self.time_embed = nn.Sequential(
13 | nn.Linear(hidden_dim, hidden_dim),
14 | nn.SiLU(),
15 | nn.Linear(hidden_dim, hidden_dim),
16 | )
17 | self.hidden_dim = hidden_dim
18 |
19 | def forward(self, t):
20 | return self.time_embed(timestep_embedding(t, self.hidden_dim).squeeze(1))
21 |
22 | class ResidualBlock(nn.Module):
23 | def __init__(self, in_features, out_features, time_features):
24 | super(ResidualBlock, self).__init__()
25 | self.fc = nn.Linear(in_features, out_features)
26 | self.norm = nn.LayerNorm(out_features)
27 | self.emb_layer = nn.Sequential(
28 | nn.SiLU(),
29 | linear(
30 | time_features,
31 | out_features,
32 | ),
33 | )
34 | self.act = nn.SiLU()
35 | self.drop = nn.Dropout(0)
36 |
37 | def forward(self, x, emb):
38 | h = self.fc(x)
39 | h = h + self.emb_layer(emb)
40 | h = self.norm(h)
41 | h = self.act(h)
42 | h = self.drop(h)
43 | return h
44 |
45 | class Cell_Unet(nn.Module):
46 | def __init__(self, input_dim=2, hidden_num=[2000,1000,500,500], dropout=0.1):
47 | super(Cell_Unet, self).__init__()
48 | self.hidden_num = hidden_num
49 |
50 | self.time_embedding = TimeEmbedding(hidden_num[0])
51 |
52 | # Create layers dynamically
53 | self.layers = nn.ModuleList()
54 |
55 | self.layers.append(ResidualBlock(input_dim, hidden_num[0], hidden_num[0]))
56 |
57 | for i in range(len(hidden_num)-1):
58 | self.layers.append(ResidualBlock(hidden_num[i], hidden_num[i+1], hidden_num[0]))
59 |
60 | self.reverse_layers = nn.ModuleList()
61 | for i in reversed(range(len(hidden_num)-1)):
62 | self.reverse_layers.append(ResidualBlock(hidden_num[i+1], hidden_num[i], hidden_num[0]))
63 |
64 | self.out1 = nn.Linear(hidden_num[0], int(hidden_num[1]*2))
65 | self.norm_out = nn.LayerNorm(int(hidden_num[1]*2))
66 | self.out2 = nn.Linear(int(hidden_num[1]*2), input_dim, bias=True)
67 |
68 | self.act = nn.SiLU()
69 | self.drop = nn.Dropout(dropout)
70 |
71 | def forward(self, x_input, t, y=None):
72 | emb = self.time_embedding(t)
73 | x = x_input.float()
74 |
75 | # Forward pass with history saving
76 | history = []
77 | for layer in self.layers:
78 | x = layer(x, emb)
79 | history.append(x)
80 |
81 | history.pop()
82 |
83 | # Reverse pass with skip connections
84 | for layer in self.reverse_layers:
85 | x = layer(x, emb)
86 | x = x + history.pop() # Skip connection
87 |
88 | x = self.out1(x)
89 | x = self.norm_out(x)
90 | x = self.act(x)
91 | x = self.out2(x)
92 | return x
93 |
94 |
95 | class Cell_classifier(nn.Module):
96 | def __init__(self, input_dim=2, hidden_num=[2000,1000,500,200], num_class=11, dropout = 0.1):
97 | super().__init__()
98 | self.num_class = num_class
99 | self.input_dim = input_dim
100 | self.hidden_num = hidden_num
101 | self.drop_rate = dropout
102 |
103 | self.time_embed = nn.Sequential(
104 | linear(hidden_num[0], hidden_num[0]),
105 | nn.SiLU(),
106 | linear(hidden_num[0], hidden_num[0]),
107 | )
108 |
109 | self.fc1 = nn.Linear(input_dim, hidden_num[0], bias=True)
110 | self.emb_layers1 = nn.Sequential(
111 | nn.SiLU(),
112 | linear(
113 | hidden_num[0],
114 | hidden_num[0],
115 | ),
116 | )
117 | self.norm1 = nn.BatchNorm1d(hidden_num[0])
118 |
119 | self.fc2 = nn.Linear(hidden_num[0], hidden_num[1], bias=True)
120 | self.emb_layers2 = nn.Sequential(
121 | nn.SiLU(),
122 | linear(
123 | hidden_num[0],
124 | hidden_num[1],
125 | ),
126 | )
127 | self.norm2 = nn.BatchNorm1d(hidden_num[1])
128 |
129 | self.fc3 = nn.Linear(hidden_num[1], hidden_num[2], bias=True)
130 | self.emb_layers3 = nn.Sequential(
131 | nn.SiLU(),
132 | linear(
133 | hidden_num[0],
134 | hidden_num[2],
135 | ),
136 | )
137 | self.norm3 = nn.BatchNorm1d(hidden_num[2])
138 |
139 | self.act = torch.nn.SiLU()
140 | self.drop = nn.Dropout(self.drop_rate)
141 | self.out = nn.Linear(hidden_num[2], num_class, bias=True)
142 |
143 |
144 | def forward(self, x_input, t):
145 | emb = self.time_embed(timestep_embedding(t, self.hidden_num[0]).squeeze(1))
146 |
147 | x = self.fc1(x_input)
148 | x = x+self.emb_layers1(emb)
149 | x = self.norm1(x)
150 | x = self.act(x)
151 | x = self.drop(x)
152 |
153 | x = self.fc2(x)
154 | x = x+self.emb_layers2(emb)
155 | x = self.norm2(x)
156 | x = self.act(x)
157 | x = self.drop(x)
158 |
159 | x = self.fc3(x)
160 | x = self.norm3(x)
161 | x = self.act(x)
162 | x = self.drop(x)
163 |
164 | x = self.out(x)
165 | return x
166 |
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor, retain_graph=False):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward(retain_graph=retain_graph)
180 | else:
181 | loss.backward(retain_graph=retain_graph)
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | for p in self.master_params:
203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204 | opt.step()
205 | zero_master_grads(self.master_params)
206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207 | self.lg_loss_scale += self.fp16_scale_growth
208 | return True
209 |
210 | def _optimize_normal(self, opt: th.optim.Optimizer):
211 | grad_norm, param_norm = self._compute_norms()
212 | logger.logkv_mean("grad_norm", grad_norm)
213 | logger.logkv_mean("param_norm", param_norm)
214 | opt.step()
215 | return True
216 |
217 | def _compute_norms(self, grad_scale=1.0):
218 | grad_norm = 0.0
219 | param_norm = 0.0
220 | for p in self.master_params:
221 | with th.no_grad():
222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223 | if p.grad is not None:
224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226 |
227 | def master_params_to_state_dict(self, master_params):
228 | return master_params_to_state_dict(
229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230 | )
231 |
232 | def state_dict_to_master_params(self, state_dict):
233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234 |
235 |
236 | def check_overflow(value):
237 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
238 |
--------------------------------------------------------------------------------
/guided_diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | This code started out as a PyTorch port of Ho et al's diffusion models:
3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4 |
5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6 | """
7 |
8 | import enum
9 | import math
10 |
11 | import numpy as np
12 | import torch as th
13 |
14 | from .nn import mean_flat
15 | from .losses import normal_kl, discretized_gaussian_log_likelihood
16 |
17 |
18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
19 | """
20 | Get a pre-defined beta schedule for the given name.
21 |
22 | The beta schedule library consists of beta schedules which remain similar
23 | in the limit of num_diffusion_timesteps.
24 | Beta schedules may be added, but should not be removed or changed once
25 | they are committed to maintain backwards compatibility.
26 | """
27 | if schedule_name == "linear":
28 | # Linear schedule from Ho et al, extended to work for any number of
29 | # diffusion steps.
30 | scale = 1000 / num_diffusion_timesteps
31 | beta_start = scale * 0.0001
32 | beta_end = scale * 0.02
33 | return np.linspace(
34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
35 | )
36 | elif schedule_name == "cosine":
37 | return betas_for_alpha_bar(
38 | num_diffusion_timesteps,
39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
40 | )
41 | else:
42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
43 |
44 |
45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
46 | """
47 | Create a beta schedule that discretizes the given alpha_t_bar function,
48 | which defines the cumulative product of (1-beta) over time from t = [0,1].
49 |
50 | :param num_diffusion_timesteps: the number of betas to produce.
51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
52 | produces the cumulative product of (1-beta) up to that
53 | part of the diffusion process.
54 | :param max_beta: the maximum beta to use; use values lower than 1 to
55 | prevent singularities.
56 | """
57 | betas = []
58 | for i in range(num_diffusion_timesteps):
59 | t1 = i / num_diffusion_timesteps
60 | t2 = (i + 1) / num_diffusion_timesteps
61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
62 | return np.array(betas)
63 |
64 |
65 | class ModelMeanType(enum.Enum):
66 | """
67 | Which type of output the model predicts.
68 | """
69 |
70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
71 | START_X = enum.auto() # the model predicts x_0
72 | EPSILON = enum.auto() # the model predicts epsilon
73 |
74 |
75 | class ModelVarType(enum.Enum):
76 | """
77 | What is used as the model's output variance.
78 |
79 | The LEARNED_RANGE option has been added to allow the model to predict
80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
81 | """
82 |
83 | LEARNED = enum.auto()
84 | FIXED_SMALL = enum.auto()
85 | FIXED_LARGE = enum.auto()
86 | LEARNED_RANGE = enum.auto()
87 |
88 |
89 | class LossType(enum.Enum):
90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
91 | RESCALED_MSE = (
92 | enum.auto()
93 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
94 | KL = enum.auto() # use the variational lower-bound
95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
96 |
97 | def is_vb(self):
98 | return self == LossType.KL or self == LossType.RESCALED_KL
99 |
100 |
101 | class GaussianDiffusion:
102 | """
103 | Utilities for training and sampling diffusion models.
104 |
105 | Ported directly from here, and then adapted over time to further experimentation.
106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
107 |
108 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
109 | starting at T and going to 1.
110 | :param model_mean_type: a ModelMeanType determining what the model outputs.
111 | :param model_var_type: a ModelVarType determining how variance is output.
112 | :param loss_type: a LossType determining the loss function to use.
113 | :param rescale_timesteps: if True, pass floating point timesteps into the
114 | model so that they are always scaled like in the
115 | original paper (0 to 1000).
116 | """
117 |
118 | def __init__(
119 | self,
120 | *,
121 | betas,
122 | model_mean_type,
123 | model_var_type,
124 | loss_type,
125 | rescale_timesteps=False,
126 | ):
127 | self.model_mean_type = model_mean_type
128 | self.model_var_type = model_var_type
129 | self.loss_type = loss_type
130 | self.rescale_timesteps = rescale_timesteps
131 |
132 | # Use float64 for accuracy.
133 | betas = np.array(betas, dtype=np.float64)
134 | self.betas = betas
135 | assert len(betas.shape) == 1, "betas must be 1-D"
136 | assert (betas > 0).all() and (betas <= 1).all()
137 |
138 | self.num_timesteps = int(betas.shape[0])
139 |
140 | alphas = 1.0 - betas
141 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
142 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
143 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
144 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
145 |
146 | # calculations for diffusion q(x_t | x_{t-1}) and others
147 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
148 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
149 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
150 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
151 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
152 |
153 | # calculations for posterior q(x_{t-1} | x_t, x_0)
154 | self.posterior_variance = (
155 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
156 | )
157 | # log calculation clipped because the posterior variance is 0 at the
158 | # beginning of the diffusion chain.
159 | self.posterior_log_variance_clipped = np.log(
160 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
161 | )
162 | self.posterior_mean_coef1 = (
163 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
164 | )
165 | self.posterior_mean_coef2 = (
166 | (1.0 - self.alphas_cumprod_prev)
167 | * np.sqrt(alphas)
168 | / (1.0 - self.alphas_cumprod)
169 | )
170 |
171 | def q_mean_variance(self, x_start, t):
172 | """
173 | Get the distribution q(x_t | x_0).
174 |
175 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
176 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
177 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
178 | """
179 | mean = (
180 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
181 | )
182 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
183 | log_variance = _extract_into_tensor(
184 | self.log_one_minus_alphas_cumprod, t, x_start.shape
185 | )
186 | return mean, variance, log_variance
187 |
188 | def q_sample(self, x_start, t, noise=None):
189 | """
190 | Diffuse the data for a given number of diffusion steps.
191 |
192 | In other words, sample from q(x_t | x_0).
193 |
194 | :param x_start: the initial data batch.
195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
196 | :param noise: if specified, the split-out normal noise.
197 | :return: A noisy version of x_start.
198 | """
199 | if noise is None:
200 | noise = th.randn_like(x_start)
201 | assert noise.shape == x_start.shape
202 | return (
203 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
204 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
205 | * noise
206 | )
207 |
208 | def q_posterior_mean_variance(self, x_start, x_t, t):
209 | """
210 | Compute the mean and variance of the diffusion posterior:
211 |
212 | q(x_{t-1} | x_t, x_0)
213 |
214 | """
215 | assert x_start.shape == x_t.shape
216 | posterior_mean = (
217 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
218 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
219 | )
220 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
221 | posterior_log_variance_clipped = _extract_into_tensor(
222 | self.posterior_log_variance_clipped, t, x_t.shape
223 | )
224 | assert (
225 | posterior_mean.shape[0]
226 | == posterior_variance.shape[0]
227 | == posterior_log_variance_clipped.shape[0]
228 | == x_start.shape[0]
229 | )
230 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
231 |
232 | def p_mean_variance(
233 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
234 | ):
235 | """
236 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
237 | the initial x, x_0.
238 |
239 | :param model: the model, which takes a signal and a batch of timesteps
240 | as input.
241 | :param x: the [N x C x ...] tensor at time t.
242 | :param t: a 1-D Tensor of timesteps.
243 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
244 | :param denoised_fn: if not None, a function which applies to the
245 | x_start prediction before it is used to sample. Applies before
246 | clip_denoised.
247 | :param model_kwargs: if not None, a dict of extra keyword arguments to
248 | pass to the model. This can be used for conditioning.
249 | :return: a dict with the following keys:
250 | - 'mean': the model mean output.
251 | - 'variance': the model variance output.
252 | - 'log_variance': the log of 'variance'.
253 | - 'pred_xstart': the prediction for x_0.
254 | """
255 | if model_kwargs is None:
256 | model_kwargs = {}
257 |
258 | B, C = x.shape[:2]
259 | assert t.shape == (B,)
260 | model_output = model(x, self._scale_timesteps(t).unsqueeze(1), **model_kwargs) #MLP
261 | # model_output = model(x.unsqueeze(1).float(), self._scale_timesteps(t), **model_kwargs) #UNet
262 | # model_output = model_output.squeeze(1)
263 |
264 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
265 | assert model_output.shape == (B, C * 2, *x.shape[2:])
266 | model_output, model_var_values = th.split(model_output, C, dim=1)
267 | if self.model_var_type == ModelVarType.LEARNED:
268 | model_log_variance = model_var_values
269 | model_variance = th.exp(model_log_variance)
270 | else:
271 | min_log = _extract_into_tensor(
272 | self.posterior_log_variance_clipped, t, x.shape
273 | )
274 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
275 | # The model_var_values is [-1, 1] for [min_var, max_var].
276 | frac = (model_var_values + 1) / 2
277 | model_log_variance = frac * max_log + (1 - frac) * min_log
278 | model_variance = th.exp(model_log_variance)
279 | else:
280 | model_variance, model_log_variance = {
281 | # for fixedlarge, we set the initial (log-)variance like so
282 | # to get a better decoder log likelihood.
283 | ModelVarType.FIXED_LARGE: (
284 | np.append(self.posterior_variance[1], self.betas[1:]),
285 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
286 | ),
287 | ModelVarType.FIXED_SMALL: (
288 | self.posterior_variance,
289 | self.posterior_log_variance_clipped,
290 | ),
291 | }[self.model_var_type]
292 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
293 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
294 |
295 | def process_xstart(x):
296 | if denoised_fn is not None:
297 | x = denoised_fn(x)
298 | if clip_denoised:
299 | return x.clamp(-1, 1)
300 | return x
301 |
302 | if self.model_mean_type == ModelMeanType.PREVIOUS_X:
303 | pred_xstart = process_xstart(
304 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
305 | )
306 | model_mean = model_output
307 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
308 | if self.model_mean_type == ModelMeanType.START_X:
309 | pred_xstart = process_xstart(model_output)
310 | else:
311 | pred_xstart = process_xstart(
312 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
313 | )
314 | model_mean, _, _ = self.q_posterior_mean_variance(
315 | x_start=pred_xstart, x_t=x, t=t
316 | )
317 | else:
318 | raise NotImplementedError(self.model_mean_type)
319 |
320 | assert (
321 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
322 | )
323 | return {
324 | "mean": model_mean,
325 | "variance": model_variance,
326 | "log_variance": model_log_variance,
327 | "pred_xstart": pred_xstart,
328 | }
329 |
330 | def _predict_xstart_from_eps(self, x_t, t, eps):
331 | assert x_t.shape == eps.shape
332 | return (
333 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
334 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
335 | )
336 |
337 | def _predict_xstart_from_xprev(self, x_t, t, xprev):
338 | assert x_t.shape == xprev.shape
339 | return ( # (xprev - coef2*x_t) / coef1
340 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
341 | - _extract_into_tensor(
342 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
343 | )
344 | * x_t
345 | )
346 |
347 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
348 | return (
349 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
350 | - pred_xstart
351 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
352 |
353 | def _scale_timesteps(self, t):
354 | if self.rescale_timesteps:
355 | return t.float() * (1000.0 / self.num_timesteps)
356 | return t
357 |
358 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359 | """
360 | Compute the mean for the previous step, given a function cond_fn that
361 | computes the gradient of a conditional log probability with respect to
362 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
363 | condition on y.
364 |
365 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
366 | """
367 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
368 | new_mean = (
369 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
370 | )
371 | return new_mean
372 |
373 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
374 | """
375 | Compute what the p_mean_variance output would have been, should the
376 | model's score function be conditioned by cond_fn.
377 |
378 | See condition_mean() for details on cond_fn.
379 |
380 | Unlike condition_mean(), this instead uses the conditioning strategy
381 | from Song et al (2020).
382 | """
383 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
384 |
385 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
386 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
387 | x, self._scale_timesteps(t), **model_kwargs
388 | )
389 |
390 | out = p_mean_var.copy()
391 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
392 | out["mean"], _, _ = self.q_posterior_mean_variance(
393 | x_start=out["pred_xstart"], x_t=x, t=t
394 | )
395 | return out
396 |
397 | def p_sample(
398 | self,
399 | model,
400 | x,
401 | t,
402 | clip_denoised=True,
403 | denoised_fn=None,
404 | cond_fn=None,
405 | model_kwargs=None,
406 | nw=0.5,
407 | start_guide_steps=500,
408 | ):
409 | """
410 | Sample x_{t-1} from the model at the given timestep.
411 |
412 | :param model: the model to sample from.
413 | :param x: the current tensor at x_{t-1}.
414 | :param t: the value of t, starting at 0 for the first diffusion step.
415 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
416 | :param denoised_fn: if not None, a function which applies to the
417 | x_start prediction before it is used to sample.
418 | :param cond_fn: if not None, this is a gradient function that acts
419 | similarly to the model.
420 | :param model_kwargs: if not None, a dict of extra keyword arguments to
421 | pass to the model. This can be used for conditioning.
422 | :return: a dict containing the following keys:
423 | - 'sample': a random sample from the model.
424 | - 'pred_xstart': a prediction of x_0.
425 | """
426 | out = self.p_mean_variance(
427 | model,
428 | x,
429 | t,
430 | clip_denoised=clip_denoised,
431 | denoised_fn=denoised_fn,
432 | model_kwargs=model_kwargs,
433 | )
434 | noise = th.randn_like(x)#*(0.5**0.5)
435 | nonzero_mask = (
436 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
437 | ) # no noise when t == 0
438 | if cond_fn is not None and t[0] maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452 | )
453 | assert isinstance(dir, str)
454 | dir = os.path.expanduser(dir)
455 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
456 |
457 | rank = get_rank_without_mpi_import()
458 | if rank > 0:
459 | log_suffix = log_suffix + "-rank%03i" % rank
460 |
461 | if format_strs is None:
462 | if rank == 0:
463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464 | else:
465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466 | format_strs = filter(None, format_strs)
467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468 |
469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470 | if output_formats:
471 | log("Logging to %s" % dir)
472 |
473 |
474 | def _configure_default_logger():
475 | configure()
476 | Logger.DEFAULT = Logger.CURRENT
477 |
478 |
479 | def reset():
480 | if Logger.CURRENT is not Logger.DEFAULT:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = Logger.DEFAULT
483 | log("Reset logger")
484 |
485 |
486 | @contextmanager
487 | def scoped_configure(dir=None, format_strs=None, comm=None):
488 | prevlogger = Logger.CURRENT
489 | configure(dir=dir, format_strs=format_strs, comm=comm)
490 | try:
491 | yield
492 | finally:
493 | Logger.CURRENT.close()
494 | Logger.CURRENT = prevlogger
495 |
496 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device, start_guide_time=1000):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()[:start_guide_time]
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | from .cell_model import Cell_classifier, Cell_Unet
7 |
8 | NUM_CLASSES = 11
9 |
10 |
11 | def diffusion_defaults():
12 | """
13 | Defaults for image and classifier training.
14 | """
15 | return dict(
16 | learn_sigma=False,
17 | diffusion_steps=1000,
18 | noise_schedule="linear",
19 | timestep_respacing="",
20 | use_kl=False,
21 | predict_xstart=False,
22 | rescale_timesteps=False,
23 | rescale_learned_sigmas=False,
24 | class_cond=False,
25 | )
26 |
27 |
28 | def model_and_diffusion_defaults():
29 | """
30 | Defaults for image training.
31 | """
32 | res = dict(
33 | input_dim = 128,
34 | hidden_dim = [512,512,256,128],
35 | dropout = 0.0
36 | )
37 | res.update(diffusion_defaults())
38 | return res
39 |
40 |
41 | def classifier_and_diffusion_defaults():
42 | res = dict(
43 | input_dim = 128,
44 | hidden_dim = [512,512,256,128],
45 | classifier_use_fp16=False,
46 | dropout = 0.1,
47 | num_class = 11,
48 | )
49 | res.update(diffusion_defaults())
50 | return res
51 |
52 |
53 | def create_model_and_diffusion(
54 | input_dim,
55 | hidden_dim,
56 | class_cond,
57 | learn_sigma,
58 | diffusion_steps,
59 | noise_schedule,
60 | timestep_respacing,
61 | use_kl,
62 | predict_xstart,
63 | rescale_timesteps,
64 | rescale_learned_sigmas,
65 | dropout,
66 | ):
67 | model = create_model(
68 | input_dim,
69 | hidden_dim,
70 | dropout=dropout
71 | )
72 | diffusion = create_gaussian_diffusion(
73 | steps=diffusion_steps,
74 | learn_sigma=learn_sigma,
75 | noise_schedule=noise_schedule,
76 | use_kl=use_kl,
77 | predict_xstart=predict_xstart,
78 | rescale_timesteps=rescale_timesteps,
79 | rescale_learned_sigmas=rescale_learned_sigmas,
80 | timestep_respacing=timestep_respacing,
81 | )
82 | return model, diffusion
83 |
84 |
85 | def create_model(
86 | input_dim,
87 | hidden_dim,
88 | dropout,
89 | ):
90 |
91 | return Cell_Unet(
92 | input_dim,
93 | hidden_dim,
94 | dropout=dropout
95 | )
96 |
97 |
98 | def create_classifier_and_diffusion(
99 | input_dim,
100 | hidden_dim,
101 | classifier_use_fp16,
102 | learn_sigma,
103 | diffusion_steps,
104 | noise_schedule,
105 | timestep_respacing,
106 | use_kl,
107 | predict_xstart,
108 | rescale_timesteps,
109 | rescale_learned_sigmas,
110 | dropout,
111 | num_class,
112 | class_cond,
113 | ):
114 | classifier = create_classifier(
115 | input_dim,
116 | hidden_dim,
117 | dropout=dropout,
118 | num_class=num_class
119 | )
120 | diffusion = create_gaussian_diffusion(
121 | steps=diffusion_steps,
122 | learn_sigma=learn_sigma,
123 | noise_schedule=noise_schedule,
124 | use_kl=use_kl,
125 | predict_xstart=predict_xstart,
126 | rescale_timesteps=rescale_timesteps,
127 | rescale_learned_sigmas=rescale_learned_sigmas,
128 | timestep_respacing=timestep_respacing,
129 | )
130 | return classifier, diffusion
131 |
132 |
133 | def create_classifier(
134 | input_dim,
135 | hidden_dim,
136 | num_class = NUM_CLASSES,
137 | dropout = 0.1,
138 | ):
139 |
140 | return Cell_classifier(
141 | input_dim,
142 | hidden_dim,
143 | num_class,
144 | dropout,
145 | )
146 |
147 | def create_gaussian_diffusion(
148 | *,
149 | steps=1000,
150 | learn_sigma=False,
151 | sigma_small=False,
152 | noise_schedule="linear",
153 | use_kl=False,
154 | predict_xstart=False,
155 | rescale_timesteps=False,
156 | rescale_learned_sigmas=False,
157 | timestep_respacing="",
158 | ):
159 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
160 | if use_kl:
161 | loss_type = gd.LossType.RESCALED_KL
162 | elif rescale_learned_sigmas:
163 | loss_type = gd.LossType.RESCALED_MSE
164 | else:
165 | loss_type = gd.LossType.MSE
166 | if not timestep_respacing:
167 | timestep_respacing = [steps]
168 | return SpacedDiffusion(
169 | use_timesteps=space_timesteps(steps, timestep_respacing),
170 | betas=betas,
171 | model_mean_type=(
172 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
173 | ),
174 | model_var_type=(
175 | (
176 | gd.ModelVarType.FIXED_LARGE
177 | if not sigma_small
178 | else gd.ModelVarType.FIXED_SMALL
179 | )
180 | if not learn_sigma
181 | else gd.ModelVarType.LEARNED_RANGE
182 | ),
183 | loss_type=loss_type,
184 | rescale_timesteps=rescale_timesteps,
185 | )
186 |
187 |
188 | def add_dict_to_argparser(parser, default_dict):
189 | for k, v in default_dict.items():
190 | v_type = type(v)
191 | if v is None:
192 | v_type = str
193 | elif isinstance(v, bool):
194 | v_type = str2bool
195 | parser.add_argument(f"--{k}", default=v, type=v_type)
196 |
197 |
198 | def args_to_dict(args, keys):
199 | return {k: getattr(args, k) for k in keys}
200 |
201 |
202 | def str2bool(v):
203 | """
204 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
205 | """
206 | if isinstance(v, bool):
207 | return v
208 | if v.lower() in ("yes", "true", "t", "y", "1"):
209 | return True
210 | elif v.lower() in ("no", "false", "f", "n", "0"):
211 | return False
212 | else:
213 | raise argparse.ArgumentTypeError("boolean value expected")
214 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 |
16 | # For ImageNet experiments, this was a good default value.
17 | # We found that the lg_loss_scale quickly climbed to
18 | # 20-21 within the first ~1K steps of training.
19 | INITIAL_LOG_LOSS_SCALE = 20.0
20 |
21 |
22 | class TrainLoop:
23 | def __init__(
24 | self,
25 | *,
26 | model,
27 | diffusion,
28 | data,
29 | batch_size,
30 | microbatch,
31 | lr,
32 | ema_rate,
33 | log_interval,
34 | save_interval,
35 | resume_checkpoint,
36 | use_fp16=False,
37 | fp16_scale_growth=1e-3,
38 | schedule_sampler=None,
39 | weight_decay=0.0,
40 | lr_anneal_steps=0,
41 | model_name,
42 | save_dir,
43 | ):
44 | self.model = model
45 | self.diffusion = diffusion
46 | self.data = data
47 | self.batch_size = batch_size
48 | self.microbatch = microbatch if microbatch > 0 else batch_size
49 | self.lr = lr
50 | self.ema_rate = (
51 | [ema_rate]
52 | if isinstance(ema_rate, float)
53 | else [float(x) for x in ema_rate.split(",")]
54 | )
55 | self.log_interval = log_interval
56 | self.save_interval = save_interval
57 | self.resume_checkpoint = resume_checkpoint
58 | self.use_fp16 = use_fp16
59 | self.fp16_scale_growth = fp16_scale_growth
60 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
61 | self.weight_decay = weight_decay
62 | self.lr_anneal_steps = lr_anneal_steps
63 |
64 | self.step = 0
65 | self.resume_step = 0
66 | self.global_batch = self.batch_size * dist.get_world_size()
67 |
68 | self.sync_cuda = th.cuda.is_available()
69 |
70 | self._load_and_sync_parameters()
71 | self.mp_trainer = MixedPrecisionTrainer(
72 | model=self.model,
73 | use_fp16=self.use_fp16,
74 | fp16_scale_growth=fp16_scale_growth,
75 | )
76 |
77 | self.opt = AdamW(
78 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
79 | )
80 | if self.resume_step:
81 | self._load_optimizer_state()
82 | # Model was resumed, either due to a restart or a checkpoint
83 | # being specified at the command line.
84 | self.ema_params = [
85 | self._load_ema_parameters(rate) for rate in self.ema_rate
86 | ]
87 | else:
88 | self.ema_params = [
89 | copy.deepcopy(self.mp_trainer.master_params)
90 | for _ in range(len(self.ema_rate))
91 | ]
92 |
93 | if th.cuda.is_available():
94 | self.use_ddp = True
95 | self.ddp_model = DDP(
96 | self.model,
97 | device_ids=[dist_util.dev()],
98 | output_device=dist_util.dev(),
99 | broadcast_buffers=False,
100 | bucket_cap_mb=128,
101 | find_unused_parameters=False,
102 | )
103 | else:
104 | if dist.get_world_size() > 1:
105 | logger.warn(
106 | "Distributed training requires CUDA. "
107 | "Gradients will not be synchronized properly!"
108 | )
109 | self.use_ddp = False
110 | self.ddp_model = self.model
111 | self.timestamp = model_name #time.strftime("%m-%d-%H:%M",time.gmtime())
112 | self.save_dir = save_dir
113 |
114 | def _load_and_sync_parameters(self):
115 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
116 |
117 | if resume_checkpoint:
118 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
119 | if dist.get_rank() == 0:
120 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
121 | self.model.load_state_dict(
122 | dist_util.load_state_dict(
123 | resume_checkpoint, map_location=dist_util.dev()
124 | )
125 | )
126 |
127 | dist_util.sync_params(self.model.parameters())
128 |
129 | def _load_ema_parameters(self, rate):
130 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
131 |
132 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
133 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
134 | if ema_checkpoint:
135 | if dist.get_rank() == 0:
136 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
137 | state_dict = dist_util.load_state_dict(
138 | ema_checkpoint, map_location=dist_util.dev()
139 | )
140 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
141 |
142 | dist_util.sync_params(ema_params)
143 | return ema_params
144 |
145 | def _load_optimizer_state(self):
146 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
147 | opt_checkpoint = bf.join(
148 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
149 | )
150 | if bf.exists(opt_checkpoint):
151 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
152 | state_dict = dist_util.load_state_dict(
153 | opt_checkpoint, map_location=dist_util.dev()
154 | )
155 | self.opt.load_state_dict(state_dict)
156 |
157 | def run_loop(self):
158 | while (
159 | not self.lr_anneal_steps
160 | or self.step + self.resume_step < self.lr_anneal_steps
161 | ):
162 | batch, cond = next(self.data)
163 | self.run_step(batch, cond)
164 | if self.step % self.log_interval == 0:
165 | logger.dumpkvs()
166 | if self.step % self.save_interval == 0:
167 | self.save()
168 | # Run for a finite amount of time in integration tests.
169 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
170 | return
171 | self.step += 1
172 | # Save the last checkpoint if it wasn't already saved.
173 | if (self.step - 1) % self.save_interval != 0:
174 | self.save()
175 |
176 | def run_step(self, batch, cond):
177 | self.forward_backward(batch, cond)
178 | took_step = self.mp_trainer.optimize(self.opt)
179 | if took_step:
180 | self._update_ema()
181 | self._anneal_lr()
182 | self.log_step()
183 |
184 | def forward_backward(self, batch, cond):
185 | self.mp_trainer.zero_grad()
186 | for i in range(0, batch.shape[0], self.microbatch):
187 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
188 | micro_cond = {
189 | k: v[i : i + self.microbatch].to(dist_util.dev())
190 | for k, v in cond.items()
191 | }
192 | last_batch = (i + self.microbatch) >= batch.shape[0]
193 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
194 |
195 | compute_losses = functools.partial(
196 | self.diffusion.training_losses,
197 | self.ddp_model,
198 | micro,
199 | t,
200 | model_kwargs=micro_cond,
201 | )
202 |
203 | if last_batch or not self.use_ddp:
204 | losses = compute_losses()
205 | else:
206 | with self.ddp_model.no_sync():
207 | losses = compute_losses()
208 |
209 | if isinstance(self.schedule_sampler, LossAwareSampler):
210 | self.schedule_sampler.update_with_local_losses(
211 | t, losses["loss"].detach()
212 | )
213 |
214 | loss = (losses["loss"] * weights).mean()
215 | log_loss_dict(
216 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
217 | )
218 | self.mp_trainer.backward(loss)
219 |
220 | def _update_ema(self):
221 | for rate, params in zip(self.ema_rate, self.ema_params):
222 | update_ema(params, self.mp_trainer.master_params, rate=rate)
223 |
224 | def _anneal_lr(self):
225 | if not self.lr_anneal_steps:
226 | return
227 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
228 | lr = self.lr * (1 - frac_done)
229 | for param_group in self.opt.param_groups:
230 | param_group["lr"] = lr
231 |
232 | def log_step(self):
233 | logger.logkv("step", self.step + self.resume_step)
234 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
235 |
236 | def save(self):
237 | def save_checkpoint(rate, params):
238 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
239 | if dist.get_rank() == 0:
240 | logger.log(f"saving model {rate}...")
241 | if not rate:
242 | filename = f"model{(self.step+self.resume_step):06d}.pt"
243 | else:
244 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
245 | with bf.BlobFile(bf.join(self.save_dir, self.timestamp, filename), "wb") as f:
246 | th.save(state_dict, f)
247 | if not os.path.exists(os.path.join(self.save_dir, self.timestamp)):
248 | os.mkdir(os.path.join(self.save_dir, self.timestamp))
249 | save_checkpoint(0, self.mp_trainer.master_params)
250 | for rate, params in zip(self.ema_rate, self.ema_params):
251 | save_checkpoint(rate, params)
252 |
253 | if dist.get_rank() == 0:
254 | with bf.BlobFile(
255 | bf.join(self.save_dir, self.timestamp, f"opt{(self.step+self.resume_step):06d}.pt"),
256 | "wb",
257 | ) as f:
258 | th.save(self.opt.state_dict(), f)
259 |
260 | dist.barrier()
261 |
262 |
263 | def parse_resume_step_from_filename(filename):
264 | """
265 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
266 | checkpoint's number of steps.
267 | """
268 | split = filename.split("model")
269 | if len(split) < 2:
270 | return 0
271 | split1 = split[-1].split(".")[0]
272 | try:
273 | return int(split1)
274 | except ValueError:
275 | return 0
276 |
277 |
278 | def get_blob_logdir():
279 | # You can change this to be a separate path to save checkpoints to
280 | # a blobstore or some external drive.
281 | return logger.get_dir()
282 |
283 |
284 | def find_resume_checkpoint():
285 | # On your infrastructure, you may want to override this to automatically
286 | # discover the latest checkpoint on your blob storage, etc.
287 | return None
288 |
289 |
290 | def find_ema_checkpoint(main_checkpoint, step, rate):
291 | if main_checkpoint is None:
292 | return None
293 | filename = f"ema_{rate}_{(step):06d}.pt"
294 | path = bf.join(bf.dirname(main_checkpoint), filename)
295 | if bf.exists(path):
296 | return path
297 | return None
298 |
299 |
300 | def log_loss_dict(diffusion, ts, losses):
301 | for key, values in losses.items():
302 | logger.logkv(key, values.mean().item())
303 | # logger.logkv_mean(key, values.mean().item())
304 | # Log the quantiles (four quartiles, in particular).
305 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
306 | quartile = int(4 * sub_t / diffusion.num_timesteps)
307 | logger.logkv(f"{key}_q{quartile}", sub_loss)
308 | # logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
309 |
--------------------------------------------------------------------------------
/model_archi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EperLuo/scDiffusion/d34ef8e560b47159d4500cf4411a7a34e5a12a32/model_archi.png
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | cd VAE
2 | echo "Training Autoencoder, this might take a long time"
3 | python VAE_train.py --data_dir '/stor/lep/diffusion/multiome/openproblems_RNA_new.h5ad' --num_genes 13431 --save_dir '../output/checkpoint/AE/open_problem' --max_steps 200000
4 | echo "Training Autoencoder done"
5 |
6 | cd ..
7 | echo "Training diffusion backbone"
8 | python cell_train.py --data_dir '/stor/lep/diffusion/multiome/openproblems_RNA_new.h5ad' --vae_path 'output/checkpoint/AE/open_problem/model_seed=0_step=150000.pt' \
9 | --model_name 'open_problem' --lr_anneal_steps 800000 --save_dir 'output/checkpoint/backbone'
10 | echo "Training diffusion backbone done"
11 |
12 | echo "Training classifier"
13 | python classifier_train.py --data_dir '/stor/lep/diffusion/multiome/openproblems_RNA_new.h5ad' --model_path "output/checkpoint/classifier/open_problem_classifier" \
14 | --iterations 400000 --vae_path 'checkpoint/AE/open_problem/model_seed=0_step=150000.pt' --num_class 22
15 | echo "Training classifier, done"
--------------------------------------------------------------------------------