├── README.md ├── main_unet.py ├── preprocessing ├── data_processing.m └── retinex.m ├── program of collecting data ├── generate_expdata.m └── readme.md ├── prototype ├── main_realshow.py └── readme.md ├── training └── main1.py └── video ├── Vehicle-mounted prototyp.mp4 ├── Video1 Non-invasive real-time restoration of real-world objects.mp4 └── Video2 Non-invasive real-time restoration of real-world objects.mp4 /README.md: -------------------------------------------------------------------------------- 1 | # Learning-based real-time imaging through dynamic scattering media 2 | In this study, we propose a deep-learning-based method to image through dynamic scattering media in a non-invasive manner under incoherence illumination and obtain superior restoration result. 3 | 4 | paper: [https://www.nature.com/articles/s41377-024-01569-0](https://www.nature.com/articles/s41377-024-01569-0) 5 | # 1 Framework 6 | ![image](https://github.com/user-attachments/assets/681fcd41-8cc7-4cd7-a3a1-7a304d9b2d93) 7 | 8 | # 2 Recovery of unseen real-world objects 9 | 10 | ![image](https://github.com/user-attachments/assets/fd4bbcae-7fc8-4f11-982f-fc4571d305e3) 11 | 12 | # 3 Prototype 13 | ![prototype2](https://github.com/LittleMount/DescatterNet-for-unseen-real-world-objects/assets/38102067/764a4986-b4bc-40da-90a8-e83a77ea8dcc) 14 | -------------------------------------------------------------------------------- /main_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import scipy.io as scio 5 | import tensorflow as tf 6 | import matplotlib.pyplot as plt 7 | 8 | f = lambda x: (x - np.min(x)) / (np.max(x) - np.min(x)) 9 | 10 | from PIL import Image 11 | from tensorflow.keras.models import Model 12 | from tensorflow.keras.optimizers import Adam 13 | from tensorflow.keras.layers import concatenate 14 | from skimage.metrics import structural_similarity as ssim 15 | from tensorflow.keras.layers import BatchNormalization, Activation, Conv2DTranspose 16 | from tensorflow.keras.layers import Input, Dropout, Conv2D, MaxPooling2D 17 | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau 18 | 19 | 20 | class U_Net(): 21 | def __init__(self): 22 | # setting the size of image 23 | self.height = 448 24 | self.width = 448 25 | self.channels = 1 26 | self.shape = (self.height, self.width, self.channels) 27 | 28 | # optimization 29 | optimizer = Adam(0.008, 0.5) 30 | 31 | # Unet 32 | self.unet = self.build_unet() 33 | self.unet.compile(loss='mse', 34 | optimizer=optimizer, 35 | metrics=[self.metric_fun]) 36 | self.unet.summary() 37 | 38 | def build_unet(self, n_filters=32, dropout=0.1, 39 | batchnorm=True, padding='same'): 40 | 41 | # define a conv block 42 | def conv2d_block(input_tensor, n_filters=16, kernel_size=3, 43 | batchnorm=True, padding='same'): 44 | # the first layer 45 | x = Conv2D(n_filters, kernel_size, 46 | padding=padding)(input_tensor) 47 | if batchnorm: 48 | x = BatchNormalization()(x) 49 | x = Activation('relu')(x) 50 | 51 | # the second layer 52 | x = Conv2D(n_filters, kernel_size, 53 | padding=padding)(x) 54 | if batchnorm: 55 | x = BatchNormalization()(x) 56 | x = Activation('relu')(x) 57 | 58 | return x 59 | 60 | # construct an input 61 | img = Input(shape=self.shape) 62 | 63 | # contracting path 64 | c1 = conv2d_block(img, n_filters=n_filters * 1, kernel_size=3, batchnorm=batchnorm, padding=padding) 65 | p1 = MaxPooling2D((2, 2))(c1) 66 | p1 = Dropout(dropout * 0.5)(p1) 67 | 68 | c2 = conv2d_block(p1, n_filters=n_filters * 2, kernel_size=3, batchnorm=batchnorm, padding=padding) 69 | p2 = MaxPooling2D((2, 2))(c2) 70 | p2 = Dropout(dropout)(p2) 71 | 72 | c3 = conv2d_block(p2, n_filters=n_filters * 4, kernel_size=3, batchnorm=batchnorm, padding=padding) 73 | p3 = MaxPooling2D((2, 2))(c3) 74 | p3 = Dropout(dropout)(p3) 75 | 76 | c4 = conv2d_block(p3, n_filters=n_filters * 8, kernel_size=3, batchnorm=batchnorm, padding=padding) 77 | p4 = MaxPooling2D((2, 2))(c4) 78 | p4 = Dropout(dropout)(p4) 79 | 80 | c5 = conv2d_block(p4, n_filters=n_filters * 16, kernel_size=3, batchnorm=batchnorm, padding=padding) 81 | 82 | # extending path 83 | u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides=(2, 2), padding='same')(c5) 84 | u6 = concatenate([u6, c4]) 85 | u6 = Dropout(dropout)(u6) 86 | c6 = conv2d_block(u6, n_filters=n_filters * 8, kernel_size=3, batchnorm=batchnorm, padding=padding) 87 | 88 | u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides=(2, 2), padding='same')(c6) 89 | u7 = concatenate([u7, c3]) 90 | u7 = Dropout(dropout)(u7) 91 | c7 = conv2d_block(u7, n_filters=n_filters * 4, kernel_size=3, batchnorm=batchnorm, padding=padding) 92 | 93 | u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides=(2, 2), padding='same')(c7) 94 | u8 = concatenate([u8, c2]) 95 | u8 = Dropout(dropout)(u8) 96 | c8 = conv2d_block(u8, n_filters=n_filters * 2, kernel_size=3, batchnorm=batchnorm, padding=padding) 97 | 98 | u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides=(2, 2), padding='same')(c8) 99 | u9 = concatenate([u9, c1]) 100 | u9 = Dropout(dropout)(u9) 101 | c9 = conv2d_block(u9, n_filters=n_filters * 1, kernel_size=3, batchnorm=batchnorm, padding=padding) 102 | 103 | output = Conv2D(1, (1, 1), activation='sigmoid')(c9) 104 | 105 | return Model(img, output) 106 | 107 | def metric_fun(self, y_true, y_pred): 108 | return tf.image.ssim(y_true, y_pred, max_val=1) 109 | 110 | def train(self, epochs=1001, batch_size=8): 111 | os.makedirs('./weights', exist_ok=True) 112 | os.makedirs('./evaluation', exist_ok=True) 113 | # obtain data 114 | data_input = scio.loadmat('./mat_data/input_retinex.mat') 115 | data_label = scio.loadmat('./mat_data/label_retinex.mat') 116 | 117 | # load the trained model 118 | # self.unet.load_weights(r"./best_model.h5") 119 | 120 | # setting the check point 121 | callbacks = [EarlyStopping(patience=1000, verbose=2), 122 | ReduceLROnPlateau(factor=0.5, patience=50, min_lr=0.00005), 123 | ModelCheckpoint('./weights/best_model.h5', verbose=2, save_best_only=True)] 124 | 125 | # training 126 | results = self.unet.fit(np.expand_dims(data_input['input'], axis=3), 127 | np.expand_dims(data_label['label'], axis=3), 128 | batch_size=batch_size, epochs=epochs, verbose=2, 129 | callbacks=callbacks, validation_split=0.1, shuffle=True) 130 | 131 | # plot loss curve 132 | loss = results.history['loss'] 133 | val_loss = results.history['val_loss'] 134 | metric = results.history['metric_fun'] 135 | val_metric = results.history['val_metric_fun'] 136 | fig, ax = plt.subplots(1, 2, figsize=(12, 6)) 137 | x = np.linspace(0, len(loss), len(loss)) 138 | plt.subplot(121), plt.plot(x, loss, x, val_loss) 139 | plt.title('Loss curve'), plt.legend(['loss', 'val_loss']) 140 | plt.xlabel('Epochs'), plt.ylabel('loss') 141 | plt.subplot(122), plt.plot(x, metric, x, val_metric) 142 | plt.title('metric curve'), plt.legend(['metric', 'val_metric']) 143 | plt.xlabel('Epochs'), plt.ylabel('ssim') 144 | plt.show() 145 | fig.savefig('./evaluation/curve.png', bbox_inches='tight', pad_inches=0.1) 146 | plt.close() 147 | 148 | def test(self): 149 | os.makedirs('./evaluation/test_result', exist_ok=True) 150 | os.makedirs('./evaluation/single picture', exist_ok=True) 151 | self.unet.load_weights(r'weights/best_model.h5') 152 | # obtain data 153 | test_input = scio.loadmat('./mat_data/test_input_retinex.mat') 154 | test_label = scio.loadmat('./mat_data/test_label_retinex.mat') 155 | test_num = test_input['input'].shape[0] 156 | index, step = 0, 0 157 | n = 0 158 | 159 | while index < test_num: 160 | print('schedule: %d/%d' % (index, test_num)) 161 | step += 1 162 | output = self.unet.predict((np.expand_dims(test_input['input'][index:index + 1], axis=3))) 163 | label = test_label['label'][index] 164 | result = np.concatenate([test_input['input'][index], output.squeeze(), label], axis=1) 165 | result = f(result) 166 | img = Image.fromarray(np.uint8(result * 255)) 167 | img.save('./evaluation/test_result/%d_%.3f_%.3f.png' % (step, ssim(test_input['input'][index], label), ssim(output.squeeze(), label))) 168 | temp = Image.fromarray(np.uint8(255*f(output.squeeze()))) 169 | temp.save('./evaluation/single picture/%d.png' % step) 170 | index += 1 171 | 172 | def test_video(self): 173 | self.unet.load_weights(r'weights/best_model.h5') 174 | video_input = scio.loadmat('./mat_data/test_video_2.8ml.mat') 175 | result = self.unet.predict(np.expand_dims(video_input['Video_28'], axis=3)) 176 | # scio.savemat('./evaluation/result_video.mat', {'result_video': result.squeeze()}) 177 | 178 | if __name__ == '__main__': 179 | unet = U_Net() 180 | # unet.train() 181 | unet.test() 182 | # unet.test_video() 183 | -------------------------------------------------------------------------------- /preprocessing/data_processing.m: -------------------------------------------------------------------------------- 1 | % 数据预处理 2 | % 流程:读取图片 -> 裁剪 -> 归一化拉伸 -> 保存图片; 3 | clc,clear 4 | close all 5 | 6 | % 文件名的批量读取 7 | nongdu = {'0','0.6','1.2','1.8','2.4','2.8','3.2','3.6'}; 8 | for i = 1:8 9 | consentration = nongdu{i}; 10 | path = ['../20221102/',consentration, 'ml/']; % 指明哪个文件夹 11 | file_set = dir(fullfile(path,'*.png')); % 读取后缀为.jpg的文件信息,保存为结构体数组 12 | name_set = {file_set.name}; % 获取批量的文件名,保存为元胞数组 13 | save_path = ['./retinex/',consentration,'ml/']; 14 | mkdir(save_path); 15 | 16 | % 文件名批量获取已完成,下面是遍历使用文件的示例 17 | for j = 1351:1366 18 | filename = [path,num2str(j),'.png']; % 注意组合“文件夹+文件名”才可读取到图片 19 | img = im2double(imread(filename)); % 执行后续操作 20 | img_crop = img(528:1455,555:1482); 21 | % img_crop_r = imresize(img_crop,[256 256]); 22 | % img_norm = mat2gray(img_crop_r); 23 | img_retinex = retinex(img_crop); 24 | 25 | imwrite(img_retinex,[save_path,num2str(j),'.png']); 26 | % imshow(img); 27 | % pause(0.5); % 暂停0.5s,把图片显示出来 28 | end 29 | 30 | end 31 | 32 | 33 | -------------------------------------------------------------------------------- /preprocessing/retinex.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SituLab/DescatterNet-for-unseen-real-world-objects/d4856cececf1bfe72b08aa2071f4e4257b06403b/preprocessing/retinex.m -------------------------------------------------------------------------------- /program of collecting data/generate_expdata.m: -------------------------------------------------------------------------------- 1 | % prepear the data for experiment 2 | % read image -> resize -> expand black band -> save image 3 | 4 | clc,clear 5 | close all 6 | 7 | % initial 8 | path = './ILSVRC2012_img_val/'; 9 | dirname = './exp data(20220907)/'; 10 | % rmdir(dirname); 11 | mkdir(dirname); 12 | n = 1; 13 | 14 | %% 数据集1 15 | for i = 1:0 16 | img = imread([path,num2str(i),'.jpeg']); 17 | if ndims(img)==3 18 | img = rgb2gray(img); 19 | end 20 | judge_num = std2(img); 21 | if judge_num<50 % 清洗低对比度图像 22 | continue 23 | end 24 | img = imresize(img,[1400,1400],'bicubic'); 25 | img_exp = padarray(img, [2 236]); 26 | str_name = num2str(n+10000); 27 | imwrite(img_exp,[dirname,str_name(2:end),'.png']); 28 | if mod(i,10) == 0 29 | disp(i); 30 | end 31 | n = n+1; 32 | end 33 | n = 669; 34 | %% 数据集2 35 | path = './DIV2K_train_HR/'; 36 | for i = 1:0%800 37 | str1 = num2str(i+10000); 38 | img = imread([path,str1(2:end),'.png']); 39 | if ndims(img)==3 40 | img = rgb2gray(img); 41 | end 42 | judge_num = std2(img); 43 | if judge_num<50 % 清洗低对比度图像 44 | continue 45 | end 46 | img = imresize(img,[1400,1400],'bicubic'); 47 | img_exp = padarray(img, [2 236]); 48 | % img = imresize(img,[576,576],'bicubic'); 49 | % img_exp = padarray(img, [414 648], 100); 50 | str_name = num2str(n+10000); 51 | imwrite(img_exp,[dirname,str_name(2:end),'.png']); 52 | if mod(i,10) == 0 53 | disp(i); 54 | end 55 | n = n+1; 56 | end 57 | 58 | %% 数据集3 59 | path = './DIV2K_valid_HR/'; 60 | for i = 801:800%900 61 | str1 = num2str(i+10000); 62 | img = imread([path,str1(2:end),'.png']); 63 | if ndims(img)==3 64 | img = rgb2gray(img); 65 | end 66 | judge_num = std2(img); 67 | if judge_num<50 % 清洗低对比度图像 68 | continue 69 | end 70 | img = imresize(img,[1400,1400],'bicubic'); 71 | img_exp = padarray(img, [2 236]); 72 | % img = imresize(img,[576,576],'bicubic'); 73 | % img_exp = padarray(img, [414 648], 100); 74 | str_name = num2str(n+10000); 75 | imwrite(img_exp,[dirname,str_name(2:end),'.png']); 76 | if mod(i,10) == 0 77 | disp(i); 78 | end 79 | n = n+1; 80 | end 81 | % n = 1349; 82 | 83 | %% MNIST数据集 84 | n=1350; 85 | 86 | %% 数据集4 标准测试图 87 | path = './standard image/'; 88 | for i = 1:10 89 | img = imread([path,num2str(i),'.png']); 90 | if ndims(img)==3 91 | img = rgb2gray(img); 92 | end 93 | judge_num = std2(img); 94 | if judge_num<50 % 清洗低对比度图像 95 | % continue 96 | end 97 | img = imresize(img,[1400,1400],'bicubic'); 98 | img_exp = padarray(img, [2 236]); 99 | % img = imresize(img,[576,576],'bicubic'); 100 | % img_exp = padarray(img, [414 648], 100); 101 | str_name = num2str(n+10000); 102 | imwrite(img_exp,[dirname,str_name(2:end),'.png']); 103 | if mod(i,10) == 0 104 | disp(i); 105 | end 106 | n = n+1; 107 | end 108 | 109 | %% 全黑、全白的图像 110 | img0 = zeros(size(img_exp)); 111 | str_name = num2str(n+10000); 112 | imwrite(img0,[dirname,str_name(2:end),'.png']); 113 | n = n+1; 114 | 115 | img1 = ones(size(img_exp)); 116 | str_name = num2str(n+10000); 117 | imwrite(img1,[dirname,str_name(2:end),'.png']); -------------------------------------------------------------------------------- /program of collecting data/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /prototype/main_realshow.py: -------------------------------------------------------------------------------- 1 | # 流程:读取图像 -> 图像预处理 -> 输入网络 -> 展示结果 2 | # 图像预处理:图像裁剪、Retinex处理和CLAHE处理; 3 | # 以上过程是循环进行的; 4 | # 需求:测试CPU的恢复速度和GPU的恢复速度 5 | 6 | import pco 7 | import cv2 8 | import time 9 | import torch 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from unet_use_in_hardware.func import * 13 | # from unet_use_in_hardware.tensorflow_unet import * 14 | 15 | ## 初始化 16 | gap = (np.ones([448, 20], dtype=np.float32)*65535).astype(np.uint16) 17 | cam = pco.Camera() # 指定相机 18 | exp_time = 0.003 19 | cam.configuration = {'exposure time': exp_time, 'roi': (95, 112, 1886, 1903)} # (543, 560, 1438, 1455) 20 | cam.record(4, mode='ring buffer') 21 | cam.wait_for_first_image() 22 | 23 | # 指定用GPU复原 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | # 加载网络 27 | unet = UNet(1, 1) # 定义UNet网络 28 | unet.load_state_dict(torch.load('unet_2.pkl')) # 加载训练好的网络参数 29 | unet.to(device) 30 | 31 | ## 循环模块: 读取图像 -> 预处理 -> 生成input -> 复原结果 32 | n = 120 33 | while True: 34 | n += 1 35 | start = time.time() 36 | # 读取图像 37 | img, meta = cam.image(image_index=-1) 38 | 39 | '''# 自适应曝光 40 | if np.max(img[543:1438:2, 560:1455:2]) == 65535: # 过曝了,缩小曝光时间 41 | exp_time = exp_time*55000/65535 42 | cam.exposure_time = exp_time 43 | print('exp_time: ', exp_time) 44 | elif np.max(img[543:1438:2, 560:1455:2]) > 50000: # 最大像素值>50000, 无需调节 45 | pass 46 | else: # 最大像素值<50000, 增加曝光时间 47 | exp_time = exp_time*55000/np.max(img[543:1438:2, 560:1455:2]) 48 | cam.exposure_time = exp_time 49 | print('exp_time: ', exp_time)''' 50 | 51 | # 数据预处理 52 | # start = time.time() 53 | img_pre = pre_process(img) 54 | # print('图像预处理时间:{:.6f}s'.format(time.time()-start)) 55 | 56 | # 生成网络输入所需要的数据 57 | input = torch.tensor(normalize(img_pre)).reshape([1, 1, 448, 448]).to(device) 58 | 59 | # 用网络复原 60 | # start = time.time() 61 | result = unet(input) 62 | # result1 = unet1(input) 63 | # print('图像复原时间:{:.6f}s'.format(time.time() - start)) 64 | # result1 = convert2img(result) img_pre, # 65 | cv2.imshow('image', np.concatenate((img[::4, ::4], gap, im2uint16(normalize(img[::4, ::4])), 66 | gap, im2uint16(convert2img(result))), axis=1)) 67 | # gap, im2uint16(convert2img(result1))), axis=1)) 68 | # plt.figure(), plt.imshow(result1, cmap='gray'), plt.show() 69 | 70 | # 按下 'q' 键退出循环 71 | if cv2.waitKey(1) == ord('q'): 72 | break 73 | 74 | # if n > 200: 75 | # break 76 | 77 | print(time.time()-start) 78 | 79 | cv2.waitKey(-1) 80 | cam.close() 81 | cv2.destroyAllWindows() 82 | 83 | ''' else 84 | # 绘制图窗 85 | cv2.namedWindow('img', cv2.WINDOW_NORMAL) 86 | cv2.namedWindow('img_crop', cv2.WINDOW_NORMAL) 87 | cv2.namedWindow('img_retinex', cv2.WINDOW_NORMAL) 88 | cv2.namedWindow('img_clahe', cv2.WINDOW_NORMAL) 89 | cv2.resizeWindow('img', 800, 800) 90 | cv2.resizeWindow('img_crop', 800, 800) 91 | cv2.resizeWindow('img_retinex', 800, 800) 92 | cv2.resizeWindow('img_clahe', 800, 800) 93 | 94 | 95 | plt.figure(1), plt.imshow(img_crop) 96 | plt.axis('off') 97 | plt.show() 98 | runfile('./main_realshow.py') 99 | 100 | # cv2.imshow('img', img), cv2.waitKey(1) 101 | # plt.figure(1), plt.hist(img.ravel(), 655, [0, 65535]) 102 | 103 | # 裁剪数据 104 | img_crop = img[544:1439:2, 561:1456:2] 105 | # cv2.imshow('img_crop', img_crop), cv2.waitKey(1) 106 | # plt.figure(2), plt.hist(img_crop.ravel(), 655, [0, 65535]) 107 | 108 | # 计算Retinex图像 109 | retinex1 = retinex(img_crop) 110 | img_retinex = np.uint16(retinex1 * 65535) 111 | # cv2.imshow('img_retinex', img_retinex), cv2.waitKey(1) 112 | # plt.figure(3), plt.hist(img_retinex.ravel(), 655, [0, 65535]) 113 | 114 | # 使用CLAHE处理 115 | clahe = cv2.createCLAHE(clipLimit=0.1, tileGridSize=(8, 8)) 116 | img_clahe = clahe.apply(img_retinex) 117 | # cv2.imshow('img_clahe', img_clahe), cv2.waitKey(1) 118 | # plt.figure(4), plt.hist(img_clahe.ravel(), 655, [0, 65535]) 119 | 120 | # 使用imadjust 121 | img_adjust = py_imadjust(img_clahe) 122 | 123 | ''' 124 | -------------------------------------------------------------------------------- /prototype/readme.md: -------------------------------------------------------------------------------- 1 | ![prototype2](https://github.com/LittleMount/DescatterNet-for-unseen-real-world-objects/assets/38102067/a748590a-1d3c-458a-89e5-4ab3c4d834d5) 2 | 3 | -------------------------------------------------------------------------------- /training/main1.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /video/Vehicle-mounted prototyp.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SituLab/DescatterNet-for-unseen-real-world-objects/d4856cececf1bfe72b08aa2071f4e4257b06403b/video/Vehicle-mounted prototyp.mp4 -------------------------------------------------------------------------------- /video/Video1 Non-invasive real-time restoration of real-world objects.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SituLab/DescatterNet-for-unseen-real-world-objects/d4856cececf1bfe72b08aa2071f4e4257b06403b/video/Video1 Non-invasive real-time restoration of real-world objects.mp4 -------------------------------------------------------------------------------- /video/Video2 Non-invasive real-time restoration of real-world objects.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SituLab/DescatterNet-for-unseen-real-world-objects/d4856cececf1bfe72b08aa2071f4e4257b06403b/video/Video2 Non-invasive real-time restoration of real-world objects.mp4 --------------------------------------------------------------------------------