├── README.md ├── checkpoints ├── gaussian_pretrained │ ├── latest_net_G.pth │ └── opt.txt ├── poisson_pretrained │ ├── latest_net_G.pth │ └── opt.txt └── speckle_pretrained │ ├── latest_net_G.pth │ └── opt.txt ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── aligned_dataset.cpython-36.pyc │ ├── aligned_dataset.cpython-37.pyc │ ├── base_data_loader.cpython-36.pyc │ ├── base_data_loader.cpython-37.pyc │ ├── base_dataset.cpython-36.pyc │ ├── base_dataset.cpython-37.pyc │ ├── image_folder.cpython-36.pyc │ ├── image_folder.cpython-37.pyc │ ├── single_dataset.cpython-36.pyc │ └── unaligned_dataset.cpython-36.pyc ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── image_folder.py └── single_dataset.py ├── datasets ├── 0.txt ├── test │ └── 0.txt └── train │ └── 0.txt ├── docs ├── datasets.md ├── qa.md └── tips.md ├── environment.yml ├── imgs └── 0.txt ├── matlab_code_for_synthesizing_gaussian_noise.m ├── matlab_code_for_synthesizing_speckle_noise.m ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_model.cpython-36.pyc │ ├── base_model.cpython-37.pyc │ ├── cycle_gan_model.cpython-36.pyc │ ├── dbpn.cpython-36.pyc │ ├── dbpn.cpython-37.pyc │ ├── denoise_model.cpython-37.pyc │ ├── derain_model.cpython-37.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-37.pyc │ ├── pix2pix_model.cpython-36.pyc │ ├── pix2pix_model.cpython-37.pyc │ ├── test_model.cpython-36.pyc │ ├── unet_parts.cpython-36.pyc │ └── unet_parts.cpython-37.pyc ├── base_model.py ├── denoise_model.py ├── networks.py ├── pytorch-ssim-master │ ├── LICENSE.txt │ ├── README.md │ ├── einstein.png │ ├── max_ssim.gif │ ├── max_ssim.py │ ├── setup.cfg │ └── setup.py ├── pytorch_ssim │ ├── __init__.py │ ├── __init__.pyc │ └── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-37.pyc ├── ssim2 │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-36.pyc ├── test_model.py └── unet_parts.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_options.cpython-36.pyc │ ├── base_options.cpython-37.pyc │ ├── test_options.cpython-36.pyc │ ├── test_options.cpython-37.pyc │ ├── train_options.cpython-36.pyc │ └── train_options.cpython-37.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── psnr_and_ssim.py ├── python_code_for_synthesizing_poisson_noise.py ├── requirements.txt ├── results └── 0.txt ├── test.py ├── train.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── html.cpython-36.pyc ├── html.cpython-37.pyc ├── image_pool.cpython-36.pyc ├── image_pool.cpython-37.pyc ├── util.cpython-36.pyc ├── util.cpython-37.pyc ├── visualizer.cpython-36.pyc └── visualizer.cpython-37.pyc ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # Noise2Grad_Pytorch_code 2 | Pytorch implementation for "Noise2Grad: Extract Image Noise to Denoise" (IJCAI-2021) 3 | 4 | Paper Link: [Link](https://www.ijcai.org/proceedings/2021/115) 5 | 6 | 7 | Training 8 | ================================= 9 | Download the training datasets from [Google Drive](https://drive.google.com/drive/folders/1xRJLe8D3rhUWssnZczSEy3-LqPSRZ0kN). 10 | Unzip "ground_truth.zip" and "reference_clean_image.zip" and put them in "./datasets/train/". Use the code 'matlab_code_for_synthesizing_gaussian_noise.m' or 'matlab_code_for_synthesizing_speckle_noise.m' or 'python_code_for_synthesizing_poisson_noise.py' to synthesize noisy images, and then put them into "./datasets/train/noisy_train/". 11 | 12 | 13 | - Train the model: 14 | 15 | *python train.py --dataroot ./datasets/train/noisy_train/ --name new --model denoise* 16 | 17 | 18 | Testing 19 | ======= 20 | 21 | Download the testing dataset from [Google Drive](https://drive.google.com/drive/folders/1xRJLe8D3rhUWssnZczSEy3-LqPSRZ0kN). 22 | 23 | Unzip "ground_truth.zip" in './datasets/test/'. Use the code 'matlab_code_for_synthesizing_gaussian_noise.m' or 'matlab_code_for_synthesizing_speckle_noise.m' or 'python_code_for_synthesizing_poisson_noise.py' to synthesize noisy images, and then put them into "./datasets/test/noisy_test/". 24 | 25 | 26 | - Test: 27 | 28 | *python test.py --dataroot ./datasets/test/noisy_test/ --name new --model denoise* 29 | 30 | - Test with our pretrained model: 31 | 32 | #gaussian pretrained 33 | 34 | *python test.py --dataroot ./datasets/test/noisy_test/ --name gaussian_pretrained --model denoise* 35 | 36 | #poisson pretrained 37 | 38 | *python test.py --dataroot ./datasets/test/noisy_test/ --name poisson_pretrained --model denoise* 39 | 40 | #speckle pretrained 41 | 42 | *python test.py --dataroot ./datasets/test/noisy_test/ --name speckle_pretrained --model denoise* 43 | 44 | After the test, results are saved in './results/'. 45 | 46 | Run "psnr_and_ssim.py" to caculate psnr and ssim. 47 | 48 | 49 | -------------------------------------------------------------------------------- /checkpoints/gaussian_pretrained/latest_net_G.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/checkpoints/gaussian_pretrained/latest_net_G.pth -------------------------------------------------------------------------------- /checkpoints/gaussian_pretrained/opt.txt: -------------------------------------------------------------------------------- 1 | ----------------- Options --------------- 2 | aspect_ratio: 1.0 3 | batch_size: 1 4 | checkpoints_dir: ./checkpoints 5 | dataroot: ./datasets/test/noisy_test/ [default: None] 6 | dataset_mode: aligned 7 | display_winsize: 256 8 | epoch: latest 9 | eval: False 10 | fineSize: 256 11 | gpu_ids: 0 12 | init_gain: 0.02 13 | init_type: normal 14 | input_nc: 3 15 | isTrain: False [default: None] 16 | loadSize: 256 17 | load_iter: 0 [default: 0] 18 | max_dataset_size: inf 19 | model: denoise [default: test] 20 | n_layers_D: 3 21 | name: gaussian_pretrained [default: experiment_name] 22 | ndf: 16 23 | netD: basic 24 | netG: unet_256 25 | ngf: 64 26 | no_dropout: False 27 | no_flip: False 28 | norm: batch 29 | ntest: inf 30 | num_test: 50000 31 | num_threads: 4 32 | output_nc: 3 33 | phase: test 34 | resize_or_crop: resize_and_crop 35 | results_dir: ./results/ 36 | serial_batches: False 37 | suffix: 38 | verbose: False 39 | ----------------- End ------------------- 40 | -------------------------------------------------------------------------------- /checkpoints/poisson_pretrained/latest_net_G.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/checkpoints/poisson_pretrained/latest_net_G.pth -------------------------------------------------------------------------------- /checkpoints/poisson_pretrained/opt.txt: -------------------------------------------------------------------------------- 1 | ----------------- Options --------------- 2 | aspect_ratio: 1.0 3 | batch_size: 1 4 | checkpoints_dir: ./checkpoints 5 | dataroot: ./datasets/test/noisy_test/ [default: None] 6 | dataset_mode: aligned 7 | display_winsize: 256 8 | epoch: latest 9 | eval: False 10 | fineSize: 256 11 | gpu_ids: 0 12 | init_gain: 0.02 13 | init_type: normal 14 | input_nc: 3 15 | isTrain: False [default: None] 16 | loadSize: 256 17 | load_iter: 0 [default: 0] 18 | max_dataset_size: inf 19 | model: denoise [default: test] 20 | n_layers_D: 3 21 | name: poisson_pretrained [default: experiment_name] 22 | ndf: 16 23 | netD: basic 24 | netG: unet_256 25 | ngf: 64 26 | no_dropout: False 27 | no_flip: False 28 | norm: batch 29 | ntest: inf 30 | num_test: 50000 31 | num_threads: 4 32 | output_nc: 3 33 | phase: test 34 | resize_or_crop: resize_and_crop 35 | results_dir: ./results/ 36 | serial_batches: False 37 | suffix: 38 | verbose: False 39 | ----------------- End ------------------- 40 | -------------------------------------------------------------------------------- /checkpoints/speckle_pretrained/latest_net_G.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/checkpoints/speckle_pretrained/latest_net_G.pth -------------------------------------------------------------------------------- /checkpoints/speckle_pretrained/opt.txt: -------------------------------------------------------------------------------- 1 | ----------------- Options --------------- 2 | aspect_ratio: 1.0 3 | batch_size: 1 4 | checkpoints_dir: ./checkpoints 5 | dataroot: ./datasets/test/noisy_test/ [default: None] 6 | dataset_mode: aligned 7 | display_winsize: 256 8 | epoch: latest 9 | eval: False 10 | fineSize: 256 11 | gpu_ids: 0 12 | init_gain: 0.02 13 | init_type: normal 14 | input_nc: 3 15 | isTrain: False [default: None] 16 | loadSize: 256 17 | load_iter: 0 [default: 0] 18 | max_dataset_size: inf 19 | model: denoise [default: test] 20 | n_layers_D: 3 21 | name: speckle_pretrained [default: experiment_name] 22 | ndf: 16 23 | netD: basic 24 | netG: unet_256 25 | ngf: 64 26 | no_dropout: False 27 | no_flip: False 28 | norm: batch 29 | ntest: inf 30 | num_test: 50000 31 | num_threads: 4 32 | output_nc: 3 33 | phase: test 34 | resize_or_crop: resize_and_crop 35 | results_dir: ./results/ 36 | serial_batches: False 37 | suffix: 38 | verbose: False 39 | ----------------- End ------------------- 40 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches, 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/aligned_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/aligned_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/aligned_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/aligned_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/base_data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/single_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/single_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/unaligned_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/data/__pycache__/unaligned_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | import cv2 8 | 9 | 10 | 11 | 12 | class AlignedDataset(BaseDataset): 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | self.opt = opt 19 | self.istrain = opt.isTrain 20 | self.root = opt.dataroot 21 | self.dir_X = os.path.join(opt.dataroot) 22 | self.X_paths = sorted(make_dataset(self.dir_X)) 23 | 24 | 25 | 26 | def __getitem__(self, index): 27 | if self.istrain: 28 | X_path = self.X_paths[index] 29 | X = cv2.imread(X_path) 30 | (h, w, n) = X.shape 31 | 32 | width = 128 33 | 34 | h_off = random.randint(0, h - width) 35 | w_off = random.randint(0, w - width) 36 | 37 | X = X[h_off:h_off + width, w_off:w_off + width] 38 | 39 | X = cv2.cvtColor(X, cv2.COLOR_BGR2RGB) 40 | 41 | rr = random.randint(0, 3) 42 | 43 | if rr == 1: 44 | X = cv2.flip(X, 0) 45 | elif rr == 2: 46 | X = cv2.flip(X, 1) 47 | elif rr == 3: 48 | X = cv2.flip(X, -1) 49 | else: 50 | pass 51 | 52 | ind = random.randint(0, 5000) 53 | Y = cv2.imread('./datasets/train/reference_clean_image/' + str(ind) + '.png') 54 | 55 | (h3, w3, _) = Y.shape 56 | h3_off = random.randint(0, h3 - width) 57 | w3_off = random.randint(0, w3 - width) 58 | 59 | Y = Y[h3_off:h3_off + width, w3_off:w3_off + width] 60 | Y = cv2.cvtColor(Y, cv2.COLOR_BGR2RGB) 61 | 62 | X = transforms.ToTensor()(X) 63 | 64 | Y = transforms.ToTensor()(Y) 65 | 66 | return {'X': X, 'Y': Y, 'X_paths': X_path} 67 | else: 68 | 69 | X_path = self.X_paths[index] 70 | X = cv2.imread(X_path) 71 | X = cv2.cvtColor(X, cv2.COLOR_BGR2RGB) 72 | 73 | X = transforms.ToTensor()(X) 74 | 75 | return {'X': X, 'X_paths': X_path} 76 | 77 | 78 | 79 | def __len__(self): 80 | return len(self.X_paths) 81 | 82 | def name(self): 83 | return 'AlignedDataset' 84 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.loadSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | 54 | # just modify the width and height to be multiple of 4 55 | def __adjust(img): 56 | ow, oh = img.size 57 | 58 | # the size needs to be a multiple of this number, 59 | # because going through generator network may change img size 60 | # and eventually cause size mismatch error 61 | mult = 4 62 | if ow % mult == 0 and oh % mult == 0: 63 | return img 64 | w = (ow - 1) // mult 65 | w = (w + 1) * mult 66 | h = (oh - 1) // mult 67 | h = (h + 1) * mult 68 | 69 | if ow != w or oh != h: 70 | __print_size_warning(ow, oh, w, h) 71 | 72 | return img.resize((w, h), Image.BICUBIC) 73 | 74 | 75 | def __scale_width(img, target_width): 76 | ow, oh = img.size 77 | 78 | # the size needs to be a multiple of this number, 79 | # because going through generator network may change img size 80 | # and eventually cause size mismatch error 81 | mult = 4 82 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 83 | if (ow == target_width and oh % mult == 0): 84 | return img 85 | w = target_width 86 | target_height = int(target_width * oh / ow) 87 | m = (target_height - 1) // mult 88 | h = (m + 1) * mult 89 | 90 | if target_height != h: 91 | __print_size_warning(target_width, target_height, w, h) 92 | 93 | return img.resize((w, h), Image.BICUBIC) 94 | 95 | 96 | def __print_size_warning(ow, oh, w, h): 97 | if not hasattr(__print_size_warning, 'has_printed'): 98 | print("The image size needs to be a multiple of 4. " 99 | "The loaded image size was (%d, %d), so it was adjusted to " 100 | "(%d, %d). This adjustment will be done to all images " 101 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 102 | __print_size_warning.has_printed = True 103 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class SingleDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | if input_nc == 1: # RGB to gray 33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 34 | A = tmp.unsqueeze(0) 35 | 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | return len(self.A_paths) 40 | 41 | def name(self): 42 | return 'SingleImageDataset' 43 | -------------------------------------------------------------------------------- /datasets/0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/datasets/0.txt -------------------------------------------------------------------------------- /datasets/test/0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/datasets/test/0.txt -------------------------------------------------------------------------------- /datasets/train/0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/datasets/train/0.txt -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### CycleGAN Datasets 4 | Download the CycleGAN datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data. 5 | ```bash 6 | bash ./datasets/download_cyclegan_dataset.sh dataset_name 7 | ``` 8 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)] 9 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)] 10 | - `maps`: 1096 training images scraped from Google Maps. 11 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org) using keywords `wild horse` and `zebra` 12 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org) using keywords `apple` and `navel orange`. 13 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. 14 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. 15 | - `iphone2dslr_flower`: both classes of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. 16 | 17 | To train a model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting `--phase train` in `test.py`. You can also create subdirectories `testA` and `testB` if you have test data. 18 | 19 | You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. 20 | 21 | ### pix2pix datasets 22 | Download the pix2pix datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data. 23 | ```bash 24 | bash ./datasets/download_pix2pix_dataset.sh dataset_name 25 | ``` 26 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)] 27 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)] 28 | - `maps`: 1096 training images scraped from Google Maps 29 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)] 30 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)] 31 | - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)]. To train a `day2night` pix2pix model, you need to add `--direction BtoA`. 32 | 33 | We provide a python script to generate pix2pix training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A: 34 | 35 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc). 36 | 37 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`. 38 | 39 | Once the data is formatted this way, call: 40 | ```bash 41 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 42 | ``` 43 | 44 | This will combine each pair of images (A,B) into a single image file, ready for training. 45 | -------------------------------------------------------------------------------- /docs/qa.md: -------------------------------------------------------------------------------- 1 | ## Frequently Asked Questions 2 | Before you post a new question, please first look at the following Q & A and existing GitHub issues. You may also want to read [Training/Test tips](docs/tips.md) for more suggestions. 3 | 4 | #### Connection Error:HTTPConnectionPool ([#230](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/230), [#24](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/24), [#38](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/38)) 5 | Similar error messages include “Failed to establish a new connection/Connection refused”. 6 | 7 | Please start the visdom server before starting the training: 8 | ```bash 9 | python -m visdom.server 10 | ``` 11 | To install the visdom, you can use the following command: 12 | ```bash 13 | pip install visdom 14 | ``` 15 | You can also disable the visdom by setting `--display_id 0`. 16 | 17 | #### My PyTorch errors on CUDA related code. 18 | Try to run the following code snippet to make sure that CUDA is working (assuming using PyTorch >= 0.4): 19 | ```python 20 | import torch 21 | torch.cuda.init() 22 | print(torch.randn(1, device='cuda') 23 | ``` 24 | 25 | If you met an error, it is likely that your PyTorch build does not work with CUDA, e.g., it is installl from the official MacOS binary, or you have a GPU that is too old and not supported anymore. You may run the the code with CPU using `--device_ids -1`. 26 | 27 | #### TypeError: Object of type 'Tensor' is not JSON serializable ([#258](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/258)) 28 | Similar errors: AttributeError: module 'torch' has no attribute 'device' ([#314](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/314)) 29 | 30 | The current code only works with PyTorch 0.4+. An earlier PyTorch version can often cause the above errors. 31 | 32 | #### ValueError: empty range for randrange() ([#390](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/390), [#376](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/376), [#194](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/194)) 33 | Similar error messages include "ConnectionRefusedError: [Errno 111] Connection refused" 34 | 35 | It is related to data augmentation step. It often happens when you use `--resize_or_crop crop`. The program will crop random `fineSize x fineSize` patches out of the input training images. But if some of your image sizes (e.g., `256x384`) are smaller than the `fineSize` (e.g., 512), you will get this error. A simple fix will be to use other data augmentation methods such as `--resize_and_crop` or `scale_width_and_crop`. Our program will automatically resize the images according to `loadSize` before apply `fineSize x fineSize` cropping. Make sure that `loadSize >= fineSize`. 36 | 37 | 38 | #### Can I continue/resume my training? ([#350](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/350), [#275](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/275), [#234](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/234), [#87](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/87)) 39 | You can use the option `--continue_train`. Also set `--epoch_count` to specify a different starting epoch count. See more discussion in [training/test tips](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#trainingtest-tips. 40 | 41 | #### Why does my training loss not converge? ([#335](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/335), [#164](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/164), [#30](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/30)) 42 | Many GAN losses do not converge (exception: WGAN, WGAN-GP, etc. ) due to the nature of minimax optimization. For DCGAN and LSGAN objective, it is quite normal for the G and D losses to go up and down. It should be fine as long as they do not blow up. 43 | 44 | #### How can I make it work for my own data (e.g., 16-bit png, tiff, hyperspectral images)? ([#309](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/309), [#320](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/), [#202](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/202)) 45 | The current code only supports RGB and grayscale images. If you would like to train the model on other data types, please follow the following steps: 46 | 47 | - change the parameters `--input_nc` and `--output_nc` to the number of channels in your input/output images. 48 | - Write your own custom data loader (It is easy as long as you know how to load your data with python). If you write a new data loader class, you need to change the flag `--dataset_mode` accordingly. Alternatively, you can modify the existing data loader. For aligned datasets, change this [line](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/aligned_dataset.py#L24); For unaligned datasets, change these two [lines](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/unaligned_dataset.py#L36). 49 | 50 | - If you use visdom and HTML to visualize the results, you may also need to change the visualization code. 51 | 52 | #### Multi-GPU Training ([#327](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/327), [#292](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/292), [#137](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/137), [#35](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/35)) 53 | You can use Multi-GPU training by setting `--gpu_ids` (e.g., `--gpu_ids 0,1,2,3` for the first four GPUs on your machine.) To fully utilize all the GPUs, you need to increase your batch size. Try `--batch_size 4`, `--batch_size 16`, or even a larger batch_size. Each GPU will process batch_size/#GPUs images. The optimal batch size depends on the number of GPUs you have, GPU memory per GPU, and the resolution of your training images. 54 | 55 | We also recommend that you use the instance normalization for multi-GPU training by setting `--norm instance`. The current batch normalization might not work for multi-GPUs as the batchnorm parameters are not shared across different GPUs. Advanced users can try [synchronized batchnorm](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch). 56 | 57 | 58 | #### Can I run the model on CPU? ([#310](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/310)) 59 | Yes, you can set `--gpu_ids -1`. See [training/test tips](docs/tips.md) for more details. 60 | 61 | 62 | #### Are pre-trained models available? ([#10](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/10)) 63 | Yes, you can download pretrained models with the bash script `./scripts/download_cyclegan_model.sh`. See [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details. We are slowly adding more models to the repo. 64 | 65 | #### Out of memory ([#174](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/174)) 66 | CycleGAN is more memory-intensive than pix2pix as it requires two generators and two discriminators. If you would like to produce high-resolution images, you can do the following. 67 | 68 | - During training, train CycleGAN on cropped images of the training set. Please be careful not to change the aspect ratio or the scale of the original image, as this can lead to the training/test gap. You can usually do this by using `--resize_or_crop crop` option, or `--resize_or_crop scale_width_and_crop`. 69 | 70 | - Then at test time, you can load only one generator to produce the results in a single direction. This greatly saves GPU memory as you are not loading the discriminators and the other generator in the opposite direction. You can probably take the whole image as input. You can do this using `--model test --dataroot [path to the directory that contains your test images (e.g., ./datasets/horse2zebra/trainA)] --model_suffix _A --resize_or_crop none`. You can use either `--resize_or_crop none` or `--resize_or_crop scale_width --fineSize [your_desired_image_width]`. Please see the [model_suffix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/test_model.py#L16) and [resize_or_crop](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/base_dataset.py#L24) for more details. 71 | 72 | #### What is the identity loss? ([#322](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/322), [#373](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/373), [#362](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/pull/362)) 73 | We use the identity loss for our photo to painting application. The identity loss can regularize the generator to be close to an identity mapping when fed with real samples from the *target* domain. If something already looks like from the target domain, you should preserve the image without making additional changes. The generator trained with this loss will often be more conservative for unknown content. Please see more details in Sec 5.2 ''Photo generation from paintings'' and Figure 12 in the CycleGAN [paper](https://arxiv.org/pdf/1703.10593.pdf). The loss was first proposed in the Equation 6 of the prior work [[Taigman et al., 2017]](https://arxiv.org/pdf/1611.02200.pdf). 74 | 75 | #### The color gets inverted from the beginning of training ([#249](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/249)) 76 | The authors also observe that the generator unnecessarily inverts the color of the input image early in training, and then never learns to undo the inversion. In this case, you can try two things. 77 | 78 | - First, try using identity loss `--identity 1.0` or `--identity 0.1`. We observe that the identity loss makes the generator to be more conservative and make fewer unnecessary changes. However, because of this, the change may not be as dramatic. 79 | 80 | - Second, try smaller variance when initializing weights by changing `--init_gain`. We observe that smaller variance in weight initialization results in less color inversion. 81 | 82 | #### For labels2photo Cityscapes evaluation, why does the pretrained FCN-8s model not work well on the original Cityscapes input images? ([#150](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/150)) 83 | The model was trained on 256x256 images that are resized/upsampled to 1024x2048, so expected input images to the network are very blurry. The purpose of the resizing was to 1) keep the label maps in the original high resolution untouched and 2) avoid the need of changing the standard FCN training code for Cityscapes. 84 | 85 | #### How do I get the `ground-truth` numbers on the labels2photo Cityscapes evaluation? ([#150](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/150)) 86 | You need to resize the original Cityscapes images to 256x256 before running the evaluation code. 87 | 88 | 89 | #### Using resize-conv to reduce checkerboard artifacts ([#190](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190), [#64](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/64)) 90 | This Distill [blog](https://distill.pub/2016/deconv-checkerboard/) discussed one of the potential causes of the checkerboard artifacts. You can fix that issue by switching from "deconvolution" to nearest-neighbor upsampling followed by regular convolution. Here is one implementation provided by [@SsnL](https://github.com/SsnL). You can replace the ConvTranspose2d with the following layers. 91 | ```python 92 | nn.Upsample(scale_factor = 2, mode='bilinear'), 93 | nn.ReflectionPad2d(1), 94 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0), 95 | ``` 96 | We have also noticed that sometimes the checkboard artifacts will go away if you train long enough. Maybe you can try training your model a bit longer. 97 | 98 | #### pix2pix/CycleGAN has no random noise z ([#152](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/152)) 99 | The current pix2pix/CycleGAN model does not take z as input. In both pix2pix and CycleGAN, we tried to add z to the generator: e.g., adding z to a latent state, concatenating with a latent state, applying dropout, etc., but often found the output did not vary significantly as a function of z. Conditional GANs do not need noise as long as the input is sufficiently complex so that the input can kind of play the role of noise. Without noise, the mapping is deterministic. 100 | 101 | Please check out the following papers that show ways of getting z to actually have a substantial effect: e.g., [BicycleGAN](https://github.com/junyanz/BicycleGAN), [AugmentedCycleGAN](https://arxiv.org/abs/1802.10151), [MUNIT](https://arxiv.org/abs/1804.04732), [DRIT](https://arxiv.org/pdf/1808.00948.pdf), etc. 102 | 103 | #### Experiment details (e.g., BW->color) ([#306](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/306)) 104 | You can find more training details and hyperparameter settings in the appendix of [CycleGAN](https://arxiv.org/abs/1703.10593) and [pix2pix](https://arxiv.org/abs/1611.07004) papers. 105 | 106 | #### Results with [Cycada](https://arxiv.org/pdf/1711.03213.pdf) 107 | We generated the [result of translating GTA images to Cityscapes-style images](https://junyanz.github.io/CycleGAN/) using our Torch repo. Our PyTorch and Torch implementation seemed to produce a little bit different results, although we have not measured the FCN score using the pytorch-trained model. To reproduce the result of Cycada, please use the Torch repo for now. 108 | -------------------------------------------------------------------------------- /docs/tips.md: -------------------------------------------------------------------------------- 1 | ## Training/test Tips 2 | #### Training/test options 3 | Please see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. There are some model-specific flags as well, which are added in the model files, such as `--lambda_A` option in `model/cycle_gan_model.py`. The default values of these options are also adjusted in the model files. 4 | #### CPU/GPU (default `--gpu_ids 0`) 5 | Please set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g., `--batch_size 32`) to benefit from multiple GPUs. 6 | 7 | #### Visualization 8 | During training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`. 9 | 10 | #### Preprocessing 11 | Images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `'none'` tries to skip all these preprocessing steps. However, if the image size is not a multiple of some number depending on the number of downsamplings of the generator, you will get an error because the size of the output image may be different from the size of the input image. Therefore, `'none'` option still tries to adjust the image size to be a multiple of 4. You might need a bigger adjustment if you change the generator architecture. Please see `data/base_datset.py` do see how all these were implemented. 12 | 13 | #### Fine-tuning/resume training 14 | To fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count. 15 | 16 | 17 | #### Prepare your own datasets for CycleGAN 18 | You need to create two directories to host images from domain A `/path/to/data/trainA` and from domain B `/path/to/data/trainB`. Then you can train the model with the dataset flag `--dataroot /path/to/data`. Optionally, you can create hold-out test datasets at `/path/to/data/testA` and `/path/to/data/testB` to test your model on unseen images. 19 | 20 | #### Prepare your own datasets for pix2pix 21 | Pix2pix's training requires paired data. We provide a python script to generate training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A: 22 | 23 | Create folder `/path/to/data` with subdirectories `A` and `B`. `A` and `B` should each have their own subdirectories `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc). 24 | 25 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`. 26 | 27 | Once the data is formatted this way, call: 28 | ```bash 29 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 30 | ``` 31 | 32 | This will combine each pair of images (A,B) into a single image file, ready for training. 33 | 34 | 35 | #### About image size 36 | Since the generator architecture in CycleGAN involves a series of downsampling / upsampling operations, the size of the input and output image may not match if the input image size is not a multiple of 4. As a result, you may get a runtime error because the L1 identity loss cannot be enforced with images of different size. Therefore, we slightly resize the image to become multiples of 4 even with `--resize_or_crop none` option. For the same reason, `--fineSize` needs to be a multiple of 4. 37 | 38 | #### Training/Testing with high res images 39 | CycleGAN is quite memory-intensive as four networks (two generators and two discriminators) need to be loaded on one GPU, so a large image cannot be entirely loaded. In this case, we recommend training with cropped images. For example, to generate 1024px results, you can train with `--resize_or_crop scale_width_and_crop --loadSize 1024 --fineSize 360`, and test with `--resize_or_crop scale_width --fineSize 1024`. This way makes sure the training and test will be at the same scale. At test time, you can afford higher resolution because you don’t need to load all networks. 40 | 41 | #### About loss curve 42 | Unfortunately, the loss curve does not reveal much information in training GANs, and CycleGAN is no exception. To check whether the training has converged or not, we recommend periodically generating a few samples and looking at them. 43 | 44 | #### About batch size 45 | For all experiments in the paper, we set the batch size to be 1. If there is room for memory, you can use higher batch size with batch norm or instance norm. (Note that the default batchnorm does not work well with multi-GPU training. You may consider using [synchronized batchnorm](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) instead). But please be aware that it can impact the training. In particular, even with Instance Normalization, different batch sizes can lead to different results. Moreover, increasing `--fineSize` may be a good alternative to increasing the batch size. 46 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | 2 | -linux 3 | dependencies: 4 | - python=3.5.5 5 | - pytorch=0.4 6 | - scipy 7 | - pip: 8 | - dominate==2.3.1 9 | - Pillow==5.0.0 10 | - numpy==1.14.1 11 | - visdom==0.1.7 12 | -------------------------------------------------------------------------------- /imgs/0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/imgs/0.txt -------------------------------------------------------------------------------- /matlab_code_for_synthesizing_gaussian_noise.m: -------------------------------------------------------------------------------- 1 | 2 | path = dir('./datasets/train/ground_truth/*.png') 3 | new_folder = './datasets/train/noisy_train/' 4 | mkdir(new_folder); 5 | 6 | for k = 1:200000 7 | 8 | 9 | img = imread(['./datasets/train/ground_truth/' path(k).name]); 10 | 11 | sigma=randi([1,50],1); 12 | 13 | p_noise=imnoise(img, 'gaussian',0, sigma^2/255^2); 14 | 15 | imwrite(p_noise,['./datasets/train/noisy_train/' path(k).name]); 16 | 17 | 18 | 19 | end -------------------------------------------------------------------------------- /matlab_code_for_synthesizing_speckle_noise.m: -------------------------------------------------------------------------------- 1 | path = dir('./datasets/train/ground_truth/*.png') 2 | new_folder = './datasets/train/noisy_train/' 3 | mkdir(new_folder); 4 | 5 | for k = 1:200000 6 | 7 | 8 | img = imread(['./datasets/train/ground_truth/' path(k).name]); 9 | 10 | 11 | v=randi([1,20],1)/100.0; 12 | 13 | p_noise=imnoise(img, 'speckle', v); 14 | 15 | imwrite(p_noise,['./datasets/train/noisy_train/' path(k).name]); 16 | 17 | 18 | 19 | end 20 | 21 | 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | # model_filename = "models." + model_name + "_model" 10 | model_filename = "models." + "denoise_model" 11 | modellib = importlib.import_module(model_filename) 12 | 13 | # In the file, the class called ModelNameModel() will 14 | # be instantiated. It has to be a subclass of BaseModel, 15 | # and it is case-insensitive. 16 | model = None 17 | target_model_name = 'denoisemodel' 18 | for name, cls in modellib.__dict__.items(): 19 | if name.lower() == target_model_name.lower() \ 20 | and issubclass(cls, BaseModel): 21 | model = cls 22 | 23 | if model is None: 24 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 25 | exit(0) 26 | 27 | return model 28 | 29 | 30 | def get_option_setter(): 31 | model_class = find_model_using_name() 32 | return model_class.modify_commandline_options 33 | 34 | 35 | def create_model(opt): 36 | model = find_model_using_name() 37 | instance = model() 38 | instance.initialize(opt) 39 | print("model [%s] was created" % (instance.name())) 40 | return instance 41 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/cycle_gan_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/cycle_gan_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/dbpn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/dbpn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/dbpn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/dbpn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/denoise_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/denoise_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/derain_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/derain_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/pix2pix_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/pix2pix_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/pix2pix_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/pix2pix_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/test_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/test_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/unet_parts.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_parts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/__pycache__/unet_parts.cpython-37.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train): 11 | return parser 12 | 13 | def name(self): 14 | return 'BaseModel' 15 | 16 | def initialize(self, opt): 17 | self.opt = opt 18 | self.gpu_ids = opt.gpu_ids 19 | self.isTrain = opt.isTrain 20 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 21 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 22 | if opt.resize_or_crop != 'scale_width': 23 | torch.backends.cudnn.benchmark = True 24 | self.loss_names = [] 25 | self.model_names = [] 26 | self.visual_names = [] 27 | self.image_paths = [] 28 | 29 | def set_input(self, input): 30 | pass 31 | 32 | def forward(self): 33 | pass 34 | 35 | # load and print networks; create schedulers 36 | def setup(self, opt, parser=None): 37 | if self.isTrain: 38 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 39 | # self.load_networks2(2000) 40 | if not self.isTrain or opt.continue_train: 41 | # load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 42 | # self.load_networks(load_suffix) 43 | self.load_networks('latest') 44 | 45 | self.print_networks(opt.verbose) 46 | 47 | # make models eval mode during test time 48 | def eval(self): 49 | for name in self.model_names: 50 | if isinstance(name, str): 51 | net = getattr(self, 'net' + name) 52 | net.eval() 53 | 54 | # used in test time, wrapping `forward` in no_grad() so we don't save 55 | # intermediate steps for backprop 56 | def test(self): 57 | with torch.no_grad(): 58 | self.forward() 59 | 60 | # get image paths 61 | def get_image_paths(self): 62 | return self.image_paths 63 | 64 | def optimize_parameters(self): 65 | pass 66 | 67 | # update learning rate (called once every epoch) 68 | def update_learning_rate(self): 69 | for scheduler in self.schedulers: 70 | scheduler.step() 71 | lr = self.optimizers[0].param_groups[0]['lr'] 72 | print('learning rate = %.7f' % lr) 73 | 74 | # return visualization images. train.py will display these images, and save the images to a html 75 | def get_current_visuals(self): 76 | visual_ret = OrderedDict() 77 | for name in self.visual_names: 78 | if isinstance(name, str): 79 | visual_ret[name] = getattr(self, name) 80 | return visual_ret 81 | 82 | # return traning losses/errors. train.py will print out these errors as debugging information 83 | def get_current_losses(self,epoch): 84 | errors_ret = OrderedDict() 85 | loss_names = ['Denoise', 'grad'] 86 | 87 | 88 | for name in loss_names: 89 | if isinstance(name, str): 90 | # float(...) works for both scalar tensor and float number 91 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 92 | return errors_ret 93 | 94 | # save models to the disk 95 | def save_networks(self, epoch): 96 | name=['G'] 97 | for name in name: 98 | # for name in self.model_names: 99 | if isinstance(name, str): 100 | save_filename = '%s_net_%s.pth' % (epoch, name) 101 | save_path = os.path.join(self.save_dir, save_filename) 102 | net = getattr(self, 'net' + name) 103 | 104 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 105 | torch.save(net.module.cpu().state_dict(), save_path) 106 | net.cuda(self.gpu_ids[0]) 107 | else: 108 | torch.save(net.cpu().state_dict(), save_path) 109 | 110 | 111 | 112 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 113 | key = keys[i] 114 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 115 | if module.__class__.__name__.startswith('InstanceNorm') and \ 116 | (key == 'running_mean' or key == 'running_var'): 117 | if getattr(module, key) is None: 118 | state_dict.pop('.'.join(keys)) 119 | if module.__class__.__name__.startswith('InstanceNorm') and \ 120 | (key == 'num_batches_tracked'): 121 | state_dict.pop('.'.join(keys)) 122 | else: 123 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 124 | 125 | # load models from the disk 126 | def load_networks2(self, epoch): 127 | model=['G'] 128 | for name in model: 129 | if isinstance(name, str): 130 | load_filename = 'o_%s_net_%s.pth' % (epoch, name) 131 | load_path = os.path.join(self.save_dir, load_filename) 132 | net = getattr(self, 'net' + name) 133 | if isinstance(net, torch.nn.DataParallel): 134 | net = net.module 135 | print('loading the model from %s' % load_path) 136 | 137 | state_dict = torch.load(load_path, map_location=str(self.device)) 138 | if hasattr(state_dict, '_metadata'): 139 | del state_dict._metadata 140 | 141 | # patch InstanceNorm checkpoints prior to 0.4 142 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 143 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 144 | net.load_state_dict(state_dict) 145 | 146 | def load_networks(self, epoch): 147 | for name in self.model_names: 148 | if isinstance(name, str): 149 | load_filename = '%s_net_%s.pth' % (epoch, name) 150 | load_path = os.path.join(self.save_dir, load_filename) 151 | net = getattr(self, 'net' + name) 152 | if isinstance(net, torch.nn.DataParallel): 153 | net = net.module 154 | print('loading the model from %s' % load_path) 155 | # if you are using PyTorch newer than 0.4 (e.g., built from 156 | # GitHub source), you can remove str() on self.device 157 | state_dict = torch.load(load_path, map_location=str(self.device)) 158 | if hasattr(state_dict, '_metadata'): 159 | del state_dict._metadata 160 | 161 | # patch InstanceNorm checkpoints prior to 0.4 162 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 163 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 164 | net.load_state_dict(state_dict) 165 | 166 | # print network information 167 | def print_networks(self, verbose): 168 | print('---------- Networks initialized -------------') 169 | for name in self.model_names: 170 | if isinstance(name, str): 171 | net = getattr(self, 'net' + name) 172 | num_params = 0 173 | for param in net.parameters(): 174 | num_params += param.numel() 175 | if verbose: 176 | print(net) 177 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 178 | print('-----------------------------------------------') 179 | 180 | # set requies_grad=Fasle to avoid computation 181 | def set_requires_grad(self, nets, requires_grad=False): 182 | if not isinstance(nets, list): 183 | nets = [nets] 184 | for net in nets: 185 | if net is not None: 186 | for param in net.parameters(): 187 | param.requires_grad = requires_grad 188 | -------------------------------------------------------------------------------- /models/denoise_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util.image_pool import ImagePool 4 | from .base_model import BaseModel 5 | from . import networks 6 | from . import pytorch_ssim 7 | import numpy as np 8 | 9 | 10 | 11 | class DenoiseModel(BaseModel): 12 | def name(self): 13 | return 'DenoiseModel' 14 | 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train=True): 17 | 18 | 19 | parser.set_defaults(norm='batch', netG='unet_256') 20 | parser.set_defaults(dataset_mode='aligned') 21 | if is_train: 22 | parser.set_defaults(pool_size=0, no_lsgan=True) 23 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 24 | 25 | return parser 26 | 27 | 28 | 29 | def initialize(self, opt): 30 | BaseModel.initialize(self, opt) 31 | self.isTrain = opt.isTrain 32 | 33 | if self.isTrain: 34 | 35 | self.visual_names = [ 'X','Y','X_denoise1','X_denoise2','n_hat','n_tilde','X_s','X_s_denoise'] 36 | else: 37 | self.visual_names = ['X_denoise'] 38 | 39 | if self.isTrain: 40 | self.model_names = ['G'] 41 | else: # during test time, only load Gs 42 | self.model_names = ['G'] 43 | # load/define networks 44 | # self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 45 | # not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,id=2) 46 | self.netG = networks.define_G(opt.init_type, opt.init_gain, self.gpu_ids) 47 | 48 | 49 | if self.isTrain: 50 | # define loss functions 51 | 52 | self.criterionL1 = torch.nn.L1Loss() 53 | self.criterionL2 = torch.nn.MSELoss() 54 | self.ssim_loss = pytorch_ssim.SSIM() 55 | 56 | self.optimizers = [] 57 | 58 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 59 | lr=opt.lr, betas=(opt.beta1, 0.999)) 60 | 61 | self.optimizers.append(self.optimizer_G) 62 | 63 | def gradient(self,img): 64 | img_5_h = img[:,:,:-1,:-1] 65 | img_6 = img[:,:,:-1,1:] 66 | h_res = img_6 - img_5_h 67 | 68 | img5_v = img[:,:,:-1,:-1] 69 | img_8 = img[:,:,1:,:-1] 70 | v_res = img_8 - img5_v 71 | 72 | grad =(h_res+v_res) * 0.5 73 | 74 | return grad 75 | 76 | 77 | def set_input(self, input,epoch,iteration): 78 | if self.isTrain: 79 | self.X = input['X'].to(self.device) 80 | self.Y = input['Y'].to(self.device) 81 | 82 | self.X_grad = self.gradient(self.X) 83 | 84 | self.epoch = epoch 85 | self.iteration = iteration 86 | 87 | self.image_paths = input['X_paths'] 88 | else: 89 | self.X = input['X'].to(self.device) 90 | self.image_paths = input['X_paths'] 91 | 92 | 93 | def forward(self): 94 | if self.isTrain: 95 | self.n_hat, self.n_tilde, self.X_denoise1, self.X_denoise2 = self.netG(self.X) 96 | 97 | self.n_grad = self.gradient(self.n_tilde) 98 | 99 | noise3 = self.n_tilde.detach() 100 | 101 | a = torch.ones_like(self.n_tilde) * 0.5 102 | mask = torch.bernoulli(a) 103 | mask = mask * 2 - 1 104 | 105 | self.X_s = noise3 * mask + self.Y 106 | 107 | self.X_s[self.X_s > 1.0] = 1.0 108 | self.X_s[self.X_s < 0] = 0 109 | _, _, self.X_s_denoise,_ = self.netG(self.X_s.detach()) 110 | 111 | 112 | 113 | else: 114 | _,_ , self.X_denoise,_ = self.netG(self.X) 115 | 116 | 117 | 118 | def backward_G(self): 119 | tau = int(self.iteration / 500) + 1 120 | 121 | if self.iteration %tau == 0: 122 | self.loss_grad = self.criterionL2(self.n_grad, self.X_grad.detach()) 123 | else: 124 | self.loss_grad = 0 125 | 126 | self.loss_Denoise = self.criterionL2(self.X_s_denoise, self.Y) 127 | 128 | 129 | 130 | self.loss_G = self.loss_Denoise + self.loss_grad 131 | 132 | 133 | 134 | self.loss_G.backward() 135 | 136 | 137 | 138 | 139 | def optimize_parameters(self): 140 | self.forward() 141 | # update D 142 | 143 | self.optimizer_G.zero_grad() 144 | self.backward_G() 145 | self.optimizer_G.step() -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from .unet_parts import * 5 | import functools 6 | from torch.optim import lr_scheduler 7 | 8 | 9 | 10 | ############################################################################### 11 | # Helper Functions 12 | ############################################################################### 13 | 14 | 15 | def get_norm_layer(norm_type='instance'): 16 | if norm_type == 'batch': 17 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 18 | elif norm_type == 'instance': 19 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 20 | elif norm_type == 'none': 21 | norm_layer = None 22 | else: 23 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 24 | return norm_layer 25 | 26 | 27 | def get_scheduler(optimizer, opt): 28 | if opt.lr_policy == 'lambda': 29 | def lambda_rule(epoch): 30 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 31 | return lr_l 32 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 33 | elif opt.lr_policy == 'step': 34 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 35 | elif opt.lr_policy == 'plateau': 36 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 37 | elif opt.lr_policy == 'cosine': 38 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 39 | else: 40 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 41 | return scheduler 42 | 43 | 44 | def init_weights(net, init_type='normal', gain=0.02): 45 | def init_func(m): 46 | classname = m.__class__.__name__ 47 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 48 | if init_type == 'normal': 49 | init.normal_(m.weight.data, 0.0, gain) 50 | elif init_type == 'xavier': 51 | init.xavier_normal_(m.weight.data, gain=gain) 52 | elif init_type == 'kaiming': 53 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 54 | elif init_type == 'orthogonal': 55 | init.orthogonal_(m.weight.data, gain=gain) 56 | else: 57 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 58 | if hasattr(m, 'bias') and m.bias is not None: 59 | init.constant_(m.bias.data, 0.0) 60 | elif classname.find('BatchNorm2d') != -1: 61 | init.normal_(m.weight.data, 1.0, gain) 62 | init.constant_(m.bias.data, 0.0) 63 | 64 | print('initialize network with %s' % init_type) 65 | net.apply(init_func) 66 | 67 | 68 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 69 | if len(gpu_ids) > 0: 70 | assert(torch.cuda.is_available()) 71 | net.to(gpu_ids[0]) 72 | net = torch.nn.DataParallel(net, gpu_ids) 73 | init_weights(net, init_type, gain=init_gain) 74 | return net 75 | 76 | 77 | def define_G(init_type='normal', init_gain=0.02, gpu_ids=[]): 78 | net = denoisenet(feature_num=32) 79 | 80 | 81 | return init_net(net, init_type, init_gain, gpu_ids) 82 | 83 | 84 | 85 | class denoisenet(nn.Module): 86 | def __init__(self, feature_num=8): 87 | super(denoisenet, self).__init__() 88 | 89 | self.inc = inconv(3, feature_num) 90 | 91 | self.down1 = down(feature_num, feature_num*2) 92 | self.down2 = down(feature_num*2, feature_num*4) 93 | self.down3 = down(feature_num*4, feature_num*4) 94 | 95 | self.up1 = up(feature_num*8, feature_num*2) 96 | self.up2 = up(feature_num*4, feature_num*1) 97 | self.up3 = up(feature_num*2, feature_num) 98 | 99 | self.outc=nn.Sequential( 100 | nn.Conv2d(feature_num, 3, kernel_size=3, stride=1, padding=1), 101 | 102 | ) 103 | 104 | 105 | # Noise approximation module 106 | self.s = nn.Sequential( 107 | nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0), 108 | ) 109 | 110 | 111 | 112 | def forward(self, input1): 113 | 114 | fi = self.inc(input1) 115 | x2 = self.down1(fi) 116 | x3 = self.down2(x2) 117 | x4 = self.down3(x3) 118 | 119 | x = self.up1(x4, x3) 120 | x = self.up2(x, x2) 121 | ff3 = self.up3(x, fi) 122 | 123 | cont1 = self.outc(ff3) 124 | 125 | n_hat = input1 - cont1 126 | n_tilde = self.s(n_hat) 127 | 128 | cont2 = input1 - n_tilde 129 | 130 | return n_hat, n_tilde, cont1, cont2 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /models/pytorch-ssim-master/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT 2 | -------------------------------------------------------------------------------- /models/pytorch-ssim-master/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-ssim 2 | 3 | ### Differentiable structural similarity (SSIM) index. 4 | ![einstein](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/einstein.png) ![Max_ssim](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/max_ssim.gif) 5 | 6 | ## Installation 7 | 1. Clone this repo. 8 | 2. Copy "pytorch_ssim" folder in your project. 9 | 10 | ## Example 11 | ### basic usage 12 | ```python 13 | import pytorch_ssim 14 | import torch 15 | from torch.autograd import Variable 16 | 17 | img1 = Variable(torch.rand(1, 1, 256, 256)) 18 | img2 = Variable(torch.rand(1, 1, 256, 256)) 19 | 20 | if torch.cuda.is_available(): 21 | img1 = img1.cuda() 22 | img2 = img2.cuda() 23 | 24 | print(pytorch_ssim.ssim(img1, img2)) 25 | 26 | ssim_loss = pytorch_ssim.SSIM(window_size = 11) 27 | 28 | print(ssim_loss(img1, img2)) 29 | 30 | ``` 31 | ### maximize ssim 32 | ```python 33 | import pytorch_ssim 34 | import torch 35 | from torch.autograd import Variable 36 | from torch import optim 37 | import cv2 38 | import numpy as np 39 | 40 | npImg1 = cv2.imread("einstein.png") 41 | 42 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 43 | img2 = torch.rand(img1.size()) 44 | 45 | if torch.cuda.is_available(): 46 | img1 = img1.cuda() 47 | img2 = img2.cuda() 48 | 49 | 50 | img1 = Variable( img1, requires_grad=False) 51 | img2 = Variable( img2, requires_grad = True) 52 | 53 | 54 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) 55 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0] 56 | print("Initial ssim:", ssim_value) 57 | 58 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) 59 | ssim_loss = pytorch_ssim.SSIM() 60 | 61 | optimizer = optim.Adam([img2], lr=0.01) 62 | 63 | while ssim_value < 0.95: 64 | optimizer.zero_grad() 65 | ssim_out = -ssim_loss(img1, img2) 66 | ssim_value = - ssim_out.data[0] 67 | print(ssim_value) 68 | ssim_out.backward() 69 | optimizer.step() 70 | 71 | ``` 72 | 73 | ## Reference 74 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ 75 | -------------------------------------------------------------------------------- /models/pytorch-ssim-master/einstein.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/pytorch-ssim-master/einstein.png -------------------------------------------------------------------------------- /models/pytorch-ssim-master/max_ssim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/pytorch-ssim-master/max_ssim.gif -------------------------------------------------------------------------------- /models/pytorch-ssim-master/max_ssim.py: -------------------------------------------------------------------------------- 1 | import pytorch_ssim 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import optim 5 | import cv2 6 | import numpy as np 7 | 8 | npImg1 = cv2.imread("einstein.png") 9 | 10 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 11 | img2 = torch.rand(img1.size()) 12 | 13 | if torch.cuda.is_available(): 14 | img1 = img1.cuda() 15 | img2 = img2.cuda() 16 | 17 | 18 | img1 = Variable( img1, requires_grad=False) 19 | img2 = Variable( img2, requires_grad = True) 20 | 21 | 22 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) 23 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0] 24 | print("Initial ssim:", ssim_value) 25 | 26 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) 27 | ssim_loss = pytorch_ssim.SSIM() 28 | 29 | optimizer = optim.Adam([img2], lr=0.01) 30 | 31 | while ssim_value < 0.95: 32 | optimizer.zero_grad() 33 | ssim_out = -ssim_loss(img1, img2) 34 | ssim_value = - ssim_out.data[0] 35 | print(ssim_value) 36 | ssim_out.backward() 37 | optimizer.step() 38 | -------------------------------------------------------------------------------- /models/pytorch-ssim-master/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /models/pytorch-ssim-master/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | setup( 3 | name = 'pytorch_ssim', 4 | packages = ['pytorch_ssim'], # this must be the same as the name above 5 | version = '0.1', 6 | description = 'Differentiable structural similarity (SSIM) index', 7 | author = 'Po-Hsun (Evan) Su', 8 | author_email = 'evan.pohsun.su@gmail.com', 9 | url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo 10 | download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second 11 | keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords 12 | classifiers = [], 13 | ) 14 | -------------------------------------------------------------------------------- /models/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | # ssim_map = ((2*mu1_mu2 + C1)*(2*torch.abs(sigma12) + C2))/((mu1_sq + mu2_sq + C1)*(torch.abs(sigma1_sq) + torch.abs(sigma2_sq) + C2)) 33 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /models/pytorch_ssim/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/pytorch_ssim/__init__.pyc -------------------------------------------------------------------------------- /models/pytorch_ssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/pytorch_ssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/pytorch_ssim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/pytorch_ssim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/ssim2/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian2(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window2(window_size, channel): 12 | _1D_window = gaussian2(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim2(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ## 33 | C3 = C2 / 2 34 | L = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) 35 | C = (2 * torch.sqrt(torch.abs(sigma1_sq)) * torch.sqrt(torch.abs(sigma2_sq)) + C2) / ( 36 | torch.abs(sigma1_sq) + torch.abs(sigma2_sq) + C2) 37 | S = (sigma12 + C3) / (torch.sqrt(torch.abs(sigma1_sq)) * torch.sqrt(torch.abs(sigma2_sq)) + C3) 38 | ## 39 | 40 | # ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 41 | 42 | if size_average: 43 | return L.mean(),C.mean(),S.mean() 44 | else: 45 | return C.mean(1).mean(1).mean(1),S.mean(1).mean(1).mean(1) 46 | 47 | class SSIM2(torch.nn.Module): 48 | def __init__(self, window_size = 11, size_average = True): 49 | super(SSIM2, self).__init__() 50 | self.window_size = window_size 51 | self.size_average = size_average 52 | self.channel = 1 53 | self.window = create_window2(window_size, self.channel) 54 | 55 | def forward(self, img1, img2): 56 | (_, channel, _, _) = img1.size() 57 | 58 | if channel == self.channel and self.window.data.type() == img1.data.type(): 59 | window = self.window 60 | else: 61 | window = create_window2(self.window_size, channel) 62 | 63 | if img1.is_cuda: 64 | window = window.cuda(img1.get_device()) 65 | window = window.type_as(img1) 66 | 67 | self.window = window 68 | self.channel = channel 69 | 70 | 71 | return _ssim2(img1, img2, window, self.window_size, channel, self.size_average) 72 | 73 | def ssim2(img1, img2, window_size = 11, size_average = True): 74 | (_, channel, _, _) = img1.size() 75 | window = create_window2(window_size, channel) 76 | 77 | if img1.is_cuda: 78 | window = window.cuda(img1.get_device()) 79 | window = window.type_as(img1) 80 | 81 | return _ssim2(img1, img2, window, window_size, channel, size_average) 82 | -------------------------------------------------------------------------------- /models/ssim2/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/models/ssim2/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | 4 | 5 | class TestModel(BaseModel): 6 | def name(self): 7 | return 'TestModel' 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train=True): 11 | assert not is_train, 'TestModel cannot be used in train mode' 12 | parser.set_defaults(dataset_mode='single') 13 | parser.add_argument('--model_suffix', type=str, default='', 14 | help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will' 15 | ' be loaded as the generator of TestModel') 16 | 17 | return parser 18 | 19 | def initialize(self, opt): 20 | assert(not opt.isTrain) 21 | BaseModel.initialize(self, opt) 22 | 23 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 24 | self.loss_names = [] 25 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 26 | self.visual_names = ['real_A', 'fake_B'] 27 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 28 | self.model_names = ['G' + opt.model_suffix] 29 | 30 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, 31 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 32 | 33 | # assigns the model to self.netG_[suffix] so that it can be loaded 34 | # please see BaseModel.load_networks 35 | setattr(self, 'netG' + opt.model_suffix, self.netG) 36 | 37 | def set_input(self, input): 38 | # we need to use single_dataset mode 39 | self.real_A = input['A'].to(self.device) 40 | self.image_paths = input['A_paths'] 41 | 42 | def forward(self): 43 | self.fake_B = self.netG(self.real_A) 44 | -------------------------------------------------------------------------------- /models/unet_parts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class double_conv(nn.Module): 7 | '''(conv => BN => ReLU) * 2''' 8 | 9 | def __init__(self, in_ch, out_ch): 10 | super(double_conv, self).__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | return x 21 | 22 | 23 | class inconv(nn.Module): 24 | def __init__(self, in_ch, out_ch): 25 | super(inconv, self).__init__() 26 | self.conv = double_conv(in_ch, out_ch) 27 | 28 | def forward(self, x): 29 | x = self.conv(x) 30 | return x 31 | 32 | 33 | class down(nn.Module): 34 | def __init__(self, in_ch, out_ch): 35 | super(down, self).__init__() 36 | self.mpconv = nn.Sequential( 37 | nn.MaxPool2d(2), 38 | double_conv(in_ch, out_ch) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.mpconv(x) 43 | return x 44 | 45 | 46 | class up(nn.Module): 47 | def __init__(self, in_ch, out_ch, bilinear=True): 48 | super(up, self).__init__() 49 | 50 | # would be a nice idea if the upsampling could be learned too, 51 | # but my machine do not have enough memory to handle all those weights 52 | if bilinear: 53 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 54 | else: 55 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 56 | 57 | self.conv = double_conv(in_ch, out_ch) 58 | 59 | def forward(self, x1, x2): 60 | x1 = self.up(x1) 61 | 62 | 63 | 64 | # for padding issues, see 65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 67 | 68 | x = torch.cat([x2, x1], dim=1) 69 | x = self.conv(x) 70 | return x 71 | 72 | 73 | class outconv(nn.Module): 74 | def __init__(self, in_ch, out_ch): 75 | super(outconv, self).__init__() 76 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 77 | 78 | def forward(self, x): 79 | x = self.conv(x) 80 | 81 | 82 | return x -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | #. -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/test_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/train_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 15 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 16 | parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 17 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 18 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 19 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 20 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 21 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 22 | parser.add_argument('--ndf', type=int, default=16, help='# of discrim filters in first conv layer') 23 | parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD') 24 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='selects model to use for netG') 25 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 28 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 29 | parser.add_argument('--model', type=str, default='derain') 30 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 31 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 32 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 33 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 35 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 36 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 37 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 38 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') 39 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 40 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 41 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 42 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 43 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 44 | self.initialized = True 45 | return parser 46 | 47 | def gather_options(self): 48 | # initialize parser with basic options 49 | if not self.initialized: 50 | parser = argparse.ArgumentParser( 51 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 52 | parser = self.initialize(parser) 53 | 54 | # get the basic options 55 | opt, _ = parser.parse_known_args() 56 | 57 | # modify model-related parser options 58 | # model_name = opt.model 59 | model_option_setter = models.get_option_setter() 60 | parser = model_option_setter(parser, self.isTrain) 61 | opt, _ = parser.parse_known_args() # parse again with the new defaults 62 | 63 | # modify dataset-related parser options 64 | dataset_name = opt.dataset_mode 65 | dataset_option_setter = data.get_option_setter(dataset_name) 66 | parser = dataset_option_setter(parser, self.isTrain) 67 | 68 | self.parser = parser 69 | 70 | return parser.parse_args() 71 | 72 | def print_options(self, opt): 73 | message = '' 74 | message += '----------------- Options ---------------\n' 75 | for k, v in sorted(vars(opt).items()): 76 | comment = '' 77 | default = self.parser.get_default(k) 78 | if v != default: 79 | comment = '\t[default: %s]' % str(default) 80 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 81 | message += '----------------- End -------------------' 82 | print(message) 83 | 84 | # save to the disk 85 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 86 | util.mkdirs(expr_dir) 87 | file_name = os.path.join(expr_dir, 'opt.txt') 88 | with open(file_name, 'wt') as opt_file: 89 | opt_file.write(message) 90 | opt_file.write('\n') 91 | 92 | def parse(self): 93 | 94 | opt = self.gather_options() 95 | opt.isTrain = self.isTrain # train or test 96 | 97 | # process opt.suffix 98 | if opt.suffix: 99 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 100 | opt.name = opt.name + suffix 101 | 102 | self.print_options(opt) 103 | 104 | # set gpu ids 105 | str_ids = opt.gpu_ids.split(',') 106 | opt.gpu_ids = [] 107 | for str_id in str_ids: 108 | id = int(str_id) 109 | if id >= 0: 110 | opt.gpu_ids.append(id) 111 | if len(opt.gpu_ids) > 0: 112 | torch.cuda.set_device(opt.gpu_ids[0]) 113 | 114 | self.opt = opt 115 | return self.opt 116 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | # Dropout and Batchnorm has different behavioir during training and test. 12 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 13 | parser.add_argument('--num_test', type=int, default=50000, help='how many test images to run') 14 | 15 | parser.set_defaults(model='test') 16 | # To avoid cropping, the loadSize should be the same as fineSize 17 | parser.set_defaults(loadSize=parser.get_default('fineSize')) 18 | self.isTrain = False 19 | return parser 20 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=5000, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--display_id', type=int, default=-1, help='window id of the web display') 10 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 11 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 12 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 13 | parser.add_argument('--update_html_freq', type=int, default=1000000, help='frequency of saving training results to html') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=50000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=50, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 18 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 19 | # parser.add_argument('--continue_train', type=int, default=1, help='continue training: load the latest model') 20 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 21 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 22 | parser.add_argument('--niter', type=int, default=500, help='# of iter at starting learning rate') 23 | parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero') 24 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 25 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 26 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 27 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 28 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 29 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 30 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 31 | 32 | self.isTrain = True 33 | return parser 34 | -------------------------------------------------------------------------------- /psnr_and_ssim.py: -------------------------------------------------------------------------------- 1 | import skimage 2 | import cv2 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import os 5 | 6 | 7 | def calc_psnr(im1, im2): 8 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 9 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 10 | return compare_psnr(im1_y, im2_y) 11 | 12 | def calc_ssim(im1, im2): 13 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 14 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 15 | 16 | return compare_ssim(im1_y, im2_y) 17 | 18 | 19 | ## ground truth 20 | dir='./datasets/test/ground_truth/' 21 | 22 | 23 | ## denoiseed results 24 | # dir2="./results/gaussian_pretrained/test_latest/images/" 25 | # dir2="./results/speckle_pretrained/test_latest/images/" 26 | # dir2="./results/poisson_pretrained/test_latest/images/" 27 | dir2="./results/new/test_latest/images/" 28 | 29 | ssim=0 30 | psnr=0 31 | 32 | total_num=0 33 | for picname in os.listdir(dir): 34 | img1 = cv2.imread(dir+picname) 35 | name = picname.split('.')[0] 36 | img2 = cv2.imread(dir2+name+'_X_denoise.png') 37 | (h, w, n) = img1.shape 38 | (h2, w2, n) = img2.shape 39 | 40 | if h2 != h or w2 != w: 41 | print('error', picname) 42 | break 43 | 44 | a = calc_ssim(img1, img2) 45 | b = calc_psnr(img1, img2) 46 | print(picname, ':', a, b) 47 | ssim += a 48 | psnr += b 49 | 50 | total_num= total_num+1 51 | 52 | 53 | ssim=ssim/total_num 54 | psnr=psnr/total_num 55 | print("ssim=",ssim) 56 | print("psnr=",psnr) 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /python_code_for_synthesizing_poisson_noise.py: -------------------------------------------------------------------------------- 1 | #-*- coding=utf-8 -*- 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from matplotlib.image import imread, imsave 5 | from matplotlib.image import imsave 6 | import tensorflow as tf 7 | import os 8 | import cv2 9 | import scipy 10 | import scipy.misc 11 | 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 13 | config = tf.ConfigProto() 14 | #config.gpu_options.per_process_gpu_memory_fraction = 0.5 15 | config.gpu_options.allow_growth = True 16 | 17 | 18 | 19 | def add_train_noise_tf(x,lam_max): 20 | chi_rng = tf.random_uniform(shape=[1, 1, 1], minval=5, maxval=lam_max) 21 | # chi_rng = tf.random_uniform(shape=[1, 1, 1], minval=30, maxval=30) 22 | # print(chi_rng.shape) 23 | # out = tf.random_poisson(chi_rng*(x+0.5), shape=[])/chi_rng - 0.5 24 | # print(out) 25 | #chi_rng = lam_max 26 | return tf.random_poisson(chi_rng*(x), shape=[])/chi_rng 27 | 28 | 29 | a= os.path.exists('./datasets/train/noisy_train/') 30 | if a: 31 | pass 32 | else: 33 | os.mkdir('./datasets/train/noisy_train/') 34 | count = 0 35 | 36 | with tf.Session(config = config) as sess: 37 | path = os.listdir('./datasets/train/ground_truth/') 38 | for img_name in path: 39 | img = imread(os.path.join('./datasets/train/ground_truth/',img_name)) 40 | # img = img /255.0 41 | count += 1 42 | print(count) 43 | img_shape = img.shape 44 | img_input = tf.placeholder(tf.float32,img_shape) 45 | 46 | img_tensor = add_train_noise_tf(img_input,50) 47 | img_np = sess.run(img_tensor,feed_dict={img_input:img}) 48 | imsave(os.path.join('./datasets/train/noisy_train/',img_name), np.clip(img_np,0.0,1.0)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | -------------------------------------------------------------------------------- /results/0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/results/0.txt -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | 8 | 9 | if __name__ == '__main__': 10 | opt = TestOptions().parse() 11 | # hard-code some parameters for test 12 | opt.num_threads = 1 # test code only supports num_threads = 1 13 | opt.batch_size = 1 # test code only supports batch_size = 1 14 | opt.serial_batches = True # no shuffle 15 | opt.no_flip = True # no flip 16 | opt.display_id = -1 # no visdom display 17 | data_loader = CreateDataLoader(opt) 18 | dataset = data_loader.load_data() 19 | model = create_model(opt) 20 | model.setup(opt) 21 | # create a website 22 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 24 | # test with eval mode. This only affects layers like batchnorm and dropout. 25 | # pix2pix: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 26 | # CycleGAN: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 27 | if opt.eval: 28 | model.eval() 29 | for i, data in enumerate(dataset): 30 | if i >= opt.num_test: 31 | break 32 | model.set_input(data,1000,1000) 33 | model.test() 34 | visuals = model.get_current_visuals() 35 | img_path = model.get_image_paths() 36 | if i % 5 == 0: 37 | print('processing (%04d)-th image... %s' % (i, img_path)) 38 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 39 | # save the website 40 | # webpage.save() 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | if __name__ == '__main__': 8 | opt = TrainOptions().parse() 9 | data_loader = CreateDataLoader(opt) 10 | dataset = data_loader.load_data() 11 | dataset_size = len(data_loader) 12 | print('#training images = %d' % dataset_size) 13 | 14 | model = create_model(opt) 15 | model.setup(opt) 16 | visualizer = Visualizer(opt) 17 | total_steps = 0 18 | iteration=0 19 | 20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 21 | epoch_start_time = time.time() 22 | iter_data_time = time.time() 23 | epoch_iter = 0 24 | print(iteration) 25 | 26 | for i, data in enumerate(dataset): 27 | iteration=iteration+1 28 | iter_start_time = time.time() 29 | if total_steps % opt.print_freq == 0: 30 | t_data = iter_start_time - iter_data_time 31 | visualizer.reset() 32 | total_steps += opt.batch_size 33 | epoch_iter += opt.batch_size 34 | model.set_input(data,epoch,iteration) 35 | model.optimize_parameters() 36 | 37 | if total_steps % opt.display_freq == 0: 38 | save_result = total_steps % opt.update_html_freq == 0 39 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 40 | 41 | if total_steps % opt.print_freq == 0: 42 | losses = model.get_current_losses(epoch) 43 | t = (time.time() - iter_start_time) / opt.batch_size 44 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 45 | if opt.display_id > 0: 46 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 47 | 48 | if total_steps % opt.save_latest_freq == 0: 49 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 50 | save_suffix = 'iter_%d' % total_steps if opt.save_by_iter else 'latest' 51 | model.save_networks(save_suffix) 52 | 53 | iter_data_time = time.time() 54 | if epoch % opt.save_epoch_freq == 0: 55 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 56 | model.save_networks('latest') 57 | model.save_networks(epoch) 58 | 59 | print('End of epoch %d / %d \t Time Taken: %d sec' % 60 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 61 | model.update_learning_rate() 62 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | #. -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/html.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/image_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/image_pool.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuangxingLin123/Noise2Grad_Pytorch_code/a38c2b3c53511deaa51aa5fb0bc7b0d70af00bfd/util/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | # import dominate 2 | # from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | # self.doc = dominate.document(title=title) 18 | # if reflesh > 0: 19 | # with self.doc.head: 20 | # meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | # def tensor2im(input_image, imtype=np.uint8): 11 | # if isinstance(input_image, torch.Tensor): 12 | # image_tensor = input_image.data 13 | # else: 14 | # return input_image 15 | # image_numpy = image_tensor[0].cpu().float().numpy() 16 | # if image_numpy.shape[0] == 1: 17 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | # return image_numpy.astype(imtype) 20 | def tensor2im(input_image, imtype=np.uint8): 21 | if isinstance(input_image, torch.Tensor): 22 | image_tensor = input_image.data 23 | else: 24 | return input_image 25 | image_numpy = image_tensor[0].cpu().float().numpy() 26 | if image_numpy.shape[0] == 1: 27 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 28 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 29 | #mychange 30 | image_numpy=np.clip(image_numpy,0,255) 31 | return image_numpy.astype(imtype) 32 | 33 | def diagnose_network(net, name='network'): 34 | mean = 0.0 35 | count = 0 36 | for param in net.parameters(): 37 | if param.grad is not None: 38 | mean += torch.mean(torch.abs(param.grad.data)) 39 | count += 1 40 | if count > 0: 41 | mean = mean / count 42 | print(name) 43 | print(mean) 44 | 45 | 46 | def save_image(image_numpy, image_path): 47 | image_pil = Image.fromarray(image_numpy) 48 | image_pil.save(image_path) 49 | 50 | 51 | def print_numpy(x, val=True, shp=False): 52 | x = x.astype(np.float64) 53 | if shp: 54 | print('shape,', x.shape) 55 | if val: 56 | x = x.flatten() 57 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 58 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 59 | 60 | 61 | def mkdirs(paths): 62 | if isinstance(paths, list) and not isinstance(paths, str): 63 | for path in paths: 64 | mkdir(path) 65 | else: 66 | mkdir(paths) 67 | 68 | 69 | def mkdir(path): 70 | if not os.path.exists(path): 71 | os.makedirs(path) 72 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util 7 | # from . import html 8 | # from scipy.misc import imresize 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | # save image to the disk 17 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 18 | image_dir = webpage.get_image_dir() 19 | short_path = ntpath.basename(image_path[0]) 20 | name = os.path.splitext(short_path)[0] 21 | 22 | # webpage.add_header(name) 23 | ims, txts, links = [], [], [] 24 | 25 | for label, im_data in visuals.items(): 26 | # if label == 'fake_B2': 27 | im = util.tensor2im(im_data) 28 | # set jpg or png 29 | image_name = '%s_%s.png' % (name, label) 30 | save_path = os.path.join(image_dir, image_name) 31 | h, w, _ = im.shape 32 | # if aspect_ratio > 1.0: 33 | # im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 34 | # if aspect_ratio < 1.0: 35 | # im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 36 | util.save_image(im, save_path) 37 | 38 | ims.append(image_name) 39 | txts.append(label) 40 | links.append(image_name) 41 | # im = util.tensor2im(im_data) 42 | # image_name = '%s_%s.png' % (name, label) 43 | # save_path = os.path.join(image_dir, image_name) 44 | # h, w, _ = im.shape 45 | # if aspect_ratio > 1.0: 46 | # im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 47 | # if aspect_ratio < 1.0: 48 | # im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 49 | # util.save_image(im, save_path) 50 | # 51 | # ims.append(image_name) 52 | # txts.append(label) 53 | # links.append(image_name) 54 | # webpage.add_images(ims, txts, links, width=width) 55 | 56 | 57 | class Visualizer(): 58 | def __init__(self, opt): 59 | self.display_id = opt.display_id 60 | self.use_html = opt.isTrain and not opt.no_html 61 | self.win_size = opt.display_winsize 62 | self.name = opt.name 63 | self.opt = opt 64 | self.saved = False 65 | if self.display_id > 0: 66 | import visdom 67 | self.ncols = opt.display_ncols 68 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 69 | 70 | if self.use_html: 71 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 72 | self.img_dir = os.path.join(self.web_dir, 'images') 73 | print('create web directory %s...' % self.web_dir) 74 | util.mkdirs([self.web_dir, self.img_dir]) 75 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 76 | with open(self.log_name, "a") as log_file: 77 | now = time.strftime("%c") 78 | log_file.write('================ Training Loss (%s) ================\n' % now) 79 | 80 | def reset(self): 81 | self.saved = False 82 | 83 | def throw_visdom_connection_error(self): 84 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 85 | exit(1) 86 | 87 | # |visuals|: dictionary of images to display or save 88 | def display_current_results(self, visuals, epoch, save_result): 89 | if self.display_id > 0: # show images in the browser 90 | ncols = self.ncols 91 | if ncols > 0: 92 | ncols = min(ncols, len(visuals)) 93 | h, w = next(iter(visuals.values())).shape[:2] 94 | table_css = """""" % (w, h) 98 | title = self.name 99 | label_html = '' 100 | label_html_row = '' 101 | images = [] 102 | idx = 0 103 | for label, image in visuals.items(): 104 | image_numpy = util.tensor2im(image) 105 | label_html_row += '%s' % label 106 | images.append(image_numpy.transpose([2, 0, 1])) 107 | idx += 1 108 | if idx % ncols == 0: 109 | label_html += '%s' % label_html_row 110 | label_html_row = '' 111 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 112 | while idx % ncols != 0: 113 | images.append(white_image) 114 | label_html_row += '' 115 | idx += 1 116 | if label_html_row != '': 117 | label_html += '%s' % label_html_row 118 | # pane col = image row 119 | try: 120 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 121 | padding=2, opts=dict(title=title + ' images')) 122 | label_html = '%s
' % label_html 123 | self.vis.text(table_css + label_html, win=self.display_id + 2, 124 | opts=dict(title=title + ' labels')) 125 | except VisdomExceptionBase: 126 | self.throw_visdom_connection_error() 127 | 128 | else: 129 | idx = 1 130 | for label, image in visuals.items(): 131 | image_numpy = util.tensor2im(image) 132 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 133 | win=self.display_id + idx) 134 | idx += 1 135 | 136 | if self.use_html and (save_result or not self.saved): # save images to a html file 137 | self.saved = True 138 | for label, image in visuals.items(): 139 | image_numpy = util.tensor2im(image) 140 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 141 | util.save_image(image_numpy, img_path) 142 | # update website 143 | # webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 144 | for n in range(epoch, 0, -1): 145 | # webpage.add_header('epoch [%d]' % n) 146 | ims, txts, links = [], [], [] 147 | 148 | for label, image_numpy in visuals.items(): 149 | image_numpy = util.tensor2im(image) 150 | img_path = 'epoch%.3d_%s.png' % (n, label) 151 | ims.append(img_path) 152 | txts.append(label) 153 | links.append(img_path) 154 | # webpage.add_images(ims, txts, links, width=self.win_size) 155 | # webpage.save() 156 | 157 | # losses: dictionary of error labels and values 158 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 159 | if not hasattr(self, 'plot_data'): 160 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 161 | self.plot_data['X'].append(epoch + counter_ratio) 162 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 163 | try: 164 | self.vis.line( 165 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 166 | Y=np.array(self.plot_data['Y']), 167 | opts={ 168 | 'title': self.name + ' loss over time', 169 | 'legend': self.plot_data['legend'], 170 | 'xlabel': 'epoch', 171 | 'ylabel': 'loss'}, 172 | win=self.display_id) 173 | except VisdomExceptionBase: 174 | self.throw_visdom_connection_error() 175 | 176 | # losses: same format as |losses| of plot_current_losses 177 | def print_current_losses(self, epoch, i, losses, t, t_data): 178 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 179 | for k, v in losses.items(): 180 | message += '%s: %.3f ' % (k, v) 181 | 182 | print(message) 183 | with open(self.log_name, "a") as log_file: 184 | log_file.write('%s\n' % message) 185 | --------------------------------------------------------------------------------