├── LICENSE ├── README.md ├── benchmark_models.py ├── evaluation_accuracy.py ├── figure_depth.py ├── figure_noise_resistance.py ├── generate_training_set.py ├── requirements.txt ├── toolbox.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Idiap Research Institute 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSF Estimation 2 | Code for the PyTorch implementation of "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy", IEEE Transactions on Image Processing, 2020. 3 | 4 | https://ieeexplore.ieee.org/document/9068472 5 | 6 | ## Abstract 7 | Optical microscopy is an essential tool in biology and medicine. Imaging thin, yet non-flat objects in a single shot (without relying on more sophisticated sectioning setups) remains challenging as the shallow depth of field that comes with highresolution microscopes leads to unsharp image regions and makes depth localization and quantitative image interpretation difficult. Here, we present a method that improves the resolution of light microscopy images of such objects by locally estimating image distortion while jointly estimating object distance to the focal plane. Specifically, we estimate the parameters of a spatiallyvariant Point Spread Function (PSF) model using a Convolutional Neural Network (CNN), which does not require instrument- or object-specific calibration. Our method recovers PSF parameters from the image itself with up to a squared Pearson correlation coefficient of 0.99 in ideal conditions, while remaining robust to object rotation, illumination variations, or photon noise. When the recovered PSFs are used with a spatially-variant and regularized Richardson-Lucy (RL) deconvolution algorithm, we observed up to 2.1 dB better Signal-to-Noise Ratio (SNR) compared to other Blind Deconvolution (BD) techniques. Following microscope-specific calibration, we further demonstrate that the recovered PSF model parameters permit estimating surface depth with a precision of 2 micrometers and over an extended range when using engineered PSFs. Our method opens up multiple possibilities for enhancing images of non-flat objects with minimal need for a priori knowledge about the optical setup. 8 | 9 | ## Requirements 10 | The following python libraries are required. We advise the use of the conda package manager. 11 | > numpy 12 | > scikit-image 13 | > pytorch 14 | > matplotlib 15 | > PyQt5 16 | > pandas 17 | > scikit-learn 18 | 19 | For example, you can install all the requirements by using 20 | > conda install --file requirements.txt 21 | 22 | ## Generating training dataset 23 | Launch the file `generate_training_set.py` with the according parameters 24 | 25 | ## Training 26 | Launch `train.py` and modify the parameters to match the training set folder. 27 | 28 | ## Deconvolution 29 | The code for deconvolution is in the separate directory `https://github.com/idiap/semiblindpsfdeconv` 30 | 31 | ## Generating figures and tables 32 | The benchmark table is in file `benchmark_models.py`; noise resistance figure in `figure_noise_resistance.py`, and the depth-from-focus figure in `figure_depth.py` 33 | 34 | ## Citation 35 | For any use of the code or parts of the code, please cite: 36 | 37 | @article{shajkofci_spatially-variant_2020, 38 | ids = {shajkofci\_spatially-variant\_2020}, 39 | title = {Spatially-{{Variant CNN}}-{{Based Point Spread Function Estimation}} for {{Blind Deconvolution}} and {{Depth Estimation}} in {{Optical Microscopy}}}, 40 | author = {Shajkofci, Adrian and Liebling, Michael}, 41 | date = {2020}, 42 | journaltitle = {IEEE Transactions on Image Processing}, 43 | volume = {29}, 44 | pages = {5848--5861}, 45 | issn = {1941-0042}, 46 | doi = {10.1109/TIP.2020.2986880}, 47 | eventtitle = {{{IEEE Transactions}} on {{Image Processing}}}, 48 | keywords = {blind deconvolution,Calibration,convolutional neural networks,Deconvolution,depth from focus,Estimation,Microscopy,Optical diffraction,Optical imaging,Optical microscopy,point spread function estimation} 49 | } 50 | 51 | 52 | ## Licence 53 | This is free software: you can redistribute it and/or modify it under the terms of the BSD-3-Clause licence. 54 | -------------------------------------------------------------------------------- /benchmark_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | 37 | import pandas 38 | import pprint 39 | import sys 40 | 41 | from PyQt5.QtWidgets import * 42 | from PyQt5.QtGui import * 43 | from PyQt5.QtCore import * 44 | 45 | import numpy as np 46 | 47 | import datetime 48 | from toolbox import random_open_crop, random_crop 49 | from toolbox import to_32_bit, scale, to_16_bit 50 | from toolbox import center_crop_pixel 51 | 52 | from evaluation_accuracy import * 53 | from train import * 54 | 55 | torch.random.manual_seed(8) 56 | pp = pprint.PrettyPrinter() 57 | 58 | directory="data/models/" 59 | 60 | file_list = glob.glob(directory+"*") 61 | 62 | all_params = [] 63 | for file in file_list: 64 | if 'texture' in file: 65 | continue 66 | parsed = file.replace(directory, '') 67 | has_nat = 0 68 | has_nat_mult = 0 69 | if '_natural_' in parsed: 70 | has_nat = 1 71 | has_nat_mult = 1 72 | elif '_nonoise_' in parsed: 73 | has_nat = 1 74 | has_nat_mult = 2 75 | parsed = parsed.replace('.pt', '') 76 | parsed = parsed.split("_") 77 | params = {} 78 | params['patch_size'] = int(parsed[1]) 79 | params['train_natural'] = int(parsed[3])+2*has_nat_mult 80 | params['train_synthetic'] = parsed[5] 81 | params['train_points'] = parsed[7] 82 | params['train_black'] = parsed[9] 83 | params['dataset_trained'] = parsed[has_nat+10] 84 | params['run'] = int(parsed[has_nat+11]) 85 | params['model'] = parsed[has_nat+12] 86 | found = list(filter(lambda file: file['train_natural'] == params['train_natural'] and file['train_synthetic'] == params['train_synthetic'] and file['train_points'] == params['train_points'] 87 | and file['train_black'] == params['train_black'] and file['dataset_trained'] == params['dataset_trained'] and file['run'] == params['run'] and file['model'] == params['model'], all_params)) 88 | # if len(found) > 0: 89 | # found[0]['train_err'].append((int(parsed[has_nat+13][2:]), float(parsed[has_nat+14][8:]))) 90 | # found[0]['test_err'].append((int(parsed[has_nat+13][2:]), float(parsed[has_nat+15][7:]))) 91 | # found[0]['files'].append((int(parsed[has_nat+13][2:]),file)) 92 | # found[0]['train_err'] = sorted(found[0]['train_err'], key=lambda tup: tup[0]) 93 | # found[0]['test_err'] = sorted(found[0]['test_err'], key=lambda tup: tup[0]) 94 | # found[0]['files'] = sorted(found[0]['files'], key=lambda tup: tup[0]) 95 | # 96 | # else: 97 | params['train_err'] = [] 98 | params['test_err'] = [] 99 | params['train_err'].append((int(parsed[has_nat+13][2:]), float(parsed[has_nat+14][8:]))) 100 | params['test_err'].append((int(parsed[has_nat+13][2:]), float(parsed[has_nat+15][7:]))) 101 | params['files'] = [] 102 | params['files'].append((int(parsed[has_nat+13][2:]),file)) 103 | all_params.append(params) 104 | 105 | for p in all_params: 106 | train_err = [err[1] for err in p['train_err']] 107 | test_err = [err[1] for err in p['test_err']] 108 | files = [err[1] for err in p['files']] 109 | p.update({'train_err':train_err}) 110 | p.update({'test_err':test_err}) 111 | p.update({'files': files}) 112 | p.update({'best_epoch_train': np.argmin(train_err)}) 113 | p.update({'best_epoch_test': np.argmin(test_err)}) 114 | 115 | print('Found {} different modalities'.format(all_params.__sizeof__())) 116 | 117 | def test(params): 118 | global log 119 | 120 | file = params['files'][params['best_epoch_train']] 121 | patch_size = params['patch_size'] 122 | model_type = params['dataset_trained'] 123 | synthetic = params['test_synthetic'] 124 | natural = params['test_natural'] 125 | points = params['test_points'] 126 | black = params['test_black'] 127 | 128 | if 'noise' in params: 129 | noise = params['noise'] 130 | noise_type = params['noise_type'] 131 | else: 132 | noise = 0.0 133 | noise_type = None 134 | if natural == 3: 135 | isnat = '_natural' 136 | natural = 1 137 | elif natural == 5: # nonoise 138 | isnat = '' 139 | natural = 1 140 | else: 141 | isnat='' 142 | 143 | run_nb = 'test' 144 | run_name = '{}_n_{}_s_{}_p_{}_b_{}{}_{}_{}/'.format(patch_size, natural, synthetic, points, black, isnat,model_type, run_nb) 145 | folder_prefix = "/idiap/temp/ashajkofci/" 146 | 147 | logging.basicConfig( 148 | format="%(asctime)s [{}] %(message)s".format(run_name), 149 | handlers=[ 150 | logging.StreamHandler() 151 | ]) 152 | 153 | log = logging.getLogger('') 154 | log.setLevel(logging.INFO) 155 | 156 | def add_running_mean(_node): 157 | for child in _node.children(): 158 | if type(child) == nn.BatchNorm2d: 159 | if child.running_mean is None: 160 | del child._parameters['running_mean'] 161 | del child._parameters['running_var'] 162 | del child._parameters['num_batches_tracked'] 163 | child.register_buffer('running_mean', torch.zeros(child.num_features).cuda()) 164 | child.register_buffer('running_var', torch.ones(child.num_features).cuda()) 165 | child.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long).cuda()) 166 | child.track_running_stats = True 167 | elif isinstance(child, nn.Module): 168 | add_running_mean(child) 169 | return 170 | 171 | 172 | def stop_running_mean(_node): 173 | for child in _node.children(): 174 | if type(child) == nn.BatchNorm2d: 175 | child.track_running_stats = False 176 | elif isinstance(child, nn.Module): 177 | stop_running_mean(child) 178 | return 179 | 180 | for handler in logging.root.handlers[:]: 181 | logging.root.removeHandler(handler) 182 | 183 | logging.basicConfig( 184 | format="%(asctime)s [{}] %(message)s".format(run_name+file.split('/')[-1]), 185 | handlers=[ 186 | logging.FileHandler("output_log_{}.log".format(run_nb)), 187 | logging.StreamHandler() 188 | ]) 189 | 190 | log.info("Load test data for {}....".format(run_name)) 191 | log.info("Loading file {}".format(file)) 192 | model = torch.load('{}'.format(file)) 193 | model.eval() 194 | 195 | test_loader_, test_header_ = load_crops(folder_prefix = folder_prefix, test=True, patch_size=patch_size, synthetic=synthetic, natural=natural, points=points, black=black, model_type=model_type, batch_size=84, noise=noise, isnew=False, noise_type=noise_type, suffix=isnat) 196 | results = eval_regression_accuracy(model, test_loader_, test_header_, max_iter=10000) 197 | 198 | return results 199 | 200 | def test_everything(): 201 | df = None 202 | for p in all_params: 203 | 204 | if int(p['run']) < 600 or int(p['run']) > 700: 205 | continue 206 | 207 | tested_parameters = [] 208 | all_train_set = list(filter(lambda file: file['model'] == p['model'] and file['dataset_trained'] == p['dataset_trained'] and file['patch_size'] == p['patch_size'] , all_params)) 209 | for item in all_train_set: 210 | p2 = p.copy() 211 | 212 | p2['test_synthetic'] = item['train_synthetic'] 213 | p2['test_natural'] = item['train_natural'] 214 | p2['test_points'] = item['train_points'] 215 | p2['test_black'] = item['train_black'] 216 | 217 | p2['test_synthetic'] = p2['train_synthetic'] 218 | p2['test_natural'] = p2['train_natural'] 219 | p2['test_points'] = p2['train_points'] 220 | p2['test_black'] = p2['train_black'] 221 | 222 | param_tuple = (item['train_synthetic'],item['train_natural'],item['train_points'],item['train_black']) 223 | if param_tuple in tested_parameters: 224 | continue 225 | tested_parameters.append(param_tuple) 226 | results = test(p2) 227 | p2.update(results) 228 | for key, val in p.items(): 229 | if isinstance(val, list): 230 | p2[key] = ' '.join(str(e) for e in val) 231 | if df is None: 232 | df = pandas.DataFrame(p2, index=[0]) 233 | else: 234 | df = pandas.concat([df, pandas.DataFrame(p2, index=[0])], axis=0, sort=False) 235 | df.to_csv('results_nonoise_test.csv', float_format='%.10f') 236 | break 237 | 238 | 239 | def test_noise(): 240 | noise_levels = np.linspace(0.0,1.0,50) 241 | df = None 242 | runs= [263,164,102,146,605,609] 243 | 244 | runsf = [str(i) for i in runs] 245 | 246 | noise_type = 'axial_luminosity' 247 | 248 | filename = 'results_noise_{}_{}_{}_{}_{}.csv'.format(noise_type, np.min(noise_levels) , np.max(noise_levels), noise_levels.shape[0] , '-'.join(runsf)) 249 | print("Saving in {}".format(filename)) 250 | 251 | all_train_set = list(filter( 252 | lambda file: file['run'] in runs, all_params)) 253 | for p in all_train_set: 254 | for noise in noise_levels: 255 | log.info("Noise level: {}".format(noise)) 256 | log.info("Noise type: {}".format(noise_type)) 257 | p2 = p.copy() 258 | p2['noise'] = noise 259 | p2['noise_type'] = noise_type 260 | p2['test_synthetic'] = p['train_synthetic'] 261 | p2['test_natural'] = p['train_natural'] 262 | p2['test_points'] = p['train_points'] 263 | p2['test_black'] = p['train_black'] 264 | results = test(p2) 265 | p2.update(results) 266 | for key, val in p.items(): 267 | if isinstance(val, list): 268 | p2[key] = ' '.join(str(e) for e in val) 269 | if df is None: 270 | df = pandas.DataFrame(p2, index=[0]) 271 | else: 272 | df = pandas.concat([df, pandas.DataFrame(p2, index=[0])], axis=0, sort=False) 273 | df.to_csv(filename, float_format='%.10f') 274 | 275 | 276 | def print_noisy(): 277 | images = random_open_crop('/media/adrian/ext4data/data/images_texture2/', None, num_images = 3, as_grey=True) 278 | ranges_large = [0.1, 0.5, 0.9] 279 | i = 0 280 | def to_long(img): 281 | if img.min() < 0: 282 | img -= img.min() 283 | if img.max() > 1: 284 | img /= img.max() 285 | img *= 65535 286 | 287 | return img.astype(np.uint16) 288 | 289 | for img in images: 290 | 291 | img = img.astype(np.float) 292 | img = scale(img) 293 | if (np.min(img) < 0.0): 294 | img -= np.min(img) 295 | if img.max() > 0: 296 | img -= img.mean() 297 | img /= (img.std()) 298 | img += 1 299 | if img.max() > 1: 300 | img /= img.max() 301 | if img.min() < 0: 302 | img -= img.min() 303 | if img.max() > 1: 304 | img /= img.max() 305 | image = Tensor(img) 306 | 307 | for r in ranges_large: 308 | io.imsave('noisy_gaussian_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(image, 'gauss', r).data.cpu().numpy(), 128))) 309 | io.imsave('noisy_poisson_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(image, 'poisson', r).data.cpu().numpy(), 128))) 310 | io.imsave('noisy_gaussianpoisson_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(noisy(image, 'poisson', r), 'gauss', r).data.cpu().numpy(), 128))) 311 | io.imsave('noisy_rotation_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(image, 'rotation', r).data.cpu().numpy(), 128))) 312 | io.imsave('noisy_luminosity_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(image, 'luminosity', r).data.cpu().numpy(), 128))) 313 | io.imsave('noisy_axialluminosity_{}_{}.png'.format(i, r), to_long(center_crop_pixel(noisy(image, 'axial_luminosity', r).data.cpu().numpy(), 128))) 314 | i += 1 315 | 316 | 317 | def test_plane_graph(feat13_arr, feat23_arr, num_cases, step, figname='', degree=6, verbose=True): 318 | 319 | degrees_inclination = degree 320 | radians_inclination = degrees_inclination * math.pi / 180. 321 | resolution = 0.32 322 | size = feat13_arr[0].shape[0] 323 | max_view = resolution*size #micrometers 0.32 micron per pixel * 2048 pixel 324 | max_depth = max_view * math.tan(radians_inclination) *2.0 325 | print('Max depth : {} um'.format(max_depth)) 326 | 327 | if verbose: 328 | plt.figure(dpi=600, figsize=(4,3)) 329 | plt.rc('text', usetex=True) 330 | plt.rc('font', family='serif') 331 | 332 | for idx, feat13 in enumerate(feat13_arr): 333 | 334 | feat13 = feat13[:,1::step] 335 | y_13 = feat13.mean(axis=0) 336 | x = np.linspace(0,feat13.shape[1],feat13.shape[1]) 337 | var = feat13.var(axis=0) 338 | alpha = max_depth / (y_13.max()-y_13.min()) 339 | y = y_13*alpha 340 | x = x*max_view/num_cases[idx] 341 | 342 | fit = np.polyfit(x,y,1) 343 | fit_fn = np.poly1d(fit) 344 | fitted_x = fit_fn(x) 345 | rsquare = r2_score(y, fitted_x) 346 | error1 = ((y - fitted_x).__abs__()) 347 | #if verbose: 348 | # plt.plot(x,y, '.', label=r'$z(\mathbf{s})$') 349 | # plt.plot(x, fitted_x, '--k', label='gt') 350 | # plt.errorbar(x,y,yerr=var*max_view) 351 | # print('Y range :{}'.format((y.max()-y.min()))) 352 | # print('Image {} | Feature 1+3 : R2= {} err= {}'.format(idx, rsquare, error1.mean())) 353 | 354 | feat23 = feat23_arr[idx] 355 | feat23 = feat23[:,1::step] 356 | y_23 = feat23.mean(axis=0) 357 | x = np.linspace(0,feat23.shape[1],feat23.shape[1]) 358 | var = feat23.var(axis=0) 359 | alpha = max_depth / (y_23.max()-y_23.min()) 360 | y2 = y_23*alpha 361 | x = x*max_view/num_cases[idx] 362 | 363 | fit = np.polyfit(x,y2,1) 364 | fit_fn = np.poly1d(fit) 365 | fitted_x = fit_fn(x) 366 | rsquare2 = r2_score(y2, fitted_x) 367 | error2 = ((y2 - fitted_x).__abs__()) 368 | if verbose: 369 | print('Y range :{}'.format((y2.max()-y2.min()))) 370 | 371 | plt.plot(x,y2, '.', label=r'$z(\mathbf{s})$') 372 | plt.plot(x, fitted_x, '--k', label='gt') 373 | plt.errorbar(x,y2,yerr=var*max_view) 374 | print('Image {} | Feature 2+3 : R2= {} err= {}'.format(idx, rsquare2, error2.mean())) 375 | 376 | if verbose: 377 | #plt.title(r'$R^2=$' + '{:0.3f} / error = {:0.5f}'.format(rsquare, error)) 378 | 379 | plt.xlabel(r'x position $(\mu m)$') 380 | plt.ylabel(r'z position $(\mu m)$') 381 | data = np.asarray([y2.max(), y2.min()]) 382 | plt.ylim(data.min()-5, data.max()+5) 383 | #plt.title(figname) 384 | plt.legend() 385 | plt.tight_layout() 386 | plt.gcf().subplots_adjust(bottom=0.22) 387 | plt.savefig('line_plot.png') 388 | 389 | # plt.figure(dpi=300, figsize=(5,3)) 390 | # plt.plot(x, error1, '.') 391 | # plt.plot(x, error2, '.') 392 | # plt.xlabel('Y direction in the input image [microns]') 393 | # plt.ylabel('Depth [microns]') 394 | # plt.legend(['error 1+3', 'error 2+3']) 395 | # plt.title(figname) 396 | plt.show() 397 | 398 | print('R2 1+3 = {} R2 2+3 = {}'.format(rsquare, rsquare2)) 399 | print('Error 1+3 = {} error 2+3 = {}'.format(error1.mean(), error2.mean())) 400 | return rsquare, rsquare2, error1.mean(), error2.mean(), error1.std(), error2.std() 401 | 402 | 403 | def test_plane_stats(model_file=None, step=128): 404 | all_data = [] 405 | degrees = [3,6,10] 406 | folder_tpl = '/home/adrian/git/adrian-wip-git/gaussiandeconv/img/astigmatism_avril2019/proche/{}/*/*.tif' 407 | 408 | if model_file is None: 409 | folder_model_tpl = '/media/adrian/OMENDATA/data/trained_models_new/*2d*.pt' 410 | else: 411 | folder_model_tpl = model_file 412 | 413 | files = glob.glob(folder_model_tpl) 414 | for model_file in files: 415 | print("Load model {}".format(model_file)) 416 | for degree in degrees: 417 | folder = folder_tpl.format(degree) 418 | print('Search for {}'.format(folder)) 419 | img_list = glob.glob(folder) 420 | print('Found {} files'.format(len(img_list))) 421 | for im_file in img_list: 422 | feat13_arr = [] 423 | feat23_arr = [] 424 | num_cases_arr = [] 425 | feat13, feat23, num_cases = test_moving_window(model_file, im_file, step=step, verbose=False) 426 | feat13_arr.append(feat13) 427 | feat23_arr.append(feat23) 428 | num_cases_arr.append(num_cases) 429 | 430 | rsquare, rsquare2, error1, error2, e1std, e2std = test_plane_graph(feat13_arr, feat23_arr, num_cases_arr, step, '{} degrees'.format(degree), degree=degree, verbose=True) 431 | all_data.append({'model_file':model_file, 'degree':degree, 'rsquare1':rsquare, 'rsquare2':rsquare2, 'error1':error1, 'error2':error2, 'error1_std':e1std, 'error2_std':e2std}) 432 | df = pandas.DataFrame(all_data) 433 | df.to_csv('plane_stats_results_{}.csv'.format(datetime.datetime.now()),float_format='%.5f') 434 | 435 | 436 | model_file = None 437 | image_file = [] 438 | 439 | class Example(QMainWindow): 440 | 441 | def __init__(self): 442 | super().__init__() 443 | 444 | self.initUI() 445 | 446 | def initUI(self): 447 | self.central_widget = QWidget() 448 | self.setCentralWidget(self.central_widget) 449 | 450 | self.folderLayout = QWidget(); 451 | 452 | self.pathRoot = QDir.rootPath() 453 | 454 | self.dirmodel = QFileSystemModel(self) 455 | self.dirmodel.setRootPath(QDir.currentPath()) 456 | 457 | self.indexRoot = self.dirmodel.index(self.dirmodel.rootPath()) 458 | 459 | self.folder_view = QTreeView(); 460 | self.folder_view.setDragEnabled(True) 461 | self.folder_view.setModel(self.dirmodel) 462 | self.folder_view.setRootIndex(self.indexRoot) 463 | 464 | self.selectionModel = self.folder_view.selectionModel() 465 | 466 | self.left_layout = QVBoxLayout() 467 | self.left_layout.addWidget(self.folder_view) 468 | 469 | self.folderLayout.setLayout(self.left_layout) 470 | 471 | splitter_filebrowser = QSplitter(Qt.Horizontal) 472 | splitter_filebrowser.addWidget(self.folderLayout) 473 | splitter_filebrowser.addWidget(Figure_Canvas(self)) 474 | splitter_filebrowser.setStretchFactor(1, 1) 475 | 476 | self.textbox = QLineEdit(self) 477 | self.textbox.resize(50, 40) 478 | self.textbox.setText('128') 479 | button = QPushButton('Load', self) 480 | button.clicked.connect(self.on_click) 481 | 482 | button2 = QPushButton('Line graph', self) 483 | button2.clicked.connect(self.on_click2) 484 | 485 | vbox = QVBoxLayout() 486 | vbox.addWidget(self.textbox) 487 | vbox.addWidget(button) 488 | vbox.addWidget(button2) 489 | 490 | hbox = QHBoxLayout(self) 491 | hbox.addWidget(splitter_filebrowser) 492 | hbox.addLayout(vbox) 493 | 494 | self.centralWidget().setLayout(hbox) 495 | 496 | self.setWindowTitle('PSF detection map GUI') 497 | self.setGeometry(750, 100, 800, 600) 498 | 499 | @pyqtSlot() 500 | def on_click(self): 501 | print('load') 502 | step = int(self.textbox.text()) 503 | feat13_arr = [] 504 | feat23_arr = [] 505 | num_cases_arr = [] 506 | for im_file in image_file: 507 | feat13, feat23, num_cases = test_moving_window(model_file, im_file, step=step) 508 | feat13_arr.append(feat13) 509 | feat23_arr.append(feat23) 510 | num_cases_arr.append(num_cases) 511 | test_plane_graph(feat13_arr, feat23_arr, num_cases_arr, step) 512 | plt.show() 513 | 514 | def on_click2(self): 515 | step = int(self.textbox.text()) 516 | test_plane_stats(model_file, step) 517 | plt.show() 518 | 519 | 520 | class Figure_Canvas(QWidget): 521 | 522 | def __init__(self, parent): 523 | super().__init__(parent) 524 | 525 | self.setAcceptDrops(True) 526 | 527 | blabla = QLineEdit() 528 | 529 | self.right_layout = QVBoxLayout() 530 | self.right_layout.addWidget(blabla) 531 | 532 | self.buttonLayout = QWidget() 533 | self.buttonLayout.setLayout(self.right_layout) 534 | 535 | def dragEnterEvent(self, e): 536 | 537 | if e.mimeData().hasFormat('text/uri-list'): 538 | e.accept() 539 | else: 540 | e.ignore() 541 | 542 | def dropEvent(self, e): 543 | global image_file, model_file 544 | for url in e.mimeData().urls(): 545 | path = url.toString() 546 | ext = path.split('.')[-1] 547 | path = path.replace('file://','') 548 | if ext in ['tif', 'tiff', 'png']: 549 | image_file.append(path) 550 | print('Image loaded {}'.format(image_file)) 551 | if ext in ['pt', 'model']: 552 | model_file = path 553 | print('Model loaded {}'.format(model_file)) 554 | 555 | 556 | if __name__ == '__main__': 557 | 558 | # GUI to test one model on one image 559 | app = QApplication(sys.argv) 560 | ex = Example() 561 | ex.show() 562 | app.exec_() 563 | 564 | # Test all models for regression 565 | test_everything() 566 | 567 | # Export noisy figures 568 | print_noisy() 569 | 570 | # Test all models for noise degradation 571 | test_noise() 572 | 573 | # Test plane depth detection and output graph 574 | test_plane_stats(model_file=None, step=128) -------------------------------------------------------------------------------- /evaluation_accuracy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | from train import * 37 | 38 | def test_moving_window(model_file, image_file, step=128, verbose=True): 39 | ''' 40 | Output the PSF parameter map from a model and a file. It outputs as well different combinations of parameters for 41 | depth estimation 42 | :param model_file: file of the cnn model 43 | :param image_file: input file 44 | :param step: moving window step 45 | :param verbose: output logs 46 | :return: 47 | ''' 48 | model = torch.load('{}'.format(model_file)) 49 | num_classes = list(model._modules.items())[-1][1].out_features 50 | model.train(False) 51 | for child in model.children(): 52 | if type(child) == nn.BatchNorm2d: 53 | child.track_running_stats = False 54 | channels = list(model._modules.items())[0][1].in_channels 55 | size = 128 56 | real_size = size 57 | x = size 58 | y = size 59 | im = io.imread(image_file, as_gray=True) 60 | if im.dtype == np.uint16: 61 | im = im.astype(np.int32) 62 | 63 | def To_224(img): 64 | img = Tensor(img) 65 | img = img.repeat(3, 1, 1) 66 | img.unsqueeze_(0) 67 | m = nn.Upsample(size=(224, 224), mode='bilinear') 68 | img = m(img) 69 | return img[0].data.cpu().numpy() 70 | 71 | im = pytoolbox.image.utils.center_crop(im, 50) 72 | 73 | im = im[0:im.shape[0]//size * size, 0:im.shape[1]//size * size] 74 | weight_image = np.zeros((im.shape[0], im.shape[1], num_classes)) 75 | 76 | tile_dataset = [] 77 | y_size = 0 78 | i = 0 79 | 80 | while x <= im.shape[0]: 81 | x_size = 0 82 | while y <= im.shape[1]: 83 | a = im[x - size:x, y - size:y] 84 | 85 | if channels == 3: 86 | a = To_224(a) 87 | real_size = 224 88 | tile_dataset.append(a[:]) 89 | weight_image[x - size:x, y - size:y] += 1.0 90 | y += step 91 | x_size += 1 92 | i += 1 93 | y = size 94 | y_size += 1 95 | x += step 96 | 97 | tile_dataset = np.asarray(tile_dataset) 98 | tile_dataset = np.reshape(tile_dataset, (tile_dataset.shape[0], channels, real_size, real_size)) 99 | 100 | start = time.time() 101 | 102 | max_size = tile_dataset.shape[0] 103 | batch_size = 16 104 | it = 0 105 | output_npy = np.zeros((tile_dataset.shape[0], num_classes)) 106 | input_tensor = torch.FloatTensor(tile_dataset) 107 | 108 | while max_size > 0: 109 | num_batch = min(batch_size, max_size) 110 | out = model(Variable(input_tensor.narrow(0, it, num_batch), requires_grad=False ).cuda()) 111 | 112 | out = (out*100.).round() / 100. 113 | out[:,0] = out[:,0].round() 114 | 115 | 116 | output_npy[it:it+num_batch] = out.data.cpu().numpy() 117 | it += num_batch 118 | max_size -= num_batch 119 | 120 | end = time.time() 121 | output = np.zeros((im.shape[0], im.shape[1], output_npy.shape[1])) 122 | 123 | i = 0 124 | x = size 125 | y = size 126 | o = [] 127 | while x <= im.shape[0]: 128 | while y <= im.shape[1]: 129 | output[x - size:x, y - size:y] += output_npy[i, :] 130 | y += step 131 | i += 1 132 | y = size 133 | x += step 134 | output = output / weight_image 135 | output[:,:,0] = np.round(output[:,:,0]) 136 | dir = 'output_psfmap' 137 | 138 | print("Time elapsed : ") 139 | print(end - start) 140 | 141 | if output.shape[2] > 3: 142 | feat13 = output[:, :, 1] * ((output[:, :, 3]) - 0.5) 143 | feat23 = ((scale(output[:, :, 3]) - 0.5) * output[:, :, 2]) 144 | else : 145 | feat13 = output[:, :, 1] * ((output[:, :, 2]) - 0.5) 146 | feat23 = ((scale(output[:, :, 1]) - 0.5) * output[:, :, 2]) 147 | if verbose: 148 | plt.figure() 149 | plt.imshow(im[:, :]) 150 | plt.imsave(dir + '/ast_input.png', im[:, :]) 151 | 152 | plt.title('Input') 153 | for i in range(num_classes): 154 | plt.figure() 155 | plt.imshow(output[:, :, i]) 156 | plt.title('Output feat {}'.format(i)) 157 | 158 | plt.figure() 159 | output[:,:,1] = output[:,:,1] / output[:,:,1].max() 160 | output[:, :, i] = (output[:, :, i] / output[:, :, i].max()) 161 | output[:, :, i-1] = (output[:, :, i-1] / output[:, :, i-1].max()) 162 | combined = output[:, :, i] * output[:, :, i-1] 163 | combined = combined / combined.max() 164 | combined -= 0.5 165 | plt.imshow(output[:, :, 1] * combined) 166 | plt.title('Output feat 1*2*3') 167 | 168 | 169 | plt.figure() 170 | plt.imshow(feat13) 171 | plt.title('Output feat 1*3') 172 | plt.figure() 173 | plt.imshow(feat23) 174 | plt.title('Output feat 2*3') 175 | plt.imsave(dir + '/ast_2and3.png',feat23) 176 | 177 | plt.imsave(dir + '/ast_focus.png', output[:, :, 1] ) 178 | plt.imsave(dir + '/ast_astdirection.png', output[:, :, i] ) 179 | plt.imsave(dir + '/ast_ast.png', output[:, :, i-1] ) 180 | plt.imsave(dir + '/ast_1and3.png', feat13) 181 | 182 | return feat13, feat23, np.sqrt(tile_dataset.shape[0]) 183 | 184 | def eval_regression_accuracy(model=None, test_loader=None, test_header=None, max_iter = 1000): 185 | ''' 186 | Estimates the accuration of regression of a model 187 | :param model: cnn model file 188 | :param test_loader: image loader for test data 189 | :param test_header: what are the name of every parameters 190 | :param max_iter: how many images per model 191 | ''' 192 | global log 193 | 194 | i = 0 195 | nb_batch = 0 196 | cumloss = 0 197 | cumvariance = 0 198 | 199 | all_labels = [] 200 | all_predictions = [] 201 | with torch.no_grad(): 202 | for image, label in test_loader: 203 | 204 | if i >= max_iter: 205 | break 206 | 207 | img = Variable(image, requires_grad=False).cuda() 208 | output = model(img) 209 | output = (output*1000.).round() / 1000. 210 | output[:,0] = output[:,0].round() 211 | 212 | for a in range(output.size(0)): 213 | 214 | all_predictions.append(output[a].cpu().data.numpy()) 215 | all_labels.append(label[a].data.numpy()) 216 | 217 | if label[a,0] == 1: 218 | output[a,1:] = 1000. 219 | 220 | 221 | 222 | loss = l2loss(output, (label).cuda().float()) 223 | variance = l2variance(output, (label).cuda().float()) 224 | cumvariance += variance.cpu().data.numpy() 225 | cumvariance /= 4 226 | i += output.size(0) 227 | nb_batch += 1 228 | cumloss += loss.cpu().data.numpy() 229 | 230 | log.info("Batch {}/{}, err: {}, avg err: {}, var: {:.3E}, avg var:{:.3E}".format(nb_batch,len(test_loader),loss.cpu().data.numpy(), cumloss/nb_batch, variance.cpu().data.numpy(), cumvariance)) 231 | 232 | all_predictions = np.asarray(all_predictions) 233 | all_labels = np.asarray(all_labels) 234 | res = {} 235 | res.update({'error':cumloss/nb_batch, 'variance':cumvariance}) 236 | 237 | for feature in range(0, label.shape[1]): 238 | pred = all_predictions[:,feature] 239 | lab = all_labels[:,feature] 240 | new_pred = [] 241 | new_lab = [] 242 | for idx, val in enumerate(lab): 243 | if val != 1000. and np.abs(lab[idx]- pred[idx]) < 1000.: 244 | new_lab.append(lab[idx]) 245 | new_pred.append(pred[idx]) 246 | 247 | score = r2_score(new_pred, new_lab) 248 | if test_header[feature] == 'fwmh': 249 | test_header[feature] = 'focus' 250 | log.info("Result for feature {} with {} samples, R2 = {}".format(test_header[feature], len(new_pred), score)) 251 | res.update({test_header[feature]:score}) 252 | 253 | log.info("Error on the full set: {}".format(cumloss/nb_batch)) 254 | log.info("Error variance on the full set: {}".format(cumvariance)) 255 | return res 256 | 257 | 258 | if __name__ == '__main__': 259 | 260 | patch_size = 128 261 | synthetic = 0 262 | natural = 1 263 | points = 0 264 | black = 5 265 | model_type = '1dzernike' 266 | noise = 0.0 267 | noise_type = 'poisson' 268 | model_name = 'resnet34' 269 | suffix = '_noise_0' 270 | run_name = '{}_n_{}_s_{}_p_{}_b_{}{}_{}_{}/'.format(patch_size, natural, synthetic, points, black, suffix, model_type, run_nb) 271 | 272 | import logging 273 | 274 | logging.basicConfig( 275 | format="%(asctime)s [{}] %(message)s".format(run_name), 276 | handlers=[ 277 | logging.FileHandler("output_log_{}.log"), 278 | logging.StreamHandler() 279 | ]) 280 | 281 | niter = 0 282 | log = logging.getLogger('') 283 | log.setLevel(logging.INFO) 284 | 285 | test_loader,test_header = load_crops(folder_prefix= folder_prefix, test=True, patch_size=patch_size, synthetic=synthetic, natural=natural, points=points, black=black, model_type=model_type, batch_size=16, noise= noise, isnew=False, noise_type=noise_type, suffix= suffix) 286 | train_loader,train_header = load_crops(folder_prefix= folder_prefix, test=False, patch_size=patch_size, synthetic=synthetic, natural=natural, points=points, black=black, model_type=model_type, batch_size=batch_size, noise=0.0, suffix= suffix) 287 | 288 | log.info("Starting model run {}....".format(run_name)) 289 | 290 | st = "model_{0}_{1}*".format(run_name.strip("/"), model_name) 291 | 292 | list_saves = glob.glob(st) 293 | if len(list_saves) > 0: 294 | list_saves = sorted(list_saves) 295 | idx = list_saves[-1].find("_ep")+3 296 | current_epoch = int(list_saves[-1][idx:idx+2]) 297 | log.info("Loading file {}... epoch {}".format(list_saves[-1], current_epoch)) 298 | model = torch.load(list_saves[-1]) 299 | elif model_name == 'resnet34': 300 | model, model_name = resnet34() 301 | elif model_name == 'resnext50': 302 | model, model_name = resnext50() 303 | elif model_name == 'resnet34pretrained': 304 | model, model_name = resnet34_pretrained() 305 | elif model_name == 'resnet50pretrained': 306 | model, model_name = resnet50_pretrained() 307 | else: 308 | log.error("MODEL {} NOT FOUND".format(model_name)) 309 | 310 | run_name = 2 311 | model, model_name = load(run_name) 312 | 313 | logging.basicConfig( 314 | format="%(asctime)s [{}] %(message)s".format(run_name+model_name), 315 | handlers=[ 316 | logging.FileHandler("output_log_{}.log"), 317 | logging.StreamHandler() 318 | ]) 319 | 320 | test_moving_window() 321 | eval_regression_accuracy(model, test_loader, test_header) -------------------------------------------------------------------------------- /figure_depth.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | 37 | import numpy as np 38 | from matplotlib import pyplot as plt 39 | from sklearn.metrics import r2_score 40 | import math 41 | 42 | 43 | # Parameters of the microscope 44 | degrees_inclination = 6 45 | radians_inclination = degrees_inclination * math.pi / 180. 46 | resolution = 0.32 47 | size = 2048 48 | num_cases = 16 49 | max_view = resolution*size #micrometers 0.32 micron per pixel * 2048 pixel 50 | max_depth = max_view * math.tan(radians_inclination) 51 | print('Max depth : {} um'.format(max_depth)) 52 | 53 | # This file is generated by test_plane_stats 54 | table = np.load('feat13.npy') 55 | x = np.linspace(0,num_cases,num_cases) 56 | table = table[:,1::128] 57 | y = table.mean(axis=0) 58 | 59 | alpha = max_depth / (y.max()-y.min()) 60 | y = y*alpha 61 | x = x*max_view/num_cases 62 | 63 | print('Y range :{}'.format((y.max()-y.min()))) 64 | 65 | 66 | var = table.var(axis=0) 67 | fit = np.polyfit(x,y,1) 68 | fit_fn = np.poly1d(fit) 69 | fitted_x = fit_fn(x) 70 | rsquare = r2_score(y, fitted_x) 71 | error = ((y - fitted_x).__abs__()).mean() 72 | plt.figure(dpi=300, figsize=(6,3)) 73 | plt.plot(x,y, '.', x, fitted_x, '--k') 74 | plt.title(r'$R^2=$' + '{:0.3f} / error = {:0.5f}'.format(rsquare, error)) 75 | plt.xlabel('Y direction in the input image [microns]') 76 | plt.ylabel('Depth [microns]') 77 | plt.errorbar(x,y,yerr=var*max_view) 78 | plt.show() -------------------------------------------------------------------------------- /figure_noise_resistance.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | 37 | import numpy as np 38 | from matplotlib import pyplot as plt 39 | import pandas 40 | import glob 41 | import matplotlib 42 | 43 | # This file is generated by test_noise for different types of noise 44 | list_files = glob.glob('results_noise*.csv') 45 | 46 | # Every type of noise has one file 47 | for file in list_files: 48 | try: 49 | df2 = pandas.read_csv(file) 50 | except: 51 | df2 = None 52 | 53 | data = {} 54 | for idx in range(len(df2)): 55 | #if '2dgaussian' == df2['dataset_trained'][idx]: 56 | # focus = df2['fwmhx'][idx] + df2['fwmhy'][idx] 57 | # focus /=2 58 | #elif '2dzernike' == df2['dataset_trained'][idx]: 59 | # focus = df2['focus'][idx] + df2['ast'][idx] + df2['ast_angle'][idx] 60 | # focus /= 3 61 | #else: 62 | focus = df2['focus'][idx] 63 | 64 | name = "{}_{}_{}_n_{}_s_{}_p_{}_b_{}".format(df2['run'][idx] , df2['model'][idx] , df2['dataset_trained'][idx] , df2['train_natural'][idx], df2['train_synthetic'][idx] , df2['train_points'][idx] , df2['train_black'][idx]) 65 | if name not in data: 66 | data2 = {} 67 | else: 68 | data2 = data[name] 69 | if 'noisedata' not in data2: 70 | data2['noisedata'] = [] 71 | data2['noisevar'] = [] 72 | 73 | data2['noisedata'].append((df2['noise'][idx], focus)) 74 | data2['noisedata'] = sorted(data2['noisedata'], key=lambda tup: tup[0]) 75 | 76 | data2['noisevar'].append((df2['noise'][idx], df2['variance'][idx] / 20.0)) 77 | data2['noisevar'] = sorted(data2['noisevar'], key=lambda tup: tup[0]) 78 | 79 | data.update({name: data2}) 80 | data_array = [] 81 | 82 | for idx, value in enumerate(data): 83 | for i in data[value]['noisedata']: 84 | data[value].update({i[0]:i[1]}) 85 | data[value]['name'] = value+'_mean' 86 | u = data[value].copy() 87 | u.pop('noisedata') 88 | u.pop('noisevar') 89 | data_array.append(u) 90 | for i in data[value]['noisevar']: 91 | data[value].update({i[0]: i[1]}) 92 | data[value]['name'] = value + '_variance' 93 | data[value].pop('noisedata') 94 | data[value].pop('noisevar') 95 | data_array.append(data[value]) 96 | 97 | df = pandas.DataFrame(data_array) 98 | cols = df.columns.tolist() 99 | cols = cols[-1:] + cols[:-1] 100 | df = df[cols] 101 | 102 | dfarr = df.values 103 | dfindex = df.columns.values[1:].astype(np.float) 104 | plt.rc('text', usetex=True) 105 | f = plt.figure(figsize=(4,3), dpi=400) 106 | font = {'family': 'Times New Roman', 107 | 'weight': 'normal', 108 | 'size': 19} 109 | 110 | matplotlib.rc('font', **font) 111 | 112 | legends = [] 113 | for i in range(dfarr.shape[0]//2): 114 | name_mean = dfarr[2*i, 0] 115 | name_var = dfarr[2*i+1, 0] 116 | a = ['606','255','157'] 117 | if any(u in name_mean for u in a): 118 | pass 119 | 120 | print(name_mean) 121 | legends.append(name_mean.split('_')[0]) 122 | mean = np.clip((dfarr[2*i, 1:]).astype(np.float), 0, 1.0) 123 | var = (dfarr[2*i+1, 1:]).astype(np.float)/2 124 | assert(len(dfindex) == len(mean)) 125 | plt.plot(dfindex, mean, '.-') 126 | 127 | for y in np.arange(0, 1, 0.1): 128 | plt.plot(dfindex, [y] * len(dfindex), "--", lw=0.5, color="black", alpha=0.3) 129 | 130 | ax = plt.gca() 131 | ax.locator_params(tight=True, axis='x', nbins=2) 132 | ax.locator_params(axis='y', nbins=1) 133 | 134 | if 'poigauss' in file: 135 | plt.xlabel('Degradation strength') 136 | else: 137 | for n, label in enumerate(ax.xaxis.get_ticklabels()): 138 | label.set_visible(False) 139 | plt.ylim(0, 1) 140 | plt.xlim(0.0, 1.0) 141 | 142 | f.canvas.set_window_title(file) 143 | plt.tight_layout() 144 | plt.gcf().subplots_adjust(bottom=0.45, right=0.85) 145 | plt.savefig('{}{}'.format(file,'.png')) 146 | 147 | plt.show() -------------------------------------------------------------------------------- /generate_training_set.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | from toolbox import * 37 | from pandas.io.parsers import read_csv 38 | import glob 39 | import random 40 | import numpy as np 41 | from skimage import io 42 | import argparse 43 | import h5py 44 | import imageio 45 | 46 | 47 | def get_parser(): 48 | parser = argparse.ArgumentParser(description='Create dataset for regression of PSFS') 49 | parser.add_argument('--patch_size', dest='patch_size', type=int, default=128) 50 | parser.add_argument('--psf_size', dest='psf_size', type=int, default=127) 51 | parser.add_argument('--name', dest='name', type=str, default='images') 52 | parser.add_argument('--synthetic', dest='synthetic', type=int, default=0) 53 | parser.add_argument('--natural', dest='natural', type=int, default=1) 54 | parser.add_argument('--points', dest='points', type=int, default=0) 55 | parser.add_argument('--black', dest='black', type=int, default=5) 56 | parser.add_argument('--noise', dest='noise', type=int, default=1) 57 | parser.add_argument('--type', dest='type', type=str, default='1dgaussian') 58 | return parser 59 | 60 | 61 | file_list = glob.glob("../input_images/*") 62 | random.shuffle(file_list) 63 | num = int(round(len(file_list)*0.9)) 64 | training_file_list, test_file_list = file_list[:num], file_list[num:] 65 | 66 | 67 | def handle_list(output_dir,is_train=True): 68 | 69 | if model_type == '1dgaussian': 70 | header = 'isfake,fwmh\n' 71 | elif model_type == '2dgaussian': 72 | header = 'isfake,fwmhx,fwmhy\n' 73 | elif model_type == '1dzernike': 74 | header = 'isfake,focus\n' 75 | elif model_type == '2dzernike': 76 | header = 'isfake,focus,ast,ast_angle\n' 77 | else: 78 | print('Undefined model type') 79 | exit() 80 | output_dir = input_dir+output_dir 81 | create_dir(output_dir) 82 | list_file = open("{}/parameters.txt".format(output_dir), 'w') 83 | list_file.write(header) 84 | 85 | i = 0 86 | 87 | def handle_file(original_img, filenb, i, is_synth=False, is_black=False): 88 | 89 | for a in range(30): 90 | img = original_img.copy() 91 | if model_type == '1dgaussian': 92 | img, params = do_convolution_gaussian(img, dimensions=1) 93 | str_params_valid = '{},{}\n'.format(0, params['fwmhx']) 94 | str_params_invalid = '{},{}\n'.format(1, 1000) 95 | elif model_type == '2dgaussian': 96 | img, params = do_convolution_gaussian(img, dimensions=2) 97 | str_params_valid = '{},{},{}\n'.format(0, params['fwmhx'], params['fwmhy']) 98 | str_params_invalid = '{},{},{}\n'.format(1, 1000, 1000) 99 | elif model_type == '1dzernike': 100 | img, params = do_convolution_zernike(img, dimensions=1) 101 | str_params_valid = '{},{}\n'.format(0, params.focus) 102 | str_params_invalid = '{},{}\n'.format(1, 1000) 103 | elif model_type == '2dzernike': 104 | img, params = do_convolution_zernike(img, dimensions=2) 105 | str_params_valid = '{},{},{},{}\n'.format(0, params.focus, params.ast, params.ast_angle) 106 | str_params_invalid = '{},{},{},{}\n'.format(1, 1000,1000,1000,1000) 107 | else: 108 | print('Undefined model type') 109 | exit() 110 | 111 | img = img.astype(np.float) 112 | 113 | if (np.min(img) < 0.0): 114 | img -= np.min(img) 115 | 116 | size = args.patch_size 117 | step = 48 118 | x = size 119 | y = size 120 | 121 | if img.max() > 0: 122 | img -= img.mean() 123 | img /= (img.std()) 124 | img += 1 125 | if img.max() > 1: 126 | img /= img.max() 127 | 128 | while x <= img.shape[0]: 129 | while y <= img.shape[1]: 130 | im = img[x-size:x,y-size:y] 131 | 132 | sum_all_pixels = np.sum(im) 133 | nb_pix = size**2 134 | variance = np.var(im)/nb_pix/255.0 135 | 136 | ratio = sum_all_pixels/nb_pix/255.0 137 | 138 | if args.noise > 0.0 and is_train: 139 | im = noisy(im, 'gauss', rand_float(0.00001, 0.002)) 140 | im = noisy(im, 'poisson', rand_float(0.00001, 0.001)) 141 | if is_train: 142 | coeff = rand_float(0.4, 1.0, 1) 143 | im *= coeff 144 | if im.min() < 0: 145 | im +=im.min() 146 | if im.max() > 1: 147 | im /= im.max() 148 | im = to_8_bit(im) 149 | 150 | print ('{} Ratio : {} | variance : {}'.format(i, ratio,variance)) 151 | filename = "{}/{:05d}/{:09d}.png".format(output_dir,filenb, i) 152 | saved = False 153 | if ((ratio > 0.0003 and variance > 1e-09) or is_synth) and not is_black: 154 | list_file.write(str_params_valid) 155 | print('Good min {} max {}'.format(im.min(), im.max())) 156 | saved = True 157 | 158 | else: 159 | if ratio < 0.00015 or variance < 4e-11 or is_black: 160 | list_file.write(str_params_invalid) 161 | print('As fake') 162 | saved = True 163 | else: 164 | pass 165 | print('Rejected') 166 | if saved: 167 | i += 1 168 | imageio.imsave(filename, im) 169 | y += step 170 | y = size 171 | x += step 172 | return i 173 | 174 | filenb = 0 175 | 176 | if args.natural > 0: 177 | if is_train: 178 | ff = training_file_list 179 | else: 180 | ff = test_file_list 181 | for file in ff: 182 | print('Loading file {}'.format(file)) 183 | original_img = io.imread(file) 184 | original_img = scale(original_img) 185 | create_dir(output_dir + '/{:05d}'.format(filenb)) 186 | i = handle_file(original_img, filenb, i) 187 | filenb += 1 188 | 189 | if args.synthetic > 0: 190 | 191 | if not is_train: 192 | numsynth = int(args.synthetic//10+1) 193 | else: 194 | numsynth = args.synthetic 195 | 196 | for a in range(0,numsynth): 197 | synthetic_image = random_generate((600,600), number_coef=35, size_coeff=0.25, size_variance_coeff = 1.5, noise='poisson', noise_coeff=1.0) 198 | create_dir(output_dir + '/{:05d}'.format(filenb)) 199 | i = handle_file(synthetic_image, filenb, i, True) 200 | filenb +=1 201 | 202 | if args.points > 0: 203 | 204 | if not is_train: 205 | numsynth = int(args.points//10+1) 206 | else: 207 | numsynth = args.points 208 | 209 | for a in range(0,numsynth): 210 | synthetic_image = random_generate((600, 600), number_coef=100, size_coeff=0, size_variance_coeff=0.7, 211 | noise='poisson', noise_coeff=1.0) 212 | create_dir(output_dir + '/{:05d}'.format(filenb)) 213 | i = handle_file(synthetic_image, filenb, i, True, is_black=False) 214 | filenb +=1 215 | 216 | if args.black > 0: 217 | if not is_train: 218 | numsynth = int(args.black//10+1) 219 | else: 220 | numsynth = args.black 221 | 222 | for a in range(0,numsynth): 223 | synthetic_image = np.zeros((800,800)) 224 | create_dir(output_dir + '/{:05d}'.format(filenb)) 225 | i = handle_file(synthetic_image, filenb, i, is_black=True) 226 | filenb +=1 227 | 228 | list_file.close() 229 | 230 | 231 | def do_convolution_gaussian(img, dimensions=1): 232 | small_rand = rand_float(0, 2, 1) 233 | large_rand = rand_float(5, 20, 1) 234 | iso_rand = rand_float(0,20,1) 235 | choice = rand_int(0,3) 236 | if dimensions == 1: 237 | choice = 2 238 | 239 | if choice == 0: 240 | params = {'fwmhx': small_rand[0], 'fwmhy': large_rand[0]} 241 | elif choice == 1: 242 | params = {'fwmhx': large_rand[0], 'fwmhy': small_rand[0]} 243 | elif choice == 2: 244 | params = {'fwmhx': iso_rand[0], 'fwmhy': iso_rand[0]} 245 | else: 246 | return 247 | 248 | psf = gaussian_kernel(args.psf_size, params['fwmhx'], params['fwmhy']) 249 | return [convolve(img, psf, padding='reflect'), params] 250 | 251 | 252 | 253 | def do_convolution_zernike(img, dimensions=1): 254 | small_rand = rand_float(0, 3, 1) 255 | 256 | if(random.randint(0,100) > 50 or (dimensions == 2 and small_rand[0] > 0.5)): 257 | large_rand = rand_float(0, 1.0) 258 | else: 259 | large_rand = rand_float(0, 4.0) 260 | rand_angle = random.choice([0,math.pi/2.0]) 261 | params = Params() 262 | params.magnification = 20 263 | params.n = 1.33 264 | params.na = 0.45 265 | params.wavelength = 500 266 | params.pixelsize = 45 267 | params.tubelength = 200 268 | params.size = args.psf_size 269 | params.focus = large_rand[0] 270 | if dimensions == 1: 271 | params.ast = 0.0 272 | else: 273 | params.ast = small_rand[0] 274 | params.ast_angle = rand_angle 275 | params.sph = 0.0 276 | psf, wavefront, pupil_diameter = get_psf(params) 277 | return [convolve(img, psf), params] 278 | 279 | 280 | def create_hdf5(output_dir): 281 | output_dir = input_dir+output_dir 282 | all_files_list = sorted(glob.glob(output_dir+"/*/*.png"),key=lambda name: int(name[-13:-4])) 283 | num_files = len(all_files_list) 284 | print('{} files found'.format(num_files)) 285 | try: 286 | os.remove(output_dir+'data.h5') 287 | except OSError: 288 | pass 289 | 290 | with h5py.File(output_dir+'data.h5', 'w') as f: 291 | 292 | _file_csv = read_csv(os.path.expanduser(output_dir + "parameters.txt")) 293 | _file = _file_csv.values.astype(np.float) 294 | _header = _file_csv.head(0).columns.base 295 | 296 | dt = h5py.special_dtype(vlen=np.dtype('uint8')) 297 | data = f.create_dataset('data', (num_files, ), dtype=dt) 298 | labels = f.create_dataset('labels', (num_files, len(_header))) 299 | 300 | i = 0 301 | a = 0 302 | num_chunk = 1 303 | max_files_chunk = 1000 304 | image_list = [] 305 | labels_list = [] 306 | for filename in all_files_list: 307 | 308 | if a == max_files_chunk or i == num_files-1: 309 | f['data'][i-a:i] = image_list 310 | f['labels'][i-a:i] = labels_list 311 | 312 | image_list = [] 313 | labels_list = [] 314 | a = 0 315 | print("Chunk {}/{} written...".format(num_chunk, num_files//max_files_chunk +1)) 316 | num_chunk += 1 317 | 318 | image = open(filename, 'rb').read() 319 | 320 | index = int(filename[-13:-4]) 321 | labels_list.append(_file[index]) 322 | image_list.append(np.fromstring(image, dtype='uint8')) 323 | 324 | a += 1 325 | i += 1 326 | print("Finished!") 327 | 328 | 329 | if __name__ == '__main__': 330 | parser = get_parser() 331 | args = parser.parse_args() 332 | model_type = args.type 333 | input_dir = 'data/' 334 | train_dir = 'psf_{}_n_{}_s_{}_p_{}_b_{}_{}_train/'.format(args.patch_size, args.natural, args.synthetic, args.points, args.black, model_type) 335 | print('Directory:{}'.format(train_dir)) 336 | handle_list(train_dir, True) 337 | create_hdf5(train_dir) 338 | test_dir = 'psf_{}_n_{}_s_{}_p_{}_b_{}_{}_test/'.format(args.patch_size, args.natural, args.synthetic, args.points, args.black,model_type) 339 | print('Directory:{}'.format(test_dir)) 340 | handle_list(test_dir, False) 341 | create_hdf5(test_dir) 342 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _pytorch_select=0.2=gpu_0 6 | absl-py=0.8.1=py36_0 7 | arxiv-collector=0.3.4=py_0 8 | asn1crypto=1.2.0=py36_0 9 | asteval=0.9.16=pyh5ca1d4c_0 10 | attrs=19.3.0=py_0 11 | backcall=0.1.0=py36_0 12 | beautifulsoup4=4.8.1=py36_0 13 | blas=1.0=mkl 14 | bottleneck=1.3.1=py36hdd07704_0 15 | c-ares=1.15.0=h7b6447c_1001 16 | ca-certificates=2020.1.1=0 17 | certifi=2019.11.28=py36_1 18 | cffi=1.13.2=py36h8022711_0 19 | chardet=3.0.4=py36_1003 20 | cloudpickle=1.2.2=py_0 21 | cryptography=2.8=py36h1ba5d50_0 22 | cudatoolkit=9.2=0 23 | cudnn=7.6.4=cuda9.2_0 24 | cycler=0.10.0=py_2 25 | cymem=2.0.2=py36he1b5a44_0 26 | cython-blis=0.2.4=py36h516909a_1 27 | cytoolz=0.10.1=py36h7b6447c_0 28 | dask-core=2.9.0=py_0 29 | dataclasses=0.6=py36_0 30 | dbus=1.13.12=h746ee38_0 31 | decorator=4.4.1=py_0 32 | docopt=0.6.2=pypi_0 33 | expat=2.2.6=he6710b0_0 34 | fastai=1.0.59=1 35 | fastprogress=0.1.22=py_0 36 | fontconfig=2.13.0=h9420a91_0 37 | freetype=2.9.1=h8a8886c_1 38 | glib=2.63.1=h5a9c865_0 39 | gpytorch=1.0.1=0 40 | grpcio=1.27.2=py36hf8bcb03_0 41 | gst-plugins-base=1.14.0=hbbd80ab_1 42 | gstreamer=1.14.0=hb453b48_1 43 | h5py=2.9.0=py36h7918eee_0 44 | hdf5=1.10.4=hb1b8bf9_0 45 | icu=58.2=h9c2bf20_1 46 | idna=2.8=py36_0 47 | imagecodecs-lite=2019.12.3=pypi_0 48 | imageio=2.6.1=py36_0 49 | importlib_metadata=1.2.0=py36_0 50 | intel-openmp=2019.4=243 51 | ipython=7.12.0=py36h5ca1d4c_0 52 | ipython_genutils=0.2.0=py36_0 53 | jedi=0.16.0=py36_0 54 | joblib=0.14.1=py_0 55 | jpeg=9b=h024ee3a_2 56 | jsonschema=3.2.0=py36_0 57 | kiwisolver=1.1.0=py36hc9558a2_0 58 | libedit=3.1.20181209=hc058e9b_0 59 | libffi=3.2.1=hd88cf55_4 60 | libgcc-ng=9.1.0=hdf63c60_0 61 | libgfortran-ng=7.3.0=hdf63c60_0 62 | libpng=1.6.37=hbc83047_0 63 | libprotobuf=3.10.1=hd408876_0 64 | libstdcxx-ng=9.1.0=hdf63c60_0 65 | libtiff=4.1.0=h2733197_0 66 | libuuid=1.0.3=h1bed415_2 67 | libxcb=1.13=h1bed415_1 68 | libxml2=2.9.9=hea5a465_1 69 | line_profiler=2.1.2=py36h14c3975_0 70 | llvmlite=0.31.0=py36hd408876_0 71 | lmfit=1.0.0=py_0 72 | markdown=3.1.1=py36_0 73 | matplotlib=3.1.3=py36_0 74 | matplotlib-base=3.1.3=py36hef1b27d_0 75 | mkl=2019.4=243 76 | mkl-service=2.3.0=py36h516909a_0 77 | mkl_fft=1.0.15=py36h516909a_1 78 | mkl_random=1.1.0=py36hb3f55d8_0 79 | more-itertools=7.2.0=py36_0 80 | murmurhash=1.0.2=py36he6710b0_0 81 | ncurses=6.1=he6710b0_1 82 | networkx=2.4=py_0 83 | ninja=1.9.0=hc9558a2_1 84 | numba=0.48.0=py36h0573a6f_0 85 | numexpr=2.7.0=py36h9e4a6bb_0 86 | numpy=1.17.4=py36hc1035e2_0 87 | numpy-base=1.17.4=py36hde5b4d6_0 88 | nvidia-ml-py3=7.352.0=py_0 89 | olefile=0.46=py_0 90 | openssl=1.1.1f=h7b6447c_0 91 | packaging=19.2=py_0 92 | pandas=0.25.3=py36he6710b0_0 93 | parso=0.6.1=py_0 94 | pathvalidate=2.2.2=pypi_0 95 | pcre=8.43=he6710b0_0 96 | pexpect=4.8.0=py36_0 97 | pickleshare=0.7.5=py36_0 98 | pillow=6.2.1=py36h34e0f95_0 99 | pip=19.3.1=py36_0 100 | pipreqs=0.4.10=pypi_0 101 | plac=0.9.6=py36_0 102 | preshed=2.0.1=py36he6710b0_0 103 | prompt_toolkit=3.0.3=py_0 104 | protobuf=3.10.1=py36he6710b0_0 105 | ptyprocess=0.6.0=py36_0 106 | pycparser=2.19=py_0 107 | pydensecrf=1.0rc3=py36hf8a1672_1 108 | pygments=2.5.2=py_0 109 | pyopenssl=19.1.0=py36_0 110 | pyparsing=2.4.5=py_0 111 | pypng=0.0.20=pypi_0 112 | pyqt=5.9.2=py36hcca6a23_4 113 | pyqtgraph=0.10.0=py36h28b3542_3 114 | pyrsistent=0.15.6=py36h7b6447c_0 115 | pysocks=1.7.1=py36_0 116 | python=3.6.9=h265db76_0 117 | python-dateutil=2.8.1=py_0 118 | python_abi=3.6=1_cp36m 119 | pytoolbox=1.0=dev_0 120 | pytorch=1.3.1=cuda92py36hb0ba70e_0 121 | pytz=2019.3=py_0 122 | pywavelets=1.1.1=py36h7b6447c_0 123 | pyyaml=5.1.2=py36h7b6447c_0 124 | qt=5.9.7=h5867ecd_1 125 | ranger=0.0.1=dev_0 126 | readline=7.0=h7b6447c_5 127 | requests=2.22.0=py36_1 128 | scikit-image=0.15.0=py36he6710b0_0 129 | scikit-learn=0.22.1=py36hd81dba3_0 130 | scipy=1.3.1=py36h7c811a0_0 131 | setuptools=42.0.2=py36_0 132 | sip=4.19.8=py36hf484d3e_1000 133 | six=1.13.0=py36_0 134 | soupsieve=1.9.5=py36_0 135 | spacy=2.1.8=py36hc9558a2_0 136 | sqlite=3.30.1=h7b6447c_0 137 | srsly=0.1.0=py36he1b5a44_0 138 | tbb=2020.0=hfd86e86_0 139 | tensorboard=2.0.0=pyhb38c66f_1 140 | tensorboardx=1.9=py_0 141 | thinc=7.0.8=py36hc9558a2_0 142 | tifffile=2019.7.26.2=pypi_0 143 | tk=8.6.8=hbc83047_0 144 | toolz=0.10.0=py_0 145 | torchvision=0.4.2=cuda92py36h1667eeb_0 146 | tornado=6.0.3=py36h516909a_0 147 | tqdm=4.40.0=py_0 148 | traitlets=4.3.3=py36_0 149 | uncertainties=3.1.2=py36_0 150 | urllib3=1.25.7=py36_0 151 | wasabi=0.2.2=py_0 152 | wcwidth=0.1.8=py_0 153 | werkzeug=0.16.0=py_0 154 | wheel=0.33.6=py36_0 155 | xz=5.2.4=h14c3975_4 156 | yaml=0.1.7=had09818_2 157 | yarg=0.1.9=pypi_0 158 | zipp=0.6.0=py_0 159 | zlib=1.2.11=h7b6447c_3 160 | zstd=1.3.7=h0b5b093_0 161 | -------------------------------------------------------------------------------- /toolbox.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | import numpy as np 37 | import numpy 38 | from scipy import signal 39 | import scipy 40 | import math 41 | import copy 42 | import sys 43 | import os 44 | from skimage.transform import rotate 45 | import bz2 46 | import torch 47 | import tifffile as tiff 48 | from imageio import imread as io_imread 49 | import pickle 50 | from matplotlib.backends.backend_pdf import PdfPages 51 | from matplotlib import pyplot as plt 52 | 53 | def multipage(filename, figs=None, dpi=200): 54 | ''' 55 | Print all the plots in a PDF file 56 | ''' 57 | pp = PdfPages(filename) 58 | if figs is None: 59 | figs = [plt.figure(n) for n in plt.get_fignums()] 60 | for fig in figs: 61 | fig.savefig(pp, format='pdf') 62 | pp.close() 63 | 64 | def create_dir(dirName): 65 | if not os.path.exists(dirName): 66 | os.makedirs(dirName) 67 | 68 | def write_tiff_stack(file_name, array, compression=None, rgb=False): 69 | """ Script to export tif file to imageJ, 70 | 71 | usage: tiff.write_stack(out_file_name, array, compression=None) 72 | """ 73 | if libtiff_OK: 74 | out_tiff = TIFF.open(file_name, mode='w') 75 | a = np.flipud(array) 76 | a = np.rollaxis(a, 3, 0) 77 | print('Lib import : ', libtiff_OK) 78 | for zInd in range(a.shape[3]): 79 | out_tiff.write_image(a[:, :, :, zInd], compression=compression, write_rgb=rgb) 80 | 81 | out_tiff.close() 82 | else: 83 | # for some reason, setting imagej=True gets an error from ImageJ with 84 | # 3D stacks (t,X,Y). Hence the comment hereunder 85 | imsave(file_name, array) # , imagej=True) 86 | 87 | return None 88 | 89 | def stretch_contrast(image, min_val=0.0, max_val=1.0): 90 | """ Rescales the greylevels in an image """ 91 | curr_min = np.min(image) 92 | curr_max = np.max(image) 93 | image_ret = image - curr_min # scale starts at zero 94 | ratio = (max_val-min_val) / (curr_max-curr_min) 95 | image_ret = image_ret * ratio + min_val 96 | 97 | return image_ret 98 | 99 | 100 | def scale3d(v): 101 | ''' 102 | Normalize a ND matrix with a maximum of 1 per pixel 103 | :param v: 104 | :return: normalized vector 105 | ''' 106 | shape = v.shape 107 | norm = np.linalg.norm(v.flatten(), ord=1) 108 | if norm == 0: 109 | norm = np.finfo(v.dtype).eps 110 | out = v.flatten() / norm 111 | return np.reshape(out * (1 / np.max(np.abs(out))), newshape=shape) 112 | 113 | 114 | def pickle_save(filename, obj, compressed=True): 115 | """ 116 | save object to file using pickle 117 | """ 118 | 119 | try: 120 | if compressed: 121 | f = bz2.BZ2File(filename, 'wb') 122 | else: 123 | f = open(filename, 'wb') 124 | except IOError as details: 125 | sys.stderr.write('File {} cannot be written\n'.format(filename)) 126 | sys.stderr.write(details) 127 | return 128 | 129 | pickle.dump(obj, f, protocol=2) 130 | f.close() 131 | 132 | 133 | def plot_images(image_list, fig=None, title_str='', figure_title='', 134 | axes_titles=None, sub_plot_shape=None, 135 | is_blob_format=False, channel_swap=None, transpose_order=None, 136 | show_type=None): 137 | ''' 138 | Plot list of n images 139 | image_list: list of images 140 | is_blob_format: Caffe stores images as ch x h x w 141 | True - convert the images into h x w x ch format 142 | transpose_order : If certain transpose order of channels is to be used 143 | overrides is_blob_format 144 | show_type : imshow or matshow (by default imshow) 145 | ''' 146 | image_list = copy.deepcopy(image_list) 147 | if transpose_order is not None: 148 | for i, im in enumerate(image_list): 149 | image_list[i] = im.transpose(transpose_order) 150 | if transpose_order is None and is_blob_format: 151 | for i, im in enumerate(image_list): 152 | image_list[i] = im.transpose((1, 2, 0)) 153 | if channel_swap is not None: 154 | for i, im in enumerate(image_list): 155 | image_list[i] = im[:, :, channel_swap] 156 | plt.ion() 157 | if fig is None: 158 | fig = plt.figure() 159 | plt.figure(fig.number) 160 | plt.clf() 161 | if sub_plot_shape is None: 162 | N = np.ceil(np.sqrt(len(image_list))) 163 | sub_plot_shape = (N, N) 164 | # gs = gridspec.GridSpec(N, N) 165 | ax = [] 166 | for i in range(len(image_list)): 167 | shp = sub_plot_shape + (i + 1,) 168 | aa = fig.add_subplot(*shp) 169 | aa.autoscale(False) 170 | if (len(figure_title) == len(image_list)): 171 | aa.set_title(figure_title[i]) 172 | ax.append(aa) 173 | # ax.append(plt.subplot(gs[i])) 174 | 175 | if show_type is None: 176 | show_type = ['imshow'] * len(image_list) 177 | else: 178 | assert len(show_type) == len(image_list) 179 | 180 | for i, im in enumerate(image_list): 181 | ax[i].set_ylim(im.shape[0], 0) 182 | ax[i].set_xlim(0, im.shape[1]) 183 | if show_type[i] == 'imshow': 184 | ax[i].imshow(im) 185 | elif show_type[i] == 'matshow': 186 | res = ax[i].matshow(im) 187 | plt.colorbar(res, ax=ax[i]) 188 | ax[i].axis('off') 189 | if axes_titles is not None: 190 | ax[i].set_title(axes_titles[i]) 191 | if len(figure_title) == 1: 192 | fig.suptitle(figure_title) 193 | #plt.show(block=True) 194 | 195 | return ax 196 | 197 | 198 | def random_crop(img,N,onlygood=False,randx=None, randy=None): 199 | """ Randomly crop a portion of image of size NxN 200 | Onlygood -> selects for a part of the image with a good variance and mean 201 | """ 202 | xmax = img.shape[0]-N 203 | ymax = img.shape[1]-N 204 | if xmax == 0 and ymax == 0: 205 | return img 206 | if xmax < 0 or ymax < 0: 207 | print('ERROR: the size of the input image is smaller than the crop size.') 208 | return img 209 | if onlygood == False: 210 | if randx is None and randy is None: 211 | randx = np.random.randint(0, xmax) 212 | randy = np.random.randint(0, ymax) 213 | return img[randx:randx+N , randy:randy+N] 214 | else: 215 | i = 0 216 | while i < 100: 217 | if randx is None and randy is None: 218 | 219 | if xmax == 0: 220 | randx = 0 221 | else: 222 | randx = np.random.randint(0, xmax) 223 | if ymax == 0: 224 | randy = 0 225 | else: 226 | randy = np.random.randint(0, ymax) 227 | im = img[randx:randx + N, randy:randy + N] 228 | 229 | sum_all_pixels = np.sum(im) 230 | variance = np.var(im) 231 | nb_pix = N ** 2 232 | ratio = sum_all_pixels / nb_pix 233 | i += 1 234 | if ratio > 0.12 and variance > 0.01: 235 | return im 236 | break 237 | print('Good ratio/variance not found.') 238 | return np.zeros((N,N)) 239 | 240 | 241 | def noisy(image, noise_type, param = 0.01): 242 | ''' 243 | Parameters 244 | ---------- 245 | image : ndarray 246 | Input image data. Will be converted to float. 247 | mode : str 248 | One of the following strings, selecting the type of noise to add: 249 | 250 | 'gauss' Gaussian-distributed additive noise. 251 | 'poisson' Poisson-distributed noise generated from the data. 252 | 'sp' Replaces random pixels with 0 or 1. 253 | 'speckle' Multiplicative noise using out = image + n*image,where 254 | n is uniform noise with specified mean & variance. 255 | ''' 256 | noise_types = ['gauss', 'poisson', 'luminosity', 'rotation', 'axial_luminosity', 'poigauss'] 257 | assert noise_type in noise_types, "ERROR: noise type {} does not exist".format(noise_type) 258 | 259 | if noise_type == 'gauss': 260 | var = param/150.0 261 | sigma = var ** 0.5 262 | mean = torch.zeros(image.size()) 263 | sig = sigma*torch.ones(image.size()) 264 | gauss = torch.normal(mean, sig) 265 | noisy_image = image + gauss 266 | 267 | elif noise_type == "poisson": 268 | 269 | vals = len(np.unique(image.data.cpu().numpy()))*1.0/(420*param) 270 | vals = 2 ** np.ceil(np.log2(vals)) 271 | if float(vals) == 0.0: 272 | return image 273 | noisy_image = torch.Tensor(np.random.poisson(image * vals) / float(vals)) 274 | elif noise_type == "poigauss": 275 | return noisy(noisy(image, 'poisson', param), 'gauss', param) 276 | elif noise_type == "luminosity": 277 | noisy_image = image * (1.0-param) 278 | elif noise_type == "rotation": 279 | noisy_image = rotate(image, 180.00*param, resize=False, order=2) 280 | noisy_image = torch.Tensor(noisy_image) 281 | elif noise_type == "axial_luminosity": 282 | size_im = np.max(list(image.size())) 283 | mask = scale(gaussian_kernel(size_im, fwhmx = size_im/2, fwhmy = size_im/2, verbose=False)) 284 | mask = torch.FloatTensor(mask) 285 | noisy_image = (1-param) * image + param * mask * image 286 | 287 | return noisy_image 288 | 289 | def rand_int(a, b, size=1): 290 | ''' 291 | Return random integers in the half-open interval [a, b). 292 | ''' 293 | return np.floor((b - a) * np.random.random_sample(size) + a).astype(dtype=np.int16) 294 | 295 | 296 | def normalize(v): 297 | ''' 298 | Normalize a 2D matrix with a sum of 1 299 | :param v: 300 | :return: normalized vector 301 | ''' 302 | norm = v.sum() 303 | if norm == 0: 304 | norm = np.finfo(v.dtype).eps 305 | return v / norm 306 | 307 | 308 | def scale(v): 309 | ''' 310 | Normalize a 2D matrix with a maximum of 1 per pixel 311 | :param v: 312 | :return: normalized vector 313 | ''' 314 | norm = np.linalg.norm(v, ord=1) 315 | if norm == 0: 316 | norm = np.finfo(v.dtype).eps 317 | out = v / norm 318 | out = out * (1 / np.max(np.abs(out))) 319 | if np.all(np.isfinite(out)): 320 | return out 321 | else: 322 | print('Error, image is not finite (dividing by infinity on norm).') 323 | return np.zeros(v.shape) 324 | 325 | 326 | def to_8_bit(v): 327 | ''' 328 | Normalize a 32 bit float [0 1] image to [0 255] int 8 bit image. 329 | ''' 330 | if isdtype(v, np.uint8): 331 | print('Warning: input already 8 bit.') 332 | return v 333 | 334 | return np.asarray(np.round(v * 255.0), dtype=np.uint8) 335 | 336 | 337 | def to_16_bit(v): 338 | ''' 339 | Normalize a 32 bit float [0 1] image to [0 65535] int 16 bit image. 340 | ''' 341 | if isdtype(v, np.uint16): 342 | print('Warning: input already 16 bit.') 343 | return v 344 | 345 | return np.asarray(np.round(v * 65536.0), dtype=np.uint16) 346 | 347 | 348 | def to_32_bit(v): 349 | ''' 350 | Normalize [0 255] int 8 bit image to a 32 bit float [0 1] image. 351 | ''' 352 | if isdtype(v, np.float32): 353 | print('Warning: input already 32 bit.') 354 | return v 355 | 356 | return np.asarray(v, dtype=np.float32) / 255.0 357 | 358 | 359 | def unpad(img, npad): 360 | ''' 361 | Revert the np.pad command 362 | ''' 363 | return img[npad:-npad, npad:-npad] 364 | 365 | 366 | def to_radial(x, y): 367 | return x ** 2 + y ** 2 368 | 369 | 370 | def to_radian(x): 371 | return float(x) * np.pi / 180. 372 | 373 | 374 | def isdtype(a, dt=np.float64): 375 | ''' 376 | Test for type 377 | ''' 378 | try: 379 | return a.dtype.num == np.dtype(dt).num 380 | except AttributeError: 381 | return False 382 | 383 | 384 | def center_crop(img, percentage): 385 | ''' 386 | Extract center crop 387 | :param img: input image 388 | :param percentage: percentage of area to keep 389 | ''' 390 | assert (img.shape[0] == img.shape[1]) 391 | a = img.shape[0] 392 | offset = int(round(0.5 * a * (1 - np.sqrt(percentage / 100.0)))) 393 | return img[offset:-offset, offset:-offset] 394 | 395 | 396 | def center_crop_pixel(img, size): 397 | ''' 398 | Extract center crop 399 | :param img: input image 400 | :param pixel size of the patch 401 | ''' 402 | assert (img.shape[0] == img.shape[1]) 403 | if img.shape[0] == size: 404 | return img 405 | assert (img.shape[0] > size) 406 | margin_size = img.shape[0] - size 407 | if margin_size % 2 == 0: 408 | return unpad(img, int(margin_size / 2)) 409 | else: 410 | npad = int(math.floor(margin_size / 2)) 411 | return img[npad:-npad, npad + 1:-npad + 1] 412 | 413 | 414 | def gaussian_kernel(size, fwhmx = 3, fwhmy = 3, center=None, verbose=True): 415 | """ Make a square gaussian kernel. 416 | size is the length of a side of the square 417 | fwhm is full-width-half-maximum, which 418 | can be thought of as an effective radius. 419 | """ 420 | 421 | x = np.arange(0, size, 1, float) 422 | y = x[:, np.newaxis] 423 | 424 | if center is None: 425 | if size % 2 == 0 and verbose: 426 | print("WARNING gaussian_kernel : you have chosen a even kernel size and therefore the kernel is not centered.") 427 | x0 = y0 = size // 2 428 | else: 429 | x0 = center[0] 430 | y0 = center[1] 431 | if fwhmx == 0 and fwhmx == 0: 432 | ret = np.zeros((size, size)) 433 | ret[x0,y0] = 1.0 434 | return ret 435 | else: 436 | return normalize(np.exp(-4 * np.log(2) * ( ((x - x0) ** 2) / fwhmx**2 + ((y - y0) ** 2) / fwhmy**2 ))) 437 | 438 | 439 | def pickle_load(filename, compressed=True): 440 | """ 441 | Load from filename using pickle 442 | """ 443 | 444 | try: 445 | if compressed: 446 | f = bz2.BZ2File(filename, 'rb') 447 | else: 448 | f = open(filename, 'rb') 449 | except IOError as details: 450 | sys.stderr.write('File {} cannot be read\n'.format(filename)) 451 | sys.stderr.write(details) 452 | return 453 | 454 | obj = pickle.load(f) 455 | f.close() 456 | return obj 457 | 458 | 459 | def convolve(input, psf, padding = 'constant'): 460 | ''' 461 | Convolve an image with a psf using FFT 462 | :param padding: replicate, reflect, constant 463 | :return: output image 464 | ''' 465 | psf = normalize(psf) 466 | npad = np.max(psf.shape) 467 | 468 | if len(input.shape) != len(psf.shape): 469 | #print("Warning, input has shape : {}, psf has shape : {}".format(input.shape, psf.shape)) 470 | input = input[:,:,0] 471 | #print("New input shape : {}".format(input.shape)) 472 | 473 | input = np.pad(input, pad_width=npad, mode=padding) 474 | 475 | try: 476 | out = scipy.signal.fftconvolve(input, psf, mode='same') 477 | except: 478 | print("Exception: FFT cannot be made on image !") 479 | out = np.zeros(input.shape) 480 | 481 | out = unpad(out, npad) 482 | return out 483 | 484 | 485 | def rand_float(a, b, size=1): 486 | ''' 487 | Return random floats in the half-open interval [a, b). 488 | ''' 489 | return (b - a) * numpy.random.random_sample(size) + a 490 | 491 | 492 | def rand_int(a, b, size=1): 493 | ''' 494 | Return random integers in the half-open interval [a, b). 495 | ''' 496 | return numpy.floor((b - a) * numpy.random.random_sample(size) + a).astype(dtype=numpy.int16) 497 | 498 | 499 | def get_wavefront(x,y,params): 500 | x = 2.*x/params.size 501 | y = 2.*y/params.size 502 | r2 = to_radial(x, y) 503 | 504 | aberration = params.sph * r2**2 + params.focus * r2 + params.ast * (x*np.cos(params.ast_angle) + y*np.sin(params.ast_angle))**2 + \ 505 | params.coma * ( (x*r2)*np.cos(params.coma_angle) + (y*r2)*np.sin(params.coma_angle)) + \ 506 | params.tilt*(x*np.cos(params.tilt_angle) + y*np.sin(params.tilt_angle)) 507 | wavefront = np.exp(2*1j*np.pi*aberration) 508 | return wavefront 509 | 510 | 511 | class Params: 512 | def __init__(self): 513 | self.tilt = 0. 514 | self.tilt_angle = to_radian(0.) 515 | self.focus = 0. 516 | self.coma = 0. 517 | self.coma_angle = to_radian(0.) 518 | self.ast = 0. 519 | self.ast_angle = to_radian(0.) 520 | self.sph = 0. 521 | self.size = 127. # px 522 | self.wavelength=570. # nm 523 | self.tubelength=200. # mm 524 | self.na = 0.8 525 | self.n = 1.1 #refraction 526 | self.magnification = 20. 527 | self.pixelsize=10 # um 528 | 529 | 530 | def get_psf(params, centered = True): 531 | datapoints = params.size 532 | padding = int(np.ceil(datapoints/2)) 533 | totalpoints = datapoints + 2*padding 534 | center_point = int(np.floor(totalpoints/2)) 535 | 536 | wavelength = params.wavelength * float(1e-9) #wavelength in m 537 | pupil_diameter = 2.0 * params.tubelength * params.na / (params.magnification * params.n) 538 | D = pupil_diameter*1e-3 # diameter in m 539 | d = 1.0*1e-2 # distance btw pupil plane and object 540 | PRw = D / (2 * wavelength * d) # unit = 1/m 541 | NT = params.size 542 | x = np.linspace(-NT/2, NT/2, datapoints) 543 | y = np.linspace(-NT/2, NT/2, datapoints) 544 | xx, yy = np.meshgrid(x, y) 545 | sums = np.power(xx,2) + np.power(yy,2) 546 | wavefront = get_wavefront(xx, yy, params) 547 | pixel_limit = PRw*params.size*params.pixelsize*1e-6 548 | 549 | wavefront[sums > pixel_limit] = 0 550 | wavefront_padded = np.pad(wavefront, ((padding,padding),(padding,padding)), mode='constant',constant_values=(0)) 551 | 552 | psf = np.power(np.abs(np.fft.fft2(wavefront_padded, norm='ortho')),2) 553 | psf = np.roll(psf, center_point, axis = (0,1)) 554 | 555 | normalisation = np.power(np.sum(np.abs(wavefront)) / float(totalpoints),2) 556 | psf = unpad(psf, padding) / normalisation 557 | psf = scale(np.fliplr(psf)).astype(np.float32) 558 | 559 | return psf, wavefront, pupil_diameter 560 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for the implementation of 3 | "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy" 4 | 5 | Copyright (c) 2020 Idiap Research Institute, https://www.idiap.ch/ 6 | Written by Adrian Shajkofci , 7 | All rights reserved. 8 | 9 | This file is part of Spatially-Variant CNN-based Point Spread Function Estimation. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 3. Neither the name of mosquitto nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | ''' 35 | 36 | import torch.nn as nn 37 | from pandas.io.parsers import read_csv 38 | from torch import FloatTensor as Tensor 39 | import glob 40 | import torch.utils as utils 41 | from torch.autograd import Variable 42 | import io as _io 43 | import torchvision.transforms as transforms 44 | import torchvision 45 | from torchvision.datasets import folder 46 | from skimage import io 47 | import os 48 | from toolbox import * 49 | 50 | def grayloader(path): 51 | ''' 52 | Image loader 53 | ''' 54 | img = io.imread(path, as_gray=True) 55 | return img 56 | 57 | 58 | class Noise(object): 59 | ''' 60 | Ads noise to a pytorch image 61 | ''' 62 | def __init__(self, probability, noise_type): 63 | self.probabilit = probability 64 | self.noise_type = noise_type 65 | def __call__(self, img): 66 | img = Tensor(img) 67 | img = img[None, :,:] 68 | if self.noise_type is None: 69 | return img 70 | if self.probabilit > 0.0: 71 | img /= 65535.0 72 | img = noisy(img, self.noise_type, self.probabilit) 73 | if (img.max() > 1.0): 74 | img /= img.max() 75 | if (img.min() < 0): 76 | img -= img.min() 77 | if (img.max() > 1.0): 78 | img /= img.max() 79 | img *= 65535.0 80 | return img 81 | 82 | 83 | class To_224(object): 84 | ''' 85 | Upscale image to 224x224x3 86 | ''' 87 | def __init__(self): 88 | pass 89 | def __call__(self, img): 90 | img = img.repeat(3, 1, 1) 91 | img.unsqueeze_(0) 92 | m = nn.Upsample(size=(224,224), mode='bilinear') 93 | img = m(img) 94 | return img[0] 95 | 96 | 97 | current_epoch = 0 98 | 99 | 100 | def has_file_allowed_extension(filename, extensions): 101 | filename_lower = filename.lower() 102 | return any(filename_lower.endswith(ext) for ext in extensions) 103 | 104 | 105 | def find_classes(dir): 106 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 107 | classes.sort() 108 | class_to_idx = {classes[i]: i for i in range(len(classes))} 109 | return classes, class_to_idx 110 | 111 | 112 | def make_dataset(dir, class_to_idx, extensions): 113 | images = [] 114 | dir = os.path.expanduser(dir) 115 | for target in sorted(os.listdir(dir)): 116 | d = os.path.join(dir, target) 117 | if not os.path.isdir(d): 118 | continue 119 | 120 | for root, _, fnames in sorted(os.walk(d)): 121 | for fname in sorted(fnames): 122 | if has_file_allowed_extension(fname, extensions): 123 | path = os.path.join(root, fname) 124 | item = (path, class_to_idx[target]) 125 | images.append(item) 126 | 127 | return images 128 | 129 | 130 | class DatasetFolder(utils.data.Dataset): 131 | 132 | def __init__(self, root, loader, extensions, transform=None, target_transform=None): 133 | classes, class_to_idx = find_classes(root) 134 | samples = make_dataset(root, class_to_idx, extensions) 135 | if len(samples) == 0: 136 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 137 | "Supported extensions are: " + ",".join(extensions))) 138 | 139 | self.root = root 140 | self.loader = loader 141 | self.extensions = extensions 142 | 143 | self.classes = classes 144 | self.class_to_idx = class_to_idx 145 | self.samples = samples 146 | 147 | self.transform = transform 148 | self.target_transform = target_transform 149 | 150 | def __getitem__(self, index): 151 | path, target = self.samples[index] 152 | sample = self.loader(path) 153 | if self.transform is not None: 154 | sample = self.transform(sample) 155 | if self.target_transform is not None: 156 | target = self.target_transform(target) 157 | return sample, target 158 | 159 | def __len__(self): 160 | return len(self.samples) 161 | 162 | def __repr__(self): 163 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 164 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 165 | fmt_str += ' Root Location: {}\n'.format(self.root) 166 | tmp = ' Transforms (if any): ' 167 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 168 | tmp = ' Target Transforms (if any): ' 169 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | return fmt_str 171 | 172 | 173 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 174 | 175 | 176 | class DatasetFromHdf5(utils.data.Dataset): 177 | def __init__(self, data_file_name, loader =None, transform=None, target_transform=None): 178 | 179 | print('{}'.format(data_file_name)) 180 | hf = h5py.File(data_file_name, mode='r') 181 | self.data = hf['data'] 182 | self.labels = hf['labels'] 183 | self.transform = transform 184 | self.loader = loader 185 | self.target_transform = target_transform 186 | 187 | def __getitem__(self, index): 188 | 189 | 190 | while True: 191 | try: 192 | image = self.data[index] 193 | sample = self.loader(_io.BytesIO(image.tostring())) 194 | target = self.labels[index] 195 | sample = sample.astype(np.float) 196 | 197 | break 198 | except: 199 | index = index - 1 200 | 201 | 202 | if self.transform is not None: 203 | sample = self.transform(sample) 204 | if self.target_transform is not None: 205 | target = self.target_transform(target) 206 | 207 | return sample.float(), target 208 | 209 | def __len__(self): 210 | return self.data.shape[0] 211 | 212 | class ImageFolder(DatasetFolder): 213 | 214 | def __init__(self, root, file_list=None, transform=None, target_transform=None, 215 | loader=folder.default_loader): 216 | super(ImageFolder, self).__init__(root, loader, folder.IMG_EXTENSIONS, 217 | transform=transform, 218 | target_transform=target_transform) 219 | self.imgs = self.samples 220 | self.file_list = file_list 221 | 222 | def __getitem__(self, index): 223 | while True: 224 | try: 225 | path, _ = self.samples[index] 226 | idx_extracted = int(path[-13:-4]) 227 | target = self.file_list[idx_extracted] 228 | sample = self.loader(path) 229 | break 230 | except: 231 | index = rand_int(-10,10) 232 | if target[1] == 1000.: 233 | target[1] = 10. 234 | if self.transform is not None: 235 | sample = self.transform(sample) 236 | if self.target_transform is not None: 237 | target = self.target_transform(target) 238 | return sample.float(), target 239 | 240 | 241 | def load_crops(folder_prefix = '', test = False, patch_size=128, synthetic=10, natural=1, points=10, black=10, model_type='2dgaussian', batch_size=8, noise = 0.0, isnew = False, noise_type=None, suffix=''): 242 | global num_classes, log 243 | 244 | if test: 245 | log.info("Loading test set...") 246 | else: 247 | log.info("Loading train set...") 248 | 249 | if model_type == '1dgaussian': 250 | num_classes = 2 251 | elif model_type == '2dgaussian': 252 | num_classes = 3 253 | elif model_type == '1dzernike': 254 | num_classes = 2 255 | elif model_type == '2dzernike': 256 | num_classes = 4 257 | elif model_type == 'astzernike': 258 | num_classes = 3 259 | elif model_type == 'astsphzernike': 260 | num_classes = 4 261 | else: 262 | print('Undefined model type') 263 | exit() 264 | 265 | 266 | if patch_size != 128: 267 | transform = transforms.Compose([Noise(probability=noise, noise_type=noise_type), To_224()]) 268 | patch_size = 128 269 | else: 270 | transform = transforms.Compose([Noise(probability=noise, noise_type=noise_type)]) 271 | 272 | if test and isnew: 273 | folder_name = '{}/psf_{}_n_{}_s_{}_p_{}_b_{}{}_{}_0_test/'.format(folder_prefix, patch_size, natural, synthetic, points, black, suffix, model_type) 274 | elif test: 275 | folder_name = '{}/psf_{}_n_{}_s_{}_p_{}_b_{}{}_{}_test/'.format(folder_prefix, patch_size, natural, synthetic, points, black, suffix, model_type) 276 | else: 277 | folder_name = '{}/psf_{}_n_{}_s_{}_p_{}_b_{}{}_{}_train/'.format(folder_prefix, patch_size, natural, synthetic, 278 | points, black, suffix, model_type) 279 | 280 | _file_csv = read_csv(os.path.expanduser(folder_name+"parameters.txt")) 281 | _header = _file_csv.head(0).columns.base 282 | 283 | #_set = ImageFolder(folder_name, transform=transforms.ToTensor(), 284 | # loader=grayloader, file_list=_file) 285 | 286 | _set = DatasetFromHdf5(folder_name+"data.h5", loader = grayloader, transform=transform) 287 | if not test: 288 | _loader = torch.utils.data.DataLoader(dataset=_set, batch_size=batch_size, shuffle=True, num_workers=0) 289 | else: 290 | _loader = torch.utils.data.DataLoader(dataset=_set, batch_size=batch_size, shuffle=True, num_workers=0) 291 | 292 | if noise_type is not None: 293 | for test_images, test_labels in _loader: 294 | image = test_images[0][0].data.cpu().numpy() 295 | io.imsave('noise_sample_{}_{:.3f}.png'.format(noise_type, noise), image/65535.0) 296 | break 297 | 298 | return _loader, _header 299 | 300 | 301 | def load_oneimage(size=64, filename='image.tif'): 302 | global patch_size, batch_size 303 | patch_size = size 304 | image = grayloader('image.tif') 305 | train_loader = torch.utils.data.DataLoader(dataset=image, batch_size=batch_size, shuffle=True) 306 | test_loader = torch.utils.data.DataLoader(dataset=image, batch_size=1, shuffle=True) 307 | return train_loader, test_loader 308 | 309 | def conv3x3(in_planes, out_planes, stride=1): 310 | """3x3 convolution with padding""" 311 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 312 | padding=1, bias=False) 313 | 314 | 315 | class BasicBlock(nn.Module): 316 | expansion = 1 317 | 318 | def __init__(self, inplanes, planes, stride=1, downsample=None): 319 | super(BasicBlock, self).__init__() 320 | self.conv1 = conv3x3(inplanes, planes, stride) 321 | self.bn1 = nn.BatchNorm2d(planes) 322 | self.relu = nn.ReLU(inplace=True) 323 | self.conv2 = conv3x3(planes, planes) 324 | self.bn2 = nn.BatchNorm2d(planes) 325 | self.downsample = downsample 326 | self.stride = stride 327 | 328 | def forward(self, x): 329 | residual = x 330 | 331 | out = self.conv1(x) 332 | out = self.bn1(out) 333 | out = self.relu(out) 334 | 335 | out = self.conv2(out) 336 | out = self.bn2(out) 337 | 338 | if self.downsample is not None: 339 | residual = self.downsample(x) 340 | 341 | out += residual 342 | out = self.relu(out) 343 | 344 | return out 345 | 346 | 347 | class Bottleneck(nn.Module): 348 | expansion = 4 349 | 350 | def __init__(self, inplanes, planes, stride=1, downsample=None): 351 | super(Bottleneck, self).__init__() 352 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 353 | self.bn1 = nn.BatchNorm2d(planes) 354 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 355 | padding=1, bias=False) 356 | self.bn2 = nn.BatchNorm2d(planes) 357 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 358 | self.bn3 = nn.BatchNorm2d(planes * 4) 359 | self.relu = nn.ReLU(inplace=True) 360 | self.downsample = downsample 361 | self.stride = stride 362 | 363 | def forward(self, x): 364 | residual = x 365 | 366 | out = self.conv1(x) 367 | out = self.bn1(out) 368 | out = self.relu(out) 369 | 370 | out = self.conv2(out) 371 | out = self.bn2(out) 372 | out = self.relu(out) 373 | 374 | out = self.conv3(out) 375 | out = self.bn3(out) 376 | 377 | if self.downsample is not None: 378 | residual = self.downsample(x) 379 | 380 | out += residual 381 | out = self.relu(out) 382 | 383 | return out 384 | 385 | 386 | class ResNet(nn.Module): 387 | 388 | def __init__(self, block, layers): 389 | self.inplanes = 64 390 | super(ResNet, self).__init__() 391 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, 392 | bias=False) 393 | self.bn1 = nn.BatchNorm2d(64) 394 | self.relu = nn.ReLU(inplace=True) 395 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 396 | self.layer1 = self._make_layer(block, 64, layers[0]) 397 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 398 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 399 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 400 | self.avgpool = nn.AvgPool2d(7, stride=1) 401 | self.fc = nn.Linear(2048, num_classes) 402 | 403 | for m in self.modules(): 404 | if isinstance(m, nn.Conv2d): 405 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 406 | m.weight.data.normal_(0, math.sqrt(2. / n)) 407 | elif isinstance(m, nn.BatchNorm2d): 408 | m.weight.data.fill_(1) 409 | m.bias.data.zero_() 410 | 411 | def _make_layer(self, block, planes, blocks, stride=1): 412 | downsample = None 413 | if stride != 1 or self.inplanes != planes * block.expansion: 414 | downsample = nn.Sequential( 415 | nn.Conv2d(self.inplanes, planes * block.expansion, 416 | kernel_size=1, stride=stride, bias=False), 417 | nn.BatchNorm2d(planes * block.expansion), 418 | ) 419 | 420 | layers = [] 421 | layers.append(block(self.inplanes, planes, stride, downsample)) 422 | self.inplanes = planes * block.expansion 423 | for i in range(1, blocks): 424 | layers.append(block(self.inplanes, planes)) 425 | 426 | return nn.Sequential(*layers) 427 | 428 | def forward(self, x): 429 | x = self.conv1(x) 430 | x = self.bn1(x) 431 | x = self.relu(x) 432 | x = self.maxpool(x) 433 | 434 | x = self.layer1(x) 435 | x = self.layer2(x) 436 | x = self.layer3(x) 437 | x = self.layer4(x) 438 | 439 | x = self.avgpool(x) 440 | x = x.view(x.size(0), -1) 441 | x = self.fc(x) 442 | 443 | return x 444 | 445 | 446 | def resnet18(pretrained=False, **kwargs): 447 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs).cuda() 448 | name = "resnet18" 449 | return model, name 450 | 451 | def resnet34(pretrained=False, **kwargs): 452 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs).cuda() 453 | name = "resnet34" 454 | return model, name 455 | 456 | class BasicBlockX(nn.Module): 457 | expansion = 1 458 | 459 | def __init__(self, inplanes, planes, stride=1, downsample=None, num_group=32): 460 | super(BasicBlockX, self).__init__() 461 | self.conv1 = conv3x3(inplanes, planes*2, stride) 462 | self.bn1 = nn.BatchNorm2d(planes*2) 463 | self.relu = nn.ReLU(inplace=True) 464 | self.conv2 = conv3x3(planes*2, planes*2, groups=num_group) 465 | self.bn2 = nn.BatchNorm2d(planes*2) 466 | self.downsample = downsample 467 | self.stride = stride 468 | 469 | def forward(self, x): 470 | residual = x 471 | 472 | out = self.conv1(x) 473 | out = self.bn1(out) 474 | out = self.relu(out) 475 | 476 | out = self.conv2(out) 477 | out = self.bn2(out) 478 | 479 | if self.downsample is not None: 480 | residual = self.downsample(x) 481 | 482 | out += residual 483 | out = self.relu(out) 484 | 485 | return out 486 | 487 | 488 | class BottleneckX(nn.Module): 489 | expansion = 4 490 | 491 | def __init__(self, inplanes, planes, stride=1, downsample=None, num_group=32): 492 | super(BottleneckX, self).__init__() 493 | self.conv1 = nn.Conv2d(inplanes, planes*2, kernel_size=1, bias=False) 494 | self.bn1 = nn.BatchNorm2d(planes*2) 495 | self.conv2 = nn.Conv2d(planes*2, planes*2, kernel_size=3, stride=stride, 496 | padding=1, bias=False, groups=num_group) 497 | self.bn2 = nn.BatchNorm2d(planes*2) 498 | self.conv3 = nn.Conv2d(planes*2, planes * 4, kernel_size=1, bias=False) 499 | self.bn3 = nn.BatchNorm2d(planes * 4) 500 | self.relu = nn.ReLU(inplace=True) 501 | self.downsample = downsample 502 | self.stride = stride 503 | 504 | def forward(self, x): 505 | residual = x 506 | 507 | out = self.conv1(x) 508 | out = self.bn1(out) 509 | out = self.relu(out) 510 | 511 | out = self.conv2(out) 512 | out = self.bn2(out) 513 | out = self.relu(out) 514 | 515 | out = self.conv3(out) 516 | out = self.bn3(out) 517 | 518 | if self.downsample is not None: 519 | residual = self.downsample(x) 520 | 521 | out += residual 522 | out = self.relu(out) 523 | 524 | return out 525 | 526 | 527 | class ResNeXt(nn.Module): 528 | 529 | def __init__(self, block, layers, num_group=32): 530 | self.inplanes = 64 531 | super(ResNeXt, self).__init__() 532 | self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2, 533 | bias=False) 534 | self.bn1 = nn.BatchNorm2d(64) 535 | self.relu = nn.ReLU(inplace=True) 536 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 537 | self.layer1 = self._make_layer(block, 64, layers[0], num_group) 538 | self.layer2 = self._make_layer(block, 128, layers[1], num_group, stride=2) 539 | self.layer3 = self._make_layer(block, 128, layers[2], num_group, stride=2) 540 | self.layer4 = self._make_layer(block, 128, layers[3], num_group, stride=2) 541 | self.avgpool = nn.AvgPool2d(7, stride=1) 542 | self.fc = nn.Linear(2048, num_classes) 543 | 544 | for m in self.modules(): 545 | if isinstance(m, nn.Conv2d): 546 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 547 | m.weight.data.normal_(0, math.sqrt(2. / n)) 548 | elif isinstance(m, nn.BatchNorm2d): 549 | m.weight.data.fill_(1) 550 | m.bias.data.zero_() 551 | 552 | def _make_layer(self, block, planes, blocks, num_group, stride=1): 553 | downsample = None 554 | if stride != 1 or self.inplanes != planes * block.expansion: 555 | downsample = nn.Sequential( 556 | nn.Conv2d(self.inplanes, planes * block.expansion, 557 | kernel_size=1, stride=stride, bias=False), 558 | nn.BatchNorm2d(planes * block.expansion), 559 | ) 560 | 561 | layers = [] 562 | layers.append(block(self.inplanes, planes, stride, downsample, num_group=num_group)) 563 | self.inplanes = planes * block.expansion 564 | for i in range(1, blocks): 565 | layers.append(block(self.inplanes, planes, num_group=num_group)) 566 | 567 | return nn.Sequential(*layers) 568 | 569 | def forward(self, x): 570 | x = self.conv1(x) 571 | x = self.bn1(x) 572 | x = self.relu(x) 573 | x = self.maxpool(x) 574 | 575 | x = self.layer1(x) 576 | x = self.layer2(x) 577 | x = self.layer3(x) 578 | x = self.layer4(x) 579 | 580 | x = self.avgpool(x) 581 | x = x.view(x.size(0), -1) 582 | x = self.fc(x) 583 | 584 | return x 585 | 586 | 587 | def resnext18( **kwargs): 588 | """Constructs a ResNeXt-18 model. 589 | """ 590 | model = ResNeXt(BasicBlockX, [2, 2, 2, 2], **kwargs).cuda() 591 | name = "resnext18" 592 | return model, name 593 | 594 | 595 | def resnext34(**kwargs): 596 | """Constructs a ResNeXt-34 model. 597 | """ 598 | model = ResNeXt(BasicBlockX, [3, 4, 6, 3], **kwargs).cuda() 599 | name = "resnext34" 600 | return model, name 601 | 602 | 603 | def resnext50(**kwargs): 604 | """Constructs a ResNeXt-50 model. 605 | """ 606 | model = ResNeXt(BottleneckX, [3, 4, 6, 3], **kwargs).cuda() 607 | name = "resnext50" 608 | return model, name 609 | 610 | 611 | def resnet34_pretrained(): 612 | model = torchvision.models.resnet34(pretrained=True) 613 | num_fts = model.fc.in_features 614 | model.fc = nn.Linear(num_fts, num_classes) 615 | model = model.cuda() 616 | return model, 'resnet34pretrained' 617 | 618 | 619 | def resnet50_pretrained(): 620 | model = torchvision.models.resnet50(pretrained=True) 621 | num_fts = model.fc.in_features 622 | model.fc = nn.Linear(num_fts, num_classes) 623 | model = model.cuda() 624 | return model, 'resnet50pretrained' 625 | 626 | 627 | def l2loss(outputs, labels): 628 | zero_or_one = (1.0 - labels[:,0]) 629 | loss_flag = ((outputs[:,0] - labels[:,0])**2).mean() 630 | loss_parameters = ((outputs - labels)**2).mean(1) 631 | loss = (zero_or_one * loss_parameters).mean() + loss_flag 632 | return loss 633 | 634 | def l2variance(outputs, labels): 635 | zero_or_one = (1.0 - labels[:,0]) 636 | loss_flag = ((outputs[:,0] - labels[:,0])**2) 637 | loss_parameters = ((outputs - labels)**2) 638 | loss = (zero_or_one * loss_parameters.mean(1)) + loss_flag 639 | loss_average = loss.mean() 640 | variance = ((loss-loss_average)**2).mean() 641 | return variance 642 | 643 | def train(run_name, train_loader, test_loader): 644 | global log 645 | model.train(True) 646 | global niter, current_epoch 647 | parameters = list(model.parameters()) 648 | optimizer = torch.optim.Adam(parameters, lr=learning_rate, weight_decay=1e-5) 649 | test_errors = [] 650 | def adjust_learning_rate(optimizer, epoch): 651 | lr = learning_rate * (0.9 ** (epoch // 2)) 652 | return lr 653 | 654 | for ep in range(current_epoch, epoch): 655 | current_epoch = ep 656 | log.info("Starting to train Epoch {}".format(ep)) 657 | batch_nb = 0 658 | cum_loss = 0 659 | lr = adjust_learning_rate(optimizer, ep) 660 | log.info("Learning rate : {}".format(lr)) 661 | 662 | for param_group in optimizer.param_groups: 663 | param_group["lr"] = lr 664 | 665 | image_num = 0 666 | for image, label in train_loader: 667 | image = Variable(image).cuda() 668 | optimizer.zero_grad() 669 | output = model(image) 670 | loss = l2loss(output, Variable(label.float()).cuda()) 671 | loss.backward() 672 | optimizer.step() 673 | 674 | niter = ep * len(train_loader) + batch_nb 675 | 676 | cum_loss += loss.cpu().data.numpy() 677 | batch_nb += 1 678 | image_num += output.size(0) 679 | average_error = np.around(cum_loss / batch_nb, 3) 680 | log.info( 681 | "Ep {0}/{1}, lr {6:.1E}, bt {2}/{3}, loss {4:.2E}, avg loss {5:.2E}".format(ep, epoch, batch_nb, len(train_loader), 682 | np.around(loss.cpu().data.numpy(), 3), 683 | average_error, lr)) 684 | if len(test_errors) > 0: 685 | log.info("Test errors : {}".format(test_errors[-1])) 686 | 687 | err = test_image(model, test_loader) 688 | save("{0}_{1}_ep{2:0>2}_trainerr{3:.2}_testerr{4:.2}".format(run_name.strip("/"), model_name, current_epoch, average_error, err)) 689 | test_errors.append(err) 690 | log.info("The training is complete.") 691 | 692 | 693 | def save(run_name): 694 | log.info("Saving the model...") 695 | torch.save(model, 'model_{}.pt'.format(run_name.strip("/"))) 696 | 697 | 698 | def load(run_name): 699 | log.info("Loading the model {}... ".format(run_name)) 700 | model = torch.load('model_{}.pt'.format(run_name.strip("/"))) 701 | return model 702 | 703 | 704 | def test_image(model, loader, max_image=1000): 705 | global log 706 | model.eval() 707 | for child in model.children(): 708 | if type(child) == nn.BatchNorm2d: 709 | child.track_running_stats = False 710 | i = 0 711 | nb_batch = 0 712 | cumloss = 0 713 | 714 | for image, label in loader: 715 | if i > max_image: 716 | break 717 | output = model(Variable(image, requires_grad=False).cuda()) 718 | loss = l2loss(output, Variable(label).cuda().float()) 719 | output = (output*1000.).round() / 1000. 720 | output[:,0] = output[:,0].round() 721 | i += output.size(0) 722 | nb_batch += 1 723 | cumloss += loss.cpu().data.numpy() 724 | error_average = cumloss/nb_batch 725 | log.info("Batch {}/{}, error : {}, avg error {}".format(nb_batch,len(loader),loss.cpu().data.numpy(), error_average)) 726 | 727 | log.info("error on the full set : {}".format(error_average)) 728 | 729 | return error_average 730 | 731 | import logging 732 | 733 | log = logging.getLogger('') 734 | log.setLevel(logging.INFO) 735 | 736 | if __name__ == '__main__': 737 | 738 | epoch = 20 739 | learning_rate = 0.00001 740 | 741 | folder_prefix = "data" 742 | batch_size = 32 743 | 744 | patch_size = 128 745 | synthetic = 0 746 | natural = 1 747 | points = 0 748 | black = 5 749 | model_type = '1dzernike' 750 | noise = 0.0 751 | noise_type = 'poisson' 752 | model_name = 'resnet34' 753 | suffix = '_noise_0' 754 | run_nb = 201 755 | run_name = '{}_n_{}_s_{}_p_{}_b_{}{}_{}_{}/'.format(patch_size, natural, synthetic, points, black, suffix, model_type, run_nb) 756 | 757 | logging.basicConfig( 758 | format="%(asctime)s [{}] %(message)s".format(run_name), 759 | handlers=[ 760 | logging.FileHandler("output_log_{}.log".format(run_nb)), 761 | logging.StreamHandler() 762 | ]) 763 | 764 | niter = 0 765 | log = logging.getLogger('') 766 | log.setLevel(logging.INFO) 767 | test_loader,test_header = load_crops(folder_prefix= folder_prefix, test=True, patch_size=patch_size, synthetic=synthetic, natural=natural, points=points, black=black, model_type=model_type, batch_size=16, noise= noise, isnew=False, noise_type=noise_type, suffix= suffix) 768 | 769 | 770 | train_loader,train_header = load_crops(folder_prefix= folder_prefix, test=False, patch_size=patch_size, synthetic=synthetic, natural=natural, points=points, black=black, model_type=model_type, batch_size=batch_size, noise=0.0, suffix= suffix) 771 | 772 | log.info("Starting model run {}....".format(run_name)) 773 | 774 | st = "model_{0}_{1}*".format(run_name.strip("/"), model_name) 775 | list_saves = glob.glob(st) 776 | if len(list_saves) > 0: 777 | list_saves = sorted(list_saves) 778 | idx = list_saves[-1].find("_ep")+3 779 | current_epoch = int(list_saves[-1][idx:idx+2]) 780 | log.info("Loading file {}... epoch {}".format(list_saves[-1], current_epoch)) 781 | model = torch.load(list_saves[-1]) 782 | elif model_name == 'resnet34': 783 | model, model_name = resnet34() 784 | elif model_name == 'resnext50': 785 | model, model_name = resnext50() 786 | elif model_name == 'resnet34pretrained': 787 | model, model_name = resnet34_pretrained() 788 | elif model_name == 'resnet50pretrained': 789 | model, model_name = resnet50_pretrained() 790 | else: 791 | log.error("MODEL {} NOT FOUND".format(model_name)) 792 | 793 | 794 | logging.basicConfig( 795 | format="%(asctime)s [{}] %(message)s".format(run_name+model_name), 796 | handlers=[ 797 | logging.FileHandler("output_log_{}.log".format(run_nb)), 798 | logging.StreamHandler() 799 | ]) 800 | 801 | train(run_name, train_loader, test_loader) --------------------------------------------------------------------------------