├── README.md ├── Image_preprocess.py ├── readme.md ├── main.py ├── LBM_2D.py ├── LICENSE ├── LBM_3D.py ├── Augmentation.py ├── Disco_GAN.py ├── Dual_GAN.py ├── DCGAN.py ├── Segmentation.py └── Cycle_GAN_cross.py /README.md: -------------------------------------------------------------------------------- 1 | # TGLABX on porous anaylsis, image recognition and 3D reconstruction 2 | Deep-learning-aided-porous-media-hydrodynamic-analysis-and-three-dimensional-reconstruction 3 | 4 | The study of hydrodynamic behavior and water-rock interaction mechanisms is typically characterized by high computational efficiency requirements, to allow for the fast and accurate extraction of structural information. Therefore, we chose to use deep learning models to achieve these requirements. In this paper we started by comparing the image segmentation performance of a series of autoencoder architectures on complex geometries of porous media. The goal was to extract hydrodynamic connectivity channels and the mineral composition of rock samples on SEM (Scanning electron microscopy) data, obtained with a 0.97 accuracy. We then focused on improving the computational efficiency of LBM by using GPU acceleration, which allowed us to rapidly simulate structural flow field features of complex porous media. The results obtained showed that we were able to improve the computational efficiency by a factor of 30 in our device environment. We subsequently employed a SWD-Cycle-GAN technique to migrate sedimentation features to the initial 2D structure slices to reconstruct a 3D (three-dimensional) porous media geometry, that fits the depositional features more closely. Overall, we propose a new method for 3D structure reconstruction and permeability performance analysis of porous media, based on deep learning. The proposed method is fast, efficient and accurate. 5 | -------------------------------------------------------------------------------- /Image_preprocess.py: -------------------------------------------------------------------------------- 1 | #Any image input no matter what scale it belongs, should through following steps that used for Image segmentation task 2 | 3 | #Including two parts: 4 | # 1. Segmentation training set prepare 8000->800->choose 5 label->label me 5 | # 2. Segmentation model predict 6 | 7 | from PIL import Image 8 | import cv2 9 | import os 10 | import random 11 | import json 12 | import labelme 13 | import numpy as np 14 | 15 | ####################需要提供的数据#################### 16 | out_folder = 'H:/tmp/DEEP_WATERROCK_CODE/codetest/' #数据存储路径 17 | label_number=5 #想要标记图片的数量 18 | if_labelme=0 #0不打开,1打开labelme 19 | file_path = 'H:/tmp/DEEP_WATERROCK_CODE/test_image/test_1.png' #要分割的图片地址 20 | json_path='H:/tmp/json/' #labelme之后翻译得到的地址,路径不要出现中文 21 | ##################################################### 22 | 23 | def normalize_image_size(input_img_path,out_path): 24 | img=cv2.imread(input_img_path) 25 | img=cv2.resize(img,(8000,8000)) 26 | new_path=out_path+'normalize_image.png' 27 | cv2.imwrite(new_path,img) 28 | return new_path 29 | 30 | def cut_image(image,width_out,height_out,path): 31 | width,height=image.size 32 | width_num=int(width/width_out) 33 | height_num=int(height/height_out) 34 | for j in range(width_num): 35 | for i in range(height_num): 36 | box = (width_out*i,height_out*j, width_out* (i + 1), height_out* (j + 1)) 37 | region = image.crop(box) 38 | region.save(path+'width_num={}_height_num={}.png'.format(j, i)) 39 | return path 40 | 41 | def random_choose_image(number:int,original_path): 42 | img_list=[] 43 | files=os.listdir(original_path) 44 | for file in files: 45 | img_list.append(original_path+file) 46 | img_choose=random.sample(img_list,number) 47 | return img_choose 48 | 49 | def choose_image_resize(img_choose,label_path): 50 | n=1 51 | for img in img_choose: 52 | img=cv2.imread(img) 53 | img=cv2.resize(img,(800,800)) 54 | cv2.imwrite(label_path+str(n)+'.png',img) 55 | n+=1 56 | return label_path 57 | 58 | def train_set_create(file_path,out_folder): 59 | image_path=normalize_image_size(file_path,out_folder) 60 | image=Image.open(image_path) #读取图片 61 | image_height,image_width=image.size 62 | height=image_height/10 63 | width=image_width/10 64 | cut_image_save_path=out_folder+'cut_image/' 65 | label_image_path=out_folder+'label_image/' 66 | try: 67 | os.makedirs(cut_image_save_path) 68 | except OSError: 69 | pass 70 | try: 71 | os.makedirs(label_image_path) 72 | except OSError: 73 | pass 74 | image_cut_path = cut_image(image,width,height,cut_image_save_path) #分割图片 75 | image_cut_choose=random_choose_image(label_number,image_cut_path) #随机挑选需要label的图片 76 | image_label_path=choose_image_resize(image_cut_choose,label_image_path) #resize为(800,800)用来标记的图片集 77 | path={'cut_path':image_cut_path,'label_path':image_label_path} 78 | return path 79 | 80 | def label_me(num=1): 81 | if num==1: 82 | os.system('labelme') 83 | elif num==0: 84 | pass 85 | #利用labelme标记需要训练的数据集 86 | def json2dataset(json_path,label_number,out_folder): 87 | files=os.listdir(json_path) 88 | file_list=[] 89 | for file in files: 90 | file_list.append(file) 91 | json_list=file_list[0:label_number] 92 | jsonfile_list=file_list[label_number:2*label_number] 93 | for n in range(label_number): 94 | json_file=json_list[n] 95 | img_file=jsonfile_list[n] 96 | data=json.load(open(json_path+json_file)) 97 | img=cv2.imread(json_path+img_file+'/img.png') 98 | lbl, lbl_names = labelme.utils.labelme_shapes_to_label(img.shape, data['shapes']) 99 | mask=[] 100 | class_id=[] 101 | class_name=[] 102 | for name in lbl_names: 103 | class_name.append(name) 104 | for i in range(1,len(lbl_names)): 105 | mask.append((lbl==i).astype(np.uint8)) 106 | class_id.append(lbl_names) 107 | mask=np.asarray(mask,np.uint8) 108 | mask_path=out_folder+'mask_image/' 109 | for j in range(0,len(class_name)-1): 110 | try: 111 | os.makedirs(mask_path+str(class_name[j+1])) 112 | except OSError: 113 | pass 114 | cv2.imwrite(mask_path+str(class_name[j+1])+'/'+str(n+1)+'.png',mask[j,:,:]) 115 | return mask_path 116 | 117 | def model_predict_set_create(file_path,out_folder): 118 | image_path=normalize_image_size(file_path,out_folder) 119 | image=Image.open(image_path) #读取图片 120 | image_height,image_width=image.size 121 | height=image_height/10 122 | width=image_width/10 123 | cut_image_save_path=out_folder+'cut_image/' 124 | try: 125 | os.makedirs(cut_image_save_path) 126 | except OSError: 127 | pass 128 | image_cut_path = cut_image(image,width,height,cut_image_save_path) #分割图片,100张 129 | return image_cut_path 130 | 131 | ''' 132 | 1. Segmentation training set prepare 133 | path=train_set_create(file_path,out_folder) 134 | label_me(if_labelme) 135 | cmd->labelme_json_to_dataset 136 | mask_path,Classes=json2dataset(json_path,label_number,out_folder) 137 | 138 | 2. Segmentation model predict 139 | model_predict_set_create(file_path,out_folder) 140 | ''' 141 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 1. Image_preprocess 2 | 2. Augmentation 3 | 3. Segmentation 4 | 4. LBM 2D,3D 5 | 5. Reconstruction 6 | 5.1 DCGAN 7 | 5.2 Dual-GAN 8 | 5.3 Disco-GAN 9 | 5.4 Cycle-GAN 10 | 6 cross-generation 11 | 12 | Functions: 13 | ''' 14 | 1. Segmentation training set prepare 15 | path=train_set_create(file_path,out_folder) 16 | label_me(if_labelme) 17 | cmd->labelme_json_to_dataset 18 | mask_path,Classes=json2dataset(json_path,label_number,out_folder) 19 | 2. Segmentation model predict 20 | model_predict_set_create(file_path,out_folder) 21 | ''' 22 | ''' 23 | 3. Augmentation_dataset for few-shot learning 24 | save_path_list=Augmentation_dataset(file_path,out_folder,json_path,label_number,train_num,test_num,val_num) 25 | ''' 26 | ''' 27 | 4. Segmentation train process 28 | Segmentation_train(Encoders=Encoders, 29 | Encoder_weights=Encoder_weights, 30 | model_name=model_name, 31 | Activation=Activation, 32 | Epochs=Epoch, 33 | batch_size=train_batch_size, 34 | dataset_choose=dataset_choose, 35 | out_folder=out_folder, 36 | json_path=json_path, 37 | label_number=label_number) 38 | 5. Segmentation test process from trained models 39 | Segmentation_test(Activation=Activation, 40 | dataset_choose=dataset_choose, 41 | out_folder=out_folder, 42 | json_path=json_path, 43 | label_number=label_number) 44 | # 6. Segmentation result on visualization and save result 45 | Segmentation_result(dataset_choose=dataset_choose, 46 | out_folder=out_folder, 47 | json_path=json_path, 48 | label_number=label_number) 49 | ''' 50 | ''' 51 | 7. LBM_2D_Analysis 52 | LBM_2D_Analysis(path_all,out_folder) 53 | ''' 54 | ''' 55 | 8. LBM_3D_Analysis 56 | LBM_3D_Analysis(path,out_folder) 57 | ''' 58 | ''' 59 | 9. DCGAN train 60 | DCGAN_train(imageSize=imageSize,batchSize=batchSize, 61 | ngf=number_generator_feature, 62 | ndf=number_discriminator_feature, 63 | nz=number_z, 64 | niter=number_train_iterations, 65 | ngpu=number_gpu, 66 | manualSeed=manualSeed, 67 | out_folder=out_folder, 68 | dataset_name=dataset_name, 69 | device=device) 70 | 71 | ''' 72 | ''' 73 | 10. DCGAN generate 74 | DCGAN_generator(seedmin=seedmin, 75 | seedmax=seedmax, 76 | ngf=number_generator_feature, 77 | ndf=number_discriminator_feature, 78 | nz=number_z, 79 | ngpu=number_gpu, 80 | imageSize=imageSize, 81 | imsize=image_generate_size, 82 | out_folder=out_folder, 83 | name=generate_name, 84 | device=device, 85 | netG=netG, 86 | ) 87 | ''' 88 | ''' 89 | 11. DCGAN batch processing samples statistic 90 | result_analysis(out_folder,generate_name) 91 | ''' 92 | ''' 93 | 12. Dual_GAN generate from promoted translation style 94 | Dual_GAN(out_folder=out_folder, 95 | dataset_name=dataset_name, 96 | dataset_path=dataset_path, 97 | checkpoint_interval=checkpoint_interval, 98 | sample_interval=sample_interval, 99 | n_epochs=n_epochs, 100 | batch_size= batch_size, 101 | lr=learning_rate, 102 | img_size=generate_image_size, 103 | channels=channels, 104 | pre_trained=pre_trained, 105 | trained_epoch=trained_epoch) 106 | ''' 107 | ''' 108 | 13. Disco_GAN generate from promoted translation style 109 | Disco_GAN(out_folder=out_folder, 110 | dataset_name=dataset_name, 111 | dataset_path=dataset_path, 112 | checkpoint_interval=checkpoint_interval, 113 | sample_interval=sample_interval, 114 | n_epochs=n_epochs, 115 | batch_size= batch_size, 116 | lr=learning_rate, 117 | channels=channels, 118 | img_height=img_height, 119 | img_width=img_width, 120 | pre_trained=pre_trained, 121 | trained_epoch=trained_epoch 122 | ) 123 | ''' 124 | ''' 125 | 14. Cycle_GAN generate from promoted translation style 126 | Cycle_GAN(out_folder=out_folder, 127 | dataset_name=dataset_name, 128 | dataset_path=dataset_path, 129 | checkpoint_interval=checkpoint_interval, 130 | sample_interval=sample_interval, 131 | n_epochs=n_epochs, 132 | batch_size=batch_size, 133 | lr=learning_rate, 134 | decay_epoch=decay_epoch, 135 | n_residual_blocks=Resnet_blocks, 136 | channels=channels, 137 | img_height=img_height, 138 | img_width=img_width, 139 | pre_trained=pre_trained, 140 | trained_epoch=trained_epoch 141 | ) 142 | ''' 143 | ''' 144 | 15. SWD WS and FID distribution calc 145 | WD,SWD=WD_SWD_calc(berea_calc_WD_SWD_datasetpath) 146 | WD_SWD_distribution_plot(WD,SWD,save_path=out_folder,dataset_name='berea') 147 | ''' 148 | ''' 149 | 16. corss domain datasets create 150 | cross_cycle_dataset(original_path=dataset_path, 151 | out_folder=out_folder, 152 | cross_number=cross_number) 153 | ''' 154 | ''' 155 | 17. cross datasets train models #一个模型一个模型地训练,要清除变量 156 | cross train: 157 | Cross_Cycle_GAN(out_folder=out_folder, 158 | cross_number=cross_number, 159 | checkpoint_interval=checkpoint_interval, 160 | sample_interval=sample_interval, 161 | n_epochs=n_epochs, 162 | batch_size=batch_size, 163 | lr=learning_rate, 164 | decay_epoch=decay_epoch, 165 | n_residual_blocks=Resnet_blocks, 166 | channels=channels, 167 | img_height=img_height, 168 | img_width=img_width, 169 | pre_trained=pre_trained, 170 | trained_epoch=trained_epoch) 171 | testloader visualization: 172 | testloader_result(test_loader=test_loader, 173 | n_resudual_blocks=Resnet_blocks, 174 | G_AB_path=G_AB_path, 175 | G_BA_path=G_BA_path, 176 | test_result_save_path=test_result_save_path, 177 | channels=channels, 178 | img_height=img_height, 179 | img_width=img_width) 180 | ''' 181 | ''' 182 | 18. SWD-guided Cycle-GAN 3D reconstruction 183 | Generate_SWD=SWD_cross_cycle(out_folder=out_folder, 184 | n_epochs=n_epochs, 185 | channels=channels, 186 | img_height=img_height, 187 | img_width=img_width, 188 | n_residual_blocks=Resnet_blocks, 189 | cross_number=cross_number, 190 | berea_calc_WD_SWD_datasetpath=berea_calc_WD_SWD_datasetpath, 191 | test_loader=test_loader) 192 | ''' 193 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(r'H:/tmp/DEEP_WATERROCK_CODE/Structure_code') #path=code download save path 3 | import Image_preprocess 4 | import Augmentation 5 | import Segmentation 6 | import LBM_2D 7 | import LBM_3D 8 | import DCGAN 9 | import Dual_GAN 10 | import Disco_GAN 11 | import Cycle_GAN_cross 12 | ''' 13 | Functions: 14 | 1. Segmentation training set prepare 15 | 16 | dir: Image_preprocess.py 17 | run: 18 | path=train_set_create(file_path,out_folder) 19 | label_me(if_labelme) 20 | cmd->labelme_json_to_dataset 21 | mask_path,Classes=json2dataset(json_path,label_number,out_folder) 22 | 23 | 2. Segmentation model predict 24 | 25 | dir: Image_preprocess.py 26 | run: 27 | model_predict_set_create(file_path,out_folder) 28 | 29 | 3. Augmentation_dataset for few-shot learning 30 | 31 | dir: Augmentation.py 32 | run: 33 | save_path_list=Augmentation_dataset(file_path,out_folder,json_path,label_number,train_num,test_num,val_num) 34 | 35 | 4. Segmentation train process 36 | 37 | dir: Segmentation.py 38 | run: 39 | Segmentation_train(Encoders=Encoders, 40 | Encoder_weights=Encoder_weights, 41 | model_name=model_name, 42 | Activation=Activation, 43 | Epochs=Epoch, 44 | batch_size=train_batch_size, 45 | dataset_choose=dataset_choose, 46 | out_folder=out_folder, 47 | json_path=json_path, 48 | label_number=label_number) 49 | 50 | 5. Segmentation test process from trained models 51 | 52 | dir: Segmentation.py 53 | run: 54 | Segmentation_test(Activation=Activation, 55 | dataset_choose=dataset_choose, 56 | out_folder=out_folder, 57 | json_path=json_path, 58 | label_number=label_number) 59 | 60 | 6. Segmentation result on visualization and save result 61 | 62 | dir: Segmentation.py 63 | run: 64 | Segmentation_result(dataset_choose=dataset_choose, 65 | out_folder=out_folder, 66 | json_path=json_path, 67 | label_number=label_number) 68 | 69 | 7. LBM_2D_Analysis 70 | 71 | dir: LBM_2D.py 72 | run: 73 | LBM_2D_Analysis(path_all,out_folder) 74 | 75 | 8. LBM_3D_Analysis 76 | 77 | dir: LBM_3D.py 78 | run: 79 | LBM_3D_Analysis(path,out_folder) 80 | 81 | 9. DCGAN train 82 | 83 | dir: DCGAN.py 84 | run: 85 | DCGAN_train(imageSize=imageSize,batchSize=batchSize, 86 | ngf=number_generator_feature, 87 | ndf=number_discriminator_feature, 88 | nz=number_z, 89 | niter=number_train_iterations, 90 | ngpu=number_gpu, 91 | manualSeed=manualSeed, 92 | out_folder=out_folder, 93 | dataset_name=dataset_name, 94 | device=device) 95 | 96 | 10. DCGAN generate 97 | 98 | dir: DCGAN.py 99 | run: 100 | DCGAN_generator(seedmin=seedmin, 101 | seedmax=seedmax, 102 | ngf=number_generator_feature, 103 | ndf=number_discriminator_feature, 104 | nz=number_z, 105 | ngpu=number_gpu, 106 | imageSize=imageSize, 107 | imsize=image_generate_size, 108 | out_folder=out_folder, 109 | name=generate_name, 110 | device=device, 111 | netG=netG, 112 | ) 113 | 114 | 11. DCGAN batch processing samples statistic 115 | 116 | dir: DCGAN.py 117 | run: 118 | result_analysis(out_folder,generate_name) 119 | 120 | 12. Dual_GAN generate from promoted translation style 121 | 122 | dir: Dual_GAN.py 123 | run: 124 | Dual_GAN(out_folder=out_folder, 125 | dataset_name=dataset_name, 126 | dataset_path=dataset_path, 127 | checkpoint_interval=checkpoint_interval, 128 | sample_interval=sample_interval, 129 | n_epochs=n_epochs, 130 | batch_size= batch_size, 131 | lr=learning_rate, 132 | img_size=generate_image_size, 133 | channels=channels, 134 | pre_trained=pre_trained, 135 | trained_epoch=trained_epoch) 136 | 137 | 13. Disco_GAN generate from promoted translation style 138 | 139 | dir: Disco_GAN.py 140 | run: 141 | Disco_GAN(out_folder=out_folder, 142 | dataset_name=dataset_name, 143 | dataset_path=dataset_path, 144 | checkpoint_interval=checkpoint_interval, 145 | sample_interval=sample_interval, 146 | n_epochs=n_epochs, 147 | batch_size= batch_size, 148 | lr=learning_rate, 149 | channels=channels, 150 | img_height=img_height, 151 | img_width=img_width, 152 | pre_trained=pre_trained, 153 | trained_epoch=trained_epoch 154 | ) 155 | 156 | 14. Cycle_GAN generate from promoted translation style 157 | 158 | dir: Cycle_GAN_cross.py 159 | run: 160 | Cycle_GAN(out_folder=out_folder, 161 | dataset_name=dataset_name, 162 | dataset_path=dataset_path, 163 | checkpoint_interval=checkpoint_interval, 164 | sample_interval=sample_interval, 165 | n_epochs=n_epochs, 166 | batch_size=batch_size, 167 | lr=learning_rate, 168 | decay_epoch=decay_epoch, 169 | n_residual_blocks=Resnet_blocks, 170 | channels=channels, 171 | img_height=img_height, 172 | img_width=img_width, 173 | pre_trained=pre_trained, 174 | trained_epoch=trained_epoch 175 | ) 176 | 177 | 15. SWD WS and FID distribution calc 178 | 179 | dir: Cycle_GAN_cross.py 180 | run: 181 | WD,SWD=WD_SWD_calc(berea_calc_WD_SWD_datasetpath) 182 | WD_SWD_distribution_plot(WD,SWD,save_path=out_folder,dataset_name='berea') 183 | 184 | 16. corss domain datasets create 185 | 186 | dir: Cycle_GAN_cross.py 187 | run: 188 | cross_cycle_dataset(original_path=dataset_path, 189 | out_folder=out_folder, 190 | cross_number=cross_number) 191 | 192 | 17. cross datasets train models #一个模型一个模型地训练,要清除变量 193 | 194 | dir: Cycle_GAN_cross.py 195 | run: 196 | 197 | *cross train: 198 | Cross_Cycle_GAN(out_folder=out_folder, 199 | cross_number=cross_number, 200 | checkpoint_interval=checkpoint_interval, 201 | sample_interval=sample_interval, 202 | n_epochs=n_epochs, 203 | batch_size=batch_size, 204 | lr=learning_rate, 205 | decay_epoch=decay_epoch, 206 | n_residual_blocks=Resnet_blocks, 207 | channels=channels, 208 | img_height=img_height, 209 | img_width=img_width, 210 | pre_trained=pre_trained, 211 | trained_epoch=trained_epoch) 212 | *testloader visualization: 213 | testloader_result(test_loader=test_loader, 214 | n_resudual_blocks=Resnet_blocks, 215 | G_AB_path=G_AB_path, 216 | G_BA_path=G_BA_path, 217 | test_result_save_path=test_result_save_path, 218 | channels=channels, 219 | img_height=img_height, 220 | img_width=img_width) 221 | 222 | 18. SWD-guided Cycle-GAN 3D reconstruction 223 | 224 | dir: Cycle_GAN_cross.py 225 | run: 226 | Generate_SWD=SWD_cross_cycle(out_folder=out_folder, 227 | n_epochs=n_epochs, 228 | channels=channels, 229 | img_height=img_height, 230 | img_width=img_width, 231 | n_residual_blocks=Resnet_blocks, 232 | cross_number=cross_number, 233 | berea_calc_WD_SWD_datasetpath=berea_calc_WD_SWD_datasetpath, 234 | test_loader=test_loader) 235 | ''' 236 | -------------------------------------------------------------------------------- /LBM_2D.py: -------------------------------------------------------------------------------- 1 | #LBM 2D Permeability Analysis D2Q9 2 | 3 | import matplotlib.pyplot as plt 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import os 8 | from pylab import meshgrid, arange, streamplot, show 9 | import time 10 | 11 | ####################需要提供的数据#################### 12 | #Image_path 13 | device='cuda' 14 | out_folder='H:/tmp/DEEP_WATERROCK_CODE/codetest/' 15 | path1='H:/清华大学/论文/论文《DEEP FLOW》王明阳/深度分割重建计算/SEGMENTATION_RESULTS/results_image/formerwork/1_1.png' 16 | path2='H:/清华大学/论文/论文《DEEP FLOW》王明阳/深度分割重建计算/SEGMENTATION_RESULTS/results_image/planB/O_result/Unet++_densenet201/1.png' 17 | ##################################################### 18 | 19 | def get_img_obstacle(img_path): 20 | num_0=0 21 | num_1=0 22 | obstacle_0=[] 23 | obstacle_1=[] 24 | img=Image.open(img_path) 25 | data=np.array(img)/255.0 26 | data[data!=1]=0 27 | for x in range(0,data.shape[0]): 28 | for y in range(0,data.shape[1]): 29 | if data[x][y].all()==0: 30 | num_0+=1 31 | obstacle_0.append((x+1)*(y+1)) 32 | else: 33 | num_1+=1 34 | obstacle_1.append((x+1)*(y+1)) 35 | return obstacle_0,obstacle_1,data 36 | 37 | def to_tensor_gpu(x): 38 | return torch.tensor(x).to(device) 39 | 40 | #solver适用于(3,800,800) 41 | def lbm_solver(img_path): 42 | #开始计时 43 | time0=time.time() 44 | #初值定义 45 | omega=1.0 46 | density=1.0 47 | t1=4/9 48 | t2=1/9 49 | t3=1/36 50 | c_squ=1/3 51 | avu=1 52 | prevavu=1 53 | ts=0 54 | deltaU=1e-7 55 | cxs = np.array([1, 1, 0,-1,-1,-1, 0, 1, 0]) 56 | cys = np.array([0, 1, 1, 1, 0,-1,-1,-1, 0]) 57 | weights = np.array([1/9,1/36,1/9,1/36,1/9,1/36,1/9,1/36,4/9]) 58 | NL=9 59 | idxs=np.arange(9) 60 | device='cuda' 61 | #读取数据 62 | img_path=img_path 63 | obstacle,active_nodes,BOUND=get_img_obstacle(img_path) 64 | BOUND=BOUND.astype('bool') 65 | #print('number of obstacle:',len(obstacle),'\n','number of active nodes:',len(active_nodes)) 66 | porosity=len(obstacle)/(len(obstacle)+len(active_nodes)) 67 | print('porosity:',porosity) 68 | print('Calc nodes number:',(len(obstacle)+len(active_nodes))) 69 | BOUND=BOUND[:,:] 70 | nx=BOUND.shape[0] 71 | ny=BOUND.shape[1] 72 | F=np.tile(density/9,[nx,ny,9]) #F=repmat(density/9,[nx ny 9]) 73 | FEQ=F 74 | #cuda加速 75 | F1=to_tensor_gpu(F) 76 | FEQ1=to_tensor_gpu(FEQ) 77 | cxs1=to_tensor_gpu(cxs) 78 | cys1=to_tensor_gpu(cys) 79 | weights1=to_tensor_gpu(weights) 80 | idxs1=to_tensor_gpu(idxs) 81 | avu1=to_tensor_gpu(avu) 82 | prevavu1=to_tensor_gpu(prevavu) 83 | ts1=to_tensor_gpu(ts) 84 | #开始迭代计算 85 | while ts1<4000 and 1e-100.5: 56 | image = tf.hflip(image) 57 | mask = tf.hflip(mask) 58 | if random.random()<0.5: 59 | image = tf.vflip(image) 60 | mask = tf.vflip(mask) 61 | image = tf.to_tensor(image) 62 | mask = tf.to_tensor(mask) 63 | return image, mask 64 | def adjustContrast(self,image,mask): 65 | factor = transforms.RandomRotation.get_params([0,10]) #这里调增广后的数据的对比度 66 | image = tf.adjust_contrast(image,factor) 67 | #mask = tf.adjust_contrast(mask,factor) 68 | image = tf.to_tensor(image) 69 | mask = tf.to_tensor(mask) 70 | return image,mask 71 | def adjustBrightness(self,image,mask): 72 | factor = transforms.RandomRotation.get_params([1, 2]) #这里调增广后的数据亮度 73 | image = tf.adjust_brightness(image, factor) 74 | #mask = tf.adjust_contrast(mask, factor) 75 | image = tf.to_tensor(image) 76 | mask = tf.to_tensor(mask) 77 | return image, mask 78 | def adjustSaturation(self,image,mask): #调整饱和度 79 | factor = transforms.RandomRotation.get_params([1, 2]) # 这里调增广后的数据亮度 80 | image = tf.adjust_saturation(image, factor) 81 | #mask = tf.adjust_saturation(mask, factor) 82 | image = tf.to_tensor(image) 83 | mask = tf.to_tensor(mask) 84 | return image, mask 85 | 86 | #train data create 87 | def augmentationData_train(image_path,mask_path,option=[1,2,3,4,5],save_dir=None,multiple=70): 88 | aug_image_savedDir = os.path.join(save_dir,'train') 89 | aug_mask_savedDir = os.path.join(save_dir, 'train_mask') 90 | aug = Augmentation() 91 | res_image= os.walk(image_path) 92 | images = [] 93 | masks = [] 94 | for root,dirs,files in res_image: 95 | for f in files: 96 | images.append(os.path.join(root,f)) 97 | res_mask = os.walk(mask_path) 98 | for root,dirs,files in res_mask: 99 | for f in files: 100 | masks.append(os.path.join(root,f)) 101 | datas = list(zip(images,masks)) 102 | num = len(datas) 103 | for epoch in range(int(multiple/5)): #生成100组数据用于训练,原图用于最终测试 104 | for (image_path,mask_path) in datas: 105 | image = Image.open(image_path) 106 | mask = Image.open(mask_path) 107 | if 1 in option: 108 | num+=1 109 | image_tensor, mask_tensor = aug.rotate(image, mask) 110 | image_rotate = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'train', str(num) + '_rotate.png')) 111 | mask_rotate = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'train_mask', str(num) + '_rotate.png')) 112 | if 2 in option: 113 | num+=1 114 | image_tensor, mask_tensor = aug.flip(image, mask) 115 | image_filp = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir,'train',str(num)+'_filp.png')) 116 | mask_filp = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir,'train_mask',str(num)+'_filp.png')) 117 | if 3 in option: 118 | num+=1 119 | image_tensor, mask_tensor = aug.adjustContrast(image, mask) 120 | image_Contrast = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'train', str(num) + '_Contrast.png')) 121 | mask_Contrast = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'train_mask', str(num) + '_Contrast.png')) 122 | if 4 in option: 123 | num+=1 124 | image_tensor, mask_tensor = aug.adjustBrightness(image, mask) 125 | image_Brightness = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'train', str(num) + '_Brightness.png')) 126 | mask_Brightness = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'train_mask', str(num) + '_Brightness.png')) 127 | if 5 in option: 128 | num+=1 129 | image_tensor, mask_tensor = aug.adjustSaturation(image, mask) 130 | image_Saturation = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'train', str(num) + '_Saturation.png')) 131 | mask_Saturation = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'train_mask', str(num) + '_Saturation.png')) 132 | 133 | #test data create 134 | def augmentationData_test(image_path,mask_path,option=[1,2,3,4,5],save_dir=None,multiple=20): 135 | aug_image_savedDir = os.path.join(save_dir,'test') 136 | aug_mask_savedDir = os.path.join(save_dir, 'test_mask') 137 | aug = Augmentation() 138 | res_image= os.walk(image_path) 139 | images = [] 140 | masks = [] 141 | for root,dirs,files in res_image: 142 | for f in files: 143 | images.append(os.path.join(root,f)) 144 | res_mask = os.walk(mask_path) 145 | for root,dirs,files in res_mask: 146 | for f in files: 147 | masks.append(os.path.join(root,f)) 148 | datas = list(zip(images,masks)) 149 | num = len(datas) 150 | for epoch in range(int(multiple/5)): #生成100组数据用于test,原图用于最终测试 151 | for (image_path,mask_path) in datas: 152 | image = Image.open(image_path) 153 | mask = Image.open(mask_path) 154 | if 1 in option: 155 | num+=1 156 | image_tensor, mask_tensor = aug.rotate(image, mask) 157 | image_rotate = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'test', str(num) + '_rotate.png')) 158 | mask_rotate = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'test_mask', str(num) + '_rotate.png')) 159 | if 2 in option: 160 | num+=1 161 | image_tensor, mask_tensor = aug.flip(image, mask) 162 | image_filp = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir,'test',str(num)+'_filp.png')) 163 | mask_filp = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir,'test_mask',str(num)+'_filp.png')) 164 | if 3 in option: 165 | num+=1 166 | image_tensor, mask_tensor = aug.adjustContrast(image, mask) 167 | image_Contrast = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'test', str(num) + '_Contrast.png')) 168 | mask_Contrast = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'test_mask', str(num) + '_Contrast.png')) 169 | if 4 in option: 170 | num+=1 171 | image_tensor, mask_tensor = aug.adjustBrightness(image, mask) 172 | image_Brightness = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'test', str(num) + '_Brightness.png')) 173 | mask_Brightness = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'test_mask', str(num) + '_Brightness.png')) 174 | if 5 in option: 175 | num+=1 176 | image_tensor, mask_tensor = aug.adjustSaturation(image, mask) 177 | image_Saturation = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'test', str(num) + '_Saturation.png')) 178 | mask_Saturation = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'test_mask', str(num) + '_Saturation.png')) 179 | 180 | #validation data create 181 | def augmentationData_validation(image_path,mask_path,option=[1,2,3,4,5],save_dir=None,multiple=10): 182 | aug_image_savedDir = os.path.join(save_dir,'val') 183 | aug_mask_savedDir = os.path.join(save_dir, 'val_mask') 184 | aug = Augmentation() 185 | res_image= os.walk(image_path) 186 | images = [] 187 | masks = [] 188 | for root,dirs,files in res_image: 189 | for f in files: 190 | images.append(os.path.join(root,f)) 191 | res_mask = os.walk(mask_path) 192 | for root,dirs,files in res_mask: 193 | for f in files: 194 | masks.append(os.path.join(root,f)) 195 | datas = list(zip(images,masks)) 196 | num = len(datas) 197 | for epoch in range(int(multiple/5)): #生成100组数据用于validation,原图用于最终测试 198 | for (image_path,mask_path) in datas: 199 | image = Image.open(image_path) 200 | mask = Image.open(mask_path) 201 | if 1 in option: 202 | num+=1 203 | image_tensor, mask_tensor = aug.rotate(image, mask) 204 | image_rotate = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'val', str(num) + '_rotate.png')) 205 | mask_rotate = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'val_mask', str(num) + '_rotate.png')) 206 | if 2 in option: 207 | num+=1 208 | image_tensor, mask_tensor = aug.flip(image, mask) 209 | image_filp = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir,'val',str(num)+'_filp.png')) 210 | mask_filp = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir,'val_mask',str(num)+'_filp.png')) 211 | if 3 in option: 212 | num+=1 213 | image_tensor, mask_tensor = aug.adjustContrast(image, mask) 214 | image_Contrast = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'val', str(num) + '_Contrast.png')) 215 | mask_Contrast = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'val_mask', str(num) + '_Contrast.png')) 216 | if 4 in option: 217 | num+=1 218 | image_tensor, mask_tensor = aug.adjustBrightness(image, mask) 219 | image_Brightness = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'val', str(num) + '_Brightness.png')) 220 | mask_Brightness = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'val_mask', str(num) + '_Brightness.png')) 221 | if 5 in option: 222 | num+=1 223 | image_tensor, mask_tensor = aug.adjustSaturation(image, mask) 224 | image_Saturation = transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'val', str(num) + '_Saturation.png')) 225 | mask_Saturation = transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'val_mask', str(num) + '_Saturation.png')) 226 | 227 | def augmentation_dataset_create(gt_path,mask_files_list,save_path_list,train_num=70,test_num=20,val_num=10): #num * number = 文件数 num=5的倍数 228 | for i in range(len(save_path_list)): 229 | augmentationData_train(image_path=gt_path, 230 | mask_path=mask_files_list[i], 231 | save_dir=save_path_list[i], 232 | multiple=train_num) 233 | augmentationData_test(image_path=gt_path, 234 | mask_path=mask_files_list[i], 235 | save_dir=save_path_list[i], 236 | multiple=test_num) 237 | augmentationData_validation(image_path=gt_path, 238 | mask_path=mask_files_list[i], 239 | save_dir=save_path_list[i], 240 | multiple=val_num) 241 | 242 | def Augmentation_dataset(file_path,out_folder,json_path,label_number, 243 | train_num,test_num,val_num): 244 | try: 245 | os.makedirs(out_folder+'dataset') 246 | except OSError: 247 | pass 248 | dataset_path=out_folder+'dataset/' 249 | path=train_set_create(file_path,out_folder) 250 | gt_path=path['label_path'] 251 | mask_path=json2dataset(json_path,label_number,out_folder) 252 | mask_files=os.listdir(mask_path) 253 | mask_files_list=[] #存储mask的路径 254 | mask_path_list=[] 255 | save_path_list=[] #存储dataset save的路径 256 | for file in mask_files: 257 | mask_files_list.append(file) 258 | mask_path_list.append(mask_path+file) 259 | for file in mask_files: 260 | save_path_list.append(dataset_path+file) 261 | make_save_path(dataset_path,mask_files_list) 262 | for i in range(len(save_path_list)): 263 | make_dataset_dirs(save_path_list[i]) 264 | augmentation_dataset_create(gt_path,mask_path_list,save_path_list, 265 | train_num=train_num,test_num=test_num,val_num=val_num) 266 | return save_path_list 267 | 268 | ''' 269 | 3. Augmentation_dataset for few-shot learning 270 | save_path_list=Augmentation_dataset(file_path,out_folder,json_path,label_number,train_num,test_num,val_num) 271 | ''' 272 | -------------------------------------------------------------------------------- /Disco_GAN.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import os 8 | import math 9 | import itertools 10 | import sys 11 | import datetime 12 | import time 13 | import torchvision.transforms as transforms 14 | from torchvision.utils import save_image 15 | from torch.utils.data import DataLoader 16 | from torchvision import datasets 17 | from torch.autograd import Variable 18 | from torch.utils.data import Dataset 19 | from PIL import Image 20 | 21 | ###############需要提供的数据############### 22 | out_folder='H:/tmp/DEEP_WATERROCK_CODE/codetest/' 23 | dataset_name='pore2miropore' 24 | dataset_path='H:/tmp/DEEP_WATERROCK_CODE/pore2miropore' 25 | checkpoint_interval=10 26 | sample_interval=10 27 | n_epochs=20 28 | batch_size=1 29 | learning_rate=2e-4 30 | channels=1 31 | img_height=256 32 | img_width=256 33 | pre_trained=False 34 | trained_epoch=0 35 | ########################################## 36 | 37 | class ImageDataset(Dataset): 38 | def __init__(self, root, transforms_=None, mode='train'): 39 | self.transform = transforms.Compose(transforms_) 40 | 41 | self.files = sorted(glob.glob(os.path.join(root, mode) + '/*.*')) 42 | 43 | def __getitem__(self, index): 44 | 45 | img = Image.open(self.files[index % len(self.files)]) 46 | w, h = img.size 47 | img_A = img.crop((0, 0, w/2, h)) 48 | img_B = img.crop((w/2, 0, w, h)) 49 | 50 | if np.random.random() < 0.5: 51 | img_A = Image.fromarray(np.uint8(img_A)[ ::-1,:]) 52 | img_B = Image.fromarray(np.uint8(img_B)[ ::-1,:]) 53 | 54 | img_A = self.transform(img_A) 55 | img_B = self.transform(img_B) 56 | 57 | return {'A': img_A, 'B': img_B} 58 | 59 | def __len__(self): 60 | return len(self.files) 61 | 62 | def weights_init_normal(m): 63 | classname = m.__class__.__name__ 64 | if classname.find("Conv") != -1: 65 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 66 | elif classname.find("BatchNorm2d") != -1: 67 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 68 | torch.nn.init.constant_(m.bias.data, 0.0) 69 | 70 | ############################## 71 | # U-NET 72 | ############################## 73 | class UNetDown(nn.Module): 74 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 75 | super(UNetDown, self).__init__() 76 | layers = [nn.Conv2d(in_size, out_size, 4, 2, 1)] 77 | if normalize: 78 | layers.append(nn.InstanceNorm2d(out_size)) 79 | layers.append(nn.LeakyReLU(0.2)) 80 | if dropout: 81 | layers.append(nn.Dropout(dropout)) 82 | self.model = nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | return self.model(x) 86 | 87 | class UNetUp(nn.Module): 88 | def __init__(self, in_size, out_size, dropout=0.0): 89 | super(UNetUp, self).__init__() 90 | layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1), nn.InstanceNorm2d(out_size), nn.ReLU(inplace=True)] 91 | if dropout: 92 | layers.append(nn.Dropout(dropout)) 93 | 94 | self.model = nn.Sequential(*layers) 95 | 96 | def forward(self, x, skip_input): 97 | x = self.model(x) 98 | x = torch.cat((x, skip_input), 1) 99 | 100 | return x 101 | 102 | class GeneratorUNet(nn.Module): 103 | def __init__(self, input_shape): 104 | super(GeneratorUNet, self).__init__() 105 | channels, _, _ = input_shape 106 | self.down1 = UNetDown(channels, 64, normalize=False) 107 | self.down2 = UNetDown(64, 128) 108 | self.down3 = UNetDown(128, 256, dropout=0.5) 109 | self.down4 = UNetDown(256, 512, dropout=0.5) 110 | self.down5 = UNetDown(512, 512, dropout=0.5) 111 | self.down6 = UNetDown(512, 512, dropout=0.5, normalize=False) 112 | 113 | self.up1 = UNetUp(512, 512, dropout=0.5) 114 | self.up2 = UNetUp(1024, 512, dropout=0.5) 115 | self.up3 = UNetUp(1024, 256, dropout=0.5) 116 | self.up4 = UNetUp(512, 128) 117 | self.up5 = UNetUp(256, 64) 118 | 119 | self.final = nn.Sequential( 120 | nn.Upsample(scale_factor=2), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(128, channels, 4, padding=1), nn.Tanh() 121 | ) 122 | 123 | def forward(self, x): 124 | # U-Net generator with skip connections from encoder to decoder 125 | d1 = self.down1(x) 126 | d2 = self.down2(d1) 127 | d3 = self.down3(d2) 128 | d4 = self.down4(d3) 129 | d5 = self.down5(d4) 130 | d6 = self.down6(d5) 131 | u1 = self.up1(d6, d5) 132 | u2 = self.up2(u1, d4) 133 | u3 = self.up3(u2, d3) 134 | u4 = self.up4(u3, d2) 135 | u5 = self.up5(u4, d1) 136 | 137 | return self.final(u5) 138 | 139 | ############################## 140 | # Discriminator 141 | ############################## 142 | class Discriminator(nn.Module): 143 | def __init__(self, input_shape): 144 | super(Discriminator, self).__init__() 145 | 146 | channels, height, width = input_shape 147 | # Calculate output of image discriminator (PatchGAN) 148 | self.output_shape = (1, height // 2 ** 3, width // 2 ** 3) 149 | 150 | def discriminator_block(in_filters, out_filters, normalization=True): 151 | """Returns downsampling layers of each discriminator block""" 152 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 153 | if normalization: 154 | layers.append(nn.InstanceNorm2d(out_filters)) 155 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 156 | return layers 157 | 158 | self.model = nn.Sequential( 159 | *discriminator_block(channels, 64, normalization=False), 160 | *discriminator_block(64, 128), 161 | *discriminator_block(128, 256), 162 | nn.ZeroPad2d((1, 0, 1, 0)), 163 | nn.Conv2d(256, 1, 4, padding=1) 164 | ) 165 | 166 | def forward(self, img): 167 | # Concatenate image and condition image by channels to produce input 168 | return self.model(img) 169 | 170 | def Disco_GAN(out_folder, 171 | dataset_name, 172 | dataset_path, 173 | checkpoint_interval, 174 | sample_interval, 175 | n_epochs, 176 | batch_size, 177 | lr, 178 | channels, 179 | img_height, 180 | img_width, 181 | pre_trained:bool, 182 | trained_epoch): 183 | # Create sample and checkpoint directories 184 | save_path=out_folder+'Image2Image_translation/' 185 | os.makedirs(save_path+"/disco_images/%s" % dataset_name, exist_ok=True) 186 | os.makedirs(save_path+"/disco_saved_models/%s" %dataset_name, exist_ok=True) 187 | disco_images_path=save_path+"/disco_images/%s/" % dataset_name 188 | disco_saved_models_path=save_path+"/disco_saved_models/%s/" %dataset_name 189 | # Losses 190 | adversarial_loss = torch.nn.MSELoss() 191 | cycle_loss = torch.nn.L1Loss() 192 | pixelwise_loss = torch.nn.L1Loss() 193 | cuda = torch.cuda.is_available() 194 | 195 | input_shape = (channels, img_height, img_width) 196 | # Initialize generator and discriminator 197 | G_AB = GeneratorUNet(input_shape) 198 | G_BA = GeneratorUNet(input_shape) 199 | D_A = Discriminator(input_shape) 200 | D_B = Discriminator(input_shape) 201 | if cuda: 202 | G_AB = G_AB.cuda() 203 | G_BA = G_BA.cuda() 204 | D_A = D_A.cuda() 205 | D_B = D_B.cuda() 206 | adversarial_loss.cuda() 207 | cycle_loss.cuda() 208 | pixelwise_loss.cuda() 209 | 210 | if pre_trained==True: 211 | if trained_epoch != 0: 212 | G_AB.load_state_dict(torch.load(disco_saved_models_path+"G_AB_%d.pth" % (trained_epoch))) 213 | G_BA.load_state_dict(torch.load(disco_saved_models_path+"G_BA_%d.pth" % (trained_epoch))) 214 | D_A.load_state_dict(torch.load(disco_saved_models_path+"D_A_%d.pth" % (trained_epoch))) 215 | D_B.load_state_dict(torch.load(disco_saved_models_path+"D_B_%d.pth" % (trained_epoch))) 216 | else: 217 | G_AB.apply(weights_init_normal) 218 | G_BA.apply(weights_init_normal) 219 | D_A.apply(weights_init_normal) 220 | D_B.apply(weights_init_normal) 221 | else: 222 | G_AB.apply(weights_init_normal) 223 | G_BA.apply(weights_init_normal) 224 | D_A.apply(weights_init_normal) 225 | D_B.apply(weights_init_normal) 226 | 227 | # Optimizers 228 | optimizer_G = torch.optim.Adam( 229 | itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(0.5, 0.999) 230 | ) 231 | optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999)) 232 | optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999)) 233 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 234 | # Dataset loader 235 | transforms_ = [ 236 | transforms.Resize((img_height, img_width), Image.BICUBIC), 237 | transforms.ToTensor(), 238 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 239 | transforms.Normalize((0.1307,), (0.3081,)) 240 | ] 241 | dataloader = DataLoader( 242 | ImageDataset(dataset_path, transforms_=transforms_, mode="train"), 243 | batch_size=batch_size, 244 | shuffle=True, 245 | num_workers=0, 246 | ) 247 | val_dataloader = DataLoader( 248 | ImageDataset(dataset_path , transforms_=transforms_, mode="val"), 249 | batch_size=8, 250 | shuffle=True, 251 | num_workers=0, 252 | ) 253 | 254 | def sample_images(batches_done,path): 255 | """Saves a generated sample from the validation set""" 256 | imgs = next(iter(val_dataloader)) 257 | G_AB.eval() 258 | G_BA.eval() 259 | real_A = Variable(imgs["A"].type(Tensor)) 260 | fake_B = G_AB(real_A) 261 | real_B = Variable(imgs["B"].type(Tensor)) 262 | fake_A = G_BA(real_B) 263 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data, fake_A.data), 0) 264 | save_image(img_sample, path+"/%s.png" % (batches_done), nrow=8, normalize=True) 265 | 266 | # ---------- 267 | # Training 268 | # ---------- 269 | epochs=0 270 | prev_time = time.time() 271 | for epoch in range(epochs,n_epochs): 272 | for i, batch in enumerate(dataloader): 273 | # Model inputs 274 | real_A = Variable(batch["A"].type(Tensor)) 275 | real_B = Variable(batch["B"].type(Tensor)) 276 | # Adversarial ground truths 277 | valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) 278 | fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) 279 | # ------------------ 280 | # Train Generators 281 | # ------------------ 282 | G_AB.train() 283 | G_BA.train() 284 | optimizer_G.zero_grad() 285 | # GAN loss 286 | fake_B = G_AB(real_A) 287 | loss_GAN_AB = adversarial_loss(D_B(fake_B), valid) 288 | fake_A = G_BA(real_B) 289 | loss_GAN_BA = adversarial_loss(D_A(fake_A), valid) 290 | loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 291 | # Pixelwise translation loss 292 | loss_pixelwise = (pixelwise_loss(fake_A, real_A) + pixelwise_loss(fake_B, real_B)) / 2 293 | # Cycle loss 294 | loss_cycle_A = cycle_loss(G_BA(fake_B), real_A) 295 | loss_cycle_B = cycle_loss(G_AB(fake_A), real_B) 296 | loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 297 | # Total loss 298 | loss_G = loss_GAN + loss_cycle + loss_pixelwise 299 | loss_G.backward() 300 | optimizer_G.step() 301 | # ----------------------- 302 | # Train Discriminator A 303 | # ----------------------- 304 | optimizer_D_A.zero_grad() 305 | # Real loss 306 | loss_real = adversarial_loss(D_A(real_A), valid) 307 | # Fake loss (on batch of previously generated samples) 308 | loss_fake = adversarial_loss(D_A(fake_A.detach()), fake) 309 | # Total loss 310 | loss_D_A = (loss_real + loss_fake) / 2 311 | loss_D_A.backward() 312 | optimizer_D_A.step() 313 | # ----------------------- 314 | # Train Discriminator B 315 | # ----------------------- 316 | optimizer_D_B.zero_grad() 317 | # Real loss 318 | loss_real = adversarial_loss(D_B(real_B), valid) 319 | # Fake loss (on batch of previously generated samples) 320 | loss_fake = adversarial_loss(D_B(fake_B.detach()), fake) 321 | # Total loss 322 | loss_D_B = (loss_real + loss_fake) / 2 323 | loss_D_B.backward() 324 | optimizer_D_B.step() 325 | loss_D = 0.5 * (loss_D_A + loss_D_B) 326 | # -------------- 327 | # Log Progress 328 | # -------------- 329 | # Determine approximate time left 330 | batches_done = epoch * len(dataloader) + i 331 | batches_left = n_epochs * len(dataloader) - batches_done 332 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 333 | prev_time = time.time() 334 | # Print log 335 | print( 336 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, pixel: %f, cycle: %f] ETA: %s" 337 | % ( 338 | epoch, 339 | n_epochs, 340 | i, 341 | len(dataloader), 342 | loss_D.item(), 343 | loss_G.item(), 344 | loss_GAN.item(), 345 | loss_pixelwise.item(), 346 | loss_cycle.item(), 347 | time_left, 348 | ) 349 | ) 350 | f=open(save_path+'disco_process.txt','a') 351 | f.write( 352 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, pixel: %f, cycle: %f] ETA: %s" 353 | % ( 354 | epoch, 355 | n_epochs, 356 | i, 357 | len(dataloader), 358 | loss_D.item(), 359 | loss_G.item(), 360 | loss_GAN.item(), 361 | loss_pixelwise.item(), 362 | loss_cycle.item(), 363 | time_left, 364 | ) 365 | ) 366 | f.close() 367 | # If at sample interval save image 368 | if batches_done % sample_interval == 0: 369 | sample_images(batches_done,disco_images_path) 370 | if checkpoint_interval != -1 and epoch % checkpoint_interval == 0: 371 | # Save model checkpoints 372 | torch.save(G_AB.state_dict(), disco_saved_models_path+"G_AB_%d.pth" % (epoch)) 373 | torch.save(G_BA.state_dict(), disco_saved_models_path+"G_BA_%d.pth" % (epoch)) 374 | torch.save(D_A.state_dict(), disco_saved_models_path+"D_A_%d.pth" % (epoch)) 375 | torch.save(D_B.state_dict(), disco_saved_models_path+"D_B_%d.pth" % (epoch)) 376 | 377 | ''' 378 | 13. Disco_GAN generate from promoted translation style 379 | Disco_GAN(out_folder=out_folder, 380 | dataset_name=dataset_name, 381 | dataset_path=dataset_path, 382 | checkpoint_interval=checkpoint_interval, 383 | sample_interval=sample_interval, 384 | n_epochs=n_epochs, 385 | batch_size= batch_size, 386 | lr=learning_rate, 387 | channels=channels, 388 | img_height=img_height, 389 | img_width=img_width, 390 | pre_trained=pre_trained, 391 | trained_epoch=trained_epoch 392 | ) 393 | ''' 394 | -------------------------------------------------------------------------------- /Dual_GAN.py: -------------------------------------------------------------------------------- 1 | # Dual_GAN generate different style slice 2 | 3 | import glob 4 | import random 5 | import os 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch 10 | from torchvision.models import vgg19 11 | import math 12 | from torch.utils.data import Dataset 13 | from PIL import Image 14 | import torchvision.transforms as transforms 15 | import itertools 16 | import scipy 17 | import sys 18 | import time 19 | import datetime 20 | from torchvision.utils import save_image 21 | from torch.utils.data import DataLoader 22 | from torchvision import datasets 23 | from torch.autograd import Variable 24 | import torch.autograd as autograd 25 | 26 | ###############需要提供的数据############### 27 | out_folder='H:/tmp/DEEP_WATERROCK_CODE/codetest/' 28 | dataset_name='pore2miropore' 29 | dataset_path='H:/tmp/DEEP_WATERROCK_CODE/pore2miropore' 30 | checkpoint_interval=10 31 | sample_interval=10 32 | n_epochs=20 33 | batch_size=8 34 | learning_rate=2e-4 35 | channels=1 36 | generate_image_size=128 37 | pre_trained=False 38 | trained_epoch=0 39 | ########################################## 40 | 41 | class ImageDataset(Dataset): 42 | def __init__(self, root, transforms_=None, mode="train"): 43 | self.transform = transforms.Compose(transforms_) 44 | 45 | self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*")) 46 | 47 | def __getitem__(self, index): 48 | 49 | img = Image.open(self.files[index % len(self.files)]) 50 | w, h = img.size 51 | img_A = img.crop((0, 0, w / 2, h)) 52 | img_B = img.crop((w / 2, 0, w, h)) 53 | 54 | if np.random.random() < 0.5: 55 | img_A = Image.fromarray(np.uint8(img_A)[ ::-1,:]) 56 | img_B = Image.fromarray(np.uint8(img_B)[ ::-1,:]) 57 | 58 | img_A = self.transform(img_A) 59 | img_B = self.transform(img_B) 60 | 61 | return {"A": img_A, "B": img_B} 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def weights_init_normal(m): 67 | classname = m.__class__.__name__ 68 | if classname.find("Conv") != -1: 69 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 70 | elif classname.find("BatchNorm2d") != -1: 71 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 72 | torch.nn.init.constant_(m.bias.data, 0.0) 73 | 74 | ############################## 75 | # U-NET 76 | ############################## 77 | 78 | class UNetDown(nn.Module): 79 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 80 | super(UNetDown, self).__init__() 81 | layers = [nn.Conv2d(in_size, out_size, 4, stride=2, padding=1, bias=False)] 82 | if normalize: 83 | layers.append(nn.InstanceNorm2d(out_size, affine=True)) 84 | layers.append(nn.LeakyReLU(0.2)) 85 | if dropout: 86 | layers.append(nn.Dropout(dropout)) 87 | self.model = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.model(x) 91 | 92 | class UNetUp(nn.Module): 93 | def __init__(self, in_size, out_size, dropout=0.0): 94 | super(UNetUp, self).__init__() 95 | layers = [ 96 | nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, bias=False), 97 | nn.InstanceNorm2d(out_size, affine=True), 98 | nn.ReLU(inplace=True), 99 | ] 100 | if dropout: 101 | layers.append(nn.Dropout(dropout)) 102 | 103 | self.model = nn.Sequential(*layers) 104 | 105 | def forward(self, x, skip_input): 106 | x = self.model(x) 107 | x = torch.cat((x, skip_input), 1) 108 | 109 | return x 110 | 111 | class Generator(nn.Module): 112 | def __init__(self, channels=1): 113 | super(Generator, self).__init__() 114 | 115 | self.down1 = UNetDown(channels, 64, normalize=False) 116 | self.down2 = UNetDown(64, 128) 117 | self.down3 = UNetDown(128, 256) 118 | self.down4 = UNetDown(256, 512, dropout=0.5) 119 | self.down5 = UNetDown(512, 512, dropout=0.5) 120 | self.down6 = UNetDown(512, 512, dropout=0.5) 121 | self.down7 = UNetDown(512, 512, dropout=0.5, normalize=False) 122 | 123 | self.up1 = UNetUp(512, 512, dropout=0.5) 124 | self.up2 = UNetUp(1024, 512, dropout=0.5) 125 | self.up3 = UNetUp(1024, 512, dropout=0.5) 126 | self.up4 = UNetUp(1024, 256) 127 | self.up5 = UNetUp(512, 128) 128 | self.up6 = UNetUp(256, 64) 129 | 130 | self.final = nn.Sequential(nn.ConvTranspose2d(128, channels, 4, stride=2, padding=1), nn.Tanh()) 131 | 132 | def forward(self, x): 133 | # Propogate noise through fc layer and reshape to img shape 134 | d1 = self.down1(x) 135 | d2 = self.down2(d1) 136 | d3 = self.down3(d2) 137 | d4 = self.down4(d3) 138 | d5 = self.down5(d4) 139 | d6 = self.down6(d5) 140 | d7 = self.down7(d6) 141 | u1 = self.up1(d7, d6) 142 | u2 = self.up2(u1, d5) 143 | u3 = self.up3(u2, d4) 144 | u4 = self.up4(u3, d3) 145 | u5 = self.up5(u4, d2) 146 | u6 = self.up6(u5, d1) 147 | 148 | return self.final(u6) 149 | 150 | ############################## 151 | # Discriminator 152 | ############################## 153 | 154 | class Discriminator(nn.Module): 155 | def __init__(self, in_channels=1): 156 | super(Discriminator, self).__init__() 157 | 158 | def discrimintor_block(in_features, out_features, normalize=True): 159 | """Discriminator block""" 160 | layers = [nn.Conv2d(in_features, out_features, 4, stride=2, padding=1)] 161 | if normalize: 162 | layers.append(nn.BatchNorm2d(out_features, 0.8)) 163 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 164 | return layers 165 | 166 | self.model = nn.Sequential( 167 | *discrimintor_block(in_channels, 64, normalize=False), 168 | *discrimintor_block(64, 128), 169 | *discrimintor_block(128, 256), 170 | nn.ZeroPad2d((1, 0, 1, 0)), 171 | nn.Conv2d(256, 1, kernel_size=4) 172 | ) 173 | 174 | def forward(self, img): 175 | return self.model(img) 176 | 177 | 178 | 179 | def Dual_GAN(out_folder,dataset_name,dataset_path, 180 | checkpoint_interval,sample_interval, 181 | n_epochs,batch_size,lr,img_size,channels, 182 | pre_trained:bool,trained_epoch): 183 | 184 | def compute_gradient_penalty(D, real_samples, fake_samples): 185 | """Calculates the gradient penalty loss for WGAN GP""" 186 | # Random weight term for interpolation between real and fake samples 187 | alpha = FloatTensor(np.random.random((real_samples.size(0), 1, 1, 1))) 188 | # Get random interpolation between real and fake samples 189 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 190 | validity = D(interpolates) 191 | fake = Variable(FloatTensor(np.ones(validity.shape)), requires_grad=False) 192 | # Get gradient w.r.t. interpolates 193 | gradients = autograd.grad( 194 | outputs=validity, 195 | inputs=interpolates, 196 | grad_outputs=fake, 197 | create_graph=True, 198 | retain_graph=True, 199 | only_inputs=True, 200 | )[0] 201 | gradients = gradients.view(gradients.size(0), -1) 202 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 203 | return gradient_penalty 204 | 205 | def sample_images(batches_done,path): 206 | """Saves a generated sample from the test set""" 207 | imgs = next(iter(val_dataloader)) 208 | real_A = Variable(imgs["A"].type(FloatTensor)) 209 | fake_B = G_AB(real_A) 210 | AB = torch.cat((real_A.data, fake_B.data), -2) 211 | real_B = Variable(imgs["B"].type(FloatTensor)) 212 | fake_A = G_BA(real_B) 213 | BA = torch.cat((real_B.data, fake_A.data), -2) 214 | img_sample = torch.cat((AB, BA), 0) 215 | save_image(img_sample, path+"/%s.png" % (batches_done), nrow=8, normalize=True) 216 | 217 | save_path=out_folder+'Image2Image_translation/' 218 | os.makedirs(save_path+"/dual_images/%s" % dataset_name, exist_ok=True) 219 | os.makedirs(save_path+"/dual_saved_models/%s" %dataset_name, exist_ok=True) 220 | dual_images_path=save_path+"/dual_images/%s/" % dataset_name 221 | dual_saved_models_path=save_path+"/dual_saved_models/%s/" %dataset_name 222 | 223 | epoch=0 224 | b1=0.5 225 | b2=0.999 226 | n_cpu=0 227 | channels=1 228 | n_critic=5 229 | img_shape = (channels, img_size, img_size) 230 | cuda = True if torch.cuda.is_available() else False 231 | # Loss function 232 | cycle_loss = torch.nn.L1Loss() 233 | # Loss weights 234 | lambda_adv = 1 235 | lambda_cycle = 10 236 | lambda_gp = 10 237 | # Initialize generator and discriminator 238 | G_AB = Generator() 239 | G_BA = Generator() 240 | D_A = Discriminator() 241 | D_B = Discriminator() 242 | if cuda: 243 | G_AB.cuda() 244 | G_BA.cuda() 245 | D_A.cuda() 246 | D_B.cuda() 247 | cycle_loss.cuda() 248 | if pre_trained==True: 249 | if trained_epoch != 0: 250 | G_AB.load_state_dict(torch.load(disco_saved_models_path+"G_AB_%d.pth" % (trained_epoch))) 251 | G_BA.load_state_dict(torch.load(disco_saved_models_path+"G_BA_%d.pth" % (trained_epoch))) 252 | D_A.load_state_dict(torch.load(disco_saved_models_path+"D_A_%d.pth" % (trained_epoch))) 253 | D_B.load_state_dict(torch.load(disco_saved_models_path+"D_B_%d.pth" % (trained_epoch))) 254 | else: 255 | G_AB.apply(weights_init_normal) 256 | G_BA.apply(weights_init_normal) 257 | D_A.apply(weights_init_normal) 258 | D_B.apply(weights_init_normal) 259 | else: 260 | G_AB.apply(weights_init_normal) 261 | G_BA.apply(weights_init_normal) 262 | D_A.apply(weights_init_normal) 263 | D_B.apply(weights_init_normal) 264 | 265 | # Configure data loader 266 | transforms_ = [ 267 | transforms.Resize((img_size, img_size), Image.BICUBIC), 268 | transforms.ToTensor(), 269 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 270 | transforms.Normalize((0.1307,), (0.3081,)) 271 | ] 272 | dataloader = DataLoader( 273 | ImageDataset(dataset_path,mode='train', transforms_=transforms_), 274 | batch_size=batch_size, 275 | shuffle=True, 276 | num_workers=n_cpu, 277 | ) 278 | val_dataloader = DataLoader( 279 | ImageDataset(dataset_path, mode="val", transforms_=transforms_), 280 | batch_size=3, 281 | shuffle=True, 282 | num_workers=0, 283 | ) 284 | # Optimizers 285 | optimizer_G = torch.optim.Adam( 286 | itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2) 287 | ) 288 | optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2)) 289 | optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2)) 290 | FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 291 | LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor 292 | # ---------- 293 | # Training 294 | # ---------- 295 | batches_done = 0 296 | prev_time = time.time() 297 | for epoch in range(0,n_epochs): 298 | for i, batch in enumerate(dataloader): 299 | # Configure input 300 | imgs_A = Variable(batch["A"].type(FloatTensor)) 301 | imgs_B = Variable(batch["B"].type(FloatTensor)) 302 | # ---------------------- 303 | # Train Discriminators 304 | # ---------------------- 305 | optimizer_D_A.zero_grad() 306 | optimizer_D_B.zero_grad() 307 | # Generate a batch of images 308 | fake_A = G_BA(imgs_B).detach() 309 | fake_B = G_AB(imgs_A).detach() 310 | # ---------- 311 | # Domain A 312 | # ---------- 313 | # Compute gradient penalty for improved wasserstein training 314 | gp_A = compute_gradient_penalty(D_A, imgs_A.data, fake_A.data) 315 | # Adversarial loss 316 | D_A_loss = -torch.mean(D_A(imgs_A)) + torch.mean(D_A(fake_A)) + lambda_gp * gp_A 317 | # ---------- 318 | # Domain B 319 | # ---------- 320 | # Compute gradient penalty for improved wasserstein training 321 | gp_B = compute_gradient_penalty(D_B, imgs_B.data, fake_B.data) 322 | # Adversarial loss 323 | D_B_loss = -torch.mean(D_B(imgs_B)) + torch.mean(D_B(fake_B)) + lambda_gp * gp_B 324 | # Total loss 325 | D_loss = D_A_loss + D_B_loss 326 | D_loss.backward() 327 | optimizer_D_A.step() 328 | optimizer_D_B.step() 329 | if i % n_critic == 0: 330 | # ------------------ 331 | # Train Generators 332 | # ------------------ 333 | optimizer_G.zero_grad() 334 | # Translate images to opposite domain 335 | fake_A = G_BA(imgs_B) 336 | fake_B = G_AB(imgs_A) 337 | # Reconstruct images 338 | recov_A = G_BA(fake_B) 339 | recov_B = G_AB(fake_A) 340 | # Adversarial loss 341 | G_adv = -torch.mean(D_A(fake_A)) - torch.mean(D_B(fake_B)) 342 | # Cycle loss 343 | G_cycle = cycle_loss(recov_A, imgs_A) + cycle_loss(recov_B, imgs_B) 344 | # Total loss 345 | G_loss = lambda_adv * G_adv + lambda_cycle * G_cycle 346 | G_loss.backward() 347 | optimizer_G.step() 348 | # -------------- 349 | # Log Progress 350 | # -------------- 351 | # Determine approximate time left 352 | batches_left = n_epochs * len(dataloader) - batches_done 353 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time) / n_critic) 354 | prev_time = time.time() 355 | print( 356 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, cycle: %f] ETA: %s" 357 | % ( 358 | epoch, 359 | n_epochs, 360 | i, 361 | len(dataloader), 362 | D_loss.item(), 363 | G_adv.data.item(), 364 | G_cycle.item(), 365 | time_left, 366 | ) 367 | ) 368 | f=open(save_path+'dual_process.txt','a') 369 | f.write( 370 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, cycle: %f] ETA: %s" 371 | % ( 372 | epoch, 373 | n_epochs, 374 | i, 375 | len(dataloader), 376 | D_loss.item(), 377 | G_adv.data.item(), 378 | G_cycle.item(), 379 | time_left, 380 | ) 381 | ) 382 | f.close() 383 | # Check sample interval => save sample if there 384 | if batches_done % sample_interval == 0: 385 | sample_images(batches_done,path=dual_images_path) 386 | batches_done += 1 387 | if checkpoint_interval != -1 and epoch % checkpoint_interval == 0: 388 | # Save model checkpoints 389 | torch.save(G_AB.state_dict(), dual_saved_models_path+"G_AB_%d.pth" % (epoch)) 390 | torch.save(G_BA.state_dict(), dual_saved_models_path+"G_BA_%d.pth" % (epoch)) 391 | torch.save(D_A.state_dict(), dual_saved_models_path+"D_A_%d.pth" % (epoch)) 392 | torch.save(D_B.state_dict(), dual_saved_models_path+"D_B_%d.pth" % (epoch)) 393 | 394 | ''' 395 | 12. Dual_GAN generate from promoted translation style 396 | Dual_GAN(out_folder=out_folder, 397 | dataset_name=dataset_name, 398 | dataset_path=dataset_path, 399 | checkpoint_interval=checkpoint_interval, 400 | sample_interval=sample_interval, 401 | n_epochs=n_epochs, 402 | batch_size= batch_size, 403 | lr=learning_rate, 404 | img_size=generate_image_size, 405 | channels=channels, 406 | pre_trained=pre_trained, 407 | trained_epoch=trained_epoch) 408 | ''' 409 | -------------------------------------------------------------------------------- /DCGAN.py: -------------------------------------------------------------------------------- 1 | #DCGAN train and generate 2 | # 1. train 3 | # 2. generate 4 | import tifffile 5 | import h5py 6 | import torch.utils.data 7 | from torch import Tensor 8 | from os import listdir 9 | from os.path import join 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import os 15 | import random 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | import torchvision.datasets as dset 19 | import torchvision.transforms as transforms 20 | import torchvision.utils as vutils 21 | from torch.autograd import Variable 22 | 23 | from scipy.ndimage.filters import median_filter 24 | from skimage.filters import threshold_otsu 25 | from collections import Counter 26 | 27 | 28 | ####################需要提供的数据FOR TRAIN#################### 29 | out_folder = 'H:/tmp/DEEP_WATERROCK_CODE/codetest/' #数据存储路径 30 | dataset_name='berea' 31 | device='cuda' 32 | manualSeed=43 33 | imageSize=64 34 | batchSize=32 35 | number_generator_feature=64 36 | number_discriminator_feature=32 37 | number_z=512 38 | number_train_iterations=10 39 | number_gpu=1 40 | ####################需要提供的数据FOR GENERATE################## 41 | seedmin=62 42 | seedmax=64 43 | netG='H:/tmp/DEEP_WATERROCK_CODE/codetest/DCGAN/result/netG/netG_epoch_9.pth' 44 | generate_name='test' 45 | image_generate_size=4 46 | ############################################################## 47 | #判别器 48 | class DCGAN3d_D(nn.Container): 49 | def __init__(self, 50 | image_size, #进入判别器中的图片大小 51 | dimension_n, #nz latent space纬度 52 | channel_in, #nc 进入管道数 53 | D_feature_number, #ndf 判别网络中的初始feature数 54 | gpu_number, 55 | extra_layers_number=0): 56 | super(DCGAN3d_D,self).__init__() 57 | self.gpu_number=gpu_number 58 | assert image_size % 16 ==0,'image size has to be a multiple of 16' 59 | 60 | D=nn.Sequential( 61 | nn.Conv3d(channel_in,D_feature_number,4,2,1,bias=False), 62 | nn.LeakyReLU(0.2,inplace=True), 63 | ) 64 | i=3 65 | next_size=image_size/2 66 | next_D_feature_number=D_feature_number 67 | 68 | #build next other layers 69 | for t in range(extra_layers_number): 70 | D.add_module(str(i), 71 | nn.Conv3d(next_D_feature_number, 72 | next_D_feature_number, 73 | 3,1,1,bias=False)) 74 | D.add_module(str(i+1), 75 | nn.BatchNorm3d(next_D_feature_number)) 76 | D.add_module(str(i+2), 77 | nn.LeakyReLU(0.2,inplace=True)) 78 | i+=3 79 | while next_size>4: 80 | in_feat=next_D_feature_number 81 | out_feat=next_D_feature_number * 2 82 | D.add_module(str(i), 83 | nn.Conv3d(in_feat,out_feat,4,2,1,bias=False)) 84 | D.add_module(str(i+1), 85 | nn.BatchNorm3d(out_feat)) 86 | D.add_module(str(i+2), 87 | nn.LeakyReLU(0.2,inplace=True)) 88 | i+=3 89 | next_D_feature_number=next_D_feature_number * 2 90 | next_size=next_size/2 91 | D.add_module(str(i), 92 | nn.Conv3d(next_D_feature_number,1,4,1,0,bias=False)) 93 | D.add_module(str(i+1), 94 | nn.Sigmoid()) 95 | self.D=D 96 | 97 | def forward(self,input): 98 | gpu_ids=None 99 | if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_number > 1: 100 | gpu_ids = range(self.gpu_number) 101 | output=nn.parallel.data_parallel(self.D,input,gpu_ids) 102 | return output.view(-1,1) 103 | #生成器 104 | class DCGAN3d_G(nn.Container): 105 | def __init__(self, 106 | image_size, 107 | dimension_n, 108 | channel_in, 109 | G_feature_number, #ngf 生成网络中的初始feature数 110 | gpu_number, 111 | extra_layers_number=0): 112 | super(DCGAN3d_G,self).__init__() 113 | self.gpu_number=gpu_number 114 | assert image_size % 16 ==0, "image size has to be a multiple of 16" 115 | 116 | next_G_feature_number=G_feature_number//2 117 | end_image_size=4 118 | 119 | while end_image_size!=image_size: 120 | next_G_feature_number=next_G_feature_number * 2 121 | end_image_size = end_image_size * 2 122 | 123 | G=nn.Sequential( 124 | nn.ConvTranspose3d(dimension_n,next_G_feature_number,4,1,0,bias=False), 125 | nn.BatchNorm3d(next_G_feature_number), 126 | nn.ReLU(True), 127 | ) 128 | i=3 129 | next_size=4 130 | next_G_feature_number=next_G_feature_number 131 | 132 | while next_size 1: 164 | gpu_ids = range(self.gpu_number) 165 | return nn.parallel.data_parallel(self.G, input, gpu_ids) 166 | 167 | class DCGAN3D_G_CPU(nn.Container): 168 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): 169 | super(DCGAN3D_G_CPU, self).__init__() 170 | self.ngpu = ngpu 171 | assert isize % 16 == 0, "isize has to be a multiple of 16" 172 | 173 | cngf, tisize = ngf//2, 4 174 | while tisize != isize: 175 | cngf = cngf * 2 176 | tisize = tisize * 2 177 | 178 | main = nn.Sequential( 179 | # input is Z, going into a convolution 180 | nn.ConvTranspose3d(nz, cngf, 4, 1, 0, bias=True), 181 | nn.BatchNorm3d(cngf), 182 | nn.ReLU(True), 183 | ) 184 | 185 | i, csize, cndf = 3, 4, cngf 186 | while csize < isize//2: 187 | main.add_module(str(i), 188 | nn.ConvTranspose3d(cngf, cngf//2, 4, 2, 1, bias=True)) 189 | main.add_module(str(i+1), 190 | nn.BatchNorm3d(cngf//2)) 191 | main.add_module(str(i+2), 192 | nn.ReLU(True)) 193 | i += 3 194 | cngf = cngf // 2 195 | csize = csize * 2 196 | 197 | # Extra layers 198 | for t in range(n_extra_layers): 199 | main.add_module(str(i), 200 | nn.Conv3d(cngf, cngf, 3, 1, 1, bias=True)) 201 | main.add_module(str(i+1), 202 | nn.BatchNorm3d(cngf)) 203 | main.add_module(str(i+2), 204 | nn.ReLU(True)) 205 | i += 3 206 | 207 | main.add_module(str(i), 208 | nn.ConvTranspose3d(cngf, nc, 4, 2, 1, bias=True)) 209 | main.add_module(str(i+1), nn.Tanh()) 210 | self.main = main 211 | 212 | def forward(self, input): 213 | return self.main(input) 214 | 215 | def save_hdf5(tensor, filename): 216 | tensor = tensor.cpu() 217 | ndarr = tensor.mul(0.5).add(0.5).mul(255).byte().numpy() 218 | with h5py.File(filename, 'w') as f: 219 | f.create_dataset('data', data=ndarr, dtype="i8", compression="gzip") 220 | 221 | def is_image_file(filename): 222 | return any(filename.endswith(extension) for extension in [".hdf5", ".h5"]) 223 | 224 | def load_img(filepath): 225 | img = None 226 | with h5py.File(filepath, "r") as f: 227 | img = f['data'][()] 228 | img = np.expand_dims(img, axis=0) 229 | torch_img = Tensor(img) 230 | torch_img = torch_img.div(255).sub(0.5).div(0.5) 231 | return torch_img 232 | 233 | def weights_init(m): 234 | classname = m.__class__.__name__ 235 | if classname.find('Conv') != -1: 236 | m.weight.data.normal_(0.0, 0.02) 237 | elif classname.find('BatchNorm') != -1: 238 | m.weight.data.normal_(1.0, 0.02) 239 | m.bias.data.fill_(0) 240 | 241 | class HDF5Dataset(torch.utils.data.Dataset): 242 | def __init__(self, image_dir, input_transform=None, target_transform=None): 243 | super(HDF5Dataset, self).__init__() 244 | self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] 245 | self.input_transform = input_transform 246 | self.target_transform = target_transform 247 | 248 | def __getitem__(self, index): 249 | input = load_img(self.image_filenames[index]) 250 | target = None 251 | 252 | return input 253 | 254 | def __len__(self): 255 | return len(self.image_filenames) 256 | 257 | ###creat training images 258 | def train_dataset_preprocess(dataset_name): 259 | tiff_path=out_folder+dataset_name+'.tif' 260 | edge_length=64 261 | stride=32 262 | train_images_path=out_folder+'DCGAN/train_images/' 263 | try: 264 | os.makedirs(out_folder+'DCGAN/train_images') 265 | except OSError: 266 | pass 267 | img=tifffile.imread(tiff_path) 268 | N = edge_length 269 | M = edge_length 270 | O = edge_length 271 | I_inc = stride 272 | J_inc = stride 273 | K_inc = stride 274 | count = 0 275 | for i in range(0, img.shape[0], I_inc): 276 | for j in range(0, img.shape[1], J_inc): 277 | for k in range(0, img.shape[2], K_inc): 278 | subset = img[i:i+N, j:j+N, k:k+O] 279 | if subset.shape == (N, M, O): 280 | f = h5py.File(train_images_path+"/"+str(dataset_name)+"_"+str(count)+".hdf5", "w") 281 | f.create_dataset('data', data=subset, dtype="i8", compression="gzip") 282 | f.close() 283 | count += 1 284 | print('Generate images/dataset number count:',count) 285 | return train_images_path 286 | 287 | def DCGAN_train(imageSize, 288 | batchSize, 289 | ngf, 290 | ndf, 291 | nz, 292 | niter, 293 | ngpu, 294 | manualSeed, 295 | out_folder, 296 | dataset_name, 297 | device): 298 | data_root=train_dataset_preprocess(dataset_name) 299 | lr=1e-5 300 | workers=0 301 | nc=1 302 | criterion=nn.BCELoss() 303 | result_path=out_folder+'DCGAN/result/' 304 | outf=out_folder+'DCGAN/output/' 305 | try: 306 | os.makedirs(out_folder+'DCGAN/output') 307 | os.makedirs(out_folder+'DCGAN/result') 308 | except OSError: 309 | pass 310 | np.random.seed(43) 311 | random.seed(manualSeed) 312 | torch.manual_seed(manualSeed) 313 | cudnn.benchmark=True 314 | if torch.cuda.is_available() and device!='cuda': 315 | print("WARNING: You have a CUDA device, so you should probably run with device='cuda'") 316 | if dataset_name in ['berea']: 317 | dataset=HDF5Dataset(data_root, 318 | input_transform=transforms.Compose([transforms.ToTensor()])) 319 | assert dataset 320 | dataloader=torch.utils.data.DataLoader(dataset,batch_size=batchSize,shuffle=True,num_workers=int(workers)) 321 | 322 | netG=DCGAN3d_G(imageSize,nz,nc,ngf,ngpu) 323 | netG.apply(weights_init) 324 | print(netG) 325 | netD=DCGAN3d_D(imageSize,nz,nc,ndf,ngpu) 326 | netD.apply(weights_init) 327 | print(netD) 328 | 329 | input,noise,fixed_noise,fixed_noise_TI=None,None,None,None 330 | input=torch.FloatTensor(batchSize,nc,imageSize,imageSize,imageSize) 331 | noise=torch.FloatTensor(batchSize,nz,1,1,1) 332 | fixed_noise=torch.FloatTensor(1,nz,7,7,7).normal_(0,1) 333 | fixed_noise_TI=torch.FloatTensor(1,nz,1,1,1).normal_(0,1) 334 | label=torch.FloatTensor(batchSize) 335 | real_label=0.9 336 | fake_label=0 337 | 338 | if device=='cuda': 339 | netD.cuda() 340 | netG.cuda() 341 | criterion.cuda() 342 | input, label = input.cuda(), label.cuda() 343 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 344 | fixed_noise_TI = fixed_noise_TI.cuda() 345 | input = Variable(input) #变量可修改 346 | label = Variable(label) #变量可修改 347 | noise = Variable(noise) #变量可修改 348 | fixed_noise=Variable(fixed_noise) 349 | fixed_noise_TI=Variable(fixed_noise_TI) 350 | 351 | optimizerD=optim.Adam(netD.parameters(),lr=lr,betas=(0.5,0.999)) 352 | optimizerG=optim.Adam(netG.parameters(),lr=lr,betas=(0.5,0.999)) 353 | #main part 354 | gen_iterations=0 355 | G_loss=[] 356 | D_loss=[] 357 | iters=0 358 | for epoch in range(niter): 359 | print('This is the ',epoch,'-th') 360 | for i,data in enumerate(dataloader,0): 361 | f=open(result_path+'training_curve.scv','a') 362 | netD.zero_grad() 363 | real_cpu=data.to(device) 364 | batch_size=real_cpu.size(0) 365 | label=torch.full((batch_size,),real_label,device=device) 366 | output=netD(real_cpu).view(-1) 367 | errD_real=criterion(output,label) 368 | errD_real.backward() 369 | D_x=output.mean().item() 370 | 371 | noise=torch.randn(batch_size,nz,1,1,1,device=device) 372 | fake=netG(noise) 373 | label.fill_(fake_label) 374 | output = netD(fake.detach()).view(-1) 375 | errD_fake = criterion(output, label) 376 | errD_fake.backward() 377 | D_G_z1 = output.mean().item() 378 | errD = errD_real + errD_fake 379 | optimizerD.step() 380 | 381 | #生成器 382 | netG.zero_grad() 383 | label.fill_(1.0) 384 | noise2=torch.randn(batch_size,nz,1,1,1,device=device) 385 | fake2=netG(noise2) 386 | output = netD(fake2).view(-1) 387 | errG = criterion(output, label) 388 | errG.backward() 389 | D_G_z2 = output.mean().item() 390 | optimizerG.step() 391 | 392 | gen_iterations+=1 393 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 394 | % (epoch, niter, i, len(dataloader), 395 | errD.data, errG.data, D_x, D_G_z1, D_G_z2)) 396 | f.write('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 397 | % (epoch, niter, i, len(dataloader), 398 | errD.data, errG.data, D_x, D_G_z1, D_G_z2)) 399 | f.write('\n') 400 | f.close() 401 | 402 | fake = netG(fixed_noise) 403 | fake_TI = netG(fixed_noise_TI) 404 | try: 405 | os.makedirs(result_path+'fake_samples') 406 | os.makedirs(result_path+'fake_TI') 407 | except OSError: 408 | pass 409 | save_hdf5(fake.data, result_path+'fake_samples/'+'fake_samples_{0}.hdf5'.format(gen_iterations)) 410 | save_hdf5(fake_TI.data, result_path+'fake_TI/'+'fake_TI_{0}.hdf5'.format(gen_iterations)) 411 | # do checkpointing 412 | try: 413 | os.makedirs(result_path+'netG') 414 | os.makedirs(result_path+'netD') 415 | except OSError: 416 | pass 417 | torch.save(netG.state_dict(), result_path+'netG/'+'netG_epoch_%d.pth' % (epoch)) 418 | torch.save(netD.state_dict(), result_path+'netD/'+'netD_epoch_%d.pth' % (epoch)) 419 | #record loss 420 | G_loss.append(errG.item()) 421 | D_loss.append(errD.item()) 422 | iters+=1 423 | f=open(result_path+'Loss_log.txt','a') 424 | f.write('G_loss:') 425 | f.write('\n') 426 | for k in range(len(G_loss)): 427 | f.write(str(G_loss[k])) 428 | f.write('\n') 429 | f.write('D_loss:') 430 | f.write('\n') 431 | for k in range(len(D_loss)): 432 | f.write(str(D_loss[k])) 433 | f.write('\n') 434 | f.close() 435 | 436 | def DCGAN_generator(seedmin, 437 | seedmax, 438 | ngf, 439 | ndf, 440 | nz, 441 | ngpu, 442 | imageSize, 443 | imsize, 444 | out_folder, 445 | name, 446 | device, 447 | netG, 448 | ): 449 | 450 | if name is None: 451 | name = 'samples' 452 | try: 453 | os.makedirs(out_folder+'DCGAN/output/'+name) 454 | except OSError: 455 | pass 456 | outf=out_folder+'DCGAN/output/' 457 | 458 | for seed in range(seedmin, seedmax, 1): 459 | random.seed(seed) 460 | torch.manual_seed(seed) 461 | cudnn.benchmark = True 462 | ngpu = int(ngpu) 463 | nz = int(nz) 464 | ngf = int(ngf) 465 | ndf = int(ndf) 466 | nc = 1 467 | 468 | net = DCGAN3d_G(imageSize, nz, nc, ngf, ngpu) 469 | net.apply(weights_init) 470 | net.load_state_dict(torch.load(netG)) 471 | print(net) 472 | 473 | fixed_noise = torch.FloatTensor(1, nz, imsize, imsize, imsize).normal_(0, 1) 474 | if device=='cuda': 475 | net.cuda() 476 | fixed_noise = fixed_noise.cuda() 477 | fixed_noise = Variable(fixed_noise) 478 | fake = net(fixed_noise) 479 | save_hdf5(fake.data, '{0}/{1}_{2}.hdf5'.format(outf+name, name, seed)) 480 | 481 | def result_analysis(out_folder,generate_name): 482 | path=out_folder+'DCGAN/output/'+generate_name+'/' 483 | tiff_name=generate_name+'_tiff' 484 | datalist=os.listdir(path) 485 | try: 486 | os.makedirs(out_folder+'DCGAN/output/'+tiff_name) 487 | except OSError: 488 | pass 489 | for img in datalist: 490 | f=h5py.File(path+img,'r') 491 | array=f['data'][()] 492 | tiff=array[0,0,:,:,:].astype(np.float32) 493 | tifffile.imsave(out_folder+'DCGAN/output/{0}/{1}.tiff'.format(tiff_name,img[:-5]),tiff) 494 | 495 | path2=out_folder+'DCGAN/output/'+tiff_name 496 | tifflist=os.listdir(path2) 497 | for img in tifflist: 498 | f=open(out_folder+'DCGAN/output/'+generate_name+'_log.txt','a') 499 | im_in=tifffile.imread(path2+'/'+img) 500 | im_in=median_filter(im_in,size=(3,3,3)) 501 | im_in=im_in[40:240,40:240,40:240] 502 | im_in=im_in/255. 503 | threshold_global_otsu=threshold_otsu(im_in) 504 | segmented_image=(im_in>=threshold_global_otsu).astype(np.int32) 505 | porc=Counter(segmented_image.flatten()) 506 | porosity=porc[0]/(porc[0]+porc[1]) 507 | print(img[:-5],' porosity: ',porosity) 508 | f.write(str(img[:-5])+' porosity: '+str(porosity)) 509 | f.write('\n') 510 | f.close() 511 | ''' 512 | 9. DCGAN train 513 | 514 | DCGAN_train(imageSize=imageSize,batchSize=batchSize, 515 | ngf=number_generator_feature, 516 | ndf=number_discriminator_feature, 517 | nz=number_z, 518 | niter=number_train_iterations, 519 | ngpu=number_gpu, 520 | manualSeed=manualSeed, 521 | out_folder=out_folder, 522 | dataset_name=dataset_name, 523 | device=device) 524 | 525 | ''' 526 | ''' 527 | 10. DCGAN generate 528 | 529 | DCGAN_generator(seedmin=seedmin, 530 | seedmax=seedmax, 531 | ngf=number_generator_feature, 532 | ndf=number_discriminator_feature, 533 | nz=number_z, 534 | ngpu=number_gpu, 535 | imageSize=imageSize, 536 | imsize=image_generate_size, 537 | out_folder=out_folder, 538 | name=generate_name, 539 | device=device, 540 | netG=netG, 541 | ) 542 | ''' 543 | ''' 544 | 11. DCGAN batch processing samples statistic 545 | 546 | result_analysis(out_folder,generate_name) 547 | ''' 548 | -------------------------------------------------------------------------------- /Segmentation.py: -------------------------------------------------------------------------------- 1 | #Segmentation 2 | # 1. train and test 3 | # 2. test and viz using trained model parameters 4 | 5 | import os 6 | import torch 7 | import json 8 | import labelme 9 | import numpy as np 10 | import cv2 11 | import torch.nn as nn 12 | from PIL import Image 13 | import segmentation_models_pytorch as smp 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data import Dataset as BaseDataset 16 | import albumentations as albu 17 | import matplotlib.pyplot as plt 18 | import torchvision.transforms.functional as tf 19 | from Image_preprocess import * 20 | 21 | #from Augmentation import * 22 | 23 | ####################需要提供的数据#################### 24 | out_folder = 'H:/tmp/DEEP_WATERROCK_CODE/codetest/' #数据存储路径 25 | label_number=5 #想要标记图片的数量 26 | json_path='H:/tmp/json/' #labelme之后翻译得到的地址,路径不要出现中文 27 | dataset_choose='obstacle' #check in dataset,choose what you want to extract 28 | ##################################################### 29 | 30 | ###############训练过程中需要定义的参数############## 31 | model_name='Unet' 32 | Encoders=['resnet18','vgg16'] 33 | Activation='sigmoid' #sigmoid,relu,tanh 34 | Encoder_weights ='imagenet' 35 | Epoch=10 36 | train_batch_size=3 37 | ##################################################### 38 | 39 | class Dataset(BaseDataset): 40 | CLASSES = ['obstacle','soil','media'] 41 | def __init__( 42 | self, 43 | images_dir, 44 | masks_dir, 45 | classes=None, 46 | augmentation=None, 47 | preprocessing=None, 48 | ): 49 | self.ids = os.listdir(images_dir) 50 | self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids] 51 | self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids] 52 | self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes] 53 | self.augmentation = augmentation 54 | self.preprocessing = preprocessing 55 | def __getitem__(self, i): 56 | image = cv2.imread(self.images_fps[i]) 57 | image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) 58 | mask = cv2.imread(self.masks_fps[i], 0) 59 | masks = [(mask == v) for v in self.class_values] 60 | mask = np.stack(masks, axis=-1).astype('float') 61 | if self.augmentation: 62 | sample = self.augmentation(image=image, mask=mask) 63 | image, mask = sample['image'], sample['mask'] 64 | if self.preprocessing: 65 | sample = self.preprocessing(image=image, mask=mask) 66 | image, mask = sample['image'], sample['mask'] 67 | return image, mask 68 | def __len__(self): 69 | return len(self.ids) 70 | 71 | class Dataset2(BaseDataset): 72 | CLASSES = ['obstacle','soil','media'] 73 | def __init__( 74 | self, 75 | images_dir, 76 | masks_dir, 77 | classes=None, 78 | augmentation=None, 79 | preprocessing=None, 80 | ): 81 | self.ids = os.listdir(images_dir) 82 | self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids] 83 | self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids] 84 | self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes] 85 | self.augmentation = augmentation 86 | self.preprocessing = preprocessing 87 | def __getitem__(self, i): 88 | image = cv2.imread(self.images_fps[i]) 89 | image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) 90 | mask = cv2.imread(self.masks_fps[i], 0) 91 | masks = [(mask == v) for v in self.class_values] 92 | mask = np.stack(masks, axis=-1).astype('float') 93 | if self.augmentation: 94 | sample = self.augmentation(image=image, mask=mask) 95 | image, mask = sample['image'], sample['mask'] 96 | if self.preprocessing: 97 | sample = self.preprocessing(image=image, mask=mask) 98 | image, mask = sample['image'], sample['mask'] 99 | return image, mask 100 | def __len__(self): 101 | return len(self.ids) 102 | 103 | def visualize(**images): 104 | n = len(images) 105 | plt.figure(figsize=(8,4)) 106 | for i, (name, image) in enumerate(images.items()): 107 | plt.subplot(1, n, i + 1) 108 | plt.xticks([]) 109 | plt.yticks([]) 110 | plt.title(' '.join(name.split('_')).title()) 111 | plt.imshow(image,cmap=plt.cm.gray) 112 | plt.show() 113 | 114 | def get_training_augmentation(): 115 | train_transform = [ 116 | albu.HorizontalFlip(p=0.5), 117 | albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0), 118 | albu.PadIfNeeded(min_height=800, min_width=800, always_apply=True, border_mode=0), 119 | albu.RandomCrop(height=800, width=800, always_apply=True), 120 | albu.IAAAdditiveGaussianNoise(p=0.2), 121 | albu.IAAPerspective(p=0.5), 122 | albu.OneOf( 123 | [ 124 | albu.CLAHE(p=1), 125 | albu.RandomBrightness(p=1), 126 | albu.RandomGamma(p=1), 127 | ], 128 | p=0.9, 129 | ), 130 | albu.OneOf( 131 | [ 132 | albu.IAASharpen(p=1), 133 | albu.Blur(blur_limit=3, p=1), 134 | albu.MotionBlur(blur_limit=3, p=1), 135 | ], 136 | p=0.9, 137 | ), 138 | albu.OneOf( 139 | [ 140 | albu.RandomContrast(p=1), 141 | albu.HueSaturationValue(p=1), 142 | ], 143 | p=0.9, 144 | ), 145 | ] 146 | return albu.Compose(train_transform) 147 | 148 | def get_validation_augmentation(): 149 | """Add paddings to make image shape divisible by 32""" 150 | test_transform = [ 151 | albu.PadIfNeeded(800,800) 152 | ] 153 | return albu.Compose(test_transform) 154 | 155 | def to_tensor(x, **kwargs): 156 | return x.transpose(2,0,1).astype('float32') 157 | 158 | def get_preprocessing(preprocessing_fn): 159 | _transform = [ 160 | albu.Lambda(image=preprocessing_fn), 161 | albu.Lambda(image=to_tensor, mask=to_tensor), 162 | ] 163 | return albu.Compose(_transform) 164 | 165 | def Model_create(model_name:str,Encoder_name:str,Encoder_weights:str,CLASSES,Activation:str): 166 | Models=['DeeplabV3','Deeplabv3+','FPN','Linknet','Unet','Unet++'] 167 | if model_name=='DeeplabV3': 168 | model=smp.DeepLabV3( 169 | encoder_name=Encoder_name, 170 | encoder_weights=Encoder_weights, 171 | classes=len(CLASSES), 172 | activation=Activation,) 173 | elif model_name=='Deeplabv3+': 174 | model=smp.DeepLabV3Plus( 175 | encoder_name=Encoder_name, 176 | encoder_weights=Encoder_weights, 177 | classes=len(CLASSES), 178 | activation=Activation,) 179 | elif model_name=='FPN': 180 | model=smp.FPN( 181 | encoder_name=Encoder_name, 182 | encoder_weights=Encoder_weights, 183 | classes=len(CLASSES), 184 | activation=Activation,) 185 | elif model_name=='Linknet': 186 | model=smp.Linknet( 187 | encoder_name=Encoder_name, 188 | encoder_weights=Encoder_weights, 189 | classes=len(CLASSES), 190 | activation=Activation,) 191 | elif model_name=='Unet': 192 | model=smp.Unet( 193 | encoder_name=Encoder_name, 194 | encoder_weights=Encoder_weights, 195 | classes=len(CLASSES), 196 | activation=Activation,) 197 | elif model_name=='Unet++': 198 | model=smp.UnetPlusPlus( 199 | encoder_name=Encoder_name, 200 | encoder_weights=Encoder_weights, 201 | classes=len(CLASSES), 202 | activation=Activation,) 203 | else: 204 | print('Please select a model from the following list\n',Models) 205 | return model 206 | 207 | def cuda_is_available(): 208 | if torch.cuda.is_available(): 209 | gpu_num=torch.cuda.device_count() 210 | if gpu_num==1: 211 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 212 | device='cuda' 213 | device_ids=[0] 214 | elif gpu_num==2: 215 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 216 | device='cuda' 217 | device_ids=[0,1] 218 | elif gpu_num==3: 219 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' 220 | device='cuda' 221 | device_ids=[0,1,2] 222 | else: 223 | device='cpu' 224 | device_ids=[0] 225 | return device,device_ids 226 | 227 | def get_classes(json_path,label_number): 228 | files=os.listdir(json_path) 229 | file_list=[] 230 | for file in files: 231 | file_list.append(file) 232 | json_list=file_list[0:label_number] 233 | jsonfile_list=file_list[label_number:2*label_number] 234 | json_file=json_list[0] 235 | img_file=jsonfile_list[0] 236 | data=json.load(open(json_path+json_file)) 237 | img=cv2.imread(json_path+img_file+'/img.png') 238 | lbl, lbl_names = labelme.utils.labelme_shapes_to_label(img.shape, data['shapes']) 239 | class_name=[] 240 | for name in lbl_names: 241 | class_name.append(name) 242 | return class_name[1:] 243 | 244 | def train_dataset_choose(out_folder,dataset_choose): 245 | dataset_path=out_folder+'dataset/'+dataset_choose 246 | return dataset_path 247 | 248 | def get_name(n): 249 | col=n%10 250 | row=int(n/10) 251 | name='row='+str(row)+'_col='+str(col) 252 | return name 253 | 254 | def Segmentation_train(Encoders,Encoder_weights,model_name,Activation,Epochs,batch_size:int, 255 | dataset_choose,out_folder,json_path,label_number): 256 | dataset_path=train_dataset_choose(out_folder,dataset_choose) 257 | x_train_dir = os.path.join(dataset_path, 'train') 258 | y_train_dir = os.path.join(dataset_path, 'train_mask') 259 | x_valid_dir = os.path.join(dataset_path, 'test') 260 | y_valid_dir = os.path.join(dataset_path, 'test_mask') 261 | x_test_dir = os.path.join(dataset_path, 'val') 262 | y_test_dir = os.path.join(dataset_path, 'val_mask') 263 | loss = smp.utils.losses.DiceLoss() 264 | #记录指标 265 | metrics = [smp.utils.metrics.IoU(threshold=0.5), 266 | smp.utils.metrics.Accuracy(), 267 | smp.utils.metrics.Recall(), 268 | smp.utils.metrics.Precision(), 269 | smp.utils.metrics.Fscore(),] 270 | device,device_ids=cuda_is_available() 271 | Classes=get_classes(json_path,label_number) 272 | for i in range(len(Encoders)): 273 | model=Model_create(model_name=model_name, 274 | Encoder_name=Encoders[i], 275 | Encoder_weights=Encoder_weights, 276 | CLASSES=Classes, 277 | Activation=Activation) 278 | if torch.cuda.is_available(): 279 | model=torch.nn.DataParallel(model,device_ids=device_ids) 280 | else: 281 | model=model 282 | preprocessing_fn = smp.encoders.get_preprocessing_fn(Encoders[i], Encoder_weights) 283 | train_dataset = Dataset(x_train_dir, 284 | y_train_dir, 285 | augmentation=get_training_augmentation(), 286 | preprocessing=get_preprocessing(preprocessing_fn), 287 | classes=Classes, 288 | ) 289 | valid_dataset = Dataset(x_valid_dir, 290 | y_valid_dir, 291 | augmentation=get_validation_augmentation(), 292 | preprocessing=get_preprocessing(preprocessing_fn), 293 | classes=Classes, 294 | ) 295 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 296 | valid_loader = DataLoader(valid_dataset, batch_size=int(0.8*batch_size), shuffle=False) 297 | optimizer=torch.optim.Adam([dict(params=model.parameters(),lr=0.0001)]) 298 | train_epoch=smp.utils.train.TrainEpoch(model, 299 | loss=loss, 300 | metrics=metrics, 301 | optimizer=optimizer, 302 | device=device, 303 | verbose=True,) 304 | valid_epoch=smp.utils.train.ValidEpoch(model, 305 | loss=loss, 306 | metrics=metrics, 307 | device=device, 308 | verbose=True,) 309 | max_score=0 310 | model_save_path=out_folder+'model/'+dataset_choose+'/'+model_name+'/' 311 | try: 312 | os.makedirs(model_save_path) 313 | except OSError: 314 | pass 315 | #train 316 | for j in range(0,Epochs): 317 | print('\nEpoch: {}'.format(j)) 318 | train_logs = train_epoch.run(train_loader) 319 | valid_logs = valid_epoch.run(valid_loader) 320 | if max_score < valid_logs['iou_score']: 321 | max_score = valid_logs['iou_score'] 322 | torch.save(model,model_save_path+Encoders[i]+'_best_model.pth') 323 | print('Model saved!') 324 | if j == 25: 325 | optimizer.param_groups[0]['lr'] = 1e-5 326 | print('Decrease decoder learning rate to 1e-5!') 327 | print('The final results of :Unet++_',Encoders[i]) 328 | print('All models have been trained and generated.') 329 | 330 | def Segmentation_test(Activation, 331 | dataset_choose,out_folder,json_path,label_number): 332 | dataset_path=train_dataset_choose(out_folder,dataset_choose) 333 | x_test_dir = os.path.join(dataset_path, 'val') 334 | y_test_dir = os.path.join(dataset_path, 'val_mask') 335 | ENCODER_WEIGHTS = 'imagenet' 336 | ENCODER_WEIGHTS_2='instagram' 337 | loss = smp.utils.losses.DiceLoss() 338 | #记录指标 339 | metrics = [smp.utils.metrics.IoU(threshold=0.5), 340 | smp.utils.metrics.Accuracy(), 341 | smp.utils.metrics.Recall(), 342 | smp.utils.metrics.Precision(), 343 | smp.utils.metrics.Fscore(),] 344 | device,device_ids=cuda_is_available() 345 | Classes=get_classes(json_path,label_number) 346 | check_path=out_folder+'model' 347 | print('Check for the existence of paths and models') 348 | model_dirname=[] 349 | model_filename=[] 350 | model_dirname_path=[] 351 | all_model=[] 352 | for parent, dirnames, filenames in os.walk(check_path): 353 | for dirname in dirnames: 354 | print("Model save path check:", parent) 355 | model_dirname.append(dirname) 356 | print("Model name:", dirname) 357 | for filename in filenames: 358 | print("Models Path check:", parent) 359 | model_filename.append(parent+'/'+filename) 360 | print("Encoder:", filename[:-15]) 361 | print('Model test result') 362 | for parent,dirnames,filenames in os.walk(check_path): 363 | for dirname in dirnames: 364 | model_dirname.append(dirname) 365 | model_dirname_path.append(parent+'/'+dirname) 366 | for filename in filenames: 367 | encoder=filename[:-15] 368 | print('-'*60) 369 | print('Architecture',parent[len(check_path):]) 370 | print('parent path:',parent,'\n','trained model name:',filename,'\n','used encoder:',encoder) 371 | state_model=parent+'/'+filename 372 | all_model.append(parent+'/'+filename) 373 | print('model path :',state_model) 374 | if encoder=='resnext101_32x16d': 375 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS_2) 376 | elif encoder == 'resnext101_32x8d': 377 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS_2) 378 | else: 379 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS) 380 | test_dataset = Dataset2( 381 | x_test_dir, 382 | y_test_dir, 383 | augmentation=get_validation_augmentation(), 384 | preprocessing=get_preprocessing(preprocessing_fn), 385 | classes=Classes, 386 | ) 387 | best_model = torch.load(state_model) 388 | test_dataloader = DataLoader(test_dataset) 389 | test_epoch = smp.utils.train.ValidEpoch(model=best_model, 390 | loss=loss, 391 | metrics=metrics, 392 | device=device, 393 | ) 394 | logs = test_epoch.run(test_dataloader) 395 | print('-'*60) 396 | 397 | def Segmentation_result(dataset_choose,out_folder,json_path,label_number): 398 | dataset_path=train_dataset_choose(out_folder,dataset_choose) 399 | device,device_ids=cuda_is_available() 400 | Classes=get_classes(json_path,label_number) 401 | cut_image_original=out_folder+'cut_image' 402 | cut_image_names=[] 403 | for parent,dirnames,filenames in os.walk(cut_image_original): 404 | for file in filenames: 405 | cut_image_names.append(file) 406 | model_dirname=[] 407 | model_dirname_path=[] 408 | model_path=[] 409 | all_model=[] 410 | metrics = [smp.utils.metrics.IoU(threshold=0.5), 411 | smp.utils.metrics.Accuracy(), 412 | smp.utils.metrics.Recall(), 413 | smp.utils.metrics.Precision(), 414 | smp.utils.metrics.Fscore(),] 415 | ENCODER_WEIGHTS = 'imagenet' 416 | ENCODER_WEIGHTS_2='instagram' 417 | ACTIVATION = 'sigmoid' 418 | models_path=out_folder+'model' 419 | models_path_tmp=out_folder+'model/' 420 | for parent,dirnames,filenames in os.walk(models_path): 421 | for dirname in dirnames: 422 | model_dirname.append(dirname) 423 | model_dirname_path.append(parent+'/'+dirname) 424 | for filename in filenames: 425 | encoder=filename[:-15] 426 | print('-'*60) 427 | if encoder=='resnext101_32x16d': 428 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS_2) 429 | elif encoder == 'resnext101_32x8d': 430 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS_2) 431 | elif encoder == 'resnext101_32x16d': 432 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS_2) 433 | else: 434 | preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, ENCODER_WEIGHTS) 435 | architecture_name=parent[len(models_path_tmp):] 436 | print('Architecture',architecture_name) 437 | print('parent path:',parent,'\n','trained model name:',filename,'\n','used encoder:',encoder) 438 | state_model=parent+'/'+filename 439 | all_model.append(parent+'/'+filename) 440 | print('model path :',state_model) 441 | test_dataset_vis = Dataset(cut_image_original, 442 | cut_image_original, 443 | preprocessing=get_preprocessing(preprocessing_fn), 444 | classes=Classes 445 | ) 446 | best_model = torch.load(state_model) 447 | try: 448 | os.makedirs(out_folder+'Segmentation_results/'+parent[len(models_path_tmp):]+'_'+encoder) 449 | except OSError: 450 | pass 451 | savepath=out_folder+'Segmentation_results/'+parent[len(models_path_tmp):]+'_'+encoder+'/' 452 | for i in range(len(cut_image_names)): 453 | name=get_name(i) 454 | image_vis=test_dataset_vis[i][0].astype('uint8') 455 | image,mask=test_dataset_vis[i] 456 | mask=mask.squeeze() 457 | x_tensor=torch.from_numpy(image).to(device).unsqueeze(0) 458 | pr_mask=best_model.module.predict(x_tensor) 459 | pr_mask=pr_mask.squeeze().cpu().numpy().round() 460 | plt.figure(figsize=(5,5),frameon=False) 461 | plt.axis('off') 462 | plt.tight_layout(pad = 0) 463 | if np.asarray(pr_mask).shape==(800,800): 464 | plt.imshow(-pr_mask,cmap=plt.cm.gray) 465 | plt.savefig(savepath+name+'.png',dpi=160) 466 | else: 467 | plt.imshow(-pr_mask[0],cmap=plt.cm.gray) 468 | plt.savefig(savepath+name+'.png',dpi=160) 469 | print(architecture_name,'cut_image results are saved!') 470 | #paste them 471 | past_image_list=os.listdir(savepath) 472 | target=Image.new('RGB',(8000,8000)) 473 | paste_size=800 474 | left_num_p=0 475 | top_num_p=0 476 | img=[] 477 | for n in past_image_list: 478 | img.append(Image.open(savepath+n)) 479 | for i in range(1,11): 480 | left_num_p=0 481 | for j in range(1,11): 482 | a=paste_size*left_num_p #zuo 483 | b=paste_size*top_num_p #shang 484 | c=paste_size*(left_num_p+1) #you 485 | d=paste_size*(top_num_p+1) #xia 486 | target.paste(img[10*(i-1)+j-1], (a, b, c, d)) 487 | left_num_p+=1 488 | top_num_p+=1 489 | target.save(savepath[:-1]+'_result.png') 490 | 491 | ''' 492 | 4. Segmentation train process 493 | Segmentation_train(Encoders=Encoders, 494 | Encoder_weights=Encoder_weights, 495 | model_name=model_name, 496 | Activation=Activation, 497 | Epochs=Epoch, 498 | batch_size=train_batch_size, 499 | dataset_choose=dataset_choose, 500 | out_folder=out_folder, 501 | json_path=json_path, 502 | label_number=label_number) 503 | 504 | 5. Segmentation test process from trained models 505 | Segmentation_test(Activation=Activation, 506 | dataset_choose=dataset_choose, 507 | out_folder=out_folder, 508 | json_path=json_path, 509 | label_number=label_number) 510 | 511 | # 6. Segmentation result on visualization and save result 512 | Segmentation_result(dataset_choose=dataset_choose, 513 | out_folder=out_folder, 514 | json_path=json_path, 515 | label_number=label_number) 516 | 517 | ''' 518 | -------------------------------------------------------------------------------- /Cycle_GAN_cross.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import datetime 4 | import sys 5 | import cv2 6 | import itertools 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import torch 11 | import numpy as np 12 | import glob 13 | import os 14 | from torch.utils.data import Dataset 15 | from PIL import Image 16 | import torchvision.transforms as transforms 17 | from torchvision.utils import save_image, make_grid 18 | from torch.utils.data import DataLoader 19 | from torchvision import datasets 20 | import ot 21 | from scipy.stats import wasserstein_distance 22 | import matplotlib.pyplot as plt 23 | 24 | ###############需要提供的数据############### 25 | #Cycle_GAN 26 | out_folder='H:/tmp/DEEP_WATERROCK_CODE/codetest/' 27 | dataset_name='pore2miropore_cycle' 28 | dataset_path='H:/tmp/DEEP_WATERROCK_CODE/pore2miropore_cycle' 29 | n_epochs=10 30 | decay_epoch=8 31 | batch_size=1 32 | learning_rate=2e-5 33 | Resnet_blocks=9 34 | img_height=256 35 | img_width=256 36 | channels=3 37 | sample_interval=5 38 | checkpoint_interval=1 39 | pre_trained=False 40 | trained_epoch=0 41 | 42 | #SWD 43 | cross_number=[1,10,100] 44 | berea_calc_WD_SWD_datasetpath='H:/tmp/DEEP_WATERROCK_CODE/pore2miropore_cycle/train/B/' 45 | test_loader='H:/tmp/DEEP_WATERROCK_CODE/test_loader' 46 | 47 | #test 48 | G_AB_path='H:/清华大学/论文/论文《DEEP FLOW》王明阳/深度分割重建计算/study_gan_5/229_cross_result/model/dataset_100/G_AB_29.pth' 49 | G_BA_path='H:/清华大学/论文/论文《DEEP FLOW》王明阳/深度分割重建计算/study_gan_5/229_cross_result/model/dataset_100/G_BA_29.pth' 50 | test_result_save_path=out_folder+'/Image2Image_translation' 51 | ########################################## 52 | 53 | class ReplayBuffer: 54 | def __init__(self, max_size=50): 55 | assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful." 56 | self.max_size = max_size 57 | self.data = [] 58 | 59 | def push_and_pop(self, data): 60 | to_return = [] 61 | for element in data.data: 62 | element = torch.unsqueeze(element, 0) 63 | if len(self.data) < self.max_size: 64 | self.data.append(element) 65 | to_return.append(element) 66 | else: 67 | if random.uniform(0, 1) > 0.5: 68 | i = random.randint(0, self.max_size - 1) 69 | to_return.append(self.data[i].clone()) 70 | self.data[i] = element 71 | else: 72 | to_return.append(element) 73 | return Variable(torch.cat(to_return)) 74 | 75 | 76 | class LambdaLR: 77 | def __init__(self, n_epochs, offset, decay_start_epoch): 78 | assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" 79 | self.n_epochs = n_epochs 80 | self.offset = offset 81 | self.decay_start_epoch = decay_start_epoch 82 | 83 | def step(self, epoch): 84 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch) 85 | 86 | def weights_init_normal(m): 87 | classname = m.__class__.__name__ 88 | if classname.find("Conv") != -1: 89 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 90 | if hasattr(m, "bias") and m.bias is not None: 91 | torch.nn.init.constant_(m.bias.data, 0.0) 92 | elif classname.find("BatchNorm2d") != -1: 93 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 94 | torch.nn.init.constant_(m.bias.data, 0.0) 95 | 96 | ############################## 97 | # RESNET 98 | ############################## 99 | 100 | class ResidualBlock(nn.Module): 101 | def __init__(self, in_features): 102 | super(ResidualBlock, self).__init__() 103 | 104 | self.block = nn.Sequential( 105 | nn.ReflectionPad2d(1), 106 | nn.Conv2d(in_features, in_features, 3), 107 | nn.InstanceNorm2d(in_features), 108 | nn.ReLU(inplace=True), 109 | nn.ReflectionPad2d(1), 110 | nn.Conv2d(in_features, in_features, 3), 111 | nn.InstanceNorm2d(in_features), 112 | ) 113 | 114 | def forward(self, x): 115 | return x + self.block(x) 116 | 117 | class GeneratorResNet(nn.Module): 118 | def __init__(self, input_shape, num_residual_blocks): 119 | super(GeneratorResNet, self).__init__() 120 | 121 | channels = input_shape[0] 122 | 123 | # Initial convolution block 124 | out_features = 64 125 | model = [ 126 | nn.ReflectionPad2d(channels), 127 | nn.Conv2d(channels, out_features, 7), 128 | nn.InstanceNorm2d(out_features), 129 | nn.ReLU(inplace=True), 130 | ] 131 | in_features = out_features 132 | 133 | # Downsampling 134 | for _ in range(2): 135 | out_features *= 2 136 | model += [ 137 | nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 138 | nn.InstanceNorm2d(out_features), 139 | nn.ReLU(inplace=True), 140 | ] 141 | in_features = out_features 142 | 143 | # Residual blocks 144 | for _ in range(num_residual_blocks): 145 | model += [ResidualBlock(out_features)] 146 | 147 | # Upsampling 148 | for _ in range(2): 149 | out_features //= 2 150 | model += [ 151 | nn.Upsample(scale_factor=2), 152 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), 153 | nn.InstanceNorm2d(out_features), 154 | nn.ReLU(inplace=True), 155 | ] 156 | in_features = out_features 157 | 158 | # Output layer 159 | model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] 160 | 161 | self.model = nn.Sequential(*model) 162 | 163 | def forward(self, x): 164 | return self.model(x) 165 | 166 | ############################## 167 | # Discriminator 168 | ############################## 169 | 170 | class Discriminator(nn.Module): 171 | def __init__(self, input_shape): 172 | super(Discriminator, self).__init__() 173 | 174 | channels, height, width = input_shape 175 | 176 | # Calculate output shape of image discriminator (PatchGAN) 177 | self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) 178 | 179 | def discriminator_block(in_filters, out_filters, normalize=True): 180 | """Returns downsampling layers of each discriminator block""" 181 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 182 | if normalize: 183 | layers.append(nn.InstanceNorm2d(out_filters)) 184 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 185 | return layers 186 | 187 | self.model = nn.Sequential( 188 | *discriminator_block(channels, 64, normalize=False), 189 | *discriminator_block(64, 128), 190 | *discriminator_block(128, 256), 191 | *discriminator_block(256, 512), 192 | nn.ZeroPad2d((1, 0, 1, 0)), 193 | nn.Conv2d(512, 1, 4, padding=1) 194 | ) 195 | 196 | def forward(self, img): 197 | return self.model(img) 198 | 199 | def to_rgb(image): 200 | rgb_image = Image.new("RGB", image.size) 201 | rgb_image.paste(image) 202 | return rgb_image 203 | 204 | class ImageDataset(Dataset): 205 | def __init__(self, root, transforms_=None, unaligned=False, mode="train"): 206 | self.transform = transforms.Compose(transforms_) 207 | self.unaligned = unaligned 208 | 209 | self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*")) 210 | self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*")) 211 | 212 | def __getitem__(self, index): 213 | image_A = Image.open(self.files_A[index % len(self.files_A)]) 214 | 215 | if self.unaligned: 216 | image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]) 217 | else: 218 | image_B = Image.open(self.files_B[index % len(self.files_B)]) 219 | 220 | # Convert grayscale images to rgb 221 | if image_A.mode != "RGB": 222 | image_A = to_rgb(image_A) 223 | if image_B.mode != "RGB": 224 | image_B = to_rgb(image_B) 225 | 226 | item_A = self.transform(image_A) 227 | item_B = self.transform(image_B) 228 | return {"A": item_A, "B": item_B} 229 | 230 | def __len__(self): 231 | return max(len(self.files_A), len(self.files_B)) 232 | 233 | 234 | def Cycle_GAN(out_folder, 235 | dataset_name, 236 | dataset_path, 237 | checkpoint_interval, 238 | sample_interval, 239 | n_epochs, 240 | batch_size, 241 | lr, 242 | decay_epoch, 243 | n_residual_blocks, 244 | channels, 245 | img_height, 246 | img_width, 247 | pre_trained:bool, 248 | trained_epoch=0): 249 | lambda_cyc=10.0 250 | lambda_id=5.0 251 | b1=0.5 252 | b2=0.999 253 | n_cpu=0 254 | # Create sample and checkpoint directories 255 | save_path=out_folder+'Image2Image_translation/' 256 | os.makedirs(save_path+'cycle_images/%s' % dataset_name, exist_ok=True) 257 | os.makedirs(save_path+'cycle_saved_models/%s' % dataset_name, exist_ok=True) 258 | cycle_images_path=save_path+'cycle_images/%s' % dataset_name 259 | cycle_saved_models_path=save_path+'cycle_saved_models/%s' % dataset_name 260 | # Losses 261 | criterion_GAN = torch.nn.MSELoss() 262 | criterion_cycle = torch.nn.L1Loss() 263 | criterion_identity = torch.nn.L1Loss() 264 | cuda = torch.cuda.is_available() 265 | input_shape = (channels, img_height, img_width) 266 | # Initialize generator and discriminator 267 | G_AB = GeneratorResNet(input_shape, n_residual_blocks) 268 | G_BA = GeneratorResNet(input_shape, n_residual_blocks) 269 | D_A = Discriminator(input_shape) 270 | D_B = Discriminator(input_shape) 271 | if cuda: 272 | G_AB = G_AB.cuda() 273 | G_BA = G_BA.cuda() 274 | D_A = D_A.cuda() 275 | D_B = D_B.cuda() 276 | criterion_GAN.cuda() 277 | criterion_cycle.cuda() 278 | criterion_identity.cuda() 279 | if pre_trained==True: 280 | if trained_epoch != 0: 281 | G_AB.load_state_dict(torch.load(cycle_saved_models_path+"G_AB_%d.pth" % (trained_epoch))) 282 | G_BA.load_state_dict(torch.load(cycle_saved_models_path+"G_BA_%d.pth" % (trained_epoch))) 283 | D_A.load_state_dict(torch.load(cycle_saved_models_path+"D_A_%d.pth" % (trained_epoch))) 284 | D_B.load_state_dict(torch.load(cycle_saved_models_path+"D_B_%d.pth" % (trained_epoch))) 285 | else: 286 | G_AB.apply(weights_init_normal) 287 | G_BA.apply(weights_init_normal) 288 | D_A.apply(weights_init_normal) 289 | D_B.apply(weights_init_normal) 290 | else: 291 | G_AB.apply(weights_init_normal) 292 | G_BA.apply(weights_init_normal) 293 | D_A.apply(weights_init_normal) 294 | D_B.apply(weights_init_normal) 295 | # Optimizers 296 | optimizer_G = torch.optim.Adam( 297 | itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2) 298 | ) 299 | optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2)) 300 | optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2)) 301 | # Learning rate update schedulers 302 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( 303 | optimizer_G, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 304 | ) 305 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( 306 | optimizer_D_A, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 307 | ) 308 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( 309 | optimizer_D_B, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 310 | ) 311 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 312 | # Buffers of previously generated samples 313 | fake_A_buffer = ReplayBuffer() 314 | fake_B_buffer = ReplayBuffer() 315 | # Image transformations 316 | transforms_ = [ 317 | transforms.Resize(int(img_height * 1.12), Image.BICUBIC), 318 | transforms.RandomCrop((img_height, img_width)), 319 | transforms.RandomHorizontalFlip(), 320 | transforms.ToTensor(), 321 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 322 | ] 323 | # Training data loader 324 | dataloader = DataLoader( 325 | ImageDataset(dataset_path , transforms_=transforms_, unaligned=True), 326 | batch_size=batch_size, 327 | shuffle=True, 328 | num_workers=0, 329 | ) 330 | # Test data loader 331 | val_dataloader = DataLoader( 332 | ImageDataset(dataset_path, transforms_=transforms_, unaligned=True, mode="val"), 333 | batch_size=5, 334 | shuffle=True, 335 | num_workers=0, 336 | ) 337 | 338 | def sample_images(batches_done): 339 | """Saves a generated sample from the test set""" 340 | imgs = next(iter(val_dataloader)) 341 | G_AB.eval() 342 | G_BA.eval() 343 | real_A = Variable(imgs["A"].type(Tensor)) 344 | fake_B = G_AB(real_A) 345 | real_B = Variable(imgs["B"].type(Tensor)) 346 | fake_A = G_BA(real_B) 347 | # Arange images along x-axis 348 | real_A = make_grid(real_A, nrow=5, normalize=True) 349 | real_B = make_grid(real_B, nrow=5, normalize=True) 350 | fake_A = make_grid(fake_A, nrow=5, normalize=True) 351 | fake_B = make_grid(fake_B, nrow=5, normalize=True) 352 | # Arange images along y-axis 353 | image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) 354 | save_image(image_grid, cycle_images_path+"/%s.png" % (batches_done), normalize=False) 355 | # ---------- 356 | # Training 357 | # ---------- 358 | prev_time = time.time() 359 | for epoch in range(trained_epoch, n_epochs): 360 | for i, batch in enumerate(dataloader): 361 | # Set model input 362 | real_A = Variable(batch["A"].type(Tensor)) 363 | real_B = Variable(batch["B"].type(Tensor)) 364 | # Adversarial ground truths 365 | valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) 366 | fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) 367 | # ------------------ 368 | # Train Generators 369 | # ------------------ 370 | G_AB.train() 371 | G_BA.train() 372 | optimizer_G.zero_grad() 373 | # Identity loss 374 | loss_id_A = criterion_identity(G_BA(real_A), real_A) 375 | loss_id_B = criterion_identity(G_AB(real_B), real_B) 376 | loss_identity = (loss_id_A + loss_id_B) / 2 377 | # GAN loss 378 | fake_B = G_AB(real_A) 379 | loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) 380 | fake_A = G_BA(real_B) 381 | loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) 382 | loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 383 | # Cycle loss 384 | recov_A = G_BA(fake_B) 385 | loss_cycle_A = criterion_cycle(recov_A, real_A) 386 | recov_B = G_AB(fake_A) 387 | loss_cycle_B = criterion_cycle(recov_B, real_B) 388 | loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 389 | # Total loss 390 | loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity 391 | loss_G.backward() 392 | optimizer_G.step() 393 | # ----------------------- 394 | # Train Discriminator A 395 | # ----------------------- 396 | optimizer_D_A.zero_grad() 397 | # Real loss 398 | loss_real = criterion_GAN(D_A(real_A), valid) 399 | # Fake loss (on batch of previously generated samples) 400 | fake_A_ = fake_A_buffer.push_and_pop(fake_A) 401 | loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) 402 | # Total loss 403 | loss_D_A = (loss_real + loss_fake) / 2 404 | loss_D_A.backward() 405 | optimizer_D_A.step() 406 | # ----------------------- 407 | # Train Discriminator B 408 | # ----------------------- 409 | optimizer_D_B.zero_grad() 410 | # Real loss 411 | loss_real = criterion_GAN(D_B(real_B), valid) 412 | # Fake loss (on batch of previously generated samples) 413 | fake_B_ = fake_B_buffer.push_and_pop(fake_B) 414 | loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) 415 | # Total loss 416 | loss_D_B = (loss_real + loss_fake) / 2 417 | loss_D_B.backward() 418 | optimizer_D_B.step() 419 | loss_D = (loss_D_A + loss_D_B) / 2 420 | # -------------- 421 | # Log Progress 422 | # -------------- 423 | # Determine approximate time left 424 | batches_done = epoch * len(dataloader) + i 425 | batches_left = n_epochs * len(dataloader) - batches_done 426 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 427 | prev_time = time.time() 428 | # Print log 429 | print( 430 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" 431 | % ( 432 | epoch, 433 | n_epochs, 434 | i, 435 | len(dataloader), 436 | loss_D.item(), 437 | loss_G.item(), 438 | loss_GAN.item(), 439 | loss_cycle.item(), 440 | loss_identity.item(), 441 | time_left, 442 | ) 443 | ) 444 | f=open(save_path+'cycle_process.txt','a') 445 | f.write( 446 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" 447 | % ( 448 | epoch, 449 | n_epochs, 450 | i, 451 | len(dataloader), 452 | loss_D.item(), 453 | loss_G.item(), 454 | loss_GAN.item(), 455 | loss_cycle.item(), 456 | loss_identity.item(), 457 | time_left, 458 | ) 459 | ) 460 | f.close() 461 | # If at sample interval save image 462 | if batches_done % sample_interval == 0: 463 | sample_images(batches_done) 464 | # Update learning rates 465 | lr_scheduler_G.step() 466 | lr_scheduler_D_A.step() 467 | lr_scheduler_D_B.step() 468 | if checkpoint_interval != -1 and epoch % checkpoint_interval == 0: 469 | # Save model checkpoints 470 | torch.save(G_AB.state_dict(), cycle_saved_models_path+"G_AB_%d.pth" % (epoch)) 471 | torch.save(G_BA.state_dict(), cycle_saved_models_path+"G_BA_%d.pth" % (epoch)) 472 | torch.save(D_A.state_dict(), cycle_saved_models_path+"D_A_%d.pth" % (epoch)) 473 | torch.save(D_B.state_dict(), cycle_saved_models_path+"D_B_%d.pth" % (epoch)) 474 | 475 | ''' 476 | 用于cross SWD cycle GAN 477 | ''' 478 | def get_random_projections(n_projections, d, seed=None): 479 | if not isinstance(seed, np.random.RandomState): 480 | random_state = np.random.RandomState(seed) 481 | else: 482 | random_state = seed 483 | projections = random_state.normal(0., 1., [n_projections, d]) 484 | norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) 485 | projections = projections / norm 486 | return projections 487 | 488 | def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): 489 | from ot.lp import emd2_1d 490 | X_s = np.asanyarray(X_s) 491 | X_t = np.asanyarray(X_t) 492 | n = X_s.shape[0] 493 | m = X_t.shape[0] 494 | if X_s.shape[1] != X_t.shape[1]: 495 | raise ValueError( 496 | "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],X_t.shape[1])) 497 | if a is None: 498 | a = np.full(n, 1 / n) 499 | if b is None: 500 | b = np.full(m, 1 / m) 501 | d = X_s.shape[1] 502 | projections = get_random_projections(n_projections, d, seed) 503 | X_s_projections = np.dot(projections, X_s.T) 504 | X_t_projections = np.dot(projections, X_t.T) 505 | if log: 506 | projected_emd = np.empty(n_projections) 507 | else: 508 | projected_emd = None 509 | res = 0. 510 | for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): 511 | emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) 512 | if projected_emd is not None: 513 | projected_emd[i] = emd 514 | res += emd 515 | res = (res / n_projections) ** 0.5 516 | if log: 517 | return res, {"projections": projections, "projected_emds": projected_emd} 518 | return res 519 | 520 | def get_fid(img_A,img_B): 521 | import numpy as np 522 | from numpy import cov 523 | from scipy.linalg import sqrtm 524 | mu1,sigma1=img_A.mean(axis=0),cov(img_A,rowvar=False) 525 | mu2,sigma2=img_B.mean(axis=0),cov(img_B,rowvar=False) 526 | ssdiff=np.sum((mu1-mu2)**2.0) 527 | covmean=sqrtm(sigma1.dot(sigma2)) 528 | fid=ssdiff+np.trace(sigma1+sigma2-2.0*covmean) 529 | return fid 530 | 531 | def images_style_index(img_A:str,img_B:str): 532 | imgA=np.asarray(Image.open(img_A)) 533 | imgB=np.asarray(Image.open(img_B)) 534 | fid_1=get_fid(imgA,imgB) 535 | wd_1=wasserstein_distance(imgA.flatten(),imgB.flatten()) 536 | swd_1=sliced_wasserstein_distance(imgA,imgB) 537 | print('Fid:',fid_1) 538 | print('Wasserstrin distance:',wd_1) 539 | print('Sliced Wasserstein distance:',swd_1) 540 | result={'Fid:':fid_1, 541 | 'Wasserstrin distance:':wd_1, 542 | 'Sliced Wasserstein distance:':swd_1} 543 | return result 544 | 545 | def WD_SWD_calc(path): 546 | WD=[] 547 | SWD=[] 548 | samples=os.listdir(path) 549 | number=[] 550 | for num in samples: 551 | num=num[:-4] 552 | number.append(num) 553 | number=sorted(np.array(number).astype('int32')) #number is the number of images low to high 554 | for i in range(len(number)-1): 555 | img_forward=np.asarray(Image.open(path+str(i)+'.png')) 556 | img_backward=np.asarray(Image.open(path+str(i+1)+'.png')) 557 | wd=wasserstein_distance(img_forward.flatten(),img_backward.flatten()) 558 | WD.append(wd) 559 | swd=sliced_wasserstein_distance(img_forward,img_backward) 560 | SWD.append(swd) 561 | return WD,SWD 562 | 563 | def WD_SWD_distribution_plot(WD,SWD,save_path,dataset_name): 564 | n=len(SWD) 565 | X=np.arange(0,n,1) 566 | plt.figure(figsize=(30,6)) 567 | plt.plot(X,WD,'b') 568 | plt.fill_between(X, y1=0, y2=WD, facecolor='red', alpha=0.6) 569 | plt.title('Wassertein distance distribution',fontsize=35) 570 | plt.grid() 571 | plt.xticks(fontsize=30) 572 | plt.yticks(fontsize=30) 573 | plt.savefig(save_path+dataset_name+'_WD_distribution.png',dpi=100,bbox_inches='tight') 574 | 575 | plt.figure(figsize=(30,6)) 576 | plt.plot(X,SWD,'b') 577 | plt.fill_between(X, y1=8.3, y2=SWD, facecolor='red', alpha=0.6) 578 | plt.title('Sliced wasserstein distance distribution',fontsize=35) 579 | plt.xticks(fontsize=30) 580 | plt.yticks(fontsize=30) 581 | plt.grid() 582 | plt.savefig(save_path+dataset_name+'_SWD_distribution.png',dpi=100,bbox_inches='tight') 583 | 584 | def cross_cycle_dataset(original_path,out_folder,cross_number:list): 585 | save_path=out_folder+'cross_datasets/' 586 | 587 | 588 | original_train_A=original_path+'/train/A/' 589 | original_train_B=original_path+'/train/B/' 590 | original_val_A=original_path+'/val/A/' 591 | original_val_B=original_path+'/val/B/' 592 | #取train顺序 593 | samples=os.listdir(original_train_A) 594 | original_img_sequence=[] 595 | for num in samples: 596 | num=num[:-4] 597 | original_img_sequence.append(num) 598 | original_img_sequence=sorted(np.array(original_img_sequence).astype('int32')) #original name sequence 599 | #取val顺序 600 | samples2=os.listdir(original_val_A) 601 | original_img2_sequence=[] 602 | for num in samples2: 603 | num=num[:-4] 604 | original_img2_sequence.append(num) 605 | original_img2_sequence=sorted(np.array(original_img2_sequence).astype('int32')) 606 | 607 | n=min(len(original_img_sequence),len(original_img2_sequence)) 608 | if cross_number==[]: 609 | for i in range(n): 610 | #train/A/ 611 | os.makedirs(save_path+'dataset_'+str(i)+'/train/A/',exist_ok=True) #mode/A/ 612 | new_sequence_trainA=list(np.roll(np.array(original_img_sequence),i+1)) #adjust name sequence 613 | new_path_trainA=save_path+'dataset_'+str(i)+'/train/A/' 614 | for j in range(len(original_img_sequence)): 615 | image=cv2.imread(original_train_A+str(original_img_sequence[j])+'.png') 616 | cv2.imwrite(new_path_trainA+str(new_sequence_trainA[j])+'.png',image) 617 | #train/B/ 618 | os.makedirs(save_path+'dataset_'+str(i)+'/train/B/',exist_ok=True) 619 | new_path_trainB=save_path+'dataset_'+str(i)+'/train/B/' 620 | for j in range(len(original_img_sequence)): 621 | image=cv2.imread(original_train_B+str(original_img_sequence[j])+'.png') 622 | cv2.imwrite(new_path_trainB+str(original_img_sequence[j])+'.png',image) 623 | 624 | #val/A/ 625 | os.makedirs(save_path+'dataset_'+str(i)+'/val/A/',exist_ok=True) #mode/A/ 626 | new_sequence_valA=list(np.roll(np.array(original_img2_sequence),i+1)) #adjust name sequence 627 | new_path_valA=save_path+'dataset_'+str(i)+'/val/A/' 628 | for j in range(len(original_img2_sequence)): 629 | image=cv2.imread(original_val_A+str(original_img2_sequence[j])+'.png') 630 | cv2.imwrite(new_path_valA+str(new_sequence_valA[j])+'.png',image) 631 | #val/B/ 632 | os.makedirs(save_path+'dataset_'+str(i)+'/val/B/',exist_ok=True) 633 | new_path_valB=save_path+'dataset_'+str(i)+'/val/B/' 634 | for j in range(len(original_img2_sequence)): 635 | image=cv2.imread(original_val_B+str(original_img2_sequence[j])+'.png') 636 | cv2.imwrite(new_path_valB+str(original_img2_sequence[j])+'.png',image) 637 | 638 | else: 639 | for number in cross_number: 640 | os.makedirs(save_path+'dataset_'+str(number)+'/train/A/',exist_ok=True) 641 | new_sequence_trainA=list(np.roll(np.array(original_img_sequence),number)) 642 | new_path_trainA=save_path+'dataset_'+str(number)+'/train/A/' 643 | os.makedirs(save_path+'dataset_'+str(number)+'/val/A/',exist_ok=True) #mode/A/ 644 | new_sequence_valA=list(np.roll(np.array(original_img2_sequence),number)) #adjust name sequence 645 | new_path_valA=save_path+'dataset_'+str(number)+'/val/A/' 646 | for j in range(len(original_img_sequence)): 647 | image=cv2.imread(original_train_A+str(original_img_sequence[j])+'.png') 648 | cv2.imwrite(new_path_trainA+str(new_sequence_trainA[j])+'.png',image) 649 | for j in range(len(original_img2_sequence)): 650 | image=cv2.imread(original_val_A+str(original_img2_sequence[j])+'.png') 651 | cv2.imwrite(new_path_valA+str(new_sequence_valA[j])+'.png',image) 652 | 653 | os.makedirs(save_path+'dataset_'+str(number)+'/train/B/',exist_ok=True) 654 | new_path_trainB=save_path+'dataset_'+str(number)+'/train/B/' 655 | os.makedirs(save_path+'dataset_'+str(number)+'/val/B/',exist_ok=True) 656 | new_path_valB=save_path+'dataset_'+str(number)+'/val/B/' 657 | for j in range(len(original_img_sequence)): 658 | image=cv2.imread(original_train_B+str(original_img_sequence[j])+'.png') 659 | cv2.imwrite(new_path_trainB+str(original_img_sequence[j])+'.png',image) 660 | for j in range(len(original_img2_sequence)): 661 | image=cv2.imread(original_train_B+str(original_img2_sequence[j])+'.png') 662 | cv2.imwrite(new_path_valB+str(original_img2_sequence[j])+'.png',image) 663 | 664 | return print('All dataset generate done') 665 | 666 | 667 | def testloader_result(test_loader,n_residual_blocks,G_AB_path,G_BA_path,test_result_save_path,channels,img_height,img_width): 668 | input_shape = (channels, img_height, img_width) 669 | G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda() 670 | G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda() 671 | G_AB.load_state_dict(torch.load(G_AB_path)) 672 | G_BA.load_state_dict(torch.load(G_BA_path)) 673 | 674 | 675 | cuda = torch.cuda.is_available() 676 | transforms_ = [ 677 | transforms.Resize(int(img_height * 1.12), Image.BICUBIC), 678 | transforms.RandomCrop((img_height, img_width)), 679 | transforms.RandomHorizontalFlip(), 680 | transforms.ToTensor(), 681 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 682 | ] 683 | test_loader=DataLoader( 684 | ImageDataset(test_loader, transforms_=transforms_, unaligned=True, mode="test"), 685 | batch_size=1, 686 | shuffle=True, 687 | num_workers=0, 688 | ) 689 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 690 | for i,batch in enumerate(test_loader): 691 | real_A=Variable(batch['A'].type(Tensor)) 692 | real_B=Variable(batch['B'].type(Tensor)) 693 | 694 | fake_B = G_AB(real_A).cpu().detach().numpy() 695 | fake_A=G_BA(real_B).cpu().detach().numpy() 696 | plt.figure(figsize=(16,4)) 697 | plt.subplot(1,4,1) 698 | plt.title('G_AB') 699 | plt.imshow(fake_B[0,0,:,:],cmap=plt.cm.gray) 700 | plt.subplot(1,4,2) 701 | plt.title('G_BA') 702 | plt.imshow(fake_A[0,0,:,:],cmap=plt.cm.gray) 703 | plt.subplot(1,4,3) 704 | plt.title('GT_A') 705 | plt.imshow(batch['A'][0,0,:,:].type(Tensor).cpu().detach().numpy(),cmap=plt.cm.gray) 706 | plt.subplot(1,4,4) 707 | plt.title('GT_B') 708 | plt.imshow(batch['B'][0,0,:,:].type(Tensor).cpu().detach().numpy(),cmap=plt.cm.gray) 709 | plt.savefig(test_result_save_path+'/generate_result.png',dpi=100,bbox_inches='tight') 710 | 711 | def Cross_Cycle_GAN(out_folder, 712 | cross_number, 713 | checkpoint_interval, 714 | sample_interval, 715 | n_epochs, 716 | batch_size, 717 | lr, 718 | decay_epoch, 719 | n_residual_blocks, 720 | channels, 721 | img_height, 722 | img_width, 723 | pre_trained:bool, 724 | trained_epoch=0): 725 | lambda_cyc=10.0 726 | lambda_id=5.0 727 | b1=0.5 728 | b2=0.999 729 | n_cpu=0 730 | # Create sample and checkpoint directories 731 | save_path=out_folder+'3D_Reconstruction/' 732 | 733 | def sample_images(batches_done): 734 | """Saves a generated sample from the test set""" 735 | imgs = next(iter(val_dataloader)) 736 | G_AB.eval() 737 | G_BA.eval() 738 | real_A = Variable(imgs["A"].type(Tensor)) 739 | fake_B = G_AB(real_A) 740 | real_B = Variable(imgs["B"].type(Tensor)) 741 | fake_A = G_BA(real_B) 742 | # Arange images along x-axis 743 | real_A = make_grid(real_A, nrow=5, normalize=True) 744 | real_B = make_grid(real_B, nrow=5, normalize=True) 745 | fake_A = make_grid(fake_A, nrow=5, normalize=True) 746 | fake_B = make_grid(fake_B, nrow=5, normalize=True) 747 | # Arange images along y-axis 748 | image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) 749 | save_image(image_grid, cross_images_path+"/%s.png" % (batches_done), normalize=False) 750 | # Losses 751 | criterion_GAN = torch.nn.MSELoss() 752 | criterion_cycle = torch.nn.L1Loss() 753 | criterion_identity = torch.nn.L1Loss() 754 | cuda = torch.cuda.is_available() 755 | input_shape = (channels, img_height, img_width) 756 | # Initialize generator and discriminator 757 | G_AB = GeneratorResNet(input_shape, n_residual_blocks) 758 | G_BA = GeneratorResNet(input_shape, n_residual_blocks) 759 | D_A = Discriminator(input_shape) 760 | D_B = Discriminator(input_shape) 761 | 762 | for number in cross_number: 763 | dataset_path=out_folder+'cross_datasets/dataset_%s'%number 764 | os.makedirs(save_path+'cross_images/dataset_%s' % number, exist_ok=True) 765 | os.makedirs(save_path+'cross_models/dataset_%s' % number, exist_ok=True) 766 | cross_images_path=save_path+'cross_images/dataset_%s' % number 767 | cross_saved_models_path=save_path+'cross_models/dataset_%s' % number 768 | 769 | if cuda: 770 | G_AB = G_AB.cuda() 771 | G_BA = G_BA.cuda() 772 | D_A = D_A.cuda() 773 | D_B = D_B.cuda() 774 | criterion_GAN.cuda() 775 | criterion_cycle.cuda() 776 | criterion_identity.cuda() 777 | if pre_trained==True: 778 | if trained_epoch != 0: 779 | G_AB.load_state_dict(torch.load(cross_saved_models_path+"G_AB_%d.pth" % (trained_epoch))) 780 | G_BA.load_state_dict(torch.load(cross_saved_models_path+"G_BA_%d.pth" % (trained_epoch))) 781 | D_A.load_state_dict(torch.load(cross_saved_models_path+"D_A_%d.pth" % (trained_epoch))) 782 | D_B.load_state_dict(torch.load(cross_saved_models_path+"D_B_%d.pth" % (trained_epoch))) 783 | else: 784 | G_AB.apply(weights_init_normal) 785 | G_BA.apply(weights_init_normal) 786 | D_A.apply(weights_init_normal) 787 | D_B.apply(weights_init_normal) 788 | else: 789 | G_AB.apply(weights_init_normal) 790 | G_BA.apply(weights_init_normal) 791 | D_A.apply(weights_init_normal) 792 | D_B.apply(weights_init_normal) 793 | # Optimizers 794 | optimizer_G = torch.optim.Adam( 795 | itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2) 796 | ) 797 | optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2)) 798 | optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2)) 799 | # Learning rate update schedulers 800 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( 801 | optimizer_G, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 802 | ) 803 | lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( 804 | optimizer_D_A, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 805 | ) 806 | lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( 807 | optimizer_D_B, lr_lambda=LambdaLR(n_epochs, trained_epoch, decay_epoch).step 808 | ) 809 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 810 | # Buffers of previously generated samples 811 | fake_A_buffer = ReplayBuffer() 812 | fake_B_buffer = ReplayBuffer() 813 | # Image transformations 814 | transforms_ = [ 815 | transforms.Resize(int(img_height * 1.12), Image.BICUBIC), 816 | transforms.RandomCrop((img_height, img_width)), 817 | transforms.RandomHorizontalFlip(), 818 | transforms.ToTensor(), 819 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 820 | ] 821 | # Training data loader 822 | dataloader = DataLoader( 823 | ImageDataset(dataset_path , transforms_=transforms_, unaligned=True), 824 | batch_size=batch_size, 825 | shuffle=True, 826 | num_workers=0, 827 | ) 828 | # Test data loader 829 | val_dataloader = DataLoader( 830 | ImageDataset(dataset_path, transforms_=transforms_, unaligned=True, mode="val"), 831 | batch_size=5, 832 | shuffle=True, 833 | num_workers=0, 834 | ) 835 | # ---------- 836 | # Training 837 | # ---------- 838 | prev_time = time.time() 839 | for epoch in range(trained_epoch, n_epochs): 840 | for i, batch in enumerate(dataloader): 841 | # Set model input 842 | f=open(save_path+'dataset_%s_cross_log.txt'%number,'a') 843 | real_A = Variable(batch["A"].type(Tensor)) 844 | real_B = Variable(batch["B"].type(Tensor)) 845 | # Adversarial ground truths 846 | valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) 847 | fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) 848 | # ------------------ 849 | # Train Generators 850 | # ------------------ 851 | G_AB.train() 852 | G_BA.train() 853 | optimizer_G.zero_grad() 854 | # Identity loss 855 | loss_id_A = criterion_identity(G_BA(real_A), real_A) 856 | loss_id_B = criterion_identity(G_AB(real_B), real_B) 857 | loss_identity = (loss_id_A + loss_id_B) / 2 858 | 859 | calc1=G_BA(real_A).cpu().detach().numpy()[0,0,:,:] 860 | calc2=real_A.cpu().detach().numpy()[0,0,:,:] 861 | wd_1=wasserstein_distance(calc1.flatten(),calc2.flatten()) 862 | calc3=G_AB(real_B).cpu().detach().numpy()[0,0,:,:] 863 | calc4=real_B.cpu().detach().numpy()[0,0,:,:] 864 | wd_2=wasserstein_distance(calc3.flatten(),calc4.flatten()) 865 | wd=(wd_1+wd_2)/2 866 | 867 | swd_1=sliced_wasserstein_distance(calc1,calc2) 868 | swd_2=sliced_wasserstein_distance(calc3,calc4) 869 | swd=(swd_1+swd_2)/2 870 | 871 | # GAN loss 872 | fake_B = G_AB(real_A) 873 | loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) 874 | fake_A = G_BA(real_B) 875 | loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) 876 | loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 877 | # Cycle loss 878 | recov_A = G_BA(fake_B) 879 | loss_cycle_A = criterion_cycle(recov_A, real_A) 880 | recov_B = G_AB(fake_A) 881 | loss_cycle_B = criterion_cycle(recov_B, real_B) 882 | loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 883 | # Total loss 884 | loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity 885 | loss_G.backward() 886 | optimizer_G.step() 887 | # ----------------------- 888 | # Train Discriminator A 889 | # ----------------------- 890 | optimizer_D_A.zero_grad() 891 | # Real loss 892 | loss_real = criterion_GAN(D_A(real_A), valid) 893 | # Fake loss (on batch of previously generated samples) 894 | fake_A_ = fake_A_buffer.push_and_pop(fake_A) 895 | loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) 896 | # Total loss 897 | loss_D_A = (loss_real + loss_fake) / 2 898 | loss_D_A.backward() 899 | optimizer_D_A.step() 900 | # ----------------------- 901 | # Train Discriminator B 902 | # ----------------------- 903 | optimizer_D_B.zero_grad() 904 | # Real loss 905 | loss_real = criterion_GAN(D_B(real_B), valid) 906 | # Fake loss (on batch of previously generated samples) 907 | fake_B_ = fake_B_buffer.push_and_pop(fake_B) 908 | loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) 909 | # Total loss 910 | loss_D_B = (loss_real + loss_fake) / 2 911 | loss_D_B.backward() 912 | optimizer_D_B.step() 913 | loss_D = (loss_D_A + loss_D_B) / 2 914 | # -------------- 915 | # Log Progress 916 | # -------------- 917 | # Determine approximate time left 918 | batches_done = epoch * len(dataloader) + i 919 | batches_left = n_epochs * len(dataloader) - batches_done 920 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 921 | prev_time = time.time() 922 | # Print log 923 | print( 924 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s [WD: %f, SWD: %f]" 925 | % ( 926 | epoch, 927 | n_epochs, 928 | i, 929 | len(dataloader), 930 | loss_D.item(), 931 | loss_G.item(), 932 | loss_GAN.item(), 933 | loss_cycle.item(), 934 | loss_identity.item(), 935 | time_left, 936 | wd, 937 | swd, 938 | ) 939 | ) 940 | 941 | f.write( 942 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s [WD: %f, SWD: %f]" 943 | % ( 944 | epoch, 945 | n_epochs, 946 | i, 947 | len(dataloader), 948 | loss_D.item(), 949 | loss_G.item(), 950 | loss_GAN.item(), 951 | loss_cycle.item(), 952 | loss_identity.item(), 953 | time_left, 954 | wd, 955 | swd, 956 | ) 957 | ) 958 | f.write('\n') 959 | f.close() 960 | # If at sample interval save image 961 | if batches_done % sample_interval == 0: 962 | imgs = next(iter(val_dataloader)) 963 | G_AB.eval() 964 | G_BA.eval() 965 | real_A = Variable(imgs["A"].type(Tensor)) 966 | fake_B = G_AB(real_A) 967 | real_B = Variable(imgs["B"].type(Tensor)) 968 | fake_A = G_BA(real_B) 969 | real_A = make_grid(real_A, nrow=5, normalize=True) 970 | real_B = make_grid(real_B, nrow=5, normalize=True) 971 | fake_A = make_grid(fake_A, nrow=5, normalize=True) 972 | fake_B = make_grid(fake_B, nrow=5, normalize=True) 973 | image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) 974 | save_image(image_grid, cross_images_path+'/%s.png' % batches_done, normalize=False) 975 | # Update learning rates 976 | lr_scheduler_G.step() 977 | lr_scheduler_D_A.step() 978 | lr_scheduler_D_B.step() 979 | if checkpoint_interval != -1 and epoch % checkpoint_interval == 0: 980 | # Save model checkpoints 981 | torch.save(G_AB.state_dict(), cross_saved_models_path+"G_AB_%d.pth" % (epoch)) 982 | torch.save(G_BA.state_dict(), cross_saved_models_path+"G_BA_%d.pth" % (epoch)) 983 | torch.save(D_A.state_dict(), cross_saved_models_path+"D_A_%d.pth" % (epoch)) 984 | torch.save(D_B.state_dict(), cross_saved_models_path+"D_B_%d.pth" % (epoch)) 985 | 986 | def SWD_cross_cycle(out_folder, 987 | n_epochs, 988 | channels, 989 | img_height, 990 | img_width, 991 | n_residual_blocks, 992 | cross_number, 993 | berea_calc_WD_SWD_datasetpath, 994 | test_loader, 995 | ): 996 | n_residual_blocks=Resnet_blocks 997 | n_epochs=n_epochs 998 | channels=channels 999 | img_height=img_height 1000 | img_width=img_width 1001 | out_folder=out_folder 1002 | models_path=out_folder+'3D_Reconstruction/cross_models/' 1003 | save_path=out_folder+'3D_Reconstruction/cross_generate/' 1004 | os.makedirs(save_path, exist_ok=True) 1005 | 1006 | models=cross_number 1007 | 1008 | WD_berea,SWD_berea=WD_SWD_calc(berea_calc_WD_SWD_datasetpath) 1009 | layer=len(SWD_berea) #399 1010 | SWD_new=[] 1011 | input_shape = (channels, img_height, img_width) 1012 | G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda() 1013 | G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda() 1014 | cuda=torch.cuda.is_available() 1015 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 1016 | 1017 | for l in range(layer): 1018 | swd=0 1019 | if l==0: 1020 | for i in models: 1021 | if i==1: 1022 | break 1023 | else: 1024 | G_BA.load_state_dict(torch.load(models_path+'/dataset_%s/G_AB_%s.pth'%(i,n_epochs-1))) 1025 | for j,batch in enumerate(test_loader): 1026 | real_A=Variable(batch['A'].type(Tensor)) 1027 | real_B=Variable(batch['B'].type(Tensor)) 1028 | fake_A=G_BA(real_B) 1029 | real_AA=real_A.cpu().detach().numpy()#用来计算swd 1030 | fake_AA=fake_A.cpu().detach().numpy()#用来计算swd 1031 | swd=sliced_wasserstein_distance(fake_AA[0,0,:,:],real_AA[0,0,:,:]) 1032 | SWD_new.append(swd) 1033 | save_image(fake_A,save_path+str(l+1)+'_layer.png',normalize=False) 1034 | break 1035 | else: 1036 | for i in models: 1037 | ts=0 1038 | G_BA.load_state_dict(torch.load(torch.load(models_path+'/dataset_%s/G_AB_%s.pth'%(i,n_epochs-1)))) 1039 | while -2>np.abs(swd-SWD_berea[l])>2 and ts<100: 1040 | for j,batch in enumerate(test_loader): 1041 | real_A=Variable(batch['A'].type(Tensor)) 1042 | real_B=Variable(batch['B'].type(Tensor)) 1043 | fake_A=G_BA(real_B) 1044 | real_AA=real_A.cpu().detach().numpy()#用来计算swd 1045 | fake_AA=fake_A.cpu().detach().numpy()#用来计算swd 1046 | save_image(fake_A,save_path+str(l+1)+'_layer.png',normalize=False) 1047 | last_layer=np.asarray(Image.open(save_path+str(l)+'_layer.png')) 1048 | this_layer=np.asarray(Image.open(save_path+str(l+1)+'_layer.png')) 1049 | swd=sliced_wasserstein_distance(this_layer[0,:,:],last_layer[0,:,:]) 1050 | ts+=1 1051 | SWD_new.append(swd) 1052 | break 1053 | return SWD_new 1054 | 1055 | ''' 1056 | 14. Cycle_GAN generate from promoted translation style 1057 | Cycle_GAN(out_folder=out_folder, 1058 | dataset_name=dataset_name, 1059 | dataset_path=dataset_path, 1060 | checkpoint_interval=checkpoint_interval, 1061 | sample_interval=sample_interval, 1062 | n_epochs=n_epochs, 1063 | batch_size=batch_size, 1064 | lr=learning_rate, 1065 | decay_epoch=decay_epoch, 1066 | n_residual_blocks=Resnet_blocks, 1067 | channels=channels, 1068 | img_height=img_height, 1069 | img_width=img_width, 1070 | pre_trained=pre_trained, 1071 | trained_epoch=trained_epoch 1072 | ) 1073 | ''' 1074 | ''' 1075 | 15. SWD WS and FID distribution calc 1076 | WD,SWD=WD_SWD_calc(berea_calc_WD_SWD_datasetpath) 1077 | WD_SWD_distribution_plot(WD,SWD,save_path=out_folder,dataset_name='berea') 1078 | ''' 1079 | ''' 1080 | 16. corss domain datasets create 1081 | cross_cycle_dataset(original_path=dataset_path, 1082 | out_folder=out_folder, 1083 | cross_number=cross_number) 1084 | ''' 1085 | ''' 1086 | 17. cross datasets train models #一个模型一个模型地训练,要清除变量 1087 | cross train: 1088 | Cross_Cycle_GAN(out_folder=out_folder, 1089 | cross_number=cross_number, 1090 | checkpoint_interval=checkpoint_interval, 1091 | sample_interval=sample_interval, 1092 | n_epochs=n_epochs, 1093 | batch_size=batch_size, 1094 | lr=learning_rate, 1095 | decay_epoch=decay_epoch, 1096 | n_residual_blocks=Resnet_blocks, 1097 | channels=channels, 1098 | img_height=img_height, 1099 | img_width=img_width, 1100 | pre_trained=pre_trained, 1101 | trained_epoch=trained_epoch) 1102 | testloader visualization: 1103 | testloader_result(test_loader=test_loader, 1104 | n_resudual_blocks=Resnet_blocks, 1105 | G_AB_path=G_AB_path, 1106 | G_BA_path=G_BA_path, 1107 | test_result_save_path=test_result_save_path, 1108 | channels=channels, 1109 | img_height=img_height, 1110 | img_width=img_width) 1111 | ''' 1112 | ''' 1113 | 18. SWD-guided Cycle-GAN 3D reconstruction 1114 | Generate_SWD=SWD_cross_cycle(out_folder=out_folder, 1115 | n_epochs=n_epochs, 1116 | channels=channels, 1117 | img_height=img_height, 1118 | img_width=img_width, 1119 | n_residual_blocks=Resnet_blocks, 1120 | cross_number=cross_number, 1121 | berea_calc_WD_SWD_datasetpath=berea_calc_WD_SWD_datasetpath, 1122 | test_loader=test_loader) 1123 | 1124 | ''' 1125 | --------------------------------------------------------------------------------