├── LICENSE ├── README.md ├── data └── data_split.py └── feeders ├── __init__.py └── feeder.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 KAIST-VICLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

C-DiffSET: Leveraging Latent Diffusion for SAR-to-EO Image Translation with Confidence-Guided Reliable Object Generation

3 | 4 |
5 | Jeonghyeok Do1     6 | Jaehyup Lee†2     7 | Munchurl Kim†1 8 |
9 |
10 |
11 | Co-corresponding authors 12 |
13 |
14 | 1Korea Advanced Institute of Science and Technology, South Korea 15 |
16 |
17 | 2Kyungpook National University, South Korea 18 |
19 | 20 |
21 |

22 | 23 | 24 | 25 | 26 | 27 | 28 | GitHub Repo stars 29 |

30 |
31 |
32 | 33 | --- 34 | 35 |

36 | This repository is the official PyTorch implementation of "C-DiffSET: Leveraging Latent Diffusion for SAR-to-EO Image Translation with Confidence-Guided Reliable Object Generation". C-DiffSET achieves state-of-the-art results on multiple datasets, outperforming the recent image-to-image translation methods and SAR-to-EO image translation methods. 37 |

38 | 39 | --- 40 | 41 | ## 📧 News 42 | - **⚠ The code will be released later** 43 | - **Dec 9, 2024:** This repository is created 44 | 45 | --- 46 | 47 | ## Results 48 | Please visit our [project page](https://kaist-viclab.github.io/C-DiffSET_site/) for more experimental results. 49 | 50 | ## License 51 | The source codes including the checkpoint can be freely used for research and education only. Any commercial use should get formal permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). 52 | 53 | ## Acknowledgement 54 | This repository is built upon [FMA-Net](https://github.com/KAIST-VICLab/FMA-Net/). -------------------------------------------------------------------------------- /data/data_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | '.tif', '.TIF', '.tiff', '.TIFF', 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir, max_dataset_size=float("inf")): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(root, fname) 25 | images.append(path) 26 | return images[:min(max_dataset_size, len(images))] 27 | 28 | 29 | def make_dataset_relpath(dir, abs_path, max_dataset_size=float("inf")): 30 | images = [] 31 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 32 | 33 | for root, _, fnames in sorted(os.walk(dir)): 34 | for fname in fnames: 35 | if is_image_file(fname): 36 | path = os.path.join(root, fname) 37 | path = os.path.relpath(path, abs_path) 38 | images.append(path) 39 | return images[:min(max_dataset_size, len(images))] 40 | 41 | 42 | def make_dataset_list(dir_list, max_dataset_size=float("inf")): 43 | images = [] 44 | for dir in sorted(dir_list): 45 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 46 | for root, _, fnames in sorted(os.walk(dir)): 47 | for fname in fnames: 48 | if is_image_file(fname): 49 | path = os.path.join(root, fname) 50 | images.append(path) 51 | return images[:min(max_dataset_size, len(images))] 52 | 53 | 54 | def make_dataset_list_relpath(dir_list, abs_path, max_dataset_size=float("inf")): 55 | images = [] 56 | for dir in sorted(dir_list): 57 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 58 | for root, _, fnames in sorted(os.walk(dir)): 59 | for fname in fnames: 60 | if is_image_file(fname): 61 | path = os.path.join(root, fname) 62 | path = os.path.relpath(path, abs_path) 63 | images.append(path) 64 | return images[:min(max_dataset_size, len(images))] 65 | 66 | 67 | def SpaceNet_split(dataroot, ratio=80): 68 | ''' https://arxiv.org/abs/2004.06500 ''' 69 | SAVE_PATH = './SpaceNet_split/' 70 | if not os.path.exists(SAVE_PATH): 71 | os.makedirs(SAVE_PATH) 72 | 73 | random.seed(2025) 74 | 75 | eo_dataroot = os.path.join(dataroot, 'train/AOI_11_Rotterdam/PS-RGB/') 76 | list = make_dataset_relpath(eo_dataroot, dataroot) 77 | test_list = random.sample(list, int(len(list) * (1 - ratio / 100))) 78 | train_list = [x for x in list if x not in test_list] 79 | 80 | train_list = sorted(train_list) 81 | test_list = sorted(test_list) 82 | 83 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.txt', 'w') as f: 84 | for path in train_list: 85 | f.write(path + '\n') 86 | 87 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.pkl', 'wb') as f: 88 | pickle.dump(train_list, f) 89 | 90 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.txt', 'w') as f: 91 | for path in test_list: 92 | f.write(path + '\n') 93 | 94 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.pkl', 'wb') as f: 95 | pickle.dump(test_list, f) 96 | 97 | 98 | def QXS_split(dataroot, ratio=80): 99 | ''' https://arxiv.org/abs/2103.08259 ''' 100 | SAVE_PATH = './QXS_split/' 101 | if not os.path.exists(SAVE_PATH): 102 | os.makedirs(SAVE_PATH) 103 | 104 | random.seed(2025) 105 | 106 | eo_dataroot = os.path.join(dataroot, 'opt_256_oc_0.2') 107 | list = make_dataset_relpath(eo_dataroot, dataroot) 108 | test_list = random.sample(list, int(len(list) * (1 - ratio / 100))) 109 | train_list = [x for x in list if x not in test_list] 110 | 111 | train_list = sorted(train_list) 112 | test_list = sorted(test_list) 113 | 114 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.txt', 'w') as f: 115 | for path in train_list: 116 | f.write(path + '\n') 117 | 118 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.pkl', 'wb') as f: 119 | pickle.dump(train_list, f) 120 | 121 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.txt', 'w') as f: 122 | for path in test_list: 123 | f.write(path + '\n') 124 | 125 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.pkl', 'wb') as f: 126 | pickle.dump(test_list, f) 127 | 128 | 129 | def SAROpt_split(dataroot): 130 | ''' https://ieeexplore.ieee.org/document/9779739 ''' 131 | SAVE_PATH = './SAROpt_split/' 132 | if not os.path.exists(SAVE_PATH): 133 | os.makedirs(SAVE_PATH) 134 | 135 | train_dataroot = os.path.join(dataroot, 'trainB') 136 | test_dataroot = os.path.join(dataroot, 'testB') 137 | 138 | train_list = make_dataset_relpath(train_dataroot, dataroot) 139 | test_list = make_dataset_relpath(test_dataroot, dataroot) 140 | 141 | train_list = sorted(train_list) 142 | test_list = sorted(test_list) 143 | 144 | with open(f'{SAVE_PATH}/train_eo_list.txt', 'w') as f: 145 | for path in train_list: 146 | f.write(path + '\n') 147 | 148 | with open(f'{SAVE_PATH}/train_eo_list.pkl', 'wb') as f: 149 | pickle.dump(train_list, f) 150 | 151 | with open(f'{SAVE_PATH}/test_eo_list.txt', 'w') as f: 152 | for path in test_list: 153 | f.write(path + '\n') 154 | 155 | with open(f'{SAVE_PATH}/test_eo_list.pkl', 'wb') as f: 156 | pickle.dump(test_list, f) 157 | 158 | 159 | SpaceNet_split('./SpaceNet6', ratio=80) 160 | QXS_split('./QXS_SAROPT', ratio=80) 161 | SAROpt_split('./SAR2Opt') 162 | -------------------------------------------------------------------------------- /feeders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import feeder -------------------------------------------------------------------------------- /feeders/feeder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | import numpy as np 8 | import os 9 | import tifffile 10 | import albumentations as Alb 11 | from albumentations.pytorch import ToTensorV2 12 | from PIL import Image 13 | import pickle 14 | 15 | 16 | class SpaceNetFeeder(Dataset): 17 | ''' https://arxiv.org/abs/2004.06500 ''' 18 | def __init__(self, dataroot, listroot, load_size=900, crop_size=512, hflip=False, vflip=False, rot=False): 19 | if listroot.endswith('.txt'): 20 | with open(listroot, 'r') as f: 21 | eo_paths = f.read().splitlines() 22 | elif listroot.endswith('.pkl'): 23 | with open(listroot, 'rb') as f: 24 | eo_paths = pickle.load(f) 25 | else: 26 | print("Unsupported file format.") 27 | 28 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths] 29 | self.sar_paths = [x.replace('PS-RGB', 'SAR-Intensity') for x in self.eo_paths] 30 | 31 | self.load_size = load_size 32 | self.crop_size = crop_size 33 | assert (self.load_size >= self.crop_size) 34 | 35 | self.hflip = hflip 36 | self.vflip = vflip 37 | self.rot = rot 38 | 39 | if 'train' in listroot: 40 | self.split = 'train' 41 | elif 'test' in listroot: 42 | self.split = 'test' 43 | else: 44 | print("Unsupported file format.") 45 | 46 | def im_percent_norm(self, x, p=(1, 99), eps=1 / (2 ** 10)): 47 | pv = np.percentile(x, p, axis=(0, 1)) 48 | y = x.astype(np.float32) 49 | pmin = pv[0, ...] 50 | pmax = pv[1, ...] 51 | y = np.clip(y, pmin, pmax) 52 | y = (y - pmin) / np.maximum((pmax - pmin), eps) * 255.0 53 | return y 54 | 55 | def __getitem__(self, index): 56 | eo_path = self.eo_paths[index] 57 | sar_path = self.sar_paths[index] 58 | EO = np.array(Image.open(eo_path).convert('RGB')).astype(np.float32) 59 | SAR = tifffile.imread(sar_path).astype(np.float32) 60 | SAR = self.im_percent_norm(SAR) 61 | SAR = np.stack((SAR[:, :, 0], (SAR[:, :, 1] + SAR[:, :, 2]) / 2, SAR[:, :, 3]), axis=-1) 62 | 63 | transform = [] 64 | if self.split == 'train': 65 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size)) 66 | if self.hflip: 67 | transform.append(Alb.HorizontalFlip(p=0.5)) 68 | if self.vflip: 69 | transform.append(Alb.VerticalFlip(p=0.5)) 70 | if self.rot: 71 | transform.append(Alb.RandomRotate90(p=0.5)) 72 | else: 73 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size)) 74 | 75 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0)) 76 | transform.append(ToTensorV2()) 77 | 78 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'}) 79 | augmented = transform(image=SAR, image2=EO) 80 | SAR, EO = augmented['image'], augmented['image2'] 81 | return SAR, EO 82 | 83 | def __len__(self): 84 | return len(self.eo_paths) 85 | 86 | 87 | class QXSFeeder(Dataset): 88 | ''' https://arxiv.org/abs/2103.08259 ''' 89 | def __init__(self, dataroot, listroot, load_size=256, crop_size=256, hflip=False, vflip=False, rot=False): 90 | if listroot.endswith('.txt'): 91 | with open(listroot, 'r') as f: 92 | eo_paths = f.read().splitlines() 93 | elif listroot.endswith('.pkl'): 94 | with open(listroot, 'rb') as f: 95 | eo_paths = pickle.load(f) 96 | else: 97 | print("Unsupported file format.") 98 | 99 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths] 100 | self.sar_paths = [x.replace('opt_256_oc_0.2', 'sar_256_oc_0.2') for x in self.eo_paths] 101 | 102 | self.load_size = load_size 103 | self.crop_size = crop_size 104 | assert (self.load_size >= self.crop_size) 105 | 106 | self.hflip = hflip 107 | self.vflip = vflip 108 | self.rot = rot 109 | 110 | if 'train' in listroot: 111 | self.split = 'train' 112 | elif 'test' in listroot: 113 | self.split = 'test' 114 | else: 115 | print("Unsupported file format.") 116 | 117 | def __getitem__(self, index): 118 | eo_path = self.eo_paths[index] 119 | sar_path = self.sar_paths[index] 120 | EO = np.array(Image.open(eo_path).convert('RGB')) 121 | SAR = np.array(Image.open(sar_path).convert('RGB')) 122 | 123 | transform = [] 124 | if self.split == 'train': 125 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size)) 126 | if self.hflip: 127 | transform.append(Alb.HorizontalFlip(p=0.5)) 128 | if self.vflip: 129 | transform.append(Alb.VerticalFlip(p=0.5)) 130 | if self.rot: 131 | transform.append(Alb.RandomRotate90(p=0.5)) 132 | else: 133 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size)) 134 | 135 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0)) 136 | transform.append(ToTensorV2()) 137 | 138 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'}) 139 | augmented = transform(image=SAR, image2=EO) 140 | SAR, EO = augmented['image'], augmented['image2'] 141 | return SAR, EO 142 | 143 | def __len__(self): 144 | return len(self.eo_paths) 145 | 146 | 147 | class SAROptFeeder(Dataset): 148 | ''' https://ieeexplore.ieee.org/document/9779739 ''' 149 | def __init__(self, dataroot, listroot, load_size=600, crop_size=512, hflip=False, vflip=False, rot=False): 150 | if listroot.endswith('.txt'): 151 | with open(listroot, 'r') as f: 152 | eo_paths = f.read().splitlines() 153 | elif listroot.endswith('.pkl'): 154 | with open(listroot, 'rb') as f: 155 | eo_paths = pickle.load(f) 156 | else: 157 | print("Unsupported file format.") 158 | 159 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths] 160 | 161 | if 'train' in listroot: 162 | self.split = 'train' 163 | self.sar_paths = [x.replace('trainB', 'trainA') for x in self.eo_paths] 164 | elif 'test' in listroot: 165 | self.split = 'test' 166 | self.sar_paths = [x.replace('testB', 'testA') for x in self.eo_paths] 167 | else: 168 | print("Unsupported file format.") 169 | 170 | self.load_size = load_size 171 | self.crop_size = crop_size 172 | assert (self.load_size >= self.crop_size) 173 | 174 | self.hflip = hflip 175 | self.vflip = vflip 176 | self.rot = rot 177 | 178 | def __getitem__(self, index): 179 | eo_path = self.eo_paths[index] 180 | sar_path = self.sar_paths[index] 181 | EO = np.array(Image.open(eo_path).convert('RGB')) 182 | SAR = np.array(Image.open(sar_path).convert('RGB')) 183 | 184 | transform = [] 185 | if self.split == 'train': 186 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size)) 187 | if self.hflip: 188 | transform.append(Alb.HorizontalFlip(p=0.5)) 189 | if self.vflip: 190 | transform.append(Alb.VerticalFlip(p=0.5)) 191 | if self.rot: 192 | transform.append(Alb.RandomRotate90(p=0.5)) 193 | else: 194 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size)) 195 | 196 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0)) 197 | transform.append(ToTensorV2()) 198 | 199 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'}) 200 | augmented = transform(image=SAR, image2=EO) 201 | SAR, EO = augmented['image'], augmented['image2'] 202 | return SAR, EO 203 | 204 | def __len__(self): 205 | return len(self.eo_paths) --------------------------------------------------------------------------------