├── .gitignore ├── 3D-SR-Unet ├── __init__.py ├── data.py ├── main.py ├── model.py └── train.py ├── LICENSE ├── RAFT ├── .gitignore ├── LICENSE ├── RAFT.png ├── README.md ├── __init__.py ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── align_functions.py │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── raftConfig.py │ ├── register.py │ ├── register_custom.py │ ├── super_res_register.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── demo-frames │ ├── frame_0016.png │ ├── frame_0017.png │ ├── frame_0018.png │ ├── frame_0019.png │ ├── frame_0020.png │ ├── frame_0021.png │ ├── frame_0022.png │ ├── frame_0023.png │ ├── frame_0024.png │ └── frame_0025.png ├── demo.py ├── download_models.sh ├── evaluate.py ├── models │ └── raft-things.pth ├── train.py ├── train_mixed.sh └── train_standard.sh ├── README.md ├── config ├── EMDiffuse-n-big.json ├── EMDiffuse-n-transfer.json ├── EMDiffuse-n.json ├── EMDiffuse-r.json ├── vEMDiffuse-a.json └── vEMDiffuse-i.json ├── core ├── __pycache__ │ ├── base_model.cpython-37.pyc │ ├── base_network.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── praser.cpython-37.pyc │ └── util.cpython-37.pyc ├── base_dataset.py ├── base_model.py ├── base_network.py ├── calibration.py ├── logger.py ├── praser.py └── util.py ├── crop_single_file.py ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── dataset.cpython-37.pyc ├── dataset.py └── util │ ├── auto_augment.py │ └── mask.py ├── demo ├── denoise_demo.tif ├── microns_demo │ ├── 0.tif │ └── 1.tif ├── mouse_liver_demo │ ├── 0.tif │ └── 1.tif └── super_res_demo.tif ├── emdiffuse_conifg.py ├── example ├── denoise │ ├── prediction.ipynb │ └── training.ipynb ├── super-res │ ├── prediction.ipynb │ └── training.ipynb ├── vEMDiffuse-a │ ├── prediction.ipynb │ └── training.ipynb └── vEMDiffuse-i │ ├── prediction.ipynb │ └── training.ipynb ├── models ├── EMDiffuse_model.py ├── EMDiffuse_network.py ├── __init__.py ├── __pycache__ │ ├── EMDiffuse_model.cpython-37.pyc │ ├── EMDiffuse_network.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── loss.cpython-37.pyc │ └── metric.cpython-37.pyc ├── guided_diffusion_modules │ ├── __pycache__ │ │ ├── nn.cpython-37.pyc │ │ ├── unet.cpython-37.pyc │ │ └── unet_jit2.cpython-37.pyc │ ├── nn.py │ ├── unet.py │ ├── unet_3d.py │ ├── unet_3d_aleatoric.py │ ├── unet_aleatoric.py │ ├── unet_jit.py │ └── unet_jit2.py ├── loss.py ├── metric.py ├── unet.py ├── vEMDiffuse_model.py └── vEMDiffuse_network.py ├── requirements.txt ├── run.py ├── test_pre.py ├── vEM_test_pre.py └── vEMa_pre.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_STORE 4 | -------------------------------------------------------------------------------- /3D-SR-Unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/3D-SR-Unet/__init__.py -------------------------------------------------------------------------------- /3D-SR-Unet/data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import transforms 3 | from PIL import Image, ImageOps, ImageFilter 4 | import os 5 | import numpy as np 6 | import torchvision.transforms.functional as TF 7 | import random 8 | import torch 9 | from scipy.ndimage import zoom 10 | from scipy.ndimage import gaussian_filter1d, gaussian_filter 11 | from tifffile import imread 12 | from scipy.ndimage import zoom 13 | class KidneySRUData(data.Dataset): 14 | def __init__(self, data_root): 15 | self.data_root = data_root 16 | 17 | self.volume_list = self.read_dataset(data_root=self.data_root) 18 | 19 | def __getitem__(self, index): 20 | ret = {} 21 | gt = imread(self.volume_list[index]) 22 | # print(gt.shape) 23 | img = gt[::6, :, :] 24 | img_upsampled = zoom(img, (6, 1,1 ), order=3) 25 | img, gt, img_upsampled = self.aug(img, gt, img_upsampled) 26 | img = img / 255. 27 | gt = gt / 255. 28 | img_upsampled = img_upsampled / 255. 29 | img = self.norm(img) 30 | gt = self.norm(gt) 31 | img_upsampled = self.norm(img_upsampled) 32 | img = torch.tensor(img, dtype=torch.float32).unsqueeze_(dim=0) 33 | gt = torch.tensor(gt, dtype=torch.float32).unsqueeze_(dim=0) 34 | img_upsampled = torch.tensor(img_upsampled, dtype=torch.float32).unsqueeze_(dim=0) 35 | return img, gt,img_upsampled 36 | 37 | def norm(self, img): 38 | img = img.astype(np.float32) 39 | img = (img - 0.5) / 0.5 40 | return img 41 | 42 | def __len__(self): 43 | return len(self.volume_list) 44 | 45 | def aug(self, img, gt, img_up): 46 | if random.random() < 0.5: 47 | img = np.flip(img, axis=2) 48 | gt = np.flip(gt, axis=2) 49 | img_up = np.flip(img_up, axis=2) 50 | if random.random() < 0.5: 51 | img = np.rot90(img, k=1, axes=(1, 2)) 52 | gt = np.rot90(gt, k=1, axes=(1, 2)) 53 | img_up = np.rot90(img_up, k=1, axes=(1, 2)) 54 | return img, gt, img_up 55 | 56 | def read_dataset(self, data_root): 57 | volume_list = [] 58 | for i in range(2000): 59 | if os.path.exists(os.path.join(data_root, str(i) + '.tif')): 60 | volume_list.append(os.path.join(data_root, str(i) + '.tif')) 61 | return volume_list 62 | -------------------------------------------------------------------------------- /3D-SR-Unet/main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import argparse 3 | import os 4 | import argparse 5 | import torch 6 | import numpy as np 7 | from torch import Generator, randperm 8 | import random 9 | from torch.utils.data import DataLoader, Subset 10 | from train import train_cnn 11 | from model import SRUNet, CubicWeightedPSNRLoss 12 | from data import KidneySRUData 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | def train_distributed(args): 18 | model = SRUNet(up_scale=6).cuda() 19 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 20 | criterion = CubicWeightedPSNRLoss().cuda() 21 | dataset_train = KidneySRUData(data_root='/data/cxlu/srunet_liver_training_large/srunet_training') 22 | data_len = len(dataset_train) 23 | valid_len = int(data_len * 0.1) 24 | data_len -= valid_len 25 | dataset_train, dataset_val = subset_split(dataset_train, lengths=[data_len, valid_len], 26 | generator=Generator().manual_seed(args.seed)) 27 | train_loader = DataLoader(dataset=dataset_train, num_workers=args.num_worker, batch_size=args.b, pin_memory=True, 28 | shuffle=True) 29 | val_loader = DataLoader(dataset=dataset_val, num_workers=args.num_worker, batch_size=args.b, pin_memory=True, 30 | ) 31 | train_cnn(train_generator=train_loader, valid_generator=val_loader, args=args, optimizer=optimizer, model=model, 32 | criterion=criterion) 33 | 34 | def subset_split(dataset, lengths, generator): 35 | """ 36 | """ 37 | indices = randperm(sum(lengths), generator=generator).tolist() 38 | Subsets = [] 39 | for offset, length in zip(np.add.accumulate(lengths), lengths): 40 | if length == 0: 41 | Subsets.append(None) 42 | else: 43 | Subsets.append(Subset(dataset, indices[offset - length: offset])) 44 | return Subsets 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser(description='Self Training benchmark') 49 | parser.add_argument('--b', default=16, type=int, help='batch size') 50 | parser.add_argument('--epoch', default=100, type=int, help='epochs to train') 51 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 52 | parser.add_argument('--output', default='./model_genesis_pretrain', type=str, help='output path') 53 | parser.add_argument('--gpus', default='0,1,2,3', type=str, help='gpu indexs') 54 | parser.add_argument('--seed', default=42, type=int) 55 | parser.add_argument('--num_worker', type=int, default=8) 56 | args = parser.parse_args() 57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 58 | seed = args.seed 59 | torch.manual_seed(seed) 60 | torch.cuda.manual_seed_all(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | torch.backends.cudnn.deterministic = True 64 | torch.backends.cudnn.benchmark = False 65 | if not os.path.exists(args.output): 66 | os.makedirs(args.output) 67 | print(args) 68 | train_distributed(args) 69 | -------------------------------------------------------------------------------- /3D-SR-Unet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.ndimage import zoom 5 | 6 | 7 | class CubicWeightedPSNRLoss(nn.Module): 8 | def __init__(self): 9 | super(CubicWeightedPSNRLoss, self).__init__() 10 | 11 | def forward(self, upsampled_input, pred, target): 12 | # Perform cubic upsampling on the input 13 | error = (upsampled_input - target) ** 2 14 | weight = error / (error.max() * 2) + 0.5 15 | # Compute the pixel-wise cubic-weighted MSE loss 16 | weighted_mse = ((pred - target) ** 2 * weight).mean() 17 | # Compute the cubic-weighted PSNR loss 18 | return weighted_mse 19 | 20 | 21 | def conv3x3(in_channels, out_channels, stride=1, 22 | padding=1, bias=True, groups=1): 23 | return nn.Conv2d( 24 | in_channels, 25 | out_channels, 26 | kernel_size=3, 27 | stride=stride, 28 | padding=padding, 29 | bias=bias, 30 | groups=groups) 31 | 32 | 33 | def conv3x3x3(in_channels, out_channels, stride=1, 34 | padding=1, bias=True, groups=1): 35 | return nn.Conv3d( 36 | in_channels, 37 | out_channels, 38 | kernel_size=3, 39 | stride=stride, 40 | padding=padding, 41 | bias=bias, 42 | groups=groups) 43 | 44 | 45 | class SRUNet(nn.Module): 46 | def __init__(self, up_scale=6): 47 | super().__init__() 48 | self.up_scale = up_scale 49 | self.conv1_1 = conv3x3x3(1, 32) 50 | self.conv1_2 = conv3x3x3(32, 32) 51 | self.conv1_3 = conv3x3x3(32, 32) 52 | self.fracconv1 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, 53 | stride=(self.up_scale, 1, 1), padding=1) 54 | self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 55 | self.conv2_1 = conv3x3x3(32, 64) 56 | self.conv2_2 = conv3x3x3(64, 64) 57 | self.conv2_3 = conv3x3x3(64, 64) 58 | self.fracconv2 = nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=3, stride=(2, 1, 1), padding=1) 59 | self.conv3_1 = conv3x3x3(64, 128) 60 | self.conv3_2 = conv3x3x3(128, 128) 61 | self.conv3_3 = conv3x3x3(128, 128) 62 | self.fracconv3 = nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=3, stride=(2, 2, 2), 63 | padding=1) 64 | self.conv2_4 = conv3x3x3(128, 64) 65 | self.conv2_5 = conv3x3x3(64, 64) 66 | self.conv2_6 = conv3x3x3(64, 64) 67 | self.fracconv4 = nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3, 68 | stride=(self.up_scale // 2, 2, 2), padding=1) 69 | self.conv1_4 = conv3x3x3(64, 32) 70 | self.conv1_5 = conv3x3x3(32, 32) 71 | self.conv1_6 = conv3x3x3(32, 32) 72 | self.final_conv = conv3x3x3(32, 1) 73 | 74 | def forward(self, x): 75 | x_1_1 = F.relu(self.conv1_1(x)) 76 | x_1_2 = F.relu(self.conv1_2(x_1_1)) 77 | x_1_3 = F.relu(self.conv1_3(x_1_2)) 78 | b, c, d, h, w = x_1_3.shape 79 | x_frac1 = self.fracconv1(x_1_3, output_size=(b, c, d * self.up_scale, h, w)) 80 | # print(x_frac1.shape) 81 | x_2_1 = self.pool(x_1_3) 82 | x_2_2 = F.relu(self.conv2_1(x_2_1)) 83 | 84 | x_2_3 = F.relu(self.conv2_2(x_2_2)) 85 | x_2_4 = F.relu(self.conv2_3(x_2_3)) 86 | b, c, d, h, w = x_2_4.shape 87 | x_frac2 = self.fracconv2(x_2_4, output_size=(b, c, d * 2, h, w)) 88 | # print(x_frac2.shape) 89 | x_3_1 = self.pool(x_2_4) 90 | x_3_2 = F.relu(self.conv3_1(x_3_1)) 91 | x_3_3 = F.relu(self.conv3_2(x_3_2)) 92 | x_3_4 = F.relu(self.conv3_3(x_3_3)) 93 | b, c, d, h, w = x_3_4.shape 94 | x_frac3 = self.fracconv3(x_3_4, output_size=(b, c, d * 2, h * 2, w * 2)) 95 | # print(x_frac3.shape) 96 | x_merge_2 = torch.concatenate([x_frac3, x_frac2], dim=1) 97 | x_2_5 = F.relu(self.conv2_4(x_merge_2)) 98 | x_2_6 = F.relu(self.conv2_5(x_2_5)) 99 | x_2_7 = F.relu(self.conv2_6(x_2_6)) 100 | b, c, d, h, w = x_2_7.shape 101 | x_frac4 = self.fracconv4(x_2_7, output_size=(b, c, d * self.up_scale // 2, h * 2, w * 2)) 102 | # print(x_frac4.shape) 103 | x_merge_1 = torch.concatenate([x_frac1, x_frac4], dim=1) 104 | x_1_4 = F.relu(self.conv1_4(x_merge_1)) 105 | x_1_5 = F.relu(self.conv1_5(x_1_4)) 106 | x_1_6 = F.relu(self.conv1_6(x_1_5)) 107 | out = self.final_conv(x_1_6) 108 | return out 109 | 110 | 111 | if __name__ == '__main__': 112 | model = SRUNet(up_scale=6) 113 | 114 | test_gt = torch.rand((1, 1, 16, 128, 128)) 115 | test_input = torch.rand((1, 1, 16, 128, 128)) 116 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 117 | test_out = model(test_input) 118 | loss_function = CubicWeightedPSNRLoss() 119 | loss = loss_function(test_input, test_gt) 120 | optimizer.zero_grad() 121 | loss.backward() 122 | -------------------------------------------------------------------------------- /3D-SR-Unet/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | import os 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | 11 | def train_cnn(optimizer, model, train_generator, valid_generator, criterion, args): 12 | n_iteration_per_epoch = len(train_generator) 13 | 14 | tb_logger = SummaryWriter(log_dir=args.output) 15 | print(n_iteration_per_epoch) 16 | step_size = 100 # Apply adjustment every 10 epochs 17 | lr_initial = 1e-4 18 | lambda_lr = lambda step: lr_initial * ((step // step_size + 1) ** 0.5) if step > 0 else lr_initial 19 | scheduler = LambdaLR(optimizer, lambda_lr) 20 | train_losses = [] 21 | valid_losses_l1 = [] 22 | avg_train_losses = [] 23 | best_loss = 100000 24 | num_epoch_no_improvement = 0 25 | for epoch in range(args.epoch + 1): 26 | model.train() 27 | 28 | iteration = 0 29 | total_step = 0 30 | for idx, (image, gt, img_upsampled) in enumerate(train_generator): 31 | # scheduler.step() 32 | total_step += args.b 33 | img = image.cuda(non_blocking=True).float() 34 | gt = gt.cuda(non_blocking=True).float() 35 | img_upsampled = img_upsampled.cuda(non_blocking=True).float() 36 | pred = model(img) 37 | loss = criterion(img_upsampled, pred, gt) 38 | iteration += 1 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | scheduler.step(epoch * n_iteration_per_epoch + iteration) 43 | train_losses.append(round(loss.item(), 2)) 44 | if (iteration + 1) % 20 == 0: 45 | print('Epoch [{}/{}], iteration {}, l1oss:{:.6f}, {:.6f} ,learning rate{:.6f}' 46 | .format(epoch + 1, args.epoch, iteration + 1, loss.sum().item(), np.average(train_losses), 47 | optimizer.state_dict()['param_groups'][0]['lr'])) 48 | sys.stdout.flush() 49 | 50 | with torch.no_grad(): 51 | model.eval() 52 | print("validating....") 53 | for i, (image, gt, _) in enumerate(valid_generator): 54 | image_scale = image.cuda(non_blocking=True).float() 55 | gt_scale = gt.cuda(non_blocking=True).float() 56 | pred = model(image_scale) 57 | loss = criterion(pred, gt_scale) 58 | valid_losses_l1.append(loss.sum().item()) 59 | # logging 60 | train_loss = np.average(train_losses) 61 | valid_loss_l1 = np.average(valid_losses_l1) 62 | valid_loss = valid_loss_l1 63 | tb_logger.add_scalar('valid loss', valid_loss_l1, epoch) 64 | avg_train_losses.append(train_loss) 65 | print("Epoch {}, validation loss is {:.4f}, training loss is {:.4f}".format(epoch + 1, valid_loss, 66 | train_loss)) 67 | train_losses = [] 68 | valid_losses = [] 69 | 70 | if valid_loss < best_loss: 71 | print("Validation loss decreases from {:.4f} to {:.4f}".format(best_loss, valid_loss)) 72 | best_loss = valid_loss 73 | num_epoch_no_improvement = 0 74 | # save model 75 | # save all the weight for 3d unet 76 | torch.save({ 77 | 'args': args, 78 | 'epoch': epoch + 1, 79 | 'state_dict': model.state_dict(), 80 | 'optimizer_state_dict': optimizer.state_dict() 81 | }, os.path.join(args.output, 82 | 'best' + '.pt')) 83 | print("Saving model ", 84 | os.path.join(args.output, 85 | 'best' + '.pt')) 86 | else: 87 | if epoch % 10 == 0: 88 | torch.save({ 89 | 'args': args, 90 | 'epoch': epoch + 1, 91 | 'state_dict': model.state_dict(), 92 | 'optimizer_state_dict': optimizer.state_dict() 93 | }, os.path.join(args.output, 94 | 'epoch_' + str(epoch) + '.pt')) 95 | print("Saving model ", 96 | os.path.join(args.output, 97 | 'epoch_' + str(epoch) + '.pt')) 98 | print("Validation loss does not decrease from {:.4f}, num_epoch_no_improvement {}".format(best_loss, 99 | num_epoch_no_improvement)) 100 | num_epoch_no_improvement += 1 101 | if num_epoch_no_improvement > 10: 102 | break 103 | sys.stdout.flush() 104 | tb_logger.close() 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Luchixiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RAFT/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | pytorch_env 6 | build 7 | correlation.egg-info 8 | -------------------------------------------------------------------------------- /RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/RAFT.png -------------------------------------------------------------------------------- /RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /RAFT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/__init__.py -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/core/__init__.py -------------------------------------------------------------------------------- /RAFT/core/align_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tifffile import imread, imwrite 3 | import shutil 4 | import cv2 5 | import imutils 6 | import numpy as np 7 | import image_registration 8 | from scipy.ndimage import shift 9 | 10 | 11 | def mkdir(path): 12 | if os.path.exists(path): 13 | shutil.rmtree(path) 14 | os.mkdir(path) 15 | 16 | 17 | def delete_outlier(points1, points2, move=0, outlier_percent=0.3): 18 | """ 19 | Delete the outliers/mismatches based on the angle and distance. 20 | Args: 21 | points1: key points detected in frame1 22 | points2: key points detected in frame2 23 | move: move for a small distance to avoid points appear at the same location 24 | outlier_percent: how many outliers are removed 25 | 26 | Returns: indexes of selected key points 27 | """ 28 | # angle 29 | points1_mv = points1.copy() 30 | points1_mv[:, 0] = points1[:, 0] - move 31 | vecs = points2 - points1_mv 32 | norms = np.linalg.norm(vecs, axis=1, keepdims=True) 33 | vec_norms = vecs / (norms + 1e-6) 34 | vec_means = np.mean(vec_norms, axis=0).reshape((2, 1)) 35 | cross_angles = vec_norms.dot(vec_means)[:, 0] 36 | index = np.argsort(-cross_angles) 37 | num_select = int(len(index) * (1 - outlier_percent)) 38 | index_selected = index[0:num_select] 39 | 40 | # distance 41 | index1 = np.argsort(norms[:, 0]) 42 | # print(index1) 43 | index1_selected = index1[0:num_select] 44 | 45 | index_selected = list(set(index1_selected) & set(index_selected)) 46 | 47 | return index_selected, np.mean(norms[index_selected]) 48 | 49 | 50 | def align_images(imageGray, templateGray, maxFeatures=500, keepPercent=0.2, 51 | debug=False, outlier=True, sup_img=None): 52 | # convert both the input image and template to grayscale 53 | 54 | orb = cv2.ORB_create(maxFeatures) 55 | (kpsA, descsA) = orb.detectAndCompute(imageGray, None) 56 | (kpsB, descsB) = orb.detectAndCompute(templateGray, None) 57 | # match the features 58 | method = cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING 59 | matcher = cv2.DescriptorMatcher_create(method) 60 | matches = matcher.match(descsA, descsB, None) 61 | matches = sorted(matches, key=lambda x: x.distance) 62 | # keep only the top matches 63 | keep = int(len(matches) * keepPercent) 64 | matches = matches[:keep] 65 | # check to see if we should visualize the matched keypoints 66 | ptsA = np.zeros((len(matches), 2), dtype="float") 67 | ptsB = np.zeros((len(matches), 2), dtype="float") 68 | # loop over the top matches 69 | for (i, m) in enumerate(matches): 70 | # indicate that the two keypoints in the respective images 71 | # map to each other 72 | ptsA[i] = kpsA[m.queryIdx].pt 73 | ptsB[i] = kpsB[m.trainIdx].pt 74 | if outlier: 75 | index, distance = delete_outlier(ptsA, ptsB) 76 | 77 | matches = list(np.array(matches)[index]) 78 | ptsA = ptsA[index, :] 79 | ptsB = ptsB[index, :] 80 | if len(matches) < 10: 81 | return None 82 | (H, mask) = cv2.findHomography(ptsA, ptsB, method=cv2.RANSAC) 83 | return H 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 67 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /RAFT/core/raftConfig.py: -------------------------------------------------------------------------------- 1 | class RaftConfig: 2 | def __init__(self, path, patch_size=256, border=32, tissue='Brain', overlap=0.125): 3 | self.path = path 4 | self.patch_size = patch_size 5 | self.border = border 6 | self.tissue = tissue 7 | self.small = False 8 | self.model = 'RAFT/models/raft-things.pth' 9 | self.overlap = overlap 10 | self.mixed_precision = False 11 | self.alternate_corr = False 12 | self.occlusion = False 13 | 14 | def __getattr__(self, item): 15 | # This method is called when an attribute access is attempted. 16 | try: 17 | return self.__dict__[item] 18 | except KeyError: 19 | return None 20 | 21 | def __setattr__(self, key, value): 22 | # This method allows setting attributes directly. 23 | self.__dict__[key] = value 24 | 25 | def __contains__(self, item): 26 | # This enables the use of 'in' to check for attribute existence. 27 | return item in self.__dict__ 28 | 29 | -------------------------------------------------------------------------------- /RAFT/core/super_res_register.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import glob 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from raft import RAFT 14 | from utils import flow_viz 15 | from utils.utils import InputPadder 16 | from align_functions import * 17 | import torch.nn.functional as F 18 | import os 19 | from tifffile import imwrite 20 | 21 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 22 | DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' 23 | 24 | 25 | def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA): 26 | # initialize the dimensions of the image to be resized and 27 | # grab the image size 28 | dim = None 29 | (h, w) = image.shape[:2] 30 | 31 | # if both the width and height are None, then return the 32 | # original image 33 | if width is None and height is None: 34 | return image 35 | 36 | # check to see if the width is None 37 | if width is None: 38 | # calculate the ratio of the height and construct the 39 | # dimensions 40 | r = height / float(h) 41 | dim = (int(w * r), height) 42 | 43 | # otherwise, the height is None 44 | else: 45 | # calculate the ratio of the width and construct the 46 | # dimensions 47 | r = width / float(w) 48 | dim = (width, int(h * r)) 49 | 50 | # resize the image 51 | resized = cv2.resize(image, dim) 52 | 53 | # return the resized image 54 | return resized 55 | 56 | 57 | ###################### 58 | ## Image form trans### 59 | ###################### 60 | def img2tensor(img): 61 | img_t = np.expand_dims(img.transpose(2, 0, 1), axis=0) 62 | img_t = torch.from_numpy(img_t.astype(np.float32)) 63 | 64 | return img_t 65 | 66 | 67 | def tensor2img(img_t): 68 | img = img_t[0].detach().to("cpu").numpy() 69 | img = np.transpose(img, (1, 2, 0)) 70 | 71 | return img 72 | 73 | 74 | ###################### 75 | # occlusion detection# 76 | ###################### 77 | 78 | def warp(x, flo): 79 | """ 80 | warp an image/tensor (im2) back to im1, according to the optical flow 81 | x: [B, C, H, W] (im2) 82 | flo: [B, 2, H, W] flow 83 | """ 84 | B, C, H, W = x.size() 85 | # mesh grid 86 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 87 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 88 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 89 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 90 | grid = torch.cat((xx, yy), 1).float() 91 | 92 | if x.is_cuda: 93 | grid = grid.cuda() 94 | vgrid = grid + flo 95 | # scale grid to [-1,1] 96 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 97 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 98 | 99 | vgrid = vgrid.permute(0, 2, 3, 1) 100 | output = F.grid_sample(x, vgrid, align_corners=False) 101 | return output 102 | 103 | 104 | 105 | ########################### 106 | ## raft functions 107 | ########################### 108 | def load_image(img): 109 | img = np.stack([img, img, img], axis=2) 110 | 111 | img = torch.from_numpy(img).permute(2, 0, 1).float() 112 | return img[None].to(DEVICE) 113 | 114 | def process_pair(wf_img, gt_img, save_wf_path, save_gt_path, sup_wf_img=None, patch_size=256, stride=224, model=None, 115 | border=32): 116 | wf_img_origin = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY) 117 | gt_img_origin = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY) 118 | wf_img = wf_img_origin[gt_img_origin.shape[0] // 2 - gt_img_origin.shape[0] // 4: gt_img_origin.shape[0] // 2 + 119 | gt_img_origin.shape[0] // 4, 120 | gt_img_origin.shape[1] // 2 - gt_img_origin.shape[1] // 4: gt_img_origin.shape[1] // 2 + 121 | gt_img_origin.shape[1] // 4] 122 | gt_img = cv2.resize(gt_img_origin, (gt_img_origin.shape[1] // 2, gt_img_origin.shape[0] // 2)) 123 | x_offset, y_offset, _, _ = image_registration.chi2_shift(wf_img, gt_img, 0.1, return_error=True) 124 | wf_img = shift(wf_img_origin, (y_offset, x_offset))[ 125 | gt_img_origin.shape[0] // 2 - gt_img_origin.shape[0] // 4 - 8: gt_img_origin.shape[0] // 2 + 126 | gt_img_origin.shape[0] // 4 + 8, 127 | gt_img_origin.shape[1] // 2 - gt_img_origin.shape[1] // 4 - 8: gt_img_origin.shape[1] // 2 + 128 | gt_img_origin.shape[1] // 4 + 8] 129 | 130 | H = align_images(wf_img, gt_img, debug=False) 131 | h, w = gt_img.shape 132 | aligned = cv2.warpPerspective(wf_img, H, (w, h)) 133 | x = border 134 | x_end = wf_img.shape[0] - border 135 | y_end = wf_img.shape[0] - border 136 | count = 1 137 | while x + patch_size < x_end: 138 | y = border 139 | while y + patch_size < y_end: 140 | crop_wf_img = aligned[x - border: x + patch_size + border, y - border: y + patch_size + border] 141 | crop_gt_img = gt_img[x - border: x + patch_size + border, y - border: y + patch_size + border] 142 | H_sub = align_images(crop_wf_img, crop_gt_img) 143 | if H_sub is None: 144 | count += 1 145 | y += stride 146 | continue 147 | else: 148 | (h_sub, w_sub) = crop_gt_img.shape[:2] 149 | crop_wf_img = cv2.warpPerspective(crop_wf_img, H_sub, (w_sub, h_sub)) 150 | if np.sum(crop_wf_img[border:-border, border:-border] == 0) > 10: 151 | count += 1 152 | y += stride 153 | continue 154 | image1 = load_image(crop_gt_img) 155 | image2 = load_image(crop_wf_img) 156 | 157 | padder = InputPadder(image1.shape) 158 | image1, image2 = padder.pad(image1, image2) 159 | 160 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 161 | image_warped = warp(image2 / 255.0, flow_up) 162 | crop_wf_img = image_warped[0].permute(1, 2, 0).cpu().numpy() 163 | crop_wf_img = np.uint8(crop_wf_img[:, :, 0] * 255) 164 | if np.sum(crop_wf_img[border:-border, border:-border] == 0) > 10: 165 | count += 1 166 | y += stride 167 | continue 168 | imwrite(os.path.join(save_wf_path, str(count) + '.tif'), crop_wf_img[border:-border, border:-border]) 169 | imwrite(os.path.join(save_gt_path, str(count) + '.tif'), 170 | gt_img_origin[2 * x: 2 * x + 2 * patch_size, 2 * y: 2 * y + 2 * patch_size]) 171 | count += 1 172 | y += stride 173 | x += stride 174 | 175 | 176 | def registration(args): 177 | model = torch.nn.DataParallel(RAFT(args)) 178 | model.load_state_dict(torch.load(args.model, map_location='cpu')) 179 | 180 | model = model.module 181 | model.to(DEVICE) 182 | model.eval() 183 | 184 | with torch.no_grad(): 185 | task = 'zoom' 186 | path = args.path 187 | target_path = os.path.join(path, task) 188 | mkdir(target_path) 189 | image_types = ['Brain__2w_01.tif', 'Brain__2w_02.tif', 'Brain__2w_03.tif'] 190 | train_wf_path = os.path.join(target_path, 'train_wf') 191 | train_gt_path = os.path.join(target_path, 'train_gt') 192 | mkdir(train_wf_path) 193 | mkdir(train_gt_path) 194 | for i in range(100): 195 | if not os.path.exists(os.path.join(path, str(i), 'Brain__4w_09.tif')): 196 | continue 197 | roi_wf_path = os.path.join(train_wf_path, str(i)) 198 | roi_gt_path = os.path.join(train_gt_path, str(i)) 199 | 200 | mkdir(roi_wf_path) 201 | mkdir(roi_gt_path) 202 | for type in image_types: 203 | print(f'processing image {i}, {type}') 204 | save_wf_path = os.path.join(roi_wf_path, type[:-4]) 205 | save_gt_path = os.path.join(roi_gt_path, type[:-4]) 206 | mkdir(save_wf_path) 207 | mkdir(save_gt_path) 208 | gt_file_img = cv2.imread(os.path.join(path, str(i), 'Brain__4w_09.tif')) 209 | wf_file_img = cv2.imread(os.path.join(path, str(i), type)) 210 | sup_wf_img = None 211 | # print(wf_file_img.min()) 212 | process_pair(wf_file_img, gt_file_img, save_wf_path, save_gt_path, sup_wf_img=sup_wf_img, model=model, 213 | patch_size=args.patch_size, border=args.border, stride=int(args.patch_size * (1-args.overlap))) 214 | 215 | 216 | if __name__ == '__main__': 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument('--model', default="../models/raft-things.pth") 219 | parser.add_argument('--path', help="dataset for evaluation") 220 | parser.add_argument('--category', help="save warped images") 221 | parser.add_argument('--small', action='store_true', help='use small model') 222 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 223 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 224 | parser.add_argument('--occlusion', action='store_true', help='predict occlusion masks') 225 | parser.add_argument('--patch_size', default=128, type=int) 226 | parser.add_argument('--border', default=32, type=int) 227 | parser.add_argument('--overlap', default=0.125, type=float) 228 | 229 | args = parser.parse_args() 230 | 231 | registration(args) 232 | -------------------------------------------------------------------------------- /RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0016.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0017.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0018.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0019.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0020.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0021.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0022.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0023.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0024.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0025.png -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | model = torch.nn.DataParallel(RAFT(args)) 44 | model.load_state_dict(torch.load(args.model)) 45 | 46 | model = model.module 47 | model.to(DEVICE) 48 | model.eval() 49 | 50 | with torch.no_grad(): 51 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 52 | glob.glob(os.path.join(args.path, '*.jpg')) 53 | 54 | images = sorted(images) 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | image1 = load_image(imfile1) 57 | image2 = load_image(imfile2) 58 | 59 | padder = InputPadder(image1.shape) 60 | image1, image2 = padder.pad(image1, image2) 61 | 62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 63 | viz(image1, flow_up) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model', help="restore checkpoint") 69 | parser.add_argument('--path', help="dataset for evaluation") 70 | parser.add_argument('--small', action='store_true', help='use small model') 71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 73 | args = parser.parse_args() 74 | 75 | demo(args) 76 | -------------------------------------------------------------------------------- /RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /RAFT/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | from raft import RAFT 18 | from utils.utils import InputPadder, forward_interpolate 19 | 20 | 21 | @torch.no_grad() 22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): 23 | """ Create submission for the Sintel leaderboard """ 24 | model.eval() 25 | for dstype in ['clean', 'final']: 26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 27 | 28 | flow_prev, sequence_prev = None, None 29 | for test_id in range(len(test_dataset)): 30 | image1, image2, (sequence, frame) = test_dataset[test_id] 31 | if sequence != sequence_prev: 32 | flow_prev = None 33 | 34 | padder = InputPadder(image1.shape) 35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 36 | 37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 39 | 40 | if warm_start: 41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 42 | 43 | output_dir = os.path.join(output_path, dstype, sequence) 44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 45 | 46 | if not os.path.exists(output_dir): 47 | os.makedirs(output_dir) 48 | 49 | frame_utils.writeFlow(output_file, flow) 50 | sequence_prev = sequence 51 | 52 | 53 | @torch.no_grad() 54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 55 | """ Create submission for the Sintel leaderboard """ 56 | model.eval() 57 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 58 | 59 | if not os.path.exists(output_path): 60 | os.makedirs(output_path) 61 | 62 | for test_id in range(len(test_dataset)): 63 | image1, image2, (frame_id, ) = test_dataset[test_id] 64 | padder = InputPadder(image1.shape, mode='kitti') 65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 66 | 67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 69 | 70 | output_filename = os.path.join(output_path, frame_id) 71 | frame_utils.writeFlowKITTI(output_filename, flow) 72 | 73 | 74 | @torch.no_grad() 75 | def validate_chairs(model, iters=24): 76 | """ Perform evaluation on the FlyingChairs (test) split """ 77 | model.eval() 78 | epe_list = [] 79 | 80 | val_dataset = datasets.FlyingChairs(split='validation') 81 | for val_id in range(len(val_dataset)): 82 | image1, image2, flow_gt, _ = val_dataset[val_id] 83 | image1 = image1[None].cuda() 84 | image2 = image2[None].cuda() 85 | 86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 88 | epe_list.append(epe.view(-1).numpy()) 89 | 90 | epe = np.mean(np.concatenate(epe_list)) 91 | print("Validation Chairs EPE: %f" % epe) 92 | return {'chairs': epe} 93 | 94 | 95 | @torch.no_grad() 96 | def validate_sintel(model, iters=32): 97 | """ Peform validation using the Sintel (train) split """ 98 | model.eval() 99 | results = {} 100 | for dstype in ['clean', 'final']: 101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 102 | epe_list = [] 103 | 104 | for val_id in range(len(val_dataset)): 105 | image1, image2, flow_gt, _ = val_dataset[val_id] 106 | image1 = image1[None].cuda() 107 | image2 = image2[None].cuda() 108 | 109 | padder = InputPadder(image1.shape) 110 | image1, image2 = padder.pad(image1, image2) 111 | 112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 113 | flow = padder.unpad(flow_pr[0]).cpu() 114 | 115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 116 | epe_list.append(epe.view(-1).numpy()) 117 | 118 | epe_all = np.concatenate(epe_list) 119 | epe = np.mean(epe_all) 120 | px1 = np.mean(epe_all<1) 121 | px3 = np.mean(epe_all<3) 122 | px5 = np.mean(epe_all<5) 123 | 124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 125 | results[dstype] = np.mean(epe_list) 126 | 127 | return results 128 | 129 | 130 | @torch.no_grad() 131 | def validate_kitti(model, iters=24): 132 | """ Peform validation using the KITTI-2015 (train) split """ 133 | model.eval() 134 | val_dataset = datasets.KITTI(split='training') 135 | 136 | out_list, epe_list = [], [] 137 | for val_id in range(len(val_dataset)): 138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 139 | image1 = image1[None].cuda() 140 | image2 = image2[None].cuda() 141 | 142 | padder = InputPadder(image1.shape, mode='kitti') 143 | image1, image2 = padder.pad(image1, image2) 144 | 145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 146 | flow = padder.unpad(flow_pr[0]).cpu() 147 | 148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 149 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 150 | 151 | epe = epe.view(-1) 152 | mag = mag.view(-1) 153 | val = valid_gt.view(-1) >= 0.5 154 | 155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 156 | epe_list.append(epe[val].mean().item()) 157 | out_list.append(out[val].cpu().numpy()) 158 | 159 | epe_list = np.array(epe_list) 160 | out_list = np.concatenate(out_list) 161 | 162 | epe = np.mean(epe_list) 163 | f1 = 100 * np.mean(out_list) 164 | 165 | print("Validation KITTI: %f, %f" % (epe, f1)) 166 | return {'kitti-epe': epe, 'kitti-f1': f1} 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--model', help="restore checkpoint") 172 | parser.add_argument('--dataset', help="dataset for evaluation") 173 | parser.add_argument('--small', action='store_true', help='use small model') 174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 176 | args = parser.parse_args() 177 | 178 | model = torch.nn.DataParallel(RAFT(args)) 179 | model.load_state_dict(torch.load(args.model)) 180 | 181 | model.cuda() 182 | model.eval() 183 | 184 | # create_sintel_submission(model.module, warm_start=True) 185 | # create_kitti_submission(model.module) 186 | 187 | with torch.no_grad(): 188 | if args.dataset == 'chairs': 189 | validate_chairs(model.module) 190 | 191 | elif args.dataset == 'sintel': 192 | validate_sintel(model.module) 193 | 194 | elif args.dataset == 'kitti': 195 | validate_kitti(model.module) 196 | 197 | 198 | -------------------------------------------------------------------------------- /RAFT/models/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/models/raft-things.pth -------------------------------------------------------------------------------- /RAFT/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.data import DataLoader 18 | from raft import RAFT 19 | import evaluate 20 | import datasets 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | try: 25 | from torch.cuda.amp import GradScaler 26 | except: 27 | # dummy GradScaler for PyTorch < 1.6 28 | class GradScaler: 29 | def __init__(self): 30 | pass 31 | def scale(self, loss): 32 | return loss 33 | def unscale_(self, optimizer): 34 | pass 35 | def step(self, optimizer): 36 | optimizer.step() 37 | def update(self): 38 | pass 39 | 40 | 41 | # exclude extremly large displacements 42 | MAX_FLOW = 400 43 | SUM_FREQ = 100 44 | VAL_FREQ = 5000 45 | 46 | 47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 48 | """ Loss function defined over sequence of flow predictions """ 49 | 50 | n_predictions = len(flow_preds) 51 | flow_loss = 0.0 52 | 53 | # exlude invalid pixels and extremely large diplacements 54 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 55 | valid = (valid >= 0.5) & (mag < max_flow) 56 | 57 | for i in range(n_predictions): 58 | i_weight = gamma**(n_predictions - i - 1) 59 | i_loss = (flow_preds[i] - flow_gt).abs() 60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 61 | 62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 63 | epe = epe.view(-1)[valid.view(-1)] 64 | 65 | metrics = { 66 | 'epe': epe.mean().item(), 67 | '1px': (epe < 1).float().mean().item(), 68 | '3px': (epe < 3).float().mean().item(), 69 | '5px': (epe < 5).float().mean().item(), 70 | } 71 | 72 | return flow_loss, metrics 73 | 74 | 75 | def count_parameters(model): 76 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 77 | 78 | 79 | def fetch_optimizer(args, model): 80 | """ Create the optimizer and learning rate scheduler """ 81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 82 | 83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 85 | 86 | return optimizer, scheduler 87 | 88 | 89 | class Logger: 90 | def __init__(self, model, scheduler): 91 | self.model = model 92 | self.scheduler = scheduler 93 | self.total_steps = 0 94 | self.running_loss = {} 95 | self.writer = None 96 | 97 | def _print_training_status(self): 98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 101 | 102 | # print the training status 103 | print(training_str + metrics_str) 104 | 105 | if self.writer is None: 106 | self.writer = SummaryWriter() 107 | 108 | for k in self.running_loss: 109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 110 | self.running_loss[k] = 0.0 111 | 112 | def push(self, metrics): 113 | self.total_steps += 1 114 | 115 | for key in metrics: 116 | if key not in self.running_loss: 117 | self.running_loss[key] = 0.0 118 | 119 | self.running_loss[key] += metrics[key] 120 | 121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 122 | self._print_training_status() 123 | self.running_loss = {} 124 | 125 | def write_dict(self, results): 126 | if self.writer is None: 127 | self.writer = SummaryWriter() 128 | 129 | for key in results: 130 | self.writer.add_scalar(key, results[key], self.total_steps) 131 | 132 | def close(self): 133 | self.writer.close() 134 | 135 | 136 | def train(args): 137 | 138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus) 139 | print("Parameter Count: %d" % count_parameters(model)) 140 | 141 | if args.restore_ckpt is not None: 142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False) 143 | 144 | model.cuda() 145 | model.train() 146 | 147 | if args.stage != 'chairs': 148 | model.module.freeze_bn() 149 | 150 | train_loader = datasets.fetch_dataloader(args) 151 | optimizer, scheduler = fetch_optimizer(args, model) 152 | 153 | total_steps = 0 154 | scaler = GradScaler(enabled=args.mixed_precision) 155 | logger = Logger(model, scheduler) 156 | 157 | VAL_FREQ = 5000 158 | add_noise = True 159 | 160 | should_keep_training = True 161 | while should_keep_training: 162 | 163 | for i_batch, data_blob in enumerate(train_loader): 164 | optimizer.zero_grad() 165 | image1, image2, flow, valid = [x.cuda() for x in data_blob] 166 | 167 | if args.add_noise: 168 | stdv = np.random.uniform(0.0, 5.0) 169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) 170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) 171 | 172 | flow_predictions = model(image1, image2, iters=args.iters) 173 | 174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) 175 | scaler.scale(loss).backward() 176 | scaler.unscale_(optimizer) 177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 178 | 179 | scaler.step(optimizer) 180 | scheduler.step() 181 | scaler.update() 182 | 183 | logger.push(metrics) 184 | 185 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 187 | torch.save(model.state_dict(), PATH) 188 | 189 | results = {} 190 | for val_dataset in args.validation: 191 | if val_dataset == 'chairs': 192 | results.update(evaluate.validate_chairs(model.module)) 193 | elif val_dataset == 'sintel': 194 | results.update(evaluate.validate_sintel(model.module)) 195 | elif val_dataset == 'kitti': 196 | results.update(evaluate.validate_kitti(model.module)) 197 | 198 | logger.write_dict(results) 199 | 200 | model.train() 201 | if args.stage != 'chairs': 202 | model.module.freeze_bn() 203 | 204 | total_steps += 1 205 | 206 | if total_steps > args.num_steps: 207 | should_keep_training = False 208 | break 209 | 210 | logger.close() 211 | PATH = 'checkpoints/%s.pth' % args.name 212 | torch.save(model.state_dict(), PATH) 213 | 214 | return PATH 215 | 216 | 217 | if __name__ == '__main__': 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument('--name', default='raft', help="name your experiment") 220 | parser.add_argument('--stage', help="determines which dataset to use for training") 221 | parser.add_argument('--restore_ckpt', help="restore checkpoint") 222 | parser.add_argument('--small', action='store_true', help='use small model') 223 | parser.add_argument('--validation', type=str, nargs='+') 224 | 225 | parser.add_argument('--lr', type=float, default=0.00002) 226 | parser.add_argument('--num_steps', type=int, default=100000) 227 | parser.add_argument('--batch_size', type=int, default=6) 228 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 229 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) 230 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 231 | 232 | parser.add_argument('--iters', type=int, default=12) 233 | parser.add_argument('--wdecay', type=float, default=.00005) 234 | parser.add_argument('--epsilon', type=float, default=1e-8) 235 | parser.add_argument('--clip', type=float, default=1.0) 236 | parser.add_argument('--dropout', type=float, default=0.0) 237 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 238 | parser.add_argument('--add_noise', action='store_true') 239 | args = parser.parse_args() 240 | 241 | torch.manual_seed(1234) 242 | np.random.seed(1234) 243 | 244 | if not os.path.isdir('checkpoints'): 245 | os.mkdir('checkpoints') 246 | 247 | train(args) -------------------------------------------------------------------------------- /RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 7 | -------------------------------------------------------------------------------- /config/EMDiffuse-n-big.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EMDiffuse-r", // experiments name 3 | "norm": true, 4 | "percent": false, 5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0 6 | "seed" : -1, // random seed, seed <0 represents randomization not used 7 | "finetune_norm": false, // find the parameters to optimize 8 | "task" : "denoise", 9 | "path": { //set every part file path 10 | "base_dir": "experiments", // base path for all log except resume_state 11 | "code": "code", // code backup 12 | "tb_logger": "tb_logger", // path of tensorboard logger 13 | "results": "results", 14 | "checkpoint": "checkpoint", 15 | "resume_state": "experiments/train_EMDiffuse-n-large_240125_221819/5180" // checkpoint path, set to null if used for training 16 | // "resume_state": "experiments/EMDiffuse-n/2720" // checkpoint path, set to null if used for training 17 | // "resume_state": null // checkpoint path, set to null if used for training 18 | }, 19 | 20 | "datasets": { // train or test 21 | "train": { 22 | "which_dataset": { // import designated dataset using arguments 23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 24 | "args":{ // arguments to initialize dataset 25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf", 26 | "data_len": -1, 27 | "norm": true, 28 | "percent": false, 29 | "image_size": [768, 768] 30 | } 31 | }, 32 | "dataloader":{ 33 | "validation_split": 2, // percent or number 34 | "args":{ // arguments to initialize train_dataloader 35 | "batch_size": 3, // batch size in each gpu 36 | "num_workers": 4, 37 | "shuffle": true, 38 | "pin_memory": true, 39 | "drop_last": true 40 | }, 41 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 42 | "batch_size": 1, // batch size in each gpu 43 | "num_workers": 4, 44 | "shuffle": false, 45 | "pin_memory": true, 46 | "drop_last": false 47 | } 48 | } 49 | }, 50 | "test": { 51 | "which_dataset": { 52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file 53 | "args":{ 54 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/test_wf", 55 | "norm":true, 56 | "percent": false, 57 | "phase": "val", 58 | "image_size": [768, 768] 59 | } 60 | }, 61 | "dataloader":{ 62 | "args":{ 63 | "batch_size": 8, 64 | "num_workers": 0, 65 | "pin_memory": true 66 | } 67 | } 68 | } 69 | }, 70 | 71 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 72 | "which_model": { // import designated model(trainer) using arguments 73 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 74 | "args": { 75 | "sample_num": 1, // process of each image 76 | "task": "denoise", 77 | "ema_scheduler": { 78 | "ema_start": 1, 79 | "ema_iter": 1, 80 | "ema_decay": 0.9999 81 | }, 82 | "optimizers": [ 83 | { "lr": 5e-5, "weight_decay": 0} 84 | ] 85 | } 86 | }, 87 | "which_networks": [ // import designated list of networks using arguments 88 | { 89 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 90 | "args": { // arguments to initialize network 91 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 92 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 93 | "norm": true, 94 | "unet": { 95 | "in_channel": 2, 96 | "out_channel": 1, 97 | "inner_channel": 32, 98 | "channel_mults": [ 99 | 1, 100 | 2, 101 | 4, 102 | 8 103 | ], 104 | "attn_res": [ 105 | // 32, 106 | 16 107 | // 8 108 | ], 109 | "num_head_channels": 32, 110 | "res_blocks": 2, 111 | "dropout": 0.2, 112 | "image_size": 256 113 | }, 114 | "beta_schedule": { 115 | "train": { 116 | "schedule": "linear", 117 | "n_timestep": 2000, 118 | // "n_timestep": 5, // debug 119 | "linear_start": 1e-6, 120 | "linear_end": 0.01 121 | }, 122 | "test": { 123 | "schedule": "linear", 124 | "n_timestep": 500, 125 | "linear_start": 1e-4, 126 | "linear_end": 0.09 127 | } 128 | 129 | } 130 | } 131 | } 132 | ], 133 | "which_losses": [ // import designated list of losses without arguments 134 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 135 | ], 136 | "which_metrics": [ // import designated list of metrics without arguments 137 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 138 | ] 139 | }, 140 | 141 | "train": { // arguments for basic training 142 | "n_epoch": 1e8, // max epochs, not limited now 143 | "n_iter": 1e8, // max interations 144 | "val_epoch": 20, // valdation every specified number of epochs 145 | "save_checkpoint_epoch": 20, 146 | "log_iter": 1e4, // log every specified number of iterations 147 | "tensorboard" : true // tensorboardX enable 148 | }, 149 | 150 | "debug": { // arguments in debug mode, which will replace arguments in train 151 | "val_epoch": 1, 152 | "save_checkpoint_epoch": 1, 153 | "log_iter": 10, 154 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /config/EMDiffuse-n-transfer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EMDiffuse-n-NCtransfer_unfinetune", // experiments name 3 | "norm": true, 4 | "percent": false, 5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0 6 | "seed" : -1, // random seed, seed <0 represents randomization not used 7 | "finetune_norm": false, // find the parameters to optimize 8 | "task" : "denoise", 9 | "path": { //set every part file path 10 | "base_dir": "experiments", // base path for all log except resume_state 11 | "code": "code", // code backup 12 | "tb_logger": "tb_logger", // path of tensorboard logger 13 | "results": "results", 14 | "checkpoint": "checkpoint", 15 | "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training 16 | // "resume_state": "experiments/train_EMDiffuse-n-NCtransfer_231025_151833/checkpoint/3700" // checkpoint path, set to null if usedkua for training 17 | // "resume_state": "experiments/EMDiffuse-n/2720" // checkpoint path, set to null if used for training 18 | // "resume_state": null // checkpoint path, set to null if used for training 19 | }, 20 | "datasets": { // train or test 21 | "train": { 22 | "which_dataset": { // import designated dataset using arguments 23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 24 | "args":{ // arguments to initialize dataset 25 | "data_root": "/data/cxlu/transfer/NC/denoise/train_wf", 26 | "data_len": -1, 27 | "norm": true, 28 | "percent": false 29 | } 30 | }, 31 | "dataloader":{ 32 | "validation_split": 20, // percent or number 33 | "args":{ // arguments to initialize train_dataloader 34 | "batch_size": 3, // batch size in each gpu 35 | "num_workers": 4, 36 | "shuffle": true, 37 | "pin_memory": true, 38 | "drop_last": true 39 | }, 40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 41 | "batch_size": 10, // batch size in each gpu 42 | "num_workers": 4, 43 | "shuffle": false, 44 | "pin_memory": true, 45 | "drop_last": false 46 | } 47 | } 48 | }, 49 | "test": { 50 | "which_dataset": { 51 | 52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file 53 | "args":{ 54 | "data_root": "/data/cxlu/transfer/NC/denoise_test/test_wf", 55 | "norm":true, 56 | "percent": false, 57 | "phase": "val" 58 | } 59 | }, 60 | "dataloader":{ 61 | "args":{ 62 | "batch_size": 8, 63 | "num_workers": 4, 64 | "pin_memory": true 65 | } 66 | } 67 | } 68 | }, 69 | 70 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 71 | "which_model": { // import designated model(trainer) using arguments 72 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 73 | "args": { 74 | "sample_num": 8, // process of each image 75 | "task": "denoise", 76 | "ema_scheduler": { 77 | "ema_start": 1, 78 | "ema_iter": 1, 79 | "ema_decay": 0.9999 80 | }, 81 | "optimizers": [ 82 | { "lr": 5e-5, "weight_decay": 0} 83 | ] 84 | } 85 | }, 86 | "which_networks": [ // import designated list of networks using arguments 87 | { 88 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 89 | "args": { // arguments to initialize network 90 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 91 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 92 | "norm": true, 93 | "unet": { 94 | "in_channel": 2, 95 | "out_channel": 1, 96 | "inner_channel": 32, 97 | "channel_mults": [ 98 | 1, 99 | 2, 100 | 4, 101 | 8 102 | ], 103 | "attn_res": [ 104 | // 32, 105 | 16 106 | // 8 107 | ], 108 | "num_head_channels": 32, 109 | "res_blocks": 2, 110 | "dropout": 0.2, 111 | "image_size": 256 112 | }, 113 | "beta_schedule": { 114 | "train": { 115 | "schedule": "linear", 116 | "n_timestep": 2000, 117 | // "n_timestep": 5, // debug 118 | "linear_start": 1e-6, 119 | "linear_end": 0.01 120 | }, 121 | "test": { 122 | "schedule": "linear", 123 | "n_timestep": 1000, 124 | "linear_start": 1e-4, 125 | "linear_end": 0.09 126 | } 127 | 128 | } 129 | } 130 | } 131 | ], 132 | "which_losses": [ // import designated list of losses without arguments 133 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 134 | ], 135 | "which_metrics": [ // import designated list of metrics without arguments 136 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 137 | ] 138 | }, 139 | 140 | "train": { // arguments for basic training 141 | "n_epoch": 1e8, // max epochs, not limited now 142 | "n_iter": 1e8, // max interations 143 | "val_epoch": 100, // valdation every specified number of epochs 144 | "save_checkpoint_epoch": 20, 145 | "log_iter": 1e4, // log every specified number of iterations 146 | "tensorboard" : true // tensorboardX enable 147 | }, 148 | 149 | "debug": { // arguments in debug mode, which will replace arguments in train 150 | "val_epoch": 1, 151 | "save_checkpoint_epoch": 1, 152 | "log_iter": 10, 153 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /config/EMDiffuse-n.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EMDiffuse-n", // experiments name 3 | "norm": true, 4 | "percent": false, 5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0 6 | "seed" : -1, // random seed, seed <0 represents randomization not used 7 | "finetune_norm": false, // find the parameters to optimize 8 | "task" : "denoise", 9 | "path": { //set every part file path 10 | "base_dir": "experiments", // base path for all log except resume_state 11 | "code": "code", // code backup 12 | "tb_logger": "tb_logger", // path of tensorboard logger 13 | "results": "results", 14 | "checkpoint": "checkpoint", 15 | // "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training 16 | "resume_state": "experiments/EMDiffuse-n/best" // checkpoint path, set to null if used for training 17 | // "resume_state": null // checkpoint path, set to null if used for training 18 | }, 19 | 20 | "datasets": { // train or test 21 | "train": { 22 | "which_dataset": { // import designated dataset using arguments 23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 24 | "args":{ // arguments to initialize dataset 25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf", 26 | "data_len": -1, 27 | "norm": true, 28 | "percent": false 29 | } 30 | }, 31 | "dataloader":{ 32 | "validation_split": 0.1, // percent or number 33 | "args":{ // arguments to initialize train_dataloader 34 | "batch_size": 3, // batch size in each gpu 35 | "num_workers": 4, 36 | "shuffle": true, 37 | "pin_memory": true, 38 | "drop_last": true 39 | }, 40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 41 | "batch_size": 10, // batch size in each gpu 42 | "num_workers": 4, 43 | "shuffle": false, 44 | "pin_memory": true, 45 | "drop_last": false 46 | } 47 | } 48 | }, 49 | "test": { 50 | "which_dataset": { 51 | 52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file 53 | "args":{ 54 | "data_root": "/data/cxlu/denoise_single", 55 | "norm":true, 56 | "percent": false, 57 | "phase": "val" 58 | } 59 | }, 60 | "dataloader":{ 61 | "args":{ 62 | "batch_size": 8, 63 | "num_workers": 0, 64 | "pin_memory": true 65 | } 66 | } 67 | } 68 | }, 69 | 70 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 71 | "which_model": { // import designated model(trainer) using arguments 72 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 73 | "args": { 74 | "sample_num": 8, // process of each image 75 | "task": "denoise", 76 | "ema_scheduler": { 77 | "ema_start": 1, 78 | "ema_iter": 1, 79 | "ema_decay": 0.9999 80 | }, 81 | "optimizers": [ 82 | { "lr": 5e-5, "weight_decay": 0} 83 | ] 84 | } 85 | }, 86 | "which_networks": [ // import designated list of networks using arguments 87 | { 88 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 89 | "args": { // arguments to initialize network 90 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 91 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 92 | "norm": true, 93 | "unet": { 94 | "in_channel": 2, 95 | "out_channel": 1, 96 | "inner_channel": 32, 97 | "channel_mults": [ 98 | 1, 99 | 2, 100 | 4, 101 | 8 102 | ], 103 | "attn_res": [ 104 | // 32, 105 | 16 106 | // 8 107 | ], 108 | "num_head_channels": 32, 109 | "res_blocks": 2, 110 | "dropout": 0.2, 111 | "image_size": 256 112 | }, 113 | "beta_schedule": { 114 | "train": { 115 | "schedule": "linear", 116 | "n_timestep": 2000, 117 | // "n_timestep": 5, // debug 118 | "linear_start": 1e-6, 119 | "linear_end": 0.01 120 | }, 121 | "test": { 122 | "schedule": "linear", 123 | "n_timestep": 1000, 124 | "linear_start": 1e-4, 125 | "linear_end": 0.09 126 | } 127 | 128 | } 129 | } 130 | } 131 | ], 132 | "which_losses": [ // import designated list of losses without arguments 133 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 134 | ], 135 | "which_metrics": [ // import designated list of metrics without arguments 136 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 137 | ] 138 | }, 139 | 140 | "train": { // arguments for basic training 141 | "n_epoch": 1e8, // max epochs, not limited now 142 | "n_iter": 1e8, // max interations 143 | "val_epoch": 20, // valdation every specified number of epochs 144 | "save_checkpoint_epoch": 20, 145 | "log_iter": 1e4, // log every specified number of iterations 146 | "tensorboard" : true // tensorboardX enable 147 | }, 148 | 149 | "debug": { // arguments in debug mode, which will replace arguments in train 150 | "val_epoch": 1, 151 | "save_checkpoint_epoch": 1, 152 | "log_iter": 10, 153 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /config/EMDiffuse-r.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EMDiffuse-r", // experiments name 3 | "norm": true, 4 | "percent": false, 5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0 6 | "seed" : -1, // random seed, seed <0 represents randomization not used 7 | "finetune_norm": false, // find the parameters to optimize 8 | "task" : "denoise", 9 | "path": { //set every part file path 10 | "base_dir": "experiments", // base path for all log except resume_state 11 | "code": "code", // code backup 12 | "tb_logger": "tb_logger", // path of tensorboard logger 13 | "results": "results", 14 | "checkpoint": "checkpoint", 15 | // "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training 16 | "resume_state": "experiments/EMDiffuse-n/best" // checkpoint path, set to null if used for training 17 | // "resume_state": null // checkpoint path, set to null if used for training 18 | }, 19 | 20 | "datasets": { // train or test 21 | "train": { 22 | "which_dataset": { // import designated dataset using arguments 23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 24 | "args":{ // arguments to initialize dataset 25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf", 26 | "data_len": -1, 27 | "norm": true, 28 | "percent": false 29 | } 30 | }, 31 | "dataloader":{ 32 | "validation_split": 0.1, // percent or number 33 | "args":{ // arguments to initialize train_dataloader 34 | "batch_size": 3, // batch size in each gpu 35 | "num_workers": 4, 36 | "shuffle": true, 37 | "pin_memory": true, 38 | "drop_last": true 39 | }, 40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 41 | "batch_size": 10, // batch size in each gpu 42 | "num_workers": 4, 43 | "shuffle": false, 44 | "pin_memory": true, 45 | "drop_last": false 46 | } 47 | } 48 | }, 49 | "test": { 50 | "which_dataset": { 51 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file 52 | "args":{ 53 | "data_root": "/data/cxlu/denoise_single", 54 | "norm":true, 55 | "percent": false, 56 | "phase": "val" 57 | } 58 | }, 59 | "dataloader":{ 60 | "args":{ 61 | "batch_size": 8, 62 | "num_workers": 0, 63 | "pin_memory": true 64 | } 65 | } 66 | } 67 | }, 68 | 69 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 70 | "which_model": { // import designated model(trainer) using arguments 71 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 72 | "args": { 73 | "sample_num": 8, // process of each image 74 | "task": "denoise", 75 | "ema_scheduler": { 76 | "ema_start": 1, 77 | "ema_iter": 1, 78 | "ema_decay": 0.9999 79 | }, 80 | "optimizers": [ 81 | { "lr": 5e-5, "weight_decay": 0} 82 | ] 83 | } 84 | }, 85 | "which_networks": [ // import designated list of networks using arguments 86 | { 87 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 88 | "args": { // arguments to initialize network 89 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 90 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 91 | "norm": true, 92 | "unet": { 93 | "in_channel": 2, 94 | "out_channel": 1, 95 | "inner_channel": 32, 96 | "channel_mults": [ 97 | 1, 98 | 2, 99 | 4, 100 | 8 101 | ], 102 | "attn_res": [ 103 | // 32, 104 | 16 105 | // 8 106 | ], 107 | "num_head_channels": 32, 108 | "res_blocks": 2, 109 | "dropout": 0.2, 110 | "image_size": 256 111 | }, 112 | "beta_schedule": { 113 | "train": { 114 | "schedule": "linear", 115 | "n_timestep": 2000, 116 | // "n_timestep": 5, // debug 117 | "linear_start": 1e-6, 118 | "linear_end": 0.01 119 | }, 120 | "test": { 121 | "schedule": "linear", 122 | "n_timestep": 1000, 123 | "linear_start": 1e-4, 124 | "linear_end": 0.09 125 | } 126 | 127 | } 128 | } 129 | } 130 | ], 131 | "which_losses": [ // import designated list of losses without arguments 132 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 133 | ], 134 | "which_metrics": [ // import designated list of metrics without arguments 135 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 136 | ] 137 | }, 138 | 139 | "train": { // arguments for basic training 140 | "n_epoch": 1e8, // max epochs, not limited now 141 | "n_iter": 1e8, // max interations 142 | "val_epoch": 20, // valdation every specified number of epochs 143 | "save_checkpoint_epoch": 20, 144 | "log_iter": 1e4, // log every specified number of iterations 145 | "tensorboard" : true // tensorboardX enable 146 | }, 147 | 148 | "debug": { // arguments in debug mode, which will replace arguments in train 149 | "val_epoch": 1, 150 | "save_checkpoint_epoch": 1, 151 | "log_iter": 10, 152 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /config/vEMDiffuse-a.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vEMDiffuse-a", 3 | // experiments name 4 | "norm": true, 5 | "percent": false, 6 | "gpu_ids": [ 7 | 0, 8 | 1 9 | ], 10 | // gpu ids list, default is single 0 11 | "seed": -1, 12 | // random seed, seed <0 represents randomization not used 13 | "finetune_norm": false, 14 | // find the parameters to optimize 15 | "task": "3d_reconstruction", 16 | "path": { 17 | //set every part file path 18 | "base_dir": "experiments", 19 | // base path for all log except resume_state 20 | "code": "code", 21 | // code backup 22 | "tb_logger": "tb_logger", 23 | // path of tensorboard logger 24 | "results": "results", 25 | "checkpoint": "checkpoint", 26 | // "resume_state": "experiments/emdiffusie-a-phlep/4860" 27 | "resume_state": "experiments/vEMDiffuse-a/best" 28 | 29 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 30 | }, 31 | "datasets": { 32 | // train or test 33 | "train": { 34 | "which_dataset": { 35 | // import designated dataset using arguments 36 | "name": [ 37 | "data.dataset", 38 | "vEMDiffuseTrainingDatasetVolume" 39 | ], 40 | // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 41 | "args": { 42 | // arguments to initialize dataset 43 | "data_root": "/data/cxlu/phelps_test_patches_6144/", 44 | "data_len": -1, 45 | "norm": true, 46 | "percent": false, 47 | "z_times": 10, 48 | "method": "vEMDiffuse-a", 49 | "image_size": [256, 256] 50 | } 51 | }, 52 | "dataloader": { 53 | "validation_split": 20, 54 | // percent or number 55 | "args": { 56 | // arguments to initialize train_dataloader 57 | "batch_size": 3, 58 | // batch size in each gpu 59 | "num_workers": 2, 60 | "shuffle": true, 61 | "pin_memory": false, 62 | "drop_last": true 63 | }, 64 | "val_args": { 65 | // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 66 | "batch_size": 10, 67 | // batch size in each gpu 68 | "num_workers": 2, 69 | "shuffle": false, 70 | "pin_memory": false, 71 | "drop_last": false 72 | } 73 | } 74 | }, 75 | "test": { 76 | "which_dataset": { 77 | "name": "vEMDiffuseTestAnIsotropic", 78 | // import Dataset() class / function(not recommend) from default file 79 | "args": { 80 | // "data_root": "/data/cxlu/phelps_test_patches_2048/", 81 | "data_root": "/mnt/sdb/cxlu/phelps_test_patches_6144/", 82 | "norm": true, 83 | "percent": false, 84 | "phase": "val", 85 | "z_times": 10 86 | } 87 | }, 88 | "dataloader": { 89 | "args": { 90 | "batch_size": 8, 91 | "num_workers": 0, 92 | "pin_memory": true 93 | } 94 | } 95 | } 96 | }, 97 | "model": { 98 | // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 99 | "which_model": { 100 | // import designated model(trainer) using arguments 101 | "name": [ 102 | "models.vEMDiffuse_model", 103 | "DiReP" 104 | ], 105 | // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 106 | "args": { 107 | "sample_num": 8, 108 | // process of each image 109 | "task": "3d_reconstruct", 110 | "ema_scheduler": { 111 | "ema_start": 1, 112 | "ema_iter": 1, 113 | "ema_decay": 0.9999 114 | }, 115 | "optimizers": [ 116 | { 117 | "lr": 5e-5, 118 | "weight_decay": 0 119 | } 120 | ] 121 | } 122 | }, 123 | "which_networks": [ 124 | // import designated list of networks using arguments 125 | { 126 | "name": [ 127 | "models.vEMDiffuse_network", 128 | "Network" 129 | ], 130 | // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 131 | "args": { 132 | // arguments to initialize network 133 | "init_type": "kaiming", 134 | // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 135 | "module_name": "guided_diffusion_3d_2d", 136 | // sr3 | guided_diffusion 137 | "norm": true, 138 | "unet": { 139 | "in_channel": 3, 140 | "out_channel": 1, 141 | "inner_channel": 32, 142 | "channel_mults": [ 143 | 1, 144 | 2, 145 | 4, 146 | 8 147 | ], 148 | "attn_res": [ 149 | // 32, 150 | 16 151 | // 8 152 | ], 153 | "num_head_channels": 32, 154 | "res_blocks": 2, 155 | "dropout": 0.2, 156 | "image_size": 256 157 | }, 158 | "beta_schedule": { 159 | "train": { 160 | "schedule": "linear", 161 | "n_timestep": 2000, 162 | // "n_timestep": 5, // debug 163 | "linear_start": 1e-6, 164 | "linear_end": 0.01 165 | }, 166 | "test": { 167 | "schedule": "linear", 168 | "n_timestep": 1000, 169 | // "n_timestep": 5, // debug 170 | "linear_start": 1e-4, 171 | "linear_end": 0.09 172 | } 173 | } 174 | } 175 | } 176 | ], 177 | "which_losses": [ 178 | // import designated list of losses without arguments 179 | "mse_loss" 180 | // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 181 | ], 182 | "which_metrics": [ 183 | // import designated list of metrics without arguments 184 | "mae" 185 | // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 186 | ] 187 | }, 188 | "train": { 189 | // arguments for basic training 190 | "n_epoch": 1e8, 191 | // max epochs, not limited now 192 | "n_iter": 1e8, 193 | // max interations 194 | "val_epoch": 20, 195 | // valdation every specified number of epochs 196 | "save_checkpoint_epoch": 20, 197 | "log_iter": 1e4, 198 | // log every specified number of iterations 199 | "tensorboard": true 200 | // tensorboardX enable 201 | }, 202 | "debug": { 203 | // arguments in debug mode, which will replace arguments in train 204 | "val_epoch": 1, 205 | "save_checkpoint_epoch": 1, 206 | "log_iter": 10, 207 | "debug_split": 50 208 | // percent or number, change the size of dataloder to debug_split. 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /config/vEMDiffuse-i.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vEMDiffuse-i", 3 | // experiments name 4 | "norm": true, 5 | "percent": false, 6 | "gpu_ids": [ 7 | 0, 8 | 1 9 | ], 10 | // gpu ids list, default is single 0 11 | "seed": -1, 12 | // random seed, seed <0 represents randomization not used 13 | "finetune_norm": false, 14 | // find the parameters to optimize 15 | "task": "3d_reconstruction", 16 | "path": { 17 | //set every part file path 18 | "base_dir": "experiments", 19 | // base path for all log except resume_state 20 | "code": "code", 21 | // code backup 22 | "tb_logger": "tb_logger", 23 | // path of tensorboard logger 24 | "results": "results", 25 | "checkpoint": "checkpoint", 26 | "resume_state": "experiments/vEMDiffuse-i/best" 27 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 28 | }, 29 | "datasets": { 30 | // train or test 31 | "train": { 32 | "which_dataset": { 33 | // import designated dataset using arguments 34 | "name": [ 35 | "data.dataset", 36 | "vEMDiffuseTrainingDatasetVolume" 37 | ], 38 | // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 39 | "args": { 40 | // arguments to initialize dataset 41 | "data_root": "/data/cxlu/liver_3d_inter3_continous_clean/train_img", 42 | "data_len": -1, 43 | "norm": true, 44 | "percent": false, 45 | "z_times": 6, 46 | "method": "vEMDiffuse-i" 47 | } 48 | }, 49 | "dataloader": { 50 | "validation_split": 20, 51 | // percent or number 52 | "args": { 53 | // arguments to initialize train_dataloader 54 | "batch_size": 3, 55 | // batch size in each gpu 56 | "num_workers": 2, 57 | "shuffle": true, 58 | "pin_memory": false, 59 | "drop_last": true 60 | }, 61 | "val_args": { 62 | // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 63 | "batch_size": 10, 64 | // batch size in each gpu 65 | "num_workers": 2, 66 | "shuffle": false, 67 | "pin_memory": false, 68 | "drop_last": false 69 | } 70 | } 71 | }, 72 | "test": { 73 | "which_dataset": { 74 | "name": "vEMDiffuseTestAnIsotropic", 75 | // import Dataset() class / function(not recommend) from default file 76 | "args": { 77 | "data_root": "/data/cxlu/liver_3d_test_patches", 78 | // "data_root": "/lustre1/g/chem_jianglab/cxlu/kai_3d/test_patches", 79 | "norm": true, 80 | "percent": false, 81 | "phase": "val", 82 | "z_times": 6 83 | } 84 | }, 85 | "dataloader": { 86 | "args": { 87 | "batch_size": 8, 88 | "num_workers": 0, 89 | "pin_memory": true 90 | } 91 | } 92 | } 93 | }, 94 | "model": { 95 | // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 96 | "which_model": { 97 | // import designated model(trainer) using arguments 98 | "name": [ 99 | "models.vEMDiffuse_model", 100 | "DiReP" 101 | ], 102 | // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py]) 103 | "args": { 104 | "sample_num": 8, 105 | // process of each image 106 | "task": "3d_reconstruct", 107 | "ema_scheduler": { 108 | "ema_start": 1, 109 | "ema_iter": 1, 110 | "ema_decay": 0.9999 111 | }, 112 | "optimizers": [ 113 | { 114 | "lr": 5e-5, 115 | "weight_decay": 0 116 | } 117 | ] 118 | } 119 | }, 120 | "which_networks": [ 121 | // import designated list of networks using arguments 122 | { 123 | "name": [ 124 | "models.vEMDiffuse_network", 125 | "Network" 126 | ], 127 | // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py]) 128 | "args": { 129 | // arguments to initialize network 130 | "init_type": "kaiming", 131 | // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 132 | "module_name": "guided_diffusion_3d_2d", 133 | // sr3 | guided_diffusion 134 | "norm": true, 135 | "unet": { 136 | "in_channel": 3, 137 | "out_channel": 1, 138 | "inner_channel": 32, 139 | "channel_mults": [ 140 | 1, 141 | 2, 142 | 4, 143 | 8 144 | ], 145 | "attn_res": [ 146 | // 32, 147 | 16 148 | // 8 149 | ], 150 | "num_head_channels": 32, 151 | "res_blocks": 2, 152 | "dropout": 0.2, 153 | "image_size": 256 154 | }, 155 | "beta_schedule": { 156 | "train": { 157 | "schedule": "linear", 158 | "n_timestep": 2000, 159 | // "n_timestep": 5, // debug 160 | "linear_start": 1e-6, 161 | "linear_end": 0.01 162 | }, 163 | "test": { 164 | "schedule": "linear", 165 | "n_timestep": 1000, 166 | // "n_timestep": 5, // debug 167 | "linear_start": 1e-4, 168 | "linear_end": 0.09 169 | } 170 | } 171 | } 172 | } 173 | ], 174 | "which_losses": [ 175 | // import designated list of losses without arguments 176 | "mse_loss" 177 | // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 178 | ], 179 | "which_metrics": [ 180 | // import designated list of metrics without arguments 181 | "mae" 182 | // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 183 | ] 184 | }, 185 | "train": { 186 | // arguments for basic training 187 | "n_epoch": 1e8, 188 | // max epochs, not limited now 189 | "n_iter": 1e8, 190 | // max interations 191 | "val_epoch": 20, 192 | // valdation every specified number of epochs 193 | "save_checkpoint_epoch": 20, 194 | "log_iter": 1e4, 195 | // log every specified number of iterations 196 | "tensorboard": true 197 | // tensorboardX enable 198 | }, 199 | "debug": { 200 | // arguments in debug mode, which will replace arguments in train 201 | "val_epoch": 1, 202 | "save_checkpoint_epoch": 1, 203 | "log_iter": 10, 204 | "debug_split": 50 205 | // percent or number, change the size of dataloder to debug_split. 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /core/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/base_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/base_network.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/praser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/praser.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /core/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP','.tif' 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | if os.path.isfile(dir): 17 | images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')] 18 | else: 19 | images = [] 20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in sorted(fnames): 23 | if is_image_file(fname): 24 | path = os.path.join(root, fname) 25 | images.append(path) 26 | 27 | return images 28 | 29 | def pil_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class BaseDataset(data.Dataset): 33 | def __init__(self, data_root, image_size=[256, 256], loader=pil_loader): 34 | self.imgs = make_dataset(data_root) 35 | self.tfs = transforms.Compose([ 36 | transforms.Resize((image_size[0], image_size[1])), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 39 | ]) 40 | self.loader = loader 41 | 42 | def __getitem__(self, index): 43 | path = self.imgs[index] 44 | img = self.tfs(self.loader(path)) 45 | return img 46 | 47 | def __len__(self): 48 | return len(self.imgs) 49 | -------------------------------------------------------------------------------- /core/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | from functools import partial 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | import core.util as Util 11 | CustomResult = collections.namedtuple('CustomResult', 'name result') 12 | 13 | class BaseModel(): 14 | 15 | def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): 16 | """ init model with basic input, which are from __init__(**kwargs) function in inherited class """ 17 | self.opt = opt 18 | self.phase = opt['phase'] 19 | self.set_device = partial(Util.set_device, rank=opt['global_rank']) 20 | self.mean = opt['mean'] if 'mean' in opt.keys() else 1 21 | ''' optimizers and schedulers ''' 22 | self.schedulers = [] 23 | self.optimizers = [] 24 | ''' process record ''' 25 | self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size'] 26 | self.epoch = 0 27 | self.transfer_epoch = 0 28 | self.iter = 0 29 | self.phase_loader = phase_loader 30 | self.val_loader = val_loader 31 | self.metrics = metrics 32 | 33 | ''' logger to log file, which only work on GPU 0. writer to tensorboard and result file ''' 34 | self.logger = logger 35 | self.writer = writer 36 | self.results_dict = CustomResult([],[]) # {"name":[], "result":[]} 37 | 38 | def train(self): 39 | # val_log = self.val_step() 40 | while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']: 41 | self.epoch += 1 42 | if self.opt['distributed']: 43 | ''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch ''' 44 | self.phase_loader.sampler.set_epoch(self.epoch) 45 | 46 | train_log = self.train_step() 47 | 48 | ''' save logged informations into log dict ''' 49 | print('epoch {}: training start'.format(self.epoch)) 50 | train_log.update({'epoch': self.epoch, 'iters': self.iter}) 51 | 52 | ''' print logged informations to the screen and tensorboard ''' 53 | for key, value in train_log.items(): 54 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 55 | if self.epoch < 500: 56 | if self.epoch % 100 == 0: 57 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch)) 58 | self.save_everything() 59 | if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0 and self.epoch > 500: 60 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch)) 61 | self.save_everything() 62 | 63 | if self.epoch % self.opt['train']['val_epoch'] == 0 and self.epoch > 500: 64 | self.logger.info("\n\n\n------------------------------Validation Start------------------------------") 65 | if self.val_loader is None: 66 | self.logger.warning('Validation stop where dataloader is None, Skip it.') 67 | else: 68 | val_log = self.val_step() 69 | for key, value in val_log.items(): 70 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 71 | self.logger.info("\n------------------------------Validation End------------------------------\n\n") 72 | self.logger.info('Number of Epochs has reached the limit, End.') 73 | 74 | def test(self): 75 | pass 76 | 77 | @abstractmethod 78 | def train_step(self): 79 | raise NotImplementedError('You must specify how to train your networks.') 80 | 81 | @abstractmethod 82 | def val_step(self): 83 | raise NotImplementedError('You must specify how to do validation on your networks.') 84 | 85 | def test_step(self): 86 | pass 87 | 88 | def print_network(self, network): 89 | """ print network structure, only work on GPU 0 """ 90 | if self.opt['global_rank'] !=0: 91 | return 92 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 93 | network = network.module 94 | 95 | s, n = str(network), sum(map(lambda x: x.numel(), network.parameters())) 96 | net_struc_str = '{}'.format(network.__class__.__name__) 97 | self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 98 | self.logger.info(s) 99 | 100 | def save_network(self, network, network_label): 101 | """ save network structure, only work on GPU 0 """ 102 | if self.opt['global_rank'] !=0: 103 | return 104 | save_filename = '{}_{}.pth'.format(self.epoch, network_label) 105 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 106 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 107 | network = network.module 108 | state_dict = network.state_dict() 109 | for key, param in state_dict.items(): 110 | state_dict[key] = param.cpu() 111 | torch.save(state_dict, save_path) 112 | 113 | def load_network(self, network, network_label, strict=True): 114 | if self.opt['path']['resume_state'] is None: 115 | return 116 | self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label)) 117 | 118 | model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label) 119 | 120 | if not os.path.exists(model_path): 121 | self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path)) 122 | return 123 | 124 | self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path)) 125 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 126 | network = network.module 127 | network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict) 128 | 129 | def save_training_state(self): 130 | """ saves training state during training, only work on GPU 0 """ 131 | if self.opt['global_rank'] !=0: 132 | return 133 | 134 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 135 | state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []} 136 | for s in self.schedulers: 137 | state['schedulers'].append(s.state_dict()) 138 | for o in self.optimizers: 139 | state['optimizers'].append(o.state_dict()) 140 | save_filename = '{}.state'.format(self.epoch) 141 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 142 | torch.save(state, save_path) 143 | 144 | def resume_training(self): 145 | """ resume the optimizers and schedulers for training, only work when phase is test or resume training enable """ 146 | if self.phase!='train' or self. opt['path']['resume_state'] is None: 147 | return 148 | self.logger.info('Beign loading training states'.format()) 149 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 150 | 151 | state_path = "{}.state".format(self. opt['path']['resume_state']) 152 | 153 | if not os.path.exists(state_path): 154 | self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path)) 155 | return 156 | 157 | self.logger.info('Loading training state for [{:s}] ...'.format(state_path)) 158 | resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage)) 159 | 160 | resume_optimizers = resume_state['optimizers'] 161 | resume_schedulers = resume_state['schedulers'] 162 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers)) 163 | # assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers)) 164 | # for i, o in enumerate(resume_optimizers): 165 | if len(resume_schedulers)== len(self.schedulers): 166 | # self.optimizers[i].load_state_dict(o) 167 | for i, s in enumerate(resume_schedulers): 168 | self.schedulers[i].load_state_dict(s) 169 | 170 | self.epoch = resume_state['epoch'] 171 | self.transfer_epoch = resume_state['epoch'] 172 | self.iter = resume_state['iter'] 173 | 174 | def load_everything(self): 175 | pass 176 | 177 | @abstractmethod 178 | def save_everything(self): 179 | raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.') 180 | -------------------------------------------------------------------------------- /core/base_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class BaseNetwork(nn.Module): 4 | def __init__(self, init_type='kaiming', gain=0.02): 5 | super(BaseNetwork, self).__init__() 6 | self.init_type = init_type 7 | self.gain = gain 8 | 9 | 10 | def init_weights(self): 11 | """ 12 | initialize network's weights 13 | init_type: normal | xavier | kaiming | orthogonal 14 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 15 | """ 16 | 17 | def init_func(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('InstanceNorm2d') != -1: 20 | if hasattr(m, 'weight') and m.weight is not None: 21 | nn.init.constant_(m.weight.data, 1.0) 22 | if hasattr(m, 'bias') and m.bias is not None: 23 | nn.init.constant_(m.bias.data, 0.0) 24 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 25 | if self.init_type == 'normal': 26 | nn.init.normal_(m.weight.data, 0.0, self.gain) 27 | elif self.init_type == 'xavier': 28 | nn.init.xavier_normal_(m.weight.data, gain=self.gain) 29 | elif self.init_type == 'xavier_uniform': 30 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 31 | elif self.init_type == 'kaiming': 32 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 33 | elif self.init_type == 'orthogonal': 34 | nn.init.orthogonal_(m.weight.data, gain=self.gain) 35 | elif self.init_type == 'none': # uses pytorch's default init method 36 | m.reset_parameters() 37 | else: 38 | raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type) 39 | if hasattr(m, 'bias') and m.bias is not None: 40 | nn.init.constant_(m.bias.data, 0.0) 41 | 42 | self.apply(init_func) 43 | # propagate to children 44 | for m in self.children(): 45 | if hasattr(m, 'init_weights'): 46 | m.init_weights(self.init_type, self.gain) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /core/calibration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.optimize import brentq 4 | from scipy.stats import binom 5 | 6 | 7 | def fraction_missed_loss(lower_bound, upper_bound, ground_truth, avg_channels=True): 8 | misses = (lower_bound > ground_truth).float() + (upper_bound < ground_truth).float() 9 | misses[misses > 1.0] = 1.0 10 | if avg_channels: 11 | return misses.mean(), misses 12 | else: 13 | return misses.mean(dim=(1, 2)), misses 14 | 15 | 16 | def get_rcps_losses_from_outputs(cal_l, cal_u, ground_truth, lam): 17 | risk, misses = fraction_missed_loss(cal_l / lam, cal_u * lam, ground_truth) 18 | return risk 19 | 20 | 21 | def h1(y, mu): 22 | return y * np.log(y / mu) + (1 - y) * np.log((1 - y) / (1 - mu)) 23 | 24 | 25 | ### Log tail inequalities of mean 26 | def hoeffding_plus(mu, x, n): 27 | return -n * h1(np.maximum(mu, x), mu) 28 | 29 | 30 | def bentkus_plus(mu, x, n): 31 | return np.log(max(binom.cdf(np.floor(n * x), n, mu), 1e-10)) + 1 32 | 33 | 34 | def HB_mu_plus(muhat, n, delta, maxiters=1000): 35 | def _tailprob(mu): 36 | hoeffding_mu = hoeffding_plus(mu, muhat, n) 37 | bentkus_mu = bentkus_plus(mu, muhat, n) 38 | return min(hoeffding_mu, bentkus_mu) - np.log(delta) 39 | 40 | if _tailprob(1 - 1e-10) > 0: 41 | return 1 42 | else: 43 | try: 44 | return brentq(_tailprob, muhat, 1 - 1e-10, maxiter=maxiters) 45 | except: 46 | print(f"BRENTQ RUNTIME ERROR at muhat={muhat}") 47 | return 1.0 48 | 49 | 50 | def calibrate_model(cal_l, cal_u, ground_truth): 51 | alpha = 0.1 52 | delta = 0.1 53 | minimum_lambda = 0.9 54 | maximum_lambda = 1.3 55 | num_lambdas = 1000 56 | 57 | lambdas = torch.linspace(minimum_lambda, maximum_lambda, num_lambdas) 58 | dlambda = lambdas[1] - lambdas[0] 59 | lambda_hat = (lambdas[-1] + dlambda - 1e-9) 60 | 61 | for lam in reversed(lambdas): 62 | losses = get_rcps_losses_from_outputs(cal_l, cal_u, ground_truth, lam=(lam - dlambda)) 63 | 64 | Rhat = losses 65 | # print(cal_l.shape) 66 | RhatPlus = HB_mu_plus(Rhat.item(), cal_l.shape[0] * cal_l.shape[1] * cal_l.shape[2] * cal_l.shape[3], delta) 67 | 68 | print(f"\rLambda: {lam:.4f} | Rhat: {Rhat:.4f} | RhatPlus: {RhatPlus:.4f} ", end='') 69 | if Rhat >= alpha or RhatPlus > alpha: 70 | lambda_hat = lam 71 | print(f"Model's lambda_hat is {lambda_hat}") 72 | break 73 | return lambda_hat, (cal_l / lambda_hat).clamp(0., 1.), (cal_u * lambda_hat).clamp(0., 1.) 74 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import importlib 4 | from datetime import datetime 5 | import logging 6 | import pandas as pd 7 | 8 | import core.util as Util 9 | 10 | class InfoLogger(): 11 | """ 12 | use logging to record log, only work on GPU 0 by judging global_rank 13 | """ 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.rank = opt['global_rank'] 17 | self.phase = opt['phase'] 18 | 19 | self.setup_logger(None, opt['path']['experiments_root'], opt['phase'], level=logging.INFO, screen=False) 20 | self.logger = logging.getLogger(opt['phase']) 21 | self.infologger_ftns = {'info', 'warning', 'debug'} 22 | 23 | def __getattr__(self, name): 24 | if self.rank != 0: # info only print on GPU 0. 25 | def wrapper(info, *args, **kwargs): 26 | pass 27 | return wrapper 28 | if name in self.infologger_ftns: 29 | print_info = getattr(self.logger, name, None) 30 | def wrapper(info, *args, **kwargs): 31 | print_info(info, *args, **kwargs) 32 | return wrapper 33 | 34 | @staticmethod 35 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 36 | """ set up logger """ 37 | l = logging.getLogger(logger_name) 38 | formatter = logging.Formatter( 39 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 40 | log_file = os.path.join(root, '{}.log'.format(phase)) 41 | fh = logging.FileHandler(log_file, mode='a+') 42 | fh.setFormatter(formatter) 43 | l.setLevel(level) 44 | l.addHandler(fh) 45 | if screen: 46 | sh = logging.StreamHandler() 47 | sh.setFormatter(formatter) 48 | l.addHandler(sh) 49 | 50 | class VisualWriter(): 51 | """ 52 | use tensorboard to record visuals, support 'add_scalar', 'add_scalars', 'add_image', 'add_images', etc. funtion. 53 | Also integrated with save results function. 54 | """ 55 | def __init__(self, opt, logger): 56 | log_dir = opt['path']['tb_logger'] 57 | self.result_dir = opt['path']['results'] 58 | enabled = opt['train']['tensorboard'] 59 | self.rank = opt['global_rank'] 60 | self.task = opt['task'] 61 | 62 | self.writer = None 63 | self.selected_module = "" 64 | 65 | if enabled and self.rank==0: 66 | log_dir = str(log_dir) 67 | 68 | # Retrieve vizualization writer. 69 | succeeded = False 70 | for module in ["tensorboardX", "torch.utils.tensorboard"]: 71 | try: 72 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 73 | succeeded = True 74 | break 75 | except ImportError: 76 | succeeded = False 77 | self.selected_module = module 78 | 79 | if not succeeded: 80 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 81 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 82 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 83 | logger.warning(message) 84 | 85 | self.epoch = 0 86 | self.iter = 0 87 | self.phase = '' 88 | 89 | self.tb_writer_ftns = { 90 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 91 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 92 | } 93 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 94 | self.custom_ftns = {'close'} 95 | self.timer = datetime.now() 96 | 97 | def set_iter(self, epoch, iter, phase='train'): 98 | self.phase = phase 99 | self.epoch = epoch 100 | self.iter = iter 101 | 102 | def save_images(self, results, norm=True, percent=False): 103 | result_path = os.path.join(self.result_dir, self.phase) 104 | os.makedirs(result_path, exist_ok=True) 105 | result_path = os.path.join(result_path, str(self.epoch)) 106 | os.makedirs(result_path, exist_ok=True) 107 | from tifffile import imwrite 108 | import numpy as np 109 | ''' get names and corresponding images from results[OrderedDict] ''' 110 | try: 111 | names = results['name'] 112 | outputs = Util.postprocess(results['result'], out_type=np.uint8, min_max=(-1, 1), norm=norm) 113 | for i in range(len(names)): 114 | Image.fromarray(outputs[i]).save(os.path.join(result_path, names[i])) 115 | except: 116 | raise NotImplementedError('You must specify the context of name and result in save_current_results functions of model.') 117 | 118 | def close(self): 119 | self.writer.close() 120 | print('Close the Tensorboard SummaryWriter.') 121 | 122 | 123 | def __getattr__(self, name): 124 | """ 125 | If visualization is configured to use: 126 | return add_data() methods of tensorboard with additional information (step, tag) added. 127 | Otherwise: 128 | return a blank function handle that does nothing 129 | """ 130 | if name in self.tb_writer_ftns: 131 | add_data = getattr(self.writer, name, None) 132 | def wrapper(tag, data, *args, **kwargs): 133 | if add_data is not None: 134 | # add phase(train/valid) tag 135 | if name not in self.tag_mode_exceptions: 136 | tag = '{}/{}'.format(self.phase, tag) 137 | add_data(tag, data, self.iter, *args, **kwargs) 138 | return wrapper 139 | else: 140 | # default action for returning methods defined in this class, set_step() for instance. 141 | try: 142 | attr = object.__getattr__(name) 143 | except AttributeError: 144 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 145 | return attr 146 | 147 | 148 | class LogTracker: 149 | """ 150 | record training numerical indicators. 151 | """ 152 | def __init__(self, *keys, phase='train'): 153 | self.phase = phase 154 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 155 | self.reset() 156 | 157 | def reset(self): 158 | for col in self._data.columns: 159 | self._data[col].values[:] = 0 160 | 161 | def update(self, key, value, n=1): 162 | self._data.total[key] += value * n 163 | self._data.counts[key] += n 164 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 165 | 166 | def avg(self, key): 167 | return self._data.average[key] 168 | 169 | def result(self): 170 | return {'{}/{}'.format(self.phase, k):v for k, v in dict(self._data.average).items()} 171 | -------------------------------------------------------------------------------- /core/praser.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import json 4 | from pathlib import Path 5 | from datetime import datetime 6 | from functools import partial 7 | import importlib 8 | from types import FunctionType 9 | import shutil 10 | 11 | 12 | def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network', 13 | **modify_kwargs): 14 | """ 15 | finds a function handle with the name given as 'name' in config, 16 | and returns the instance initialized with corresponding args. 17 | """ 18 | if opt is None or len(opt) < 1: 19 | logger.info('Option is None when initialize {}'.format(init_type)) 20 | return None 21 | 22 | ''' default format is dict with name key ''' 23 | if isinstance(opt, str): 24 | opt = {'name': opt} 25 | logger.warning('Config is a str, converts to a dict {}'.format(opt)) 26 | 27 | name = opt['name'] 28 | ''' name can be list, indicates the file and class name of function ''' 29 | if isinstance(name, list): 30 | file_name, class_name = name[0], name[1] 31 | else: 32 | file_name, class_name = default_file_name, name 33 | 34 | if given_module is not None: 35 | module = given_module 36 | else: 37 | module = importlib.import_module(file_name) 38 | 39 | attr = getattr(module, class_name) 40 | kwargs = opt.get('args', {}) 41 | kwargs.update(modify_kwargs) 42 | ''' import class or function with args ''' 43 | if isinstance(attr, type): 44 | ret = attr(*args, **kwargs) 45 | ret.__name__ = ret.__class__.__name__ 46 | elif isinstance(attr, FunctionType): 47 | ret = partial(attr, *args, **kwargs) 48 | ret.__name__ = attr.__name__ 49 | # ret = attr 50 | logger.info('{} [{:s}() form {:s}] is created.'.format(init_type, class_name, file_name)) 51 | 52 | return ret 53 | 54 | 55 | def mkdirs(paths): 56 | if isinstance(paths, str): 57 | os.makedirs(paths, exist_ok=True) 58 | else: 59 | for path in paths: 60 | os.makedirs(path, exist_ok=True) 61 | 62 | 63 | def get_timestamp(): 64 | return datetime.now().strftime('%y%m%d_%H%M%S') 65 | 66 | 67 | def write_json(content, fname): 68 | fname = Path(fname) 69 | with fname.open('wt') as handle: 70 | json.dump(content, handle, indent=4, sort_keys=False) 71 | 72 | 73 | class NoneDict(dict): 74 | def __missing__(self, key): 75 | return None 76 | 77 | 78 | def dict_to_nonedict(opt): 79 | """ convert to NoneDict, which return None for missing key. """ 80 | if isinstance(opt, dict): 81 | new_opt = dict() 82 | for key, sub_opt in opt.items(): 83 | new_opt[key] = dict_to_nonedict(sub_opt) 84 | return NoneDict(**new_opt) 85 | elif isinstance(opt, list): 86 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 87 | else: 88 | return opt 89 | 90 | 91 | def dict2str(opt, indent_l=1): 92 | """ dict to string for logger """ 93 | msg = '' 94 | for k, v in opt.items(): 95 | if isinstance(v, dict): 96 | msg += ' ' * (indent_l * 2) + k + ':[\n' 97 | msg += dict2str(v, indent_l + 1) 98 | msg += ' ' * (indent_l * 2) + ']\n' 99 | else: 100 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 101 | return msg 102 | 103 | 104 | def parse(args): 105 | json_str = '' 106 | with open(args.config, 'r') as f: 107 | for line in f: 108 | line = line.split('//')[0] + '\n' 109 | json_str += line 110 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 111 | 112 | ''' replace the config context using args ''' 113 | opt['phase'] = args.phase 114 | if args.gpu is not None: 115 | opt['gpu_ids'] = [int(id) for id in args.gpu.split(',')] 116 | if args.batch is not None: 117 | opt['datasets'][opt['phase']]['dataloader']['args']['batch_size'] = args.batch 118 | if args.path is not None: 119 | opt['datasets'][opt['phase']]['which_dataset']['args']['data_root'] = args.path 120 | if args.z_times is not None: 121 | opt['datasets'][opt['phase']]['which_dataset']['args']['z_times'] = args.z_times 122 | if args.lr is not None: 123 | opt['model']['which_model']['args']['optimizers'][0]['lr'] = args.lr 124 | if args.step is not None: 125 | opt['model']['which_networks'][0]['args']['beta_schedule'][opt['phase']]['n_timestep'] = args.step 126 | ''' set cuda environment ''' 127 | if len(opt['gpu_ids']) > 1: 128 | opt['distributed'] = True 129 | else: 130 | opt['distributed'] = False 131 | 132 | ''' update name ''' 133 | if args.debug: 134 | opt['name'] = 'debug_{}'.format(opt['name']) 135 | elif opt['finetune_norm']: 136 | opt['name'] = 'finetune_{}'.format(opt['name']) 137 | else: 138 | opt['name'] = '{}_{}'.format(opt['phase'], opt['name']) 139 | 140 | ''' set log directory ''' 141 | experiments_root = os.path.join(opt['path']['base_dir'], '{}_{}'.format(opt['name'], get_timestamp())) 142 | mkdirs(experiments_root) 143 | print('results and model will be saved in {}'.format(experiments_root)) 144 | ''' save json ''' 145 | write_json(opt, '{}/config.json'.format(experiments_root)) 146 | 147 | ''' change folder relative hierarchy ''' 148 | opt['path']['experiments_root'] = experiments_root 149 | for key, path in opt['path'].items(): 150 | if 'resume' not in key and 'base' not in key and 'root' not in key: 151 | opt['path'][key] = os.path.join(experiments_root, path) 152 | mkdirs(opt['path'][key]) 153 | if args.resume is not None: 154 | opt['path']['resume_state'] = args.resume 155 | 156 | ''' debug mode ''' 157 | if 'debug' in opt['name']: 158 | opt['train'].update(opt['debug']) 159 | 160 | ''' code backup ''' 161 | for name in os.listdir('.'): 162 | if name in ['config', 'models', 'core', 'slurm', 'data']: 163 | dst = os.path.join(opt['path']['code'], name) 164 | if os.path.exists(dst): 165 | shutil.rmtree(dst) 166 | shutil.copytree(name, dst) 167 | # shutil.copytree(name,dst , ignore=shutil.ignore_patterns("*.pyc", "__pycache__")) 168 | if '.py' in name or '.sh' in name: 169 | shutil.copy(name, opt['path']['code']) 170 | opt['mean'] = args.mean 171 | return dict_to_nonedict(opt) 172 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | import torch 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | from torchvision.utils import make_grid 7 | import os 8 | import cv2 9 | 10 | 11 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1), norm=True): 12 | ''' 13 | Converts a torch Tensor into an image Numpy array 14 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 15 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 16 | ''' 17 | tensor = tensor.clamp_(*min_max) # clamp 18 | n_dim = tensor.dim() 19 | if n_dim == 4: 20 | n_img = len(tensor) 21 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | elif n_dim == 3: 24 | img_np = tensor.numpy() 25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 26 | elif n_dim == 2: 27 | img_np = tensor.numpy() 28 | else: 29 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 30 | img_np = ((img_np + 1) * 127.5).round() 31 | return img_np.astype(out_type).squeeze() 32 | 33 | 34 | def postprocess(images, out_type=np.uint8, min_max=(-1, 1), norm=True): 35 | return [tensor2img(image, out_type, min_max, norm) for image in images] 36 | 37 | 38 | def normalize_tensor(tensor, min_max=(-1, 1)): 39 | tensor = tensor.float().clamp_(*min_max) # clamp 40 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 41 | return tensor 42 | 43 | 44 | def set_seed(seed, gl_seed=0): 45 | """ set random seed, gl_seed used in worker_init_fn function """ 46 | if seed >= 0 and gl_seed >= 0: 47 | seed += gl_seed 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | 53 | ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. 54 | speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' 55 | if seed >= 0 and gl_seed >= 0: # slower, more reproducible 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | else: # faster, less reproducible 59 | torch.backends.cudnn.deterministic = False 60 | torch.backends.cudnn.benchmark = True 61 | 62 | 63 | def set_gpu(args, distributed=False, rank=0): 64 | """ set parameter to gpu or ddp """ 65 | if args is None: 66 | return None 67 | if distributed and isinstance(args, torch.nn.Module): 68 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, 69 | find_unused_parameters=True) 70 | else: 71 | return args.cuda() 72 | 73 | 74 | def set_device(args, distributed=False, rank=0): 75 | """ set parameter to gpu or cpu """ 76 | if torch.cuda.is_available(): 77 | if isinstance(args, list): 78 | return (set_gpu(item, distributed, rank) for item in args) 79 | elif isinstance(args, dict): 80 | return {key: set_gpu(args[key], distributed, rank) for key in args} 81 | else: 82 | args = set_gpu(args, distributed, rank) 83 | return args 84 | 85 | 86 | -------------------------------------------------------------------------------- /crop_single_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import numpy as np 5 | import cv2 6 | from tifffile import imwrite 7 | 8 | def mkdir(path): 9 | if os.path.exists(path): 10 | shutil.rmtree(path) 11 | os.mkdir(path) 12 | 13 | 14 | def crop(wf_img, save_wf_path, patch_size=256, overlap=0.125): 15 | if len(wf_img.shape) > 2: 16 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY) 17 | if wf_img.dtype == np.uint16: # convert to 8 bit 18 | cmin = wf_img.min() 19 | cmax = wf_img.max() 20 | cscale = cmax - cmin 21 | if cscale == 0: 22 | cscale = 1 23 | scale = float(255 - 0) / cscale 24 | wf_img = (wf_img - cmin) * scale + 0 25 | wf_img = np.clip(wf_img, 0, 255) + 0.5 26 | wf_img = wf_img.astype(np.uint8) 27 | stride = int(patch_size * (1 - overlap)) 28 | if len(wf_img.shape) > 2: 29 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY) 30 | border = 0 31 | 32 | x = border 33 | x_end = wf_img.shape[0] - border 34 | y_end = wf_img.shape[0] - border 35 | row = 0 36 | while x + patch_size < x_end: 37 | y = border 38 | col = 0 39 | while y + patch_size < y_end: 40 | crop_wf_img = wf_img[x: x + patch_size, y: y + patch_size] 41 | imwrite(os.path.join(save_wf_path, str(row) + '_' + str(col) + '.tif'), 42 | crop_wf_img) 43 | col += 1 44 | y += stride 45 | row += 1 46 | x += stride 47 | 48 | 49 | def test_pre(data_root, task='denoise'): 50 | target_path = os.path.join(data_root, task + '_test_crop_patches') 51 | mkdir(target_path) 52 | for file in os.listdir(data_root): 53 | if not file.endswith('tif'): 54 | continue 55 | mkdir(os.path.join(target_path, file[:-4])) 56 | save_wf_path = os.path.join(os.path.join(target_path, file[:-4], '0')) 57 | mkdir(save_wf_path) 58 | wf_file_img = cv2.imread(os.path.join(data_root, file)) 59 | if task == 'denoise': 60 | crop(wf_file_img, save_wf_path, patch_size=256, overlap=0.125) 61 | else: 62 | crop(wf_file_img, save_wf_path, patch_size=128, overlap=0.125) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--task', default="denoise") 68 | parser.add_argument('--path', help="dataset for evaluation") 69 | args = parser.parse_args() 70 | test_pre(args.path, args.task) 71 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch import Generator, randperm 6 | from torch.utils.data import DataLoader, Subset 7 | 8 | import core.util as Util 9 | from core.praser import init_obj 10 | from vEM_test_pre import recon_pre 11 | 12 | 13 | 14 | def define_dataloader(logger, opt): 15 | """ create train/test dataloader and validation dataloader, validation dataloader is None when phase is test or not GPU 0 """ 16 | '''create dataset and set random seed''' 17 | dataloader_args = opt['datasets'][opt['phase']]['dataloader']['args'] 18 | worker_init_fn = partial(Util.set_seed, gl_seed=opt['seed']) 19 | 20 | phase_dataset, val_dataset = define_dataset(logger, opt) 21 | 22 | '''create datasampler''' 23 | data_sampler = None 24 | if opt['distributed']: 25 | data_sampler = DistributedSampler(phase_dataset, shuffle=dataloader_args.get('shuffle', False), 26 | num_replicas=opt['world_size'], rank=opt['global_rank']) 27 | dataloader_args.update({'shuffle': False}) # sampler option is mutually exclusive with shuffle 28 | ''' create dataloader and validation dataloader ''' 29 | dataloader = DataLoader(phase_dataset, sampler=data_sampler, worker_init_fn=worker_init_fn, **dataloader_args) 30 | 31 | ''' val_dataloader don't use DistributedSampler to run only GPU 0! ''' 32 | if opt['global_rank'] == 0 and val_dataset is not None: 33 | dataloader_args.update(opt['datasets'][opt['phase']]['dataloader'].get('val_args', {})) 34 | val_dataloader = DataLoader(val_dataset, worker_init_fn=worker_init_fn, **dataloader_args) 35 | else: 36 | val_dataloader = None 37 | return dataloader, val_dataloader 38 | 39 | 40 | def define_dataset(logger, opt): 41 | ''' loading Dataset() class from given file's name ''' 42 | dataset_opt = opt['datasets'][opt['phase']]['which_dataset'] 43 | if opt['phase'] != 'train': 44 | if opt['task'] == '3d_reconstruction': 45 | dataset_opt['args']['data_root'] = recon_pre(dataset_opt['args']['data_root']) 46 | 47 | phase_dataset = init_obj(dataset_opt, logger, default_file_name='data.dataset', init_type='Dataset') 48 | val_dataset = None 49 | 50 | valid_len = 0 51 | data_len = len(phase_dataset) 52 | if 'debug' in opt['name']: 53 | debug_split = opt['debug'].get('debug_split', 1.0) 54 | if isinstance(debug_split, int): 55 | data_len = debug_split 56 | else: 57 | data_len *= debug_split 58 | 59 | dataloder_opt = opt['datasets'][opt['phase']]['dataloader'] 60 | valid_split = dataloder_opt.get('validation_split', 0) 61 | 62 | ''' divide validation dataset, valid_split==0 when phase is test or validation_split is 0. ''' 63 | if valid_split > 0.0 or 'debug' in opt['name']: 64 | if isinstance(valid_split, int): 65 | assert valid_split < data_len, "Validation set size is configured to be larger than entire dataset." 66 | valid_len = valid_split 67 | else: 68 | valid_len = int(data_len * valid_split) 69 | data_len -= valid_len 70 | phase_dataset, val_dataset = subset_split(dataset=phase_dataset, lengths=[data_len, valid_len], 71 | generator=Generator().manual_seed(opt['seed'])) 72 | 73 | logger.info('Dataset for {} have {} samples.'.format(opt['phase'], data_len)) 74 | if opt['phase'] == 'train': 75 | logger.info('Dataset for {} have {} samples.'.format('val', valid_len)) 76 | return phase_dataset, val_dataset 77 | 78 | 79 | def subset_split(dataset, lengths, generator): 80 | """ 81 | split a dataset into non-overlapping new datasets of given lengths. main code is from random_split function in pytorch 82 | """ 83 | indices = randperm(sum(lengths), generator=generator).tolist() 84 | Subsets = [] 85 | for offset, length in zip(np.add.accumulate(lengths), lengths): 86 | if length == 0: 87 | Subsets.append(None) 88 | else: 89 | Subsets.append(Subset(dataset, indices[offset - length: offset])) 90 | return Subsets 91 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /demo/denoise_demo.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/denoise_demo.tif -------------------------------------------------------------------------------- /demo/microns_demo/0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/microns_demo/0.tif -------------------------------------------------------------------------------- /demo/microns_demo/1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/microns_demo/1.tif -------------------------------------------------------------------------------- /demo/mouse_liver_demo/0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/mouse_liver_demo/0.tif -------------------------------------------------------------------------------- /demo/mouse_liver_demo/1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/mouse_liver_demo/1.tif -------------------------------------------------------------------------------- /demo/super_res_demo.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/super_res_demo.tif -------------------------------------------------------------------------------- /emdiffuse_conifg.py: -------------------------------------------------------------------------------- 1 | class EMDiffuseConfig(): 2 | 3 | def __init__(self, config, path, phase, batch_size, lr=5e-5, resume=None, gpu='0', subsample=None, port='21012', mean=2, step=None): 4 | self.path = path 5 | self.config = config 6 | self.phase = phase 7 | self.batch = batch_size 8 | self.gpu = gpu 9 | self.debug = False 10 | self.z_times = subsample 11 | self.port = port 12 | self.resume = resume 13 | self.mean = mean 14 | self.lr = lr 15 | self.step=step 16 | 17 | def __getattr__(self, item): 18 | # This method is called when an attribute access is attempted. 19 | try: 20 | return self.__dict__[item] 21 | except KeyError: 22 | return None 23 | 24 | def __setattr__(self, key, value): 25 | # This method allows setting attributes directly. 26 | self.__dict__[key] = value 27 | 28 | def __contains__(self, item): 29 | # This enables the use of 'in' to check for attribute existence. 30 | return item in self.__dict__ 31 | -------------------------------------------------------------------------------- /models/EMDiffuse_network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from inspect import isfunction 4 | from functools import partial 5 | import numpy as np 6 | from tqdm import tqdm 7 | from core.base_network import BaseNetwork 8 | 9 | 10 | class Network(BaseNetwork): 11 | def __init__(self, unet, beta_schedule, norm=True, module_name='sr3', **kwargs): 12 | super(Network, self).__init__(**kwargs) 13 | 14 | from .guided_diffusion_modules.unet import UNet 15 | self.denoise_fn = UNet(**unet) 16 | self.beta_schedule = beta_schedule 17 | self.norm = norm 18 | 19 | def set_loss(self, loss_fn): 20 | self.loss_fn = loss_fn 21 | 22 | def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'): 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 25 | betas = make_beta_schedule(**self.beta_schedule[phase]) 26 | betas = betas.detach().cpu().numpy() if isinstance( 27 | betas, torch.Tensor) else betas 28 | alphas = 1. - betas 29 | 30 | timesteps, = betas.shape 31 | self.num_timesteps = int(timesteps) 32 | 33 | gammas = np.cumprod(alphas, axis=0) 34 | gammas_prev = np.append(1., gammas[:-1]) 35 | 36 | # calculations for diffusion q(x_t | x_{t-1}) and others 37 | self.register_buffer('gammas', to_torch(gammas)) 38 | self.register_buffer('sqrt_recip_gammas', to_torch(np.sqrt(1. / gammas))) 39 | self.register_buffer('sqrt_recipm1_gammas', to_torch(np.sqrt(1. / gammas - 1))) 40 | 41 | # calculations for posterior q(x_{t-1} | x_t, x_0) 42 | posterior_variance = betas * (1. - gammas_prev) / (1. - gammas) 43 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 44 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 45 | self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(gammas_prev) / (1. - gammas))) 46 | self.register_buffer('posterior_mean_coef2', to_torch((1. - gammas_prev) * np.sqrt(alphas) / (1. - gammas))) 47 | 48 | def predict_start_from_noise(self, y_t, t, noise): 49 | return ( 50 | extract(self.sqrt_recip_gammas, t, y_t.shape) * y_t - 51 | extract(self.sqrt_recipm1_gammas, t, y_t.shape) * noise 52 | ) 53 | 54 | def q_posterior(self, y_0_hat, y_t, t): 55 | posterior_mean = ( 56 | extract(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat + 57 | extract(self.posterior_mean_coef2, t, y_t.shape) * y_t 58 | ) 59 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, y_t.shape) 60 | return posterior_mean, posterior_log_variance_clipped 61 | 62 | def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None): 63 | noise_level = extract(self.gammas, t, x_shape=(1, 1)).to(y_t.device) 64 | y_0_hat = self.predict_start_from_noise( 65 | y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level)) 66 | 67 | if clip_denoised: # todo: clip 68 | if self.norm: 69 | y_0_hat.clamp_(-1., 1.) 70 | else: 71 | y_0_hat.clamp_(0., 1.) 72 | 73 | model_mean, posterior_log_variance = self.q_posterior( 74 | y_0_hat=y_0_hat, y_t=y_t, t=t) 75 | return model_mean, posterior_log_variance, y_0_hat 76 | 77 | def q_sample(self, y_0, sample_gammas, noise=None): 78 | noise = default(noise, lambda: torch.randn_like(y_0)) 79 | return ( 80 | sample_gammas.sqrt() * y_0 + 81 | (1 - sample_gammas).sqrt() * noise 82 | ) 83 | 84 | @torch.no_grad() 85 | def p_sample(self, y_t, t, clip_denoised=True, y_cond=None, adjust=False): 86 | model_mean, model_log_variance, y_0_hat = self.p_mean_variance( 87 | y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond) 88 | 89 | noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t) 90 | if adjust: 91 | if t[0] < (self.num_timesteps * 0.2): 92 | mean_diff = model_mean.view(model_mean.size(0), -1).mean(1) - y_cond.view(y_cond.size(0), -1).mean(1) 93 | mean_diff = mean_diff.view(model_mean.size(0), 1, 1, 1) 94 | model_mean = model_mean - 0.5 * mean_diff.repeat( 95 | (1, model_mean.shape[1], model_mean.shape[2], model_mean.shape[3])) 96 | return model_mean + noise * (0.5 * model_log_variance).exp(), y_0_hat 97 | 98 | @torch.no_grad() 99 | def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8, adjust=False): 100 | b, *_ = y_cond.shape 101 | 102 | assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num' 103 | sample_inter = (self.num_timesteps // sample_num) 104 | if y_0 is not None: 105 | y_t = default(y_t, lambda: torch.randn_like(y_0)) 106 | else: 107 | y_t = default(y_t, lambda: torch.randn_like(y_cond)) 108 | ret_arr = y_t 109 | 110 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 111 | t = torch.full((b,), i, device=y_cond.device, dtype=torch.long) 112 | y_t, y_0_hat = self.p_sample(y_t, t, y_cond=y_cond, adjust=adjust) 113 | if mask is not None: 114 | y_t = y_0 * (1. - mask) + mask * y_t 115 | if i % sample_inter == 0: 116 | ret_arr = torch.cat([ret_arr, y_0_hat], dim=0) 117 | return y_t, ret_arr 118 | 119 | def forward(self, y_0, y_cond=None, mask=None, noise=None): 120 | # sampling from p(gammas) 121 | b, _, _, _ = y_0.shape 122 | t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long() 123 | gamma_t1 = extract(self.gammas, t - 1, x_shape=(1, 1)) 124 | sqrt_gamma_t2 = extract(self.gammas, t, x_shape=(1, 1)) 125 | sample_gammas = (sqrt_gamma_t2 - gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1 # Todo: why 126 | sample_gammas = sample_gammas.view(b, -1) 127 | if noise is None: 128 | noise = torch.randn_like(y_0) 129 | # noise = default(noise, lambda: torch.randn_like(y_0)) 130 | y_noisy = self.q_sample( 131 | y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise) 132 | 133 | if mask is not None: 134 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy * mask + (1. - mask) * y_0], dim=1), sample_gammas) 135 | loss = self.loss_fn(mask * noise, mask * noise_hat) 136 | else: 137 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas) 138 | loss = self.loss_fn(noise_hat, noise) 139 | return loss 140 | 141 | 142 | # gaussian diffusion trainer class 143 | def exists(x): 144 | return x is not None 145 | 146 | 147 | def default(val, d): 148 | if exists(val): 149 | return val 150 | return d() if isfunction(d) else d 151 | 152 | 153 | def extract(a, t, x_shape=(1, 1, 1, 1)): 154 | b, *_ = t.shape 155 | out = a.gather(-1, t) 156 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 157 | 158 | 159 | # beta_schedule function 160 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 161 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 162 | warmup_time = int(n_timestep * warmup_frac) 163 | betas[:warmup_time] = np.linspace( 164 | linear_start, linear_end, warmup_time, dtype=np.float64) 165 | return betas 166 | 167 | 168 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3): 169 | if schedule == 'quad': 170 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 171 | n_timestep, dtype=np.float64) ** 2 172 | elif schedule == 'linear': 173 | betas = np.linspace(linear_start, linear_end, 174 | n_timestep, dtype=np.float64) 175 | elif schedule == 'warmup10': 176 | 177 | betas = _warmup_beta(linear_start, linear_end, 178 | n_timestep, 0.1) 179 | elif schedule == 'warmup50': 180 | betas = _warmup_beta(linear_start, linear_end, 181 | n_timestep, 0.5) 182 | elif schedule == 'const': 183 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 184 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 185 | betas = 1. / np.linspace(n_timestep, 186 | 1, n_timestep, dtype=np.float64) 187 | elif schedule == "cosine": 188 | timesteps = ( 189 | torch.arange(n_timestep + 1, dtype=torch.float64) / 190 | n_timestep + cosine_s 191 | ) 192 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 193 | alphas = torch.cos(alphas).pow(2) 194 | alphas = alphas / alphas[0] 195 | betas = 1 - alphas[1:] / alphas[:-1] 196 | betas = betas.clamp(max=0.999) 197 | else: 198 | raise NotImplementedError(schedule) 199 | return betas 200 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from core.praser import init_obj 2 | import torch 3 | import warnings 4 | from core.logger import VisualWriter, InfoLogger 5 | import core.praser as Praser 6 | import core.util as Util 7 | from data import define_dataloader 8 | 9 | def create_model(**cfg_model): 10 | """ create_model """ 11 | opt = cfg_model['opt'] 12 | logger = cfg_model['logger'] 13 | 14 | model_opt = opt['model']['which_model'] 15 | model_opt['args'].update(cfg_model) 16 | model = init_obj(model_opt, logger, default_file_name='models.model', init_type='Model') 17 | 18 | return model 19 | 20 | 21 | def define_network(logger, opt, network_opt): 22 | """ define network with weights initialization """ 23 | net = init_obj(network_opt, logger, default_file_name='models.network', init_type='Network') 24 | 25 | if opt['phase'] == 'train': 26 | logger.info('Network [{}] weights initialize using [{:s}] method.'.format(net.__class__.__name__, 27 | network_opt['args'].get('init_type', 28 | 'default'))) 29 | net.init_weights() 30 | return net 31 | 32 | 33 | def define_loss(logger, loss_opt): 34 | return init_obj(loss_opt, logger, default_file_name='models.loss', init_type='Loss') 35 | 36 | 37 | def define_metric(logger, metric_opt): 38 | return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric') 39 | 40 | 41 | def create_EMDiffuse(opt): 42 | gpu=0 43 | if 'local_rank' not in opt: 44 | opt['local_rank'] = opt['global_rank'] = gpu 45 | if opt['distributed']: 46 | torch.cuda.set_device(int(opt['local_rank'])) 47 | print('using GPU {} for training'.format(int(opt['local_rank']))) 48 | torch.distributed.init_process_group(backend='nccl', 49 | init_method=opt['init_method'], 50 | world_size=opt['world_size'], 51 | rank=opt['global_rank'], 52 | group_name='mtorch' 53 | ) 54 | '''set seed and and cuDNN environment ''' 55 | torch.backends.cudnn.enabled = False 56 | # warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True') 57 | Util.set_seed(opt['seed']) 58 | 59 | ''' set logger ''' 60 | phase_logger = InfoLogger(opt) 61 | phase_writer = VisualWriter(opt, phase_logger) 62 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root'])) 63 | 64 | '''set networks and dataset''' 65 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. 66 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']] 67 | 68 | ''' set metrics, loss, optimizer and schedulers ''' 69 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']] 70 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']] 71 | 72 | model = create_model( 73 | opt=opt, 74 | networks=networks, 75 | phase_loader=phase_loader, 76 | val_loader=val_loader, 77 | losses=losses, 78 | metrics=metrics, 79 | logger=phase_logger, 80 | writer=phase_writer 81 | ) 82 | return model 83 | 84 | -------------------------------------------------------------------------------- /models/__pycache__/EMDiffuse_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/EMDiffuse_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/EMDiffuse_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/EMDiffuse_network.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /models/guided_diffusion_modules/__pycache__/nn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/nn.cpython-37.pyc -------------------------------------------------------------------------------- /models/guided_diffusion_modules/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/guided_diffusion_modules/__pycache__/unet_jit2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/unet_jit2.cpython-37.pyc -------------------------------------------------------------------------------- /models/guided_diffusion_modules/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class GroupNorm32(nn.GroupNorm): 12 | def forward(self, x): 13 | return super().forward(x.float()).type(x.dtype) 14 | 15 | 16 | def zero_module(module): 17 | """ 18 | Zero out the parameters of a module and return it. 19 | """ 20 | for p in module.parameters(): 21 | p.detach().zero_() 22 | return module 23 | 24 | 25 | def scale_module(module, scale): 26 | """ 27 | Scale the parameters of a module and return it. 28 | """ 29 | for p in module.parameters(): 30 | p.detach().mul_(scale) 31 | return module 32 | 33 | 34 | def mean_flat(tensor): 35 | """ 36 | Take the mean over all non-batch dimensions. 37 | """ 38 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 39 | 40 | 41 | def normalization(channels, group_num): 42 | """ 43 | Make a standard normalization layer. 44 | 45 | :param channels: number of input channels. 46 | :return: an nn.Module for normalization. 47 | """ 48 | 49 | # return GroupNorm32(group_num, channels) 50 | return nn.GroupNorm(group_num, channels) # todo: normalization changed 51 | 52 | def Layernormalization(channels): 53 | """ 54 | Make a standard normalization layer. 55 | 56 | :param channels: number of input channels. 57 | :return: an nn.Module for normalization. 58 | """ 59 | 60 | return nn.LayerNorm(channels) 61 | 62 | def checkpoint(func, inputs, params, flag): 63 | """ 64 | Evaluate a function without caching intermediate activations, allowing for 65 | reduced memory at the expense of extra compute in the backward pass. 66 | 67 | :param func: the function to evaluate. 68 | :param inputs: the argument sequence to pass to `func`. 69 | :param params: a sequence of parameters `func` depends on but does not 70 | explicitly take as arguments. 71 | :param flag: if False, disable gradient checkpointing. 72 | """ 73 | if flag: 74 | args = tuple(inputs) + tuple(params) 75 | return CheckpointFunction.apply(func, len(inputs), *args) 76 | else: 77 | return func(*inputs) 78 | 79 | 80 | class CheckpointFunction(torch.autograd.Function): 81 | @staticmethod 82 | def forward(ctx, run_function, length, *args): 83 | ctx.run_function = run_function 84 | ctx.input_tensors = list(args[:length]) 85 | ctx.input_params = list(args[length:]) 86 | with torch.no_grad(): 87 | output_tensors = ctx.run_function(*ctx.input_tensors) 88 | return output_tensors 89 | 90 | @staticmethod 91 | def backward(ctx, *output_grads): 92 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 93 | with torch.enable_grad(): 94 | # Fixes a bug where the first op in run_function modifies the 95 | # Tensor storage in place, which is not allowed for detach()'d 96 | # Tensors. 97 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 98 | output_tensors = ctx.run_function(*shallow_copies) 99 | input_grads = torch.autograd.grad( 100 | output_tensors, 101 | ctx.input_tensors + ctx.input_params, 102 | output_grads, 103 | allow_unused=True, 104 | ) 105 | del ctx.input_tensors 106 | del ctx.input_params 107 | del output_tensors 108 | return (None, None) + input_grads 109 | 110 | 111 | def count_flops_attn(model, _x, y): 112 | """ 113 | A counter for the `thop` package to count the operations in an 114 | attention operation. 115 | Meant to be used like: 116 | macs, params = thop.profile( 117 | model, 118 | inputs=(inputs, timestamps), 119 | custom_ops={QKVAttention: QKVAttention.count_flops}, 120 | ) 121 | """ 122 | b, c, *spatial = y[0].shape 123 | num_spatial = int(np.prod(spatial)) 124 | # We perform two matmuls with the same number of ops. 125 | # The first computes the weight matrix, the second computes 126 | # the combination of the value vectors. 127 | matmul_ops = 2 * b * (num_spatial ** 2) * c 128 | model.total_ops += torch.DoubleTensor([matmul_ops]) 129 | 130 | 131 | def gamma_embedding(gammas, dim:int, max_period:int=10000): 132 | """ 133 | Create sinusoidal timestep embeddings. 134 | :param gammas: a 1-D Tensor of N indices, one per batch element. 135 | These may be fractional. 136 | :param dim: the dimension of the output. 137 | :param max_period: controls the minimum frequency of the embeddings. 138 | :return: an [N x dim] Tensor of positional embeddings. 139 | """ 140 | half = dim // 2 141 | freqs = torch.exp( 142 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 143 | ).to(device=gammas.device) 144 | args = gammas[:, None].float() * freqs[None] 145 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 146 | if dim % 2: 147 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 148 | return embedding -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | # class mse_loss(nn.Module): 8 | # def __init__(self) -> None: 9 | # super().__init__() 10 | # self.loss_fn = nn.MSELoss() 11 | # def forward(self, output, target): 12 | # return self.loss_fn(output, target) 13 | 14 | 15 | def mse_loss(output, target): 16 | return F.mse_loss(output, target) 17 | 18 | 19 | def l1_loss(output, target): 20 | return F.l1_loss(output, target) 21 | 22 | 23 | def loss_predict_loss(out, target, pred_loss): 24 | target_loss = F.mse_loss(out, target, reduction='none') 25 | return torch.sum(target_loss) / ( 26 | target.shape[0] * target.shape[1] * target.shape[2] * target.shape[3]), LossPredLoss(pred_loss, 27 | target_loss) 28 | 29 | 30 | def pin_loss(q_upper, q_lower, target): 31 | q_lo_loss = PinballLoss(0.05) 32 | q_hi_loss = PinballLoss(0.95) 33 | loss = q_lo_loss(q_lower, target) + q_hi_loss(q_upper, target) 34 | return loss 35 | 36 | 37 | def SampleLossPredLoss(input, target, margin=1.0, reduction='mean'): 38 | # input: (b, w * h) 39 | 40 | b = input.shape[0] 41 | target = target.detach() 42 | target = target.view(b, -1) 43 | target = torch.mean(target, dim=1) 44 | input = input.view(b, -1) 45 | input = torch.mean(input, dim=1) 46 | assert input.shape[0] % 2 == 0, 'the batch size is not even.' 47 | assert input.shape == input.flip(0).shape 48 | input = (input - input.flip(0))[ 49 | :input.shape[0] // 2] # [l_1 - l_2B, l_2 - l_2B-1, ... , l_B - l_B+1], where batch_size = 2B 50 | target = (target - target.flip(0))[:target.shape[0] // 2] 51 | target = target.detach() 52 | one = 2 * torch.sign(torch.clamp(target, min=0)) - 1 # 1 operation which is defined by the authors 53 | if reduction == 'mean': 54 | loss = torch.sum(torch.clamp(margin - one * input, min=0)) 55 | loss = loss / (input.size(0)) # Note that the size of input is already halved 56 | elif reduction == 'none': 57 | loss = torch.clamp(margin - one * input, min=0) 58 | else: 59 | NotImplementedError() 60 | return loss 61 | 62 | 63 | def LossPredLoss(input, target, margin=1.0, reduction='mean'): 64 | # input: (b, w * h) 65 | 66 | b = input.shape[0] 67 | target = target.view(b, -1) 68 | input = input.view(b, -1) 69 | assert input.shape[1] % 2 == 0, 'the batch size is not even.' 70 | assert input.shape == input.flip(1).shape 71 | index_shuffle = torch.randperm(input.shape[1]) 72 | 73 | input = input[:, index_shuffle] 74 | target = target[:, index_shuffle] 75 | input = (input - input.flip(1))[:, 76 | :input.shape[1] // 2] # [l_1 - l_2B, l_2 - l_2B-1, ... , l_B - l_B+1], where batch_size = 2B 77 | target = (target - target.flip(1))[:, :target.shape[1] // 2] 78 | target = target.detach() 79 | one = 2 * torch.sign(torch.clamp(target, min=0)) - 1 # 1 operation which is defined by the authors 80 | if reduction == 'mean': 81 | loss = torch.sum(torch.clamp(margin - one * input, min=0)) 82 | loss = loss / (input.size(0) * input.size(1)) # Note that the size of input is already halved 83 | elif reduction == 'none': 84 | loss = torch.clamp(margin - one * input, min=0) 85 | else: 86 | NotImplementedError() 87 | return loss 88 | 89 | 90 | def pin_loss2(q_lower, q_uper, out, target): 91 | q_lo_loss = PinballLoss(0.05) 92 | q_hi_loss = PinballLoss(0.95) 93 | loss = q_lo_loss(q_lower, target) + q_hi_loss(q_uper, target) + mse_loss(out, target) 94 | return loss 95 | 96 | 97 | def mse_var_loss(output, target, variance, weight=1): 98 | variance = weight * variance 99 | loss1 = torch.mul(torch.exp(-variance), (output - target) ** 2) 100 | loss2 = variance 101 | loss = .5 * (loss1 + loss2) 102 | return loss.mean() 103 | 104 | def mse_var_loss2(output, target, variance, var_weight): 105 | # print((1-var_weight).max(), (1-var_weight).min()) 106 | variance = variance * torch.clamp(var_weight, min=1e-2, max=1) 107 | loss1 = torch.mul(torch.exp(-variance), (output - target) ** 2) 108 | loss2 = variance 109 | loss = .5 * (loss1 + loss2) 110 | return loss.mean() 111 | 112 | 113 | def mse_var_loss_sample(output, target, variance, weight=1): 114 | # variance = 4 * variance 115 | target_loss = (output - target) ** 2 116 | loss1 = torch.mul(torch.exp(-variance), target_loss) 117 | loss2 = variance 118 | loss3 = SampleLossPredLoss(variance, target_loss,reduction='mean') 119 | var_loss = .5 * (loss1 + loss2) 120 | 121 | return var_loss.mean() + loss3 122 | 123 | 124 | 125 | class MSE_VAR(nn.Module): 126 | def __init__(self, var_weight): 127 | super(MSE_VAR, self).__init__() 128 | self.var_weight = var_weight 129 | 130 | def forward(self, results, label): 131 | mean, var = results['mean'], results['var'] 132 | var = self.var_weight * var 133 | 134 | loss1 = torch.mul(torch.exp(-var), (mean - label) ** 2) 135 | loss2 = var 136 | loss = .5 * (loss1 + loss2) 137 | return loss.mean() 138 | 139 | 140 | class PinballLoss(): 141 | 142 | def __init__(self, quantile=0.10, reduction='mean'): 143 | self.quantile = quantile 144 | assert 0 < self.quantile 145 | assert self.quantile < 1 146 | self.reduction = reduction 147 | 148 | def __call__(self, output, target): 149 | assert output.shape == target.shape 150 | loss = torch.zeros_like(target, dtype=torch.float) 151 | error = output - target 152 | smaller_index = error < 0 153 | bigger_index = 0 < error 154 | loss[smaller_index] = self.quantile * (abs(error)[smaller_index]) 155 | loss[bigger_index] = (1 - self.quantile) * (abs(error)[bigger_index]) 156 | 157 | if self.reduction == 'sum': 158 | loss = loss.sum() 159 | if self.reduction == 'mean': 160 | loss = loss.mean() 161 | 162 | return loss 163 | 164 | 165 | class FocalLoss(nn.Module): 166 | def __init__(self, gamma=2, alpha=None, size_average=True): 167 | super(FocalLoss, self).__init__() 168 | self.gamma = gamma 169 | self.alpha = alpha 170 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) 171 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 172 | self.size_average = size_average 173 | 174 | def forward(self, input, target): 175 | if input.dim() > 2: 176 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 177 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 178 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 179 | target = target.view(-1, 1) 180 | 181 | logpt = F.log_softmax(input) 182 | logpt = logpt.gather(1, target) 183 | logpt = logpt.view(-1) 184 | pt = Variable(logpt.data.exp()) 185 | 186 | if self.alpha is not None: 187 | if self.alpha.type() != input.data.type(): 188 | self.alpha = self.alpha.type_as(input.data) 189 | at = self.alpha.gather(0, target.data.view(-1)) 190 | logpt = logpt * Variable(at) 191 | 192 | loss = -1 * (1 - pt) ** self.gamma * logpt 193 | if self.size_average: 194 | return loss.mean() 195 | else: 196 | return loss.sum() 197 | -------------------------------------------------------------------------------- /models/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from scipy.stats import entropy 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from torchvision.models.inception import inception_v3 9 | 10 | 11 | def mae(input, target): 12 | with torch.no_grad(): 13 | loss = nn.L1Loss() 14 | output = loss(input, target) 15 | return output 16 | 17 | 18 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 19 | """Computes the inception score of the generated images imgs 20 | 21 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 22 | cuda -- whether or not to run on GPU 23 | batch_size -- batch size for feeding into Inception v3 24 | splits -- number of splits 25 | """ 26 | N = len(imgs) 27 | 28 | assert batch_size > 0 29 | assert N > batch_size 30 | 31 | # Set up dtype 32 | if cuda: 33 | dtype = torch.cuda.FloatTensor 34 | else: 35 | if torch.cuda.is_available(): 36 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 37 | dtype = torch.FloatTensor 38 | 39 | # Set up dataloader 40 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 41 | 42 | # Load inception model 43 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 44 | inception_model.eval() 45 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 46 | def get_pred(x): 47 | if resize: 48 | x = up(x) 49 | x = inception_model(x) 50 | return F.softmax(x).data.cpu().numpy() 51 | 52 | # Get predictions 53 | preds = np.zeros((N, 1000)) 54 | 55 | for i, batch in enumerate(dataloader, 0): 56 | batch = batch.type(dtype) 57 | batchv = Variable(batch) 58 | batch_size_i = batch.size()[0] 59 | 60 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 61 | 62 | # Now compute the mean kl-div 63 | split_scores = [] 64 | 65 | for k in range(splits): 66 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 67 | py = np.mean(part, axis=0) 68 | scores = [] 69 | for i in range(part.shape[0]): 70 | pyx = part[i, :] 71 | scores.append(entropy(pyx, py)) 72 | split_scores.append(np.exp(np.mean(scores))) 73 | 74 | return np.mean(split_scores), np.std(split_scores) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.0 2 | torchvision>=0.14.0 3 | matplotlib 4 | tensorboard 5 | scipy 6 | tifffile 7 | opencv-python 8 | pandas 9 | imutils 10 | image_registration 11 | numpy>=1.23.0 12 | pytest 13 | warmup_scheduler 14 | tqdm 15 | imagecodecs -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | import torch 5 | import torch.multiprocessing as mp 6 | 7 | from core.logger import VisualWriter, InfoLogger 8 | import core.praser as Praser 9 | import core.util as Util 10 | from data import define_dataloader 11 | from models import create_model, define_network, define_loss, define_metric 12 | 13 | 14 | def main_worker(gpu, ngpus_per_node, opt): 15 | """ threads running on each GPU """ 16 | if 'local_rank' not in opt: 17 | opt['local_rank'] = opt['global_rank'] = gpu 18 | if opt['distributed']: 19 | torch.cuda.set_device(int(opt['local_rank'])) 20 | print('using GPU {} for training'.format(int(opt['local_rank']))) 21 | torch.distributed.init_process_group(backend='nccl', 22 | init_method=opt['init_method'], 23 | world_size=opt['world_size'], 24 | rank=opt['global_rank'], 25 | group_name='mtorch' 26 | ) 27 | '''set seed and and cuDNN environment ''' 28 | torch.backends.cudnn.enabled = False 29 | warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True') 30 | Util.set_seed(opt['seed']) 31 | 32 | ''' set logger ''' 33 | phase_logger = InfoLogger(opt) 34 | phase_writer = VisualWriter(opt, phase_logger) 35 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root'])) 36 | 37 | '''set networks and dataset''' 38 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. 39 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']] 40 | 41 | ''' set metrics, loss, optimizer and schedulers ''' 42 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']] 43 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']] 44 | 45 | model = create_model( 46 | opt=opt, 47 | networks=networks, 48 | phase_loader=phase_loader, 49 | val_loader=val_loader, 50 | losses=losses, 51 | metrics=metrics, 52 | logger=phase_logger, 53 | writer=phase_writer 54 | ) 55 | 56 | phase_logger.info('Begin model {}.'.format(opt['phase'])) 57 | 58 | if opt['phase'] == 'train': 59 | model.train() 60 | else: 61 | model.test() 62 | 63 | phase_writer.close() 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('-c', '--config', type=str, default='config/EMDiffuse-n.json', 69 | help='JSON file for configuration') 70 | parser.add_argument('--path', type=str, default=None, help='patch of cropped patches') 71 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], help='Run train or test', default='train') 72 | parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu') 73 | parser.add_argument('--gpu', type=str, default=None, help='the gpu devices used') 74 | parser.add_argument('-d', '--debug', action='store_true') 75 | parser.add_argument('-z', '--z_times', default=None, type=int, help='The anisotropy time of the volume em') 76 | parser.add_argument('-P', '--port', default='21012', type=str) 77 | parser.add_argument('--mean', type=int, default=2, 78 | help='EMDiffuse samples one plausible solution from distribution. The number of samples you ' 79 | 'want to generate and averaging') 80 | parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate') 81 | parser.add_argument('--step', type=int, default=None, help='Steps of the diffusion process. More steps lead to ' 82 | 'better image quality. ') 83 | parser.add_argument('--resume', type=str, default=None, 84 | help='Resume state path and load epoch number e.g., experiments/EMDiffuse-n/2720') 85 | 86 | ''' parser configs ''' 87 | args = parser.parse_args() 88 | 89 | opt = Praser.parse(args) 90 | 91 | ''' cuda devices ''' 92 | gpu_str = ','.join(str(x) for x in opt['gpu_ids']) 93 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str 94 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str)) 95 | 96 | ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training''' 97 | # [Todo]: multi GPU on multi machine 98 | if opt['distributed']: 99 | ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count() 100 | opt['world_size'] = ngpus_per_node 101 | opt['init_method'] = 'tcp://127.0.0.1:' + args.port 102 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 103 | else: 104 | opt['world_size'] = 1 105 | main_worker(0, 1, opt) 106 | -------------------------------------------------------------------------------- /test_pre.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import cv2 6 | from tifffile import imwrite 7 | 8 | 9 | def mkdir(path): 10 | if os.path.exists(path): 11 | shutil.rmtree(path) 12 | os.mkdir(path) 13 | 14 | 15 | def process_denoise_pair(wf_img, save_wf_path, path_size=256, stride=224): 16 | # print(wf_image.shape) 17 | if len(wf_img.shape) > 2: 18 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY) 19 | board = 0 20 | x = board 21 | x_end = wf_img.shape[0] - board 22 | y_end = wf_img.shape[0] - board 23 | row = 0 24 | while x + path_size < x_end: 25 | y = board 26 | col = 0 27 | while y + path_size < y_end: 28 | crop_wf_img = wf_img[x: x + path_size, y : y + path_size] 29 | 30 | imwrite(os.path.join(save_wf_path, str(row) + '_' + str(col) + '.tif'), 31 | crop_wf_img) 32 | col += 1 33 | y += stride 34 | row += 1 35 | x += stride 36 | 37 | 38 | def test_pre(data_root, task='denoise'): 39 | target_path = os.path.join(data_root, task + '_test_crop_patches') 40 | mkdir(target_path) 41 | if task == 'denoise': 42 | image_types = ['Brain__4w_04.tif', 'Brain__4w_05.tif', 'Brain__4w_06.tif', 'Brain__4w_07.tif', 43 | 'Brain__4w_08.tif'] 44 | else: 45 | image_types = ['Brain__2w_01.tif', 'Brain__2w_02.tif', 'Brain__2w_03.tif'] 46 | for region_index in os.listdir(data_root): 47 | if not region_index.isdigit(): 48 | continue 49 | mkdir(os.path.join(target_path, region_index)) 50 | for type in image_types: 51 | # mkdir(os.path.join(target_path, region_index, type)) 52 | save_wf_path = os.path.join(os.path.join(target_path, region_index, type[:-4])) 53 | mkdir(save_wf_path) 54 | print(os.path.join(data_root, region_index, type)) 55 | wf_file_img = cv2.imread(os.path.join(data_root, region_index, type)) 56 | if task == 'denoise': 57 | process_denoise_pair(wf_file_img, save_wf_path, path_size=256, stride=224) 58 | else: 59 | process_denoise_pair(wf_file_img, save_wf_path, path_size=128, stride=112) 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--task', default="denoise") 63 | parser.add_argument('--path', help="dataset for evaluation") 64 | args = parser.parse_args() 65 | test_pre(args.path, args.task) -------------------------------------------------------------------------------- /vEM_test_pre.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tifffile import imread, imwrite 3 | import numpy as np 4 | import shutil 5 | def mkdir(path): 6 | if os.path.exists(path): 7 | shutil.rmtree(path) 8 | os.mkdir(path) 9 | 10 | def prctile_norm(x, min_prc=0, max_prc=100): 11 | x = np.array(x).astype(np.float64) 12 | y = (x - min_prc) / (max_prc - min_prc + 1e-7) 13 | y[y > 1] = 1 14 | y[y < 0] = 0 15 | return y * 255. 16 | 17 | def recon_pre(root_path): 18 | target_path = os.path.join(root_path, 'crop_patches') 19 | mkdir(os.path.join(target_path)) 20 | for file in os.listdir(root_path): 21 | if 'tif' not in file: 22 | continue 23 | x = 0 24 | # index = file.split('_')[1][:-4] 25 | index = file[:-4] 26 | path_size = 256 27 | stride = 224 28 | img = imread(os.path.join(root_path, file)) 29 | col_num = 0 30 | os.makedirs(os.path.join(target_path, index)) 31 | while x + path_size <= img.shape[0]: 32 | y = 0 33 | row_num = 0 34 | x_start = x 35 | while y + path_size <= img.shape[1]: 36 | y_start = y 37 | patch = img[x_start:x_start + path_size, y_start:y_start + path_size] 38 | imwrite(os.path.join(target_path, index, str(row_num) + '_' + str(col_num) + '.tif'), patch) 39 | row_num += 1 40 | y += stride 41 | row_num += 1 42 | x += stride 43 | col_num += 1 44 | return os.path.join(root_path, 'crop_patches') 45 | 46 | -------------------------------------------------------------------------------- /vEMa_pre.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tifffile import imread, imwrite 4 | import argparse 5 | 6 | 7 | def find_max_number(folder_path): 8 | max_number = 0 9 | for filename in os.listdir(folder_path): 10 | if filename.endswith('.tif'): 11 | filename = filename[:-4] 12 | if not filename.isdigit(): 13 | continue 14 | filename = int(filename) 15 | 16 | number = int(filename) 17 | max_number = max(max_number, number) 18 | return max_number 19 | 20 | 21 | def mkdir(path): 22 | if os.path.exists(path): 23 | import shutil 24 | shutil.rmtree(path) 25 | os.mkdir(path) 26 | 27 | 28 | def vem_transpose(data_root): 29 | stacks = [] 30 | z_depth = find_max_number(data_root) 31 | 32 | for i in range(z_depth): 33 | stacks.append(imread(os.path.join(data_root, f'{i}.tif'))) 34 | stack = np.stack(stacks) 35 | print(stack.shape) 36 | stack = stack.transpose(1, 0, 2) 37 | target_file_path = os.path.join(data_root, 'transposed') 38 | mkdir(target_file_path) 39 | for i in range(stack.shape[0]): 40 | imwrite(os.path.join(target_file_path, str(i) + '.tif'), stack[i]) 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | # parser.add_argument('--task', default="denoise") 46 | parser.add_argument('--path', help="dataset for evaluation") 47 | args = parser.parse_args() 48 | vem_transpose(args.path) 49 | --------------------------------------------------------------------------------