├── LICENSE ├── README.md ├── Stain_seperation ├── Main_Stain_Norm.py ├── stain_Norm_Vahadane.py └── stain_utils.py ├── Transfer_learning_Resnet ├── Resnet_P_hash_value.py └── Resnet_SSI.py ├── WSI_test_CycleGan_pathology_A2B_patch.py ├── datasets.py ├── evaluation_metrics ├── 3_1_0.png ├── 3_1_1.png ├── 4_4_0.png ├── 4_4_1.png ├── A ├── A_rA_Similarity.py ├── A_rB_Similarity.py ├── Neg_area_pearson.py ├── Pos_area_pearson.py ├── Resnet_P_hash_value.py └── Resnet_P_hash_value2.py ├── img ├── a.png └── b.png ├── mymodels.py ├── train_Cycle_Gan.py ├── train_Cycle_GanSSIM.py ├── train_Cycle_Gan_UnetSSIM.py ├── train_Cycle_Gan_pathology_cls.py ├── train_Cycle_Gan_pathology_seg.py ├── train_Cycle_Gan_unet.py ├── unet_utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Pathology-Consistent-Stain-Transfer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks 2 | 3 | ![image](./img/a.png) 4 | Fig.3 The overview of the proposed method, which include two generators. Each generator is composed of an encoder-decoder architecture and a pathological representation network. The pathological representation network is co-trained by expert knowledge database and training dataset. The expert knowledge dataset is annotated by experienced pathologists, where blue areas are cancer lesion areas and white areas are the normal tissue areas. And the objective function includes: adversarial loss, cycle consistency loss, pathological consistency loss and base space aligned loss, where cycle consistency loss includes L1 loss and structural similarity constraint SSIM loss. 5 | 6 | result: 7 | ![image](./img/b.png) 8 | 9 | Please site: 10 | Liu, Shuting, et al. "Unpaired Stain Transfer Using Pathology-Consistent Constrained Generative Adversarial Networks." IEEE Transactions on Medical Imaging 40.8 (2021): 1977-1989. 11 | -------------------------------------------------------------------------------- /Stain_seperation/Main_Stain_Norm.py: -------------------------------------------------------------------------------- 1 | import time, cv2, os 2 | import Stain_seperation.stain_utils as utils 3 | import Stain_seperation.stain_Norm_Vahadane as Norm_Vahadane 4 | import glob 5 | 6 | def assure_path_exists(path): 7 | if not os.path.exists(path): 8 | try: 9 | os.makedirs(path) 10 | except OSError: 11 | raise 12 | 13 | 14 | def stain_Norm(target_tiles_list, source_tiles_list, output_path, re_fit=False): 15 | assure_path_exists(output_path) 16 | if target_tiles_list.__len__() > 0 and source_tiles_list.__len__() > 0: 17 | start_time = time.time() 18 | print('Start stain normalization') 19 | norm = Norm_Vahadane.normalizer() 20 | if re_fit: 21 | norm.fit(target_tiles_list) 22 | print('done: target stain matrix') 23 | for counter, img in enumerate(source_tiles_list): 24 | img_name = (img.split('/')[-1]) 25 | transformed_img = norm.transform(utils.read_image(img)) 26 | cv2.imwrite(os.path.join(output_path, img_name), cv2.cvtColor(transformed_img, cv2.COLOR_RGB2BGR)) 27 | elapsed = (time.time() - start_time) 28 | print("--- %s seconds ---" % round((elapsed / 2), 2)) 29 | else: 30 | print('the soucer or target path is invalidate') 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | source_tiles_path = '/home/zhangbc/Mydataspace/LST/raw_data/GLAS/Train/img' 36 | target_tiles_path = '/home/zhangbc/Mydataspace/LST/raw_data/NewDataset_x20/Expert/' 37 | Norm_output_path = '/home/zhangbc/Mydataspace/LST/raw_data/GLAS/Train_Norm/img' 38 | 39 | target_tiles_list = sorted(glob.glob(os.path.join(target_tiles_path, 'HE') + '/*/*.png')) 40 | source_tiles_list = sorted(glob.glob(os.path.join(source_tiles_path, '*.bmp'))) 41 | stain_Norm(target_tiles_list, source_tiles_list, Norm_output_path, re_fit=True) -------------------------------------------------------------------------------- /Stain_seperation/stain_Norm_Vahadane.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stain normalization inspired by method of: 3 | 4 | A. Vahadane et al., ‘Structure-Preserving Color Normalization and Sparse Stain Separation for Histological Images’, IEEE Transactions on Medical Imaging, vol. 35, no. 8, pp. 1962–1971, Aug. 2016. 5 | 6 | Uses the spams package: 7 | 8 | http://spams-devel.gforge.inria.fr/index.html 9 | 10 | Use with python via e.g https://anaconda.org/conda-forge/python-spams 11 | """ 12 | # windows: pip install spams-bin 13 | # linux:pip install python-spams 14 | import spams 15 | import numpy as np 16 | import Stain_seperation.stain_utils as ut 17 | 18 | 19 | def get_stain_matrix(I, threshold=0.8, lamda=0.1): 20 | """ 21 | Get 2x3 stain matrix. First row H and second row E 22 | :param I: 23 | :param threshold: 24 | :param lamda: 25 | :return: 26 | """ 27 | mask = ut.notwhite_mask(I, thresh=threshold).reshape((-1,)) 28 | OD = ut.RGB_to_OD(I).reshape((-1, 3)) 29 | OD = OD[mask] 30 | dictionary = spams.trainDL(OD.T, K=2, lambda1=lamda, mode=2, modeD=0, posAlpha=True, posD=True, verbose=False).T 31 | if dictionary[0, 0] < dictionary[1, 0]: 32 | dictionary = dictionary[[1, 0], :] 33 | dictionary = ut.normalize_rows(dictionary) 34 | return dictionary 35 | 36 | 37 | class normalizer(object): 38 | """ 39 | A stain normalization object 40 | """ 41 | 42 | def __init__(self): 43 | 44 | self.stain_matrix_target = np.array([[0.62600721, 0.62330743, 0.46861798], 45 | [0.3203682, 0.5473311, 0.77317067]]) 46 | # Ki67 Normalization initial matirx obtained from "Sample_target" 47 | # [[0.58594418, 0.68469766, 0.43342651] 48 | # [0.3203682, 0.5473311, 0.77317067]] 49 | 50 | # [[0.62600721,0.62330743,0.46861798], 51 | # [0.35395456,0.58236586,0.73182387]] 52 | 53 | # [[0.58583788, 0.66078505, 0.46920901], 54 | # [0.3536072, 0.56354522, 0.74657801]] 55 | 56 | # HE Normalization initial matirx obtained from "Sample_target" 57 | # self.stain_matrix_target = np.array([[0.60559458, 0.69559906, 0.38651928], 58 | # [0.1100605, 0.94701408, 0.30174662]]) 59 | # [[0.59958405,0.70248408,0.38342546] 60 | # [0.06893222,0.95236792,0.2970584]] 61 | 62 | # [[0.60559458 0.69559906 0.38651928] 63 | # [0.1100605 0.94701408 0.30174662]] 64 | 65 | # [[0.60715608 0.72015621 0.3357626] 66 | # [0.21154943 0.9271104 0.30937542]] 67 | 68 | def fit(self, target_list): 69 | if target_list.__len__() > 1: 70 | Ws = [] 71 | for f_id in range(target_list.__len__()): 72 | target = ut.read_image(target_list[f_id]) 73 | target = ut.standardize_brightness(target) 74 | stain_matrix_target = get_stain_matrix(target) 75 | Ws.append(stain_matrix_target) 76 | Ws = np.asarray(Ws) 77 | Median_W = np.median(Ws, axis=0) 78 | self.stain_matrix_target = ut.normalize_rows(Median_W) 79 | print('WSI target stain matrix: ', self.stain_matrix_target) 80 | else: 81 | target = ut.read_image(target_list[0]) 82 | target = ut.standardize_brightness(target) 83 | self.stain_matrix_target = get_stain_matrix(target) 84 | print('Single target image stain matrix: ', self.stain_matrix_target) 85 | 86 | def stains_Vec_RGB(self, stain_matrix_target): 87 | return ut.OD_to_RGB(stain_matrix_target) 88 | 89 | def transform(self, I): 90 | I = ut.standardize_brightness(I) 91 | stain_matrix_source = get_stain_matrix(I) 92 | source_concentrations = ut.get_concentrations(I, stain_matrix_source) 93 | return (255 * np.exp(-1 * np.dot(source_concentrations, self.stain_matrix_target).reshape(I.shape))).astype( 94 | np.uint8) 95 | 96 | def hematoxylin_eosin(self, I): 97 | I = ut.standardize_brightness(I) 98 | h, w, _ = I.shape 99 | stain_matrix_source = get_stain_matrix(I) 100 | source_concentrations = ut.get_concentrations(I, stain_matrix_source) 101 | 102 | H = source_concentrations[:, 0].reshape(h, w) 103 | H = np.exp(-1 * H) 104 | 105 | E = source_concentrations[:, 1].reshape(h, w) 106 | E = np.exp(-1 * E) 107 | 108 | # H = np.reshape(source_concentrations[:, 0], newshape=(h*w, 1)) 109 | # H = (255 * np.exp(-1 * np.dot(H, np.reshape(stain_matrix_source[0], 110 | # newshape=(1, 3))).reshape(I.shape))).astype(np.uint8) 111 | # E = np.reshape(source_concentrations[:, 1], newshape=(h*w, 1)) 112 | # E = (255 * np.exp(-1 * np.dot(E, np.reshape(stain_matrix_source[1], 113 | # newshape=(1, 3))).reshape(I.shape))).astype(np.uint8) 114 | return H, E 115 | -------------------------------------------------------------------------------- /Stain_seperation/stain_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import cv2 as cv 5 | import spams 6 | ########################################## 7 | 8 | 9 | def Get_file_list(dir): 10 | pattern = "*.png" 11 | imageslist = [] 12 | # read directory of images 13 | for _, _, _ in os.walk(dir): 14 | imageslist.extend(glob(os.path.join(dir, pattern))) 15 | imageslist.sort() 16 | return imageslist 17 | 18 | 19 | def read_image(path): 20 | """ 21 | Read an image to RGB uint8 22 | :param path: 23 | :return: 24 | """ 25 | im = cv.imread(path) 26 | im = cv.cvtColor(im, cv.COLOR_BGR2RGB) 27 | return im 28 | 29 | 30 | def standardize_brightness(I): 31 | """ 32 | 33 | :param I: 34 | :return: 35 | """ 36 | p = np.percentile(I, 90) 37 | return np.clip(I * 255.0 / p, 0, 255).astype(np.uint8) 38 | 39 | 40 | def remove_zeros(I): 41 | """ 42 | Remove zeros, replace with 1's. 43 | :param I: uint8 array 44 | :return: 45 | """ 46 | mask = (I == 0) 47 | I[mask] = 1 48 | return I 49 | 50 | 51 | def RGB_to_OD(I): 52 | """ 53 | Convert from RGB to optical density == Beer-Lamber transformation 54 | :param I: 55 | :return: 56 | """ 57 | I = remove_zeros(I) 58 | return -1 * np.log(I / 255) 59 | 60 | 61 | def OD_to_RGB(OD): 62 | """ 63 | Convert from optical density to RGB 64 | :param OD: 65 | :return: 66 | """ 67 | return (255 * np.exp(-1 * OD)).astype(np.uint8) 68 | 69 | 70 | def normalize_rows(A): 71 | """ 72 | Normalize rows of an array 73 | :param A: 74 | :return: 75 | """ 76 | return A / np.linalg.norm(A, axis=1)[:, None] 77 | 78 | 79 | def notwhite_mask(I, thresh=0.8): 80 | """ 81 | Get a binary mask where true denotes 'not white' 82 | :param I: 83 | :param thresh: 84 | :return: 85 | """ 86 | I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB) 87 | L = I_LAB[:, :, 0] / 255.0 88 | return (L < thresh) 89 | 90 | 91 | def sign(x): 92 | """ 93 | Returns the sign of x 94 | :param x: 95 | :return: 96 | """ 97 | if x > 0: 98 | return +1 99 | elif x < 0: 100 | return -1 101 | elif x == 0: 102 | return 0 103 | 104 | 105 | def get_concentrations(I, stain_matrix, lamda=0.01): 106 | """ 107 | Get concentrations, a npix x 2 matrix 108 | :param I: 109 | :param stain_matrix: a 2x3 stain matrix 110 | :return: 111 | """ 112 | OD = RGB_to_OD(I).reshape((-1, 3)) 113 | return spams.lasso(OD.T, D=stain_matrix.T, mode=2, lambda1=lamda, pos=True).toarray().T -------------------------------------------------------------------------------- /Transfer_learning_Resnet/Resnet_P_hash_value.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import os 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | 8 | import numpy as np 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | ''' 12 | resnet-18 layer3--->thre=0.005 13 | resnet-18 layer2--->thre=0.01 14 | resnet-18 layer1--->thre=0.02 15 | ''' 16 | 17 | 18 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 19 | use_gpu = True 20 | showfeature = True 21 | h = 1500 22 | w = 3010 23 | 24 | def getImage(filepath): 25 | transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 28 | img_data = cv2.imread(filepath) 29 | img_data = cv2.resize(img_data, (w, h), interpolation=cv2.INTER_CUBIC) 30 | img_data = np.array(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)).astype(np.uint8) 31 | return transform(img_data) 32 | 33 | 34 | class FeatureExtractor(nn.Module): 35 | def __init__(self, submodule, extracted_layers): 36 | super(FeatureExtractor, self).__init__() 37 | self.submodule = submodule 38 | self.extracted_layers = extracted_layers 39 | 40 | # modify the forward function 41 | def forward(self, x): 42 | outputs = [] 43 | for name, module in self.submodule._modules.items(): 44 | if name is "fc": x = x.view(x.size(0), -1) 45 | x = module(x) 46 | if name in self.extracted_layers: 47 | outputs.append(x) 48 | return outputs 49 | 50 | 51 | def getHash(f1,f2,f0, thre=None): 52 | f1 = np.mean(f1, axis=(1, 2)) 53 | f2 = np.mean(f2, axis=(1, 2)) 54 | f0 = np.mean(f0, axis=(1, 2)) 55 | dist1 = np.abs(f0-f1) 56 | dist2 = np.abs(f0-f2) 57 | if thre is None: 58 | thre = (np.min(dist1)+np.max(dist1)+np.min(dist2)+np.max(dist2))/4.0 59 | hash1 =[] 60 | print(thre) 61 | for i in range(dist1.shape[0]): 62 | if dist1[i] > thre: 63 | hash1.append(1) 64 | else: 65 | hash1.append(0) 66 | hash2=[] 67 | for i in range(dist2.shape[0]): 68 | if dist2[i] > thre: 69 | hash2.append(1) 70 | else: 71 | hash2.append(0) 72 | length = hash1.__len__() 73 | degree1 = (length - np.sum(hash1)) / length * 100 74 | degree2 = (length - np.sum(hash2)) / length * 100 75 | return degree1,degree2 76 | 77 | 78 | def P_Hash(f1,f0, thre=0.005): 79 | f1 = np.mean(f1, axis=(1, 2)) 80 | f0 = np.mean(f0, axis=(1, 2)) 81 | dist = np.abs(f0-f1) 82 | hash =[] 83 | print(thre) 84 | for i in range(dist.shape[0]): 85 | if dist[i] > thre: 86 | hash.append(1) 87 | else: 88 | hash.append(0) 89 | length = hash.__len__() 90 | degree = (length - np.sum(hash)) / length * 100 91 | return degree 92 | 93 | 94 | def Resnet_P_hash(file_fake1, file_fake2, file_real): 95 | 96 | extract_list = ["layer2"] 97 | 98 | Net = models.resnet18(pretrained=True) 99 | #print(Net) # 可以打印看模型结构 100 | if use_gpu: 101 | Net.cuda() 102 | 103 | image_fake1 = getImage(file_fake1) 104 | image_fake2 = getImage(file_fake2) 105 | image_real = getImage(file_real) 106 | 107 | Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor 108 | input1 = Tensor(1, 3, h, w) 109 | input2 = Tensor(1, 3, h, w) 110 | input0 = Tensor(1, 3, h, w) 111 | 112 | x1 = torch.unsqueeze(image_fake1, dim=0) 113 | x2 = torch.unsqueeze(image_fake2, dim=0) 114 | x0 = torch.unsqueeze(image_real, dim=0) 115 | 116 | fake1_input = Variable(input1.copy_(x1)) 117 | fake2_input = Variable(input2.copy_(x2)) 118 | real_input = Variable(input0.copy_(x0)) 119 | 120 | extract_result = FeatureExtractor(Net, extract_list) 121 | x1_channelfeature = np.squeeze(extract_result(fake1_input)[0].cpu().detach().numpy()) 122 | x2_channelfeature = np.squeeze(extract_result(fake2_input)[0].cpu().detach().numpy()) 123 | x0_channelfeature = np.squeeze(extract_result(real_input)[0].cpu().detach().numpy()) 124 | 125 | ''' 126 | degree1,degree2 = getHash(x1_channelfeature,x2_channelfeature,x0_channelfeature, thre=0.005) 127 | print('our——phash', degree1) 128 | print('cycle_phash', degree2) 129 | ''' 130 | 131 | degree = P_Hash(x1_channelfeature, x0_channelfeature, thre=0.01) 132 | print('our——phash', degree) 133 | degree = P_Hash(x2_channelfeature, x0_channelfeature, thre=0.01) 134 | print('our——phash', degree) 135 | 136 | 137 | print(x1_channelfeature.shape) 138 | if showfeature: 139 | plt.figure() 140 | for i in range(x1_channelfeature.shape[0]): 141 | plt.subplot(1,3,1) 142 | plt.imshow(x1_channelfeature[i, :, :]) 143 | plt.subplot(1,3,2) 144 | plt.imshow(x2_channelfeature[i, :, :]) 145 | plt.subplot(1,3,3) 146 | plt.imshow(x0_channelfeature[i, :, :]) 147 | plt.waitforbuttonpress() 148 | plt.show() 149 | 150 | if __name__ == '__main__': 151 | a = 7 152 | b = 3 153 | img1 = '/home/zhangbc/Mydataspace/LST/mymodelx288/Cycle_GAN_pathology/epoch3/output/A2B/WSI_data00' + str(a) + '/0' + str(b) + '_pre_B.png' 154 | img2 = '/home/zhangbc/Mydataspace/LST/mymodelx288/Cycle_GAN/epoch3/output/A2B/WSI_data00' + str(a) + '/0' + str(b) + '_pre_B.png' 155 | img0 = '/home/zhangbc/Mydataspace/LST/raw_data/NewDataset_x20/Test_aligned/Ki67/WSI_data00' + str(a) + '/0' + str(b) + '.png' 156 | Resnet_P_hash(img1, img2, img0) 157 | 158 | -------------------------------------------------------------------------------- /Transfer_learning_Resnet/Resnet_SSI.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models, transforms 3 | import argparse 4 | import sys 5 | import os 6 | 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch 11 | from torch.utils.data import Dataset 12 | from PIL import Image 13 | import numpy as np 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | 18 | class TestImage(Dataset): 19 | def __init__(self, filepath, patch_size=256): 20 | self.Image, self.pos_list = self.getImage(filepath) 21 | self.transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 24 | self.patch_size = patch_size 25 | 26 | def __getitem__(self, index): 27 | item = self.transform(self.Image[self.pos_list[index][0] - self.patch_size:self.pos_list[index][0], 28 | self.pos_list[index][1] - self.patch_size:self.pos_list[index][1], :]) 29 | 30 | return {'Img': item} 31 | 32 | def __len__(self): 33 | return len(self.pos_list) 34 | 35 | def getImage(self, filepath): 36 | Testimg_data = np.array(Image.open(filepath).convert('RGB')).astype(np.uint8) 37 | Testimg_shape = Testimg_data.shape 38 | X_index = list(range(opt.size, Testimg_shape[0], int(opt.size))) 39 | X_index.append(Testimg_shape[0]) 40 | X_index = np.unique(np.asarray(X_index, dtype=np.int)) 41 | Y_index = list(range(opt.size, Testimg_shape[1], int(opt.size))) 42 | Y_index.append(Testimg_shape[1]) 43 | Y_index = np.unique(np.asarray(Y_index, dtype=np.int)) 44 | pos_list = [] 45 | for x_id in X_index: 46 | for y_id in Y_index: 47 | pos_list.append([x_id, y_id]) 48 | return Testimg_data, pos_list 49 | 50 | 51 | class FeatureExtractor(nn.Module): 52 | def __init__(self, submodule, extracted_layers): 53 | super(FeatureExtractor, self).__init__() 54 | self.submodule = submodule 55 | self.extracted_layers = extracted_layers 56 | 57 | # modify the forward function 58 | def forward(self, x): 59 | outputs = [] 60 | for name, module in self.submodule._modules.items(): 61 | if name is "fc": x = x.view(x.size(0), -1) 62 | x = module(x) 63 | if name in self.extracted_layers: 64 | outputs.append(x) 65 | return outputs 66 | 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--batchSize', type=int, default=2, help='size of the batches') 69 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 70 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 71 | parser.add_argument('--size', type=int, default=150, help='size of the data (squared assumed)') 72 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 73 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 74 | opt = parser.parse_args() 75 | print(opt) 76 | 77 | 78 | def Resnet_SSI(file_fake, file_real, file_real2): 79 | 80 | extract_list = ["avgpool"] 81 | 82 | resnet = models.resnet18(pretrained=True) 83 | print(resnet) # 可以打印看模型结构 84 | if opt.cuda: 85 | resnet.cuda() 86 | 87 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 88 | input1 = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size*2) 89 | input2 = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size*2) 90 | input3 = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size*2) 91 | 92 | dataloader_fake_IHC = DataLoader(TestImage(file_fake, patch_size=opt.size), 93 | batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) 94 | 95 | dataloader_real_HE = DataLoader(TestImage(file_real, patch_size=opt.size), 96 | batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) 97 | dataloader_real_IHC = DataLoader(TestImage(file_real2, patch_size=opt.size), 98 | batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) 99 | MSSI_fakeIHCImg_vector=[] 100 | for i, batch in enumerate(dataloader_fake_IHC): 101 | # Set model input 102 | fake_input = Variable(input1.copy_(batch['Img'])) 103 | extract_result = FeatureExtractor(resnet, extract_list) 104 | SSI_fake_vector = extract_result(fake_input)[0].cpu().detach().numpy() 105 | for patch_id in range(opt.batchSize): 106 | MSSI_fakeIHCImg_vector.append(SSI_fake_vector[patch_id].squeeze()) 107 | MSSI_fakeIHCImg_vector = np.squeeze(np.mean(np.array(MSSI_fakeIHCImg_vector), axis=0)) 108 | 109 | MSSI_realHEImg_vector = [] 110 | for i, batch in enumerate(dataloader_real_HE): 111 | # Set model input 112 | real_input = Variable(input2.copy_(batch['Img'])) 113 | extract_result = FeatureExtractor(resnet, extract_list) 114 | SSI_real_vector = extract_result(real_input)[0].cpu().detach().numpy() 115 | for patch_id in range(opt.batchSize): 116 | MSSI_realHEImg_vector.append(SSI_real_vector[patch_id].squeeze()) 117 | 118 | MSSI_realHEImg_vector = np.squeeze(np.mean(np.array(MSSI_realHEImg_vector), axis=0)) 119 | 120 | MSSI_realIHCImg_vector = [] 121 | for i, batch in enumerate(dataloader_real_IHC): 122 | # Set model input 123 | real_input2 = Variable(input3.copy_(batch['Img'])) 124 | extract_result = FeatureExtractor(resnet, extract_list) 125 | SSI_real_vector2 = extract_result(real_input2)[0].cpu().detach().numpy() 126 | for patch_id in range(opt.batchSize): 127 | MSSI_realIHCImg_vector.append(SSI_real_vector2[patch_id].squeeze()) 128 | MSSI_realIHCImg_vector = np.squeeze(np.mean(np.array(MSSI_realIHCImg_vector), axis=0)) 129 | 130 | Vec_HE2fakeIHC = (MSSI_realHEImg_vector - MSSI_fakeIHCImg_vector)*len(MSSI_realHEImg_vector) 131 | Vec_HE2realIHC = (MSSI_realHEImg_vector - MSSI_realIHCImg_vector)*len(MSSI_realHEImg_vector) 132 | Vec_fakeIHC2realIHC = (MSSI_fakeIHCImg_vector - MSSI_realIHCImg_vector)*len(MSSI_realHEImg_vector) 133 | 134 | D = 1-np.linalg.norm(Vec_fakeIHC2realIHC)/np.linalg.norm(Vec_HE2realIHC) 135 | 136 | Vec_HE2fakeIHC = Vec_HE2fakeIHC / np.linalg.norm(Vec_HE2fakeIHC) 137 | Vec_HE2realIHC = Vec_HE2realIHC / np.linalg.norm(Vec_HE2realIHC) 138 | 139 | MSSI = ((np.dot(Vec_HE2fakeIHC, Vec_HE2realIHC)+1)/2+D)/2 140 | #/np.maximum(np.linalg.norm(Vec_HE2fakeIHC), np.linalg.norm(Vec_HE2realIHC)) 141 | 142 | return MSSI 143 | 144 | if __name__ == '__main__': 145 | file_fake = '/home/zhangbc/Mydataspace/LST/mymodel/HE2IHC_20/epoch19/output/A2B/H6_pre_B.png' 146 | file_real2 = '/home/zhangbc/Mydataspace/LST/raw_data/HE_IHC_20/test/K6.png' 147 | file_real = '/home/zhangbc/Mydataspace/LST/raw_data/HE_IHC_20/test/H6.png' 148 | MSSI = Resnet_SSI(file_fake, file_real, file_real2) # b-B:c+E 149 | print("MSSI:", MSSI) 150 | -------------------------------------------------------------------------------- /WSI_test_CycleGan_pathology_A2B_patch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | import torch 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | import numpy as np 13 | from mymodels import Generator_unet_seg, Generator_unet_cls 14 | import cv2 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 17 | 18 | 19 | class TestImage(Dataset): 20 | def __init__(self, Image_data, pos_list,transforms_=None, patch_size=512): 21 | self.transform = transforms_ 22 | self.Image = Image_data 23 | self.pos_list = pos_list 24 | self.patch_size = patch_size 25 | 26 | def __getitem__(self, index): 27 | item = self.transform(self.Image[self.pos_list[index][0] - self.patch_size:self.pos_list[index][0], 28 | self.pos_list[index][1] - self.patch_size:self.pos_list[index][1], :]) 29 | 30 | return {'A': item, 'x': self.pos_list[index][0], 'y': self.pos_list[index][1]} 31 | 32 | def __len__(self): 33 | return len(self.pos_list) 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--batchSize', type=int, default=1, help='size of the batches') 37 | parser.add_argument('--filepath', type=str, help='root directory of the testimage') 38 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 39 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 40 | parser.add_argument('--size', type=int, default=576, help='size of the data (squared assumed)') 41 | parser.add_argument('--cuda', action='store_true', default=True, help='use GPU computation') 42 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 43 | parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth', help='A2B generator checkpoint file') 44 | opt = parser.parse_args() 45 | # breast 5 / neuroendocrine 7 / GLAS 46 | datasetname = 'neuroendocrine' 47 | 48 | # opt.generator_A2B = '/home/zhangbc/Mydataspace/LST/'+datasetname+'/mymodelx288/Cycle_GAN_pathology_seg/exp1/epoch3/netG_A2B.pth' 49 | # opt.generator_B2A = '/home/zhangbc/Mydataspace/LST/'+datasetname+'/mymodelx288/Cycle_GAN_pathology_seg/exp1/epoch3/netG_B2A.pth' 50 | opt.generator_A2B = '/home/zhangbc/Mydataspace/LST/'+datasetname+'/mymodelx288/Cycle_GAN_pathology_seg_PN3_1/other_epoch/epoch9/netG_A2B.pth' 51 | opt.generator_B2A = '/home/zhangbc/Mydataspace/LST/'+datasetname+'/mymodelx288/Cycle_GAN_pathology_seg_PN3_1/other_epoch/epoch9/netG_B2A.pth' 52 | print(opt) 53 | if torch.cuda.is_available() and not opt.cuda: 54 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 55 | 56 | ###### Definition of variables ###### 57 | # Networks 58 | netG_A2B = Generator_unet_seg(opt.input_nc, opt.output_nc, 10, alt_leak=True, neg_slope=0.1) 59 | netG_B2A = Generator_unet_seg(opt.output_nc, opt.input_nc, 10, alt_leak=True, neg_slope=0.1) 60 | 61 | if opt.cuda: 62 | netG_A2B.cuda() 63 | netG_B2A.cuda() 64 | 65 | netG_A2B.load_state_dict(torch.load(opt.generator_A2B)) 66 | netG_B2A.load_state_dict(torch.load(opt.generator_B2A)) 67 | # Set model's test mode 68 | netG_A2B.eval() 69 | netG_B2A.eval() 70 | 71 | # Inputs & targets memory allocation 72 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 73 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 74 | 75 | # neuroendocrine 76 | # wsi_name = [1,1,2,2,2,3,3,4,4,4,4,4,5,5,6,6,7,7,7,8,8] 77 | # id = [1,2,1,2,3,1,2,1,2,3,4,5,1,2,1,2,1,2,3,1,2] 78 | wsi_name = [1,1,2,2,2,3,3,4,4,4,4,4,5,5,6,6,7,7,7,8,8] 79 | id = [1,2,1,2,3,1,2,1,2,3,4,5,1,2,1,2,1,2,3,1,2] 80 | # breast 81 | # wsi_name = [4,4,4,4,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8, 9,9,9] 82 | # id = [1,2,3,4,1,2,3,4,5,6,7,8,9,10,1,2,3,4,5,6,7,8,9,10,1,2,3] 83 | 84 | for i in range(wsi_name.__len__()): 85 | print(i) 86 | opt.filepath = '/home/zhangbc/Desktop/Mydataspace/LST/raw_data/'+datasetname+'/Test/HE/WSI_data00'+str(wsi_name[i])+'/'+str(id[i]).zfill(2)+'.png' 87 | if not os.path.exists(opt.filepath): 88 | print('wrong filepath') 89 | exit(0) 90 | # read test data and add padding 91 | Testimg_data = cv2.imread(opt.filepath) 92 | Testimg_data = cv2.copyMakeBorder(Testimg_data, opt.size//4, opt.size//4, opt.size//4, opt.size//4, cv2.BORDER_REFLECT) 93 | Testimg_data = np.array(cv2.cvtColor(Testimg_data, cv2.COLOR_BGR2RGB)).astype(np.uint8) 94 | 95 | filename = opt.filepath.split('/')[-1].split('.')[0] 96 | Testimg_shape = Testimg_data.shape 97 | X_index = list(range(opt.size, Testimg_shape[0], int(opt.size / 2))) 98 | X_index.append(Testimg_shape[0]) 99 | X_index = np.unique(np.asarray(X_index, dtype=np.int)) 100 | Y_index = list(range(opt.size, Testimg_shape[1], int(opt.size / 2))) 101 | Y_index.append(Testimg_shape[1]) 102 | Y_index = np.unique(np.asarray(Y_index, dtype=np.int)) 103 | pos_list = [] 104 | for x_id in X_index: 105 | for y_id in Y_index: 106 | pos_list.append([x_id, y_id]) 107 | Testimg_data = Testimg_data[0:X_index[-1], 0:Y_index[-1], :] 108 | Pre_img = np.zeros_like(Testimg_data) 109 | rec_img = np.zeros_like(Testimg_data) 110 | # Dataset loader 111 | transforms_ = transforms.Compose([ 112 | transforms.ToTensor(), 113 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 114 | 115 | dataloader = DataLoader(TestImage(Testimg_data, pos_list, transforms_=transforms_, patch_size=opt.size), 116 | batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) 117 | ################################### 118 | 119 | ###### Testing###### 120 | 121 | for i, batch in enumerate(dataloader): 122 | # Set model input 123 | x_ = batch['x'].numpy() # list of x position 124 | y_ = batch['y'].numpy() 125 | real_A = Variable(input_A.copy_(batch['A'])) 126 | # Generate output 127 | fake_B, _, c_out, _ = netG_A2B(real_A) 128 | rec_A, _, _, _ = netG_B2A(fake_B) 129 | fake_B = 0.5 * (fake_B.data + 1.0) 130 | rec_A = 0.5 * (rec_A.data + 1.0) 131 | #print(c_out.data) 132 | 133 | for patch_id in range(opt.batchSize): 134 | pre_patch_B = fake_B[patch_id].squeeze() 135 | pre_patch_B_np = torch.mul(pre_patch_B, 255).__add__(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', 136 | torch.uint8).numpy() 137 | 138 | pre_patch_B_np_center = pre_patch_B_np[int(opt.size / 4):int(opt.size / 4) * 3, 139 | int(opt.size / 4):int(opt.size / 4) * 3, :] 140 | # Save image files 141 | Pre_img[x_[patch_id] - opt.size + int(opt.size / 4):x_[patch_id] - opt.size + int(opt.size / 4) * 3, 142 | y_[patch_id] - opt.size + int(opt.size / 4):y_[patch_id] - opt.size + int(opt.size / 4) * 3, 143 | :] = pre_patch_B_np_center 144 | 145 | pre_patch_A = rec_A[patch_id].squeeze() 146 | pre_patch_A_np = torch.mul(pre_patch_A, 255).__add__(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', 147 | torch.uint8).numpy() 148 | 149 | pre_patch_A_np_center = pre_patch_A_np[int(opt.size / 4):int(opt.size / 4) * 3, 150 | int(opt.size / 4):int(opt.size / 4) * 3, :] 151 | # Save image files 152 | rec_img[x_[patch_id] - opt.size + int(opt.size / 4):x_[patch_id] - opt.size + int(opt.size / 4) * 3, 153 | y_[patch_id] - opt.size + int(opt.size / 4):y_[patch_id] - opt.size + int(opt.size / 4) * 3, 154 | :] = pre_patch_A_np_center 155 | 156 | sys.stdout.write('\rGenerated images %04d of %04d' % (i + 1, len(dataloader))) 157 | 158 | Pre_img = Pre_img[int(opt.size / 4):X_index[-1] - int(opt.size / 4), 159 | int(opt.size / 4):Y_index[-1] - int(opt.size / 4), :] 160 | rec_img = rec_img[int(opt.size / 4):X_index[-1] - int(opt.size / 4), 161 | int(opt.size / 4):Y_index[-1] - int(opt.size / 4), :] 162 | Testimg_data = Testimg_data[int(opt.size / 4):X_index[-1] - int(opt.size / 4), 163 | int(opt.size / 4):Y_index[-1] - int(opt.size / 4), :] 164 | Pre_img = Image.fromarray(Pre_img) 165 | rec_img = Image.fromarray(rec_img) 166 | Testimg = Image.fromarray(Testimg_data) 167 | save_dir = os.path.join(os.path.dirname(opt.generator_A2B), 'output/A2B/') 168 | save_dir = os.path.join(save_dir, os.path.dirname(opt.filepath).split('/')[-1]) 169 | if not os.path.exists(save_dir): 170 | os.makedirs(save_dir) 171 | Pre_img.save(os.path.join(save_dir, filename + '_pre_B.png')) 172 | Testimg.save(os.path.join(save_dir, filename + '_real_A.png')) 173 | rec_img.save(os.path.join(save_dir, filename + '_rec_A.png')) 174 | sys.stdout.write('\n') -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import torchvision.transforms.functional as tf 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | 12 | class ImageDataset(Dataset): 13 | def __init__(self, root, transforms_=None, batch_size=None, unaligned=False): 14 | self.transform = transforms_ 15 | self.unaligned = unaligned 16 | self.batch_size = batch_size 17 | self.files_A = sorted(glob.glob(os.path.join(root, 'HE') + '/*/*.png')) 18 | self.files_B = sorted(glob.glob(os.path.join(root, 'Ki67') + '/*/*.png')) 19 | 20 | def __getitem__(self, index): 21 | item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) 22 | # print(self.files_A[index % len(self.files_A)], '\n') 23 | if self.unaligned: 24 | item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) 25 | else: 26 | item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) 27 | 28 | return {'HE': item_A, 'Ki67': item_B} 29 | 30 | def __len__(self): 31 | return max(len(self.files_A)//self.batch_size * self.batch_size, 32 | len(self.files_B)//self.batch_size * self.batch_size) 33 | 34 | 35 | class ExpertDataset_label(Dataset): 36 | def __init__(self, export_root, transform_expert=None, batch_size=None, unaligned=False): 37 | self.transform_expert = transform_expert 38 | self.batch_size = batch_size 39 | self.unaligned = unaligned 40 | self.files_expert_A = sorted(glob.glob(os.path.join(export_root, 'HE') + '/*/*.npz')) 41 | self.files_expert_B = sorted(glob.glob(os.path.join(export_root, 'Ki67') + '/*/*.npz')) 42 | 43 | def __getitem__(self, index): 44 | 45 | expert_data_A = np.load(self.files_expert_A[index % len(self.files_expert_A)]) 46 | expert_item_A = expert_data_A['image'] 47 | expert_item_A = Image.fromarray(expert_item_A.astype('uint8')).convert('RGB') 48 | expert_item_A = self.transform_expert(expert_item_A) 49 | expert_item_A_p = torch.from_numpy(np.asarray(expert_data_A['proportion'], dtype=np.float32)) 50 | expert_item_A_l = torch.from_numpy(np.asarray(expert_data_A['label'], dtype=np.float32)) 51 | 52 | if self.unaligned: 53 | expert_data_B = np.load(self.files_expert_B[random.randint(0, len(self.files_expert_B) - 1)]) 54 | else: 55 | expert_data_B = np.load(self.files_expert_B[index % len(self.files_expert_B)]) 56 | 57 | expert_item_B = expert_data_B['image'] 58 | expert_item_B = Image.fromarray(expert_item_B.astype('uint8')).convert('RGB') 59 | expert_item_B = self.transform_expert(expert_item_B) 60 | expert_item_B_p = torch.from_numpy(np.asarray(expert_data_B['proportion'], dtype=np.float32)) 61 | expert_item_B_l = torch.from_numpy(np.asarray(expert_data_B['label'], dtype=np.float32)) 62 | 63 | return {'expert_HE': expert_item_A, 'expert_Ki67': expert_item_B, 64 | 'expert_HE_label': expert_item_A_l, 'expert_Ki67_label': expert_item_B_l, 65 | 'expert_HE_p': expert_item_A_p, 'expert_Ki67_p': expert_item_B_p, } 66 | 67 | def __len__(self): 68 | return max(len(self.files_expert_A)//self.batch_size * self.batch_size, 69 | len(self.files_expert_B)//self.batch_size * self.batch_size) 70 | 71 | 72 | class ExpertDataset_mask(Dataset): 73 | def __init__(self, export_root, batch_size=None, unaligned=False): 74 | self.batch_size = batch_size 75 | self.transform_img = transforms.Compose([ 76 | transforms.RandomHorizontalFlip(), 77 | transforms.RandomVerticalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 80 | self.transform_mask = transforms.Compose([ 81 | transforms.RandomHorizontalFlip(), 82 | transforms.RandomVerticalFlip(), 83 | transforms.ToTensor()]) 84 | self.unaligned = unaligned 85 | self.files_expert_A = sorted(glob.glob(os.path.join(export_root, 'HE') + '/*/*.npz')) 86 | self.files_expert_B = sorted(glob.glob(os.path.join(export_root, 'Ki67') + '/*/*.npz')) 87 | 88 | def __getitem__(self, index): 89 | 90 | expert_data_A = np.load(self.files_expert_A[index % len(self.files_expert_A)]) 91 | expert_item_A = expert_data_A['image'] 92 | expert_item_A = Image.fromarray(expert_item_A.astype('uint8')).convert('RGB') 93 | 94 | expert_item_A_mask = expert_data_A['mask']*255.0 95 | expert_item_A_mask = Image.fromarray(expert_item_A_mask.astype('uint8')) 96 | 97 | if self.unaligned: 98 | expert_data_B = np.load(self.files_expert_B[random.randint(0, len(self.files_expert_B) - 1)]) 99 | else: 100 | expert_data_B = np.load(self.files_expert_B[index % len(self.files_expert_B)]) 101 | 102 | seed = np.random.randint(2147483647) 103 | random.seed(seed) # apply this seed to img tranfsorms 104 | expert_item_A = self.transform_img(expert_item_A) 105 | 106 | random.seed(seed) # apply this seed to target tranfsorms 107 | expert_item_A_mask = self.transform_mask(expert_item_A_mask) 108 | 109 | expert_item_B = expert_data_B['image'] 110 | expert_item_B = Image.fromarray(expert_item_B.astype('uint8')).convert('RGB') 111 | 112 | expert_item_B_mask = expert_data_B['mask']*255.0 113 | expert_item_B_mask = Image.fromarray(expert_item_B_mask.astype('uint8')) 114 | 115 | seed = np.random.randint(2147483648) 116 | random.seed(seed) # apply this seed to img tranfsorms 117 | expert_item_B = self.transform_img(expert_item_B) 118 | 119 | random.seed(seed) # apply this seed to target tranfsorms 120 | expert_item_B_mask = self.transform_mask(expert_item_B_mask) 121 | 122 | return {'expert_HE': expert_item_A, 'expert_Ki67': expert_item_B, 123 | 'expert_HE_mask': expert_item_A_mask, 'expert_Ki67_mask': expert_item_B_mask} 124 | 125 | def __len__(self): 126 | return max(len(self.files_expert_A)//self.batch_size * self.batch_size, 127 | len(self.files_expert_B)//self.batch_size * self.batch_size) -------------------------------------------------------------------------------- /evaluation_metrics/3_1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/evaluation_metrics/3_1_0.png -------------------------------------------------------------------------------- /evaluation_metrics/3_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/evaluation_metrics/3_1_1.png -------------------------------------------------------------------------------- /evaluation_metrics/4_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/evaluation_metrics/4_4_0.png -------------------------------------------------------------------------------- /evaluation_metrics/4_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/evaluation_metrics/4_4_1.png -------------------------------------------------------------------------------- /evaluation_metrics/A: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluation_metrics/A_rA_Similarity.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from sewar.full_ref import ssim, msssim, psnr 4 | from pandas import DataFrame 5 | 6 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 'Cycle_GAN_pathology_cls', 'Cycle_GAN_pathology_seg'] 7 | model_id = 3 8 | epoch_id = 9 9 | print(modelname[model_id], epoch_id) 10 | wsi_name = [1,1,2,2,2,3,3,4,4,4,4,4,5,5,6,6,7,7,7,8,8] 11 | id = [1,2,1,2,3,1,2,1,2,3,4,5,1,2,1,2,1,2,3,1,2] 12 | xlsdata = { 13 | 'fid': [], 14 | 'msssim': [], 15 | 'ssim': [], 16 | 'psnr': [], 17 | 'mae': []} 18 | 19 | # wsi_name = [1, 1] 20 | # id = [1, 2] 21 | for idx in range(wsi_name.__len__()): 22 | img0 = cv2.imread('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/A2B/WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'_real_A.png') 23 | img1 = cv2.imread('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/A2B/WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'_rec_A.png') 24 | 25 | tile0 = img0[0:1500, 0:1505, :] 26 | tile1 = img1[0:1500, 0:1505, :] 27 | degree_msssim = msssim(tile0, tile1, ws=16) 28 | degree_psnr = psnr(tile0, tile1) 29 | degree_ssim, _ = ssim(tile0, tile1, ws=16) 30 | tile0_gray = np.asarray(cv2.cvtColor(tile0, cv2.COLOR_BGR2GRAY), dtype=float) 31 | tile1_gray = np.asarray(cv2.cvtColor(tile1, cv2.COLOR_BGR2GRAY), dtype=float) 32 | diff_img = np.abs(tile0_gray-tile1_gray) 33 | mae = np.mean(diff_img, axis=(0, 1)) 34 | print('<-----WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'------>') 35 | print('>>part 1') 36 | print('mae:', mae) 37 | print('msssim:', degree_msssim.real*100) 38 | print('ssim:', degree_ssim*100) 39 | print('psnr:', degree_psnr) 40 | xlsdata['fid'].append(str(wsi_name[idx])+'_'+str(id[idx])+'part1') 41 | xlsdata['msssim'].append(degree_msssim.real*100) 42 | xlsdata['ssim'].append(degree_ssim*100) 43 | xlsdata['psnr'].append(degree_psnr) 44 | xlsdata['mae'].append(mae) 45 | 46 | tile0 = img0[0:1500, 1505:3010, :] 47 | tile1 = img1[0:1500, 1505:3010, :] 48 | 49 | degree_msssim = msssim(tile0, tile1, ws=16) 50 | degree_psnr = psnr(tile0, tile1) 51 | degree_ssim, _ = ssim(tile0, tile1, ws=16) 52 | tile0_gray = np.asarray(cv2.cvtColor(tile0, cv2.COLOR_BGR2GRAY), dtype=float) 53 | tile1_gray = np.asarray(cv2.cvtColor(tile1, cv2.COLOR_BGR2GRAY), dtype=float) 54 | diff_img = np.abs(tile0_gray-tile1_gray) 55 | mae = np.mean(diff_img, axis=(0, 1)) 56 | print('>>part 2') 57 | print('mae:', mae) 58 | print('msssim:', degree_msssim.real*100) 59 | print('ssim:', degree_ssim*100) 60 | print('psnr:', degree_psnr) 61 | print('<---------end---------->') 62 | xlsdata['fid'].append(str(wsi_name[idx])+'_'+str(id[idx])+'part2') 63 | xlsdata['msssim'].append(degree_msssim.real*100) 64 | xlsdata['ssim'].append(degree_ssim*100) 65 | xlsdata['psnr'].append(degree_psnr) 66 | xlsdata['mae'].append(mae) 67 | df = DataFrame(xlsdata) 68 | df.to_excel('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/A_A.xlsx') 69 | -------------------------------------------------------------------------------- /evaluation_metrics/A_rB_Similarity.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from sewar.full_ref import ssim, msssim, psnr 4 | from pandas import DataFrame 5 | from skimage.measure import compare_ssim 6 | 7 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 'Cycle_GAN_pathology_cls', 'Cycle_GAN_pathology_seg'] 8 | model_id = 2 9 | epoch_id = 5 10 | print(modelname[model_id], epoch_id) 11 | wsi_name = [1,1,2,2,2,3,3,4,4,4,4,4,5,5,6,6,7,7,7,8,8] 12 | id = [1,2,1,2,3,1,2,1,2,3,4,5,1,2,1,2,1,2,3,1,2] 13 | 14 | xlsdata = { 15 | 'fid': [], 16 | 'cs': [], 17 | 'ssim': []} 18 | 19 | # wsi_name = [1, 1] 20 | # id = [1, 2] 21 | for idx in range(wsi_name.__len__()): 22 | img0 = cv2.imread('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/A2B/WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'_real_A.png') 23 | img1 = cv2.imread('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/A2B/WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'_pre_B.png') 24 | 25 | tile0 = img0[0:1500, 0:1500, :] 26 | tile1 = img1[0:1500, 0:1500, :] 27 | 28 | degree_ssim1, degree_ssim2 = ssim(tile0, tile1, ws=5) 29 | 30 | print('<-----WSI_data00'+str(wsi_name[idx])+'/0'+str(id[idx])+'------>') 31 | print('>>part 1') 32 | #print('msssim:', degree_msssim.real*100) 33 | print('ssim1:', degree_ssim1 * 100) 34 | print('ssim2:', degree_ssim2 * 100) 35 | xlsdata['fid'].append(str(wsi_name[idx])+'_'+str(id[idx])+'part1') 36 | xlsdata['cs'].append(degree_ssim2*100) 37 | xlsdata['ssim'].append(degree_ssim1*100) 38 | 39 | tile0 = img0[0:1500, 1505:3010, :] 40 | tile1 = img1[0:1500, 1505:3010, :] 41 | 42 | degree_ssim1, degree_ssim2 = ssim(tile0, tile1, ws=5) 43 | 44 | print('>>part 2') 45 | print('ssim1:', degree_ssim1 * 100) 46 | print('ssim2:', degree_ssim2 * 100) 47 | print('<---------end---------->') 48 | xlsdata['fid'].append(str(wsi_name[idx])+'_'+str(id[idx])+'part2') 49 | xlsdata['ssim'].append(degree_ssim1*100) 50 | xlsdata['cs'].append(degree_ssim2 * 100) 51 | 52 | df = DataFrame(xlsdata) 53 | df.to_excel('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/'+modelname[model_id]+'/exp1/epoch'+str(epoch_id)+'/output/B_B_SSIM.xlsx') 54 | -------------------------------------------------------------------------------- /evaluation_metrics/Neg_area_pearson.py: -------------------------------------------------------------------------------- 1 | import Stain_seperation.stain_Norm_Vahadane as Norm_Vahadane 2 | import cv2 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from pandas import DataFrame 6 | 7 | 8 | def getImage(filepath): 9 | img_data = cv2.imread(filepath) 10 | img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB) 11 | return img_data 12 | 13 | 14 | def Neg_area(img1_path, img0_path, size=750,step=750, re_fit=False, showfig = False): 15 | img0 = getImage(img0_path) 16 | img1 = getImage(img1_path) 17 | imgshape = img0.shape 18 | 19 | norm = Norm_Vahadane.normalizer() 20 | if re_fit: 21 | norm.fit([img0_path, img1_path]) 22 | H0, E0 = norm.hematoxylin_eosin(img0) 23 | Neg0 = 1 - E0 24 | H1, E1 = norm.hematoxylin_eosin(img1) 25 | Neg1 = 1 - E1 26 | 27 | 28 | Neg0[Neg0 < 0.4] = 0 29 | Neg1[Neg1 < 0.4] = 0 30 | 31 | if showfig: 32 | plt.figure(1) 33 | plt.subplot(2, 2, 1) 34 | plt.imshow(img0) 35 | plt.subplot(2, 2, 2) 36 | plt.imshow(Neg0) 37 | plt.subplot(2, 2, 3) 38 | plt.imshow(img1) 39 | plt.subplot(2, 2, 4) 40 | plt.imshow(Neg1) 41 | plt.show() 42 | 43 | Neg0[Neg0 >= 0.4] = 1 44 | Neg1[Neg1 >= 0.4] = 1 45 | 46 | cv2.imwrite('3_1_0.png', Neg0*255) 47 | cv2.imwrite('3_1_1.png', Neg1*255) 48 | X_index = list(range(size, imgshape[0], step)) 49 | X_index.append(imgshape[0]) 50 | X_index = np.unique(np.asarray(X_index, dtype=np.int)) 51 | Y_index = list(range(size, imgshape[1], step)) 52 | Y_index.append(imgshape[1]) 53 | Y_index = np.unique(np.asarray(Y_index, dtype=np.int)) 54 | Neg0_area=[] 55 | Neg1_area=[] 56 | for x_id in X_index: 57 | for y_id in Y_index: 58 | Neg0_area.append(np.sum(Neg0[x_id - size:x_id, y_id - size:y_id])) 59 | Neg1_area.append(np.sum(Neg1[x_id - size:x_id, y_id - size:y_id])) 60 | ''' 61 | Neg0_1 = Neg0[:, 0:imgshape[1]//2] 62 | Neg0_2 = Neg0[:, imgshape[1] // 2:imgshape[1]] 63 | Neg1_1 = Neg1[:, 0:imgshape[1] // 2] 64 | Neg1_2 = Neg1[:, imgshape[1] // 2:imgshape[1]] 65 | 66 | Neg0_1area = np.sum(Neg0_1[:]) 67 | Neg0_2area = np.sum(Neg0_2[:]) 68 | Neg1_1area = np.sum(Neg1_1[:]) 69 | Neg1_2area = np.sum(Neg1_2[:]) 70 | ''' 71 | return Neg0_area, Neg1_area 72 | # neuroendocrine,breast 73 | if __name__ == '__main__': 74 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 75 | 'Cycle_GAN_pathology_cls','Cycle_GAN_tum', 76 | 'Cycle_GAN_pathology_seg', 'Cycle_GAN_pathology_seg_PN1_1', 77 | 'Cycle_GAN_pathology_seg_PN1_3', 'Cycle_GAN_pathology_seg_PN3_1', 78 | 'Cycle_GAN_PN_1_1','Cycle_GAN_PN_1_3', 'Cycle_GAN_PN_3_1', 79 | ] 80 | model_id = 5 81 | print(model_id) 82 | epoch_id = 10 83 | wsi_name = [1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8] 84 | id = [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 5, 1, 2, 1, 2, 1, 2, 3, 1, 2] 85 | # wsi_name = [4,4,4,4,7,7,7,7,7,7,7,7,7,7, 8,8,8,8,8,8,8,8,8,8, 9,9,9] 86 | # id = [1,2,3,4,1,2,3,4,5,6,7,8,9,10,1,2,3,4,5,6,7,8,9,10,1,2,3] 87 | xlsdata = { 88 | 'fid': [], 89 | 'Neg_area': [], 90 | 'ref_Neg_area': []} 91 | for idx in range(wsi_name.__len__()): 92 | # img1 = '/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/exp3/epoch' + str(epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 93 | # img0 = '/home/zhangbc/Mydataspace/LST/raw_data/neuroendocrine/Test/Ki67/WSI_data00' + str(wsi_name[idx])+'/'+str(id[idx]).zfill(2)+'.png' 94 | img1 = 'G:/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/epoch' + str(epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 95 | img0 = 'G:/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/epoch' + str(epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 96 | # img0 = '/home/zhangbc/Mydataspace/LST/raw_data/neuroendocrine/Test/Ki67/WSI_data00' + str(wsi_name[idx])+'/'+str(id[idx]).zfill(2)+'.png' 97 | Neg0_area, Neg1_area = Neg_area(img1, img0, size=750, step=375, re_fit=True, showfig=True) 98 | print(str(wsi_name[idx]) + '_' + str(id[idx])) 99 | for pid in range(Neg0_area.__len__()): 100 | print('part', str(pid)) 101 | print('reference ki67 Neg area:', Neg0_area[pid]) 102 | print('generated ki67 Neg area:', Neg1_area[pid]) 103 | xlsdata['fid'].append(str(wsi_name[idx]) + '_' + str(id[idx]) + '_part'+str(pid)) 104 | xlsdata['ref_Neg_area'].append(Neg0_area[pid]) 105 | xlsdata['Neg_area'].append(Neg1_area[pid]) 106 | 107 | df = DataFrame(xlsdata) 108 | order = ['fid', 'Neg_area', 'ref_Neg_area'] 109 | df = df[order] 110 | df.to_excel('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/exp3/epoch' + str(epoch_id) + '/output/B_B_neg.xlsx') -------------------------------------------------------------------------------- /evaluation_metrics/Pos_area_pearson.py: -------------------------------------------------------------------------------- 1 | import Stain_seperation.stain_Norm_Vahadane as Norm_Vahadane 2 | import cv2 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from pandas import DataFrame 6 | import skimage.morphology as sm 7 | 8 | def getImage(filepath): 9 | img_data = cv2.imread(filepath) 10 | img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB) 11 | return img_data 12 | 13 | 14 | def Neg_area(img1_path,img0_path, size=750,step=750, re_fit=False, showfig = False): 15 | img0 = getImage(img0_path) 16 | img1 = getImage(img1_path) 17 | imgshape = img0.shape 18 | 19 | norm = Norm_Vahadane.normalizer() 20 | if re_fit: 21 | norm.fit([img0_path, img1_path]) 22 | H0, E0 = norm.hematoxylin_eosin(img0) 23 | Neg0 = 1 - E0 24 | Neg0[Neg0 < 0.3] = 0 25 | Neg0[Neg0 >= 0.3] = 1 26 | kernel = sm.disk(5) 27 | Neg0 = sm.dilation(Neg0, kernel) 28 | Pos0 = H0 - Neg0 29 | Pos0[Pos0 < 0.1] = 1 30 | Pos0 = 1 - Pos0 31 | H1, E1 = norm.hematoxylin_eosin(img1) 32 | Neg1 = 1 - E1 33 | Neg1[Neg1 < 0.3] = 0 34 | Neg1[Neg1 >= 0.3] = 1 35 | Neg1 = sm.dilation(Neg1, kernel) 36 | Pos1 = H1 - Neg1 37 | Pos1[Pos1 <= 0.1] = 1 38 | Pos1 = 1 - Pos1 39 | 40 | Pos0[Pos0 < 0.2] = 0 41 | Pos1[Pos1 < 0.2] = 0 42 | 43 | if showfig: 44 | plt.figure(1) 45 | plt.subplot(2, 2, 1) 46 | plt.imshow(img0) 47 | plt.subplot(2, 2, 2) 48 | plt.imshow(Pos0) 49 | plt.subplot(2, 2, 3) 50 | plt.imshow(img1) 51 | plt.subplot(2, 2, 4) 52 | plt.imshow(Pos1) 53 | plt.show() 54 | 55 | Pos0[Pos0 >= 0.2] = 1 56 | Pos1[Pos1 >= 0.2] = 1 57 | 58 | cv2.imwrite('4_4_0.png', Pos0*255) 59 | cv2.imwrite('4_4_1.png', Pos1*255) 60 | X_index = list(range(size, imgshape[0], step)) 61 | X_index.append(imgshape[0]) 62 | X_index = np.unique(np.asarray(X_index, dtype=np.int)) 63 | Y_index = list(range(size, imgshape[1], step)) 64 | Y_index.append(imgshape[1]) 65 | Y_index = np.unique(np.asarray(Y_index, dtype=np.int)) 66 | Pos0_area=[] 67 | Pos1_area=[] 68 | for x_id in X_index: 69 | for y_id in Y_index: 70 | Pos0_area.append(np.sum(Pos0[x_id - size:x_id, y_id - size:y_id])) 71 | Pos1_area.append(np.sum(Pos1[x_id - size:x_id, y_id - size:y_id])) 72 | ''' 73 | Neg0_1 = Neg0[:, 0:imgshape[1]//2] 74 | Neg0_2 = Neg0[:, imgshape[1] // 2:imgshape[1]] 75 | Neg1_1 = Neg1[:, 0:imgshape[1] // 2] 76 | Neg1_2 = Neg1[:, imgshape[1] // 2:imgshape[1]] 77 | 78 | Neg0_1area = np.sum(Neg0_1[:]) 79 | Neg0_2area = np.sum(Neg0_2[:]) 80 | Neg1_1area = np.sum(Neg1_1[:]) 81 | Neg1_2area = np.sum(Neg1_2[:]) 82 | ''' 83 | return Pos0_area, Pos1_area 84 | 85 | if __name__ == '__main__': 86 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 87 | 'Cycle_GAN_pathology_cls', 88 | 'Cycle_GAN_pathology_seg', 'Cycle_GAN_pathology_seg_PN1_1', 89 | 'Cycle_GAN_pathology_seg_PN1_3', 'Cycle_GAN_pathology_seg_PN3_1', 90 | 'Cycle_GAN_PN_1_1','Cycle_GAN_PN_1_3', 'Cycle_GAN_PN_3_1', 91 | ] 92 | model_id = 8 93 | print(model_id) 94 | epoch_id = 9 95 | wsi_name = [1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8] 96 | id = [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 5, 1, 2, 1, 2, 1, 2, 3, 1, 2] 97 | # wsi_name = [4,4,4,4,4,7,7,7,7,7,7,7,7,7,7, 8,8,8,8,8,8,8,8,8,8, 9,9,9] 98 | # id = [4,1,2,3,4,1,2,3,4,5,6,7,8,9,10,1,2,3,4,5,6,7,8,9,10,1,2,3] 99 | xlsdata = { 100 | 'fid': [], 101 | 'Neg_area': [], 102 | 'ref_Neg_area': []} 103 | for idx in range(wsi_name.__len__()): 104 | img1 = '/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/exp3/epoch' + str(epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 105 | img0 = '/home/zhangbc/Mydataspace/LST/raw_data/neuroendocrine/Test/Ki67/WSI_data00' + str(wsi_name[idx])+'/'+str(id[idx]).zfill(2)+'.png' 106 | Neg0_area, Neg1_area = Neg_area(img1, img0, size=750, step=375, re_fit=True, showfig=False) 107 | print(str(wsi_name[idx]) + '_' + str(id[idx])) 108 | for pid in range(Neg0_area.__len__()): 109 | print('part', str(pid)) 110 | print('reference ki67 Neg area:', Neg0_area[pid]) 111 | print('generated ki67 Neg area:', Neg1_area[pid]) 112 | xlsdata['fid'].append(str(wsi_name[idx]) + '_' + str(id[idx]) + '_part'+str(pid)) 113 | xlsdata['ref_Neg_area'].append(Neg0_area[pid]) 114 | xlsdata['Neg_area'].append(Neg1_area[pid]) 115 | 116 | df = DataFrame(xlsdata) 117 | order = ['fid', 'Neg_area', 'ref_Neg_area'] 118 | df = df[order] 119 | df.to_excel('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/exp3/epoch' + str(epoch_id) + '/output/B_B_Pos.xlsx') -------------------------------------------------------------------------------- /evaluation_metrics/Resnet_P_hash_value.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import os 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | 8 | import numpy as np 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | from pandas import DataFrame 12 | ''' 13 | resnet-18 layer3--->thre=0.0025 14 | resnet-18 layer3--->thre=0.005 15 | resnet-18 layer2--->thre=0.01 16 | resnet-18 layer1--->thre=0.02 17 | ''' 18 | 19 | use_gpu = True 20 | h = 1500 21 | w = 1505 22 | 23 | def getImage(filepath, part): 24 | transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 27 | img_data = cv2.imread(filepath) 28 | if part == 1: 29 | img_data = img_data[0:1500, 0:1505, :] 30 | elif part == 2: 31 | img_data = img_data[0:1500, 1505:3010, :] 32 | img_data = cv2.resize(img_data, (w, h), interpolation=cv2.INTER_CUBIC) 33 | img_data = np.array(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)).astype(np.uint8) 34 | return transform(img_data) 35 | 36 | 37 | class FeatureExtractor(nn.Module): 38 | def __init__(self, submodule, extracted_layers): 39 | super(FeatureExtractor, self).__init__() 40 | self.submodule = submodule 41 | self.extracted_layers = extracted_layers 42 | 43 | # modify the forward function 44 | def forward(self, x): 45 | outputs = [] 46 | for name, module in self.submodule._modules.items(): 47 | if name is "fc": x = x.view(x.size(0), -1) 48 | x = module(x) 49 | if name in self.extracted_layers: 50 | outputs.append(x) 51 | return outputs 52 | 53 | 54 | def getHash(f1,f2,f0, thre=None): 55 | f1 = np.mean(f1, axis=(1, 2)) 56 | f2 = np.mean(f2, axis=(1, 2)) 57 | f0 = np.mean(f0, axis=(1, 2)) 58 | dist1 = np.abs(f0-f1) 59 | dist2 = np.abs(f0-f2) 60 | if thre is None: 61 | thre = (np.min(dist1)+np.max(dist1)+np.min(dist2)+np.max(dist2))/4.0 62 | hash1 =[] 63 | #print(thre) 64 | for i in range(dist1.shape[0]): 65 | if dist1[i] > thre: 66 | hash1.append(1) 67 | else: 68 | hash1.append(0) 69 | hash2=[] 70 | for i in range(dist2.shape[0]): 71 | if dist2[i] > thre: 72 | hash2.append(1) 73 | else: 74 | hash2.append(0) 75 | length = hash1.__len__() 76 | degree1 = (length - np.sum(hash1)) / length * 100 77 | degree2 = (length - np.sum(hash2)) / length * 100 78 | return degree1,degree2 79 | 80 | 81 | def P_Hash(f1, f0, sf_list=None, thre=0.005): 82 | if sf_list is None: 83 | f1 = np.mean(f1, axis=(1, 2)) 84 | f0 = np.mean(f0, axis=(1, 2)) 85 | else: 86 | f1 = np.mean(f1[sf_list, :, :], axis=(1, 2)) 87 | f0 = np.mean(f0[sf_list, :, :], axis=(1, 2)) 88 | dist = np.abs(f0-f1) 89 | hash =[] 90 | for i in range(dist.shape[0]): 91 | if dist[i] > thre: # if dist[i] > thre: 92 | hash.append(1) 93 | else: 94 | hash.append(0) 95 | length = hash.__len__() 96 | degree = (length - np.sum(hash)) / length * 100 97 | return degree, f1, f0, hash 98 | 99 | 100 | def Resnet_P_hash2(file_fake1, file_fake2, file_real,showfeature): 101 | 102 | extract_list = ["layer2"] 103 | 104 | Net = models.resnet18(pretrained=True) 105 | #print(Net) # 可以打印看模型结构 106 | if use_gpu: 107 | Net.cuda() 108 | 109 | image_fake1 = getImage(file_fake1) 110 | image_fake2 = getImage(file_fake2) 111 | image_real = getImage(file_real) 112 | 113 | Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor 114 | input1 = Tensor(1, 3, h, w) 115 | input2 = Tensor(1, 3, h, w) 116 | input0 = Tensor(1, 3, h, w) 117 | 118 | x1 = torch.unsqueeze(image_fake1, dim=0) 119 | x2 = torch.unsqueeze(image_fake2, dim=0) 120 | x0 = torch.unsqueeze(image_real, dim=0) 121 | 122 | fake1_input = Variable(input1.copy_(x1)) 123 | fake2_input = Variable(input2.copy_(x2)) 124 | real_input = Variable(input0.copy_(x0)) 125 | 126 | extract_result = FeatureExtractor(Net, extract_list) 127 | x1_channelfeature = np.squeeze(extract_result(fake1_input)[0].cpu().detach().numpy()) 128 | x2_channelfeature = np.squeeze(extract_result(fake2_input)[0].cpu().detach().numpy()) 129 | x0_channelfeature = np.squeeze(extract_result(real_input)[0].cpu().detach().numpy()) 130 | 131 | degree, _, _, _ = P_Hash(x1_channelfeature, x0_channelfeature, thre=0.01) 132 | print('our——phash', degree) 133 | degree, _, _, _ = P_Hash(x2_channelfeature, x0_channelfeature, thre=0.01) 134 | print('cycle——phash', degree) 135 | 136 | 137 | print(x1_channelfeature.shape) 138 | if showfeature: 139 | plt.figure() 140 | for i in range(x1_channelfeature.shape[0]): 141 | plt.subplot(1,3,1) 142 | plt.imshow(x1_channelfeature[i, :, :]) 143 | plt.subplot(1,3,2) 144 | plt.imshow(x2_channelfeature[i, :, :]) 145 | plt.subplot(1,3,3) 146 | plt.imshow(x0_channelfeature[i, :, :]) 147 | plt.waitforbuttonpress() 148 | plt.show() 149 | 150 | 151 | def Resnet_P_hash(testImage,refImage, part=1, layername="layer1", thre=0.02, sf_list=None, showfeature=False): 152 | extract_list =[layername] 153 | 154 | Net = models.resnet101(pretrained=True) 155 | # print(Net) # 可以打印看模型结构 156 | if use_gpu: 157 | Net.cuda() 158 | 159 | image_test = getImage(testImage, part) 160 | image_ref = getImage(refImage, part) 161 | 162 | Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor 163 | input1 = Tensor(1, 3, h, w) 164 | input0 = Tensor(1, 3, h, w) 165 | 166 | x1 = torch.unsqueeze(image_test, dim=0) 167 | x0 = torch.unsqueeze(image_ref, dim=0) 168 | 169 | test_input = Variable(input1.copy_(x1)) 170 | ref_input = Variable(input0.copy_(x0)) 171 | 172 | extract_result = FeatureExtractor(Net, extract_list) 173 | x1_channelfeature = np.squeeze(extract_result(test_input)[0].cpu().detach().numpy()) 174 | x0_channelfeature = np.squeeze(extract_result(ref_input)[0].cpu().detach().numpy()) 175 | 176 | degree, f1, f0, hash_tabel = P_Hash(x1_channelfeature, x0_channelfeature, sf_list=sf_list, thre=thre) 177 | list_A = np.reshape(np.asarray(np.where(np.asarray(hash_tabel)>0)), newshape=[-1]) 178 | # print(list_A.shape) 179 | # print('list_A', list_A) 180 | # print('Resnet_phash', degree) 181 | shape = f1.shape[0] 182 | f1 = np.reshape(f1, (1, shape)) 183 | f1 = np.tile(f1, (10, 1)) 184 | f0 = np.reshape(f0, (1, shape)) 185 | f0 = np.tile(f0, (10, 1)) 186 | hash_tabel = np.reshape(hash_tabel, (1, shape)) 187 | hash_tabel = np.tile(hash_tabel, (10, 1)) 188 | black = np.zeros(shape=(5, shape)) 189 | 190 | show_img = np.vstack((np.vstack((np.vstack((np.vstack((f1, black)), f0)), black)), hash_tabel)) 191 | kk=[9, 12, 18, 27, 28, 29, 38, 45, 47, 64, 85, 103, 111, 126] 192 | if showfeature: 193 | plt.figure() 194 | for idx in range(x1_channelfeature.shape[0]):# x1_channelfeature.shape[0] list_A.__len__() kk.__len__() 195 | # i = kk[idx] 196 | i = idx 197 | plt.subplot(3, 1, 1) 198 | plt.imshow(show_img) 199 | plt.subplot(3,1,2) 200 | plt.imshow(x1_channelfeature[i, :, :]) 201 | plt.subplot(3,1,3) 202 | plt.imshow(x0_channelfeature[i, :, :]) 203 | plt.waitforbuttonpress() 204 | print(i) 205 | plt.show() 206 | return degree, list_A 207 | 208 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 209 | 210 | if __name__ == '__main__': 211 | 212 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 'Cycle_GAN_pathology_cls', 213 | 'Cycle_GAN_pathology_seg'] 214 | model_id = 5 215 | print(model_id) 216 | epoch_id = 7 217 | # wsi_name = [1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8] 218 | # id = [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 5, 1, 2, 1, 2, 1, 2, 3, 1, 2] 219 | wsi_name = [4,4,4,4,7,7,7,7,7,7,7,7,7,7, 8,8,8,8,8,8,8,8,8,8, 9,9,9] 220 | id = [1,2,3,4,1,2,3,4,5,6,7,8,9,10,1,2,3,4,5,6,7,8,9,10,1,2,3] 221 | D_list = np.zeros(shape=(2048)) 222 | xlsdata = { 223 | 'thre': [], 224 | 'p-hash_value_mean': [], 225 | 'p-hash_value_std': []} 226 | layer_id = 2 227 | layername = "layer" + str(layer_id) 228 | for index in range(1, 21): 229 | # thre = index*0.0005 230 | thre = index * 0.001 231 | degree_list = [] 232 | for idx in range(wsi_name.__len__()): 233 | img1 = '/home/zhangbc/Mydataspace/LST/breast/mymodelx288/' + modelname[model_id] + '/exp2/epoch' + str( 234 | epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 235 | img0 = '/home/zhangbc/Mydataspace/LST/raw_data/breast/Test/Ki67/WSI_data00' + str( 236 | wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '.png' 237 | 238 | degree, A = Resnet_P_hash(img1, img0, part=1, layername=layername, sf_list=None, thre=thre, showfeature=False) 239 | degree_list.append(degree) 240 | for i in A: 241 | D_list[i] += 1 242 | 243 | degree, A = Resnet_P_hash(img1, img0, part=2, layername=layername, sf_list=None, thre=thre, showfeature=False) 244 | degree_list.append(degree) 245 | for i in A: 246 | D_list[i] += 1 247 | degree_list = np.asarray(degree_list, dtype=np.float32) 248 | mean_pv = np.mean(degree_list) 249 | std_pv = np.std(degree_list) 250 | print(index) 251 | print('Resnet_phash', mean_pv) 252 | print('Resnet_phash_std', std_pv) 253 | xlsdata['thre'].append(thre) 254 | xlsdata['p-hash_value_mean'].append(mean_pv) 255 | xlsdata['p-hash_value_std'].append(std_pv) 256 | 257 | df = DataFrame(xlsdata) 258 | order = ['thre', 'p-hash_value_mean', 'p-hash_value_std'] 259 | df = df[order] 260 | df.to_excel('/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/' + modelname[model_id] + '/epoch' + str( 261 | epoch_id) + '/output/B_B_pv_'+layername+'.xlsx') 262 | 263 | ''' 264 | print(D_list) 265 | d_flist = np.where(D_list>0) 266 | print(d_flist) 267 | DF = DF + np.reshape(np.asarray(d_flist), [-1]).tolist() 268 | print(np.unique(DF).tolist()) 269 | 270 | 271 | DF =[15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 49, 50, 51, 52, 54, 55, 56, 59, 60, 61, 62, 63, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 116, 117, 118, 120, 122, 123, 125, 126, 127, 129, 130, 131, 132, 133, 134, 136, 137, 139, 140, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 170, 171, 172, 174, 175, 176, 179, 180, 181, 184, 187, 188, 189, 192, 194, 195, 196, 197, 198, 201, 202, 203, 205, 207, 208, 209, 211, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 231, 232, 233, 235, 236, 237, 238, 239, 241, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 258, 259, 260, 261, 264, 265, 267, 271, 273, 275, 276, 277, 278, 279, 280, 282, 283, 284, 285, 287, 288, 289, 291, 292, 293, 294, 295, 296, 299, 301, 302, 303, 304, 307, 308, 310, 311, 315, 316, 317, 318, 319, 321, 322, 323, 324, 325, 326, 327, 328, 330, 331, 332, 334, 335, 336, 337, 339, 346, 348, 350, 352, 354, 356, 357, 358, 359, 362, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 376, 377, 378, 379, 384, 386, 387, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 407, 410, 411, 412, 414, 415, 416, 417, 418, 420, 421, 422, 423, 424, 425, 427, 428, 430, 431, 432, 434, 435, 438, 439, 442, 443, 444, 445, 446, 447, 448, 449, 450, 452, 453, 454, 455, 456, 457, 458, 459, 460, 464, 465, 466, 467, 468, 469, 470, 472, 473, 475, 476, 477, 478, 479, 482, 483, 484, 485, 486, 487, 488, 489, 490, 494, 496, 497, 498, 499, 500, 502, 503, 504, 505, 509, 510, 511] 272 | 273 | ''' 274 | -------------------------------------------------------------------------------- /evaluation_metrics/Resnet_P_hash_value2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import os 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | 8 | import numpy as np 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | from pandas import DataFrame 12 | ''' 13 | resnet-18 layer3--->thre=0.0025 14 | resnet-18 layer3--->thre=0.005 15 | resnet-18 layer2--->thre=0.01 16 | resnet-18 layer1--->thre=0.02 17 | ''' 18 | 19 | use_gpu = True 20 | h = 1500 21 | w = 1505 22 | 23 | def getImage(filepath, part): 24 | transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 27 | img_data = cv2.imread(filepath) 28 | if part == 1: 29 | img_data = img_data[0:1500, 0:1505, :] 30 | elif part == 2: 31 | img_data = img_data[0:1500, 1505:3010, :] 32 | img_data = cv2.resize(img_data, (w, h), interpolation=cv2.INTER_CUBIC) 33 | img_data = np.array(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)).astype(np.uint8) 34 | return transform(img_data) 35 | 36 | 37 | class FeatureExtractor(nn.Module): 38 | def __init__(self, submodule, extracted_layers): 39 | super(FeatureExtractor, self).__init__() 40 | self.submodule = submodule 41 | self.extracted_layers = extracted_layers 42 | 43 | # modify the forward function 44 | def forward(self, x): 45 | outputs = [] 46 | for name, module in self.submodule._modules.items(): 47 | if name is "fc": x = x.view(x.size(0), -1) 48 | x = module(x) 49 | if name in self.extracted_layers: 50 | outputs.append(x) 51 | return outputs 52 | 53 | 54 | def getHash(f1,f2,f0, thre=None): 55 | f1 = np.mean(f1, axis=(1, 2)) 56 | f2 = np.mean(f2, axis=(1, 2)) 57 | f0 = np.mean(f0, axis=(1, 2)) 58 | dist1 = np.abs(f0-f1) 59 | dist2 = np.abs(f0-f2) 60 | if thre is None: 61 | thre = (np.min(dist1)+np.max(dist1)+np.min(dist2)+np.max(dist2))/4.0 62 | hash1 =[] 63 | #print(thre) 64 | for i in range(dist1.shape[0]): 65 | if dist1[i] > thre: 66 | hash1.append(1) 67 | else: 68 | hash1.append(0) 69 | hash2=[] 70 | for i in range(dist2.shape[0]): 71 | if dist2[i] > thre: 72 | hash2.append(1) 73 | else: 74 | hash2.append(0) 75 | length = hash1.__len__() 76 | degree1 = (length - np.sum(hash1)) / length * 100 77 | degree2 = (length - np.sum(hash2)) / length * 100 78 | return degree1,degree2 79 | 80 | 81 | def P_Hash(f1, f0, sf_list=None, thre=0.005): 82 | if sf_list is None: 83 | f1 = np.mean(f1, axis=(1, 2)) 84 | f0 = np.mean(f0, axis=(1, 2)) 85 | else: 86 | f1 = np.mean(f1[sf_list, :, :], axis=(1, 2)) 87 | f0 = np.mean(f0[sf_list, :, :], axis=(1, 2)) 88 | dist = np.abs(f0-f1) 89 | hash =[] 90 | for i in range(dist.shape[0]): 91 | if dist[i] > thre: # if dist[i] > thre: 92 | hash.append(1) 93 | else: 94 | hash.append(0) 95 | length = hash.__len__() 96 | degree = (length - np.sum(hash)) / length * 100 97 | return degree, f1, f0, hash 98 | 99 | 100 | def Resnet_P_hash2(file_fake1, file_fake2, file_real,showfeature): 101 | 102 | extract_list = ["layer2"] 103 | 104 | Net = models.resnet18(pretrained=True) 105 | #print(Net) # 可以打印看模型结构 106 | if use_gpu: 107 | Net.cuda() 108 | 109 | image_fake1 = getImage(file_fake1) 110 | image_fake2 = getImage(file_fake2) 111 | image_real = getImage(file_real) 112 | 113 | Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor 114 | input1 = Tensor(1, 3, h, w) 115 | input2 = Tensor(1, 3, h, w) 116 | input0 = Tensor(1, 3, h, w) 117 | 118 | x1 = torch.unsqueeze(image_fake1, dim=0) 119 | x2 = torch.unsqueeze(image_fake2, dim=0) 120 | x0 = torch.unsqueeze(image_real, dim=0) 121 | 122 | fake1_input = Variable(input1.copy_(x1)) 123 | fake2_input = Variable(input2.copy_(x2)) 124 | real_input = Variable(input0.copy_(x0)) 125 | 126 | extract_result = FeatureExtractor(Net, extract_list) 127 | x1_channelfeature = np.squeeze(extract_result(fake1_input)[0].cpu().detach().numpy()) 128 | x2_channelfeature = np.squeeze(extract_result(fake2_input)[0].cpu().detach().numpy()) 129 | x0_channelfeature = np.squeeze(extract_result(real_input)[0].cpu().detach().numpy()) 130 | 131 | degree, _, _, _ = P_Hash(x1_channelfeature, x0_channelfeature, thre=0.01) 132 | print('our——phash', degree) 133 | degree, _, _, _ = P_Hash(x2_channelfeature, x0_channelfeature, thre=0.01) 134 | print('cycle——phash', degree) 135 | 136 | 137 | print(x1_channelfeature.shape) 138 | if showfeature: 139 | plt.figure() 140 | for i in range(x1_channelfeature.shape[0]): 141 | plt.subplot(1,3,1) 142 | plt.imshow(x1_channelfeature[i, :, :]) 143 | plt.subplot(1,3,2) 144 | plt.imshow(x2_channelfeature[i, :, :]) 145 | plt.subplot(1,3,3) 146 | plt.imshow(x0_channelfeature[i, :, :]) 147 | plt.waitforbuttonpress() 148 | plt.show() 149 | 150 | 151 | def Resnet_P_hash(testImage,refImage, part=1, layername="layer1", thre=0.02, sf_list=None, showfeature=False): 152 | extract_list =[layername] 153 | 154 | Net = models.resnet101(pretrained=True) 155 | # print(Net) # 可以打印看模型结构 156 | if use_gpu: 157 | Net.cuda() 158 | 159 | image_test = getImage(testImage, part) 160 | image_ref = getImage(refImage, part) 161 | 162 | Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor 163 | input1 = Tensor(1, 3, h, w) 164 | input0 = Tensor(1, 3, h, w) 165 | 166 | x1 = torch.unsqueeze(image_test, dim=0) 167 | x0 = torch.unsqueeze(image_ref, dim=0) 168 | 169 | test_input = Variable(input1.copy_(x1)) 170 | ref_input = Variable(input0.copy_(x0)) 171 | 172 | extract_result = FeatureExtractor(Net, extract_list) 173 | x1_channelfeature = np.squeeze(extract_result(test_input)[0].cpu().detach().numpy()) 174 | x0_channelfeature = np.squeeze(extract_result(ref_input)[0].cpu().detach().numpy()) 175 | 176 | degree, f1, f0, hash_tabel = P_Hash(x1_channelfeature, x0_channelfeature, sf_list=sf_list, thre=thre) 177 | list_A = np.reshape(np.asarray(np.where(np.asarray(hash_tabel)>0)), newshape=[-1]) 178 | # print(list_A.shape) 179 | # print('list_A', list_A) 180 | # print('Resnet_phash', degree) 181 | shape = f1.shape[0] 182 | f1 = np.reshape(f1, (1, shape)) 183 | f1 = np.tile(f1, (10, 1)) 184 | f0 = np.reshape(f0, (1, shape)) 185 | f0 = np.tile(f0, (10, 1)) 186 | hash_tabel = np.reshape(hash_tabel, (1, shape)) 187 | hash_tabel = np.tile(hash_tabel, (10, 1)) 188 | black = np.zeros(shape=(5, shape)) 189 | 190 | show_img = np.vstack((np.vstack((np.vstack((np.vstack((f1, black)), f0)), black)), hash_tabel)) 191 | kk=[9, 12, 18, 27, 28, 29, 38, 45, 47, 64, 85, 103, 111, 126] 192 | if showfeature: 193 | plt.figure() 194 | for idx in range(x1_channelfeature.shape[0]):# x1_channelfeature.shape[0] list_A.__len__() kk.__len__() 195 | # i = kk[idx] 196 | i = idx 197 | plt.subplot(3, 1, 1) 198 | plt.imshow(show_img) 199 | plt.subplot(3,1,2) 200 | plt.imshow(x1_channelfeature[i, :, :]) 201 | plt.subplot(3,1,3) 202 | plt.imshow(x0_channelfeature[i, :, :]) 203 | plt.waitforbuttonpress() 204 | print(i) 205 | plt.show() 206 | return degree, list_A 207 | 208 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 209 | 210 | if __name__ == '__main__': 211 | 212 | modelname = ['Cycle_GAN', 'Cycle_GAN_unet', 'Cycle_GAN_UnetSSIM', 'Cycle_GAN_SSIM', 'Cycle_GAN_pathology_cls', 213 | 'Cycle_GAN_pathology_seg'] 214 | model_id = 2 215 | print(model_id) 216 | epoch_id = 10 217 | # wsi_name = [1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8] 218 | # id = [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 5, 1, 2, 1, 2, 1, 2, 3, 1, 2] 219 | wsi_name = [4,4,4,4,7,7,7,7,7,7,7,7,7,7, 8,8,8,8,8,8,8,8,8,8, 9,9,9] 220 | id = [1,2,3,4,1,2,3,4,5,6,7,8,9,10,1,2,3,4,5,6,7,8,9,10,1,2,3] 221 | D_list = np.zeros(shape=(2048)) 222 | xlsdata = { 223 | 'thre': [], 224 | 'p-hash_value_mean': [], 225 | 'p-hash_value_std': []} 226 | layer_id = 3 227 | layername = "layer" + str(layer_id) 228 | thre = 0.005 229 | degree_list = [] 230 | for idx in range(wsi_name.__len__()): 231 | img1 = '/home/zhangbc/Mydataspace/LST/breast/mymodelx288/' + modelname[model_id] + '/exp2/epoch' + str( 232 | epoch_id) + '/output/A2B/WSI_data00' + str(wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '_pre_B.png' 233 | img0 = '/home/zhangbc/Mydataspace/LST/raw_data/breast/Test/Ki67/WSI_data00' + str( 234 | wsi_name[idx]) + '/' + str(id[idx]).zfill(2) + '.png' 235 | 236 | degree, A = Resnet_P_hash(img1, img0, part=1, layername=layername, sf_list=None, thre=thre, showfeature=False) 237 | degree_list.append(degree) 238 | for i in A: 239 | D_list[i] += 1 240 | 241 | degree, A = Resnet_P_hash(img1, img0, part=2, layername=layername, sf_list=None, thre=thre, showfeature=False) 242 | degree_list.append(degree) 243 | for i in A: 244 | D_list[i] += 1 245 | degree_list = np.asarray(degree_list, dtype=np.float32) 246 | mean_pv = np.mean(degree_list) 247 | std_pv = np.std(degree_list) 248 | print('Resnet_phash', mean_pv) 249 | print('Resnet_phash_std', std_pv) 250 | xlsdata['thre'].append(thre) 251 | xlsdata['p-hash_value_mean'].append(mean_pv) 252 | xlsdata['p-hash_value_std'].append(std_pv) 253 | 254 | df = DataFrame(xlsdata) 255 | order = ['thre', 'p-hash_value_mean', 'p-hash_value_std'] 256 | df = df[order] 257 | df.to_excel('/home/zhangbc/Mydataspace/LST/breast/mymodelx288/' + modelname[model_id] + '/exp2/epoch' + str( 258 | epoch_id) + '/output/B_B_pv_'+layername+'.xlsx') 259 | 260 | ''' 261 | print(D_list) 262 | d_flist = np.where(D_list>0) 263 | print(d_flist) 264 | DF = DF + np.reshape(np.asarray(d_flist), [-1]).tolist() 265 | print(np.unique(DF).tolist()) 266 | 267 | 268 | DF =[15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 49, 50, 51, 52, 54, 55, 56, 59, 60, 61, 62, 63, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 116, 117, 118, 120, 122, 123, 125, 126, 127, 129, 130, 131, 132, 133, 134, 136, 137, 139, 140, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 170, 171, 172, 174, 175, 176, 179, 180, 181, 184, 187, 188, 189, 192, 194, 195, 196, 197, 198, 201, 202, 203, 205, 207, 208, 209, 211, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 231, 232, 233, 235, 236, 237, 238, 239, 241, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 258, 259, 260, 261, 264, 265, 267, 271, 273, 275, 276, 277, 278, 279, 280, 282, 283, 284, 285, 287, 288, 289, 291, 292, 293, 294, 295, 296, 299, 301, 302, 303, 304, 307, 308, 310, 311, 315, 316, 317, 318, 319, 321, 322, 323, 324, 325, 326, 327, 328, 330, 331, 332, 334, 335, 336, 337, 339, 346, 348, 350, 352, 354, 356, 357, 358, 359, 362, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 376, 377, 378, 379, 384, 386, 387, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 407, 410, 411, 412, 414, 415, 416, 417, 418, 420, 421, 422, 423, 424, 425, 427, 428, 430, 431, 432, 434, 435, 438, 439, 442, 443, 444, 445, 446, 447, 448, 449, 450, 452, 453, 454, 455, 456, 457, 458, 459, 460, 464, 465, 466, 467, 468, 469, 470, 472, 473, 475, 476, 477, 478, 479, 482, 483, 484, 485, 486, 487, 488, 489, 490, 494, 496, 497, 498, 499, 500, 502, 503, 504, 505, 509, 510, 511] 269 | 270 | ''' 271 | -------------------------------------------------------------------------------- /img/a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/img/a.png -------------------------------------------------------------------------------- /img/b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fightingkitty/Unpaired-Stain-Transfer-using-Pathology-Consistent-Constrained-Generative-Adversarial-Networks/b57c56b314e65a0f31d9e44f57174108599c8b14/img/b.png -------------------------------------------------------------------------------- /mymodels.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from unet_utils import Up, Down 4 | import torch 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, in_features, alt_leak=False, neg_slope=1e-2): 9 | super(ResidualBlock, self).__init__() 10 | 11 | conv_block = [ nn.ReflectionPad2d(1), 12 | nn.Conv2d(in_features, in_features, 3), 13 | nn.InstanceNorm2d(in_features), 14 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True), 15 | nn.ReflectionPad2d(1), 16 | nn.Conv2d(in_features, in_features, 3), 17 | nn.InstanceNorm2d(in_features) ] 18 | 19 | self.conv_block = nn.Sequential(*conv_block) 20 | 21 | def forward(self, x): 22 | return x + self.conv_block(x) 23 | 24 | 25 | class Pathology_block(nn.Module): 26 | def __init__(self, in_features, out_features, n_residual_blocks, alt_leak=False, neg_slope=1e-2): 27 | super(Pathology_block, self).__init__() 28 | 29 | ext_model = [nn.Conv2d(in_features, out_features, 1), 30 | nn.InstanceNorm2d(out_features), 31 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)] 32 | ext_model += [nn.ReflectionPad2d(1), 33 | nn.Conv2d(out_features, out_features, 4, stride=2), 34 | nn.InstanceNorm2d(out_features), 35 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)] 36 | 37 | for _ in range(n_residual_blocks): 38 | ext_model += [ResidualBlock(out_features, alt_leak, neg_slope)] 39 | self.extractor = nn.Sequential(*ext_model) 40 | 41 | def forward(self, x1, x2, x3): 42 | 43 | x1 = F.interpolate(x1, scale_factor=0.5) 44 | diffY1 = x2.size()[2] - x1.size()[2] 45 | diffX1 = x2.size()[3] - x1.size()[3] 46 | x1 = F.pad(x1, [diffX1 // 2, diffX1 - diffX1 // 2, 47 | diffY1 // 2, diffY1 - diffY1 // 2]) 48 | x = torch.cat([x1, x2], dim=1) 49 | 50 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=True) 51 | diffY2 = x2.size()[2] - x3.size()[2] 52 | diffX2 = x2.size()[3] - x3.size()[3] 53 | x3 = F.pad(x3, [diffX2 // 2, diffX2 - diffX2 // 2, 54 | diffY2 // 2, diffY2 - diffY2 // 2]) 55 | x = torch.cat([x, x3], dim=1) 56 | 57 | return self.extractor(x) 58 | 59 | 60 | class Generator_resnet(nn.Module): 61 | def __init__(self, input_nc, output_nc, n_residual_blocks=10, alt_leak=False, neg_slope=1e-2): 62 | super(Generator_resnet, self).__init__() 63 | 64 | # Initial convolution block [N 32 H W] 65 | model_encoder = [nn.ReflectionPad2d(3), 66 | nn.Conv2d(input_nc, 32, 7), 67 | nn.InstanceNorm2d(32), 68 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)] 69 | 70 | # Downsampling [N 64 H/2 W/2]-->[N 256 H/8 W/8] 71 | in_features = 32 72 | out_features = in_features*2 73 | for _ in range(3): 74 | model_encoder += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 75 | nn.InstanceNorm2d(out_features), 76 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)] 77 | in_features = out_features 78 | out_features = in_features*2 79 | 80 | # Residual blocks [N 256 H/8 W/8] 81 | for _ in range(n_residual_blocks//2): 82 | model_encoder += [ResidualBlock(in_features, alt_leak, neg_slope)] 83 | 84 | model_decoder = [] 85 | # Residual blocks [N 256 H/8 W/8] 86 | for _ in range(n_residual_blocks//2): 87 | model_decoder += [ResidualBlock(in_features, alt_leak, neg_slope)] 88 | # Upsampling [N 128 H/4 W/4]-->[N 32 H W] 89 | out_features = in_features//2 90 | for _ in range(3): 91 | model_decoder += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), 92 | nn.InstanceNorm2d(out_features), 93 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)] 94 | in_features = out_features 95 | out_features = in_features//2 96 | 97 | # Output layer [N 3 H W] 98 | model_decoder += [nn.ReflectionPad2d(3), 99 | nn.Conv2d(32, output_nc, 7), 100 | nn.Tanh()] 101 | 102 | self.encoder = nn.Sequential(*model_encoder) 103 | self.decoder = nn.Sequential(*model_decoder) 104 | 105 | def forward(self, x): 106 | features = self.encoder(x) 107 | output = self.decoder(features) 108 | return output, features 109 | 110 | 111 | class Generator_unet(nn.Module): 112 | def __init__(self, input_nc, output_nc, n_residual_blocks=8, alt_leak=False, neg_slope=1e-2): 113 | super(Generator_unet, self).__init__() 114 | # Initial convolution block [N 32 H W] 115 | self.inc = nn.Sequential(nn.ReflectionPad2d(3), 116 | nn.Conv2d(input_nc, 32, 7), 117 | nn.InstanceNorm2d(32), 118 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 119 | # Downsampling [N 64 H/2 W/2] 120 | self.down1 = Down(32, 64, alt_leak, neg_slope) 121 | # Downsampling [N 128 H/4 W/4] 122 | self.down2 = Down(64, 128, alt_leak, neg_slope) 123 | # Downsampling [N 256 H/8 W/8] 124 | self.down3 = Down(128, 256, alt_leak, neg_slope) 125 | 126 | # Residual blocks [N 256 H/8 W/8] 127 | res_ext_encoder = [] 128 | for _ in range(n_residual_blocks // 2): 129 | res_ext_encoder += [ResidualBlock(256, alt_leak, neg_slope)] 130 | self.ext_f1 = nn.Sequential(*res_ext_encoder) 131 | # Residual blocks [N 256 H/8 W/8] 132 | res_ext_decoder = [] 133 | for _ in range(n_residual_blocks // 2): 134 | res_ext_decoder += [ResidualBlock(256, alt_leak, neg_slope)] 135 | self.ext_f2 = nn.Sequential(*res_ext_decoder) 136 | # Downsampling [N 128 H/4 W/4] 137 | self.up1 = Up(256, 128, alt_leak, neg_slope) 138 | # Downsampling [N 64 H/2 W/2] 139 | self.up2 = Up(128, 64, alt_leak, neg_slope) 140 | # Downsampling [N 32 H W] 141 | self.up3 = Up(64, 32, alt_leak, neg_slope) 142 | # Downsampling [N 3 H W] 143 | self.outc = nn.Sequential(nn.ReflectionPad2d(3), 144 | nn.Conv2d(32, output_nc, 7), 145 | nn.Tanh()) 146 | 147 | def forward(self, x): 148 | x0 = self.inc(x) 149 | x1 = self.down1(x0) 150 | x2 = self.down2(x1) 151 | x3 = self.down3(x2) 152 | latent_features = self.ext_f1(x3) 153 | features = self.ext_f2(latent_features) 154 | x = self.up1(features, x2) 155 | x = self.up2(x, x1) 156 | x = self.up3(x, x0) 157 | outputs = self.outc(x) 158 | return outputs, latent_features 159 | 160 | 161 | class Generator_unet_cls(nn.Module): 162 | def __init__(self, input_nc, output_nc, n_residual_blocks=8, alt_leak=False, neg_slope=1e-2): 163 | super(Generator_unet_cls, self).__init__() 164 | # Initial convolution block [N 32 H W] 165 | self.inc = nn.Sequential(nn.ReflectionPad2d(3), 166 | nn.Conv2d(input_nc, 32, 7), 167 | nn.InstanceNorm2d(32), 168 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 169 | # Downsampling [N 64 H/2 W/2] 170 | self.down1 = Down(32, 64, alt_leak, neg_slope) 171 | # Downsampling [N 128 H/4 W/4] 172 | self.down2 = Down(64, 128, alt_leak, neg_slope) 173 | # Downsampling [N 256 H/8 W/8] 174 | self.down3 = Down(128, 256, alt_leak, neg_slope) 175 | 176 | # Residual blocks [N 256 H/8 W/8] 177 | res_ext_encoder = [] 178 | for _ in range(n_residual_blocks // 2): 179 | res_ext_encoder += [ResidualBlock(256, alt_leak, neg_slope)] 180 | self.ext_f1 = nn.Sequential(*res_ext_encoder) 181 | # merge features [N 256 H/8 W/8] 182 | self.pathology_f = Pathology_block(448, 256, n_residual_blocks // 2, alt_leak, neg_slope) 183 | 184 | self.merge = nn.Sequential(nn.Conv2d(512, 256, 1), 185 | nn.InstanceNorm2d(256), 186 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 187 | # Residual blocks [N 256 H/8 W/8] 188 | res_ext_decoder = [] 189 | for _ in range(n_residual_blocks // 2): 190 | res_ext_decoder += [ResidualBlock(256, alt_leak, neg_slope)] 191 | self.ext_f2 = nn.Sequential(*res_ext_decoder) 192 | # Downsampling [N 128 H/4 W/4] 193 | self.up1 = Up(256, 128, alt_leak, neg_slope) 194 | # Downsampling [N 64 H/2 W/2] 195 | self.up2 = Up(128, 64, alt_leak, neg_slope) 196 | # Downsampling [N 32 H W] 197 | self.up3 = Up(64, 32, alt_leak, neg_slope) 198 | # Downsampling [N 3 H W] 199 | self.outc = nn.Sequential(nn.ReflectionPad2d(3), 200 | nn.Conv2d(32, output_nc, 7), 201 | nn.Tanh()) 202 | 203 | self.out_cls = nn.Sequential(nn.Dropout(0.15), 204 | nn.ReflectionPad2d(1), 205 | nn.Conv2d(256, 1, 3), 206 | nn.Sigmoid()) 207 | 208 | def forward(self, x, mode='G'): 209 | # encoder 210 | x0 = self.inc(x) 211 | x1 = self.down1(x0) 212 | x2 = self.down2(x1) 213 | x3 = self.down3(x2) 214 | 215 | # extract feature 216 | pathology_features = self.pathology_f(x1, x2, x3) 217 | c_out = self.out_cls(pathology_features) 218 | # Average pooling and flatten 219 | c_out = F.avg_pool2d(c_out, c_out.size()[2:]).view(c_out.size()[0]) 220 | if mode == 'C': 221 | return c_out 222 | latent_features = self.ext_f1(x3) 223 | features = torch.cat([latent_features, pathology_features], dim=1) 224 | features = self.merge(features) 225 | features = self.ext_f2(features) 226 | 227 | # decoder 228 | x = self.up1(features, x2) 229 | x = self.up2(x, x1) 230 | x = self.up3(x, x0) 231 | outputs = self.outc(x) 232 | return outputs, latent_features, c_out, pathology_features 233 | 234 | 235 | class Generator_unet_seg(nn.Module): 236 | def __init__(self, input_nc, output_nc, n_residual_blocks=8, alt_leak=False, neg_slope=1e-2): 237 | super(Generator_unet_seg, self).__init__() 238 | # Initial convolution block [N 32 H W] 239 | self.inc = nn.Sequential(nn.ReflectionPad2d(3), 240 | nn.Conv2d(input_nc, 32, 7), 241 | nn.InstanceNorm2d(32), 242 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 243 | # Downsampling [N 64 H/2 W/2] 244 | self.down1 = Down(32, 64, alt_leak, neg_slope) 245 | # Downsampling [N 128 H/4 W/4] 246 | self.down2 = Down(64, 128, alt_leak, neg_slope) 247 | # Downsampling [N 256 H/8 W/8] 248 | self.down3 = Down(128, 256, alt_leak, neg_slope) 249 | 250 | # Residual blocks [N 256 H/8 W/8] 251 | res_ext_encoder = [] 252 | for _ in range(n_residual_blocks // 2): 253 | res_ext_encoder += [ResidualBlock(256, alt_leak, neg_slope)] 254 | self.ext_f1 = nn.Sequential(*res_ext_encoder) 255 | # merge features [N 256 H/8 W/8] 256 | self.pathology_f = Pathology_block(448, 256, n_residual_blocks // 2, alt_leak, neg_slope) 257 | 258 | self.merge = nn.Sequential(nn.Conv2d(512, 256, 1), 259 | nn.InstanceNorm2d(256), 260 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 261 | # Residual blocks [N 256 H/8 W/8] 262 | res_ext_decoder = [] 263 | for _ in range(n_residual_blocks // 2): 264 | res_ext_decoder += [ResidualBlock(256, alt_leak, neg_slope)] 265 | self.ext_f2 = nn.Sequential(*res_ext_decoder) 266 | # Downsampling [N 128 H/4 W/4] 267 | self.up1 = Up(256, 128, alt_leak, neg_slope) 268 | # Downsampling [N 64 H/2 W/2] 269 | self.up2 = Up(128, 64, alt_leak, neg_slope) 270 | # Downsampling [N 32 H W] 271 | self.up3 = Up(64, 32, alt_leak, neg_slope) 272 | # Downsampling [N 3 H W] 273 | self.outc = nn.Sequential(nn.ReflectionPad2d(3), 274 | nn.Conv2d(32, output_nc, 7), 275 | nn.Tanh()) 276 | 277 | self.out_seg = nn.Sequential(nn.ReflectionPad2d(1), 278 | nn.Conv2d(256, 1, 3), 279 | nn.Sigmoid()) 280 | 281 | def forward(self, x, mode='G'): 282 | # encoder 283 | x0 = self.inc(x) 284 | x1 = self.down1(x0) 285 | x2 = self.down2(x1) 286 | x3 = self.down3(x2) 287 | 288 | # extract feature 289 | pathology_features = self.pathology_f(x1, x2, x3) 290 | c_out = self.out_seg(pathology_features) 291 | # Average pooling and flatten 292 | if mode == 'C': 293 | return c_out 294 | latent_features = self.ext_f1(x3) 295 | features = torch.cat([latent_features, pathology_features], dim=1) 296 | features = self.merge(features) 297 | features = self.ext_f2(features) 298 | 299 | # decoder 300 | x = self.up1(features, x2) 301 | x = self.up2(x, x1) 302 | x = self.up3(x, x0) 303 | outputs = self.outc(x) 304 | return outputs, latent_features, c_out, pathology_features 305 | 306 | 307 | class Discriminator(nn.Module): 308 | def __init__(self, input_nc): 309 | super(Discriminator, self).__init__() 310 | 311 | # A bunch of convolutions one after another 312 | model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1 ), 313 | nn.LeakyReLU(0.2, inplace=True)] 314 | 315 | model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1), 316 | nn.InstanceNorm2d(128), 317 | nn.LeakyReLU(0.2, inplace=True) ] 318 | 319 | model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1), 320 | nn.InstanceNorm2d(256), 321 | nn.LeakyReLU(0.2, inplace=True) ] 322 | 323 | model += [ nn.Conv2d(256, 512, 4, padding=1), 324 | nn.InstanceNorm2d(512), 325 | nn.LeakyReLU(0.2, inplace=True) ] 326 | 327 | # FCN classification layer 328 | model += [nn.Conv2d(512, 1, 4, padding=1)] 329 | 330 | self.model = nn.Sequential(*model) 331 | 332 | def forward(self, x): 333 | x = self.model(x) 334 | # Average pooling and flatten 335 | return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0]) -------------------------------------------------------------------------------- /train_Cycle_Gan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch 11 | 12 | from mymodels import Generator_resnet 13 | from mymodels import Discriminator 14 | from utils import ReplayBuffer 15 | from utils import LambdaLR 16 | from utils import Logger 17 | from utils import weights_init_normal 18 | from datasets import ImageDataset 19 | # breast/ neuroendocrine / GLAS 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/datasetX20_288/Train_PN_3_1', help='root directory of the dataset') 24 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_PN_3_1', help='root directory of the model') 25 | 26 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 27 | parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs of training') 28 | parser.add_argument('--batchSize', type=int, default=2, help='size of the batches') 29 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 30 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 31 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 32 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 33 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 34 | 35 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 36 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 37 | 38 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 39 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_PN_3_1/temp', help='continue train root directory of the model') 40 | opt = parser.parse_args() 41 | print(opt) 42 | 43 | if torch.cuda.is_available() and not opt.cuda: 44 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 45 | 46 | ###### Definition of variables ###### 47 | # Networks 48 | netG_A2B = Generator_resnet(opt.input_nc, opt.output_nc, 10, False) 49 | netG_B2A = Generator_resnet(opt.output_nc, opt.input_nc, 10, False) 50 | netD_A = Discriminator(opt.input_nc) 51 | netD_B = Discriminator(opt.output_nc) 52 | 53 | if opt.cuda: 54 | netG_A2B.cuda() 55 | netG_B2A.cuda() 56 | netD_A.cuda() 57 | netD_B.cuda() 58 | 59 | netG_A2B.apply(weights_init_normal) 60 | netG_B2A.apply(weights_init_normal) 61 | netD_A.apply(weights_init_normal) 62 | netD_B.apply(weights_init_normal) 63 | 64 | # Lossess 65 | criterion_GAN = nn.MSELoss() 66 | criterion_cycle = nn.L1Loss() 67 | criterion_identity = nn.L1Loss() 68 | 69 | # Optimizers & LR schedulers 70 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 71 | lr=opt.lr, betas=(0.5, 0.999)) 72 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 73 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 74 | 75 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 76 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 77 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 78 | 79 | # Inputs & targets memory allocation 80 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 81 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 82 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 83 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 84 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 85 | 86 | fake_A_buffer = ReplayBuffer() 87 | fake_B_buffer = ReplayBuffer() 88 | 89 | transforms_ = transforms.Compose([ 90 | transforms.RandomCrop(opt.size), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.RandomVerticalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 95 | 96 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 97 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 98 | 99 | ################################### 100 | start_epoch = opt.epoch 101 | 102 | if opt.continue_train: 103 | 104 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 105 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 106 | 107 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 108 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 109 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 110 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 111 | start_epoch = netG_B2A_checkpoint['epoch']+1 112 | 113 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 114 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 115 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 116 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 117 | 118 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 119 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 120 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 121 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 122 | 123 | # Loss plot 124 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 125 | 126 | ###### Training ###### 127 | for epoch in range(start_epoch, opt.n_epochs): 128 | for i, batch in enumerate(dataloader): 129 | 130 | # Set model input 131 | real_A = Variable(input_A.copy_(batch['HE'])) 132 | real_B = Variable(input_B.copy_(batch['Ki67'])) 133 | 134 | # Generators A2B and B2A 135 | optimizer_G.zero_grad() 136 | 137 | # Identity loss 138 | # G_A2B(B) should equal B if real B is fed 139 | same_B, _ = netG_A2B(real_B) 140 | loss_identity_B = criterion_identity(same_B, real_B) 141 | # G_B2A(A) should equal A if real A is fed 142 | same_A, _ = netG_B2A(real_A) 143 | loss_identity_A = criterion_identity(same_A, real_A) 144 | 145 | # GAN loss 146 | fake_B, _ = netG_A2B(real_A) 147 | pred_fake = netD_B(fake_B) 148 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 149 | 150 | fake_A, _ = netG_B2A(real_B) 151 | pred_fake = netD_A(fake_A) 152 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 153 | 154 | # Cycle loss 155 | recovered_A, _ = netG_B2A(fake_B) 156 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 157 | 158 | recovered_B, _ = netG_A2B(fake_A) 159 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 160 | 161 | # Total loss 162 | loss_G = 5.0 * (loss_identity_A + loss_identity_B) + \ 163 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 164 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) 165 | 166 | loss_G.backward() 167 | 168 | optimizer_G.step() 169 | ################################### 170 | 171 | # Discriminator A 172 | optimizer_D_A.zero_grad() 173 | 174 | # Real loss 175 | pred_real = netD_A(real_A) 176 | loss_D_real = criterion_GAN(pred_real, target_real) 177 | 178 | # Fake loss 179 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 180 | pred_fake = netD_A(fake_Ad.detach()) 181 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 182 | 183 | # Total loss 184 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 185 | loss_D_A.backward() 186 | 187 | optimizer_D_A.step() 188 | ################################### 189 | 190 | # Discriminator B 191 | optimizer_D_B.zero_grad() 192 | 193 | # Real loss 194 | pred_real = netD_B(real_B) 195 | loss_D_real = criterion_GAN(pred_real, target_real) 196 | 197 | # Fake loss 198 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 199 | pred_fake = netD_B(fake_Bd.detach()) 200 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 201 | 202 | # Total loss 203 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 204 | loss_D_B.backward() 205 | 206 | optimizer_D_B.step() 207 | ################################### 208 | 209 | # Progress report (http://localhost:8097) 210 | logger.log({'loss_G': loss_G, 211 | 'loss_G_identity': (loss_identity_A + loss_identity_B), 212 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 213 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 214 | 'loss_D': (loss_D_A + loss_D_B)}, 215 | images={'real_cycleGAN_A': real_A, 'real_cycleGAN_B': real_B, 216 | 'fake_cycleGAN_A': fake_A, 'fake_cycleGAN_B': fake_B}) 217 | 218 | # save models at half of an epoch 219 | if (i+1) % (dataloader.__len__()//5 + 1) == 0: 220 | saveroot = os.path.join(opt.modelroot, 'temp') 221 | if not os.path.exists(saveroot): 222 | os.makedirs(saveroot) 223 | 224 | # Save models checkpoints 225 | netG_A2B_checkpoints = { 226 | "model": netG_A2B.state_dict() 227 | } 228 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 229 | 230 | netG_B2A_checkpoints = { 231 | "model": netG_B2A.state_dict(), 232 | 'optimizer': optimizer_G.state_dict(), 233 | "epoch": epoch, 234 | 'lr_schedule': lr_scheduler_G.state_dict() 235 | } 236 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 237 | 238 | netD_A_checkpoints = { 239 | "model": netD_A.state_dict(), 240 | 'optimizer': optimizer_D_A.state_dict(), 241 | 'lr_schedule': lr_scheduler_D_A.state_dict() 242 | } 243 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 244 | 245 | netD_B_checkpoints = { 246 | "model": netD_B.state_dict(), 247 | 'optimizer': optimizer_D_B.state_dict(), 248 | 'lr_schedule': lr_scheduler_D_B.state_dict() 249 | } 250 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 251 | 252 | # Update learning rates 253 | lr_scheduler_G.step() 254 | lr_scheduler_D_A.step() 255 | lr_scheduler_D_B.step() 256 | 257 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 258 | if not os.path.exists(saveroot): 259 | os.makedirs(saveroot) 260 | 261 | # Save models checkpoints 262 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 263 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 264 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 265 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 266 | -------------------------------------------------------------------------------- /train_Cycle_GanSSIM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch 12 | 13 | from mymodels import Generator_resnet 14 | from mymodels import Discriminator 15 | from utils import ReplayBuffer 16 | from utils import LambdaLR 17 | from utils import Logger 18 | from utils import weights_init_normal 19 | from utils import MS_SSIM_Loss 20 | from datasets import ImageDataset 21 | # breast/ neuroendocrine / GLAS 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/datasetX20_288/Train', help='root directory of the dataset') 25 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_SSIM', help='root directory of the model') 26 | 27 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 28 | parser.add_argument('--n_epochs', type=int, default=12, help='number of epochs of training') 29 | parser.add_argument('--batchSize', type=int, default=4, help='size of the batches') 30 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 31 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 32 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 33 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 34 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 35 | 36 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 37 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 38 | 39 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 40 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_SSIM/temp', help='continue train root directory of the model') 41 | opt = parser.parse_args() 42 | 43 | print(opt) 44 | 45 | if torch.cuda.is_available() and not opt.cuda: 46 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 47 | 48 | ###### Definition of variables ###### 49 | # Networks 50 | netG_A2B = Generator_resnet(opt.input_nc, opt.output_nc, 10, False) 51 | netG_B2A = Generator_resnet(opt.output_nc, opt.input_nc, 10, False) 52 | netD_A = Discriminator(opt.input_nc) 53 | netD_B = Discriminator(opt.output_nc) 54 | 55 | if opt.cuda: 56 | netG_A2B.cuda() 57 | netG_B2A.cuda() 58 | netD_A.cuda() 59 | netD_B.cuda() 60 | 61 | netG_A2B.apply(weights_init_normal) 62 | netG_B2A.apply(weights_init_normal) 63 | netD_A.apply(weights_init_normal) 64 | netD_B.apply(weights_init_normal) 65 | 66 | # Lossess 67 | criterion_GAN = nn.MSELoss() 68 | criterion_cycle = nn.L1Loss() 69 | criterion_identity = nn.L1Loss() 70 | criterion_ssim = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3) 71 | 72 | # Optimizers & LR schedulers 73 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 74 | lr=opt.lr, betas=(0.5, 0.999)) 75 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 76 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 77 | 78 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 79 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 80 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 81 | 82 | # Inputs & targets memory allocation 83 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 84 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 85 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 86 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 87 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 88 | 89 | fake_A_buffer = ReplayBuffer() 90 | fake_B_buffer = ReplayBuffer() 91 | 92 | transforms_ = transforms.Compose([ 93 | transforms.RandomCrop(opt.size), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.RandomVerticalFlip(), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 98 | 99 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 100 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 101 | 102 | ################################### 103 | start_epoch = opt.epoch 104 | 105 | if opt.continue_train: 106 | 107 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 108 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 109 | 110 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 111 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 112 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 113 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 114 | start_epoch = netG_B2A_checkpoint['epoch'] 115 | 116 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 117 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 118 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 119 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 120 | 121 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 122 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 123 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 124 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 125 | 126 | # Loss plot 127 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 128 | 129 | ###### Training ###### 130 | for epoch in range(opt.epoch, opt.n_epochs): 131 | for i, batch in enumerate(dataloader): 132 | 133 | # Set model input 134 | real_A = Variable(input_A.copy_(batch['HE'])) 135 | real_B = Variable(input_B.copy_(batch['Ki67'])) 136 | 137 | ###### Generators A2B and B2A ###### 138 | optimizer_G.zero_grad() 139 | 140 | # TUM remove the Identity loss 141 | # G_A2B(B) should equal B if real B is fed 142 | same_B, _ = netG_A2B(real_B) 143 | loss_identity_B = criterion_identity(same_B, real_B) 144 | # G_B2A(A) should equal A if real A is fed 145 | same_A, _ = netG_B2A(real_A) 146 | loss_identity_A = criterion_identity(same_A, real_A) 147 | 148 | # GAN loss 149 | fake_B, features_fb = netG_A2B(real_A) 150 | pred_fake = netD_B(fake_B) 151 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 152 | 153 | fake_A, features_fa = netG_B2A(real_B) 154 | pred_fake = netD_A(fake_A) 155 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 156 | 157 | # Cycle loss 158 | recovered_A, features_ra = netG_B2A(fake_B) 159 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 160 | criterion_ssimABA = criterion_ssim(recovered_A, real_A) 161 | 162 | recovered_B, features_rb = netG_A2B(fake_A) 163 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 164 | criterion_ssimBAB = criterion_ssim(recovered_B, real_B) 165 | 166 | # Total loss 167 | loss_G = 5.0 * (loss_identity_A + loss_identity_B) + \ 168 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 169 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) + \ 170 | 2.0 * (criterion_ssimABA + criterion_ssimBAB) 171 | 172 | loss_G.backward() 173 | 174 | optimizer_G.step() 175 | ################################### 176 | 177 | ###### Discriminator A ###### 178 | optimizer_D_A.zero_grad() 179 | 180 | # Real loss 181 | pred_real = netD_A(real_A) 182 | loss_D_real = criterion_GAN(pred_real, target_real) 183 | 184 | # Fake loss 185 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 186 | pred_fake = netD_A(fake_Ad.detach()) 187 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 188 | 189 | # Total loss 190 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 191 | loss_D_A.backward() 192 | 193 | optimizer_D_A.step() 194 | ################################### 195 | 196 | ###### Discriminator B ###### 197 | optimizer_D_B.zero_grad() 198 | 199 | # Real loss 200 | pred_real = netD_B(real_B) 201 | loss_D_real = criterion_GAN(pred_real, target_real) 202 | 203 | # Fake loss 204 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 205 | pred_fake = netD_B(fake_Bd.detach()) 206 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 207 | 208 | # Total loss 209 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 210 | loss_D_B.backward() 211 | 212 | optimizer_D_B.step() 213 | ################################### 214 | 215 | # Progress report (http://localhost:8097) 216 | logger.log({'loss_G': loss_G, 217 | 'loss_G_SSIM': (criterion_ssimABA + criterion_ssimBAB), 218 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 219 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 220 | 'loss_D': (loss_D_A + loss_D_B)}, 221 | images={'real_cycleSSIM_A': real_A, 'real_cycleSSIM_B': real_B, 222 | 'fake_cycleSSIM_A': fake_A, 'fake_cycleSSIM_B': fake_B}) 223 | 224 | # save models at half of an epoch 225 | if (i + 1) % (dataloader.__len__() // 5 + 1) == 0: 226 | saveroot = os.path.join(opt.modelroot, 'temp') 227 | if not os.path.exists(saveroot): 228 | os.makedirs(saveroot) 229 | 230 | # Save models checkpoints 231 | netG_A2B_checkpoints = { 232 | "model": netG_A2B.state_dict() 233 | } 234 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 235 | 236 | netG_B2A_checkpoints = { 237 | "model": netG_B2A.state_dict(), 238 | 'optimizer': optimizer_G.state_dict(), 239 | "epoch": epoch, 240 | 'lr_schedule': lr_scheduler_G.state_dict() 241 | } 242 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 243 | 244 | netD_A_checkpoints = { 245 | "model": netD_A.state_dict(), 246 | 'optimizer': optimizer_D_A.state_dict(), 247 | 'lr_schedule': lr_scheduler_D_A.state_dict() 248 | } 249 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 250 | 251 | netD_B_checkpoints = { 252 | "model": netD_B.state_dict(), 253 | 'optimizer': optimizer_D_B.state_dict(), 254 | 'lr_schedule': lr_scheduler_D_B.state_dict() 255 | } 256 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 257 | 258 | # Update learning rates 259 | lr_scheduler_G.step() 260 | lr_scheduler_D_A.step() 261 | lr_scheduler_D_B.step() 262 | 263 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 264 | if not os.path.exists(saveroot): 265 | os.makedirs(saveroot) 266 | 267 | # Save models checkpoints 268 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 269 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 270 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 271 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 272 | -------------------------------------------------------------------------------- /train_Cycle_Gan_UnetSSIM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch 12 | 13 | from mymodels import Generator_unet 14 | from mymodels import Discriminator 15 | from utils import ReplayBuffer 16 | from utils import LambdaLR 17 | from utils import Logger 18 | from utils import weights_init_normal 19 | from utils import MS_SSIM_Loss 20 | from datasets import ImageDataset 21 | 22 | # breast/ neuroendocrine / GLAS 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/datasetX20_288/Train', help='root directory of the dataset') 26 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_UnetSSIM', help='root directory of the model') 27 | 28 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 29 | parser.add_argument('--n_epochs', type=int, default=12, help='number of epochs of training') 30 | parser.add_argument('--batchSize', type=int, default=4, help='size of the batches') 31 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 32 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 33 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 34 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 35 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 36 | 37 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 38 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 39 | 40 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 41 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_UnetSSIM/temp', help='continue train root directory of the model') 42 | opt = parser.parse_args() 43 | 44 | print(opt) 45 | 46 | if torch.cuda.is_available() and not opt.cuda: 47 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 48 | 49 | ###### Definition of variables ###### 50 | # Networks 51 | netG_A2B = Generator_unet(opt.input_nc, opt.output_nc, 10, alt_leak=True, neg_slope=0.1) 52 | netG_B2A = Generator_unet(opt.output_nc, opt.input_nc, 10, alt_leak=True, neg_slope=0.1) 53 | netD_A = Discriminator(opt.input_nc) 54 | netD_B = Discriminator(opt.output_nc) 55 | 56 | if opt.cuda: 57 | netG_A2B.cuda() 58 | netG_B2A.cuda() 59 | netD_A.cuda() 60 | netD_B.cuda() 61 | 62 | netG_A2B.apply(weights_init_normal) 63 | netG_B2A.apply(weights_init_normal) 64 | netD_A.apply(weights_init_normal) 65 | netD_B.apply(weights_init_normal) 66 | 67 | # Lossess 68 | criterion_GAN = nn.MSELoss() 69 | criterion_cycle = nn.L1Loss() 70 | criterion_identity = nn.L1Loss() 71 | criterion_ssim = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3) 72 | 73 | # Optimizers & LR schedulers 74 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 75 | lr=opt.lr, betas=(0.5, 0.999)) 76 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 77 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 78 | 79 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 80 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 81 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 82 | 83 | # Inputs & targets memory allocation 84 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 85 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 86 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 87 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 88 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 89 | 90 | fake_A_buffer = ReplayBuffer() 91 | fake_B_buffer = ReplayBuffer() 92 | 93 | transforms_ = transforms.Compose([ 94 | transforms.RandomCrop(opt.size), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.RandomVerticalFlip(), 97 | transforms.ToTensor(), 98 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 99 | 100 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 101 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 102 | 103 | ################################### 104 | start_epoch = opt.epoch 105 | 106 | if opt.continue_train: 107 | 108 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 109 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 110 | 111 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 112 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 113 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 114 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 115 | start_epoch = netG_B2A_checkpoint['epoch'] 116 | 117 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 118 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 119 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 120 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 121 | 122 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 123 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 124 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 125 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 126 | 127 | # Loss plot 128 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 129 | 130 | ###### Training ###### 131 | for epoch in range(opt.epoch, opt.n_epochs): 132 | for i, batch in enumerate(dataloader): 133 | 134 | # Set model input 135 | real_A = Variable(input_A.copy_(batch['HE'])) 136 | real_B = Variable(input_B.copy_(batch['Ki67'])) 137 | 138 | ###### Generators A2B and B2A ###### 139 | optimizer_G.zero_grad() 140 | 141 | # TUM remove the Identity loss 142 | # G_A2B(B) should equal B if real B is fed 143 | same_B, _ = netG_A2B(real_B) 144 | loss_identity_B = criterion_identity(same_B, real_B) 145 | # G_B2A(A) should equal A if real A is fed 146 | same_A, _ = netG_B2A(real_A) 147 | loss_identity_A = criterion_identity(same_A, real_A) 148 | 149 | # GAN loss 150 | fake_B, features_fb = netG_A2B(real_A) 151 | pred_fake = netD_B(fake_B) 152 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 153 | 154 | fake_A, features_fa = netG_B2A(real_B) 155 | pred_fake = netD_A(fake_A) 156 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 157 | 158 | # Cycle loss 159 | recovered_A, features_ra = netG_B2A(fake_B) 160 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 161 | criterion_ssimABA = criterion_ssim(recovered_A, real_A) 162 | 163 | recovered_B, features_rb = netG_A2B(fake_A) 164 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 165 | criterion_ssimBAB = criterion_ssim(recovered_B, real_B) 166 | 167 | # Total loss 168 | loss_G = 5.0 * (loss_identity_A + loss_identity_B) +\ 169 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 170 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) + \ 171 | 2.0 * (criterion_ssimABA + criterion_ssimBAB) 172 | 173 | loss_G.backward() 174 | 175 | optimizer_G.step() 176 | ################################### 177 | 178 | ###### Discriminator A ###### 179 | optimizer_D_A.zero_grad() 180 | 181 | # Real loss 182 | pred_real = netD_A(real_A) 183 | loss_D_real = criterion_GAN(pred_real, target_real) 184 | 185 | # Fake loss 186 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 187 | pred_fake = netD_A(fake_Ad.detach()) 188 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 189 | 190 | # Total loss 191 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 192 | loss_D_A.backward() 193 | 194 | optimizer_D_A.step() 195 | ################################### 196 | 197 | ###### Discriminator B ###### 198 | optimizer_D_B.zero_grad() 199 | 200 | # Real loss 201 | pred_real = netD_B(real_B) 202 | loss_D_real = criterion_GAN(pred_real, target_real) 203 | 204 | # Fake loss 205 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 206 | pred_fake = netD_B(fake_Bd.detach()) 207 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 208 | 209 | # Total loss 210 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 211 | loss_D_B.backward() 212 | 213 | optimizer_D_B.step() 214 | ################################### 215 | 216 | # Progress report (http://localhost:8097) 217 | logger.log({'loss_G': loss_G, 218 | 'loss_G_SSIM': (criterion_ssimABA + criterion_ssimBAB), 219 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 220 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 221 | 'loss_D': (loss_D_A + loss_D_B)}, 222 | images={'real_cycleUnetSSIM_A': real_A, 'real_cycleUnetSSIM_B': real_B, 223 | 'fake_cycleUnetSSIM_A': fake_A, 'fake_cycleUnetSSIM_B': fake_B}) 224 | 225 | # save models at half of an epoch 226 | if (i + 1) % (dataloader.__len__() // 5 + 1) == 0: 227 | saveroot = os.path.join(opt.modelroot, 'temp') 228 | if not os.path.exists(saveroot): 229 | os.makedirs(saveroot) 230 | 231 | # Save models checkpoints 232 | netG_A2B_checkpoints = { 233 | "model": netG_A2B.state_dict() 234 | } 235 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 236 | 237 | netG_B2A_checkpoints = { 238 | "model": netG_B2A.state_dict(), 239 | 'optimizer': optimizer_G.state_dict(), 240 | "epoch": epoch, 241 | 'lr_schedule': lr_scheduler_G.state_dict() 242 | } 243 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 244 | 245 | netD_A_checkpoints = { 246 | "model": netD_A.state_dict(), 247 | 'optimizer': optimizer_D_A.state_dict(), 248 | 'lr_schedule': lr_scheduler_D_A.state_dict() 249 | } 250 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 251 | 252 | netD_B_checkpoints = { 253 | "model": netD_B.state_dict(), 254 | 'optimizer': optimizer_D_B.state_dict(), 255 | 'lr_schedule': lr_scheduler_D_B.state_dict() 256 | } 257 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 258 | 259 | # Update learning rates 260 | lr_scheduler_G.step() 261 | lr_scheduler_D_A.step() 262 | lr_scheduler_D_B.step() 263 | 264 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 265 | if not os.path.exists(saveroot): 266 | os.makedirs(saveroot) 267 | 268 | # Save models checkpoints 269 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 270 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 271 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 272 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 273 | -------------------------------------------------------------------------------- /train_Cycle_Gan_pathology_cls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch 12 | 13 | from mymodels import Generator_unet_cls 14 | from mymodels import Discriminator 15 | from utils import ReplayBuffer 16 | from utils import LambdaLR 17 | from utils import Logger 18 | from utils import weights_init_normal 19 | from utils import MS_SSIM_Loss 20 | from datasets import ImageDataset, ExpertDataset_label 21 | 22 | # breast/ neuroendocrine / GLAS 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/datasetX20_288/Train', help='root directory of the dataset') 26 | parser.add_argument('--dataroot_ek', type=str, default='/home/zhangbc/Mydataspace/LST/breast/datasetX20_288/Expert_knowledge', help='root directory of the expert knowledge') 27 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_pathology_cls', help='root directory of the model') 28 | 29 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 30 | parser.add_argument('--n_epochs', type=int, default=12, help='number of epochs of training') 31 | parser.add_argument('--batchSize', type=int, default=2, help='size of the batches') 32 | parser.add_argument('--batchSize2', type=int, default=8, help='size of the batches for expert knowledge learning') 33 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 34 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 35 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 36 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 37 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 38 | 39 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 40 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 41 | 42 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 43 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/breast/mymodelx288/Cycle_GAN_pathology_cls/temp', help='continue train root directory of the model') 44 | opt = parser.parse_args() 45 | print(opt) 46 | 47 | if torch.cuda.is_available() and not opt.cuda: 48 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 49 | 50 | ###### Definition of variables ###### 51 | # Networks 52 | netG_A2B = Generator_unet_cls(opt.input_nc, opt.output_nc, 10, alt_leak=True, neg_slope=0.1) 53 | netG_B2A = Generator_unet_cls(opt.output_nc, opt.input_nc, 10, alt_leak=True, neg_slope=0.1) 54 | netD_A = Discriminator(opt.input_nc) 55 | netD_B = Discriminator(opt.output_nc) 56 | 57 | if opt.cuda: 58 | netG_A2B.cuda() 59 | netG_B2A.cuda() 60 | netD_A.cuda() 61 | netD_B.cuda() 62 | 63 | netG_A2B.train() 64 | netG_B2A.train() 65 | 66 | netG_A2B.apply(weights_init_normal) 67 | netG_B2A.apply(weights_init_normal) 68 | netD_A.apply(weights_init_normal) 69 | netD_B.apply(weights_init_normal) 70 | # Lossess 71 | criterion_GAN = nn.MSELoss() 72 | criterion_cycle = nn.L1Loss() 73 | criterion_identity = nn.L1Loss() 74 | criterion_ssim = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3) 75 | 76 | # Optimizers & LR schedulers 77 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 78 | lr=opt.lr, betas=(0.5, 0.999)) 79 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 80 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 81 | 82 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 83 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 84 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 85 | 86 | # Inputs & targets memory allocation 87 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 88 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 89 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 90 | expert_A = Tensor(opt.batchSize2, opt.input_nc, opt.size, opt.size) 91 | expert_B = Tensor(opt.batchSize2, opt.output_nc, opt.size, opt.size) 92 | expert_AL = Tensor(opt.batchSize2) 93 | expert_BL = Tensor(opt.batchSize2) 94 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 95 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 96 | 97 | fake_A_buffer = ReplayBuffer() 98 | fake_B_buffer = ReplayBuffer() 99 | 100 | transforms_ = transforms.Compose([ 101 | transforms.RandomCrop(opt.size), 102 | transforms.RandomHorizontalFlip(), 103 | transforms.RandomVerticalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 106 | 107 | transform_expert_ = transforms.Compose([ 108 | transforms.RandomHorizontalFlip(), 109 | transforms.RandomVerticalFlip(), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 112 | 113 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 114 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 115 | 116 | expertloader = DataLoader(ExpertDataset_label(opt.dataroot_ek, transform_expert=transform_expert_, 117 | batch_size=opt.batchSize2, unaligned=True), 118 | batch_size=opt.batchSize2, shuffle=True, num_workers=opt.n_cpu) 119 | 120 | ################################### 121 | start_epoch = opt.epoch 122 | pre_trainning = 1 123 | if opt.continue_train: 124 | pre_trainning = 0 125 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 126 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 127 | #netG_A2B.load_state_dict(netG_A2B_checkpoint) 128 | 129 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 130 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 131 | #netG_B2A.load_state_dict(netG_B2A_checkpoint) 132 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 133 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 134 | start_epoch = netG_B2A_checkpoint['epoch'] 135 | 136 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 137 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 138 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 139 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 140 | 141 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 142 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 143 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 144 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 145 | 146 | # Loss plot 147 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 148 | expert_iter = iter(expertloader) 149 | ###### Training ###### 150 | for epoch in range(start_epoch, opt.n_epochs): 151 | for i, batch in enumerate(dataloader): 152 | real_A = Variable(input_A.copy_(batch['HE'])) 153 | real_B = Variable(input_B.copy_(batch['Ki67'])) 154 | 155 | try: 156 | expert_batch = expert_iter.__next__() 157 | except StopIteration: 158 | expert_iter = iter(expertloader) 159 | expert_batch = expert_iter.__next__() 160 | 161 | real_expert_A = Variable(expert_A.copy_(expert_batch['expert_HE'])) 162 | real_expert_B = Variable(expert_B.copy_(expert_batch['expert_Ki67'])) 163 | real_expert_AL = Variable(expert_AL.copy_(expert_batch['expert_HE_label'])) 164 | real_expert_BL = Variable(expert_BL.copy_(expert_batch['expert_Ki67_label'])) 165 | 166 | # Generators A2B and B2A 167 | optimizer_G.zero_grad() 168 | # learn expert knowledge 169 | c_expa = netG_A2B(real_expert_A, mode='C') 170 | loss_expert_A = criterion_GAN(c_expa, real_expert_AL) 171 | 172 | c_expb = netG_B2A(real_expert_B, mode='C') 173 | loss_expert_B = criterion_GAN(c_expb, real_expert_BL) 174 | 175 | loss_C = (loss_expert_A + loss_expert_B) 176 | 177 | # Identity loss 178 | # G_A2B(B) should equal B if real B is fed 179 | same_B, _, _, _ = netG_A2B(real_B) 180 | loss_identity_B = criterion_identity(same_B, real_B) 181 | # G_B2A(A) should equal A if real A is fed 182 | same_A, _, _, _ = netG_B2A(real_A) 183 | loss_identity_A = criterion_identity(same_A, real_A) 184 | 185 | # GAN loss 186 | fake_B, features_lfb, c_fb, features_pfb = netG_A2B(real_A) 187 | pred_fake = netD_B(fake_B) 188 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 189 | 190 | fake_A, features_lfa, c_fa, features_pfa = netG_B2A(real_B) 191 | pred_fake = netD_A(fake_A) 192 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 193 | 194 | # Cycle loss 195 | recovered_A, features_lra, c_ra, features_pra = netG_B2A(fake_B) 196 | loss_cycle_embd_ABA = criterion_cycle(features_lra, features_lfb) 197 | loss_cycle_pathology_ABA = criterion_cycle(features_pra, features_pfb) 198 | loss_class_pathology_ABA = criterion_GAN(c_ra, c_fb) 199 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 200 | loss_cycle_ssimABA = criterion_ssim(recovered_A, real_A) 201 | 202 | recovered_B, features_lrb, c_rb, features_prb = netG_A2B(fake_A) 203 | loss_cycle_embd_BAB = criterion_cycle(features_lrb, features_lfa) 204 | loss_cycle_pathology_BAB = criterion_cycle(features_prb, features_pfa) 205 | loss_class_pathology_BAB = criterion_GAN(c_rb, c_fa) 206 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 207 | loss_cycle_ssimBAB = criterion_ssim(recovered_B, real_B) 208 | 209 | # Total loss 210 | loss_G = 5.0 * loss_C +\ 211 | 2.5 * (loss_identity_A + loss_identity_B) + \ 212 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 213 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) + \ 214 | 2.0 * (loss_cycle_ssimABA + loss_cycle_ssimBAB) + \ 215 | 1.0 * (loss_cycle_embd_ABA + loss_cycle_embd_BAB) + \ 216 | 1.0 * (loss_cycle_pathology_ABA + loss_cycle_pathology_BAB) + \ 217 | 2.0 * (loss_class_pathology_ABA + loss_class_pathology_BAB) 218 | 219 | loss_G.backward() 220 | 221 | optimizer_G.step() 222 | ################################### 223 | 224 | # Discriminator A 225 | optimizer_D_A.zero_grad() 226 | 227 | # Real loss 228 | pred_real = netD_A(real_A) 229 | loss_D_real = criterion_GAN(pred_real, target_real) 230 | 231 | # Fake loss 232 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 233 | pred_fake = netD_A(fake_Ad.detach()) 234 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 235 | 236 | # Total loss 237 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 238 | loss_D_A.backward() 239 | 240 | optimizer_D_A.step() 241 | ################################### 242 | 243 | # Discriminator B 244 | optimizer_D_B.zero_grad() 245 | 246 | # Real loss 247 | pred_real = netD_B(real_B) 248 | loss_D_real = criterion_GAN(pred_real, target_real) 249 | 250 | # Fake loss 251 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 252 | pred_fake = netD_B(fake_Bd.detach()) 253 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 254 | 255 | # Total loss 256 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 257 | loss_D_B.backward() 258 | 259 | optimizer_D_B.step() 260 | ################################### 261 | 262 | # Progress report (http://localhost:8097) 263 | logger.log({'loss_G': loss_G, 264 | 'loss_G_idt': (loss_identity_A + loss_identity_B), 265 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 266 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 267 | 'loss_G_ssim': (loss_cycle_ssimABA + loss_cycle_ssimBAB), 268 | 'loss_pc': (loss_class_pathology_ABA + loss_class_pathology_BAB), 269 | 'loss_pf': (loss_cycle_pathology_ABA + loss_cycle_pathology_BAB), 270 | 'loss_embd': (loss_cycle_embd_ABA + loss_cycle_embd_BAB), 271 | 'loss_C': loss_C, 272 | 'loss_D': (loss_D_A + loss_D_B)}, 273 | images={'real_cycleGAN_pf_cls_A': real_A, 'real_cycleGAN_pf_cls_B': real_B, 274 | 'fake_cycleGAN_pf_cls_A': fake_A, 'fake_cycleGAN_pf_cls_B': fake_B}) 275 | 276 | # save models at half of an epoch 277 | if (i+1) % (dataloader.__len__()//6 + 1) == 0: 278 | saveroot = os.path.join(opt.modelroot, 'temp') 279 | if not os.path.exists(saveroot): 280 | os.makedirs(saveroot) 281 | 282 | # Save models checkpoints 283 | netG_A2B_checkpoints = { 284 | "model": netG_A2B.state_dict() 285 | } 286 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 287 | 288 | netG_B2A_checkpoints = { 289 | "model": netG_B2A.state_dict(), 290 | 'optimizer': optimizer_G.state_dict(), 291 | "epoch": epoch, 292 | 'lr_schedule': lr_scheduler_G.state_dict() 293 | } 294 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 295 | 296 | netD_A_checkpoints = { 297 | "model": netD_A.state_dict(), 298 | 'optimizer': optimizer_D_A.state_dict(), 299 | 'lr_schedule': lr_scheduler_D_A.state_dict() 300 | } 301 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 302 | 303 | netD_B_checkpoints = { 304 | "model": netD_B.state_dict(), 305 | 'optimizer': optimizer_D_B.state_dict(), 306 | 'lr_schedule': lr_scheduler_D_B.state_dict() 307 | } 308 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 309 | 310 | # Update learning rates 311 | lr_scheduler_G.step() 312 | lr_scheduler_D_A.step() 313 | lr_scheduler_D_B.step() 314 | 315 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 316 | if not os.path.exists(saveroot): 317 | os.makedirs(saveroot) 318 | 319 | # Save models checkpoints 320 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 321 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 322 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 323 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 324 | -------------------------------------------------------------------------------- /train_Cycle_Gan_pathology_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch 12 | 13 | from mymodels import Generator_unet_seg 14 | from mymodels import Discriminator 15 | from utils import ReplayBuffer 16 | from utils import LambdaLR 17 | from utils import Logger 18 | from utils import weights_init_normal 19 | from utils import MS_SSIM_Loss 20 | from datasets import ImageDataset, ExpertDataset_mask 21 | # breast/ neuroendocrine / GLAS 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/datasetX20_288/Train_PN_1_1', help='root directory of the dataset') 24 | parser.add_argument('--dataroot_ek', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/datasetX20_288/Expert_knowledge', help='root directory of the expert knowledge') 25 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_pathology_seg_PN1_1', help='root directory of the model') 26 | 27 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 28 | parser.add_argument('--n_epochs', type=int, default=11, help='number of epochs of training') 29 | parser.add_argument('--batchSize', type=int, default=2, help='size of the batches') 30 | parser.add_argument('--batchSize2', type=int, default=8, help='size of the batches for expert knowledge learning') 31 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 32 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 33 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 34 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 35 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 36 | 37 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 38 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 39 | 40 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 41 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_pathology_seg_PN1_1/temp', help='continue train root directory of the model') 42 | opt = parser.parse_args() 43 | print(opt) 44 | 45 | if torch.cuda.is_available() and not opt.cuda: 46 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 47 | 48 | ###### Definition of variables ###### 49 | # Networks 50 | netG_A2B = Generator_unet_seg(opt.input_nc, opt.output_nc, 10, alt_leak=True, neg_slope=0.1) 51 | netG_B2A = Generator_unet_seg(opt.output_nc, opt.input_nc, 10, alt_leak=True, neg_slope=0.1) 52 | netD_A = Discriminator(opt.input_nc) 53 | netD_B = Discriminator(opt.output_nc) 54 | 55 | if opt.cuda: 56 | netG_A2B.cuda() 57 | netG_B2A.cuda() 58 | netD_A.cuda() 59 | netD_B.cuda() 60 | 61 | netG_A2B.train() 62 | netG_B2A.train() 63 | 64 | netG_A2B.apply(weights_init_normal) 65 | netG_B2A.apply(weights_init_normal) 66 | netD_A.apply(weights_init_normal) 67 | netD_B.apply(weights_init_normal) 68 | 69 | # Lossess 70 | criterion_GAN = nn.MSELoss() 71 | criterion_cycle = nn.L1Loss() 72 | criterion_identity = nn.L1Loss() 73 | 74 | criterion_ssim = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3) 75 | 76 | # Optimizers & LR schedulers 77 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 78 | lr=opt.lr, betas=(0.5, 0.999)) 79 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 80 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 81 | 82 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 83 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 84 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 85 | 86 | # Inputs & targets memory allocation 87 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 88 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 89 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 90 | expert_A = Tensor(opt.batchSize2, opt.input_nc, opt.size, opt.size) 91 | expert_B = Tensor(opt.batchSize2, opt.output_nc, opt.size, opt.size) 92 | expert_AL = Tensor(opt.batchSize2, 1, opt.size//8, opt.size//8) 93 | expert_BL = Tensor(opt.batchSize2, 1, opt.size//8, opt.size//8) 94 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 95 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 96 | 97 | fake_A_buffer = ReplayBuffer() 98 | fake_B_buffer = ReplayBuffer() 99 | 100 | transforms_ = transforms.Compose([ 101 | transforms.RandomCrop(opt.size), 102 | transforms.RandomHorizontalFlip(), 103 | transforms.RandomVerticalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 106 | 107 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 108 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 109 | 110 | expertloader = DataLoader(ExpertDataset_mask(opt.dataroot_ek, batch_size=opt.batchSize2, unaligned=False), 111 | batch_size=opt.batchSize2, shuffle=True, num_workers=opt.n_cpu) 112 | 113 | ################################### 114 | start_epoch = opt.epoch 115 | pre_trainning = 1 116 | if opt.continue_train: 117 | pre_trainning = 0 118 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 119 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 120 | 121 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 122 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 123 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 124 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 125 | start_epoch = netG_B2A_checkpoint['epoch'] 126 | 127 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 128 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 129 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 130 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 131 | 132 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 133 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 134 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 135 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 136 | 137 | # Loss plot 138 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 139 | expert_iter = iter(expertloader) 140 | ###### Training ###### 141 | for epoch in range(start_epoch, opt.n_epochs): 142 | for i, batch in enumerate(dataloader): 143 | real_A = Variable(input_A.copy_(batch['HE'])) 144 | real_B = Variable(input_B.copy_(batch['Ki67'])) 145 | 146 | try: 147 | expert_batch = expert_iter.__next__() 148 | except StopIteration: 149 | expert_iter = iter(expertloader) 150 | expert_batch = expert_iter.__next__() 151 | 152 | real_expert_A = Variable(expert_A.copy_(expert_batch['expert_HE'])) 153 | real_expert_B = Variable(expert_B.copy_(expert_batch['expert_Ki67'])) 154 | real_expert_AL = Variable(expert_AL.copy_(expert_batch['expert_HE_mask'])) 155 | real_expert_BL = Variable(expert_BL.copy_(expert_batch['expert_Ki67_mask'])) 156 | 157 | # Generators A2B and B2A 158 | optimizer_G.zero_grad() 159 | # learn expert knowledge 160 | c_expa = netG_A2B(real_expert_A, mode='C') 161 | loss_expert_A = criterion_GAN(c_expa, real_expert_AL) 162 | 163 | c_expb = netG_B2A(real_expert_B, mode='C') 164 | loss_expert_B = criterion_GAN(c_expb, real_expert_BL) 165 | 166 | loss_C = (loss_expert_A + loss_expert_B) 167 | 168 | # Identity loss 169 | # G_A2B(B) should equal B if real B is fed 170 | same_B, _, _, _ = netG_A2B(real_B) 171 | loss_identity_B = criterion_identity(same_B, real_B) 172 | # G_B2A(A) should equal A if real A is fed 173 | same_A, _, _, _ = netG_B2A(real_A) 174 | loss_identity_A = criterion_identity(same_A, real_A) 175 | 176 | # GAN loss 177 | fake_B, features_lfb, c_fb, features_pfb = netG_A2B(real_A) 178 | pred_fake = netD_B(fake_B) 179 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 180 | 181 | fake_A, features_lfa, c_fa, features_pfa = netG_B2A(real_B) 182 | pred_fake = netD_A(fake_A) 183 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 184 | 185 | # Cycle loss 186 | recovered_A, features_lra, c_ra, features_pra = netG_B2A(fake_B) 187 | loss_cycle_embd_ABA = criterion_cycle(features_lra, features_lfb) 188 | loss_cycle_pathology_ABA = criterion_cycle(features_pra, features_pfb) 189 | loss_class_pathology_ABA = criterion_GAN(c_ra, c_fb) 190 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 191 | loss_cycle_ssimABA = criterion_ssim(recovered_A, real_A) 192 | 193 | recovered_B, features_lrb, c_rb, features_prb = netG_A2B(fake_A) 194 | loss_cycle_embd_BAB = criterion_cycle(features_lrb, features_lfa) 195 | loss_cycle_pathology_BAB = criterion_cycle(features_prb, features_pfa) 196 | loss_class_pathology_BAB = criterion_GAN(c_rb, c_fa) 197 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 198 | loss_cycle_ssimBAB = criterion_ssim(recovered_B, real_B) 199 | 200 | # Total loss 201 | loss_G = 5.0 * loss_C +\ 202 | 2.5 * (loss_identity_A + loss_identity_B) + \ 203 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 204 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) + \ 205 | 2.0 * (loss_cycle_ssimABA + loss_cycle_ssimBAB) + \ 206 | 1.0 * (loss_cycle_embd_ABA + loss_cycle_embd_BAB) + \ 207 | 1.0 * (loss_cycle_pathology_ABA + loss_cycle_pathology_BAB) + \ 208 | 2.0 * (loss_class_pathology_ABA + loss_class_pathology_BAB) 209 | 210 | loss_G.backward() 211 | 212 | optimizer_G.step() 213 | ################################### 214 | 215 | # Discriminator A 216 | optimizer_D_A.zero_grad() 217 | 218 | # Real loss 219 | pred_real = netD_A(real_A) 220 | loss_D_real = criterion_GAN(pred_real, target_real) 221 | 222 | # Fake loss 223 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 224 | pred_fake = netD_A(fake_Ad.detach()) 225 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 226 | 227 | # Total loss 228 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 229 | loss_D_A.backward() 230 | 231 | optimizer_D_A.step() 232 | ################################### 233 | 234 | # Discriminator B 235 | optimizer_D_B.zero_grad() 236 | 237 | # Real loss 238 | pred_real = netD_B(real_B) 239 | loss_D_real = criterion_GAN(pred_real, target_real) 240 | 241 | # Fake loss 242 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 243 | pred_fake = netD_B(fake_Bd.detach()) 244 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 245 | 246 | # Total loss 247 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 248 | loss_D_B.backward() 249 | 250 | optimizer_D_B.step() 251 | ################################### 252 | 253 | # Progress report (http://localhost:8097) 254 | logger.log({'loss_G': loss_G, 255 | 'loss_G_idt': (loss_identity_A + loss_identity_B), 256 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 257 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 258 | 'loss_G_ssim': (loss_cycle_ssimABA + loss_cycle_ssimBAB), 259 | 'loss_ps': (loss_class_pathology_ABA + loss_class_pathology_BAB), 260 | 'loss_pf': (loss_cycle_pathology_ABA + loss_cycle_pathology_BAB), 261 | 'loss_embd': (loss_cycle_embd_ABA + loss_cycle_embd_BAB), 262 | 'loss_C': loss_C, 263 | 'loss_D': (loss_D_A + loss_D_B)}, 264 | images={'real_cycleGAN_pf_seg_A': real_A, 'real_cycleGAN_pf_seg_B': real_B, 265 | 'fake_cycleGAN_pf_seg_A': fake_A, 'fake_cycleGAN_pf_seg_B': fake_B, 266 | 'seg_A': c_expa, 'real_mask_A':real_expert_AL, 267 | 'G_sA': c_fb, 'G_sB': c_ra}) 268 | 269 | # save models at half of an epoch 270 | if i == 200 and epoch == 1: 271 | saveroot = os.path.join(opt.modelroot, 'star_temp') 272 | if not os.path.exists(saveroot): 273 | os.makedirs(saveroot) 274 | 275 | # Save models checkpoints 276 | netG_A2B_checkpoints = { 277 | "model": netG_A2B.state_dict() 278 | } 279 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 280 | 281 | netG_B2A_checkpoints = { 282 | "model": netG_B2A.state_dict(), 283 | 'optimizer': optimizer_G.state_dict(), 284 | "epoch": epoch, 285 | 'lr_schedule': lr_scheduler_G.state_dict() 286 | } 287 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 288 | 289 | netD_A_checkpoints = { 290 | "model": netD_A.state_dict(), 291 | 'optimizer': optimizer_D_A.state_dict(), 292 | 'lr_schedule': lr_scheduler_D_A.state_dict() 293 | } 294 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 295 | 296 | netD_B_checkpoints = { 297 | "model": netD_B.state_dict(), 298 | 'optimizer': optimizer_D_B.state_dict(), 299 | 'lr_schedule': lr_scheduler_D_B.state_dict() 300 | } 301 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 302 | 303 | if (i+1) % (dataloader.__len__()//6 + 1) == 0: 304 | saveroot = os.path.join(opt.modelroot, 'temp') 305 | if not os.path.exists(saveroot): 306 | os.makedirs(saveroot) 307 | 308 | # Save models checkpoints 309 | netG_A2B_checkpoints = { 310 | "model": netG_A2B.state_dict() 311 | } 312 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 313 | 314 | netG_B2A_checkpoints = { 315 | "model": netG_B2A.state_dict(), 316 | 'optimizer': optimizer_G.state_dict(), 317 | "epoch": epoch, 318 | 'lr_schedule': lr_scheduler_G.state_dict() 319 | } 320 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 321 | 322 | netD_A_checkpoints = { 323 | "model": netD_A.state_dict(), 324 | 'optimizer': optimizer_D_A.state_dict(), 325 | 'lr_schedule': lr_scheduler_D_A.state_dict() 326 | } 327 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 328 | 329 | netD_B_checkpoints = { 330 | "model": netD_B.state_dict(), 331 | 'optimizer': optimizer_D_B.state_dict(), 332 | 'lr_schedule': lr_scheduler_D_B.state_dict() 333 | } 334 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 335 | 336 | # Update learning rates 337 | lr_scheduler_G.step() 338 | lr_scheduler_D_A.step() 339 | lr_scheduler_D_B.step() 340 | 341 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 342 | if not os.path.exists(saveroot): 343 | os.makedirs(saveroot) 344 | 345 | # Save models checkpoints 346 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 347 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 348 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 349 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 350 | 351 | -------------------------------------------------------------------------------- /train_Cycle_Gan_unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import itertools 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch 12 | 13 | from mymodels import Generator_unet 14 | from mymodels import Discriminator 15 | from utils import ReplayBuffer 16 | from utils import LambdaLR 17 | from utils import Logger 18 | from utils import weights_init_normal 19 | from datasets import ImageDataset 20 | # breast/ neuroendocrine / GLAS 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/datasetX20_288/Train_PN_3_1', help='root directory of the dataset') 23 | parser.add_argument('--modelroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_unet_PN_3_1', help='root directory of the model') 24 | 25 | parser.add_argument('--epoch', type=int, default=1, help='starting epoch') 26 | parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs of training') 27 | parser.add_argument('--batchSize', type=int, default=2, help='size of the batches') 28 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') 29 | parser.add_argument('--decay_epoch', type=int, default=2, help='epoch to start linearly decaying the learning rate to 0') 30 | parser.add_argument('--size', type=int, default=288, help='size of the data crop (squared assumed)') 31 | parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') 32 | parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') 33 | 34 | parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation') 35 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 36 | 37 | parser.add_argument('--continue_train', type=bool, default=False, help='load model and continue trainning') 38 | parser.add_argument('--loadroot', type=str, default='/home/zhangbc/Mydataspace/LST/neuroendocrine/mymodelx288/Cycle_GAN_unet_PN_3_1/temp', help='continue train root directory of the model') 39 | opt = parser.parse_args() 40 | print(opt) 41 | 42 | if torch.cuda.is_available() and not opt.cuda: 43 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 44 | 45 | ###### Definition of variables ###### 46 | # Networks 47 | netG_A2B = Generator_unet(opt.input_nc, opt.output_nc, 10, alt_leak=True, neg_slope=0.1) 48 | netG_B2A = Generator_unet(opt.output_nc, opt.input_nc, 10, alt_leak=True, neg_slope=0.1) 49 | netD_A = Discriminator(opt.input_nc) 50 | netD_B = Discriminator(opt.output_nc) 51 | 52 | if opt.cuda: 53 | netG_A2B.cuda() 54 | netG_B2A.cuda() 55 | netD_A.cuda() 56 | netD_B.cuda() 57 | 58 | netG_A2B.apply(weights_init_normal) 59 | netG_B2A.apply(weights_init_normal) 60 | netD_A.apply(weights_init_normal) 61 | netD_B.apply(weights_init_normal) 62 | # Lossess 63 | criterion_GAN = nn.MSELoss() 64 | criterion_cycle = nn.L1Loss() 65 | criterion_identity = nn.L1Loss() 66 | 67 | # Optimizers & LR schedulers 68 | optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 69 | lr=opt.lr, betas=(0.5, 0.999)) 70 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 71 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 72 | 73 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 74 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 75 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) 76 | 77 | # Inputs & targets memory allocation 78 | Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor 79 | input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) 80 | input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) 81 | target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) 82 | target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) 83 | 84 | fake_A_buffer = ReplayBuffer() 85 | fake_B_buffer = ReplayBuffer() 86 | 87 | transforms_ = transforms.Compose([ 88 | transforms.RandomCrop(opt.size), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.RandomVerticalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 93 | 94 | dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, batch_size=opt.batchSize, unaligned=True), 95 | batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) 96 | 97 | ################################### 98 | start_epoch = opt.epoch 99 | 100 | if opt.continue_train: 101 | 102 | netG_A2B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_A2B.pth')) # 加载断点 103 | netG_A2B.load_state_dict(netG_A2B_checkpoint['model']) # 加载模型可学习参数 104 | 105 | netG_B2A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netG_B2A.pth')) # 加载断点 106 | netG_B2A.load_state_dict(netG_B2A_checkpoint['model']) # 加载模型可学习参数 107 | optimizer_G.load_state_dict(netG_B2A_checkpoint['optimizer']) # 加载优化器参数 108 | lr_scheduler_G.load_state_dict(netG_B2A_checkpoint['lr_schedule']) # 加载lr_scheduler 109 | start_epoch = netG_B2A_checkpoint['epoch'] 110 | 111 | netD_A_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_A.pth')) # 加载断点 112 | netD_A.load_state_dict(netD_A_checkpoint['model']) # 加载模型可学习参数 113 | optimizer_D_A.load_state_dict(netD_A_checkpoint['optimizer']) # 加载优化器参数 114 | lr_scheduler_D_A.load_state_dict(netD_A_checkpoint['lr_schedule']) # 加载lr_scheduler 115 | 116 | netD_B_checkpoint = torch.load(os.path.join(opt.loadroot, 'netD_B.pth')) # 加载断点 117 | netD_B.load_state_dict(netD_B_checkpoint['model']) # 加载模型可学习参数 118 | optimizer_D_B.load_state_dict(netD_B_checkpoint['optimizer']) # 加载优化器参数 119 | lr_scheduler_D_B.load_state_dict(netD_B_checkpoint['lr_schedule']) # 加载lr_scheduler 120 | 121 | # Loss plot 122 | logger = Logger(opt.n_epochs, len(dataloader), start_epoch) 123 | 124 | ###### Training ###### 125 | for epoch in range(start_epoch, opt.n_epochs): 126 | for i, batch in enumerate(dataloader): 127 | 128 | # Set model input 129 | real_A = Variable(input_A.copy_(batch['HE'])) 130 | real_B = Variable(input_B.copy_(batch['Ki67'])) 131 | 132 | # Generators A2B and B2A 133 | optimizer_G.zero_grad() 134 | 135 | # Identity loss 136 | # G_A2B(B) should equal B if real B is fed 137 | same_B, _ = netG_A2B(real_B) 138 | loss_identity_B = criterion_identity(same_B, real_B) 139 | # G_B2A(A) should equal A if real A is fed 140 | same_A, _ = netG_B2A(real_A) 141 | loss_identity_A = criterion_identity(same_A, real_A) 142 | 143 | # GAN loss 144 | fake_B, _ = netG_A2B(real_A) 145 | pred_fake = netD_B(fake_B) 146 | loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 147 | 148 | fake_A, _ = netG_B2A(real_B) 149 | pred_fake = netD_A(fake_A) 150 | loss_GAN_B2A = criterion_GAN(pred_fake, target_real) 151 | 152 | # Cycle loss 153 | recovered_A, _ = netG_B2A(fake_B) 154 | loss_cycle_ABA = criterion_cycle(recovered_A, real_A) 155 | 156 | recovered_B, _ = netG_A2B(fake_A) 157 | loss_cycle_BAB = criterion_cycle(recovered_B, real_B) 158 | 159 | # Total loss 160 | loss_G = 5.0 * (loss_identity_A + loss_identity_B) +\ 161 | 1.0 * (loss_GAN_A2B + loss_GAN_B2A) + \ 162 | 10.0 * (loss_cycle_ABA + loss_cycle_BAB) 163 | 164 | loss_G.backward() 165 | 166 | optimizer_G.step() 167 | ################################### 168 | 169 | # Discriminator A 170 | optimizer_D_A.zero_grad() 171 | 172 | # Real loss 173 | pred_real = netD_A(real_A) 174 | loss_D_real = criterion_GAN(pred_real, target_real) 175 | 176 | # Fake loss 177 | fake_Ad = fake_A_buffer.push_and_pop(fake_A) 178 | pred_fake = netD_A(fake_Ad.detach()) 179 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 180 | 181 | # Total loss 182 | loss_D_A = (loss_D_real + loss_D_fake)*0.5 183 | loss_D_A.backward() 184 | 185 | optimizer_D_A.step() 186 | ################################### 187 | 188 | # Discriminator B 189 | optimizer_D_B.zero_grad() 190 | 191 | # Real loss 192 | pred_real = netD_B(real_B) 193 | loss_D_real = criterion_GAN(pred_real, target_real) 194 | 195 | # Fake loss 196 | fake_Bd = fake_B_buffer.push_and_pop(fake_B) 197 | pred_fake = netD_B(fake_Bd.detach()) 198 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 199 | 200 | # Total loss 201 | loss_D_B = (loss_D_real + loss_D_fake)*0.5 202 | loss_D_B.backward() 203 | 204 | optimizer_D_B.step() 205 | ################################### 206 | 207 | # Progress report (http://localhost:8097) 208 | logger.log({'loss_G': loss_G, 209 | 'loss_G_identity': (loss_identity_A + loss_identity_B), 210 | 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 211 | 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 212 | 'loss_D': (loss_D_A + loss_D_B)}, 213 | images={'real_cycle_unet_A': real_A, 'real_cycle_unet_B': real_B, 214 | 'fake_cycle_unet_A': fake_A, 'fake_cycle_unet_B': fake_B}) 215 | 216 | # save models at half of an epoch 217 | if (i+1) % (dataloader.__len__()//5 + 1) == 0: 218 | saveroot = os.path.join(opt.modelroot, 'temp') 219 | if not os.path.exists(saveroot): 220 | os.makedirs(saveroot) 221 | 222 | # Save models checkpoints 223 | netG_A2B_checkpoints = { 224 | "model": netG_A2B.state_dict() 225 | } 226 | torch.save(netG_A2B_checkpoints, os.path.join(saveroot, 'netG_A2B.pth')) 227 | 228 | netG_B2A_checkpoints = { 229 | "model": netG_B2A.state_dict(), 230 | 'optimizer': optimizer_G.state_dict(), 231 | "epoch": epoch, 232 | 'lr_schedule': lr_scheduler_G.state_dict() 233 | } 234 | torch.save(netG_B2A_checkpoints, os.path.join(saveroot, 'netG_B2A.pth')) 235 | 236 | netD_A_checkpoints = { 237 | "model": netD_A.state_dict(), 238 | 'optimizer': optimizer_D_A.state_dict(), 239 | 'lr_schedule': lr_scheduler_D_A.state_dict() 240 | } 241 | torch.save(netD_A_checkpoints, os.path.join(saveroot, 'netD_A.pth')) 242 | 243 | netD_B_checkpoints = { 244 | "model": netD_B.state_dict(), 245 | 'optimizer': optimizer_D_B.state_dict(), 246 | 'lr_schedule': lr_scheduler_D_B.state_dict() 247 | } 248 | torch.save(netD_B_checkpoints, os.path.join(saveroot, 'netD_B.pth')) 249 | 250 | # Update learning rates 251 | lr_scheduler_G.step() 252 | lr_scheduler_D_A.step() 253 | lr_scheduler_D_B.step() 254 | 255 | saveroot = os.path.join(opt.modelroot, 'epoch'+str(epoch)) 256 | if not os.path.exists(saveroot): 257 | os.makedirs(saveroot) 258 | 259 | # Save models checkpoints 260 | torch.save(netG_A2B.state_dict(), os.path.join(saveroot, 'netG_A2B.pth')) 261 | torch.save(netG_B2A.state_dict(), os.path.join(saveroot, 'netG_B2A.pth')) 262 | torch.save(netD_A.state_dict(), os.path.join(saveroot, 'netD_A.pth')) 263 | torch.save(netD_B.state_dict(), os.path.join(saveroot, 'netD_B.pth')) 264 | -------------------------------------------------------------------------------- /unet_utils.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Down(nn.Module): 9 | """Downscaling with conv with stride=2,instanceNorm, relu""" 10 | 11 | def __init__(self, in_features, out_features, alt_leak=False, neg_slope=1e-2): 12 | super().__init__() 13 | self.down_conv = nn.Sequential( 14 | nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 15 | nn.InstanceNorm2d(out_features), 16 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | return self.down_conv(x) 21 | 22 | 23 | class Up(nn.Module): 24 | """Upscaling then double conv""" 25 | 26 | def __init__(self, in_features, out_features, alt_leak=False, neg_slope=1e-2): 27 | super().__init__() 28 | # upsample 29 | self.up_conv = nn.Sequential( 30 | nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), 31 | nn.InstanceNorm2d(out_features), 32 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 33 | 34 | # if bilinear, use the normal convolutions to reduce the number of channels 35 | self.gate = nn.Sequential( 36 | nn.Conv2d(out_features, out_features//2, 1), 37 | nn.InstanceNorm2d(out_features), 38 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 39 | 40 | self.merge = nn.Sequential( 41 | nn.Conv2d(out_features+out_features//2, out_features, 1), 42 | nn.InstanceNorm2d(out_features), 43 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 44 | 45 | def forward(self, x1, x2): 46 | 47 | x1 = self.up_conv(x1) 48 | 49 | x2 = self.gate(x2) 50 | diffY = x1.size()[2] - x2.size()[2] 51 | diffX = x1.size()[3] - x2.size()[3] 52 | 53 | x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2, 54 | diffY // 2, diffY - diffY // 2]) 55 | x = torch.cat([x1, x2], dim=1) 56 | 57 | return self.merge(x) 58 | 59 | class Pathology_feature(nn.Module): 60 | """Upscaling then double conv""" 61 | 62 | def __init__(self, in_features, out_features, alt_leak=False, neg_slope=1e-2): 63 | super().__init__() 64 | # upsample 65 | self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 66 | self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 67 | 68 | self.merge = nn.Sequential( 69 | nn.Conv2d(in_features, out_features, 1), 70 | nn.InstanceNorm2d(out_features), 71 | nn.LeakyReLU(neg_slope, inplace=True) if alt_leak else nn.ReLU(inplace=True)) 72 | 73 | def forward(self, x1, x2, x3): 74 | 75 | x1 = self.up4(x1) 76 | 77 | diffY1 = x3.size()[2] - x1.size()[2] 78 | diffX1 = x3.size()[3] - x1.size()[3] 79 | 80 | x1 = F.pad(x1, [diffX1 // 2, diffX1 - diffX1 // 2, 81 | diffY1 // 2, diffY1 - diffY1 // 2]) 82 | 83 | x = torch.cat([x3, x1], dim=1) 84 | 85 | x2 = self.up2(x2) 86 | 87 | diffY2 = x3.size()[2] - x2.size()[2] 88 | diffX2 = x3.size()[3] - x2.size()[3] 89 | 90 | x2 = F.pad(x2, [diffX2 // 2, diffX2 - diffX2 // 2, 91 | diffY2 // 2, diffY2 - diffY2 // 2]) 92 | 93 | x = torch.cat([x, x2], dim=1) 94 | 95 | return self.merge(x) 96 | 97 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import datetime 4 | import sys 5 | 6 | from torch.autograd import Variable 7 | import torch 8 | from visdom import Visdom 9 | import numpy as np 10 | from pytorch_msssim import MS_SSIM, ms_ssim, SSIM, ssim 11 | 12 | 13 | 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def tensor2image(tensor): 24 | image = 127.5*(tensor[0].cpu().float().numpy() + 1.0) 25 | if image.shape[0] == 1: 26 | image = np.tile(image, (3,1,1)) 27 | return image.astype(np.uint8) 28 | 29 | class Logger(): 30 | def __init__(self, n_epochs, batches_epoch, star_epoch): 31 | self.viz = Visdom() 32 | self.n_epochs = n_epochs 33 | self.batches_epoch = batches_epoch 34 | self.epoch = star_epoch 35 | self.epoch_re = 1 36 | self.batch = 1 37 | self.prev_time = time.time() 38 | self.mean_period = 0 39 | self.losses = {} 40 | self.loss_windows = {} 41 | self.image_windows = {} 42 | 43 | 44 | def log(self, losses=None, images=None): 45 | self.mean_period += (time.time() - self.prev_time) 46 | self.prev_time = time.time() 47 | 48 | sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch)) 49 | 50 | for i, loss_name in enumerate(losses.keys()): 51 | if loss_name not in self.losses: 52 | self.losses[loss_name] = losses[loss_name].item() 53 | else: 54 | self.losses[loss_name] += losses[loss_name].item() 55 | 56 | if (i+1) == len(losses.keys()): 57 | sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch)) 58 | else: 59 | sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch)) 60 | 61 | batches_done = self.batches_epoch*(self.epoch_re-1) + self.batch 62 | batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch 63 | sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done))) 64 | 65 | # Draw images 66 | for image_name, tensor in images.items(): 67 | if image_name not in self.image_windows: 68 | self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name}) 69 | else: 70 | self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name}) 71 | 72 | if (self.batch % self.batches_epoch) == 0: 73 | # Plot losses 74 | for loss_name, loss in self.losses.items(): 75 | 76 | if loss_name not in self.loss_windows: 77 | self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]),opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name}) 78 | else: 79 | self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append') 80 | 81 | # Reset losses for next epoch 82 | self.losses[loss_name] = 0.0 83 | 84 | self.epoch += 1 85 | self.epoch_re += 1 86 | self.batch = 1 87 | sys.stdout.write('\n') 88 | else: 89 | self.batch += 1 90 | 91 | 92 | class ReplayBuffer(): 93 | def __init__(self, max_size=50): 94 | assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.' 95 | self.max_size = max_size 96 | self.data = [] 97 | 98 | def push_and_pop(self, data): 99 | to_return = [] 100 | for element in data.data: 101 | element = torch.unsqueeze(element, 0) 102 | if len(self.data) < self.max_size: 103 | self.data.append(element) 104 | to_return.append(element) 105 | else: 106 | if random.uniform(0, 1) > 0.5: 107 | i = random.randint(0, self.max_size-1) 108 | to_return.append(self.data[i].clone()) 109 | self.data[i] = element 110 | else: 111 | to_return.append(element) 112 | return Variable(torch.cat(to_return)) 113 | 114 | 115 | class LambdaLR(): 116 | def __init__(self, n_epochs, offset, decay_start_epoch): 117 | assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!" 118 | self.n_epochs = n_epochs 119 | self.offset = offset 120 | self.decay_start_epoch = decay_start_epoch 121 | 122 | def step(self, epoch): 123 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch) 124 | 125 | 126 | def weights_init_normal(m): 127 | classname = m.__class__.__name__ 128 | if classname.find('Conv') != -1: 129 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 130 | elif classname.find('BatchNorm2d') != -1: 131 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 132 | torch.nn.init.constant(m.bias.data, 0.0) 133 | 134 | 135 | class MS_SSIM_Loss(MS_SSIM): 136 | def forward(self, img1, img2): 137 | img1 = (img1 + 1) / 2 138 | img2 = (img2 + 1) / 2 139 | return 5*(1 - super(MS_SSIM_Loss, self).forward(img1, img2)) 140 | 141 | 142 | class SSIM_Loss(SSIM): 143 | def forward(self, img1, img2): 144 | img1 = (img1 + 1) / 2 145 | img2 = (img2 + 1) / 2 146 | return 5*(1 - super(SSIM_Loss, self).forward(img1, img2)) 147 | --------------------------------------------------------------------------------