├── .gitignore
├── 3D-SR-Unet
├── __init__.py
├── data.py
├── main.py
├── model.py
└── train.py
├── LICENSE
├── RAFT
├── .gitignore
├── LICENSE
├── RAFT.png
├── README.md
├── __init__.py
├── alt_cuda_corr
│ ├── correlation.cpp
│ ├── correlation_kernel.cu
│ └── setup.py
├── chairs_split.txt
├── core
│ ├── __init__.py
│ ├── align_functions.py
│ ├── corr.py
│ ├── datasets.py
│ ├── extractor.py
│ ├── raft.py
│ ├── raftConfig.py
│ ├── register.py
│ ├── register_custom.py
│ ├── super_res_register.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
├── demo-frames
│ ├── frame_0016.png
│ ├── frame_0017.png
│ ├── frame_0018.png
│ ├── frame_0019.png
│ ├── frame_0020.png
│ ├── frame_0021.png
│ ├── frame_0022.png
│ ├── frame_0023.png
│ ├── frame_0024.png
│ └── frame_0025.png
├── demo.py
├── download_models.sh
├── evaluate.py
├── models
│ └── raft-things.pth
├── train.py
├── train_mixed.sh
└── train_standard.sh
├── README.md
├── config
├── EMDiffuse-n-big.json
├── EMDiffuse-n-transfer.json
├── EMDiffuse-n.json
├── EMDiffuse-r.json
├── vEMDiffuse-a.json
└── vEMDiffuse-i.json
├── core
├── __pycache__
│ ├── base_model.cpython-37.pyc
│ ├── base_network.cpython-37.pyc
│ ├── logger.cpython-37.pyc
│ ├── praser.cpython-37.pyc
│ └── util.cpython-37.pyc
├── base_dataset.py
├── base_model.py
├── base_network.py
├── calibration.py
├── logger.py
├── praser.py
└── util.py
├── crop_single_file.py
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── dataset.cpython-37.pyc
├── dataset.py
└── util
│ ├── auto_augment.py
│ └── mask.py
├── demo
├── denoise_demo.tif
├── microns_demo
│ ├── 0.tif
│ └── 1.tif
├── mouse_liver_demo
│ ├── 0.tif
│ └── 1.tif
└── super_res_demo.tif
├── emdiffuse_conifg.py
├── example
├── denoise
│ ├── prediction.ipynb
│ └── training.ipynb
├── super-res
│ ├── prediction.ipynb
│ └── training.ipynb
├── vEMDiffuse-a
│ ├── prediction.ipynb
│ └── training.ipynb
└── vEMDiffuse-i
│ ├── prediction.ipynb
│ └── training.ipynb
├── models
├── EMDiffuse_model.py
├── EMDiffuse_network.py
├── __init__.py
├── __pycache__
│ ├── EMDiffuse_model.cpython-37.pyc
│ ├── EMDiffuse_network.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── loss.cpython-37.pyc
│ └── metric.cpython-37.pyc
├── guided_diffusion_modules
│ ├── __pycache__
│ │ ├── nn.cpython-37.pyc
│ │ ├── unet.cpython-37.pyc
│ │ └── unet_jit2.cpython-37.pyc
│ ├── nn.py
│ ├── unet.py
│ ├── unet_3d.py
│ ├── unet_3d_aleatoric.py
│ ├── unet_aleatoric.py
│ ├── unet_jit.py
│ └── unet_jit2.py
├── loss.py
├── metric.py
├── unet.py
├── vEMDiffuse_model.py
└── vEMDiffuse_network.py
├── requirements.txt
├── run.py
├── test_pre.py
├── vEM_test_pre.py
└── vEMa_pre.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | .DS_STORE
4 |
--------------------------------------------------------------------------------
/3D-SR-Unet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/3D-SR-Unet/__init__.py
--------------------------------------------------------------------------------
/3D-SR-Unet/data.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from torchvision import transforms
3 | from PIL import Image, ImageOps, ImageFilter
4 | import os
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | import random
8 | import torch
9 | from scipy.ndimage import zoom
10 | from scipy.ndimage import gaussian_filter1d, gaussian_filter
11 | from tifffile import imread
12 | from scipy.ndimage import zoom
13 | class KidneySRUData(data.Dataset):
14 | def __init__(self, data_root):
15 | self.data_root = data_root
16 |
17 | self.volume_list = self.read_dataset(data_root=self.data_root)
18 |
19 | def __getitem__(self, index):
20 | ret = {}
21 | gt = imread(self.volume_list[index])
22 | # print(gt.shape)
23 | img = gt[::6, :, :]
24 | img_upsampled = zoom(img, (6, 1,1 ), order=3)
25 | img, gt, img_upsampled = self.aug(img, gt, img_upsampled)
26 | img = img / 255.
27 | gt = gt / 255.
28 | img_upsampled = img_upsampled / 255.
29 | img = self.norm(img)
30 | gt = self.norm(gt)
31 | img_upsampled = self.norm(img_upsampled)
32 | img = torch.tensor(img, dtype=torch.float32).unsqueeze_(dim=0)
33 | gt = torch.tensor(gt, dtype=torch.float32).unsqueeze_(dim=0)
34 | img_upsampled = torch.tensor(img_upsampled, dtype=torch.float32).unsqueeze_(dim=0)
35 | return img, gt,img_upsampled
36 |
37 | def norm(self, img):
38 | img = img.astype(np.float32)
39 | img = (img - 0.5) / 0.5
40 | return img
41 |
42 | def __len__(self):
43 | return len(self.volume_list)
44 |
45 | def aug(self, img, gt, img_up):
46 | if random.random() < 0.5:
47 | img = np.flip(img, axis=2)
48 | gt = np.flip(gt, axis=2)
49 | img_up = np.flip(img_up, axis=2)
50 | if random.random() < 0.5:
51 | img = np.rot90(img, k=1, axes=(1, 2))
52 | gt = np.rot90(gt, k=1, axes=(1, 2))
53 | img_up = np.rot90(img_up, k=1, axes=(1, 2))
54 | return img, gt, img_up
55 |
56 | def read_dataset(self, data_root):
57 | volume_list = []
58 | for i in range(2000):
59 | if os.path.exists(os.path.join(data_root, str(i) + '.tif')):
60 | volume_list.append(os.path.join(data_root, str(i) + '.tif'))
61 | return volume_list
62 |
--------------------------------------------------------------------------------
/3D-SR-Unet/main.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import argparse
3 | import os
4 | import argparse
5 | import torch
6 | import numpy as np
7 | from torch import Generator, randperm
8 | import random
9 | from torch.utils.data import DataLoader, Subset
10 | from train import train_cnn
11 | from model import SRUNet, CubicWeightedPSNRLoss
12 | from data import KidneySRUData
13 |
14 | warnings.filterwarnings('ignore')
15 |
16 |
17 | def train_distributed(args):
18 | model = SRUNet(up_scale=6).cuda()
19 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
20 | criterion = CubicWeightedPSNRLoss().cuda()
21 | dataset_train = KidneySRUData(data_root='/data/cxlu/srunet_liver_training_large/srunet_training')
22 | data_len = len(dataset_train)
23 | valid_len = int(data_len * 0.1)
24 | data_len -= valid_len
25 | dataset_train, dataset_val = subset_split(dataset_train, lengths=[data_len, valid_len],
26 | generator=Generator().manual_seed(args.seed))
27 | train_loader = DataLoader(dataset=dataset_train, num_workers=args.num_worker, batch_size=args.b, pin_memory=True,
28 | shuffle=True)
29 | val_loader = DataLoader(dataset=dataset_val, num_workers=args.num_worker, batch_size=args.b, pin_memory=True,
30 | )
31 | train_cnn(train_generator=train_loader, valid_generator=val_loader, args=args, optimizer=optimizer, model=model,
32 | criterion=criterion)
33 |
34 | def subset_split(dataset, lengths, generator):
35 | """
36 | """
37 | indices = randperm(sum(lengths), generator=generator).tolist()
38 | Subsets = []
39 | for offset, length in zip(np.add.accumulate(lengths), lengths):
40 | if length == 0:
41 | Subsets.append(None)
42 | else:
43 | Subsets.append(Subset(dataset, indices[offset - length: offset]))
44 | return Subsets
45 |
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser(description='Self Training benchmark')
49 | parser.add_argument('--b', default=16, type=int, help='batch size')
50 | parser.add_argument('--epoch', default=100, type=int, help='epochs to train')
51 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
52 | parser.add_argument('--output', default='./model_genesis_pretrain', type=str, help='output path')
53 | parser.add_argument('--gpus', default='0,1,2,3', type=str, help='gpu indexs')
54 | parser.add_argument('--seed', default=42, type=int)
55 | parser.add_argument('--num_worker', type=int, default=8)
56 | args = parser.parse_args()
57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
58 | seed = args.seed
59 | torch.manual_seed(seed)
60 | torch.cuda.manual_seed_all(seed)
61 | np.random.seed(seed)
62 | random.seed(seed)
63 | torch.backends.cudnn.deterministic = True
64 | torch.backends.cudnn.benchmark = False
65 | if not os.path.exists(args.output):
66 | os.makedirs(args.output)
67 | print(args)
68 | train_distributed(args)
69 |
--------------------------------------------------------------------------------
/3D-SR-Unet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from scipy.ndimage import zoom
5 |
6 |
7 | class CubicWeightedPSNRLoss(nn.Module):
8 | def __init__(self):
9 | super(CubicWeightedPSNRLoss, self).__init__()
10 |
11 | def forward(self, upsampled_input, pred, target):
12 | # Perform cubic upsampling on the input
13 | error = (upsampled_input - target) ** 2
14 | weight = error / (error.max() * 2) + 0.5
15 | # Compute the pixel-wise cubic-weighted MSE loss
16 | weighted_mse = ((pred - target) ** 2 * weight).mean()
17 | # Compute the cubic-weighted PSNR loss
18 | return weighted_mse
19 |
20 |
21 | def conv3x3(in_channels, out_channels, stride=1,
22 | padding=1, bias=True, groups=1):
23 | return nn.Conv2d(
24 | in_channels,
25 | out_channels,
26 | kernel_size=3,
27 | stride=stride,
28 | padding=padding,
29 | bias=bias,
30 | groups=groups)
31 |
32 |
33 | def conv3x3x3(in_channels, out_channels, stride=1,
34 | padding=1, bias=True, groups=1):
35 | return nn.Conv3d(
36 | in_channels,
37 | out_channels,
38 | kernel_size=3,
39 | stride=stride,
40 | padding=padding,
41 | bias=bias,
42 | groups=groups)
43 |
44 |
45 | class SRUNet(nn.Module):
46 | def __init__(self, up_scale=6):
47 | super().__init__()
48 | self.up_scale = up_scale
49 | self.conv1_1 = conv3x3x3(1, 32)
50 | self.conv1_2 = conv3x3x3(32, 32)
51 | self.conv1_3 = conv3x3x3(32, 32)
52 | self.fracconv1 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3,
53 | stride=(self.up_scale, 1, 1), padding=1)
54 | self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
55 | self.conv2_1 = conv3x3x3(32, 64)
56 | self.conv2_2 = conv3x3x3(64, 64)
57 | self.conv2_3 = conv3x3x3(64, 64)
58 | self.fracconv2 = nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=3, stride=(2, 1, 1), padding=1)
59 | self.conv3_1 = conv3x3x3(64, 128)
60 | self.conv3_2 = conv3x3x3(128, 128)
61 | self.conv3_3 = conv3x3x3(128, 128)
62 | self.fracconv3 = nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=3, stride=(2, 2, 2),
63 | padding=1)
64 | self.conv2_4 = conv3x3x3(128, 64)
65 | self.conv2_5 = conv3x3x3(64, 64)
66 | self.conv2_6 = conv3x3x3(64, 64)
67 | self.fracconv4 = nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3,
68 | stride=(self.up_scale // 2, 2, 2), padding=1)
69 | self.conv1_4 = conv3x3x3(64, 32)
70 | self.conv1_5 = conv3x3x3(32, 32)
71 | self.conv1_6 = conv3x3x3(32, 32)
72 | self.final_conv = conv3x3x3(32, 1)
73 |
74 | def forward(self, x):
75 | x_1_1 = F.relu(self.conv1_1(x))
76 | x_1_2 = F.relu(self.conv1_2(x_1_1))
77 | x_1_3 = F.relu(self.conv1_3(x_1_2))
78 | b, c, d, h, w = x_1_3.shape
79 | x_frac1 = self.fracconv1(x_1_3, output_size=(b, c, d * self.up_scale, h, w))
80 | # print(x_frac1.shape)
81 | x_2_1 = self.pool(x_1_3)
82 | x_2_2 = F.relu(self.conv2_1(x_2_1))
83 |
84 | x_2_3 = F.relu(self.conv2_2(x_2_2))
85 | x_2_4 = F.relu(self.conv2_3(x_2_3))
86 | b, c, d, h, w = x_2_4.shape
87 | x_frac2 = self.fracconv2(x_2_4, output_size=(b, c, d * 2, h, w))
88 | # print(x_frac2.shape)
89 | x_3_1 = self.pool(x_2_4)
90 | x_3_2 = F.relu(self.conv3_1(x_3_1))
91 | x_3_3 = F.relu(self.conv3_2(x_3_2))
92 | x_3_4 = F.relu(self.conv3_3(x_3_3))
93 | b, c, d, h, w = x_3_4.shape
94 | x_frac3 = self.fracconv3(x_3_4, output_size=(b, c, d * 2, h * 2, w * 2))
95 | # print(x_frac3.shape)
96 | x_merge_2 = torch.concatenate([x_frac3, x_frac2], dim=1)
97 | x_2_5 = F.relu(self.conv2_4(x_merge_2))
98 | x_2_6 = F.relu(self.conv2_5(x_2_5))
99 | x_2_7 = F.relu(self.conv2_6(x_2_6))
100 | b, c, d, h, w = x_2_7.shape
101 | x_frac4 = self.fracconv4(x_2_7, output_size=(b, c, d * self.up_scale // 2, h * 2, w * 2))
102 | # print(x_frac4.shape)
103 | x_merge_1 = torch.concatenate([x_frac1, x_frac4], dim=1)
104 | x_1_4 = F.relu(self.conv1_4(x_merge_1))
105 | x_1_5 = F.relu(self.conv1_5(x_1_4))
106 | x_1_6 = F.relu(self.conv1_6(x_1_5))
107 | out = self.final_conv(x_1_6)
108 | return out
109 |
110 |
111 | if __name__ == '__main__':
112 | model = SRUNet(up_scale=6)
113 |
114 | test_gt = torch.rand((1, 1, 16, 128, 128))
115 | test_input = torch.rand((1, 1, 16, 128, 128))
116 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
117 | test_out = model(test_input)
118 | loss_function = CubicWeightedPSNRLoss()
119 | loss = loss_function(test_input, test_gt)
120 | optimizer.zero_grad()
121 | loss.backward()
122 |
--------------------------------------------------------------------------------
/3D-SR-Unet/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import sys
5 | import os
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 |
11 | def train_cnn(optimizer, model, train_generator, valid_generator, criterion, args):
12 | n_iteration_per_epoch = len(train_generator)
13 |
14 | tb_logger = SummaryWriter(log_dir=args.output)
15 | print(n_iteration_per_epoch)
16 | step_size = 100 # Apply adjustment every 10 epochs
17 | lr_initial = 1e-4
18 | lambda_lr = lambda step: lr_initial * ((step // step_size + 1) ** 0.5) if step > 0 else lr_initial
19 | scheduler = LambdaLR(optimizer, lambda_lr)
20 | train_losses = []
21 | valid_losses_l1 = []
22 | avg_train_losses = []
23 | best_loss = 100000
24 | num_epoch_no_improvement = 0
25 | for epoch in range(args.epoch + 1):
26 | model.train()
27 |
28 | iteration = 0
29 | total_step = 0
30 | for idx, (image, gt, img_upsampled) in enumerate(train_generator):
31 | # scheduler.step()
32 | total_step += args.b
33 | img = image.cuda(non_blocking=True).float()
34 | gt = gt.cuda(non_blocking=True).float()
35 | img_upsampled = img_upsampled.cuda(non_blocking=True).float()
36 | pred = model(img)
37 | loss = criterion(img_upsampled, pred, gt)
38 | iteration += 1
39 | optimizer.zero_grad()
40 | loss.backward()
41 | optimizer.step()
42 | scheduler.step(epoch * n_iteration_per_epoch + iteration)
43 | train_losses.append(round(loss.item(), 2))
44 | if (iteration + 1) % 20 == 0:
45 | print('Epoch [{}/{}], iteration {}, l1oss:{:.6f}, {:.6f} ,learning rate{:.6f}'
46 | .format(epoch + 1, args.epoch, iteration + 1, loss.sum().item(), np.average(train_losses),
47 | optimizer.state_dict()['param_groups'][0]['lr']))
48 | sys.stdout.flush()
49 |
50 | with torch.no_grad():
51 | model.eval()
52 | print("validating....")
53 | for i, (image, gt, _) in enumerate(valid_generator):
54 | image_scale = image.cuda(non_blocking=True).float()
55 | gt_scale = gt.cuda(non_blocking=True).float()
56 | pred = model(image_scale)
57 | loss = criterion(pred, gt_scale)
58 | valid_losses_l1.append(loss.sum().item())
59 | # logging
60 | train_loss = np.average(train_losses)
61 | valid_loss_l1 = np.average(valid_losses_l1)
62 | valid_loss = valid_loss_l1
63 | tb_logger.add_scalar('valid loss', valid_loss_l1, epoch)
64 | avg_train_losses.append(train_loss)
65 | print("Epoch {}, validation loss is {:.4f}, training loss is {:.4f}".format(epoch + 1, valid_loss,
66 | train_loss))
67 | train_losses = []
68 | valid_losses = []
69 |
70 | if valid_loss < best_loss:
71 | print("Validation loss decreases from {:.4f} to {:.4f}".format(best_loss, valid_loss))
72 | best_loss = valid_loss
73 | num_epoch_no_improvement = 0
74 | # save model
75 | # save all the weight for 3d unet
76 | torch.save({
77 | 'args': args,
78 | 'epoch': epoch + 1,
79 | 'state_dict': model.state_dict(),
80 | 'optimizer_state_dict': optimizer.state_dict()
81 | }, os.path.join(args.output,
82 | 'best' + '.pt'))
83 | print("Saving model ",
84 | os.path.join(args.output,
85 | 'best' + '.pt'))
86 | else:
87 | if epoch % 10 == 0:
88 | torch.save({
89 | 'args': args,
90 | 'epoch': epoch + 1,
91 | 'state_dict': model.state_dict(),
92 | 'optimizer_state_dict': optimizer.state_dict()
93 | }, os.path.join(args.output,
94 | 'epoch_' + str(epoch) + '.pt'))
95 | print("Saving model ",
96 | os.path.join(args.output,
97 | 'epoch_' + str(epoch) + '.pt'))
98 | print("Validation loss does not decrease from {:.4f}, num_epoch_no_improvement {}".format(best_loss,
99 | num_epoch_no_improvement))
100 | num_epoch_no_improvement += 1
101 | if num_epoch_no_improvement > 10:
102 | break
103 | sys.stdout.flush()
104 | tb_logger.close()
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Luchixiang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/RAFT/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | dist
4 | datasets
5 | pytorch_env
6 | build
7 | correlation.egg-info
8 |
--------------------------------------------------------------------------------
/RAFT/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/RAFT/RAFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/RAFT.png
--------------------------------------------------------------------------------
/RAFT/README.md:
--------------------------------------------------------------------------------
1 | # RAFT
2 | This repository contains the source code for our paper:
3 |
4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 |
8 |
9 |
10 | ## Requirements
11 | The code has been tested with PyTorch 1.6 and Cuda 10.1.
12 | ```Shell
13 | conda create --name raft
14 | conda activate raft
15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
16 | ```
17 |
18 | ## Demos
19 | Pretrained models can be downloaded by running
20 | ```Shell
21 | ./download_models.sh
22 | ```
23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
24 |
25 | You can demo a trained model on a sequence of frames
26 | ```Shell
27 | python demo.py --model=models/raft-things.pth --path=demo-frames
28 | ```
29 |
30 | ## Required Data
31 | To evaluate/train RAFT, you will need to download the required datasets.
32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
34 | * [Sintel](http://sintel.is.tue.mpg.de/)
35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
37 |
38 |
39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
40 |
41 | ```Shell
42 | ├── datasets
43 | ├── Sintel
44 | ├── test
45 | ├── training
46 | ├── KITTI
47 | ├── testing
48 | ├── training
49 | ├── devkit
50 | ├── FlyingChairs_release
51 | ├── data
52 | ├── FlyingThings3D
53 | ├── frames_cleanpass
54 | ├── frames_finalpass
55 | ├── optical_flow
56 | ```
57 |
58 | ## Evaluation
59 | You can evaluate a trained model using `evaluate.py`
60 | ```Shell
61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
62 | ```
63 |
64 | ## Training
65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
66 | ```Shell
67 | ./train_standard.sh
68 | ```
69 |
70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
71 | ```Shell
72 | ./train_mixed.sh
73 | ```
74 |
75 | ## (Optional) Efficent Implementation
76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
77 | ```Shell
78 | cd alt_cuda_corr && python setup.py install && cd ..
79 | ```
80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
81 |
--------------------------------------------------------------------------------
/RAFT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/__init__.py
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/RAFT/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/core/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/align_functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tifffile import imread, imwrite
3 | import shutil
4 | import cv2
5 | import imutils
6 | import numpy as np
7 | import image_registration
8 | from scipy.ndimage import shift
9 |
10 |
11 | def mkdir(path):
12 | if os.path.exists(path):
13 | shutil.rmtree(path)
14 | os.mkdir(path)
15 |
16 |
17 | def delete_outlier(points1, points2, move=0, outlier_percent=0.3):
18 | """
19 | Delete the outliers/mismatches based on the angle and distance.
20 | Args:
21 | points1: key points detected in frame1
22 | points2: key points detected in frame2
23 | move: move for a small distance to avoid points appear at the same location
24 | outlier_percent: how many outliers are removed
25 |
26 | Returns: indexes of selected key points
27 | """
28 | # angle
29 | points1_mv = points1.copy()
30 | points1_mv[:, 0] = points1[:, 0] - move
31 | vecs = points2 - points1_mv
32 | norms = np.linalg.norm(vecs, axis=1, keepdims=True)
33 | vec_norms = vecs / (norms + 1e-6)
34 | vec_means = np.mean(vec_norms, axis=0).reshape((2, 1))
35 | cross_angles = vec_norms.dot(vec_means)[:, 0]
36 | index = np.argsort(-cross_angles)
37 | num_select = int(len(index) * (1 - outlier_percent))
38 | index_selected = index[0:num_select]
39 |
40 | # distance
41 | index1 = np.argsort(norms[:, 0])
42 | # print(index1)
43 | index1_selected = index1[0:num_select]
44 |
45 | index_selected = list(set(index1_selected) & set(index_selected))
46 |
47 | return index_selected, np.mean(norms[index_selected])
48 |
49 |
50 | def align_images(imageGray, templateGray, maxFeatures=500, keepPercent=0.2,
51 | debug=False, outlier=True, sup_img=None):
52 | # convert both the input image and template to grayscale
53 |
54 | orb = cv2.ORB_create(maxFeatures)
55 | (kpsA, descsA) = orb.detectAndCompute(imageGray, None)
56 | (kpsB, descsB) = orb.detectAndCompute(templateGray, None)
57 | # match the features
58 | method = cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING
59 | matcher = cv2.DescriptorMatcher_create(method)
60 | matches = matcher.match(descsA, descsB, None)
61 | matches = sorted(matches, key=lambda x: x.distance)
62 | # keep only the top matches
63 | keep = int(len(matches) * keepPercent)
64 | matches = matches[:keep]
65 | # check to see if we should visualize the matched keypoints
66 | ptsA = np.zeros((len(matches), 2), dtype="float")
67 | ptsB = np.zeros((len(matches), 2), dtype="float")
68 | # loop over the top matches
69 | for (i, m) in enumerate(matches):
70 | # indicate that the two keypoints in the respective images
71 | # map to each other
72 | ptsA[i] = kpsA[m.queryIdx].pt
73 | ptsB[i] = kpsB[m.trainIdx].pt
74 | if outlier:
75 | index, distance = delete_outlier(ptsA, ptsB)
76 |
77 | matches = list(np.array(matches)[index])
78 | ptsA = ptsA[index, :]
79 | ptsB = ptsB[index, :]
80 | if len(matches) < 10:
81 | return None
82 | (H, mask) = cv2.findHomography(ptsA, ptsB, method=cv2.RANSAC)
83 | return H
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/RAFT/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/RAFT/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/RAFT/core/raftConfig.py:
--------------------------------------------------------------------------------
1 | class RaftConfig:
2 | def __init__(self, path, patch_size=256, border=32, tissue='Brain', overlap=0.125):
3 | self.path = path
4 | self.patch_size = patch_size
5 | self.border = border
6 | self.tissue = tissue
7 | self.small = False
8 | self.model = 'RAFT/models/raft-things.pth'
9 | self.overlap = overlap
10 | self.mixed_precision = False
11 | self.alternate_corr = False
12 | self.occlusion = False
13 |
14 | def __getattr__(self, item):
15 | # This method is called when an attribute access is attempted.
16 | try:
17 | return self.__dict__[item]
18 | except KeyError:
19 | return None
20 |
21 | def __setattr__(self, key, value):
22 | # This method allows setting attributes directly.
23 | self.__dict__[key] = value
24 |
25 | def __contains__(self, item):
26 | # This enables the use of 'in' to check for attribute existence.
27 | return item in self.__dict__
28 |
29 |
--------------------------------------------------------------------------------
/RAFT/core/super_res_register.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import glob
9 | import numpy as np
10 | import torch
11 | from PIL import Image
12 |
13 | from raft import RAFT
14 | from utils import flow_viz
15 | from utils.utils import InputPadder
16 | from align_functions import *
17 | import torch.nn.functional as F
18 | import os
19 | from tifffile import imwrite
20 |
21 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
22 | DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
23 |
24 |
25 | def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
26 | # initialize the dimensions of the image to be resized and
27 | # grab the image size
28 | dim = None
29 | (h, w) = image.shape[:2]
30 |
31 | # if both the width and height are None, then return the
32 | # original image
33 | if width is None and height is None:
34 | return image
35 |
36 | # check to see if the width is None
37 | if width is None:
38 | # calculate the ratio of the height and construct the
39 | # dimensions
40 | r = height / float(h)
41 | dim = (int(w * r), height)
42 |
43 | # otherwise, the height is None
44 | else:
45 | # calculate the ratio of the width and construct the
46 | # dimensions
47 | r = width / float(w)
48 | dim = (width, int(h * r))
49 |
50 | # resize the image
51 | resized = cv2.resize(image, dim)
52 |
53 | # return the resized image
54 | return resized
55 |
56 |
57 | ######################
58 | ## Image form trans###
59 | ######################
60 | def img2tensor(img):
61 | img_t = np.expand_dims(img.transpose(2, 0, 1), axis=0)
62 | img_t = torch.from_numpy(img_t.astype(np.float32))
63 |
64 | return img_t
65 |
66 |
67 | def tensor2img(img_t):
68 | img = img_t[0].detach().to("cpu").numpy()
69 | img = np.transpose(img, (1, 2, 0))
70 |
71 | return img
72 |
73 |
74 | ######################
75 | # occlusion detection#
76 | ######################
77 |
78 | def warp(x, flo):
79 | """
80 | warp an image/tensor (im2) back to im1, according to the optical flow
81 | x: [B, C, H, W] (im2)
82 | flo: [B, 2, H, W] flow
83 | """
84 | B, C, H, W = x.size()
85 | # mesh grid
86 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
87 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
88 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
89 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
90 | grid = torch.cat((xx, yy), 1).float()
91 |
92 | if x.is_cuda:
93 | grid = grid.cuda()
94 | vgrid = grid + flo
95 | # scale grid to [-1,1]
96 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
97 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
98 |
99 | vgrid = vgrid.permute(0, 2, 3, 1)
100 | output = F.grid_sample(x, vgrid, align_corners=False)
101 | return output
102 |
103 |
104 |
105 | ###########################
106 | ## raft functions
107 | ###########################
108 | def load_image(img):
109 | img = np.stack([img, img, img], axis=2)
110 |
111 | img = torch.from_numpy(img).permute(2, 0, 1).float()
112 | return img[None].to(DEVICE)
113 |
114 | def process_pair(wf_img, gt_img, save_wf_path, save_gt_path, sup_wf_img=None, patch_size=256, stride=224, model=None,
115 | border=32):
116 | wf_img_origin = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY)
117 | gt_img_origin = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)
118 | wf_img = wf_img_origin[gt_img_origin.shape[0] // 2 - gt_img_origin.shape[0] // 4: gt_img_origin.shape[0] // 2 +
119 | gt_img_origin.shape[0] // 4,
120 | gt_img_origin.shape[1] // 2 - gt_img_origin.shape[1] // 4: gt_img_origin.shape[1] // 2 +
121 | gt_img_origin.shape[1] // 4]
122 | gt_img = cv2.resize(gt_img_origin, (gt_img_origin.shape[1] // 2, gt_img_origin.shape[0] // 2))
123 | x_offset, y_offset, _, _ = image_registration.chi2_shift(wf_img, gt_img, 0.1, return_error=True)
124 | wf_img = shift(wf_img_origin, (y_offset, x_offset))[
125 | gt_img_origin.shape[0] // 2 - gt_img_origin.shape[0] // 4 - 8: gt_img_origin.shape[0] // 2 +
126 | gt_img_origin.shape[0] // 4 + 8,
127 | gt_img_origin.shape[1] // 2 - gt_img_origin.shape[1] // 4 - 8: gt_img_origin.shape[1] // 2 +
128 | gt_img_origin.shape[1] // 4 + 8]
129 |
130 | H = align_images(wf_img, gt_img, debug=False)
131 | h, w = gt_img.shape
132 | aligned = cv2.warpPerspective(wf_img, H, (w, h))
133 | x = border
134 | x_end = wf_img.shape[0] - border
135 | y_end = wf_img.shape[0] - border
136 | count = 1
137 | while x + patch_size < x_end:
138 | y = border
139 | while y + patch_size < y_end:
140 | crop_wf_img = aligned[x - border: x + patch_size + border, y - border: y + patch_size + border]
141 | crop_gt_img = gt_img[x - border: x + patch_size + border, y - border: y + patch_size + border]
142 | H_sub = align_images(crop_wf_img, crop_gt_img)
143 | if H_sub is None:
144 | count += 1
145 | y += stride
146 | continue
147 | else:
148 | (h_sub, w_sub) = crop_gt_img.shape[:2]
149 | crop_wf_img = cv2.warpPerspective(crop_wf_img, H_sub, (w_sub, h_sub))
150 | if np.sum(crop_wf_img[border:-border, border:-border] == 0) > 10:
151 | count += 1
152 | y += stride
153 | continue
154 | image1 = load_image(crop_gt_img)
155 | image2 = load_image(crop_wf_img)
156 |
157 | padder = InputPadder(image1.shape)
158 | image1, image2 = padder.pad(image1, image2)
159 |
160 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
161 | image_warped = warp(image2 / 255.0, flow_up)
162 | crop_wf_img = image_warped[0].permute(1, 2, 0).cpu().numpy()
163 | crop_wf_img = np.uint8(crop_wf_img[:, :, 0] * 255)
164 | if np.sum(crop_wf_img[border:-border, border:-border] == 0) > 10:
165 | count += 1
166 | y += stride
167 | continue
168 | imwrite(os.path.join(save_wf_path, str(count) + '.tif'), crop_wf_img[border:-border, border:-border])
169 | imwrite(os.path.join(save_gt_path, str(count) + '.tif'),
170 | gt_img_origin[2 * x: 2 * x + 2 * patch_size, 2 * y: 2 * y + 2 * patch_size])
171 | count += 1
172 | y += stride
173 | x += stride
174 |
175 |
176 | def registration(args):
177 | model = torch.nn.DataParallel(RAFT(args))
178 | model.load_state_dict(torch.load(args.model, map_location='cpu'))
179 |
180 | model = model.module
181 | model.to(DEVICE)
182 | model.eval()
183 |
184 | with torch.no_grad():
185 | task = 'zoom'
186 | path = args.path
187 | target_path = os.path.join(path, task)
188 | mkdir(target_path)
189 | image_types = ['Brain__2w_01.tif', 'Brain__2w_02.tif', 'Brain__2w_03.tif']
190 | train_wf_path = os.path.join(target_path, 'train_wf')
191 | train_gt_path = os.path.join(target_path, 'train_gt')
192 | mkdir(train_wf_path)
193 | mkdir(train_gt_path)
194 | for i in range(100):
195 | if not os.path.exists(os.path.join(path, str(i), 'Brain__4w_09.tif')):
196 | continue
197 | roi_wf_path = os.path.join(train_wf_path, str(i))
198 | roi_gt_path = os.path.join(train_gt_path, str(i))
199 |
200 | mkdir(roi_wf_path)
201 | mkdir(roi_gt_path)
202 | for type in image_types:
203 | print(f'processing image {i}, {type}')
204 | save_wf_path = os.path.join(roi_wf_path, type[:-4])
205 | save_gt_path = os.path.join(roi_gt_path, type[:-4])
206 | mkdir(save_wf_path)
207 | mkdir(save_gt_path)
208 | gt_file_img = cv2.imread(os.path.join(path, str(i), 'Brain__4w_09.tif'))
209 | wf_file_img = cv2.imread(os.path.join(path, str(i), type))
210 | sup_wf_img = None
211 | # print(wf_file_img.min())
212 | process_pair(wf_file_img, gt_file_img, save_wf_path, save_gt_path, sup_wf_img=sup_wf_img, model=model,
213 | patch_size=args.patch_size, border=args.border, stride=int(args.patch_size * (1-args.overlap)))
214 |
215 |
216 | if __name__ == '__main__':
217 | parser = argparse.ArgumentParser()
218 | parser.add_argument('--model', default="../models/raft-things.pth")
219 | parser.add_argument('--path', help="dataset for evaluation")
220 | parser.add_argument('--category', help="save warped images")
221 | parser.add_argument('--small', action='store_true', help='use small model')
222 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
223 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
224 | parser.add_argument('--occlusion', action='store_true', help='predict occlusion masks')
225 | parser.add_argument('--patch_size', default=128, type=int)
226 | parser.add_argument('--border', default=32, type=int)
227 | parser.add_argument('--overlap', default=0.125, type=float)
228 |
229 | args = parser.parse_args()
230 |
231 | registration(args)
232 |
--------------------------------------------------------------------------------
/RAFT/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/RAFT/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/core/utils/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/RAFT/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/RAFT/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0018.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0019.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0020.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0021.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0022.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0023.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0024.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/demo-frames/frame_0025.png
--------------------------------------------------------------------------------
/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from raft import RAFT
13 | from utils import flow_viz
14 | from utils.utils import InputPadder
15 |
16 |
17 |
18 | DEVICE = 'cuda'
19 |
20 | def load_image(imfile):
21 | img = np.array(Image.open(imfile)).astype(np.uint8)
22 | img = torch.from_numpy(img).permute(2, 0, 1).float()
23 | return img[None].to(DEVICE)
24 |
25 |
26 | def viz(img, flo):
27 | img = img[0].permute(1,2,0).cpu().numpy()
28 | flo = flo[0].permute(1,2,0).cpu().numpy()
29 |
30 | # map flow to rgb image
31 | flo = flow_viz.flow_to_image(flo)
32 | img_flo = np.concatenate([img, flo], axis=0)
33 |
34 | # import matplotlib.pyplot as plt
35 | # plt.imshow(img_flo / 255.0)
36 | # plt.show()
37 |
38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
39 | cv2.waitKey()
40 |
41 |
42 | def demo(args):
43 | model = torch.nn.DataParallel(RAFT(args))
44 | model.load_state_dict(torch.load(args.model))
45 |
46 | model = model.module
47 | model.to(DEVICE)
48 | model.eval()
49 |
50 | with torch.no_grad():
51 | images = glob.glob(os.path.join(args.path, '*.png')) + \
52 | glob.glob(os.path.join(args.path, '*.jpg'))
53 |
54 | images = sorted(images)
55 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
56 | image1 = load_image(imfile1)
57 | image2 = load_image(imfile2)
58 |
59 | padder = InputPadder(image1.shape)
60 | image1, image2 = padder.pad(image1, image2)
61 |
62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
63 | viz(image1, flow_up)
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--model', help="restore checkpoint")
69 | parser.add_argument('--path', help="dataset for evaluation")
70 | parser.add_argument('--small', action='store_true', help='use small model')
71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
73 | args = parser.parse_args()
74 |
75 | demo(args)
76 |
--------------------------------------------------------------------------------
/RAFT/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/RAFT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | from raft import RAFT
18 | from utils.utils import InputPadder, forward_interpolate
19 |
20 |
21 | @torch.no_grad()
22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
23 | """ Create submission for the Sintel leaderboard """
24 | model.eval()
25 | for dstype in ['clean', 'final']:
26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
27 |
28 | flow_prev, sequence_prev = None, None
29 | for test_id in range(len(test_dataset)):
30 | image1, image2, (sequence, frame) = test_dataset[test_id]
31 | if sequence != sequence_prev:
32 | flow_prev = None
33 |
34 | padder = InputPadder(image1.shape)
35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
36 |
37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
39 |
40 | if warm_start:
41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
42 |
43 | output_dir = os.path.join(output_path, dstype, sequence)
44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
45 |
46 | if not os.path.exists(output_dir):
47 | os.makedirs(output_dir)
48 |
49 | frame_utils.writeFlow(output_file, flow)
50 | sequence_prev = sequence
51 |
52 |
53 | @torch.no_grad()
54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
55 | """ Create submission for the Sintel leaderboard """
56 | model.eval()
57 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
58 |
59 | if not os.path.exists(output_path):
60 | os.makedirs(output_path)
61 |
62 | for test_id in range(len(test_dataset)):
63 | image1, image2, (frame_id, ) = test_dataset[test_id]
64 | padder = InputPadder(image1.shape, mode='kitti')
65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
66 |
67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
69 |
70 | output_filename = os.path.join(output_path, frame_id)
71 | frame_utils.writeFlowKITTI(output_filename, flow)
72 |
73 |
74 | @torch.no_grad()
75 | def validate_chairs(model, iters=24):
76 | """ Perform evaluation on the FlyingChairs (test) split """
77 | model.eval()
78 | epe_list = []
79 |
80 | val_dataset = datasets.FlyingChairs(split='validation')
81 | for val_id in range(len(val_dataset)):
82 | image1, image2, flow_gt, _ = val_dataset[val_id]
83 | image1 = image1[None].cuda()
84 | image2 = image2[None].cuda()
85 |
86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
88 | epe_list.append(epe.view(-1).numpy())
89 |
90 | epe = np.mean(np.concatenate(epe_list))
91 | print("Validation Chairs EPE: %f" % epe)
92 | return {'chairs': epe}
93 |
94 |
95 | @torch.no_grad()
96 | def validate_sintel(model, iters=32):
97 | """ Peform validation using the Sintel (train) split """
98 | model.eval()
99 | results = {}
100 | for dstype in ['clean', 'final']:
101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
102 | epe_list = []
103 |
104 | for val_id in range(len(val_dataset)):
105 | image1, image2, flow_gt, _ = val_dataset[val_id]
106 | image1 = image1[None].cuda()
107 | image2 = image2[None].cuda()
108 |
109 | padder = InputPadder(image1.shape)
110 | image1, image2 = padder.pad(image1, image2)
111 |
112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
113 | flow = padder.unpad(flow_pr[0]).cpu()
114 |
115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
116 | epe_list.append(epe.view(-1).numpy())
117 |
118 | epe_all = np.concatenate(epe_list)
119 | epe = np.mean(epe_all)
120 | px1 = np.mean(epe_all<1)
121 | px3 = np.mean(epe_all<3)
122 | px5 = np.mean(epe_all<5)
123 |
124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
125 | results[dstype] = np.mean(epe_list)
126 |
127 | return results
128 |
129 |
130 | @torch.no_grad()
131 | def validate_kitti(model, iters=24):
132 | """ Peform validation using the KITTI-2015 (train) split """
133 | model.eval()
134 | val_dataset = datasets.KITTI(split='training')
135 |
136 | out_list, epe_list = [], []
137 | for val_id in range(len(val_dataset)):
138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
139 | image1 = image1[None].cuda()
140 | image2 = image2[None].cuda()
141 |
142 | padder = InputPadder(image1.shape, mode='kitti')
143 | image1, image2 = padder.pad(image1, image2)
144 |
145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
146 | flow = padder.unpad(flow_pr[0]).cpu()
147 |
148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
149 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
150 |
151 | epe = epe.view(-1)
152 | mag = mag.view(-1)
153 | val = valid_gt.view(-1) >= 0.5
154 |
155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
156 | epe_list.append(epe[val].mean().item())
157 | out_list.append(out[val].cpu().numpy())
158 |
159 | epe_list = np.array(epe_list)
160 | out_list = np.concatenate(out_list)
161 |
162 | epe = np.mean(epe_list)
163 | f1 = 100 * np.mean(out_list)
164 |
165 | print("Validation KITTI: %f, %f" % (epe, f1))
166 | return {'kitti-epe': epe, 'kitti-f1': f1}
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument('--model', help="restore checkpoint")
172 | parser.add_argument('--dataset', help="dataset for evaluation")
173 | parser.add_argument('--small', action='store_true', help='use small model')
174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
176 | args = parser.parse_args()
177 |
178 | model = torch.nn.DataParallel(RAFT(args))
179 | model.load_state_dict(torch.load(args.model))
180 |
181 | model.cuda()
182 | model.eval()
183 |
184 | # create_sintel_submission(model.module, warm_start=True)
185 | # create_kitti_submission(model.module)
186 |
187 | with torch.no_grad():
188 | if args.dataset == 'chairs':
189 | validate_chairs(model.module)
190 |
191 | elif args.dataset == 'sintel':
192 | validate_sintel(model.module)
193 |
194 | elif args.dataset == 'kitti':
195 | validate_kitti(model.module)
196 |
197 |
198 |
--------------------------------------------------------------------------------
/RAFT/models/raft-things.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/RAFT/models/raft-things.pth
--------------------------------------------------------------------------------
/RAFT/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | from raft import RAFT
19 | import evaluate
20 | import datasets
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | try:
25 | from torch.cuda.amp import GradScaler
26 | except:
27 | # dummy GradScaler for PyTorch < 1.6
28 | class GradScaler:
29 | def __init__(self):
30 | pass
31 | def scale(self, loss):
32 | return loss
33 | def unscale_(self, optimizer):
34 | pass
35 | def step(self, optimizer):
36 | optimizer.step()
37 | def update(self):
38 | pass
39 |
40 |
41 | # exclude extremly large displacements
42 | MAX_FLOW = 400
43 | SUM_FREQ = 100
44 | VAL_FREQ = 5000
45 |
46 |
47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
48 | """ Loss function defined over sequence of flow predictions """
49 |
50 | n_predictions = len(flow_preds)
51 | flow_loss = 0.0
52 |
53 | # exlude invalid pixels and extremely large diplacements
54 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
55 | valid = (valid >= 0.5) & (mag < max_flow)
56 |
57 | for i in range(n_predictions):
58 | i_weight = gamma**(n_predictions - i - 1)
59 | i_loss = (flow_preds[i] - flow_gt).abs()
60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
61 |
62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
63 | epe = epe.view(-1)[valid.view(-1)]
64 |
65 | metrics = {
66 | 'epe': epe.mean().item(),
67 | '1px': (epe < 1).float().mean().item(),
68 | '3px': (epe < 3).float().mean().item(),
69 | '5px': (epe < 5).float().mean().item(),
70 | }
71 |
72 | return flow_loss, metrics
73 |
74 |
75 | def count_parameters(model):
76 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
77 |
78 |
79 | def fetch_optimizer(args, model):
80 | """ Create the optimizer and learning rate scheduler """
81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
82 |
83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
85 |
86 | return optimizer, scheduler
87 |
88 |
89 | class Logger:
90 | def __init__(self, model, scheduler):
91 | self.model = model
92 | self.scheduler = scheduler
93 | self.total_steps = 0
94 | self.running_loss = {}
95 | self.writer = None
96 |
97 | def _print_training_status(self):
98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
101 |
102 | # print the training status
103 | print(training_str + metrics_str)
104 |
105 | if self.writer is None:
106 | self.writer = SummaryWriter()
107 |
108 | for k in self.running_loss:
109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
110 | self.running_loss[k] = 0.0
111 |
112 | def push(self, metrics):
113 | self.total_steps += 1
114 |
115 | for key in metrics:
116 | if key not in self.running_loss:
117 | self.running_loss[key] = 0.0
118 |
119 | self.running_loss[key] += metrics[key]
120 |
121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
122 | self._print_training_status()
123 | self.running_loss = {}
124 |
125 | def write_dict(self, results):
126 | if self.writer is None:
127 | self.writer = SummaryWriter()
128 |
129 | for key in results:
130 | self.writer.add_scalar(key, results[key], self.total_steps)
131 |
132 | def close(self):
133 | self.writer.close()
134 |
135 |
136 | def train(args):
137 |
138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
139 | print("Parameter Count: %d" % count_parameters(model))
140 |
141 | if args.restore_ckpt is not None:
142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
143 |
144 | model.cuda()
145 | model.train()
146 |
147 | if args.stage != 'chairs':
148 | model.module.freeze_bn()
149 |
150 | train_loader = datasets.fetch_dataloader(args)
151 | optimizer, scheduler = fetch_optimizer(args, model)
152 |
153 | total_steps = 0
154 | scaler = GradScaler(enabled=args.mixed_precision)
155 | logger = Logger(model, scheduler)
156 |
157 | VAL_FREQ = 5000
158 | add_noise = True
159 |
160 | should_keep_training = True
161 | while should_keep_training:
162 |
163 | for i_batch, data_blob in enumerate(train_loader):
164 | optimizer.zero_grad()
165 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
166 |
167 | if args.add_noise:
168 | stdv = np.random.uniform(0.0, 5.0)
169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
171 |
172 | flow_predictions = model(image1, image2, iters=args.iters)
173 |
174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
175 | scaler.scale(loss).backward()
176 | scaler.unscale_(optimizer)
177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
178 |
179 | scaler.step(optimizer)
180 | scheduler.step()
181 | scaler.update()
182 |
183 | logger.push(metrics)
184 |
185 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
187 | torch.save(model.state_dict(), PATH)
188 |
189 | results = {}
190 | for val_dataset in args.validation:
191 | if val_dataset == 'chairs':
192 | results.update(evaluate.validate_chairs(model.module))
193 | elif val_dataset == 'sintel':
194 | results.update(evaluate.validate_sintel(model.module))
195 | elif val_dataset == 'kitti':
196 | results.update(evaluate.validate_kitti(model.module))
197 |
198 | logger.write_dict(results)
199 |
200 | model.train()
201 | if args.stage != 'chairs':
202 | model.module.freeze_bn()
203 |
204 | total_steps += 1
205 |
206 | if total_steps > args.num_steps:
207 | should_keep_training = False
208 | break
209 |
210 | logger.close()
211 | PATH = 'checkpoints/%s.pth' % args.name
212 | torch.save(model.state_dict(), PATH)
213 |
214 | return PATH
215 |
216 |
217 | if __name__ == '__main__':
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument('--name', default='raft', help="name your experiment")
220 | parser.add_argument('--stage', help="determines which dataset to use for training")
221 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
222 | parser.add_argument('--small', action='store_true', help='use small model')
223 | parser.add_argument('--validation', type=str, nargs='+')
224 |
225 | parser.add_argument('--lr', type=float, default=0.00002)
226 | parser.add_argument('--num_steps', type=int, default=100000)
227 | parser.add_argument('--batch_size', type=int, default=6)
228 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
229 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
230 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
231 |
232 | parser.add_argument('--iters', type=int, default=12)
233 | parser.add_argument('--wdecay', type=float, default=.00005)
234 | parser.add_argument('--epsilon', type=float, default=1e-8)
235 | parser.add_argument('--clip', type=float, default=1.0)
236 | parser.add_argument('--dropout', type=float, default=0.0)
237 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
238 | parser.add_argument('--add_noise', action='store_true')
239 | args = parser.parse_args()
240 |
241 | torch.manual_seed(1234)
242 | np.random.seed(1234)
243 |
244 | if not os.path.isdir('checkpoints'):
245 | os.mkdir('checkpoints')
246 |
247 | train(args)
--------------------------------------------------------------------------------
/RAFT/train_mixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
7 |
--------------------------------------------------------------------------------
/RAFT/train_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
7 |
--------------------------------------------------------------------------------
/config/EMDiffuse-n-big.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "EMDiffuse-r", // experiments name
3 | "norm": true,
4 | "percent": false,
5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0
6 | "seed" : -1, // random seed, seed <0 represents randomization not used
7 | "finetune_norm": false, // find the parameters to optimize
8 | "task" : "denoise",
9 | "path": { //set every part file path
10 | "base_dir": "experiments", // base path for all log except resume_state
11 | "code": "code", // code backup
12 | "tb_logger": "tb_logger", // path of tensorboard logger
13 | "results": "results",
14 | "checkpoint": "checkpoint",
15 | "resume_state": "experiments/train_EMDiffuse-n-large_240125_221819/5180" // checkpoint path, set to null if used for training
16 | // "resume_state": "experiments/EMDiffuse-n/2720" // checkpoint path, set to null if used for training
17 | // "resume_state": null // checkpoint path, set to null if used for training
18 | },
19 |
20 | "datasets": { // train or test
21 | "train": {
22 | "which_dataset": { // import designated dataset using arguments
23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
24 | "args":{ // arguments to initialize dataset
25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf",
26 | "data_len": -1,
27 | "norm": true,
28 | "percent": false,
29 | "image_size": [768, 768]
30 | }
31 | },
32 | "dataloader":{
33 | "validation_split": 2, // percent or number
34 | "args":{ // arguments to initialize train_dataloader
35 | "batch_size": 3, // batch size in each gpu
36 | "num_workers": 4,
37 | "shuffle": true,
38 | "pin_memory": true,
39 | "drop_last": true
40 | },
41 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
42 | "batch_size": 1, // batch size in each gpu
43 | "num_workers": 4,
44 | "shuffle": false,
45 | "pin_memory": true,
46 | "drop_last": false
47 | }
48 | }
49 | },
50 | "test": {
51 | "which_dataset": {
52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file
53 | "args":{
54 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/test_wf",
55 | "norm":true,
56 | "percent": false,
57 | "phase": "val",
58 | "image_size": [768, 768]
59 | }
60 | },
61 | "dataloader":{
62 | "args":{
63 | "batch_size": 8,
64 | "num_workers": 0,
65 | "pin_memory": true
66 | }
67 | }
68 | }
69 | },
70 |
71 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
72 | "which_model": { // import designated model(trainer) using arguments
73 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
74 | "args": {
75 | "sample_num": 1, // process of each image
76 | "task": "denoise",
77 | "ema_scheduler": {
78 | "ema_start": 1,
79 | "ema_iter": 1,
80 | "ema_decay": 0.9999
81 | },
82 | "optimizers": [
83 | { "lr": 5e-5, "weight_decay": 0}
84 | ]
85 | }
86 | },
87 | "which_networks": [ // import designated list of networks using arguments
88 | {
89 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
90 | "args": { // arguments to initialize network
91 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
92 | "module_name": "guided_diffusion", // sr3 | guided_diffusion
93 | "norm": true,
94 | "unet": {
95 | "in_channel": 2,
96 | "out_channel": 1,
97 | "inner_channel": 32,
98 | "channel_mults": [
99 | 1,
100 | 2,
101 | 4,
102 | 8
103 | ],
104 | "attn_res": [
105 | // 32,
106 | 16
107 | // 8
108 | ],
109 | "num_head_channels": 32,
110 | "res_blocks": 2,
111 | "dropout": 0.2,
112 | "image_size": 256
113 | },
114 | "beta_schedule": {
115 | "train": {
116 | "schedule": "linear",
117 | "n_timestep": 2000,
118 | // "n_timestep": 5, // debug
119 | "linear_start": 1e-6,
120 | "linear_end": 0.01
121 | },
122 | "test": {
123 | "schedule": "linear",
124 | "n_timestep": 500,
125 | "linear_start": 1e-4,
126 | "linear_end": 0.09
127 | }
128 |
129 | }
130 | }
131 | }
132 | ],
133 | "which_losses": [ // import designated list of losses without arguments
134 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
135 | ],
136 | "which_metrics": [ // import designated list of metrics without arguments
137 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
138 | ]
139 | },
140 |
141 | "train": { // arguments for basic training
142 | "n_epoch": 1e8, // max epochs, not limited now
143 | "n_iter": 1e8, // max interations
144 | "val_epoch": 20, // valdation every specified number of epochs
145 | "save_checkpoint_epoch": 20,
146 | "log_iter": 1e4, // log every specified number of iterations
147 | "tensorboard" : true // tensorboardX enable
148 | },
149 |
150 | "debug": { // arguments in debug mode, which will replace arguments in train
151 | "val_epoch": 1,
152 | "save_checkpoint_epoch": 1,
153 | "log_iter": 10,
154 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split.
155 | }
156 | }
157 |
--------------------------------------------------------------------------------
/config/EMDiffuse-n-transfer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "EMDiffuse-n-NCtransfer_unfinetune", // experiments name
3 | "norm": true,
4 | "percent": false,
5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0
6 | "seed" : -1, // random seed, seed <0 represents randomization not used
7 | "finetune_norm": false, // find the parameters to optimize
8 | "task" : "denoise",
9 | "path": { //set every part file path
10 | "base_dir": "experiments", // base path for all log except resume_state
11 | "code": "code", // code backup
12 | "tb_logger": "tb_logger", // path of tensorboard logger
13 | "results": "results",
14 | "checkpoint": "checkpoint",
15 | "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training
16 | // "resume_state": "experiments/train_EMDiffuse-n-NCtransfer_231025_151833/checkpoint/3700" // checkpoint path, set to null if usedkua for training
17 | // "resume_state": "experiments/EMDiffuse-n/2720" // checkpoint path, set to null if used for training
18 | // "resume_state": null // checkpoint path, set to null if used for training
19 | },
20 | "datasets": { // train or test
21 | "train": {
22 | "which_dataset": { // import designated dataset using arguments
23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
24 | "args":{ // arguments to initialize dataset
25 | "data_root": "/data/cxlu/transfer/NC/denoise/train_wf",
26 | "data_len": -1,
27 | "norm": true,
28 | "percent": false
29 | }
30 | },
31 | "dataloader":{
32 | "validation_split": 20, // percent or number
33 | "args":{ // arguments to initialize train_dataloader
34 | "batch_size": 3, // batch size in each gpu
35 | "num_workers": 4,
36 | "shuffle": true,
37 | "pin_memory": true,
38 | "drop_last": true
39 | },
40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
41 | "batch_size": 10, // batch size in each gpu
42 | "num_workers": 4,
43 | "shuffle": false,
44 | "pin_memory": true,
45 | "drop_last": false
46 | }
47 | }
48 | },
49 | "test": {
50 | "which_dataset": {
51 |
52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file
53 | "args":{
54 | "data_root": "/data/cxlu/transfer/NC/denoise_test/test_wf",
55 | "norm":true,
56 | "percent": false,
57 | "phase": "val"
58 | }
59 | },
60 | "dataloader":{
61 | "args":{
62 | "batch_size": 8,
63 | "num_workers": 4,
64 | "pin_memory": true
65 | }
66 | }
67 | }
68 | },
69 |
70 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
71 | "which_model": { // import designated model(trainer) using arguments
72 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
73 | "args": {
74 | "sample_num": 8, // process of each image
75 | "task": "denoise",
76 | "ema_scheduler": {
77 | "ema_start": 1,
78 | "ema_iter": 1,
79 | "ema_decay": 0.9999
80 | },
81 | "optimizers": [
82 | { "lr": 5e-5, "weight_decay": 0}
83 | ]
84 | }
85 | },
86 | "which_networks": [ // import designated list of networks using arguments
87 | {
88 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
89 | "args": { // arguments to initialize network
90 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
91 | "module_name": "guided_diffusion", // sr3 | guided_diffusion
92 | "norm": true,
93 | "unet": {
94 | "in_channel": 2,
95 | "out_channel": 1,
96 | "inner_channel": 32,
97 | "channel_mults": [
98 | 1,
99 | 2,
100 | 4,
101 | 8
102 | ],
103 | "attn_res": [
104 | // 32,
105 | 16
106 | // 8
107 | ],
108 | "num_head_channels": 32,
109 | "res_blocks": 2,
110 | "dropout": 0.2,
111 | "image_size": 256
112 | },
113 | "beta_schedule": {
114 | "train": {
115 | "schedule": "linear",
116 | "n_timestep": 2000,
117 | // "n_timestep": 5, // debug
118 | "linear_start": 1e-6,
119 | "linear_end": 0.01
120 | },
121 | "test": {
122 | "schedule": "linear",
123 | "n_timestep": 1000,
124 | "linear_start": 1e-4,
125 | "linear_end": 0.09
126 | }
127 |
128 | }
129 | }
130 | }
131 | ],
132 | "which_losses": [ // import designated list of losses without arguments
133 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
134 | ],
135 | "which_metrics": [ // import designated list of metrics without arguments
136 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
137 | ]
138 | },
139 |
140 | "train": { // arguments for basic training
141 | "n_epoch": 1e8, // max epochs, not limited now
142 | "n_iter": 1e8, // max interations
143 | "val_epoch": 100, // valdation every specified number of epochs
144 | "save_checkpoint_epoch": 20,
145 | "log_iter": 1e4, // log every specified number of iterations
146 | "tensorboard" : true // tensorboardX enable
147 | },
148 |
149 | "debug": { // arguments in debug mode, which will replace arguments in train
150 | "val_epoch": 1,
151 | "save_checkpoint_epoch": 1,
152 | "log_iter": 10,
153 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split.
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/config/EMDiffuse-n.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "EMDiffuse-n", // experiments name
3 | "norm": true,
4 | "percent": false,
5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0
6 | "seed" : -1, // random seed, seed <0 represents randomization not used
7 | "finetune_norm": false, // find the parameters to optimize
8 | "task" : "denoise",
9 | "path": { //set every part file path
10 | "base_dir": "experiments", // base path for all log except resume_state
11 | "code": "code", // code backup
12 | "tb_logger": "tb_logger", // path of tensorboard logger
13 | "results": "results",
14 | "checkpoint": "checkpoint",
15 | // "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training
16 | "resume_state": "experiments/EMDiffuse-n/best" // checkpoint path, set to null if used for training
17 | // "resume_state": null // checkpoint path, set to null if used for training
18 | },
19 |
20 | "datasets": { // train or test
21 | "train": {
22 | "which_dataset": { // import designated dataset using arguments
23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
24 | "args":{ // arguments to initialize dataset
25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf",
26 | "data_len": -1,
27 | "norm": true,
28 | "percent": false
29 | }
30 | },
31 | "dataloader":{
32 | "validation_split": 0.1, // percent or number
33 | "args":{ // arguments to initialize train_dataloader
34 | "batch_size": 3, // batch size in each gpu
35 | "num_workers": 4,
36 | "shuffle": true,
37 | "pin_memory": true,
38 | "drop_last": true
39 | },
40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
41 | "batch_size": 10, // batch size in each gpu
42 | "num_workers": 4,
43 | "shuffle": false,
44 | "pin_memory": true,
45 | "drop_last": false
46 | }
47 | }
48 | },
49 | "test": {
50 | "which_dataset": {
51 |
52 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file
53 | "args":{
54 | "data_root": "/data/cxlu/denoise_single",
55 | "norm":true,
56 | "percent": false,
57 | "phase": "val"
58 | }
59 | },
60 | "dataloader":{
61 | "args":{
62 | "batch_size": 8,
63 | "num_workers": 0,
64 | "pin_memory": true
65 | }
66 | }
67 | }
68 | },
69 |
70 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
71 | "which_model": { // import designated model(trainer) using arguments
72 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
73 | "args": {
74 | "sample_num": 8, // process of each image
75 | "task": "denoise",
76 | "ema_scheduler": {
77 | "ema_start": 1,
78 | "ema_iter": 1,
79 | "ema_decay": 0.9999
80 | },
81 | "optimizers": [
82 | { "lr": 5e-5, "weight_decay": 0}
83 | ]
84 | }
85 | },
86 | "which_networks": [ // import designated list of networks using arguments
87 | {
88 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
89 | "args": { // arguments to initialize network
90 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
91 | "module_name": "guided_diffusion", // sr3 | guided_diffusion
92 | "norm": true,
93 | "unet": {
94 | "in_channel": 2,
95 | "out_channel": 1,
96 | "inner_channel": 32,
97 | "channel_mults": [
98 | 1,
99 | 2,
100 | 4,
101 | 8
102 | ],
103 | "attn_res": [
104 | // 32,
105 | 16
106 | // 8
107 | ],
108 | "num_head_channels": 32,
109 | "res_blocks": 2,
110 | "dropout": 0.2,
111 | "image_size": 256
112 | },
113 | "beta_schedule": {
114 | "train": {
115 | "schedule": "linear",
116 | "n_timestep": 2000,
117 | // "n_timestep": 5, // debug
118 | "linear_start": 1e-6,
119 | "linear_end": 0.01
120 | },
121 | "test": {
122 | "schedule": "linear",
123 | "n_timestep": 1000,
124 | "linear_start": 1e-4,
125 | "linear_end": 0.09
126 | }
127 |
128 | }
129 | }
130 | }
131 | ],
132 | "which_losses": [ // import designated list of losses without arguments
133 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
134 | ],
135 | "which_metrics": [ // import designated list of metrics without arguments
136 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
137 | ]
138 | },
139 |
140 | "train": { // arguments for basic training
141 | "n_epoch": 1e8, // max epochs, not limited now
142 | "n_iter": 1e8, // max interations
143 | "val_epoch": 20, // valdation every specified number of epochs
144 | "save_checkpoint_epoch": 20,
145 | "log_iter": 1e4, // log every specified number of iterations
146 | "tensorboard" : true // tensorboardX enable
147 | },
148 |
149 | "debug": { // arguments in debug mode, which will replace arguments in train
150 | "val_epoch": 1,
151 | "save_checkpoint_epoch": 1,
152 | "log_iter": 10,
153 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split.
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/config/EMDiffuse-r.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "EMDiffuse-r", // experiments name
3 | "norm": true,
4 | "percent": false,
5 | "gpu_ids": [0, 1], // gpu ids list, default is single 0
6 | "seed" : -1, // random seed, seed <0 represents randomization not used
7 | "finetune_norm": false, // find the parameters to optimize
8 | "task" : "denoise",
9 | "path": { //set every part file path
10 | "base_dir": "experiments", // base path for all log except resume_state
11 | "code": "code", // code backup
12 | "tb_logger": "tb_logger", // path of tensorboard logger
13 | "results": "results",
14 | "checkpoint": "checkpoint",
15 | // "resume_state": "experiments/train_EMDiffuse-n_230712_163715/checkpoint/2720" // checkpoint path, set to null if used for training
16 | "resume_state": "experiments/EMDiffuse-n/best" // checkpoint path, set to null if used for training
17 | // "resume_state": null // checkpoint path, set to null if used for training
18 | },
19 |
20 | "datasets": { // train or test
21 | "train": {
22 | "which_dataset": { // import designated dataset using arguments
23 | "name": ["data.dataset", "EMDiffusenDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
24 | "args":{ // arguments to initialize dataset
25 | "data_root": "/data/cxlu/macos_backup/EMDiffuse_dataset/denoise/train_wf",
26 | "data_len": -1,
27 | "norm": true,
28 | "percent": false
29 | }
30 | },
31 | "dataloader":{
32 | "validation_split": 0.1, // percent or number
33 | "args":{ // arguments to initialize train_dataloader
34 | "batch_size": 3, // batch size in each gpu
35 | "num_workers": 4,
36 | "shuffle": true,
37 | "pin_memory": true,
38 | "drop_last": true
39 | },
40 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
41 | "batch_size": 10, // batch size in each gpu
42 | "num_workers": 4,
43 | "shuffle": false,
44 | "pin_memory": true,
45 | "drop_last": false
46 | }
47 | }
48 | },
49 | "test": {
50 | "which_dataset": {
51 | "name": "EMDiffusenDataset", // import Dataset() class / function(not recommend) from default file
52 | "args":{
53 | "data_root": "/data/cxlu/denoise_single",
54 | "norm":true,
55 | "percent": false,
56 | "phase": "val"
57 | }
58 | },
59 | "dataloader":{
60 | "args":{
61 | "batch_size": 8,
62 | "num_workers": 0,
63 | "pin_memory": true
64 | }
65 | }
66 | }
67 | },
68 |
69 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
70 | "which_model": { // import designated model(trainer) using arguments
71 | "name": ["models.EMDiffuse_model", "DiReP"], // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
72 | "args": {
73 | "sample_num": 8, // process of each image
74 | "task": "denoise",
75 | "ema_scheduler": {
76 | "ema_start": 1,
77 | "ema_iter": 1,
78 | "ema_decay": 0.9999
79 | },
80 | "optimizers": [
81 | { "lr": 5e-5, "weight_decay": 0}
82 | ]
83 | }
84 | },
85 | "which_networks": [ // import designated list of networks using arguments
86 | {
87 | "name": ["models.EMDiffuse_network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
88 | "args": { // arguments to initialize network
89 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
90 | "module_name": "guided_diffusion", // sr3 | guided_diffusion
91 | "norm": true,
92 | "unet": {
93 | "in_channel": 2,
94 | "out_channel": 1,
95 | "inner_channel": 32,
96 | "channel_mults": [
97 | 1,
98 | 2,
99 | 4,
100 | 8
101 | ],
102 | "attn_res": [
103 | // 32,
104 | 16
105 | // 8
106 | ],
107 | "num_head_channels": 32,
108 | "res_blocks": 2,
109 | "dropout": 0.2,
110 | "image_size": 256
111 | },
112 | "beta_schedule": {
113 | "train": {
114 | "schedule": "linear",
115 | "n_timestep": 2000,
116 | // "n_timestep": 5, // debug
117 | "linear_start": 1e-6,
118 | "linear_end": 0.01
119 | },
120 | "test": {
121 | "schedule": "linear",
122 | "n_timestep": 1000,
123 | "linear_start": 1e-4,
124 | "linear_end": 0.09
125 | }
126 |
127 | }
128 | }
129 | }
130 | ],
131 | "which_losses": [ // import designated list of losses without arguments
132 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
133 | ],
134 | "which_metrics": [ // import designated list of metrics without arguments
135 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
136 | ]
137 | },
138 |
139 | "train": { // arguments for basic training
140 | "n_epoch": 1e8, // max epochs, not limited now
141 | "n_iter": 1e8, // max interations
142 | "val_epoch": 20, // valdation every specified number of epochs
143 | "save_checkpoint_epoch": 20,
144 | "log_iter": 1e4, // log every specified number of iterations
145 | "tensorboard" : true // tensorboardX enable
146 | },
147 |
148 | "debug": { // arguments in debug mode, which will replace arguments in train
149 | "val_epoch": 1,
150 | "save_checkpoint_epoch": 1,
151 | "log_iter": 10,
152 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split.
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/config/vEMDiffuse-a.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "vEMDiffuse-a",
3 | // experiments name
4 | "norm": true,
5 | "percent": false,
6 | "gpu_ids": [
7 | 0,
8 | 1
9 | ],
10 | // gpu ids list, default is single 0
11 | "seed": -1,
12 | // random seed, seed <0 represents randomization not used
13 | "finetune_norm": false,
14 | // find the parameters to optimize
15 | "task": "3d_reconstruction",
16 | "path": {
17 | //set every part file path
18 | "base_dir": "experiments",
19 | // base path for all log except resume_state
20 | "code": "code",
21 | // code backup
22 | "tb_logger": "tb_logger",
23 | // path of tensorboard logger
24 | "results": "results",
25 | "checkpoint": "checkpoint",
26 | // "resume_state": "experiments/emdiffusie-a-phlep/4860"
27 | "resume_state": "experiments/vEMDiffuse-a/best"
28 |
29 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration
30 | },
31 | "datasets": {
32 | // train or test
33 | "train": {
34 | "which_dataset": {
35 | // import designated dataset using arguments
36 | "name": [
37 | "data.dataset",
38 | "vEMDiffuseTrainingDatasetVolume"
39 | ],
40 | // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
41 | "args": {
42 | // arguments to initialize dataset
43 | "data_root": "/data/cxlu/phelps_test_patches_6144/",
44 | "data_len": -1,
45 | "norm": true,
46 | "percent": false,
47 | "z_times": 10,
48 | "method": "vEMDiffuse-a",
49 | "image_size": [256, 256]
50 | }
51 | },
52 | "dataloader": {
53 | "validation_split": 20,
54 | // percent or number
55 | "args": {
56 | // arguments to initialize train_dataloader
57 | "batch_size": 3,
58 | // batch size in each gpu
59 | "num_workers": 2,
60 | "shuffle": true,
61 | "pin_memory": false,
62 | "drop_last": true
63 | },
64 | "val_args": {
65 | // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
66 | "batch_size": 10,
67 | // batch size in each gpu
68 | "num_workers": 2,
69 | "shuffle": false,
70 | "pin_memory": false,
71 | "drop_last": false
72 | }
73 | }
74 | },
75 | "test": {
76 | "which_dataset": {
77 | "name": "vEMDiffuseTestAnIsotropic",
78 | // import Dataset() class / function(not recommend) from default file
79 | "args": {
80 | // "data_root": "/data/cxlu/phelps_test_patches_2048/",
81 | "data_root": "/mnt/sdb/cxlu/phelps_test_patches_6144/",
82 | "norm": true,
83 | "percent": false,
84 | "phase": "val",
85 | "z_times": 10
86 | }
87 | },
88 | "dataloader": {
89 | "args": {
90 | "batch_size": 8,
91 | "num_workers": 0,
92 | "pin_memory": true
93 | }
94 | }
95 | }
96 | },
97 | "model": {
98 | // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
99 | "which_model": {
100 | // import designated model(trainer) using arguments
101 | "name": [
102 | "models.vEMDiffuse_model",
103 | "DiReP"
104 | ],
105 | // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
106 | "args": {
107 | "sample_num": 8,
108 | // process of each image
109 | "task": "3d_reconstruct",
110 | "ema_scheduler": {
111 | "ema_start": 1,
112 | "ema_iter": 1,
113 | "ema_decay": 0.9999
114 | },
115 | "optimizers": [
116 | {
117 | "lr": 5e-5,
118 | "weight_decay": 0
119 | }
120 | ]
121 | }
122 | },
123 | "which_networks": [
124 | // import designated list of networks using arguments
125 | {
126 | "name": [
127 | "models.vEMDiffuse_network",
128 | "Network"
129 | ],
130 | // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
131 | "args": {
132 | // arguments to initialize network
133 | "init_type": "kaiming",
134 | // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
135 | "module_name": "guided_diffusion_3d_2d",
136 | // sr3 | guided_diffusion
137 | "norm": true,
138 | "unet": {
139 | "in_channel": 3,
140 | "out_channel": 1,
141 | "inner_channel": 32,
142 | "channel_mults": [
143 | 1,
144 | 2,
145 | 4,
146 | 8
147 | ],
148 | "attn_res": [
149 | // 32,
150 | 16
151 | // 8
152 | ],
153 | "num_head_channels": 32,
154 | "res_blocks": 2,
155 | "dropout": 0.2,
156 | "image_size": 256
157 | },
158 | "beta_schedule": {
159 | "train": {
160 | "schedule": "linear",
161 | "n_timestep": 2000,
162 | // "n_timestep": 5, // debug
163 | "linear_start": 1e-6,
164 | "linear_end": 0.01
165 | },
166 | "test": {
167 | "schedule": "linear",
168 | "n_timestep": 1000,
169 | // "n_timestep": 5, // debug
170 | "linear_start": 1e-4,
171 | "linear_end": 0.09
172 | }
173 | }
174 | }
175 | }
176 | ],
177 | "which_losses": [
178 | // import designated list of losses without arguments
179 | "mse_loss"
180 | // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
181 | ],
182 | "which_metrics": [
183 | // import designated list of metrics without arguments
184 | "mae"
185 | // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
186 | ]
187 | },
188 | "train": {
189 | // arguments for basic training
190 | "n_epoch": 1e8,
191 | // max epochs, not limited now
192 | "n_iter": 1e8,
193 | // max interations
194 | "val_epoch": 20,
195 | // valdation every specified number of epochs
196 | "save_checkpoint_epoch": 20,
197 | "log_iter": 1e4,
198 | // log every specified number of iterations
199 | "tensorboard": true
200 | // tensorboardX enable
201 | },
202 | "debug": {
203 | // arguments in debug mode, which will replace arguments in train
204 | "val_epoch": 1,
205 | "save_checkpoint_epoch": 1,
206 | "log_iter": 10,
207 | "debug_split": 50
208 | // percent or number, change the size of dataloder to debug_split.
209 | }
210 | }
211 |
--------------------------------------------------------------------------------
/config/vEMDiffuse-i.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "vEMDiffuse-i",
3 | // experiments name
4 | "norm": true,
5 | "percent": false,
6 | "gpu_ids": [
7 | 0,
8 | 1
9 | ],
10 | // gpu ids list, default is single 0
11 | "seed": -1,
12 | // random seed, seed <0 represents randomization not used
13 | "finetune_norm": false,
14 | // find the parameters to optimize
15 | "task": "3d_reconstruction",
16 | "path": {
17 | //set every part file path
18 | "base_dir": "experiments",
19 | // base path for all log except resume_state
20 | "code": "code",
21 | // code backup
22 | "tb_logger": "tb_logger",
23 | // path of tensorboard logger
24 | "results": "results",
25 | "checkpoint": "checkpoint",
26 | "resume_state": "experiments/vEMDiffuse-i/best"
27 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration
28 | },
29 | "datasets": {
30 | // train or test
31 | "train": {
32 | "which_dataset": {
33 | // import designated dataset using arguments
34 | "name": [
35 | "data.dataset",
36 | "vEMDiffuseTrainingDatasetVolume"
37 | ],
38 | // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py])
39 | "args": {
40 | // arguments to initialize dataset
41 | "data_root": "/data/cxlu/liver_3d_inter3_continous_clean/train_img",
42 | "data_len": -1,
43 | "norm": true,
44 | "percent": false,
45 | "z_times": 6,
46 | "method": "vEMDiffuse-i"
47 | }
48 | },
49 | "dataloader": {
50 | "validation_split": 20,
51 | // percent or number
52 | "args": {
53 | // arguments to initialize train_dataloader
54 | "batch_size": 3,
55 | // batch size in each gpu
56 | "num_workers": 2,
57 | "shuffle": true,
58 | "pin_memory": false,
59 | "drop_last": true
60 | },
61 | "val_args": {
62 | // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader
63 | "batch_size": 10,
64 | // batch size in each gpu
65 | "num_workers": 2,
66 | "shuffle": false,
67 | "pin_memory": false,
68 | "drop_last": false
69 | }
70 | }
71 | },
72 | "test": {
73 | "which_dataset": {
74 | "name": "vEMDiffuseTestAnIsotropic",
75 | // import Dataset() class / function(not recommend) from default file
76 | "args": {
77 | "data_root": "/data/cxlu/liver_3d_test_patches",
78 | // "data_root": "/lustre1/g/chem_jianglab/cxlu/kai_3d/test_patches",
79 | "norm": true,
80 | "percent": false,
81 | "phase": "val",
82 | "z_times": 6
83 | }
84 | },
85 | "dataloader": {
86 | "args": {
87 | "batch_size": 8,
88 | "num_workers": 0,
89 | "pin_memory": true
90 | }
91 | }
92 | }
93 | },
94 | "model": {
95 | // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict
96 | "which_model": {
97 | // import designated model(trainer) using arguments
98 | "name": [
99 | "models.vEMDiffuse_model",
100 | "DiReP"
101 | ],
102 | // import Model() class / function(not recommend) from models.EMDiffuse_model.py (default is [models.EMDiffuse_model.py])
103 | "args": {
104 | "sample_num": 8,
105 | // process of each image
106 | "task": "3d_reconstruct",
107 | "ema_scheduler": {
108 | "ema_start": 1,
109 | "ema_iter": 1,
110 | "ema_decay": 0.9999
111 | },
112 | "optimizers": [
113 | {
114 | "lr": 5e-5,
115 | "weight_decay": 0
116 | }
117 | ]
118 | }
119 | },
120 | "which_networks": [
121 | // import designated list of networks using arguments
122 | {
123 | "name": [
124 | "models.vEMDiffuse_network",
125 | "Network"
126 | ],
127 | // import Network() class / function(not recommend) from default file (default is [models/EMDiffuse_network.py])
128 | "args": {
129 | // arguments to initialize network
130 | "init_type": "kaiming",
131 | // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming
132 | "module_name": "guided_diffusion_3d_2d",
133 | // sr3 | guided_diffusion
134 | "norm": true,
135 | "unet": {
136 | "in_channel": 3,
137 | "out_channel": 1,
138 | "inner_channel": 32,
139 | "channel_mults": [
140 | 1,
141 | 2,
142 | 4,
143 | 8
144 | ],
145 | "attn_res": [
146 | // 32,
147 | 16
148 | // 8
149 | ],
150 | "num_head_channels": 32,
151 | "res_blocks": 2,
152 | "dropout": 0.2,
153 | "image_size": 256
154 | },
155 | "beta_schedule": {
156 | "train": {
157 | "schedule": "linear",
158 | "n_timestep": 2000,
159 | // "n_timestep": 5, // debug
160 | "linear_start": 1e-6,
161 | "linear_end": 0.01
162 | },
163 | "test": {
164 | "schedule": "linear",
165 | "n_timestep": 1000,
166 | // "n_timestep": 5, // debug
167 | "linear_start": 1e-4,
168 | "linear_end": 0.09
169 | }
170 | }
171 | }
172 | }
173 | ],
174 | "which_losses": [
175 | // import designated list of losses without arguments
176 | "mse_loss"
177 | // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
178 | ],
179 | "which_metrics": [
180 | // import designated list of metrics without arguments
181 | "mae"
182 | // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}}
183 | ]
184 | },
185 | "train": {
186 | // arguments for basic training
187 | "n_epoch": 1e8,
188 | // max epochs, not limited now
189 | "n_iter": 1e8,
190 | // max interations
191 | "val_epoch": 20,
192 | // valdation every specified number of epochs
193 | "save_checkpoint_epoch": 20,
194 | "log_iter": 1e4,
195 | // log every specified number of iterations
196 | "tensorboard": true
197 | // tensorboardX enable
198 | },
199 | "debug": {
200 | // arguments in debug mode, which will replace arguments in train
201 | "val_epoch": 1,
202 | "save_checkpoint_epoch": 1,
203 | "log_iter": 10,
204 | "debug_split": 50
205 | // percent or number, change the size of dataloder to debug_split.
206 | }
207 | }
208 |
--------------------------------------------------------------------------------
/core/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/core/__pycache__/base_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/base_network.cpython-37.pyc
--------------------------------------------------------------------------------
/core/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/core/__pycache__/praser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/praser.cpython-37.pyc
--------------------------------------------------------------------------------
/core/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/core/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/core/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from torchvision import transforms
3 | from PIL import Image
4 | import os
5 | import numpy as np
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP','.tif'
10 | ]
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14 |
15 | def make_dataset(dir):
16 | if os.path.isfile(dir):
17 | images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')]
18 | else:
19 | images = []
20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
21 | for root, _, fnames in sorted(os.walk(dir)):
22 | for fname in sorted(fnames):
23 | if is_image_file(fname):
24 | path = os.path.join(root, fname)
25 | images.append(path)
26 |
27 | return images
28 |
29 | def pil_loader(path):
30 | return Image.open(path).convert('RGB')
31 |
32 | class BaseDataset(data.Dataset):
33 | def __init__(self, data_root, image_size=[256, 256], loader=pil_loader):
34 | self.imgs = make_dataset(data_root)
35 | self.tfs = transforms.Compose([
36 | transforms.Resize((image_size[0], image_size[1])),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
39 | ])
40 | self.loader = loader
41 |
42 | def __getitem__(self, index):
43 | path = self.imgs[index]
44 | img = self.tfs(self.loader(path))
45 | return img
46 |
47 | def __len__(self):
48 | return len(self.imgs)
49 |
--------------------------------------------------------------------------------
/core/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import abstractmethod
3 | from functools import partial
4 | import collections
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | import core.util as Util
11 | CustomResult = collections.namedtuple('CustomResult', 'name result')
12 |
13 | class BaseModel():
14 |
15 | def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer):
16 | """ init model with basic input, which are from __init__(**kwargs) function in inherited class """
17 | self.opt = opt
18 | self.phase = opt['phase']
19 | self.set_device = partial(Util.set_device, rank=opt['global_rank'])
20 | self.mean = opt['mean'] if 'mean' in opt.keys() else 1
21 | ''' optimizers and schedulers '''
22 | self.schedulers = []
23 | self.optimizers = []
24 | ''' process record '''
25 | self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size']
26 | self.epoch = 0
27 | self.transfer_epoch = 0
28 | self.iter = 0
29 | self.phase_loader = phase_loader
30 | self.val_loader = val_loader
31 | self.metrics = metrics
32 |
33 | ''' logger to log file, which only work on GPU 0. writer to tensorboard and result file '''
34 | self.logger = logger
35 | self.writer = writer
36 | self.results_dict = CustomResult([],[]) # {"name":[], "result":[]}
37 |
38 | def train(self):
39 | # val_log = self.val_step()
40 | while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']:
41 | self.epoch += 1
42 | if self.opt['distributed']:
43 | ''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch '''
44 | self.phase_loader.sampler.set_epoch(self.epoch)
45 |
46 | train_log = self.train_step()
47 |
48 | ''' save logged informations into log dict '''
49 | print('epoch {}: training start'.format(self.epoch))
50 | train_log.update({'epoch': self.epoch, 'iters': self.iter})
51 |
52 | ''' print logged informations to the screen and tensorboard '''
53 | for key, value in train_log.items():
54 | self.logger.info('{:5s}: {}\t'.format(str(key), value))
55 | if self.epoch < 500:
56 | if self.epoch % 100 == 0:
57 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch))
58 | self.save_everything()
59 | if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0 and self.epoch > 500:
60 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch))
61 | self.save_everything()
62 |
63 | if self.epoch % self.opt['train']['val_epoch'] == 0 and self.epoch > 500:
64 | self.logger.info("\n\n\n------------------------------Validation Start------------------------------")
65 | if self.val_loader is None:
66 | self.logger.warning('Validation stop where dataloader is None, Skip it.')
67 | else:
68 | val_log = self.val_step()
69 | for key, value in val_log.items():
70 | self.logger.info('{:5s}: {}\t'.format(str(key), value))
71 | self.logger.info("\n------------------------------Validation End------------------------------\n\n")
72 | self.logger.info('Number of Epochs has reached the limit, End.')
73 |
74 | def test(self):
75 | pass
76 |
77 | @abstractmethod
78 | def train_step(self):
79 | raise NotImplementedError('You must specify how to train your networks.')
80 |
81 | @abstractmethod
82 | def val_step(self):
83 | raise NotImplementedError('You must specify how to do validation on your networks.')
84 |
85 | def test_step(self):
86 | pass
87 |
88 | def print_network(self, network):
89 | """ print network structure, only work on GPU 0 """
90 | if self.opt['global_rank'] !=0:
91 | return
92 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
93 | network = network.module
94 |
95 | s, n = str(network), sum(map(lambda x: x.numel(), network.parameters()))
96 | net_struc_str = '{}'.format(network.__class__.__name__)
97 | self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
98 | self.logger.info(s)
99 |
100 | def save_network(self, network, network_label):
101 | """ save network structure, only work on GPU 0 """
102 | if self.opt['global_rank'] !=0:
103 | return
104 | save_filename = '{}_{}.pth'.format(self.epoch, network_label)
105 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename)
106 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
107 | network = network.module
108 | state_dict = network.state_dict()
109 | for key, param in state_dict.items():
110 | state_dict[key] = param.cpu()
111 | torch.save(state_dict, save_path)
112 |
113 | def load_network(self, network, network_label, strict=True):
114 | if self.opt['path']['resume_state'] is None:
115 | return
116 | self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label))
117 |
118 | model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label)
119 |
120 | if not os.path.exists(model_path):
121 | self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path))
122 | return
123 |
124 | self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path))
125 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
126 | network = network.module
127 | network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict)
128 |
129 | def save_training_state(self):
130 | """ saves training state during training, only work on GPU 0 """
131 | if self.opt['global_rank'] !=0:
132 | return
133 |
134 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'
135 | state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []}
136 | for s in self.schedulers:
137 | state['schedulers'].append(s.state_dict())
138 | for o in self.optimizers:
139 | state['optimizers'].append(o.state_dict())
140 | save_filename = '{}.state'.format(self.epoch)
141 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename)
142 | torch.save(state, save_path)
143 |
144 | def resume_training(self):
145 | """ resume the optimizers and schedulers for training, only work when phase is test or resume training enable """
146 | if self.phase!='train' or self. opt['path']['resume_state'] is None:
147 | return
148 | self.logger.info('Beign loading training states'.format())
149 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'
150 |
151 | state_path = "{}.state".format(self. opt['path']['resume_state'])
152 |
153 | if not os.path.exists(state_path):
154 | self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path))
155 | return
156 |
157 | self.logger.info('Loading training state for [{:s}] ...'.format(state_path))
158 | resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage))
159 |
160 | resume_optimizers = resume_state['optimizers']
161 | resume_schedulers = resume_state['schedulers']
162 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers))
163 | # assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers))
164 | # for i, o in enumerate(resume_optimizers):
165 | if len(resume_schedulers)== len(self.schedulers):
166 | # self.optimizers[i].load_state_dict(o)
167 | for i, s in enumerate(resume_schedulers):
168 | self.schedulers[i].load_state_dict(s)
169 |
170 | self.epoch = resume_state['epoch']
171 | self.transfer_epoch = resume_state['epoch']
172 | self.iter = resume_state['iter']
173 |
174 | def load_everything(self):
175 | pass
176 |
177 | @abstractmethod
178 | def save_everything(self):
179 | raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.')
180 |
--------------------------------------------------------------------------------
/core/base_network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class BaseNetwork(nn.Module):
4 | def __init__(self, init_type='kaiming', gain=0.02):
5 | super(BaseNetwork, self).__init__()
6 | self.init_type = init_type
7 | self.gain = gain
8 |
9 |
10 | def init_weights(self):
11 | """
12 | initialize network's weights
13 | init_type: normal | xavier | kaiming | orthogonal
14 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
15 | """
16 |
17 | def init_func(m):
18 | classname = m.__class__.__name__
19 | if classname.find('InstanceNorm2d') != -1:
20 | if hasattr(m, 'weight') and m.weight is not None:
21 | nn.init.constant_(m.weight.data, 1.0)
22 | if hasattr(m, 'bias') and m.bias is not None:
23 | nn.init.constant_(m.bias.data, 0.0)
24 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
25 | if self.init_type == 'normal':
26 | nn.init.normal_(m.weight.data, 0.0, self.gain)
27 | elif self.init_type == 'xavier':
28 | nn.init.xavier_normal_(m.weight.data, gain=self.gain)
29 | elif self.init_type == 'xavier_uniform':
30 | nn.init.xavier_uniform_(m.weight.data, gain=1.0)
31 | elif self.init_type == 'kaiming':
32 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
33 | elif self.init_type == 'orthogonal':
34 | nn.init.orthogonal_(m.weight.data, gain=self.gain)
35 | elif self.init_type == 'none': # uses pytorch's default init method
36 | m.reset_parameters()
37 | else:
38 | raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type)
39 | if hasattr(m, 'bias') and m.bias is not None:
40 | nn.init.constant_(m.bias.data, 0.0)
41 |
42 | self.apply(init_func)
43 | # propagate to children
44 | for m in self.children():
45 | if hasattr(m, 'init_weights'):
46 | m.init_weights(self.init_type, self.gain)
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/core/calibration.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.optimize import brentq
4 | from scipy.stats import binom
5 |
6 |
7 | def fraction_missed_loss(lower_bound, upper_bound, ground_truth, avg_channels=True):
8 | misses = (lower_bound > ground_truth).float() + (upper_bound < ground_truth).float()
9 | misses[misses > 1.0] = 1.0
10 | if avg_channels:
11 | return misses.mean(), misses
12 | else:
13 | return misses.mean(dim=(1, 2)), misses
14 |
15 |
16 | def get_rcps_losses_from_outputs(cal_l, cal_u, ground_truth, lam):
17 | risk, misses = fraction_missed_loss(cal_l / lam, cal_u * lam, ground_truth)
18 | return risk
19 |
20 |
21 | def h1(y, mu):
22 | return y * np.log(y / mu) + (1 - y) * np.log((1 - y) / (1 - mu))
23 |
24 |
25 | ### Log tail inequalities of mean
26 | def hoeffding_plus(mu, x, n):
27 | return -n * h1(np.maximum(mu, x), mu)
28 |
29 |
30 | def bentkus_plus(mu, x, n):
31 | return np.log(max(binom.cdf(np.floor(n * x), n, mu), 1e-10)) + 1
32 |
33 |
34 | def HB_mu_plus(muhat, n, delta, maxiters=1000):
35 | def _tailprob(mu):
36 | hoeffding_mu = hoeffding_plus(mu, muhat, n)
37 | bentkus_mu = bentkus_plus(mu, muhat, n)
38 | return min(hoeffding_mu, bentkus_mu) - np.log(delta)
39 |
40 | if _tailprob(1 - 1e-10) > 0:
41 | return 1
42 | else:
43 | try:
44 | return brentq(_tailprob, muhat, 1 - 1e-10, maxiter=maxiters)
45 | except:
46 | print(f"BRENTQ RUNTIME ERROR at muhat={muhat}")
47 | return 1.0
48 |
49 |
50 | def calibrate_model(cal_l, cal_u, ground_truth):
51 | alpha = 0.1
52 | delta = 0.1
53 | minimum_lambda = 0.9
54 | maximum_lambda = 1.3
55 | num_lambdas = 1000
56 |
57 | lambdas = torch.linspace(minimum_lambda, maximum_lambda, num_lambdas)
58 | dlambda = lambdas[1] - lambdas[0]
59 | lambda_hat = (lambdas[-1] + dlambda - 1e-9)
60 |
61 | for lam in reversed(lambdas):
62 | losses = get_rcps_losses_from_outputs(cal_l, cal_u, ground_truth, lam=(lam - dlambda))
63 |
64 | Rhat = losses
65 | # print(cal_l.shape)
66 | RhatPlus = HB_mu_plus(Rhat.item(), cal_l.shape[0] * cal_l.shape[1] * cal_l.shape[2] * cal_l.shape[3], delta)
67 |
68 | print(f"\rLambda: {lam:.4f} | Rhat: {Rhat:.4f} | RhatPlus: {RhatPlus:.4f} ", end='')
69 | if Rhat >= alpha or RhatPlus > alpha:
70 | lambda_hat = lam
71 | print(f"Model's lambda_hat is {lambda_hat}")
72 | break
73 | return lambda_hat, (cal_l / lambda_hat).clamp(0., 1.), (cal_u * lambda_hat).clamp(0., 1.)
74 |
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import importlib
4 | from datetime import datetime
5 | import logging
6 | import pandas as pd
7 |
8 | import core.util as Util
9 |
10 | class InfoLogger():
11 | """
12 | use logging to record log, only work on GPU 0 by judging global_rank
13 | """
14 | def __init__(self, opt):
15 | self.opt = opt
16 | self.rank = opt['global_rank']
17 | self.phase = opt['phase']
18 |
19 | self.setup_logger(None, opt['path']['experiments_root'], opt['phase'], level=logging.INFO, screen=False)
20 | self.logger = logging.getLogger(opt['phase'])
21 | self.infologger_ftns = {'info', 'warning', 'debug'}
22 |
23 | def __getattr__(self, name):
24 | if self.rank != 0: # info only print on GPU 0.
25 | def wrapper(info, *args, **kwargs):
26 | pass
27 | return wrapper
28 | if name in self.infologger_ftns:
29 | print_info = getattr(self.logger, name, None)
30 | def wrapper(info, *args, **kwargs):
31 | print_info(info, *args, **kwargs)
32 | return wrapper
33 |
34 | @staticmethod
35 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
36 | """ set up logger """
37 | l = logging.getLogger(logger_name)
38 | formatter = logging.Formatter(
39 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
40 | log_file = os.path.join(root, '{}.log'.format(phase))
41 | fh = logging.FileHandler(log_file, mode='a+')
42 | fh.setFormatter(formatter)
43 | l.setLevel(level)
44 | l.addHandler(fh)
45 | if screen:
46 | sh = logging.StreamHandler()
47 | sh.setFormatter(formatter)
48 | l.addHandler(sh)
49 |
50 | class VisualWriter():
51 | """
52 | use tensorboard to record visuals, support 'add_scalar', 'add_scalars', 'add_image', 'add_images', etc. funtion.
53 | Also integrated with save results function.
54 | """
55 | def __init__(self, opt, logger):
56 | log_dir = opt['path']['tb_logger']
57 | self.result_dir = opt['path']['results']
58 | enabled = opt['train']['tensorboard']
59 | self.rank = opt['global_rank']
60 | self.task = opt['task']
61 |
62 | self.writer = None
63 | self.selected_module = ""
64 |
65 | if enabled and self.rank==0:
66 | log_dir = str(log_dir)
67 |
68 | # Retrieve vizualization writer.
69 | succeeded = False
70 | for module in ["tensorboardX", "torch.utils.tensorboard"]:
71 | try:
72 | self.writer = importlib.import_module(module).SummaryWriter(log_dir)
73 | succeeded = True
74 | break
75 | except ImportError:
76 | succeeded = False
77 | self.selected_module = module
78 |
79 | if not succeeded:
80 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
81 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
82 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
83 | logger.warning(message)
84 |
85 | self.epoch = 0
86 | self.iter = 0
87 | self.phase = ''
88 |
89 | self.tb_writer_ftns = {
90 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
91 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
92 | }
93 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
94 | self.custom_ftns = {'close'}
95 | self.timer = datetime.now()
96 |
97 | def set_iter(self, epoch, iter, phase='train'):
98 | self.phase = phase
99 | self.epoch = epoch
100 | self.iter = iter
101 |
102 | def save_images(self, results, norm=True, percent=False):
103 | result_path = os.path.join(self.result_dir, self.phase)
104 | os.makedirs(result_path, exist_ok=True)
105 | result_path = os.path.join(result_path, str(self.epoch))
106 | os.makedirs(result_path, exist_ok=True)
107 | from tifffile import imwrite
108 | import numpy as np
109 | ''' get names and corresponding images from results[OrderedDict] '''
110 | try:
111 | names = results['name']
112 | outputs = Util.postprocess(results['result'], out_type=np.uint8, min_max=(-1, 1), norm=norm)
113 | for i in range(len(names)):
114 | Image.fromarray(outputs[i]).save(os.path.join(result_path, names[i]))
115 | except:
116 | raise NotImplementedError('You must specify the context of name and result in save_current_results functions of model.')
117 |
118 | def close(self):
119 | self.writer.close()
120 | print('Close the Tensorboard SummaryWriter.')
121 |
122 |
123 | def __getattr__(self, name):
124 | """
125 | If visualization is configured to use:
126 | return add_data() methods of tensorboard with additional information (step, tag) added.
127 | Otherwise:
128 | return a blank function handle that does nothing
129 | """
130 | if name in self.tb_writer_ftns:
131 | add_data = getattr(self.writer, name, None)
132 | def wrapper(tag, data, *args, **kwargs):
133 | if add_data is not None:
134 | # add phase(train/valid) tag
135 | if name not in self.tag_mode_exceptions:
136 | tag = '{}/{}'.format(self.phase, tag)
137 | add_data(tag, data, self.iter, *args, **kwargs)
138 | return wrapper
139 | else:
140 | # default action for returning methods defined in this class, set_step() for instance.
141 | try:
142 | attr = object.__getattr__(name)
143 | except AttributeError:
144 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
145 | return attr
146 |
147 |
148 | class LogTracker:
149 | """
150 | record training numerical indicators.
151 | """
152 | def __init__(self, *keys, phase='train'):
153 | self.phase = phase
154 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
155 | self.reset()
156 |
157 | def reset(self):
158 | for col in self._data.columns:
159 | self._data[col].values[:] = 0
160 |
161 | def update(self, key, value, n=1):
162 | self._data.total[key] += value * n
163 | self._data.counts[key] += n
164 | self._data.average[key] = self._data.total[key] / self._data.counts[key]
165 |
166 | def avg(self, key):
167 | return self._data.average[key]
168 |
169 | def result(self):
170 | return {'{}/{}'.format(self.phase, k):v for k, v in dict(self._data.average).items()}
171 |
--------------------------------------------------------------------------------
/core/praser.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import json
4 | from pathlib import Path
5 | from datetime import datetime
6 | from functools import partial
7 | import importlib
8 | from types import FunctionType
9 | import shutil
10 |
11 |
12 | def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network',
13 | **modify_kwargs):
14 | """
15 | finds a function handle with the name given as 'name' in config,
16 | and returns the instance initialized with corresponding args.
17 | """
18 | if opt is None or len(opt) < 1:
19 | logger.info('Option is None when initialize {}'.format(init_type))
20 | return None
21 |
22 | ''' default format is dict with name key '''
23 | if isinstance(opt, str):
24 | opt = {'name': opt}
25 | logger.warning('Config is a str, converts to a dict {}'.format(opt))
26 |
27 | name = opt['name']
28 | ''' name can be list, indicates the file and class name of function '''
29 | if isinstance(name, list):
30 | file_name, class_name = name[0], name[1]
31 | else:
32 | file_name, class_name = default_file_name, name
33 |
34 | if given_module is not None:
35 | module = given_module
36 | else:
37 | module = importlib.import_module(file_name)
38 |
39 | attr = getattr(module, class_name)
40 | kwargs = opt.get('args', {})
41 | kwargs.update(modify_kwargs)
42 | ''' import class or function with args '''
43 | if isinstance(attr, type):
44 | ret = attr(*args, **kwargs)
45 | ret.__name__ = ret.__class__.__name__
46 | elif isinstance(attr, FunctionType):
47 | ret = partial(attr, *args, **kwargs)
48 | ret.__name__ = attr.__name__
49 | # ret = attr
50 | logger.info('{} [{:s}() form {:s}] is created.'.format(init_type, class_name, file_name))
51 |
52 | return ret
53 |
54 |
55 | def mkdirs(paths):
56 | if isinstance(paths, str):
57 | os.makedirs(paths, exist_ok=True)
58 | else:
59 | for path in paths:
60 | os.makedirs(path, exist_ok=True)
61 |
62 |
63 | def get_timestamp():
64 | return datetime.now().strftime('%y%m%d_%H%M%S')
65 |
66 |
67 | def write_json(content, fname):
68 | fname = Path(fname)
69 | with fname.open('wt') as handle:
70 | json.dump(content, handle, indent=4, sort_keys=False)
71 |
72 |
73 | class NoneDict(dict):
74 | def __missing__(self, key):
75 | return None
76 |
77 |
78 | def dict_to_nonedict(opt):
79 | """ convert to NoneDict, which return None for missing key. """
80 | if isinstance(opt, dict):
81 | new_opt = dict()
82 | for key, sub_opt in opt.items():
83 | new_opt[key] = dict_to_nonedict(sub_opt)
84 | return NoneDict(**new_opt)
85 | elif isinstance(opt, list):
86 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
87 | else:
88 | return opt
89 |
90 |
91 | def dict2str(opt, indent_l=1):
92 | """ dict to string for logger """
93 | msg = ''
94 | for k, v in opt.items():
95 | if isinstance(v, dict):
96 | msg += ' ' * (indent_l * 2) + k + ':[\n'
97 | msg += dict2str(v, indent_l + 1)
98 | msg += ' ' * (indent_l * 2) + ']\n'
99 | else:
100 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
101 | return msg
102 |
103 |
104 | def parse(args):
105 | json_str = ''
106 | with open(args.config, 'r') as f:
107 | for line in f:
108 | line = line.split('//')[0] + '\n'
109 | json_str += line
110 | opt = json.loads(json_str, object_pairs_hook=OrderedDict)
111 |
112 | ''' replace the config context using args '''
113 | opt['phase'] = args.phase
114 | if args.gpu is not None:
115 | opt['gpu_ids'] = [int(id) for id in args.gpu.split(',')]
116 | if args.batch is not None:
117 | opt['datasets'][opt['phase']]['dataloader']['args']['batch_size'] = args.batch
118 | if args.path is not None:
119 | opt['datasets'][opt['phase']]['which_dataset']['args']['data_root'] = args.path
120 | if args.z_times is not None:
121 | opt['datasets'][opt['phase']]['which_dataset']['args']['z_times'] = args.z_times
122 | if args.lr is not None:
123 | opt['model']['which_model']['args']['optimizers'][0]['lr'] = args.lr
124 | if args.step is not None:
125 | opt['model']['which_networks'][0]['args']['beta_schedule'][opt['phase']]['n_timestep'] = args.step
126 | ''' set cuda environment '''
127 | if len(opt['gpu_ids']) > 1:
128 | opt['distributed'] = True
129 | else:
130 | opt['distributed'] = False
131 |
132 | ''' update name '''
133 | if args.debug:
134 | opt['name'] = 'debug_{}'.format(opt['name'])
135 | elif opt['finetune_norm']:
136 | opt['name'] = 'finetune_{}'.format(opt['name'])
137 | else:
138 | opt['name'] = '{}_{}'.format(opt['phase'], opt['name'])
139 |
140 | ''' set log directory '''
141 | experiments_root = os.path.join(opt['path']['base_dir'], '{}_{}'.format(opt['name'], get_timestamp()))
142 | mkdirs(experiments_root)
143 | print('results and model will be saved in {}'.format(experiments_root))
144 | ''' save json '''
145 | write_json(opt, '{}/config.json'.format(experiments_root))
146 |
147 | ''' change folder relative hierarchy '''
148 | opt['path']['experiments_root'] = experiments_root
149 | for key, path in opt['path'].items():
150 | if 'resume' not in key and 'base' not in key and 'root' not in key:
151 | opt['path'][key] = os.path.join(experiments_root, path)
152 | mkdirs(opt['path'][key])
153 | if args.resume is not None:
154 | opt['path']['resume_state'] = args.resume
155 |
156 | ''' debug mode '''
157 | if 'debug' in opt['name']:
158 | opt['train'].update(opt['debug'])
159 |
160 | ''' code backup '''
161 | for name in os.listdir('.'):
162 | if name in ['config', 'models', 'core', 'slurm', 'data']:
163 | dst = os.path.join(opt['path']['code'], name)
164 | if os.path.exists(dst):
165 | shutil.rmtree(dst)
166 | shutil.copytree(name, dst)
167 | # shutil.copytree(name,dst , ignore=shutil.ignore_patterns("*.pyc", "__pycache__"))
168 | if '.py' in name or '.sh' in name:
169 | shutil.copy(name, opt['path']['code'])
170 | opt['mean'] = args.mean
171 | return dict_to_nonedict(opt)
172 |
--------------------------------------------------------------------------------
/core/util.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import math
4 | import torch
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 | from torchvision.utils import make_grid
7 | import os
8 | import cv2
9 |
10 |
11 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1), norm=True):
12 | '''
13 | Converts a torch Tensor into an image Numpy array
14 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
15 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
16 | '''
17 | tensor = tensor.clamp_(*min_max) # clamp
18 | n_dim = tensor.dim()
19 | if n_dim == 4:
20 | n_img = len(tensor)
21 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23 | elif n_dim == 3:
24 | img_np = tensor.numpy()
25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
26 | elif n_dim == 2:
27 | img_np = tensor.numpy()
28 | else:
29 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
30 | img_np = ((img_np + 1) * 127.5).round()
31 | return img_np.astype(out_type).squeeze()
32 |
33 |
34 | def postprocess(images, out_type=np.uint8, min_max=(-1, 1), norm=True):
35 | return [tensor2img(image, out_type, min_max, norm) for image in images]
36 |
37 |
38 | def normalize_tensor(tensor, min_max=(-1, 1)):
39 | tensor = tensor.float().clamp_(*min_max) # clamp
40 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
41 | return tensor
42 |
43 |
44 | def set_seed(seed, gl_seed=0):
45 | """ set random seed, gl_seed used in worker_init_fn function """
46 | if seed >= 0 and gl_seed >= 0:
47 | seed += gl_seed
48 | torch.manual_seed(seed)
49 | torch.cuda.manual_seed_all(seed)
50 | np.random.seed(seed)
51 | random.seed(seed)
52 |
53 | ''' change the deterministic and benchmark maybe cause uncertain convolution behavior.
54 | speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html '''
55 | if seed >= 0 and gl_seed >= 0: # slower, more reproducible
56 | torch.backends.cudnn.deterministic = True
57 | torch.backends.cudnn.benchmark = False
58 | else: # faster, less reproducible
59 | torch.backends.cudnn.deterministic = False
60 | torch.backends.cudnn.benchmark = True
61 |
62 |
63 | def set_gpu(args, distributed=False, rank=0):
64 | """ set parameter to gpu or ddp """
65 | if args is None:
66 | return None
67 | if distributed and isinstance(args, torch.nn.Module):
68 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True,
69 | find_unused_parameters=True)
70 | else:
71 | return args.cuda()
72 |
73 |
74 | def set_device(args, distributed=False, rank=0):
75 | """ set parameter to gpu or cpu """
76 | if torch.cuda.is_available():
77 | if isinstance(args, list):
78 | return (set_gpu(item, distributed, rank) for item in args)
79 | elif isinstance(args, dict):
80 | return {key: set_gpu(args[key], distributed, rank) for key in args}
81 | else:
82 | args = set_gpu(args, distributed, rank)
83 | return args
84 |
85 |
86 |
--------------------------------------------------------------------------------
/crop_single_file.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import numpy as np
5 | import cv2
6 | from tifffile import imwrite
7 |
8 | def mkdir(path):
9 | if os.path.exists(path):
10 | shutil.rmtree(path)
11 | os.mkdir(path)
12 |
13 |
14 | def crop(wf_img, save_wf_path, patch_size=256, overlap=0.125):
15 | if len(wf_img.shape) > 2:
16 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY)
17 | if wf_img.dtype == np.uint16: # convert to 8 bit
18 | cmin = wf_img.min()
19 | cmax = wf_img.max()
20 | cscale = cmax - cmin
21 | if cscale == 0:
22 | cscale = 1
23 | scale = float(255 - 0) / cscale
24 | wf_img = (wf_img - cmin) * scale + 0
25 | wf_img = np.clip(wf_img, 0, 255) + 0.5
26 | wf_img = wf_img.astype(np.uint8)
27 | stride = int(patch_size * (1 - overlap))
28 | if len(wf_img.shape) > 2:
29 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY)
30 | border = 0
31 |
32 | x = border
33 | x_end = wf_img.shape[0] - border
34 | y_end = wf_img.shape[0] - border
35 | row = 0
36 | while x + patch_size < x_end:
37 | y = border
38 | col = 0
39 | while y + patch_size < y_end:
40 | crop_wf_img = wf_img[x: x + patch_size, y: y + patch_size]
41 | imwrite(os.path.join(save_wf_path, str(row) + '_' + str(col) + '.tif'),
42 | crop_wf_img)
43 | col += 1
44 | y += stride
45 | row += 1
46 | x += stride
47 |
48 |
49 | def test_pre(data_root, task='denoise'):
50 | target_path = os.path.join(data_root, task + '_test_crop_patches')
51 | mkdir(target_path)
52 | for file in os.listdir(data_root):
53 | if not file.endswith('tif'):
54 | continue
55 | mkdir(os.path.join(target_path, file[:-4]))
56 | save_wf_path = os.path.join(os.path.join(target_path, file[:-4], '0'))
57 | mkdir(save_wf_path)
58 | wf_file_img = cv2.imread(os.path.join(data_root, file))
59 | if task == 'denoise':
60 | crop(wf_file_img, save_wf_path, patch_size=256, overlap=0.125)
61 | else:
62 | crop(wf_file_img, save_wf_path, patch_size=128, overlap=0.125)
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--task', default="denoise")
68 | parser.add_argument('--path', help="dataset for evaluation")
69 | args = parser.parse_args()
70 | test_pre(args.path, args.task)
71 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 |
4 | from torch.utils.data.distributed import DistributedSampler
5 | from torch import Generator, randperm
6 | from torch.utils.data import DataLoader, Subset
7 |
8 | import core.util as Util
9 | from core.praser import init_obj
10 | from vEM_test_pre import recon_pre
11 |
12 |
13 |
14 | def define_dataloader(logger, opt):
15 | """ create train/test dataloader and validation dataloader, validation dataloader is None when phase is test or not GPU 0 """
16 | '''create dataset and set random seed'''
17 | dataloader_args = opt['datasets'][opt['phase']]['dataloader']['args']
18 | worker_init_fn = partial(Util.set_seed, gl_seed=opt['seed'])
19 |
20 | phase_dataset, val_dataset = define_dataset(logger, opt)
21 |
22 | '''create datasampler'''
23 | data_sampler = None
24 | if opt['distributed']:
25 | data_sampler = DistributedSampler(phase_dataset, shuffle=dataloader_args.get('shuffle', False),
26 | num_replicas=opt['world_size'], rank=opt['global_rank'])
27 | dataloader_args.update({'shuffle': False}) # sampler option is mutually exclusive with shuffle
28 | ''' create dataloader and validation dataloader '''
29 | dataloader = DataLoader(phase_dataset, sampler=data_sampler, worker_init_fn=worker_init_fn, **dataloader_args)
30 |
31 | ''' val_dataloader don't use DistributedSampler to run only GPU 0! '''
32 | if opt['global_rank'] == 0 and val_dataset is not None:
33 | dataloader_args.update(opt['datasets'][opt['phase']]['dataloader'].get('val_args', {}))
34 | val_dataloader = DataLoader(val_dataset, worker_init_fn=worker_init_fn, **dataloader_args)
35 | else:
36 | val_dataloader = None
37 | return dataloader, val_dataloader
38 |
39 |
40 | def define_dataset(logger, opt):
41 | ''' loading Dataset() class from given file's name '''
42 | dataset_opt = opt['datasets'][opt['phase']]['which_dataset']
43 | if opt['phase'] != 'train':
44 | if opt['task'] == '3d_reconstruction':
45 | dataset_opt['args']['data_root'] = recon_pre(dataset_opt['args']['data_root'])
46 |
47 | phase_dataset = init_obj(dataset_opt, logger, default_file_name='data.dataset', init_type='Dataset')
48 | val_dataset = None
49 |
50 | valid_len = 0
51 | data_len = len(phase_dataset)
52 | if 'debug' in opt['name']:
53 | debug_split = opt['debug'].get('debug_split', 1.0)
54 | if isinstance(debug_split, int):
55 | data_len = debug_split
56 | else:
57 | data_len *= debug_split
58 |
59 | dataloder_opt = opt['datasets'][opt['phase']]['dataloader']
60 | valid_split = dataloder_opt.get('validation_split', 0)
61 |
62 | ''' divide validation dataset, valid_split==0 when phase is test or validation_split is 0. '''
63 | if valid_split > 0.0 or 'debug' in opt['name']:
64 | if isinstance(valid_split, int):
65 | assert valid_split < data_len, "Validation set size is configured to be larger than entire dataset."
66 | valid_len = valid_split
67 | else:
68 | valid_len = int(data_len * valid_split)
69 | data_len -= valid_len
70 | phase_dataset, val_dataset = subset_split(dataset=phase_dataset, lengths=[data_len, valid_len],
71 | generator=Generator().manual_seed(opt['seed']))
72 |
73 | logger.info('Dataset for {} have {} samples.'.format(opt['phase'], data_len))
74 | if opt['phase'] == 'train':
75 | logger.info('Dataset for {} have {} samples.'.format('val', valid_len))
76 | return phase_dataset, val_dataset
77 |
78 |
79 | def subset_split(dataset, lengths, generator):
80 | """
81 | split a dataset into non-overlapping new datasets of given lengths. main code is from random_split function in pytorch
82 | """
83 | indices = randperm(sum(lengths), generator=generator).tolist()
84 | Subsets = []
85 | for offset, length in zip(np.add.accumulate(lengths), lengths):
86 | if length == 0:
87 | Subsets.append(None)
88 | else:
89 | Subsets.append(Subset(dataset, indices[offset - length: offset]))
90 | return Subsets
91 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/data/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/demo/denoise_demo.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/denoise_demo.tif
--------------------------------------------------------------------------------
/demo/microns_demo/0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/microns_demo/0.tif
--------------------------------------------------------------------------------
/demo/microns_demo/1.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/microns_demo/1.tif
--------------------------------------------------------------------------------
/demo/mouse_liver_demo/0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/mouse_liver_demo/0.tif
--------------------------------------------------------------------------------
/demo/mouse_liver_demo/1.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/mouse_liver_demo/1.tif
--------------------------------------------------------------------------------
/demo/super_res_demo.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/demo/super_res_demo.tif
--------------------------------------------------------------------------------
/emdiffuse_conifg.py:
--------------------------------------------------------------------------------
1 | class EMDiffuseConfig():
2 |
3 | def __init__(self, config, path, phase, batch_size, lr=5e-5, resume=None, gpu='0', subsample=None, port='21012', mean=2, step=None):
4 | self.path = path
5 | self.config = config
6 | self.phase = phase
7 | self.batch = batch_size
8 | self.gpu = gpu
9 | self.debug = False
10 | self.z_times = subsample
11 | self.port = port
12 | self.resume = resume
13 | self.mean = mean
14 | self.lr = lr
15 | self.step=step
16 |
17 | def __getattr__(self, item):
18 | # This method is called when an attribute access is attempted.
19 | try:
20 | return self.__dict__[item]
21 | except KeyError:
22 | return None
23 |
24 | def __setattr__(self, key, value):
25 | # This method allows setting attributes directly.
26 | self.__dict__[key] = value
27 |
28 | def __contains__(self, item):
29 | # This enables the use of 'in' to check for attribute existence.
30 | return item in self.__dict__
31 |
--------------------------------------------------------------------------------
/models/EMDiffuse_network.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from inspect import isfunction
4 | from functools import partial
5 | import numpy as np
6 | from tqdm import tqdm
7 | from core.base_network import BaseNetwork
8 |
9 |
10 | class Network(BaseNetwork):
11 | def __init__(self, unet, beta_schedule, norm=True, module_name='sr3', **kwargs):
12 | super(Network, self).__init__(**kwargs)
13 |
14 | from .guided_diffusion_modules.unet import UNet
15 | self.denoise_fn = UNet(**unet)
16 | self.beta_schedule = beta_schedule
17 | self.norm = norm
18 |
19 | def set_loss(self, loss_fn):
20 | self.loss_fn = loss_fn
21 |
22 | def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'):
23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
25 | betas = make_beta_schedule(**self.beta_schedule[phase])
26 | betas = betas.detach().cpu().numpy() if isinstance(
27 | betas, torch.Tensor) else betas
28 | alphas = 1. - betas
29 |
30 | timesteps, = betas.shape
31 | self.num_timesteps = int(timesteps)
32 |
33 | gammas = np.cumprod(alphas, axis=0)
34 | gammas_prev = np.append(1., gammas[:-1])
35 |
36 | # calculations for diffusion q(x_t | x_{t-1}) and others
37 | self.register_buffer('gammas', to_torch(gammas))
38 | self.register_buffer('sqrt_recip_gammas', to_torch(np.sqrt(1. / gammas)))
39 | self.register_buffer('sqrt_recipm1_gammas', to_torch(np.sqrt(1. / gammas - 1)))
40 |
41 | # calculations for posterior q(x_{t-1} | x_t, x_0)
42 | posterior_variance = betas * (1. - gammas_prev) / (1. - gammas)
43 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
44 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
45 | self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(gammas_prev) / (1. - gammas)))
46 | self.register_buffer('posterior_mean_coef2', to_torch((1. - gammas_prev) * np.sqrt(alphas) / (1. - gammas)))
47 |
48 | def predict_start_from_noise(self, y_t, t, noise):
49 | return (
50 | extract(self.sqrt_recip_gammas, t, y_t.shape) * y_t -
51 | extract(self.sqrt_recipm1_gammas, t, y_t.shape) * noise
52 | )
53 |
54 | def q_posterior(self, y_0_hat, y_t, t):
55 | posterior_mean = (
56 | extract(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat +
57 | extract(self.posterior_mean_coef2, t, y_t.shape) * y_t
58 | )
59 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, y_t.shape)
60 | return posterior_mean, posterior_log_variance_clipped
61 |
62 | def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None):
63 | noise_level = extract(self.gammas, t, x_shape=(1, 1)).to(y_t.device)
64 | y_0_hat = self.predict_start_from_noise(
65 | y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level))
66 |
67 | if clip_denoised: # todo: clip
68 | if self.norm:
69 | y_0_hat.clamp_(-1., 1.)
70 | else:
71 | y_0_hat.clamp_(0., 1.)
72 |
73 | model_mean, posterior_log_variance = self.q_posterior(
74 | y_0_hat=y_0_hat, y_t=y_t, t=t)
75 | return model_mean, posterior_log_variance, y_0_hat
76 |
77 | def q_sample(self, y_0, sample_gammas, noise=None):
78 | noise = default(noise, lambda: torch.randn_like(y_0))
79 | return (
80 | sample_gammas.sqrt() * y_0 +
81 | (1 - sample_gammas).sqrt() * noise
82 | )
83 |
84 | @torch.no_grad()
85 | def p_sample(self, y_t, t, clip_denoised=True, y_cond=None, adjust=False):
86 | model_mean, model_log_variance, y_0_hat = self.p_mean_variance(
87 | y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond)
88 |
89 | noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t)
90 | if adjust:
91 | if t[0] < (self.num_timesteps * 0.2):
92 | mean_diff = model_mean.view(model_mean.size(0), -1).mean(1) - y_cond.view(y_cond.size(0), -1).mean(1)
93 | mean_diff = mean_diff.view(model_mean.size(0), 1, 1, 1)
94 | model_mean = model_mean - 0.5 * mean_diff.repeat(
95 | (1, model_mean.shape[1], model_mean.shape[2], model_mean.shape[3]))
96 | return model_mean + noise * (0.5 * model_log_variance).exp(), y_0_hat
97 |
98 | @torch.no_grad()
99 | def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8, adjust=False):
100 | b, *_ = y_cond.shape
101 |
102 | assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num'
103 | sample_inter = (self.num_timesteps // sample_num)
104 | if y_0 is not None:
105 | y_t = default(y_t, lambda: torch.randn_like(y_0))
106 | else:
107 | y_t = default(y_t, lambda: torch.randn_like(y_cond))
108 | ret_arr = y_t
109 |
110 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
111 | t = torch.full((b,), i, device=y_cond.device, dtype=torch.long)
112 | y_t, y_0_hat = self.p_sample(y_t, t, y_cond=y_cond, adjust=adjust)
113 | if mask is not None:
114 | y_t = y_0 * (1. - mask) + mask * y_t
115 | if i % sample_inter == 0:
116 | ret_arr = torch.cat([ret_arr, y_0_hat], dim=0)
117 | return y_t, ret_arr
118 |
119 | def forward(self, y_0, y_cond=None, mask=None, noise=None):
120 | # sampling from p(gammas)
121 | b, _, _, _ = y_0.shape
122 | t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long()
123 | gamma_t1 = extract(self.gammas, t - 1, x_shape=(1, 1))
124 | sqrt_gamma_t2 = extract(self.gammas, t, x_shape=(1, 1))
125 | sample_gammas = (sqrt_gamma_t2 - gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1 # Todo: why
126 | sample_gammas = sample_gammas.view(b, -1)
127 | if noise is None:
128 | noise = torch.randn_like(y_0)
129 | # noise = default(noise, lambda: torch.randn_like(y_0))
130 | y_noisy = self.q_sample(
131 | y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise)
132 |
133 | if mask is not None:
134 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy * mask + (1. - mask) * y_0], dim=1), sample_gammas)
135 | loss = self.loss_fn(mask * noise, mask * noise_hat)
136 | else:
137 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas)
138 | loss = self.loss_fn(noise_hat, noise)
139 | return loss
140 |
141 |
142 | # gaussian diffusion trainer class
143 | def exists(x):
144 | return x is not None
145 |
146 |
147 | def default(val, d):
148 | if exists(val):
149 | return val
150 | return d() if isfunction(d) else d
151 |
152 |
153 | def extract(a, t, x_shape=(1, 1, 1, 1)):
154 | b, *_ = t.shape
155 | out = a.gather(-1, t)
156 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
157 |
158 |
159 | # beta_schedule function
160 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
161 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
162 | warmup_time = int(n_timestep * warmup_frac)
163 | betas[:warmup_time] = np.linspace(
164 | linear_start, linear_end, warmup_time, dtype=np.float64)
165 | return betas
166 |
167 |
168 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3):
169 | if schedule == 'quad':
170 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
171 | n_timestep, dtype=np.float64) ** 2
172 | elif schedule == 'linear':
173 | betas = np.linspace(linear_start, linear_end,
174 | n_timestep, dtype=np.float64)
175 | elif schedule == 'warmup10':
176 |
177 | betas = _warmup_beta(linear_start, linear_end,
178 | n_timestep, 0.1)
179 | elif schedule == 'warmup50':
180 | betas = _warmup_beta(linear_start, linear_end,
181 | n_timestep, 0.5)
182 | elif schedule == 'const':
183 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
184 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
185 | betas = 1. / np.linspace(n_timestep,
186 | 1, n_timestep, dtype=np.float64)
187 | elif schedule == "cosine":
188 | timesteps = (
189 | torch.arange(n_timestep + 1, dtype=torch.float64) /
190 | n_timestep + cosine_s
191 | )
192 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
193 | alphas = torch.cos(alphas).pow(2)
194 | alphas = alphas / alphas[0]
195 | betas = 1 - alphas[1:] / alphas[:-1]
196 | betas = betas.clamp(max=0.999)
197 | else:
198 | raise NotImplementedError(schedule)
199 | return betas
200 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from core.praser import init_obj
2 | import torch
3 | import warnings
4 | from core.logger import VisualWriter, InfoLogger
5 | import core.praser as Praser
6 | import core.util as Util
7 | from data import define_dataloader
8 |
9 | def create_model(**cfg_model):
10 | """ create_model """
11 | opt = cfg_model['opt']
12 | logger = cfg_model['logger']
13 |
14 | model_opt = opt['model']['which_model']
15 | model_opt['args'].update(cfg_model)
16 | model = init_obj(model_opt, logger, default_file_name='models.model', init_type='Model')
17 |
18 | return model
19 |
20 |
21 | def define_network(logger, opt, network_opt):
22 | """ define network with weights initialization """
23 | net = init_obj(network_opt, logger, default_file_name='models.network', init_type='Network')
24 |
25 | if opt['phase'] == 'train':
26 | logger.info('Network [{}] weights initialize using [{:s}] method.'.format(net.__class__.__name__,
27 | network_opt['args'].get('init_type',
28 | 'default')))
29 | net.init_weights()
30 | return net
31 |
32 |
33 | def define_loss(logger, loss_opt):
34 | return init_obj(loss_opt, logger, default_file_name='models.loss', init_type='Loss')
35 |
36 |
37 | def define_metric(logger, metric_opt):
38 | return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric')
39 |
40 |
41 | def create_EMDiffuse(opt):
42 | gpu=0
43 | if 'local_rank' not in opt:
44 | opt['local_rank'] = opt['global_rank'] = gpu
45 | if opt['distributed']:
46 | torch.cuda.set_device(int(opt['local_rank']))
47 | print('using GPU {} for training'.format(int(opt['local_rank'])))
48 | torch.distributed.init_process_group(backend='nccl',
49 | init_method=opt['init_method'],
50 | world_size=opt['world_size'],
51 | rank=opt['global_rank'],
52 | group_name='mtorch'
53 | )
54 | '''set seed and and cuDNN environment '''
55 | torch.backends.cudnn.enabled = False
56 | # warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True')
57 | Util.set_seed(opt['seed'])
58 |
59 | ''' set logger '''
60 | phase_logger = InfoLogger(opt)
61 | phase_writer = VisualWriter(opt, phase_logger)
62 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root']))
63 |
64 | '''set networks and dataset'''
65 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test.
66 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']]
67 |
68 | ''' set metrics, loss, optimizer and schedulers '''
69 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']]
70 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']]
71 |
72 | model = create_model(
73 | opt=opt,
74 | networks=networks,
75 | phase_loader=phase_loader,
76 | val_loader=val_loader,
77 | losses=losses,
78 | metrics=metrics,
79 | logger=phase_logger,
80 | writer=phase_writer
81 | )
82 | return model
83 |
84 |
--------------------------------------------------------------------------------
/models/__pycache__/EMDiffuse_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/EMDiffuse_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/EMDiffuse_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/EMDiffuse_network.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/models/guided_diffusion_modules/__pycache__/nn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/nn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/guided_diffusion_modules/__pycache__/unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/unet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/guided_diffusion_modules/__pycache__/unet_jit2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luchixiang/EMDiffuse/79c8045a7b5db882156ae84689158e4f4ddef322/models/guided_diffusion_modules/__pycache__/unet_jit2.cpython-37.pyc
--------------------------------------------------------------------------------
/models/guided_diffusion_modules/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class GroupNorm32(nn.GroupNorm):
12 | def forward(self, x):
13 | return super().forward(x.float()).type(x.dtype)
14 |
15 |
16 | def zero_module(module):
17 | """
18 | Zero out the parameters of a module and return it.
19 | """
20 | for p in module.parameters():
21 | p.detach().zero_()
22 | return module
23 |
24 |
25 | def scale_module(module, scale):
26 | """
27 | Scale the parameters of a module and return it.
28 | """
29 | for p in module.parameters():
30 | p.detach().mul_(scale)
31 | return module
32 |
33 |
34 | def mean_flat(tensor):
35 | """
36 | Take the mean over all non-batch dimensions.
37 | """
38 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
39 |
40 |
41 | def normalization(channels, group_num):
42 | """
43 | Make a standard normalization layer.
44 |
45 | :param channels: number of input channels.
46 | :return: an nn.Module for normalization.
47 | """
48 |
49 | # return GroupNorm32(group_num, channels)
50 | return nn.GroupNorm(group_num, channels) # todo: normalization changed
51 |
52 | def Layernormalization(channels):
53 | """
54 | Make a standard normalization layer.
55 |
56 | :param channels: number of input channels.
57 | :return: an nn.Module for normalization.
58 | """
59 |
60 | return nn.LayerNorm(channels)
61 |
62 | def checkpoint(func, inputs, params, flag):
63 | """
64 | Evaluate a function without caching intermediate activations, allowing for
65 | reduced memory at the expense of extra compute in the backward pass.
66 |
67 | :param func: the function to evaluate.
68 | :param inputs: the argument sequence to pass to `func`.
69 | :param params: a sequence of parameters `func` depends on but does not
70 | explicitly take as arguments.
71 | :param flag: if False, disable gradient checkpointing.
72 | """
73 | if flag:
74 | args = tuple(inputs) + tuple(params)
75 | return CheckpointFunction.apply(func, len(inputs), *args)
76 | else:
77 | return func(*inputs)
78 |
79 |
80 | class CheckpointFunction(torch.autograd.Function):
81 | @staticmethod
82 | def forward(ctx, run_function, length, *args):
83 | ctx.run_function = run_function
84 | ctx.input_tensors = list(args[:length])
85 | ctx.input_params = list(args[length:])
86 | with torch.no_grad():
87 | output_tensors = ctx.run_function(*ctx.input_tensors)
88 | return output_tensors
89 |
90 | @staticmethod
91 | def backward(ctx, *output_grads):
92 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
93 | with torch.enable_grad():
94 | # Fixes a bug where the first op in run_function modifies the
95 | # Tensor storage in place, which is not allowed for detach()'d
96 | # Tensors.
97 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
98 | output_tensors = ctx.run_function(*shallow_copies)
99 | input_grads = torch.autograd.grad(
100 | output_tensors,
101 | ctx.input_tensors + ctx.input_params,
102 | output_grads,
103 | allow_unused=True,
104 | )
105 | del ctx.input_tensors
106 | del ctx.input_params
107 | del output_tensors
108 | return (None, None) + input_grads
109 |
110 |
111 | def count_flops_attn(model, _x, y):
112 | """
113 | A counter for the `thop` package to count the operations in an
114 | attention operation.
115 | Meant to be used like:
116 | macs, params = thop.profile(
117 | model,
118 | inputs=(inputs, timestamps),
119 | custom_ops={QKVAttention: QKVAttention.count_flops},
120 | )
121 | """
122 | b, c, *spatial = y[0].shape
123 | num_spatial = int(np.prod(spatial))
124 | # We perform two matmuls with the same number of ops.
125 | # The first computes the weight matrix, the second computes
126 | # the combination of the value vectors.
127 | matmul_ops = 2 * b * (num_spatial ** 2) * c
128 | model.total_ops += torch.DoubleTensor([matmul_ops])
129 |
130 |
131 | def gamma_embedding(gammas, dim:int, max_period:int=10000):
132 | """
133 | Create sinusoidal timestep embeddings.
134 | :param gammas: a 1-D Tensor of N indices, one per batch element.
135 | These may be fractional.
136 | :param dim: the dimension of the output.
137 | :param max_period: controls the minimum frequency of the embeddings.
138 | :return: an [N x dim] Tensor of positional embeddings.
139 | """
140 | half = dim // 2
141 | freqs = torch.exp(
142 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
143 | ).to(device=gammas.device)
144 | args = gammas[:, None].float() * freqs[None]
145 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
146 | if dim % 2:
147 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
148 | return embedding
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | # class mse_loss(nn.Module):
8 | # def __init__(self) -> None:
9 | # super().__init__()
10 | # self.loss_fn = nn.MSELoss()
11 | # def forward(self, output, target):
12 | # return self.loss_fn(output, target)
13 |
14 |
15 | def mse_loss(output, target):
16 | return F.mse_loss(output, target)
17 |
18 |
19 | def l1_loss(output, target):
20 | return F.l1_loss(output, target)
21 |
22 |
23 | def loss_predict_loss(out, target, pred_loss):
24 | target_loss = F.mse_loss(out, target, reduction='none')
25 | return torch.sum(target_loss) / (
26 | target.shape[0] * target.shape[1] * target.shape[2] * target.shape[3]), LossPredLoss(pred_loss,
27 | target_loss)
28 |
29 |
30 | def pin_loss(q_upper, q_lower, target):
31 | q_lo_loss = PinballLoss(0.05)
32 | q_hi_loss = PinballLoss(0.95)
33 | loss = q_lo_loss(q_lower, target) + q_hi_loss(q_upper, target)
34 | return loss
35 |
36 |
37 | def SampleLossPredLoss(input, target, margin=1.0, reduction='mean'):
38 | # input: (b, w * h)
39 |
40 | b = input.shape[0]
41 | target = target.detach()
42 | target = target.view(b, -1)
43 | target = torch.mean(target, dim=1)
44 | input = input.view(b, -1)
45 | input = torch.mean(input, dim=1)
46 | assert input.shape[0] % 2 == 0, 'the batch size is not even.'
47 | assert input.shape == input.flip(0).shape
48 | input = (input - input.flip(0))[
49 | :input.shape[0] // 2] # [l_1 - l_2B, l_2 - l_2B-1, ... , l_B - l_B+1], where batch_size = 2B
50 | target = (target - target.flip(0))[:target.shape[0] // 2]
51 | target = target.detach()
52 | one = 2 * torch.sign(torch.clamp(target, min=0)) - 1 # 1 operation which is defined by the authors
53 | if reduction == 'mean':
54 | loss = torch.sum(torch.clamp(margin - one * input, min=0))
55 | loss = loss / (input.size(0)) # Note that the size of input is already halved
56 | elif reduction == 'none':
57 | loss = torch.clamp(margin - one * input, min=0)
58 | else:
59 | NotImplementedError()
60 | return loss
61 |
62 |
63 | def LossPredLoss(input, target, margin=1.0, reduction='mean'):
64 | # input: (b, w * h)
65 |
66 | b = input.shape[0]
67 | target = target.view(b, -1)
68 | input = input.view(b, -1)
69 | assert input.shape[1] % 2 == 0, 'the batch size is not even.'
70 | assert input.shape == input.flip(1).shape
71 | index_shuffle = torch.randperm(input.shape[1])
72 |
73 | input = input[:, index_shuffle]
74 | target = target[:, index_shuffle]
75 | input = (input - input.flip(1))[:,
76 | :input.shape[1] // 2] # [l_1 - l_2B, l_2 - l_2B-1, ... , l_B - l_B+1], where batch_size = 2B
77 | target = (target - target.flip(1))[:, :target.shape[1] // 2]
78 | target = target.detach()
79 | one = 2 * torch.sign(torch.clamp(target, min=0)) - 1 # 1 operation which is defined by the authors
80 | if reduction == 'mean':
81 | loss = torch.sum(torch.clamp(margin - one * input, min=0))
82 | loss = loss / (input.size(0) * input.size(1)) # Note that the size of input is already halved
83 | elif reduction == 'none':
84 | loss = torch.clamp(margin - one * input, min=0)
85 | else:
86 | NotImplementedError()
87 | return loss
88 |
89 |
90 | def pin_loss2(q_lower, q_uper, out, target):
91 | q_lo_loss = PinballLoss(0.05)
92 | q_hi_loss = PinballLoss(0.95)
93 | loss = q_lo_loss(q_lower, target) + q_hi_loss(q_uper, target) + mse_loss(out, target)
94 | return loss
95 |
96 |
97 | def mse_var_loss(output, target, variance, weight=1):
98 | variance = weight * variance
99 | loss1 = torch.mul(torch.exp(-variance), (output - target) ** 2)
100 | loss2 = variance
101 | loss = .5 * (loss1 + loss2)
102 | return loss.mean()
103 |
104 | def mse_var_loss2(output, target, variance, var_weight):
105 | # print((1-var_weight).max(), (1-var_weight).min())
106 | variance = variance * torch.clamp(var_weight, min=1e-2, max=1)
107 | loss1 = torch.mul(torch.exp(-variance), (output - target) ** 2)
108 | loss2 = variance
109 | loss = .5 * (loss1 + loss2)
110 | return loss.mean()
111 |
112 |
113 | def mse_var_loss_sample(output, target, variance, weight=1):
114 | # variance = 4 * variance
115 | target_loss = (output - target) ** 2
116 | loss1 = torch.mul(torch.exp(-variance), target_loss)
117 | loss2 = variance
118 | loss3 = SampleLossPredLoss(variance, target_loss,reduction='mean')
119 | var_loss = .5 * (loss1 + loss2)
120 |
121 | return var_loss.mean() + loss3
122 |
123 |
124 |
125 | class MSE_VAR(nn.Module):
126 | def __init__(self, var_weight):
127 | super(MSE_VAR, self).__init__()
128 | self.var_weight = var_weight
129 |
130 | def forward(self, results, label):
131 | mean, var = results['mean'], results['var']
132 | var = self.var_weight * var
133 |
134 | loss1 = torch.mul(torch.exp(-var), (mean - label) ** 2)
135 | loss2 = var
136 | loss = .5 * (loss1 + loss2)
137 | return loss.mean()
138 |
139 |
140 | class PinballLoss():
141 |
142 | def __init__(self, quantile=0.10, reduction='mean'):
143 | self.quantile = quantile
144 | assert 0 < self.quantile
145 | assert self.quantile < 1
146 | self.reduction = reduction
147 |
148 | def __call__(self, output, target):
149 | assert output.shape == target.shape
150 | loss = torch.zeros_like(target, dtype=torch.float)
151 | error = output - target
152 | smaller_index = error < 0
153 | bigger_index = 0 < error
154 | loss[smaller_index] = self.quantile * (abs(error)[smaller_index])
155 | loss[bigger_index] = (1 - self.quantile) * (abs(error)[bigger_index])
156 |
157 | if self.reduction == 'sum':
158 | loss = loss.sum()
159 | if self.reduction == 'mean':
160 | loss = loss.mean()
161 |
162 | return loss
163 |
164 |
165 | class FocalLoss(nn.Module):
166 | def __init__(self, gamma=2, alpha=None, size_average=True):
167 | super(FocalLoss, self).__init__()
168 | self.gamma = gamma
169 | self.alpha = alpha
170 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
171 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
172 | self.size_average = size_average
173 |
174 | def forward(self, input, target):
175 | if input.dim() > 2:
176 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
177 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
178 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
179 | target = target.view(-1, 1)
180 |
181 | logpt = F.log_softmax(input)
182 | logpt = logpt.gather(1, target)
183 | logpt = logpt.view(-1)
184 | pt = Variable(logpt.data.exp())
185 |
186 | if self.alpha is not None:
187 | if self.alpha.type() != input.data.type():
188 | self.alpha = self.alpha.type_as(input.data)
189 | at = self.alpha.gather(0, target.data.view(-1))
190 | logpt = logpt * Variable(at)
191 |
192 | loss = -1 * (1 - pt) ** self.gamma * logpt
193 | if self.size_average:
194 | return loss.mean()
195 | else:
196 | return loss.sum()
197 |
--------------------------------------------------------------------------------
/models/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 | from scipy.stats import entropy
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.nn import functional as F
8 | from torchvision.models.inception import inception_v3
9 |
10 |
11 | def mae(input, target):
12 | with torch.no_grad():
13 | loss = nn.L1Loss()
14 | output = loss(input, target)
15 | return output
16 |
17 |
18 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
19 | """Computes the inception score of the generated images imgs
20 |
21 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
22 | cuda -- whether or not to run on GPU
23 | batch_size -- batch size for feeding into Inception v3
24 | splits -- number of splits
25 | """
26 | N = len(imgs)
27 |
28 | assert batch_size > 0
29 | assert N > batch_size
30 |
31 | # Set up dtype
32 | if cuda:
33 | dtype = torch.cuda.FloatTensor
34 | else:
35 | if torch.cuda.is_available():
36 | print("WARNING: You have a CUDA device, so you should probably set cuda=True")
37 | dtype = torch.FloatTensor
38 |
39 | # Set up dataloader
40 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
41 |
42 | # Load inception model
43 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
44 | inception_model.eval()
45 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
46 | def get_pred(x):
47 | if resize:
48 | x = up(x)
49 | x = inception_model(x)
50 | return F.softmax(x).data.cpu().numpy()
51 |
52 | # Get predictions
53 | preds = np.zeros((N, 1000))
54 |
55 | for i, batch in enumerate(dataloader, 0):
56 | batch = batch.type(dtype)
57 | batchv = Variable(batch)
58 | batch_size_i = batch.size()[0]
59 |
60 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)
61 |
62 | # Now compute the mean kl-div
63 | split_scores = []
64 |
65 | for k in range(splits):
66 | part = preds[k * (N // splits): (k+1) * (N // splits), :]
67 | py = np.mean(part, axis=0)
68 | scores = []
69 | for i in range(part.shape[0]):
70 | pyx = part[i, :]
71 | scores.append(entropy(pyx, py))
72 | split_scores.append(np.exp(np.mean(scores)))
73 |
74 | return np.mean(split_scores), np.std(split_scores)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.13.0
2 | torchvision>=0.14.0
3 | matplotlib
4 | tensorboard
5 | scipy
6 | tifffile
7 | opencv-python
8 | pandas
9 | imutils
10 | image_registration
11 | numpy>=1.23.0
12 | pytest
13 | warmup_scheduler
14 | tqdm
15 | imagecodecs
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 | import torch
5 | import torch.multiprocessing as mp
6 |
7 | from core.logger import VisualWriter, InfoLogger
8 | import core.praser as Praser
9 | import core.util as Util
10 | from data import define_dataloader
11 | from models import create_model, define_network, define_loss, define_metric
12 |
13 |
14 | def main_worker(gpu, ngpus_per_node, opt):
15 | """ threads running on each GPU """
16 | if 'local_rank' not in opt:
17 | opt['local_rank'] = opt['global_rank'] = gpu
18 | if opt['distributed']:
19 | torch.cuda.set_device(int(opt['local_rank']))
20 | print('using GPU {} for training'.format(int(opt['local_rank'])))
21 | torch.distributed.init_process_group(backend='nccl',
22 | init_method=opt['init_method'],
23 | world_size=opt['world_size'],
24 | rank=opt['global_rank'],
25 | group_name='mtorch'
26 | )
27 | '''set seed and and cuDNN environment '''
28 | torch.backends.cudnn.enabled = False
29 | warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True')
30 | Util.set_seed(opt['seed'])
31 |
32 | ''' set logger '''
33 | phase_logger = InfoLogger(opt)
34 | phase_writer = VisualWriter(opt, phase_logger)
35 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root']))
36 |
37 | '''set networks and dataset'''
38 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test.
39 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']]
40 |
41 | ''' set metrics, loss, optimizer and schedulers '''
42 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']]
43 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']]
44 |
45 | model = create_model(
46 | opt=opt,
47 | networks=networks,
48 | phase_loader=phase_loader,
49 | val_loader=val_loader,
50 | losses=losses,
51 | metrics=metrics,
52 | logger=phase_logger,
53 | writer=phase_writer
54 | )
55 |
56 | phase_logger.info('Begin model {}.'.format(opt['phase']))
57 |
58 | if opt['phase'] == 'train':
59 | model.train()
60 | else:
61 | model.test()
62 |
63 | phase_writer.close()
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('-c', '--config', type=str, default='config/EMDiffuse-n.json',
69 | help='JSON file for configuration')
70 | parser.add_argument('--path', type=str, default=None, help='patch of cropped patches')
71 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], help='Run train or test', default='train')
72 | parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu')
73 | parser.add_argument('--gpu', type=str, default=None, help='the gpu devices used')
74 | parser.add_argument('-d', '--debug', action='store_true')
75 | parser.add_argument('-z', '--z_times', default=None, type=int, help='The anisotropy time of the volume em')
76 | parser.add_argument('-P', '--port', default='21012', type=str)
77 | parser.add_argument('--mean', type=int, default=2,
78 | help='EMDiffuse samples one plausible solution from distribution. The number of samples you '
79 | 'want to generate and averaging')
80 | parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate')
81 | parser.add_argument('--step', type=int, default=None, help='Steps of the diffusion process. More steps lead to '
82 | 'better image quality. ')
83 | parser.add_argument('--resume', type=str, default=None,
84 | help='Resume state path and load epoch number e.g., experiments/EMDiffuse-n/2720')
85 |
86 | ''' parser configs '''
87 | args = parser.parse_args()
88 |
89 | opt = Praser.parse(args)
90 |
91 | ''' cuda devices '''
92 | gpu_str = ','.join(str(x) for x in opt['gpu_ids'])
93 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str
94 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str))
95 |
96 | ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training'''
97 | # [Todo]: multi GPU on multi machine
98 | if opt['distributed']:
99 | ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count()
100 | opt['world_size'] = ngpus_per_node
101 | opt['init_method'] = 'tcp://127.0.0.1:' + args.port
102 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
103 | else:
104 | opt['world_size'] = 1
105 | main_worker(0, 1, opt)
106 |
--------------------------------------------------------------------------------
/test_pre.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import cv2
6 | from tifffile import imwrite
7 |
8 |
9 | def mkdir(path):
10 | if os.path.exists(path):
11 | shutil.rmtree(path)
12 | os.mkdir(path)
13 |
14 |
15 | def process_denoise_pair(wf_img, save_wf_path, path_size=256, stride=224):
16 | # print(wf_image.shape)
17 | if len(wf_img.shape) > 2:
18 | wf_img = cv2.cvtColor(wf_img, cv2.COLOR_BGR2GRAY)
19 | board = 0
20 | x = board
21 | x_end = wf_img.shape[0] - board
22 | y_end = wf_img.shape[0] - board
23 | row = 0
24 | while x + path_size < x_end:
25 | y = board
26 | col = 0
27 | while y + path_size < y_end:
28 | crop_wf_img = wf_img[x: x + path_size, y : y + path_size]
29 |
30 | imwrite(os.path.join(save_wf_path, str(row) + '_' + str(col) + '.tif'),
31 | crop_wf_img)
32 | col += 1
33 | y += stride
34 | row += 1
35 | x += stride
36 |
37 |
38 | def test_pre(data_root, task='denoise'):
39 | target_path = os.path.join(data_root, task + '_test_crop_patches')
40 | mkdir(target_path)
41 | if task == 'denoise':
42 | image_types = ['Brain__4w_04.tif', 'Brain__4w_05.tif', 'Brain__4w_06.tif', 'Brain__4w_07.tif',
43 | 'Brain__4w_08.tif']
44 | else:
45 | image_types = ['Brain__2w_01.tif', 'Brain__2w_02.tif', 'Brain__2w_03.tif']
46 | for region_index in os.listdir(data_root):
47 | if not region_index.isdigit():
48 | continue
49 | mkdir(os.path.join(target_path, region_index))
50 | for type in image_types:
51 | # mkdir(os.path.join(target_path, region_index, type))
52 | save_wf_path = os.path.join(os.path.join(target_path, region_index, type[:-4]))
53 | mkdir(save_wf_path)
54 | print(os.path.join(data_root, region_index, type))
55 | wf_file_img = cv2.imread(os.path.join(data_root, region_index, type))
56 | if task == 'denoise':
57 | process_denoise_pair(wf_file_img, save_wf_path, path_size=256, stride=224)
58 | else:
59 | process_denoise_pair(wf_file_img, save_wf_path, path_size=128, stride=112)
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('--task', default="denoise")
63 | parser.add_argument('--path', help="dataset for evaluation")
64 | args = parser.parse_args()
65 | test_pre(args.path, args.task)
--------------------------------------------------------------------------------
/vEM_test_pre.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tifffile import imread, imwrite
3 | import numpy as np
4 | import shutil
5 | def mkdir(path):
6 | if os.path.exists(path):
7 | shutil.rmtree(path)
8 | os.mkdir(path)
9 |
10 | def prctile_norm(x, min_prc=0, max_prc=100):
11 | x = np.array(x).astype(np.float64)
12 | y = (x - min_prc) / (max_prc - min_prc + 1e-7)
13 | y[y > 1] = 1
14 | y[y < 0] = 0
15 | return y * 255.
16 |
17 | def recon_pre(root_path):
18 | target_path = os.path.join(root_path, 'crop_patches')
19 | mkdir(os.path.join(target_path))
20 | for file in os.listdir(root_path):
21 | if 'tif' not in file:
22 | continue
23 | x = 0
24 | # index = file.split('_')[1][:-4]
25 | index = file[:-4]
26 | path_size = 256
27 | stride = 224
28 | img = imread(os.path.join(root_path, file))
29 | col_num = 0
30 | os.makedirs(os.path.join(target_path, index))
31 | while x + path_size <= img.shape[0]:
32 | y = 0
33 | row_num = 0
34 | x_start = x
35 | while y + path_size <= img.shape[1]:
36 | y_start = y
37 | patch = img[x_start:x_start + path_size, y_start:y_start + path_size]
38 | imwrite(os.path.join(target_path, index, str(row_num) + '_' + str(col_num) + '.tif'), patch)
39 | row_num += 1
40 | y += stride
41 | row_num += 1
42 | x += stride
43 | col_num += 1
44 | return os.path.join(root_path, 'crop_patches')
45 |
46 |
--------------------------------------------------------------------------------
/vEMa_pre.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from tifffile import imread, imwrite
4 | import argparse
5 |
6 |
7 | def find_max_number(folder_path):
8 | max_number = 0
9 | for filename in os.listdir(folder_path):
10 | if filename.endswith('.tif'):
11 | filename = filename[:-4]
12 | if not filename.isdigit():
13 | continue
14 | filename = int(filename)
15 |
16 | number = int(filename)
17 | max_number = max(max_number, number)
18 | return max_number
19 |
20 |
21 | def mkdir(path):
22 | if os.path.exists(path):
23 | import shutil
24 | shutil.rmtree(path)
25 | os.mkdir(path)
26 |
27 |
28 | def vem_transpose(data_root):
29 | stacks = []
30 | z_depth = find_max_number(data_root)
31 |
32 | for i in range(z_depth):
33 | stacks.append(imread(os.path.join(data_root, f'{i}.tif')))
34 | stack = np.stack(stacks)
35 | print(stack.shape)
36 | stack = stack.transpose(1, 0, 2)
37 | target_file_path = os.path.join(data_root, 'transposed')
38 | mkdir(target_file_path)
39 | for i in range(stack.shape[0]):
40 | imwrite(os.path.join(target_file_path, str(i) + '.tif'), stack[i])
41 |
42 |
43 | if __name__ == '__main__':
44 | parser = argparse.ArgumentParser()
45 | # parser.add_argument('--task', default="denoise")
46 | parser.add_argument('--path', help="dataset for evaluation")
47 | args = parser.parse_args()
48 | vem_transpose(args.path)
49 |
--------------------------------------------------------------------------------