├── LICENSE
├── README.md
├── data
└── data_split.py
└── feeders
├── __init__.py
└── feeder.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 KAIST-VICLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
C-DiffSET: Leveraging Latent Diffusion for SAR-to-EO Image Translation with Confidence-Guided Reliable Object Generation
3 |
4 |
9 |
10 |
11 | †Co-corresponding authors
12 |
13 |
14 | 1Korea Advanced Institute of Science and Technology, South Korea
15 |
16 |
17 | 2Kyungpook National University, South Korea
18 |
19 |
20 |
31 |
32 |
33 | ---
34 |
35 |
36 | This repository is the official PyTorch implementation of "C-DiffSET: Leveraging Latent Diffusion for SAR-to-EO Image Translation with Confidence-Guided Reliable Object Generation". C-DiffSET achieves state-of-the-art results on multiple datasets, outperforming the recent image-to-image translation methods and SAR-to-EO image translation methods.
37 |
38 |
39 | ---
40 |
41 | ## 📧 News
42 | - **⚠ The code will be released later**
43 | - **Dec 9, 2024:** This repository is created
44 |
45 | ---
46 |
47 | ## Results
48 | Please visit our [project page](https://kaist-viclab.github.io/C-DiffSET_site/) for more experimental results.
49 |
50 | ## License
51 | The source codes including the checkpoint can be freely used for research and education only. Any commercial use should get formal permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr).
52 |
53 | ## Acknowledgement
54 | This repository is built upon [FMA-Net](https://github.com/KAIST-VICLab/FMA-Net/).
--------------------------------------------------------------------------------
/data/data_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import pickle
4 |
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | '.tif', '.TIF', '.tiff', '.TIFF',
10 | ]
11 |
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 |
17 | def make_dataset(dir, max_dataset_size=float("inf")):
18 | images = []
19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
20 |
21 | for root, _, fnames in sorted(os.walk(dir)):
22 | for fname in fnames:
23 | if is_image_file(fname):
24 | path = os.path.join(root, fname)
25 | images.append(path)
26 | return images[:min(max_dataset_size, len(images))]
27 |
28 |
29 | def make_dataset_relpath(dir, abs_path, max_dataset_size=float("inf")):
30 | images = []
31 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
32 |
33 | for root, _, fnames in sorted(os.walk(dir)):
34 | for fname in fnames:
35 | if is_image_file(fname):
36 | path = os.path.join(root, fname)
37 | path = os.path.relpath(path, abs_path)
38 | images.append(path)
39 | return images[:min(max_dataset_size, len(images))]
40 |
41 |
42 | def make_dataset_list(dir_list, max_dataset_size=float("inf")):
43 | images = []
44 | for dir in sorted(dir_list):
45 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
46 | for root, _, fnames in sorted(os.walk(dir)):
47 | for fname in fnames:
48 | if is_image_file(fname):
49 | path = os.path.join(root, fname)
50 | images.append(path)
51 | return images[:min(max_dataset_size, len(images))]
52 |
53 |
54 | def make_dataset_list_relpath(dir_list, abs_path, max_dataset_size=float("inf")):
55 | images = []
56 | for dir in sorted(dir_list):
57 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
58 | for root, _, fnames in sorted(os.walk(dir)):
59 | for fname in fnames:
60 | if is_image_file(fname):
61 | path = os.path.join(root, fname)
62 | path = os.path.relpath(path, abs_path)
63 | images.append(path)
64 | return images[:min(max_dataset_size, len(images))]
65 |
66 |
67 | def SpaceNet_split(dataroot, ratio=80):
68 | ''' https://arxiv.org/abs/2004.06500 '''
69 | SAVE_PATH = './SpaceNet_split/'
70 | if not os.path.exists(SAVE_PATH):
71 | os.makedirs(SAVE_PATH)
72 |
73 | random.seed(2025)
74 |
75 | eo_dataroot = os.path.join(dataroot, 'train/AOI_11_Rotterdam/PS-RGB/')
76 | list = make_dataset_relpath(eo_dataroot, dataroot)
77 | test_list = random.sample(list, int(len(list) * (1 - ratio / 100)))
78 | train_list = [x for x in list if x not in test_list]
79 |
80 | train_list = sorted(train_list)
81 | test_list = sorted(test_list)
82 |
83 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.txt', 'w') as f:
84 | for path in train_list:
85 | f.write(path + '\n')
86 |
87 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.pkl', 'wb') as f:
88 | pickle.dump(train_list, f)
89 |
90 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.txt', 'w') as f:
91 | for path in test_list:
92 | f.write(path + '\n')
93 |
94 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.pkl', 'wb') as f:
95 | pickle.dump(test_list, f)
96 |
97 |
98 | def QXS_split(dataroot, ratio=80):
99 | ''' https://arxiv.org/abs/2103.08259 '''
100 | SAVE_PATH = './QXS_split/'
101 | if not os.path.exists(SAVE_PATH):
102 | os.makedirs(SAVE_PATH)
103 |
104 | random.seed(2025)
105 |
106 | eo_dataroot = os.path.join(dataroot, 'opt_256_oc_0.2')
107 | list = make_dataset_relpath(eo_dataroot, dataroot)
108 | test_list = random.sample(list, int(len(list) * (1 - ratio / 100)))
109 | train_list = [x for x in list if x not in test_list]
110 |
111 | train_list = sorted(train_list)
112 | test_list = sorted(test_list)
113 |
114 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.txt', 'w') as f:
115 | for path in train_list:
116 | f.write(path + '\n')
117 |
118 | with open(f'{SAVE_PATH}/train_eo_list_{ratio:03}.pkl', 'wb') as f:
119 | pickle.dump(train_list, f)
120 |
121 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.txt', 'w') as f:
122 | for path in test_list:
123 | f.write(path + '\n')
124 |
125 | with open(f'{SAVE_PATH}/test_eo_list_{ratio:03}.pkl', 'wb') as f:
126 | pickle.dump(test_list, f)
127 |
128 |
129 | def SAROpt_split(dataroot):
130 | ''' https://ieeexplore.ieee.org/document/9779739 '''
131 | SAVE_PATH = './SAROpt_split/'
132 | if not os.path.exists(SAVE_PATH):
133 | os.makedirs(SAVE_PATH)
134 |
135 | train_dataroot = os.path.join(dataroot, 'trainB')
136 | test_dataroot = os.path.join(dataroot, 'testB')
137 |
138 | train_list = make_dataset_relpath(train_dataroot, dataroot)
139 | test_list = make_dataset_relpath(test_dataroot, dataroot)
140 |
141 | train_list = sorted(train_list)
142 | test_list = sorted(test_list)
143 |
144 | with open(f'{SAVE_PATH}/train_eo_list.txt', 'w') as f:
145 | for path in train_list:
146 | f.write(path + '\n')
147 |
148 | with open(f'{SAVE_PATH}/train_eo_list.pkl', 'wb') as f:
149 | pickle.dump(train_list, f)
150 |
151 | with open(f'{SAVE_PATH}/test_eo_list.txt', 'w') as f:
152 | for path in test_list:
153 | f.write(path + '\n')
154 |
155 | with open(f'{SAVE_PATH}/test_eo_list.pkl', 'wb') as f:
156 | pickle.dump(test_list, f)
157 |
158 |
159 | SpaceNet_split('./SpaceNet6', ratio=80)
160 | QXS_split('./QXS_SAROPT', ratio=80)
161 | SAROpt_split('./SAR2Opt')
162 |
--------------------------------------------------------------------------------
/feeders/__init__.py:
--------------------------------------------------------------------------------
1 | from . import feeder
--------------------------------------------------------------------------------
/feeders/feeder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 |
7 | import numpy as np
8 | import os
9 | import tifffile
10 | import albumentations as Alb
11 | from albumentations.pytorch import ToTensorV2
12 | from PIL import Image
13 | import pickle
14 |
15 |
16 | class SpaceNetFeeder(Dataset):
17 | ''' https://arxiv.org/abs/2004.06500 '''
18 | def __init__(self, dataroot, listroot, load_size=900, crop_size=512, hflip=False, vflip=False, rot=False):
19 | if listroot.endswith('.txt'):
20 | with open(listroot, 'r') as f:
21 | eo_paths = f.read().splitlines()
22 | elif listroot.endswith('.pkl'):
23 | with open(listroot, 'rb') as f:
24 | eo_paths = pickle.load(f)
25 | else:
26 | print("Unsupported file format.")
27 |
28 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths]
29 | self.sar_paths = [x.replace('PS-RGB', 'SAR-Intensity') for x in self.eo_paths]
30 |
31 | self.load_size = load_size
32 | self.crop_size = crop_size
33 | assert (self.load_size >= self.crop_size)
34 |
35 | self.hflip = hflip
36 | self.vflip = vflip
37 | self.rot = rot
38 |
39 | if 'train' in listroot:
40 | self.split = 'train'
41 | elif 'test' in listroot:
42 | self.split = 'test'
43 | else:
44 | print("Unsupported file format.")
45 |
46 | def im_percent_norm(self, x, p=(1, 99), eps=1 / (2 ** 10)):
47 | pv = np.percentile(x, p, axis=(0, 1))
48 | y = x.astype(np.float32)
49 | pmin = pv[0, ...]
50 | pmax = pv[1, ...]
51 | y = np.clip(y, pmin, pmax)
52 | y = (y - pmin) / np.maximum((pmax - pmin), eps) * 255.0
53 | return y
54 |
55 | def __getitem__(self, index):
56 | eo_path = self.eo_paths[index]
57 | sar_path = self.sar_paths[index]
58 | EO = np.array(Image.open(eo_path).convert('RGB')).astype(np.float32)
59 | SAR = tifffile.imread(sar_path).astype(np.float32)
60 | SAR = self.im_percent_norm(SAR)
61 | SAR = np.stack((SAR[:, :, 0], (SAR[:, :, 1] + SAR[:, :, 2]) / 2, SAR[:, :, 3]), axis=-1)
62 |
63 | transform = []
64 | if self.split == 'train':
65 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size))
66 | if self.hflip:
67 | transform.append(Alb.HorizontalFlip(p=0.5))
68 | if self.vflip:
69 | transform.append(Alb.VerticalFlip(p=0.5))
70 | if self.rot:
71 | transform.append(Alb.RandomRotate90(p=0.5))
72 | else:
73 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size))
74 |
75 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0))
76 | transform.append(ToTensorV2())
77 |
78 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'})
79 | augmented = transform(image=SAR, image2=EO)
80 | SAR, EO = augmented['image'], augmented['image2']
81 | return SAR, EO
82 |
83 | def __len__(self):
84 | return len(self.eo_paths)
85 |
86 |
87 | class QXSFeeder(Dataset):
88 | ''' https://arxiv.org/abs/2103.08259 '''
89 | def __init__(self, dataroot, listroot, load_size=256, crop_size=256, hflip=False, vflip=False, rot=False):
90 | if listroot.endswith('.txt'):
91 | with open(listroot, 'r') as f:
92 | eo_paths = f.read().splitlines()
93 | elif listroot.endswith('.pkl'):
94 | with open(listroot, 'rb') as f:
95 | eo_paths = pickle.load(f)
96 | else:
97 | print("Unsupported file format.")
98 |
99 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths]
100 | self.sar_paths = [x.replace('opt_256_oc_0.2', 'sar_256_oc_0.2') for x in self.eo_paths]
101 |
102 | self.load_size = load_size
103 | self.crop_size = crop_size
104 | assert (self.load_size >= self.crop_size)
105 |
106 | self.hflip = hflip
107 | self.vflip = vflip
108 | self.rot = rot
109 |
110 | if 'train' in listroot:
111 | self.split = 'train'
112 | elif 'test' in listroot:
113 | self.split = 'test'
114 | else:
115 | print("Unsupported file format.")
116 |
117 | def __getitem__(self, index):
118 | eo_path = self.eo_paths[index]
119 | sar_path = self.sar_paths[index]
120 | EO = np.array(Image.open(eo_path).convert('RGB'))
121 | SAR = np.array(Image.open(sar_path).convert('RGB'))
122 |
123 | transform = []
124 | if self.split == 'train':
125 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size))
126 | if self.hflip:
127 | transform.append(Alb.HorizontalFlip(p=0.5))
128 | if self.vflip:
129 | transform.append(Alb.VerticalFlip(p=0.5))
130 | if self.rot:
131 | transform.append(Alb.RandomRotate90(p=0.5))
132 | else:
133 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size))
134 |
135 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0))
136 | transform.append(ToTensorV2())
137 |
138 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'})
139 | augmented = transform(image=SAR, image2=EO)
140 | SAR, EO = augmented['image'], augmented['image2']
141 | return SAR, EO
142 |
143 | def __len__(self):
144 | return len(self.eo_paths)
145 |
146 |
147 | class SAROptFeeder(Dataset):
148 | ''' https://ieeexplore.ieee.org/document/9779739 '''
149 | def __init__(self, dataroot, listroot, load_size=600, crop_size=512, hflip=False, vflip=False, rot=False):
150 | if listroot.endswith('.txt'):
151 | with open(listroot, 'r') as f:
152 | eo_paths = f.read().splitlines()
153 | elif listroot.endswith('.pkl'):
154 | with open(listroot, 'rb') as f:
155 | eo_paths = pickle.load(f)
156 | else:
157 | print("Unsupported file format.")
158 |
159 | self.eo_paths = [os.path.join(dataroot, eo_path) for eo_path in eo_paths]
160 |
161 | if 'train' in listroot:
162 | self.split = 'train'
163 | self.sar_paths = [x.replace('trainB', 'trainA') for x in self.eo_paths]
164 | elif 'test' in listroot:
165 | self.split = 'test'
166 | self.sar_paths = [x.replace('testB', 'testA') for x in self.eo_paths]
167 | else:
168 | print("Unsupported file format.")
169 |
170 | self.load_size = load_size
171 | self.crop_size = crop_size
172 | assert (self.load_size >= self.crop_size)
173 |
174 | self.hflip = hflip
175 | self.vflip = vflip
176 | self.rot = rot
177 |
178 | def __getitem__(self, index):
179 | eo_path = self.eo_paths[index]
180 | sar_path = self.sar_paths[index]
181 | EO = np.array(Image.open(eo_path).convert('RGB'))
182 | SAR = np.array(Image.open(sar_path).convert('RGB'))
183 |
184 | transform = []
185 | if self.split == 'train':
186 | transform.append(Alb.RandomCrop(width=self.crop_size, height=self.crop_size))
187 | if self.hflip:
188 | transform.append(Alb.HorizontalFlip(p=0.5))
189 | if self.vflip:
190 | transform.append(Alb.VerticalFlip(p=0.5))
191 | if self.rot:
192 | transform.append(Alb.RandomRotate90(p=0.5))
193 | else:
194 | transform.append(Alb.CenterCrop(width=self.crop_size, height=self.crop_size))
195 |
196 | transform.append(Alb.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0))
197 | transform.append(ToTensorV2())
198 |
199 | transform = Alb.Compose(transform, additional_targets={'image2': 'image'})
200 | augmented = transform(image=SAR, image2=EO)
201 | SAR, EO = augmented['image'], augmented['image2']
202 | return SAR, EO
203 |
204 | def __len__(self):
205 | return len(self.eo_paths)
--------------------------------------------------------------------------------