├── flare_removal ├── requirements.txt ├── matlab │ ├── CropCenter.m │ ├── CropRandom.m │ ├── EqualizeChannels.m │ ├── RandomSpectralResponse.m │ ├── GetDefocusPhase.m │ ├── GetPsf.m │ ├── RandomDirtyAperture.m │ └── main.m ├── run.sh ├── python │ ├── models.py │ ├── calculate_metrics.py │ ├── u_net_pp.py │ ├── vgg_test.py │ ├── u_net_test.py │ ├── losses_test.py │ ├── u_net.py │ ├── pretrain_ViT.py │ ├── losses.py │ ├── data_provider.py │ ├── evaluate.py │ ├── test_model.py │ ├── synthesis.py │ ├── remove_flare.py │ ├── train.py │ ├── vgg.py │ └── utils.py └── README.md ├── calc_metrics.sh ├── train_script.sh ├── lens-flare ├── combine_flares.sh └── downsample.sh ├── flare_free ├── combine_images.sh ├── downsample.sh └── split.sh ├── .gitignore ├── evaluation_script.sh ├── test_flare_remove.sh └── README.md /flare_removal/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | numpy 3 | opencv-python 4 | scikit-image 5 | scipy 6 | tensorflow>=2.6 7 | tensorflow-addons 8 | tqdm 9 | -------------------------------------------------------------------------------- /flare_removal/matlab/CropCenter.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function cropped = CropCenter(im, crop) 3 | 4 | if length(crop) == 1 5 | crop = [crop, crop]; 6 | end 7 | window = centerCropWindow2d(size(im), crop); 8 | cropped = imcrop(im, window); 9 | 10 | end 11 | -------------------------------------------------------------------------------- /flare_removal/matlab/CropRandom.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function cropped = CropRandom(im, crop) 3 | 4 | if length(crop) == 1 5 | crop = [crop, crop]; 6 | end 7 | window = randomCropWindow2d(size(im), crop); 8 | cropped = imcrop(im, window); 9 | 10 | end 11 | -------------------------------------------------------------------------------- /calc_metrics.sh: -------------------------------------------------------------------------------- 1 | python3 -m flare_removal.python.calculate_metrics \ 2 | --input_dir=flare_free/real/downsampled_ground_truth \ 3 | --blended_dir=flare_removal/model_tests/unet_3plus_entire_data/test_real/output_blend \ 4 | --output_dir=flare_removal/model_tests/unet_3plus_entire_data/test_real 5 | -------------------------------------------------------------------------------- /train_script.sh: -------------------------------------------------------------------------------- 1 | srun -p csc2529 --nodelist squid05 -c 2 --gres gpu:1 --pty python3 -m flare_removal.python.train \ 2 | --train_dir=flare_removal/trained_models/unet_entire_data/logs \ 3 | --scene_dir=flare_free/downsampled_data_v2 \ 4 | --flare_dir=lens-flare/downsampled_flares_v2 \ 5 | --epochs=150 \ 6 | --batch_size=2 \ 7 | --learning_rate=1e-4 \ 8 | --training_res=224 \ 9 | --flare_res_h=353 \ 10 | --flare_res_w=263 \ 11 | --model=unet \ 12 | --exp_name=_baseline_unet_entire_data 13 | -------------------------------------------------------------------------------- /lens-flare/combine_flares.sh: -------------------------------------------------------------------------------- 1 | # Create the 'combined' directory if it doesn't exist 2 | mkdir -p combined 3 | 4 | # Initialize the counter 5 | counter=1 6 | 7 | # Copy and rename images from 'captured' 8 | for file in captured/*; do 9 | cp "$file" combined/flare_$(printf "%04d" "$counter").png 10 | ((counter++)) 11 | done 12 | 13 | # Copy and rename images from 'simulated' 14 | for file in simulated/*; do 15 | cp "$file" combined/flare_$(printf "%04d" "$counter").png 16 | ((counter++)) 17 | done 18 | -------------------------------------------------------------------------------- /flare_free/combine_images.sh: -------------------------------------------------------------------------------- 1 | # Create the 'data' directory if it doesn't exist 2 | mkdir -p data 3 | 4 | # Initialize the counter 5 | counter=1 6 | 7 | # Copy and rename images from 'reflection_layer' 8 | for file in reflection_layer/*; do 9 | mv "$file" data/img_$(printf "%05d" "$counter").jpg 10 | ((counter++)) 11 | done 12 | 13 | # Copy and rename images from 'transmission_layer' 14 | for file in transmission_layer/*; do 15 | mv "$file" data/img_$(printf "%05d" "$counter").jpg 16 | ((counter++)) 17 | done 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | flare_removal/training 2 | flare_removal/python/remove_flare_224.py 3 | .idea/ 4 | wandb 5 | flare_free/data/ 6 | flare_free/downsampled_data_v2/ 7 | flare_free/downsampled_test/ 8 | flare_free/downsampled_train/ 9 | flare_free/downsampled_val/ 10 | flare_free/real/ 11 | flare_free/synthetic/ 12 | flare_removal/python/__pycache__ 13 | flare_removal/trained_models/ 14 | flare_removal/model_tests/ 15 | lens-flare/combined/ 16 | lens-flare/downsampled_flares 17 | lens-flare/downsampled_flares_v2/ 18 | .nfs00000000039600380000ca3c 19 | -------------------------------------------------------------------------------- /evaluation_script.sh: -------------------------------------------------------------------------------- 1 | srun -p csc2529 -c 8 --nodelist=squid03 --gres gpu:1 --pty python3 -m flare_removal.python.evaluate \ 2 | --eval_dir=flare_removal/unet_3plus_entire_data_eval \ 3 | --train_dir=flare_removal/unet_3plus_entire_data_training/logs \ 4 | --scene_dir=flare_free/real/downsampled_input \ 5 | --flare_dir=lens-flare/downsampled_flares_v2 \ 6 | --training_res=224 \ 7 | --flare_res_h=353 \ 8 | --flare_res_w=263 \ 9 | --model=unet_3plus_2d 10 | 11 | srun -p csc2529 -c 8 --nodelist=squid03 --gres gpu:1 --pty python3 -m flare_removal.python.evaluate \ 12 | --eval_dir=flare_removal/unet_3plus_entire_data_eval \ 13 | --train_dir=flare_removal/unet_3plus_entire_data_training/logs \ 14 | --scene_dir=flare_free/synthetic/downsampled_input \ 15 | --flare_dir=lens-flare/downsampled_flares_v2 \ 16 | --training_res=224 \ 17 | --flare_res_h=353 \ 18 | --flare_res_w=263 \ 19 | --model=unet_3plus_2d 20 | -------------------------------------------------------------------------------- /flare_removal/matlab/EqualizeChannels.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function equalized = EqualizeChannels(im) 3 | % EqualizeChannels Equalizes channel means. 4 | % 5 | % equalized = EqualizeChannels(im) 6 | % Applies scaling to each channel individually, such that the resulting mean 7 | % value of each channel is the same as the minimum channel. This can be thought 8 | % of as a white balance in some sense. Note that we apply a gain that's < 1 9 | % here, which may introduce color artifacts if any input pixel is clipped due to 10 | % saturation. 11 | % 12 | % Arguments: 13 | % 14 | % im: An [H, W, C]-array where C is the channel dimension. 15 | % 16 | % Returns: 17 | % 18 | % equalized: Same shape and type as `im`, with channel means equalized. 19 | % 20 | % Required toolboxes: none. 21 | 22 | channel_means = mean(im, [1, 2]); 23 | channel_gains = min(channel_means) / channel_means; 24 | equalized = im .* channel_gains; 25 | 26 | end 27 | -------------------------------------------------------------------------------- /flare_free/downsample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the input and output directories 4 | img_format="png" 5 | input_dir="data" 6 | output_dir="downsampled_data" 7 | res=224 8 | 9 | # Create the output directory if it doesn't exist 10 | mkdir -p "$output_dir" 11 | 12 | # Loop through all jpg or png files in the "" folder 13 | for image in "$input_dir"/*.$img_format; do 14 | # Extract the base filename without the extension 15 | base_name=$(basename "$image" .$img_format) 16 | 17 | # Set the output filename with "downsampled" postfix and the specified format 18 | output_image="$output_dir/${base_name}_downsampled.$img_format" 19 | 20 | # Downsample the image and copy it to the "downsampled" folder with the manual resolution 21 | ffmpeg -i "$image" -vf scale="${res}:${res}" "$output_image" -loglevel quiet 22 | 23 | echo "Processed: $image -> $output_image" 24 | done 25 | 26 | echo "All images have been downsampled and copied." 27 | -------------------------------------------------------------------------------- /lens-flare/downsample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the input and output directories 4 | img_format="png" 5 | input_dir="combined" 6 | output_dir="downsampled_flares_v2" 7 | res_h=353 8 | res_w=263 9 | 10 | # Create the output directory if it doesn't exist 11 | mkdir -p "$output_dir" 12 | 13 | # Loop through all jpg or png files in the "" folder 14 | for image in "$input_dir"/*.$img_format; do 15 | # Extract the base filename without the extension 16 | base_name=$(basename "$image" .$img_format) 17 | 18 | # Set the output filename with "downsampled" postfix and the specified format 19 | output_image="$output_dir/${base_name}_downsampled.$img_format" 20 | 21 | # Downsample the image and copy it to the "downsampled" folder with the manual resolution 22 | ffmpeg -i "$image" -vf scale="${res_h}:${res_w}" "$output_image" -loglevel quiet 23 | 24 | echo "Processed: $image -> $output_image" 25 | done 26 | 27 | echo "All images have been downsampled and copied." 28 | -------------------------------------------------------------------------------- /flare_free/split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create the target directories 4 | mkdir -p downsampled_train_v2 downsampled_val_v2 downsampled_test_v2 5 | 6 | # Count the total number of images in the data directory 7 | total_images=$(ls downsampled_data_v2 | wc -l) 8 | 9 | # Calculate the number of images for each set 10 | train_count=$((total_images * 80 / 100)) 11 | val_count=$((total_images * 10 / 100)) 12 | test_count=$((total_images - train_count - val_count)) 13 | 14 | # Shuffle the images and distribute them 15 | shuf -e downsampled_data_v2/* | { 16 | # Move images to the train set 17 | counter=0 18 | while [ $counter -lt $train_count ]; do 19 | read file 20 | mv "$file" downsampled_train_v2/ 21 | ((counter++)) 22 | done 23 | 24 | # Move images to the val set 25 | counter=0 26 | while [ $counter -lt $val_count ]; do 27 | read file 28 | mv "$file" downsampled_val_v2/ 29 | ((counter++)) 30 | done 31 | 32 | # Move remaining images to the test set 33 | while read file; do 34 | mv "$file" downsampled_test_v2/ 35 | done 36 | } 37 | -------------------------------------------------------------------------------- /test_flare_remove.sh: -------------------------------------------------------------------------------- 1 | srun -p csc2529 -c 8 --nodelist=squid06 --gres gpu:1 --pty \ 2 | python3 -m flare_removal.python.remove_flare \ 3 | --ckpt=flare_removal/trained_models/unet_3plus_entire_data/logs \ 4 | --model='unet_3plus_2d' \ 5 | --input_dir=flare_free/real/downsampled_input \ 6 | --out_dir=flare_removal/model_tests/unet_3plus_entire_data/test_real 7 | 8 | python3 -m flare_removal.python.calculate_metrics \ 9 | --gt_dir=flare_free/real/downsampled_ground_truth \ 10 | --blended_dir=flare_removal/model_tests/unet_3plus_entire_data/test_real/output_blend \ 11 | --out_dir=flare_removal/model_tests/unet_3plus_entire_data/test_real 12 | 13 | srun -p csc2529 -c 8 --nodelist=squid06 --gres gpu:1 --pty \ 14 | python3 -m flare_removal.python.remove_flare \ 15 | --ckpt=flare_removal/trained_models/unet_3plus_entire_data/logs \ 16 | --model='unet_3plus_2d' \ 17 | --input_dir=flare_free/synthetic/downsampled_input \ 18 | --out_dir=flare_removal/model_tests/unet_3plus_entire_data/test_synthetic 19 | 20 | python3 -m flare_removal.python.calculate_metrics \ 21 | --gt_dir=flare_free/synthetic/downsampled_ground_truth \ 22 | --blended_dir=flare_removal/model_tests/unet_3plus_entire_data/test_synthetic/output_blend \ 23 | --out_dir=flare_removal/model_tests/unet_3plus_entire_data/test_synthetic 24 | -------------------------------------------------------------------------------- /flare_removal/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | # Note: this script must be run from the git repository root (google_research/). 17 | set -e 18 | set -x 19 | 20 | # Create and activate a new virtual environment. 21 | python3 -m venv env 22 | source ./env/bin/activate 23 | 24 | # Install necessary dependencies. 25 | pip install -r flare_removal/requirements.txt 26 | 27 | # The following command should execute without missing dependencies. However, 28 | # it's expected to fail unless you provide appropriate model and data paths in 29 | # additional command-line arguments (or change the default argument values in 30 | # the source file). See README.md for more details. 31 | python3 -m flare_removal.python.remove_flare 32 | -------------------------------------------------------------------------------- /flare_removal/matlab/RandomSpectralResponse.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function response_matrix = RandomSpectralResponse(wavelengths) 3 | % RandomSpectralResponse Random RGB spectral response matrix. 4 | % 5 | % response_matrix = RandomSpectralResponse(wavelengths) 6 | % Returns the RGB response coefficients at the given wavelengths, with 7 | % reasonable random perturbations to simulate the uncertainty in a real-world 8 | % imaging system. 9 | % 10 | % Arguments 11 | % 12 | % wavelengths: an N-vector specifying N wavelengths where the spectral response 13 | % is sampled. Unit: meters. 14 | % 15 | % Returns 16 | % 17 | % response_matrix: a 3 x N matrix where each column contains the RGB response 18 | % factors for the corresponding wavelength in the input 19 | % argument. 20 | % 21 | % Required toolboxes: none. 22 | 23 | rgb_centers = [620; 540; 460] * 1e-9 + rand([3, 1]) * 20e-9; 24 | passband = 50e-9 + rand([3, 1]) * 10e-9; 25 | 26 | r_response = EvaluateGaussian(wavelengths, rgb_centers(1), passband(1)); 27 | g_response = EvaluateGaussian(wavelengths, rgb_centers(2), passband(2)); 28 | b_response = EvaluateGaussian(wavelengths, rgb_centers(3), passband(3)); 29 | 30 | response_matrix = [r_response; g_response; b_response]; 31 | 32 | end 33 | 34 | function val = EvaluateGaussian(x, mu, sigma) 35 | val = exp(-(x - mu) .^ 2 / (2 * sigma .^ 2)); 36 | end 37 | -------------------------------------------------------------------------------- /flare_removal/matlab/GetDefocusPhase.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function [phase, mask] = GetDefocusPhase(n, r) 3 | % GetDefocusPhase Phase shift due to defocus for a round aperture. 4 | % 5 | % [phase, mask] = GetDefocusPhase(n, aperture_r) 6 | % Computes the phase shift per unit defocus in the Fourier domain. Also returns 7 | % the corresponding circular mask on the Fourier plane that defines the valid 8 | % region of the frequency response. 9 | % 10 | % Arguments 11 | % 12 | % n: Number of samples in each direction for the image and spectrum. The output 13 | % will be an [n, n]-array. 14 | % 15 | % r: Radius of the circular low-pass filter applied on the spectrum, assuming 16 | % the spectrum is a unit square. 17 | % 18 | % Returns 19 | % 20 | % phase: Amount of (complex) phase shift in the spectrum for each unit (1) of 21 | % defocus. Zero outside the disk of radius `r`. [n, n]-array. 22 | % 23 | % mask: A centered disk of 1 surrounded by 0, representing the low-pass filter 24 | % that is applied to the spectrum (including the `phase` array above). 25 | % [n, n]-array. 26 | % 27 | % Required toolboxes: none. 28 | 29 | %% Pixel center coordinates in Cartesian and polar forms. 30 | sample_x = linspace(-(n - 1) / 2, (n - 1) / 2, n) / n / r; 31 | [xx, yy] = meshgrid(sample_x); 32 | [~, rr] = cart2pol(xx, yy); 33 | 34 | %% The mask is simply a centered unit disk. 35 | % Zernike polynomials below are only defined on the unit disk. 36 | mask = rr <= 1; 37 | 38 | %% Compute the Zernike polynomial of degree 2, order 0. 39 | % Zernike polynomials form a complete, orthogonal basis over the unit disk. The 40 | % "degree 2, order 0" component represents defocus, and is defined as (in 41 | % unnormalized form): 42 | % 43 | % Z = 2 * r^2 - 1. 44 | % 45 | % Reference: 46 | % Paul Fricker (2021). Analyzing LASIK Optical Data Using Zernike Functions. 47 | % https://www.mathworks.com/company/newsletters/articles/analyzing-lasik-optical-data-using-zernike-functions.html 48 | phase = single(2 * rr .^ 2 - 1); 49 | phase(~mask) = 0; 50 | 51 | end 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # computational_imaging 2 | CSC2529 project 3 | 4 | Data preparation: 5 | 1) Go to flare_free and specify in downsample.sh the input dir (entire original size data - 27,449 images), output dir, and verify res is 224. 6 | 2) Run downsample.sh. 7 | 3) Download the real and synthetic test datasets from https://drive.google.com/drive/folders/1_gi3W8fOEusfmglJdiKCUwk3IA7B2jfQ and place them in flare_free. 8 | 4) Repeat steps 1 and 2 on ground_truth and input directories that are in real (20 images) and synthetic (37 images) directories. 9 | 5) Go to lens-flare and specify in downsample.sh the input dir (entire original flares data - 5,001 images), output dir, and verify that res_h and res_w are 353 and 263 respectively. 10 | 6) Run downsample.sh. 11 | 12 | 13 | Model training: 14 | 1) Go to google-research and specify in train_script.sh train_dir (where to save logs), scene_dir (where to get entire downsampled data), flare_dir (where to get entire downsampled flares), epochs=200, training_res=224, flare_res_h=353, flare_res_w=263, model (name of the model you are running) and exp_name (_). 15 | 2) Run train_script.sh 16 | 17 | 18 | Evaluating model (optional - if there is a separate evaluation data): 19 | 1) After model finished training go to google-research and specify in evaluation_script.sh eval_dir (where to save results), train_dir (logs dir of the trained model), scene_dir (path to downsampled ground_truth test dir, once for real and once for synthetic), training_res=224, flare_res_h=353, flare_res_w=263, model (name of the model type you are evaluating). 20 | 2) Run evaluation_script.sh. 21 | 22 | 23 | Test model to produce visual and qualitative results: 24 | 1) In google-research specify in test_flare_remove.sh: 25 | a) for the commands starting with srun: ckpt (path to logs dir of trained model, model (model type), input_dir (path to downsampled input test dir, once for real and once for synthetic), out_dir (path where results are saved). 26 | b) for the commands starting with python3: gt_dir (path to ground truth images), blended_dir (path to blended images dir) and out_dir (output path of the text file with metrics). 27 | 2) Run test_flare_remove.sh. 28 | -------------------------------------------------------------------------------- /flare_removal/matlab/GetPsf.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function psf = GetPsf(aperture, phase, wavelengths, spectral_response, crop) 3 | % GetPsf Computes the RGB response of an aperture under a point light source. 4 | % 5 | % psf = GetPsf(aperture, phase, wavelengths, spectral_response, crop) 6 | % Computes the point spread function (PSF) of the given aperture in response to 7 | % a white point light source (i.e., it has a flat spectrum). 8 | % 9 | % Arguments 10 | % 11 | % aperture: A grayscale image representing the aperture, where 0 means total 12 | % opacity and 1 means total transparency. 13 | % 14 | % phase: An image of the same size as `aperture` representing the phase shift. 15 | % 16 | % wavelengths: An L-vector of wavelengths at which the light spectrum is 17 | % sampled. They are normalized by the wavelength at which `phase` 18 | % is computed. 19 | % 20 | % spectral_response: Sensitivity of RGB pixels at `wavelengths`. Size [3, L]. 21 | % 22 | % crop: Side length of the output array. It should be smaller than the input due 23 | % to wavelength-dependent resizing - otherwise we would get out-of-range 24 | % samples. 25 | % 26 | % Returns 27 | % 28 | % psf: An RGB image of size [crop, crop]. 29 | % 30 | % Required toolboxes: none. 31 | 32 | % Expand to 3-D array of size [H, W, C] where C is the size of the `wavelengths` 33 | % vector (i.e., the number of samples in the spectrum). This accounts for the 34 | % phase term's dependency on the wavelength. 35 | phase_wl = phase ./ reshape(wavelengths, 1, 1, []); 36 | 37 | % Pupil function in the frequency domain. 38 | pupil_wl = aperture .* exp(1j * phase_wl); 39 | 40 | % Point spread function (PSF) in the spatial domain is related to the pupil 41 | % function by a Fourier transform. 42 | psf_wl = abs(fft2(pupil_wl)) .^ 2; 43 | psf_wl = fftshift(fftshift(psf_wl, 1), 2); 44 | 45 | % Apart from affecting the phase term `phase_wl`, the wavelength also has an 46 | % effect on the spatial scaling. 47 | num_wl = length(wavelengths); 48 | psf_cropped = zeros(crop, crop, num_wl, 'single'); 49 | for i = 1:num_wl 50 | wl = wavelengths(i); 51 | psf_resized = imresize(psf_wl(:, :, i), wl, 'bilinear'); 52 | psf_cropped(:, :, i) = CropCenter(psf_resized, crop); 53 | end 54 | 55 | % Finally, apply the spectral response matrix to convert to RGB. 56 | psf_cropped = reshape(psf_cropped, [], num_wl); 57 | psf = reshape(psf_cropped * spectral_response, crop, crop, 3); 58 | 59 | end 60 | -------------------------------------------------------------------------------- /flare_removal/matlab/RandomDirtyAperture.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | function im = RandomDirtyAperture(mask) 3 | % RandomDirtyAperture Synthetic dirty aperture with random dots and scratches. 4 | % 5 | % im = RandomDirtyAperture(mask) 6 | % Returns an N x N monochromatic image emulating a dirty aperture plane. 7 | % Specifically, we add disks and polylines of random size and opacity to an 8 | % otherwise white image, in an attempt to model random dust and scratches. 9 | % 10 | % TODO(qiurui): the spatial scale of the random dots and polylines are currently 11 | % hard-coded in order to match the paper. They should instead be relative to 12 | % the requested resolution, n. 13 | % 14 | % Arguments 15 | % 16 | % mask: An [N, N]-logical matrix representing the aperture mask. Typically, this 17 | % should be a centered disk of 1 surrounded by 0. 18 | % 19 | % Returns 20 | % 21 | % im: An [N, N]-matrix of values in [0, 1] where 0 means completely opaque and 1 22 | % means completely transparent. The returned matrix is real-valued (i.e., we 23 | % ignore the phase shift that may be introduced by the "dust" and 24 | % "scratches"). 25 | % 26 | % Required toolboxes: Computer Vision Toolbox. 27 | 28 | n = size(mask, 1); 29 | im = ones(size(mask), 'single'); 30 | 31 | %% Add dots (circles), simulating dust. 32 | num_dots = max(0, round(30 + randn * 5)); 33 | max_radius = max(0, 100 + randn * 50); 34 | for i = 1:num_dots 35 | circle_xyr = rand(1, 3, 'single') .* [n, n, max_radius]; 36 | opacity = 0.5 + rand * 0.5; 37 | im = insertShape(im, 'FilledCircle', circle_xyr, 'Color', 'black', ... 38 | 'Opacity', opacity); 39 | end 40 | 41 | %% Add polylines, simulating scratches. 42 | num_lines = max(0, round(30 + randn * 5)); 43 | max_width = max(0, round(20 + randn * 5)); 44 | for i = 1:num_lines 45 | num_segments = randi(16); 46 | start_xy = rand(2, 1) * n; 47 | segment_length = rand * 600; 48 | segments_xy = RandomPointsInUnitCircle(num_segments) * segment_length; 49 | vertices_xy = cumsum([start_xy, segments_xy], 2); 50 | vertices_xy = reshape(vertices_xy, 1, []); 51 | width = randi(max_width); 52 | % Note: the 'Opacity' option doesn't apply to lines, so we have to change the 53 | % line color to achieve a similar effect. Also note that [0.5 .. 1] opacity 54 | % maps to [0.5 .. 0] in color values. 55 | color = rand * 0.5; 56 | im = insertShape(im, 'Line', vertices_xy, 'LineWidth', width, ... 57 | 'Color', [color, color, color]); 58 | end 59 | 60 | im = single(mask) .* rgb2gray(im); 61 | 62 | end 63 | 64 | function xy = RandomPointsInUnitCircle(num_points) 65 | r = rand(1, num_points, 'single'); 66 | theta = rand(1, num_points, 'single') * 2 * pi; 67 | xy = [r .* cos(theta); r .* sin(theta)]; 68 | end 69 | -------------------------------------------------------------------------------- /flare_removal/matlab/main.m: -------------------------------------------------------------------------------- 1 | % // clang-format off 2 | clear 3 | close all 4 | 5 | %% Typical parameters for a smartphone camera. 6 | % Nominal wavelength (m). 7 | lambda = 550e-9; 8 | % Focal length (m). 9 | f = 2.2e-3; 10 | % Pixel pitch on the sensor (m). 11 | delta = 1e-6; 12 | % Sensor size (width & height, m). 13 | l = 6e-3; 14 | % Simulation resolution, in both spatial and frequency domains. 15 | res = l / delta; 16 | 17 | %% Compute defocus phase shift and aperture mask in the Fourier domain. 18 | % Frequency range (extent) of the Fourier transform (m ^ -1). 19 | lf = lambda * f / delta; 20 | % Diameter of the circular low-pass filter on the Fourier plane. 21 | df = 1e-3; 22 | % Low-pass radius, normalized by simulation resolution. 23 | rf_norm = df / 2 / lf; 24 | [defocus_phase, aperture_mask] = GetDefocusPhase(res, rf_norm); 25 | 26 | %% Wavelengths at which the spectral response is sampled. 27 | num_wavelengths = 73; 28 | wavelengths = linspace(380, 740, num_wavelengths) * 1e-9; 29 | 30 | %% Create output directories. 31 | out_dir = 'streaks/'; 32 | mkdir(out_dir); 33 | aperture_dir = 'apertures/'; 34 | mkdir(aperture_dir); 35 | out_crop = 800; 36 | 37 | %% generate the PSFs 38 | parfor tt = 1:1000 39 | aperture = RandomDirtyAperture(aperture_mask); 40 | imwrite(aperture, strcat(aperture_dir, sprintf('%03d.png',tt - 1))); 41 | 42 | %% Random RGB spectral response. 43 | wl_to_rgb = RandomSpectralResponse(wavelengths).'; 44 | 45 | for ii = 1:4 46 | %% Random defocus. 47 | defocus_crop = 4000; 48 | defocus = randn * 5; 49 | psf_rgb = GetPsf(aperture, defocus_phase * defocus, ... 50 | wavelengths ./ lambda, wl_to_rgb, defocus_crop); 51 | 52 | for kk = 1:4 53 | %% Randomly crop and distort the PSF. 54 | focal_length_px = f / delta * [1, 1]; 55 | sensor_crop = [2400, 2400]; 56 | principal_point = sensor_crop / 2; 57 | radial_distortion = [randn * 0.8, 0]; 58 | camera_params = cameraIntrinsics( ... 59 | focal_length_px, principal_point, sensor_crop, ... 60 | 'RadialDistortion', radial_distortion); 61 | psf_cropped = CropRandom(psf_rgb, sensor_crop); 62 | psf_distorted = undistortImage(psf_cropped, camera_params); 63 | 64 | %% Apply global tone curve (gamma) and write to disk. 65 | psf_ds = imresize(psf_distorted, 0.5, 'box'); 66 | psf_out = EqualizeChannels(CropCenter(psf_ds, out_crop)); 67 | psf_gamma = abs(psf_out .^ (1/2.2)); 68 | psf_gamma = min(psf_gamma, 2^16 - 1); 69 | psf_u16 = uint16(psf_gamma); 70 | 71 | output_file_name = sprintf('aperture%04d_blur%02d_crop%02d.png', ... 72 | tt - 1, ii - 1, kk - 1); 73 | imwrite(psf_u16, strcat(out_dir, output_file_name)); 74 | fprintf('Written to disk: %s\n', output_file_name); 75 | 76 | end 77 | end 78 | end 79 | -------------------------------------------------------------------------------- /flare_removal/python/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Models for flare removal.""" 17 | 18 | from flare_removal.python import u_net 19 | from flare_removal.python import vgg 20 | from keras_unet_collection.models import swin_unet_2d, unet_3plus_2d 21 | from keras_unet_collection import base, utils 22 | from flare_removal.python.ViT import ViT, ViTWithDecoder 23 | from flare_removal.python.pretrain_ViT import Pretrain_ViT 24 | from flare_removal.python.u_net_pp import get_model as unet_pp 25 | 26 | 27 | def build_model(model_type, batch_size, res): 28 | """Returns a Keras model specified by name.""" 29 | if model_type == 'unet': 30 | return u_net.get_model( 31 | input_shape=(res, res, 3), 32 | scales=4, 33 | bottleneck_depth=1024, 34 | bottleneck_layers=2) 35 | 36 | elif model_type == 'unet_pp': 37 | return unet_pp(input_shape=(res, res, 3)) 38 | 39 | elif model_type == 'unet_3plus_2d': 40 | return unet_3plus_2d( 41 | (res, res, 3), 42 | n_labels=3, 43 | filter_num_down=[64, 128, 256, 512], 44 | batch_norm=True, 45 | pool='max', 46 | unpool=False, 47 | #deep_supervision=True, 48 | weights='imagenet', 49 | name='unet3plus') 50 | 51 | elif model_type == 'TransUNET': 52 | return models.transunet_2d((res, res, 3), filter_num=[64, 128, 256, 512], n_labels=3, stack_num_down=2, stack_num_up=2, 53 | embed_dim=512, num_mlp=256, num_heads=6, num_transformer=6, 54 | activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', 55 | batch_norm=True, pool=True, unpool='bilinear', name='transunet') 56 | 57 | elif model_type == 'vit': 58 | return Pretrain_ViT( 59 | image_size=224, 60 | patch_size=16, 61 | encoder_dim=64, 62 | decoder_dim=128, 63 | depth=6, 64 | heads=4, 65 | mlp_dim=128, 66 | dropout=0.1, 67 | decoder_depth=4 68 | ) 69 | 70 | elif model_type == 'can': 71 | return vgg.build_can( 72 | input_shape=(512, 512, 3), conv_channels=64, out_channels=3) 73 | else: 74 | raise ValueError(model_type) 75 | -------------------------------------------------------------------------------- /flare_removal/python/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | # calculate_metric.py 2 | import os 3 | import cv2 4 | import numpy as np 5 | from skimage.metrics import structural_similarity as compare_ssim 6 | import math 7 | from absl import app 8 | from absl import flags 9 | 10 | FLAGS = flags.FLAGS 11 | flags.DEFINE_string('gt_dir', None, 12 | 'The directory contains all ground truth test images.') 13 | flags.DEFINE_string('blended_dir', None, 14 | 'The directory contains all blended images.') 15 | flags.DEFINE_string('out_dir', None, 'Output directory.') 16 | 17 | 18 | def calculate_psnr(img1, img2): 19 | mse = np.mean((img1 - img2) ** 2) 20 | if mse == 0: # Images are identical 21 | return float('inf') 22 | pixel_max = 255.0 23 | psnr = 20 * math.log10(pixel_max / math.sqrt(mse)) 24 | return psnr 25 | 26 | def calculate_ssim(img1, img2): 27 | img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 28 | img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) 29 | ssim, _ = compare_ssim(img1_gray, img2_gray, full=True) 30 | return ssim 31 | 32 | def main(_): 33 | # Directories 34 | gt_dir = FLAGS.gt_dir # Change this to your reference image directory if needed 35 | blended_dir = FLAGS.blended_dir # Directories to compare 36 | out_dir = FLAGS.out_dir 37 | results = [] 38 | 39 | #for out_dir in blended_dir: 40 | #results[out_dir] = [] 41 | #output_path = blended_dir 42 | for img_name in os.listdir(blended_dir): 43 | # Load reference and corresponding output images 44 | input_img_path = os.path.join(gt_dir, img_name) 45 | output_img_path = os.path.join(blended_dir, img_name) 46 | 47 | if not os.path.exists(input_img_path) or not os.path.exists(output_img_path): 48 | continue # Skip if files don't match 49 | 50 | input_img = cv2.imread(input_img_path) 51 | output_img = cv2.imread(output_img_path) 52 | 53 | if input_img is None or output_img is None: 54 | continue # Skip invalid images 55 | 56 | # Compute PSNR and SSIM 57 | psnr = calculate_psnr(input_img, output_img) 58 | ssim = calculate_ssim(input_img, output_img) 59 | results.append((img_name, psnr, ssim)) 60 | 61 | output_file = os.path.join(out_dir, "psnr_ssim_results.txt") 62 | with open(output_file, "w") as file: 63 | i = 0 64 | avg_psnr = 0 65 | avg_ssim = 0 66 | file.write(f"Results for {out_dir}:\n") 67 | for img_name, psnr, ssim in results: 68 | file.write(f"Image: {img_name} | PSNR: {psnr:.2f} | SSIM: {ssim:.4f}\n") 69 | i += 1 70 | avg_psnr += psnr 71 | avg_ssim += ssim 72 | 73 | avg_psnr /= i 74 | avg_ssim /= i 75 | file.write(f"Average PSNR: {avg_psnr:.2f} | Average SSIM: {avg_ssim:.4f}\n") 76 | 77 | print("Results have been saved to their respective model testing folders.") 78 | 79 | if __name__ == '__main__': 80 | app.run(main) 81 | -------------------------------------------------------------------------------- /flare_removal/python/u_net_pp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .u_net import _down_block, _up_block 3 | 4 | 5 | def _nested_block(x, skips, depth, stage, interpolation='bilinear', name_prefix='nested'): 6 | """Handles the dense skip connections for U-Net++. 7 | 8 | Args: 9 | x: The input tensor for the current stage. 10 | skips: List of skip connections from previous downscaling levels. 11 | depth: Number of output channels. 12 | stage: Current depth level. 13 | interpolation: Interpolation method for upsampling. 14 | name_prefix: Layer name prefix. 15 | 16 | Returns: 17 | A tensor that represents the fused output of this nested block. 18 | """ 19 | outputs = [x] 20 | for i, skip in enumerate(skips): 21 | # Fuse the input tensor and skip connection at this level. 22 | upsampled = tf.keras.layers.UpSampling2D( 23 | size=(2, 2), interpolation=interpolation, 24 | name=f'{name_prefix}_upsample_stage{stage}_level{i}' 25 | )(outputs[-1]) 26 | merged = tf.keras.layers.concatenate( 27 | [upsampled, skip], name=f'{name_prefix}_concat_stage{stage}_level{i}' 28 | ) 29 | conv = tf.keras.layers.Conv2D( 30 | filters=depth, 31 | kernel_size=3, 32 | padding='same', 33 | activation='relu', 34 | name=f'{name_prefix}_conv_stage{stage}_level{i}' 35 | )(merged) 36 | outputs.append(conv) 37 | return outputs[-1] # Return the most refined output for this stage. 38 | 39 | 40 | def get_model(input_shape=(512, 512, 3), scales=4, bottleneck_depth=1024, bottleneck_layers=2): 41 | """Builds a U-Net++ with dense skip connections. 42 | 43 | Args: 44 | input_shape: Shape of the input tensor without batch dimension. 45 | scales: Number of downscaling/upscaling blocks. 46 | bottleneck_depth: Number of channels in the bottleneck. 47 | bottleneck_layers: Number of Conv2D layers in the bottleneck. 48 | 49 | Returns: 50 | A Keras model instance representing a U-Net++. 51 | """ 52 | input_layer = tf.keras.Input(shape=input_shape, name='input') 53 | previous_output = input_layer 54 | 55 | # Downscaling arm with skip connections. 56 | skips = [] 57 | depths = [bottleneck_depth // 2**i for i in range(scales, 0, -1)] 58 | for depth in depths: 59 | skip, previous_output = _down_block( 60 | previous_output, depth, name_prefix=f'down{depth}' 61 | ) 62 | skips.append(skip) 63 | 64 | # Bottleneck. 65 | for i in range(bottleneck_layers): 66 | previous_output = tf.keras.layers.Conv2D( 67 | filters=bottleneck_depth, 68 | kernel_size=3, 69 | padding='same', 70 | activation='relu', 71 | name=f'bottleneck_conv{i + 1}' 72 | )(previous_output) 73 | 74 | # Upscaling arm with nested dense skip connections. 75 | nested_skips = [[] for _ in range(len(skips))] 76 | for i, (depth, skip) in enumerate(zip(reversed(depths), reversed(skips))): 77 | nested_skips[i].append(skip) 78 | previous_output = _nested_block( 79 | previous_output, nested_skips[i], depth, 80 | stage=i, name_prefix=f'nested_up{depth}' 81 | ) 82 | nested_skips[i].append(previous_output) 83 | 84 | # Squash output to (0, 1). 85 | output_layer = tf.keras.layers.Conv2D( 86 | filters=input_shape[-1], 87 | kernel_size=1, 88 | activation='sigmoid', 89 | name='output' 90 | )(previous_output) 91 | 92 | return tf.keras.Model(input_layer, output_layer, name='unet_pp') 93 | 94 | 95 | # Helper functions (_down_block and _up_block) remain unchanged. 96 | -------------------------------------------------------------------------------- /flare_removal/python/vgg_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests the `vgg` module.""" 17 | 18 | import tensorflow as tf 19 | 20 | from flare_removal.python import vgg 21 | 22 | 23 | class Vgg19Test(tf.test.TestCase): 24 | 25 | def test_duplicate_layers(self): 26 | with self.assertRaises(ValueError): 27 | vgg.Vgg19(tap_out_layers=['block1_conv1', 'block1_conv1'], weights=None) 28 | 29 | def test_invalid_layers(self): 30 | with self.assertRaisesRegex(ValueError, 'block1_conv3'): 31 | vgg.Vgg19(tap_out_layers=['block1_conv3'], weights=None) 32 | 33 | def test_output_shape(self): 34 | vgg_19 = vgg.Vgg19( 35 | tap_out_layers=['block1_conv2', 'block4_conv2', 'block5_conv2'], 36 | weights=None) 37 | images = tf.ones((4, 384, 512, 3)) * 0.5 38 | features = vgg_19(images) 39 | self.assertAllEqual(features[0].shape, [4, 384, 512, 64]) 40 | self.assertAllEqual(features[1].shape, [4, 48, 64, 512]) 41 | self.assertAllEqual(features[2].shape, [4, 24, 32, 512]) 42 | 43 | 44 | class IdentityInitializerTest(tf.test.TestCase): 45 | 46 | def setUp(self): 47 | super(IdentityInitializerTest, self).setUp() 48 | self._initializer = vgg.IdentityInitializer() 49 | 50 | def test_one_channel(self): 51 | kernel = self._initializer([5, 5, 1, 1]) 52 | x = tf.random.uniform([2, 512, 512, 1], seed=0) 53 | y = tf.nn.conv2d(x, kernel, strides=1, padding='SAME') 54 | self.assertAllClose(y, x) 55 | 56 | def test_equal_input_output_channels(self): 57 | kernel = self._initializer([3, 5, 64, 64]) 58 | x = tf.random.uniform([2, 256, 256, 64], seed=0) 59 | y = tf.nn.conv2d(x, kernel, strides=1, padding='SAME') 60 | self.assertAllClose(y, x) 61 | 62 | def test_more_output_channels_than_input(self): 63 | kernel = self._initializer([5, 3, 3, 64]) 64 | x = tf.random.uniform([2, 512, 512, 3], seed=0) 65 | y = tf.nn.conv2d(x, kernel, strides=1, padding='SAME') 66 | self.assertAllClose(y[Ellipsis, :3], x) 67 | self.assertAllEqual(tf.math.count_nonzero(y[Ellipsis, 3:]), 0) 68 | 69 | def test_more_input_channels_than_output(self): 70 | kernel = self._initializer([3, 3, 64, 3]) 71 | x = tf.random.uniform([2, 256, 256, 64], seed=0) 72 | y = tf.nn.conv2d(x, kernel, strides=1, padding='SAME') 73 | self.assertAllClose(y, x[Ellipsis, :3]) 74 | 75 | 76 | class ContextAggregationNetworkTest(tf.test.TestCase): 77 | 78 | def test_output_shape(self): 79 | x = tf.random.uniform([2, 256, 256, 3], seed=0) 80 | can = vgg.build_can(input_shape=x.shape[1:]) 81 | y = can(x) 82 | self.assertAllEqual(y.shape, x.shape) 83 | 84 | def test_contains_named_conv_blocks(self): 85 | can = vgg.build_can(name='can') 86 | for i in range(9): 87 | self.assertIsNotNone(can.get_layer(name=f'can_g_conv{i}')) 88 | self.assertIsNotNone(can.get_layer(name='can_g_conv_last')) 89 | 90 | def test_first_conv_block_shapes(self): 91 | can = vgg.build_can(input_shape=[512, 512, 3], name='can') 92 | conv0 = can.get_layer(name='can_g_conv0') 93 | # The following shapes are explicitly described by Zhang et al. in Section 3 94 | # of the paper. 95 | self.assertAllEqual(conv0.input.shape, [None, 512, 512, 1475]) 96 | self.assertAllEqual(conv0.output.shape, [None, 512, 512, 64]) 97 | 98 | 99 | if __name__ == '__main__': 100 | tf.test.main() 101 | -------------------------------------------------------------------------------- /flare_removal/python/u_net_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests the `u_net` module.""" 17 | import tensorflow as tf 18 | 19 | from flare_removal.python import u_net 20 | 21 | 22 | class UNetTest(tf.test.TestCase): 23 | 24 | def test_zero_scale(self): 25 | model = u_net.get_model( 26 | input_shape=(128, 128, 1), scales=0, bottleneck_depth=32) 27 | model.summary() 28 | 29 | input_layer = model.get_layer('input') 30 | bottleneck_conv1 = model.get_layer('bottleneck_conv1') 31 | bottleneck_conv2 = model.get_layer('bottleneck_conv2') 32 | output_layer = model.get_layer('output') 33 | self.assertIs(input_layer.output, bottleneck_conv1.input) 34 | self.assertIs(bottleneck_conv1.output, bottleneck_conv2.input) 35 | self.assertIs(bottleneck_conv2.output, output_layer.input) 36 | self.assertAllEqual(model.input_shape, [None, 128, 128, 1]) 37 | self.assertAllEqual(bottleneck_conv1.output_shape, [None, 128, 128, 32]) 38 | self.assertAllEqual(bottleneck_conv2.output_shape, [None, 128, 128, 32]) 39 | self.assertAllEqual(model.output_shape, [None, 128, 128, 1]) 40 | 41 | def test_one_scale(self): 42 | model = u_net.get_model( 43 | input_shape=(64, 64, 3), scales=1, bottleneck_depth=128) 44 | model.summary() 45 | 46 | # Downscaling arm. 47 | input_layer = model.get_layer('input') 48 | down_conv1 = model.get_layer('down64_conv1') 49 | down_conv2 = model.get_layer('down64_conv2') 50 | down_pool = model.get_layer('down64_pool') 51 | bottleneck_conv1 = model.get_layer('bottleneck_conv1') 52 | self.assertIs(input_layer.output, down_conv1.input) 53 | self.assertIs(down_conv1.output, down_conv2.input) 54 | self.assertIs(down_conv2.output, down_pool.input) 55 | self.assertIs(down_pool.output, bottleneck_conv1.input) 56 | self.assertAllEqual(model.input_shape, [None, 64, 64, 3]) 57 | self.assertAllEqual(down_conv1.output_shape, [None, 64, 64, 64]) 58 | self.assertAllEqual(down_conv2.output_shape, [None, 64, 64, 64]) 59 | self.assertAllEqual(down_pool.output_shape, [None, 32, 32, 64]) 60 | self.assertAllEqual(bottleneck_conv1.output_shape, [None, 32, 32, 128]) 61 | 62 | # Upscaling arm. 63 | bottleneck_conv2 = model.get_layer('bottleneck_conv2') 64 | up_2x = model.get_layer('up64_2x') 65 | up_2xconv = model.get_layer('up64_2xconv') 66 | up_concat = model.get_layer('up64_concat') 67 | up_conv1 = model.get_layer('up64_conv1') 68 | up_conv2 = model.get_layer('up64_conv2') 69 | output_layer = model.get_layer('output') 70 | self.assertIs(bottleneck_conv2.output, up_2x.input) 71 | self.assertIs(up_2x.output, up_2xconv.input) 72 | self.assertIs(up_2xconv.output, up_concat.input[0]) 73 | self.assertIs(up_concat.output, up_conv1.input) 74 | self.assertIs(up_conv1.output, up_conv2.input) 75 | self.assertIs(up_conv2.output, output_layer.input) 76 | self.assertAllEqual(bottleneck_conv2.output_shape, [None, 32, 32, 128]) 77 | self.assertAllEqual(up_2x.output_shape, [None, 64, 64, 128]) 78 | self.assertAllEqual(up_2xconv.output_shape, [None, 64, 64, 64]) 79 | self.assertAllEqual(up_concat.output_shape, [None, 64, 64, 128]) 80 | self.assertAllEqual(up_conv1.output_shape, [None, 64, 64, 64]) 81 | self.assertAllEqual(up_conv2.output_shape, [None, 64, 64, 64]) 82 | self.assertAllEqual(output_layer.output_shape, [None, 64, 64, 3]) 83 | 84 | # Skip connection. 85 | self.assertIs(down_conv2.output, up_concat.input[1]) 86 | 87 | 88 | if __name__ == '__main__': 89 | tf.test.main() 90 | -------------------------------------------------------------------------------- /flare_removal/python/losses_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests the `losses` module.""" 17 | import tensorflow as tf 18 | 19 | from flare_removal.python import losses 20 | 21 | 22 | class PerceptualLossTest(tf.test.TestCase): 23 | 24 | def test_identical_inputs(self): 25 | loss = losses.PerceptualLoss() 26 | images = tf.random.uniform((2, 192, 256, 3)) 27 | self.assertAllClose(loss(images, images), 0.0) 28 | 29 | def test_different_inputs(self): 30 | loss = losses.PerceptualLoss() 31 | image_1 = tf.zeros((2, 192, 256, 3)) 32 | image_2 = tf.random.uniform((2, 192, 256, 3)) 33 | self.assertAllGreater(loss(image_1, image_2), 1.0) 34 | 35 | def test_similar_vs_different_inputs(self): 36 | loss = losses.PerceptualLoss() 37 | pure_bright = tf.ones((3, 256, 256, 3)) * tf.constant([0.9, 0.7, 0.7]) 38 | pure_dark = tf.ones((3, 256, 256, 3)) * tf.constant([0.5, 0.2, 0.2]) 39 | speckles = tf.random.uniform((3, 256, 256, 3)) 40 | self.assertAllGreater( 41 | loss(pure_bright, speckles), loss(pure_bright, pure_dark)) 42 | 43 | 44 | class CompositeLossTest(tf.test.TestCase): 45 | 46 | def test_l1_with_weight(self): 47 | composite = losses.CompositeLoss() 48 | composite.add_loss('l1', 2.0) 49 | 50 | y_true = tf.constant(0.3, shape=(2, 192, 256, 3), dtype=tf.float32) 51 | y_pred = tf.constant(0.5, shape=(2, 192, 256, 3), dtype=tf.float32) 52 | self.assertAllClose( 53 | composite(y_true, y_pred), 54 | tf.reduce_mean(tf.abs(y_true - y_pred)) * 2.0) 55 | 56 | def test_l1_l2_different_weights(self): 57 | composite = losses.CompositeLoss() 58 | composite.add_loss('L1', 1.0) 59 | composite.add_loss('L2', 0.5) 60 | 61 | y_true = tf.constant(127, shape=(2, 192, 256, 3), dtype=tf.int32) 62 | y_pred = tf.constant(215, shape=(2, 192, 256, 3), dtype=tf.int32) 63 | l1 = tf.cast(tf.reduce_mean(tf.abs(y_true - y_pred)), tf.float32) 64 | l2 = tf.cast(tf.reduce_mean(tf.square(y_true - y_pred)), tf.float32) 65 | self.assertAllClose(composite(y_true, y_pred), l1 * 1.0 + l2 * 0.5) 66 | 67 | def test_composite_loss_equals_sum_of_components(self): 68 | composite = losses.CompositeLoss() 69 | mae = tf.keras.losses.MAE 70 | vgg = losses.PerceptualLoss() 71 | composite.add_loss(mae, 1.0) 72 | composite.add_loss(vgg, 2.0) 73 | 74 | y_true = tf.random.uniform((1, 192, 256, 3)) 75 | y_pred = tf.random.uniform((1, 192, 256, 3)) 76 | loss_value = composite(y_true, y_pred) 77 | mae_loss_value = tf.math.reduce_mean(mae(y_true, y_pred)) 78 | vgg_loss_value = vgg(y_true, y_pred) 79 | self.assertAllClose(loss_value, mae_loss_value * 1.0 + vgg_loss_value * 2.0) 80 | 81 | def test_duplicate_component_raises_error(self): 82 | composite = losses.CompositeLoss() 83 | composite.add_loss('l1', 1.0) 84 | with self.assertRaisesRegex(ValueError, 'exist'): 85 | composite.add_loss('l1', 2.0) 86 | 87 | def test_call_before_adding_component_raises_error(self): 88 | composite = losses.CompositeLoss() 89 | y_true = tf.random.uniform((1, 192, 256, 3)) 90 | y_pred = tf.random.uniform((1, 192, 256, 3)) 91 | with self.assertRaises(AssertionError): 92 | composite(y_true, y_pred) 93 | 94 | def test_invalid_weight(self): 95 | composite = losses.CompositeLoss() 96 | with self.assertRaisesRegex(ValueError, r'-1\.0'): 97 | composite.add_loss('l2', -1.0) 98 | 99 | 100 | if __name__ == '__main__': 101 | tf.test.main() 102 | -------------------------------------------------------------------------------- /flare_removal/python/u_net.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implements a custom U-Net. 17 | 18 | Reference: 19 | Ronneberger O., Fischer P., Brox T. (2015) U-Net: Convolutional Networks for 20 | Biomedical Image Segmentation, MICCAI 2015. 21 | https://doi.org/10.1007/978-3-319-24574-4_28 22 | """ 23 | from typing import Sequence, Tuple 24 | 25 | import tensorflow as tf 26 | 27 | 28 | def _down_block(x, 29 | depth, 30 | name_prefix = 'down'): 31 | """Applies a U-Net downscaling block to the previous stage's output. 32 | 33 | Args: 34 | x: Output from the previous stage, with shape [B, H, W, C]. 35 | depth: Number of channels in the output tensor. 36 | name_prefix: Prefix to each layer's name. Each block's prefix must be unique 37 | in the same model. 38 | 39 | Returns: 40 | Two tensors: 41 | - Output of the Conv2D layer used for the skip connection. Has shape [B, H, 42 | W, `depth`]. 43 | - Output of the MaxPool2D layer used as the input to the next block. Has 44 | shape [B, H/2, W/2, `depth`]. 45 | """ 46 | conv = tf.keras.layers.Conv2D( 47 | filters=depth, 48 | kernel_size=3, 49 | padding='same', 50 | activation='relu', 51 | name=f'{name_prefix}_conv1')( 52 | x) 53 | skip = tf.keras.layers.Conv2D( 54 | filters=depth, 55 | kernel_size=3, 56 | padding='same', 57 | activation='relu', 58 | name=f'{name_prefix}_conv2')( 59 | conv) 60 | down_2x = tf.keras.layers.MaxPool2D( 61 | pool_size=(2, 2), name=f'{name_prefix}_pool')( 62 | skip) 63 | return skip, down_2x 64 | 65 | 66 | def _up_block(x, 67 | skip, 68 | depth, 69 | interpolation = 'bilinear', 70 | name_prefix = 'up'): 71 | """Applies a U-Net upscaling block to the previous stage's output. 72 | 73 | Args: 74 | x: Output from the previous stage, with shape [B, H, W, C]. 75 | skip: Output from the corresponding downscaling block, with shape [B, 2H, 76 | 2W, C']. Normally C' = C / 2. 77 | depth: Number of channels in the output tensor. 78 | interpolation: Interpolation method. Must be "neareat" or "bilinear". 79 | name_prefix: Prefix to each layer's name. Each block's prefix must be unique 80 | in the same model. 81 | 82 | Returns: 83 | Output of the upscaling block. Has shape [B, 2H, 2W, `depth`]. 84 | """ 85 | up_2x = tf.keras.layers.UpSampling2D( 86 | size=(2, 2), interpolation=interpolation, name=f'{name_prefix}_2x')( 87 | x) 88 | up_2x = tf.keras.layers.Conv2D( 89 | filters=depth, 90 | kernel_size=2, 91 | padding='same', 92 | activation='relu', 93 | name=f'{name_prefix}_2xconv')( 94 | up_2x) 95 | concat = tf.keras.layers.concatenate([up_2x, skip], 96 | name=f'{name_prefix}_concat') 97 | conv = tf.keras.layers.Conv2D( 98 | filters=depth, 99 | kernel_size=3, 100 | padding='same', 101 | activation='relu', 102 | name=f'{name_prefix}_conv1')( 103 | concat) 104 | conv = tf.keras.layers.Conv2D( 105 | filters=depth, 106 | kernel_size=3, 107 | padding='same', 108 | activation='relu', 109 | name=f'{name_prefix}_conv2')( 110 | conv) 111 | return conv 112 | 113 | 114 | def get_model(input_shape = (512, 512, 3), 115 | scales = 4, 116 | bottleneck_depth = 1024, 117 | bottleneck_layers = 2): 118 | """Builds a U-Net with given parameters. 119 | 120 | The output of this model has the same shape as the input tensor. 121 | 122 | Args: 123 | input_shape: Shape of the input tensor, without the batch dimension. For a 124 | typical RGB image, this should be [height, width, 3]. 125 | scales: Number of downscaling/upscaling blocks in the network. The width and 126 | height of the input tensor are 2**`scales` times those of the bottleneck. 127 | 0 means no rescaling is applied and a simple feed-forward network is 128 | returned. 129 | bottleneck_depth: Number of channels in the bottleneck tensors. 130 | bottleneck_layers: Number of Conv2D layers in the bottleneck. 131 | 132 | Returns: 133 | A Keras model instance representing a U-Net. 134 | """ 135 | input_layer = tf.keras.Input(shape=input_shape, name='input') 136 | previous_output = input_layer 137 | 138 | # Downscaling arm. Produces skip connections. 139 | skips = [] 140 | depths = [bottleneck_depth // 2**i for i in range(scales, 0, -1)] 141 | for depth in depths: 142 | skip, previous_output = _down_block( 143 | previous_output, depth, name_prefix=f'down{depth}') 144 | skips.append(skip) 145 | 146 | # Bottleneck. 147 | for i in range(bottleneck_layers): 148 | previous_output = tf.keras.layers.Conv2D( 149 | filters=bottleneck_depth, 150 | kernel_size=3, 151 | padding='same', 152 | activation='relu', 153 | name=f'bottleneck_conv{i + 1}')( 154 | previous_output) 155 | 156 | # Upscaling arm. Consumes skip connections. 157 | for depth, skip in zip(reversed(depths), reversed(skips)): 158 | previous_output = _up_block( 159 | previous_output, skip, depth, name_prefix=f'up{depth}') 160 | 161 | # Squash output to (0, 1). 162 | output_layer = tf.keras.layers.Conv2D( 163 | filters=input_shape[-1], 164 | kernel_size=1, 165 | activation='sigmoid', 166 | name='output')( 167 | previous_output) 168 | 169 | return tf.keras.Model(input_layer, output_layer, name='unet') 170 | -------------------------------------------------------------------------------- /flare_removal/python/pretrain_ViT.py: -------------------------------------------------------------------------------- 1 | from transformers import ViTFeatureExtractor, TFAutoModel, ViTConfig 2 | import tensorflow as tf 3 | from einops.layers.tensorflow import Rearrange 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow.keras.layers import Layer 7 | from einops import rearrange, repeat 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | class PreNorm(Layer): 13 | def __init__(self, fn): 14 | super(PreNorm, self).__init__() 15 | self.norm = nn.LayerNormalization() 16 | self.fn = fn 17 | 18 | def call(self, x, training=True): 19 | return self.fn(self.norm(x), training=training) 20 | 21 | class MLP(Layer): 22 | def __init__(self, dim, hidden_dim, dropout=0.0): 23 | super(MLP, self).__init__() 24 | self.net = Sequential([ 25 | nn.Dense(units=hidden_dim, activation='gelu'), 26 | nn.Dropout(rate=dropout), 27 | nn.Dense(units=dim), 28 | nn.Dropout(rate=dropout) 29 | ]) 30 | 31 | def call(self, x, training=True): 32 | return self.net(x, training=training) 33 | 34 | class Attention(Layer): 35 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 36 | super(Attention, self).__init__() 37 | inner_dim = dim_head * heads 38 | self.heads = heads 39 | self.scale = dim_head ** -0.5 40 | self.attend = nn.Softmax() 41 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 42 | self.to_out = Sequential([ 43 | nn.Dense(units=dim), 44 | nn.Dropout(rate=dropout) 45 | ]) 46 | 47 | def call(self, x, training=True): 48 | qkv = tf.split(self.to_qkv(x), num_or_size_splits=3, axis=-1) 49 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 50 | dots = tf.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 51 | attn = self.attend(dots) 52 | x = tf.einsum('b h i j, b h j d -> b h i d', attn, v) 53 | x = rearrange(x, 'b h n d -> b n (h d)') 54 | return self.to_out(x, training=training) 55 | 56 | class Transformer(Layer): 57 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 58 | super(Transformer, self).__init__() 59 | self.layers = [PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)) for _ in range(depth)] 60 | self.mlp_layers = [PreNorm(MLP(dim, mlp_dim, dropout=dropout)) for _ in range(depth)] 61 | 62 | def call(self, x, training=True): 63 | for attn, mlp in zip(self.layers, self.mlp_layers): 64 | x = attn(x, training=training) + x 65 | x = mlp(x, training=training) + x 66 | return x 67 | 68 | class Pretrain_ViT(tf.keras.Model): 69 | def __init__(self, image_size, patch_size, encoder_dim, decoder_dim, depth, heads, mlp_dim, dim_head=64, 70 | dropout=0.0, decoder_depth=4): 71 | super(Pretrain_ViT, self).__init__() 72 | 73 | # Image dimensions 74 | image_height, image_width = image_size, image_size 75 | patch_height, patch_width = patch_size, patch_size 76 | num_patches = (image_height // patch_height) * (image_width // patch_width) 77 | patch_dim = patch_height * patch_width * 3 # RGB channels 78 | 79 | # Load pretrained encoder (from Hugging Face) 80 | self.encoder = TFAutoModel.from_pretrained("google/vit-base-patch16-224-in21k", from_pt=True) 81 | 82 | self.encoder_projection = tf.keras.layers.Dense(units=decoder_dim, name="encoder_projection") 83 | 84 | # Decoder positional embedding 85 | self.decoder_pos_embedding = tf.Variable(tf.random.normal([1, num_patches, decoder_dim])) 86 | 87 | # Decoder Transformer 88 | self.decoder = Transformer(dim=decoder_dim, depth=decoder_depth, heads=heads, dim_head=dim_head, 89 | mlp_dim=decoder_dim * 4, dropout=dropout) 90 | 91 | # Reconstruction head for output 92 | self.reconstruction_head = Sequential([ 93 | tf.keras.layers.Dense(units=patch_dim), 94 | Rearrange('b (h w) (p1 p2 c) -> b (h p1) (w p2) c', 95 | h=image_height // patch_height, p1=patch_height, p2=patch_width) 96 | ]) 97 | 98 | def call(self, img, training=True): 99 | # Preprocess input image for the pretrained encoder 100 | # shape = tf.shape(img) 101 | # if shape[1] == 224 and shape[2] == 224 and shape[3] == 3: 102 | img = tf.transpose(img, perm=[0, 3, 1, 2]) 103 | encoder_features = self.encoder(img, training=training)["last_hidden_state"] 104 | # Remove [CLS] token 105 | encoder_features = encoder_features[:, 1:, :] # Shape: [1, 196, 64] 106 | 107 | encoder_features = self.encoder_projection(encoder_features) 108 | 109 | # Add positional embeddings (match number of patches) 110 | x = encoder_features + self.decoder_pos_embedding # Shape: [1, 196, 128] 111 | 112 | # Pass through decoder 113 | x = self.decoder(x, training=training) 114 | 115 | # Reconstruct patches and return 116 | reconstructed = self.reconstruction_head(x) 117 | return reconstructed 118 | 119 | 120 | # #Instantiate the model 121 | # model =Pretrain_ViT( 122 | # image_size=224, 123 | # patch_size=16, 124 | # encoder_dim=64, 125 | # decoder_dim=128, 126 | # depth=6, 127 | # heads=4, 128 | # mlp_dim=128, 129 | # dropout=0.1, 130 | # decoder_depth=4 131 | # ) 132 | 133 | # # Use ViTFeatureExtractor to preprocess the input 134 | # feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") 135 | 136 | # # Fix input image range 137 | # raw_img = tf.random.uniform([1, 224, 224, 3], minval=0, maxval=255, dtype=tf.float32) 138 | # raw_img_uint8 = tf.cast(raw_img, tf.uint8) # Cast to uint8 for compatibility with feature extractor 139 | 140 | # # Preprocess the image 141 | # processed_img = feature_extractor(images=raw_img_uint8.numpy(), return_tensors="tf")["pixel_values"] 142 | 143 | # # Forward pass 144 | # output = model(processed_img) 145 | # print("Input shape:", processed_img.shape) 146 | # print("Output shape:", output.shape) 147 | -------------------------------------------------------------------------------- /flare_removal/python/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loss functions for training a lens flare reduction model.""" 17 | from typing import Callable, Dict, Mapping, Optional, Union 18 | 19 | import tensorflow as tf 20 | 21 | from flare_removal.python import vgg 22 | 23 | 24 | def get_loss(name): 25 | """Returns the loss function object for the given name. 26 | 27 | Supported configs: 28 | - "l1": 29 | Pixel-wise MAE. 30 | - "l2": 31 | Pixel-wise MSE. 32 | - "perceptual" (or "percep"): 33 | Perceptual loss implemented using a pre-trained VGG19 network, plus L1 loss. 34 | The two losses have equal weights. 35 | 36 | Args: 37 | name: One of the three configs above. Not case-sensitive. 38 | 39 | Returns: 40 | A Keras `Loss` object. 41 | """ 42 | name = name.lower() 43 | if name == 'l2': 44 | return tf.keras.losses.MeanSquaredError() 45 | elif name == 'l1': 46 | return tf.keras.losses.MeanAbsoluteError() 47 | elif name in ['percep', 'perceptual']: 48 | loss_fn = CompositeLoss() 49 | loss_fn.add_loss(PerceptualLoss(), weight=1.0) 50 | # Note that `PerceptualLoss` uses [0, 255] range internally. Since our API 51 | # assumes [0, 1] range for input images, we actually need to scale the L1 52 | # loss by 255 to achieve a true 1:1 weighting. 53 | loss_fn.add_loss('L1', weight=255.0) 54 | return loss_fn 55 | else: 56 | raise ValueError(f'Unrecognized loss function name: {name}') 57 | 58 | 59 | class PerceptualLoss(tf.keras.losses.Loss): 60 | """A perceptual loss function based on the VGG-19 model. 61 | 62 | The loss function is defined as a weighted sum of the L1 loss at various 63 | tap-out layers of the network. 64 | """ 65 | DEFAULT_COEFFS = { 66 | 'block1_conv2': 1 / 2.6, 67 | 'block2_conv2': 1 / 4.8, 68 | 'block3_conv2': 1 / 3.7, 69 | 'block4_conv2': 1 / 5.6, 70 | 'block5_conv2': 10 / 1.5, 71 | } 72 | 73 | def __init__(self, 74 | coeffs = None, 75 | name = 'perceptual'): 76 | """Initializes a perceptual loss instance. 77 | 78 | Args: 79 | coeffs: Key-value pairs where the keys are the tap-out layer names, and 80 | the values are their coefficients in the weighted sum. Defaults to the 81 | `self.DEFAULT_COEFFS`. 82 | name: Name of this Tensorflow object. 83 | """ 84 | super(PerceptualLoss, self).__init__(name=name) 85 | coeffs = coeffs or self.DEFAULT_COEFFS 86 | layers, self._coeffs = zip(*coeffs.items()) 87 | self._model = vgg.Vgg19(tap_out_layers=layers) 88 | 89 | def call(self, y_true, y_pred): 90 | """Invokes the loss function. 91 | 92 | See base class for details. 93 | 94 | Do not call this method directly. Use the __call__() method instead. 95 | 96 | Args: 97 | y_true: ground-truth image batch, with shape [B, H, W, C]. 98 | y_pred: predicted image batch, with the same shape. 99 | 100 | Returns: 101 | A [B, 1, 1] tensor containing the perceptual loss values. Note that 102 | according to the base class's specs, if the inputs have D dimensions, the 103 | output must have D-1 dimensions. Hence the [B, 1, 1] shape. 104 | """ 105 | true_features = self._model(y_true) 106 | pred_features = self._model(y_pred) 107 | total_loss = tf.constant(0.0) 108 | for ft, fp, coeff in zip(true_features, pred_features, self._coeffs): 109 | # MAE only reduces the last dimension, leading to a [B, H, W]-tensor. 110 | loss = tf.keras.losses.MAE(ft, fp) 111 | # Further reduce on the H and W dimensions. 112 | loss = tf.reduce_mean(loss, axis=[1, 2], keepdims=True) 113 | total_loss += loss * coeff 114 | return total_loss 115 | 116 | 117 | class CompositeLoss(tf.keras.losses.Loss): 118 | """A weighted sum of individual loss functions for images. 119 | 120 | Attributes: 121 | losses: Mapping from Keras loss objects to weights. 122 | """ 123 | 124 | def __init__(self, name = 'composite'): 125 | """Initializes an instance with given weights. 126 | 127 | Args: 128 | name: Optional name for this Tensorflow object. 129 | """ 130 | super(CompositeLoss, self).__init__(name=name) 131 | self.losses: Dict[tf.keras.losses.Loss, float] = {} 132 | 133 | def add_loss(self, loss, weight): 134 | """Adds a component loss to the composite with specific weight. 135 | 136 | Args: 137 | loss: A Keras loss object or identifier. All standard Keras loss 138 | identifiers are supported (e.g., string like "mse", loss functions, and 139 | `tf.keras.losses.Loss` objects). In addition, strings "l1" and "l2" are 140 | also supported. Cannot be a loss that is already added to this 141 | `CompositeLoss`. 142 | weight: Weight associated with this loss. Must be > 0. 143 | 144 | Raises: 145 | ValueError: If the given `loss` already exists, or if `weight` is empty or 146 | <= 0. 147 | """ 148 | if weight <= 0.0: 149 | raise ValueError(f'Weight must be > 0, but is {weight}.') 150 | if isinstance(loss, str): 151 | loss = loss.lower() 152 | loss = {'l1': 'mae', 'l2': 'mse'}.get(loss, loss) 153 | loss_fn = tf.keras.losses.get(loss) 154 | else: 155 | loss_fn = loss 156 | if loss_fn in self.losses: 157 | raise ValueError('The same loss already exists.') 158 | self.losses[loss_fn] = weight # pytype: disable=container-type-mismatch # typed-keras 159 | 160 | def call(self, y_true, y_pred): 161 | """See base class.""" 162 | assert self.losses, 'At least one component loss must be added.' 163 | loss_sum = tf.constant(0.0) 164 | y_true = tf.cast(y_true, tf.float32) 165 | y_pred = tf.cast(y_pred, tf.float32) 166 | for loss, weight in self.losses.items(): 167 | loss_sum = loss(y_true, y_pred) * weight + loss_sum 168 | return loss_sum 169 | -------------------------------------------------------------------------------- /flare_removal/python/data_provider.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loads data from the dataset.""" 17 | import os.path 18 | from typing import Iterable, Tuple, Union 19 | 20 | import tensorflow as tf 21 | 22 | _SHUFFLE_BUFFER_SIZE = 10_000 23 | 24 | 25 | def image_dataset_from_files(data_dir, 26 | image_shape, 27 | batch_size = 0, 28 | shuffle = True, 29 | repeat = -1): 30 | """Loads images from individual JPG or PNG files. 31 | 32 | Args: 33 | data_dir: Parent directory where input images are located. All JPEG and PNG 34 | files under this directory (either directly or indirectly) will be 35 | included. 36 | image_shape: Shape of the images in (H, W, C) format. 37 | batch_size: 0 means images are not batched. Positive values define the batch 38 | size. The batched images have shape (B, H, W, C). 39 | shuffle: Whether to randomize the order of the images. 40 | repeat: 0 means the dataset is not repeated. -1 means it's repeated 41 | indefinitely. A positive value means it's repeated for the specified 42 | number of times (epochs). 43 | 44 | Returns: 45 | A Dataset object containing (H, W, C) or (B, H, W, C) image tensors. 46 | """ 47 | extensions = ['jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG'] 48 | # Images directly under the given directory. 49 | globs = [os.path.join(data_dir, f'*.{e}') for e in extensions] 50 | # Images under subdirectories. 51 | globs += [os.path.join(data_dir, '**', f'*.{e}') for e in extensions] 52 | files = tf.data.Dataset.list_files(globs, shuffle, seed=0) 53 | 54 | def _parser(file_name): 55 | blob = tf.io.read_file(file_name) 56 | image = tf.io.decode_image(blob, dtype=tf.float32) 57 | image.set_shape(image_shape) 58 | return image 59 | 60 | images = files.map( 61 | _parser, num_parallel_calls=tf.data.AUTOTUNE, deterministic=not shuffle) 62 | 63 | if repeat < 0: 64 | images = images.repeat() 65 | elif repeat > 0: 66 | images = images.repeat(repeat) 67 | 68 | if batch_size > 0: 69 | images = images.batch(batch_size, drop_remainder=True) 70 | 71 | images = images.prefetch(tf.data.AUTOTUNE) 72 | 73 | return images 74 | 75 | 76 | def image_dataset_from_tfrecords(globs, 77 | tag, 78 | image_shape, 79 | batch_size = 0, 80 | shuffle = True, 81 | repeat = -1): 82 | """Loads images from sharded TFRecord files. 83 | 84 | Args: 85 | globs: One or more glob pattern matching the TFRecord files. 86 | tag: Name of the TFExample "feature" to decode. 87 | image_shape: Shape of the images in (H, W, C) format. 88 | batch_size: 0 means images are not batched. Positive values define the batch 89 | size. The batched images have shape (B, H, W, C). 90 | shuffle: Whether to randomize the order of the images. 91 | repeat: 0 means the dataset is not repeated. -1 means it's repeated 92 | indefinitely. A positive value means it's repeated for the specified 93 | number of times (epochs). 94 | 95 | Returns: 96 | A Dataset object containing (H, W, C) or (B, H, W, C) image tensors. 97 | """ 98 | files = tf.data.Dataset.list_files(globs, shuffle, seed=0) 99 | examples = files.interleave( 100 | tf.data.TFRecordDataset, 101 | num_parallel_calls=tf.data.AUTOTUNE, 102 | deterministic=not shuffle) 103 | 104 | if shuffle: 105 | examples = examples.shuffle( 106 | buffer_size=_SHUFFLE_BUFFER_SIZE, seed=0, reshuffle_each_iteration=True) 107 | 108 | def _parser(example): 109 | features = tf.io.parse_single_example( 110 | example, features={tag: tf.io.FixedLenFeature([], tf.string)}) 111 | image_u8 = tf.reshape( 112 | tf.io.decode_raw(features[tag], tf.uint8), image_shape) 113 | image_f32 = tf.image.convert_image_dtype(image_u8, tf.float32) 114 | return image_f32 115 | 116 | images = examples.map( 117 | _parser, num_parallel_calls=tf.data.AUTOTUNE, deterministic=not shuffle) 118 | 119 | if repeat < 0: 120 | images = images.repeat() 121 | elif repeat > 0: 122 | images = images.repeat(repeat) 123 | 124 | if batch_size > 0: 125 | images = images.batch(batch_size, drop_remainder=True) 126 | 127 | images = images.prefetch(tf.data.AUTOTUNE) 128 | 129 | return images 130 | 131 | 132 | def get_scene_dataset(path, 133 | source, 134 | batch_size, 135 | input_shape = (640, 640, 3), 136 | repeat = 0): 137 | """Returns scene images according to configuration.""" 138 | if source == 'tfrecord': 139 | return image_dataset_from_tfrecords( 140 | globs=os.path.join(path, '*.tfrecord'), 141 | tag='image', 142 | image_shape=input_shape, 143 | batch_size=batch_size, 144 | repeat=repeat) 145 | 146 | elif source == 'jpg': 147 | return image_dataset_from_files( 148 | data_dir=path, 149 | image_shape=input_shape, 150 | batch_size=batch_size, 151 | repeat=repeat) 152 | 153 | else: 154 | raise ValueError('Unrecognized data source', source) 155 | 156 | 157 | def get_flare_dataset(path, 158 | source, 159 | batch_size, 160 | input_shape = (752, 1008, 3), 161 | repeat = -1): 162 | """Returns flare images according to configuration.""" 163 | if source == 'tfrecord': 164 | return image_dataset_from_tfrecords( 165 | globs=path, 166 | tag='flare', 167 | image_shape=input_shape, 168 | batch_size=batch_size, 169 | repeat=repeat) 170 | 171 | elif source == 'jpg': 172 | return image_dataset_from_files( 173 | data_dir=path, 174 | image_shape=input_shape, 175 | batch_size=batch_size, 176 | repeat=repeat) 177 | 178 | else: 179 | raise ValueError('Unrecognized data source', source) 180 | -------------------------------------------------------------------------------- /flare_removal/python/evaluate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluation script for flare removal.""" 17 | 18 | import os.path 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import tensorflow as tf 24 | 25 | from flare_removal.python import data_provider 26 | from flare_removal.python import losses 27 | from flare_removal.python import models 28 | from flare_removal.python import synthesis 29 | 30 | flags.DEFINE_string( 31 | 'eval_dir', '/tmp/eval', 32 | 'Directory where evaluation summaries and outputs are written.') 33 | flags.DEFINE_string( 34 | 'train_dir', '/tmp/train', 35 | 'Directory where training checkpoints are written. This script will ' 36 | 'repeatedly poll and evaluate the latest checkpoint.') 37 | flags.DEFINE_string('scene_dir', None, 38 | 'Full path to the directory containing scene images.') 39 | flags.DEFINE_string('flare_dir', None, 40 | 'Full path to the directory containing flare images.') 41 | flags.DEFINE_enum( 42 | 'data_source', 'jpg', ['tfrecord', 'jpg'], 43 | 'Source of training data. Use "jpg" for individual image files, such as ' 44 | 'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.') 45 | flags.DEFINE_string('model', 'unet', 'the name of the training model') 46 | flags.DEFINE_string('loss', 'percep', 'the name of the loss for training') 47 | flags.DEFINE_integer('batch_size', 2, 'Evaluation batch size.') 48 | flags.DEFINE_float( 49 | 'learning_rate', 1e-4, 50 | 'Unused placeholder. The flag has to be defined to satisfy parameter sweep ' 51 | 'requirements.') 52 | flags.DEFINE_float( 53 | 'scene_noise', 0.01, 54 | 'Gaussian noise sigma added in the scene in synthetic data. The actual ' 55 | 'Gaussian variance for each image will be drawn from a Chi-squared ' 56 | 'distribution with a scale of scene_noise.') 57 | flags.DEFINE_float( 58 | 'flare_max_gain', 10.0, 59 | 'Max digital gain applied to the flare patterns during synthesis.') 60 | flags.DEFINE_float('flare_loss_weight', 1.0, 61 | 'Weight added on the flare loss (scene loss is 1).') 62 | flags.DEFINE_integer('training_res', 512, 63 | 'Image resolution at which the network is trained.') 64 | flags.DEFINE_integer('flare_res_h', 1008, 'Height of flare image') 65 | flags.DEFINE_integer('flare_res_w', 752, 'Width of flare image') 66 | FLAGS = flags.FLAGS 67 | 68 | 69 | def main(_): 70 | eval_dir = FLAGS.eval_dir 71 | assert eval_dir, 'Flag --eval_dir must not be empty.' 72 | train_dir = FLAGS.train_dir 73 | assert train_dir, 'Flag --train_dir must not be empty.' 74 | summary_dir = os.path.join(eval_dir, 'summary') 75 | 76 | # Load data. 77 | scenes = data_provider.get_scene_dataset( 78 | FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=0, 79 | input_shape=(FLAGS.training_res, FLAGS.training_res, 3)) 80 | flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source, 81 | FLAGS.batch_size, input_shape=(FLAGS.flare_res_w, FLAGS.flare_res_h, 3)) 82 | 83 | # Make a model. 84 | model = models.build_model(FLAGS.model, FLAGS.batch_size, FLAGS.training_res) 85 | loss_fn = losses.get_loss(FLAGS.loss) 86 | 87 | ckpt = tf.train.Checkpoint( 88 | step=tf.Variable(0, dtype=tf.int64), 89 | training_finished=tf.Variable(False, dtype=tf.bool), 90 | model=model) 91 | 92 | summary_writer = tf.summary.create_file_writer(summary_dir) 93 | 94 | # The checkpoints_iterator keeps polling the latest training checkpoints, 95 | # until: 96 | # 1) `timeout` seconds have passed waiting for a new checkpoint; and 97 | # 2) `timeout_fn` (in this case, the flag indicating the last training 98 | # checkpoint) evaluates to true. 99 | for ckpt_path in tf.train.checkpoints_iterator( 100 | train_dir, timeout=30, timeout_fn=lambda: ckpt.training_finished): 101 | try: 102 | status = ckpt.restore(ckpt_path) 103 | # Assert that all model variables are restored, but allow extra unmatched 104 | # variables in the checkpoint. (For example, optimizer states are not 105 | # needed for evaluation.) 106 | status.assert_existing_objects_matched() 107 | # Suppress warnings about unmatched variables. 108 | status.expect_partial() 109 | logging.info('Restored checkpoint %s @ step %d.', ckpt_path, ckpt.step) 110 | except (tf.errors.NotFoundError, AssertionError): 111 | logging.exception('Failed to restore checkpoint from %s.', ckpt_path) 112 | continue 113 | 114 | total_psnr = 0.0 115 | total_ssim = 0.0 116 | batch_counter = 0 117 | for scene, flare in tf.data.Dataset.zip((scenes, flares)): 118 | loss_value, summary = synthesis.run_step( 119 | scene, 120 | flare, 121 | model, 122 | loss_fn, 123 | noise=FLAGS.scene_noise, 124 | flare_max_gain=FLAGS.flare_max_gain, 125 | flare_loss_weight=FLAGS.flare_loss_weight, 126 | training_res=FLAGS.training_res) 127 | 128 | scene_batch = summary[:, :, 512: 768, :] 129 | predict_batch = summary[:, :, 256: 512, :] 130 | 131 | batch_psnr = tf.image.psnr(scene_batch, predict_batch, max_val=1.0) 132 | total_psnr += tf.reduce_sum(batch_psnr).numpy() 133 | 134 | batch_ssim = tf.image.ssim(scene_batch, predict_batch, max_val=1.0) 135 | total_ssim += tf.reduce_sum(batch_ssim).numpy() 136 | 137 | batch_counter += batch_psnr.shape[0] 138 | 139 | avg_psnr = total_psnr / batch_counter if batch_counter > 0 else 0.0 140 | average_ssim = total_ssim / batch_counter if batch_counter > 0 else 0.0 141 | 142 | with summary_writer.as_default(): 143 | tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step) 144 | tf.summary.scalar('loss', loss_value, step=ckpt.step) 145 | tf.summary.scalar("average_psnr", avg_psnr, step=batch_counter) 146 | tf.summary.scalar("average_ssim", average_ssim, step=batch_counter) 147 | 148 | logging.info('Done!') 149 | 150 | 151 | if __name__ == '__main__': 152 | app.run(main) 153 | -------------------------------------------------------------------------------- /flare_removal/python/test_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluation script for flare removal.""" 17 | 18 | import os.path 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import tensorflow as tf 24 | 25 | from flare_removal.python import data_provider 26 | from flare_removal.python import losses 27 | from flare_removal.python import models 28 | from flare_removal.python import synthesis 29 | 30 | flags.DEFINE_string( 31 | 'eval_dir', '/tmp/eval', 32 | 'Directory where evaluation summaries and outputs are written.') 33 | flags.DEFINE_string( 34 | 'train_dir', '/tmp/train', 35 | 'Directory where training checkpoints are written. This script will ' 36 | 'repeatedly poll and evaluate the latest checkpoint.') 37 | flags.DEFINE_string('scene_dir', None, 38 | 'Full path to the directory containing scene images.') 39 | flags.DEFINE_string('flare_dir', None, 40 | 'Full path to the directory containing flare images.') 41 | flags.DEFINE_enum( 42 | 'data_source', 'jpg', ['tfrecord', 'jpg'], 43 | 'Source of training data. Use "jpg" for individual image files, such as ' 44 | 'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.') 45 | flags.DEFINE_string('model', 'unet', 'the name of the training model') 46 | flags.DEFINE_string('loss', 'percep', 'the name of the loss for training') 47 | flags.DEFINE_integer('batch_size', 2, 'Evaluation batch size.') 48 | flags.DEFINE_float( 49 | 'learning_rate', 1e-4, 50 | 'Unused placeholder. The flag has to be defined to satisfy parameter sweep ' 51 | 'requirements.') 52 | flags.DEFINE_float( 53 | 'scene_noise', 0.01, 54 | 'Gaussian noise sigma added in the scene in synthetic data. The actual ' 55 | 'Gaussian variance for each image will be drawn from a Chi-squared ' 56 | 'distribution with a scale of scene_noise.') 57 | flags.DEFINE_float( 58 | 'flare_max_gain', 10.0, 59 | 'Max digital gain applied to the flare patterns during synthesis.') 60 | flags.DEFINE_float('flare_loss_weight', 1.0, 61 | 'Weight added on the flare loss (scene loss is 1).') 62 | flags.DEFINE_integer('training_res', 512, 63 | 'Image resolution at which the network is trained.') 64 | flags.DEFINE_integer('flare_res_h', 1008, 'Height of flare image') 65 | flags.DEFINE_integer('flare_res_w', 752, 'Width of flare image') 66 | FLAGS = flags.FLAGS 67 | 68 | 69 | def main(_): 70 | eval_dir = FLAGS.eval_dir 71 | assert eval_dir, 'Flag --eval_dir must not be empty.' 72 | train_dir = FLAGS.train_dir 73 | assert train_dir, 'Flag --train_dir must not be empty.' 74 | summary_dir = os.path.join(eval_dir, 'summary') 75 | 76 | # Load data. 77 | scenes = data_provider.get_scene_dataset( 78 | FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=0, 79 | input_shape=(FLAGS.training_res, FLAGS.training_res, 3)) 80 | flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source, 81 | FLAGS.batch_size, input_shape=(FLAGS.flare_res_w, FLAGS.flare_res_h, 3)) 82 | 83 | # Make a model. 84 | model = models.build_model(FLAGS.model, FLAGS.batch_size, FLAGS.training_res) 85 | loss_fn = losses.get_loss(FLAGS.loss) 86 | 87 | ckpt = tf.train.Checkpoint( 88 | step=tf.Variable(0, dtype=tf.int64), 89 | training_finished=tf.Variable(False, dtype=tf.bool), 90 | model=model) 91 | 92 | summary_writer = tf.summary.create_file_writer(summary_dir) 93 | 94 | # The checkpoints_iterator keeps polling the latest training checkpoints, 95 | # until: 96 | # 1) `timeout` seconds have passed waiting for a new checkpoint; and 97 | # 2) `timeout_fn` (in this case, the flag indicating the last training 98 | # checkpoint) evaluates to true. 99 | for ckpt_path in tf.train.checkpoints_iterator( 100 | train_dir, timeout=30, timeout_fn=lambda: ckpt.training_finished): 101 | try: 102 | status = ckpt.restore(ckpt_path) 103 | # Assert that all model variables are restored, but allow extra unmatched 104 | # variables in the checkpoint. (For example, optimizer states are not 105 | # needed for evaluation.) 106 | status.assert_existing_objects_matched() 107 | # Suppress warnings about unmatched variables. 108 | status.expect_partial() 109 | logging.info('Restored checkpoint %s @ step %d.', ckpt_path, ckpt.step) 110 | except (tf.errors.NotFoundError, AssertionError): 111 | logging.exception('Failed to restore checkpoint from %s.', ckpt_path) 112 | continue 113 | 114 | total_psnr = 0.0 115 | total_ssim = 0.0 116 | batch_counter = 0 117 | for scene, flare in tf.data.Dataset.zip((scenes, flares)): 118 | loss_value, summary = synthesis.run_test( 119 | scene, 120 | flare, 121 | model, 122 | loss_fn, 123 | noise=FLAGS.scene_noise, 124 | flare_max_gain=FLAGS.flare_max_gain, 125 | flare_loss_weight=FLAGS.flare_loss_weight, 126 | training_res=FLAGS.training_res) 127 | 128 | scene_batch = summary[:, :, 512: 768, :] 129 | predict_batch = summary[:, :, 256: 512, :] 130 | 131 | batch_psnr = tf.image.psnr(scene_batch, predict_batch, max_val=1.0) 132 | total_psnr += tf.reduce_sum(batch_psnr).numpy() 133 | 134 | batch_ssim = tf.image.ssim(scene_batch, predict_batch, max_val=1.0) 135 | total_ssim += tf.reduce_sum(batch_ssim).numpy() 136 | 137 | batch_counter += batch_psnr.shape[0] 138 | 139 | avg_psnr = total_psnr / batch_counter if batch_counter > 0 else 0.0 140 | average_ssim = total_ssim / batch_counter if batch_counter > 0 else 0.0 141 | 142 | with summary_writer.as_default(): 143 | tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step) 144 | tf.summary.scalar('loss', loss_value, step=ckpt.step) 145 | tf.summary.scalar("average_psnr", avg_psnr, step=batch_counter) 146 | tf.summary.scalar("average_ssim", average_ssim, step=batch_counter) 147 | 148 | logging.info('Done!') 149 | 150 | 151 | if __name__ == '__main__': 152 | app.run(main) 153 | -------------------------------------------------------------------------------- /flare_removal/python/synthesis.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generates synthetic scenes containing lens flare.""" 17 | import math 18 | 19 | import tensorflow as tf 20 | 21 | from flare_removal.python import utils 22 | 23 | 24 | def add_flare(scene, 25 | flare, 26 | noise, 27 | flare_max_gain = 10.0, 28 | apply_affine = True, 29 | training_res = 512): 30 | """Adds flare to natural images. 31 | 32 | Here the natural images are in sRGB. They are first linearized before flare 33 | patterns are added. The result is then converted back to sRGB. 34 | 35 | Args: 36 | scene: Natural image batch in sRGB. 37 | flare: Lens flare image batch in sRGB. 38 | noise: Strength of the additive Gaussian noise. For each image, the Gaussian 39 | variance is drawn from a scaled Chi-squared distribution, where the scale 40 | is defined by `noise`. 41 | flare_max_gain: Maximum gain applied to the flare images in the linear 42 | domain. RGB gains are applied randomly and independently, not exceeding 43 | this maximum. 44 | apply_affine: Whether to apply affine transformation. 45 | training_res: Resolution of training images. Images must be square, and this 46 | value specifies the side length. 47 | 48 | Returns: 49 | - Flare-free scene in sRGB. 50 | - Flare-only image in sRGB. 51 | - Scene with flare in sRGB. 52 | - Gamma value used during synthesis. 53 | """ 54 | batch_size, flare_input_height, flare_input_width, _ = flare.shape 55 | 56 | # Since the gamma encoding is unknown, we use a random value so that the model 57 | # will hopefully generalize to a reasonable range of gammas. 58 | gamma = tf.random.uniform([], 1.8, 2.2) 59 | flare_linear = tf.image.adjust_gamma(flare, gamma) 60 | 61 | # Remove DC background in flare. 62 | flare_linear = utils.remove_background(flare_linear) 63 | 64 | if apply_affine: 65 | rotation = tf.random.uniform([batch_size], minval=-math.pi, maxval=math.pi) 66 | shift = tf.random.normal([batch_size, 2], mean=0.0, stddev=10.0) 67 | shear = tf.random.uniform([batch_size, 2], 68 | minval=-math.pi / 9, 69 | maxval=math.pi / 9) 70 | scale = tf.random.uniform([batch_size, 2], minval=0.9, maxval=1.2) 71 | 72 | flare_linear = utils.apply_affine_transform( 73 | flare_linear, 74 | rotation=rotation, 75 | shift_x=shift[:, 0], 76 | shift_y=shift[:, 1], 77 | shear_x=shear[:, 0], 78 | shear_y=shear[:, 1], 79 | scale_x=scale[:, 0], 80 | scale_y=scale[:, 1]) 81 | 82 | flare_linear = tf.clip_by_value(flare_linear, 0.0, 1.0) 83 | flare_linear = tf.image.crop_to_bounding_box( 84 | flare_linear, 85 | offset_height=(flare_input_height - training_res) // 2, 86 | offset_width=(flare_input_width - training_res) // 2, 87 | target_height=training_res, 88 | target_width=training_res) 89 | flare_linear = tf.image.random_flip_left_right( 90 | tf.image.random_flip_up_down(flare_linear)) 91 | 92 | # First normalize the white balance. Then apply random white balance. 93 | flare_linear = utils.normalize_white_balance(flare_linear) 94 | rgb_gains = tf.random.uniform([3], 0, flare_max_gain, dtype=tf.float32) 95 | flare_linear *= rgb_gains 96 | 97 | # Further augmentation on flare patterns: random blur and DC offset. 98 | blur_size = tf.random.uniform([], 0.1, 3) 99 | flare_linear = utils.apply_blur(flare_linear, blur_size) 100 | offset = tf.random.uniform([], -0.02, 0.02) 101 | flare_linear = tf.clip_by_value(flare_linear + offset, 0.0, 1.0) 102 | 103 | flare_srgb = tf.image.adjust_gamma(flare_linear, 1.0 / gamma) 104 | 105 | # Scene augmentation: random crop and flips. 106 | scene_linear = tf.image.adjust_gamma(scene, gamma) 107 | scene_linear = tf.image.random_crop(scene_linear, flare_linear.shape) 108 | scene_linear = tf.image.random_flip_left_right( 109 | tf.image.random_flip_up_down(scene_linear)) 110 | 111 | # Additive Gaussian noise. The Gaussian's variance is drawn from a Chi-squared 112 | # distribution. This is equivalent to drawing the Gaussian's standard 113 | # deviation from a truncated normal distribution, as shown below. 114 | sigma = tf.abs(tf.random.normal([], 0, noise)) 115 | noise = tf.random.normal(scene_linear.shape, 0, sigma) 116 | scene_linear += noise 117 | 118 | # Random digital gain. 119 | gain = tf.random.uniform([], 0, 1.2) # varying the intensity scale 120 | scene_linear = tf.clip_by_value(gain * scene_linear, 0.0, 1.0) 121 | 122 | scene_srgb = tf.image.adjust_gamma(scene_linear, 1.0 / gamma) 123 | 124 | # Combine the flare-free scene with a flare pattern to produce a synthetic 125 | # training example. 126 | combined_linear = scene_linear + flare_linear 127 | combined_srgb = tf.image.adjust_gamma(combined_linear, 1.0 / gamma) 128 | combined_srgb = tf.clip_by_value(combined_srgb, 0.0, 1.0) 129 | 130 | return (utils.quantize_8(scene_srgb), utils.quantize_8(flare_srgb), 131 | utils.quantize_8(combined_srgb), gamma) 132 | 133 | 134 | def run_step(scene, 135 | flare, 136 | model, 137 | loss_fn, 138 | noise = 0.0, 139 | flare_max_gain = 10.0, 140 | flare_loss_weight = 0.0, 141 | training_res = 512): 142 | """Executes a forward step.""" 143 | scene, flare, combined, gamma = add_flare( 144 | scene, 145 | flare, 146 | flare_max_gain=flare_max_gain, 147 | noise=noise, 148 | training_res=training_res) 149 | 150 | pred_scene = model(combined) 151 | pred_flare = utils.remove_flare(combined, pred_scene, gamma) 152 | 153 | flare_mask = utils.get_highlight_mask(flare) 154 | # Fill the saturation region with the ground truth, so that no L1/L2 loss 155 | # and better for perceptual loss since it matches the surrounding scenes. 156 | masked_scene = pred_scene * (1 - flare_mask) + scene * flare_mask 157 | loss_value = loss_fn(scene, masked_scene) 158 | if flare_loss_weight > 0: 159 | masked_flare = pred_flare * (1 - flare_mask) + flare * flare_mask 160 | loss_value += flare_loss_weight * loss_fn(flare, masked_flare) 161 | 162 | image_summary = tf.concat([combined, pred_scene, scene, pred_flare, flare], 163 | axis=2) 164 | 165 | return loss_value, image_summary 166 | 167 | def run_test(scene, 168 | flare, 169 | model, 170 | loss_fn, 171 | noise = 0.0, 172 | flare_max_gain = 10.0, 173 | flare_loss_weight = 0.0, 174 | training_res = 512): 175 | """Executes a forward step.""" 176 | # scene, flare, combined, gamma = add_flare( 177 | # scene, 178 | # flare, 179 | # flare_max_gain=flare_max_gain, 180 | # noise=noise, 181 | # training_res=training_res) 182 | 183 | combined = flare 184 | pred_scene = model(combined) 185 | pred_flare = utils.remove_flare(combined, pred_scene, gamma) 186 | 187 | flare_mask = utils.get_highlight_mask(flare) 188 | # Fill the saturation region with the ground truth, so that no L1/L2 loss 189 | # and better for perceptual loss since it matches the surrounding scenes. 190 | masked_scene = pred_scene * (1 - flare_mask) + scene * flare_mask 191 | loss_value = loss_fn(scene, masked_scene) 192 | if flare_loss_weight > 0: 193 | masked_flare = pred_flare * (1 - flare_mask) + flare * flare_mask 194 | loss_value += flare_loss_weight * loss_fn(flare, masked_flare) 195 | 196 | image_summary = tf.concat([combined, pred_scene, scene, pred_flare, flare], 197 | axis=2) 198 | 199 | return loss_value, image_summary 200 | -------------------------------------------------------------------------------- /flare_removal/README.md: -------------------------------------------------------------------------------- 1 | # How to Train Neural Networks for Flare Removal 2 | 3 | This repository contains code that accompanies the following paper: 4 | 5 | > Yicheng Wu, Qiurui He, Tianfan Xue, Rahul Garg, Jiawen Chen, Ashok 6 | > Veeraraghavan, and Jonathan T. Barron. **How to train neural networks for 7 | > flare removal**. *Proceedings of the IEEE/CVF International Conference on 8 | > Computer Vision (ICCV)*, 2021. 9 | 10 | - The paper (including the supplemental materials) is available on 11 | [arXiv](https://arxiv.org/abs/2011.12485) as well as 12 | [CVF Open Access](https://openaccess.thecvf.com/content/ICCV2021/html/Wu_How_To_Train_Neural_Networks_for_Flare_Removal_ICCV_2021_paper.html). 13 | 14 | - The [main project page](https://yichengwu.github.io/flare-removal/) contains 15 | more information, including a recorded presentation. 16 | 17 | ## Announcements 18 | 19 | - We have made a small fix to the VGG loss in an attempt to fix the issue 20 | below. We thank the reader for reporting this issue. 21 | 22 | - ~~**1/30/2022:** It has been brought to our attention that there might be an 23 | issue with the training code that causes the trained model to perform worse 24 | than what we show on the 25 | [test images](https://drive.google.com/corp/drive/folders/1_gi3W8fOEusfmglJdiKCUwk3IA7B2jfQ). 26 | This issue was likely introduced when we cleaned up the repository prior to 27 | open-sourcing. We are actively investigating this issue, and will submit a 28 | patch to this repository as soon as possible. The issue does not affect the 29 | testing script (`remove_flare.py`). We can also confirm that our published 30 | results (both quantitative and qualitative) are accurate and reproducible 31 | using an older (internal) version of the code.~~ 32 | 33 | ## Dataset 34 | 35 | ### Flare-only images 36 | 37 | A total of 5,001 RGB flare images are 38 | [released](https://research.google/tools/datasets/lens-flare/) via Google 39 | Research's public dataset repository under the 40 | [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license. Among them: 41 | 42 | - 2,001 are from lab captures (1,001 captures + interpolation between frames). 43 | These images are placed under the `captured` subdirectory. 44 | 45 | - 3,000 are simulated computationally. These images are placed under the 46 | `simulated` subdirectory. 47 | 48 | To obtain this data: 49 | 50 | 1. Install [Google Cloud SDK](https://cloud.google.com/sdk/docs/quickstart). 51 | This should automatically install the `gsutil` tool which is required to 52 | access the Google Cloud Storage bucket. 53 | 54 | 2. Run the following command: 55 | 56 | ```shell 57 | $ gsutil cp -r gs://gresearch/lens-flare /your/local/path 58 | ``` 59 | 60 | ### Flare-free (scene) images 61 | 62 | We use the same image dataset as 63 | [*Single Image Reflection Removal with Perceptual Losses*](https://people.eecs.berkeley.edu/~cecilia77/project-pages/reflection.html) 64 | (Zhang et al., CVPR 2018). Please follow 65 | [their instructions](https://github.com/ceciliavision/perceptual-reflection-removal#dataset) 66 | to access this data. Note that we do *not* make the distinction between the 67 | *reflection layer* and the *transmission layer* - we shuffle the entire dataset 68 | and treat it as a unified set of natural images. You may want to make an 69 | appropriate train-test split before using this dataset. 70 | 71 | ## Code 72 | 73 | ### Synthesizing scattering flare 74 | 75 | The code for synthesizing random scattering flare ("streaks") is written in 76 | Matlab and located under the `matlab` directory. Simply execute the `main.m` 77 | script to reproduce our results. 78 | 79 | By default, it writes to the following directories: 80 | 81 | - **`matlab/apertures`**: Simulated defective apertures with dots (resembling 82 | dust) and polylines (resembling scratches). 83 | 84 | - **`matlab/streaks`**: Flare patterns resulting from the simulated defective 85 | apertures above. Multiple flare patterns are generated for each aperture, 86 | accounting for varying light source locations, defocus, and distortion. 87 | These images are used to further synthesize flare-corrupted photographs. 88 | 89 | ### Training a flare removal model 90 | 91 | **WARNING:** Commands below are executed from the repository root 92 | `google_research/`. Otherwise, Python may not be able to resolve the module 93 | paths correctly. 94 | 95 | The training and testing programs require certain dependencies (see 96 | `requirements.txt`). You may create a 97 | [virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) 98 | and install these dependencies using `pip`, as demonstrated in `run.sh`. Note 99 | that running `flare_removal/run.sh` directly will fail due to missing arguments 100 | (see [below](#testing-the-model-on-images) for details), but will at least 101 | install the correct dependencies. 102 | 103 | The training script is `python/train.py`. A separate evaluation script 104 | `python/evaluate.py` is also available, so an additional job can be started to 105 | monitor the training progress in parallel (optional). 106 | 107 | ```shell 108 | $ python3 -m flare_removal.python.train \ 109 | --train_dir=/path/to/training/logs/dir \ 110 | --scene_dir=/path/to/flare-free/training/image/dir \ 111 | --flare_dir=/path/to/flare-only/image/dir 112 | 113 | # Optional. 114 | $ python3 -m flare_removal.python.evaluate \ 115 | --eval_dir=/path/to/evaluation/logs/dir \ 116 | --train_dir=/path/to/training/logs/dir \ 117 | --scene_dir=/path/to/flare-free/evaluation/image/dir \ 118 | --flare_dir=/path/to/flare-only/image/dir 119 | ``` 120 | 121 | A few notes on the arguments: 122 | 123 | - **`--train_dir`/`--eval_dir`**: This is where all training/evaluation states 124 | are preserved, including metrics, summary images, and model weight 125 | checkpoints. When a training job restarts, it will also try to pick up from 126 | its last state from this directory. Hence, you should use a new (empty) 127 | directory for each new experiment. 128 | 129 | - **`--scene_dir`**: Parent directory of all flare-free images. You may use 130 | any natural image dataset, as long as all images are RGB and have the same 131 | size. See [above](#flare-free-scene-images) for an example. Note that the 132 | scene images used for `evaluate.py` should be different from those for 133 | `train.py`. 134 | 135 | - **`--flare_dir`**: Parent directory of all flare-only images. If you 136 | downloaded our dataset using the instructions [above](#flare-only-images), 137 | the argument should be `--flare_dir=/your/local/path/lens-flare`. 138 | 139 | - **Other arguments** are provided to further customize the training 140 | configuration, e.g., alternative forms of input data, hyperparameters, etc. 141 | Please refer to the source code for additional documentation. 142 | 143 | The training job will write the following contents to disk, under 144 | `path/to/training/logs/dir`: 145 | 146 | - **`model/`**: latest model files. 147 | 148 | - **`summary/`**: training metrics and summary images, to be visualized using 149 | [TensorBoard](https://www.tensorflow.org/tensorboard). 150 | 151 | - **`ckpt-*`**: model checkpoints, for restoration of previous model weights 152 | 153 | ### Testing the model on images 154 | 155 | We also provide a Python script to test a trained model on images in the wild. 156 | Suppose you have followed the steps above to train a flare removal model, you 157 | could invoke the testing script as follows: 158 | 159 | ```shell 160 | $ python3 -m flare_removal.python.remove_flare \ 161 | --ckpt=/path/to/training/logs/dir/model \ 162 | --input_dir=/path/to/test/image/dir \ 163 | --out_dir=/path/to/output/dir 164 | ``` 165 | 166 | The `--ckpt` argument locates the model directory saved by the training script. 167 | The other arguments are self-explanatory. For more details, including additional 168 | arguments, please refer to the source file. 169 | 170 | ## Pre-trained model 171 | 172 | Unfortunately, due to licensing constraints, we cannot release the pre-trained 173 | model. However, you should be able to reproduce our results using the code and 174 | datasets described above. 175 | 176 | ## Citation 177 | 178 | If you find this work useful, please cite: 179 | 180 | ``` 181 | @InProceedings{flareremvoal2021, 182 | author = {Wu, Yicheng and He, Qiurui and Xue, Tianfan and Garg, Rahul and 183 | Chen, Jiawen and Veeraraghavan, Ashok and Barron, Jonathan T.}, 184 | title = {How To Train Neural Networks for Flare Removal}, 185 | booktitle = {Proceedings of the IEEE/CVF International Conference on 186 | Computer Vision (ICCV)}, 187 | month = {October}, 188 | year = {2021}, 189 | pages = {2239-2247} 190 | } 191 | ``` 192 | -------------------------------------------------------------------------------- /flare_removal/python/remove_flare.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Remove flares from RGB images. 17 | 18 | This script exercises the following paper on RGB images: 19 | Yicheng Wu, Qiurui He, Tianfan Xue, Rahul Garg, Jiawen Chen, Ashok 20 | Veeraraghavan, and Jonathan T. Barron. How to train neural networks for flare 21 | removal. ICCV, 2021. 22 | 23 | Input images: 24 | 25 | - Images larger than 512 x 512 will be center-cropped to 512 x 512 before being 26 | passed to the model. 27 | 28 | - Images larger than 2048 x 2048 will be center-cropped to 2048 x 2048 first. 29 | Next, they will be downsampled to 512 x 512 and passed into the model. The 30 | inferred flare-free images will be upsampled back to 2048 x 2048. (Section 6.4 31 | of the paper.) 32 | 33 | - Images smaller than 512 x 512 are not supported. 34 | 35 | Output images: 36 | 37 | - By default, output images will be written to separate directories: 38 | - Preprocessed input 39 | - Inferred scene 40 | - Inferred flare 41 | - Inferred scene with blended light source (Section 5.2 of the paper) 42 | 43 | - Alternatively, use `--separate_out_dirs=0` to write output images to the same 44 | directory as the input. The output images will have different suffixes. 45 | """ 46 | 47 | import os.path 48 | from typing import Optional 49 | 50 | from absl import app 51 | from absl import flags 52 | import tensorflow as tf 53 | import tqdm 54 | 55 | from flare_removal.python import models 56 | from flare_removal.python import utils 57 | 58 | FLAGS = flags.FLAGS 59 | 60 | _DEFAULT_CKPT = None 61 | flags.DEFINE_string( 62 | 'ckpt', _DEFAULT_CKPT, 63 | 'Location of the model checkpoint. May be a SavedModel dir, in which case ' 64 | 'the model architecture & weights are both loaded, and "--model" is ' 65 | 'ignored. May also be a TF checkpoint path, in which case only the latest ' 66 | 'model weights are loaded (this is much faster), and "--model" is ' 67 | 'required. To load a specific checkpoint, use the checkpoint prefix ' 68 | 'instead of the checkpoint directory for this argument.') 69 | flags.DEFINE_string( 70 | 'model', None, 71 | 'Only required when "--ckpt" points to a TF checkpoint or checkpoint dir. ' 72 | 'Must be one of "unet" or "can".') 73 | flags.DEFINE_integer( 74 | 'batch_size', 1, 75 | 'Number of images in each batch. Some networks (e.g., the rain removal ' 76 | 'network) can only accept predefined batch sizes.') 77 | flags.DEFINE_string('input_dir', None, 78 | 'The directory contains all input images.') 79 | flags.DEFINE_string('out_dir', None, 'Output directory.') 80 | flags.DEFINE_boolean( 81 | 'separate_out_dirs', True, 82 | 'Whether the results are saved in separate folders under different names ' 83 | '(True), or the same folder under different names (False).') 84 | 85 | 86 | def center_crop(image, width, height): 87 | """Returns the center crop of a given image.""" 88 | old_height, old_width, _ = image.shape 89 | x_offset = (old_width - width) // 2 90 | y_offset = (old_height - height) // 2 91 | if x_offset < 0 or y_offset < 0: 92 | raise ValueError('The specified output size is bigger than the image size.') 93 | return image[y_offset:(y_offset + height), x_offset:(x_offset + width), :] 94 | 95 | 96 | def write_outputs_same_dir(out_dir, 97 | name_prefix, 98 | input_image = None, 99 | pred_scene = None, 100 | pred_flare = None, 101 | pred_blend = None): 102 | """Writes various outputs to the same directory on disk.""" 103 | if not tf.io.gfile.isdir(out_dir): 104 | raise ValueError(f'{out_dir} is not a directory.') 105 | path_prefix = os.path.join(out_dir, name_prefix) 106 | if input_image is not None: 107 | utils.write_image(input_image, path_prefix + '_input.png') 108 | if pred_scene is not None: 109 | utils.write_image(pred_scene, path_prefix + '_output.png') 110 | if pred_flare is not None: 111 | utils.write_image(pred_flare, path_prefix + '_output_flare.png') 112 | if pred_blend is not None: 113 | utils.write_image(pred_blend, path_prefix + '_output_blend.png') 114 | 115 | 116 | def write_outputs_separate_dir(out_dir, 117 | file_name, 118 | input_image = None, 119 | pred_scene = None, 120 | pred_flare = None, 121 | pred_blend = None): 122 | """Writes various outputs to separate subdirectories on disk.""" 123 | if not tf.io.gfile.isdir(out_dir): 124 | raise ValueError(f'{out_dir} is not a directory.') 125 | if input_image is not None: 126 | utils.write_image(input_image, os.path.join(out_dir, 'input', file_name)) 127 | if pred_scene is not None: 128 | utils.write_image(pred_scene, os.path.join(out_dir, 'output', file_name)) 129 | if pred_flare is not None: 130 | utils.write_image(pred_flare, 131 | os.path.join(out_dir, 'output_flare', file_name)) 132 | if pred_blend is not None: 133 | utils.write_image(pred_blend, 134 | os.path.join(out_dir, 'output_blend', file_name)) 135 | 136 | 137 | def process_one_image(model, image_path, out_dir, separate_out_dirs): 138 | """Reads one image and writes inference results to disk.""" 139 | with tf.io.gfile.GFile(image_path, 'rb') as f: 140 | blob = f.read() 141 | input_u8 = tf.image.decode_image(blob)[Ellipsis, :3] 142 | input_f32 = tf.image.convert_image_dtype(input_u8, tf.float32, saturate=True) 143 | h, w, _ = input_f32.shape 144 | 145 | if min(h, w) >= 2048: 146 | input_image = center_crop(input_f32, 2048, 2048)[None, Ellipsis] 147 | input_low = tf.image.resize( 148 | input_image, [224, 224], method=tf.image.ResizeMethod.AREA) 149 | pred_scene_low = tf.clip_by_value(model(input_low), 0.0, 1.0) 150 | pred_flare_low = utils.remove_flare(input_low, pred_scene_low) 151 | pred_flare = tf.image.resize(pred_flare_low, [2048, 2048], antialias=True) 152 | pred_scene = utils.remove_flare(input_image, pred_flare) 153 | else: 154 | input_image = center_crop(input_f32, 224, 224)[None, Ellipsis] 155 | input_image = tf.concat([input_image] * FLAGS.batch_size, axis=0) 156 | pred_scene = tf.clip_by_value(model(input_image), 0.0, 1.0) 157 | pred_flare = utils.remove_flare(input_image, pred_scene) 158 | pred_blend = utils.blend_light_source(input_image[0, Ellipsis], pred_scene[0, Ellipsis]) 159 | 160 | out_filename_stem = os.path.splitext(os.path.basename(image_path))[0] 161 | if separate_out_dirs: 162 | write_outputs_separate_dir( 163 | out_dir, 164 | out_filename_stem + '.png', 165 | input_image=input_image[0, Ellipsis], 166 | pred_scene=pred_scene[0, Ellipsis], 167 | pred_flare=pred_flare[0, Ellipsis], 168 | pred_blend=pred_blend) 169 | else: 170 | write_outputs_same_dir( 171 | out_dir, 172 | out_filename_stem, 173 | input_image=input_image[0, Ellipsis], 174 | pred_scene=pred_scene[0, Ellipsis], 175 | pred_flare=pred_flare[0, Ellipsis], 176 | pred_blend=pred_blend) 177 | 178 | 179 | def load_model(path, 180 | model_type = None, 181 | batch_size = None): 182 | """Loads a model from SavedModel or standard TF checkpoint.""" 183 | try: 184 | return tf.keras.models.load_model(path) 185 | except (ImportError, IOError): 186 | print(f'Didn\'t find SavedModel at "{path}". ' 187 | 'Trying latest checkpoint next.') 188 | model = models.build_model(model_type, batch_size, res=224) 189 | ckpt = tf.train.Checkpoint(model=model) 190 | ckpt_path = tf.train.latest_checkpoint(path) or path 191 | ckpt.restore(ckpt_path).assert_existing_objects_matched() 192 | return model 193 | 194 | 195 | def main(_): 196 | out_dir = FLAGS.out_dir or os.path.join(FLAGS.input_dir, 'model_output') 197 | tf.io.gfile.makedirs(out_dir) 198 | 199 | model = load_model(FLAGS.ckpt, FLAGS.model, FLAGS.batch_size) 200 | 201 | # The following grep works for both png and jpg. 202 | input_files = sorted(tf.io.gfile.glob(os.path.join(FLAGS.input_dir, '*.*g'))) 203 | for input_file in tqdm.tqdm(input_files): 204 | process_one_image(model, input_file, out_dir, FLAGS.separate_out_dirs) 205 | 206 | print('done') 207 | 208 | 209 | if __name__ == '__main__': 210 | app.run(main) 211 | -------------------------------------------------------------------------------- /flare_removal/python/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Training script for flare removal. 17 | 18 | This script trains a model that outputs a flare-free image from a flare-polluted 19 | image. 20 | """ 21 | import os.path 22 | import time 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | import tensorflow as tf 28 | 29 | from flare_removal.python import data_provider 30 | from flare_removal.python import losses 31 | from flare_removal.python import models 32 | from flare_removal.python import synthesis 33 | 34 | import wandb 35 | 36 | flags.DEFINE_string( 37 | 'train_dir', '/tmp/train', 38 | 'Directory where training checkpoints and summaries are written.') 39 | flags.DEFINE_string('scene_dir', None, 40 | 'Full path to the directory containing scene images.') 41 | flags.DEFINE_string('flare_dir', None, 42 | 'Full path to the directory containing flare images.') 43 | flags.DEFINE_enum( 44 | 'data_source', 'jpg', ['tfrecord', 'jpg'], 45 | 'Source of training data. Use "jpg" for individual image files, such as ' 46 | 'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.') 47 | # flags.DEFINE_string('model', 'unet', 'the name of the training model') 48 | flags.DEFINE_string('model', 'swin_unet_2d', 'the name of the training model') 49 | flags.DEFINE_string('loss', 'percep', 'the name of the loss for training') 50 | flags.DEFINE_integer('batch_size', 4, 'Training batch size.') 51 | flags.DEFINE_integer('epochs', 100, 'Training config: epochs.') 52 | flags.DEFINE_integer( 53 | 'ckpt_period', 10980, 54 | 'Write model checkpoint and summary to disk every ckpt_period steps.') 55 | flags.DEFINE_float('learning_rate', 1e-4, 'Initial learning rate.') 56 | flags.DEFINE_float( 57 | 'scene_noise', 0.01, 58 | 'Gaussian noise sigma added in the scene in synthetic data. The actual ' 59 | 'Gaussian variance for each image will be drawn from a Chi-squared ' 60 | 'distribution with a scale of scene_noise.') 61 | flags.DEFINE_float( 62 | 'flare_max_gain', 10.0, 63 | 'Max digital gain applied to the flare patterns during synthesis.') 64 | flags.DEFINE_float('flare_loss_weight', 1.0, 65 | 'Weight added on the flare loss (scene loss is 1).') 66 | flags.DEFINE_integer('training_res', 512, 'Training resolution.') 67 | flags.DEFINE_integer('flare_res_h', 1008, 'Height of flare image') 68 | flags.DEFINE_integer('flare_res_w', 752, 'Width of flare image') 69 | flags.DEFINE_string('exp_name', '', 'Experiment name') 70 | FLAGS = flags.FLAGS 71 | 72 | 73 | @tf.function 74 | def train_step(model, scene, flare, loss_fn, optimizer): 75 | """Executes one step of gradient descent.""" 76 | with tf.GradientTape() as tape: 77 | loss_value, summary = synthesis.run_step( 78 | scene, 79 | flare, 80 | model, 81 | loss_fn, 82 | noise=FLAGS.scene_noise, 83 | flare_max_gain=FLAGS.flare_max_gain, 84 | flare_loss_weight=FLAGS.flare_loss_weight, 85 | training_res=FLAGS.training_res) 86 | grads = tape.gradient(loss_value, model.trainable_weights) 87 | grads, _ = tf.clip_by_global_norm(grads, 5.0) 88 | optimizer.apply_gradients(zip(grads, model.trainable_weights)) 89 | 90 | return loss_value, summary 91 | 92 | 93 | def main(_): 94 | wandb.login() 95 | wandb.init( 96 | project="computational_imaging_project", 97 | name="David" + FLAGS.exp_name, 98 | config={ 99 | "batch_size": FLAGS.batch_size, 100 | "epochs": FLAGS.epochs, 101 | "model": FLAGS.model, 102 | "loss": FLAGS.loss, 103 | "learning_rate": FLAGS.learning_rate 104 | }) 105 | 106 | train_dir = FLAGS.train_dir 107 | assert train_dir, 'Flag --train_dir must not be empty.' 108 | summary_dir = os.path.join(train_dir, 'summary') 109 | model_dir = os.path.join(train_dir, 'model') 110 | 111 | # Load data. 112 | scenes = data_provider.get_scene_dataset( 113 | FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, 114 | repeat=FLAGS.epochs, input_shape=(FLAGS.training_res, FLAGS.training_res, 3)) 115 | flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source, 116 | FLAGS.batch_size, input_shape=(FLAGS.flare_res_w, FLAGS.flare_res_h, 3)) 117 | # Make a model. 118 | model = models.build_model(FLAGS.model, FLAGS.batch_size, FLAGS.training_res) 119 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 120 | loss_fn = losses.get_loss(FLAGS.loss) 121 | 122 | # Model checkpoints. Checkpoints don't contain model architecture, but 123 | # weights only. We use checkpoints to keep track of the training progress. 124 | ckpt = tf.train.Checkpoint( 125 | step=tf.Variable(0, dtype=tf.int64), 126 | training_finished=tf.Variable(False, dtype=tf.bool), 127 | optimizer=optimizer, 128 | model=model) 129 | ckpt_mgr = tf.train.CheckpointManager( 130 | ckpt, train_dir, max_to_keep=3, keep_checkpoint_every_n_hours=3) 131 | 132 | # Restore the latest checkpoint (model weights), if any. This is helpful if 133 | # the training job gets restarted from an unexpected termination. 134 | latest_ckpt = ckpt_mgr.latest_checkpoint 135 | restore_status = None 136 | if latest_ckpt is not None: 137 | # Note that due to lazy initialization, not all checkpointed variables can 138 | # be restored at this point. Hence 'expect_partial()'. Full restoration is 139 | # checked in the first training step below. 140 | restore_status = ckpt.restore(latest_ckpt).expect_partial() 141 | logging.info('Restoring latest checkpoint @ step %d from: %s', ckpt.step, 142 | latest_ckpt) 143 | else: 144 | logging.info('Previous checkpoints not found. Starting afresh.') 145 | 146 | summary_writer = tf.summary.create_file_writer(summary_dir) 147 | 148 | step_time_metric = tf.keras.metrics.Mean('step_time') 149 | step_start_time = time.time() 150 | i = 1 151 | e = 0 152 | wandb.log({"epoch": e}) 153 | epoch_len = len([f for f in os.listdir(FLAGS.scene_dir)]) // FLAGS.batch_size 154 | 155 | for scene, flare in tf.data.Dataset.zip((scenes, flares)): 156 | if i % epoch_len == 0: 157 | e += 1 158 | wandb.log({"epoch": e}) 159 | 160 | # Perform one training step. 161 | loss_value, summary = train_step(model, scene, flare, loss_fn, optimizer) 162 | 163 | # By this point, all lazily initialized variables should have been 164 | # restored by the checkpoint if one was available. 165 | if restore_status is not None: 166 | restore_status.assert_consumed() 167 | restore_status = None 168 | 169 | # Write training summaries and checkpoints to disk. 170 | ckpt.step.assign_add(1) 171 | if ckpt.step % FLAGS.ckpt_period == 0: 172 | # Write model checkpoint to disk. 173 | ckpt_mgr.save() 174 | 175 | # Also save the full model using the latest weights. To restore previous 176 | # weights, you'd have to load the model and restore a previously saved 177 | # checkpoint. 178 | tf.keras.models.save_model(model, model_dir, save_format='tf') 179 | 180 | # Write summaries to disk, which can be visualized with TensorBoard. 181 | with summary_writer.as_default(): 182 | tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step) 183 | tf.summary.scalar('loss', loss_value, step=ckpt.step) 184 | tf.summary.scalar( 185 | 'step_time', step_time_metric.result(), step=ckpt.step) 186 | step_time_metric.reset_state() 187 | 188 | # Record elapsed time in this training step. 189 | step_end_time = time.time() 190 | step_time_metric.update_state(step_end_time - step_start_time) 191 | step_start_time = step_end_time 192 | 193 | wandb.log({"loss": loss_value.numpy(), 194 | "step": int(ckpt.step.numpy()), 195 | "step_time": step_time_metric.result()}) 196 | i += 1 197 | ckpt.training_finished.assign(True) 198 | ckpt_mgr.save() 199 | logging.info('Done!') 200 | 201 | 202 | if __name__ == '__main__': 203 | app.run(main) 204 | -------------------------------------------------------------------------------- /flare_removal/python/vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Wrappers and extensions for a pre-trained VGG-19 network.""" 17 | 18 | from typing import List, Optional, Sequence, Tuple 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | class Vgg19(tf.keras.Model): 25 | """A modified VGG-19 network with configurable tap-outs. 26 | 27 | The network is modified such that all max pooling are replaced by average 28 | pooling. 29 | 30 | Supported layers and their output shapes are: 31 | - block1_conv1 .. 2: [B, H, W, 64] 32 | - block1_pool: [B, H/2, W/2, 64] 33 | - block2_conv1 .. 2: [B, H/2, W/2, 128] 34 | - block2_pool: [B, H/4, W/4, 128] 35 | - block3_conv1 .. 4: [B, H/4, W/4, 256] 36 | - block3_pool: [B, H/8, W/8, 256] 37 | - block4_conv1 .. 4: [B, H/8, W/8, 512] 38 | - block4_pool: [B, H/16, W/16, 512] 39 | - block5_conv1 .. 4: [B, H/16, W/16, 512] 40 | - block5_pool: [B, H/32, W/32, 512] 41 | where [B, H, W, 3] is the batched input image tensor. 42 | """ 43 | 44 | def __init__(self, 45 | tap_out_layers, 46 | trainable = False, 47 | weights = 'imagenet'): 48 | """Initializes a modified VGG-19 network, with optional pre-trained weights. 49 | 50 | Args: 51 | tap_out_layers: Names of the layers used as tap-out points. The output 52 | tensors of these layers will be returned when model is called. Must be a 53 | subset of the supported layers listed above. 54 | trainable: Whether the network's weights are frozen. 55 | weights: Source of the pre-trained weights. Use None if the network is to 56 | be initialized randomly. See `tf.keras.applications.VGG19` for details. 57 | 58 | Raises: 59 | ValueError: If `tap_out_layers` has duplicate or invalid entries. 60 | """ 61 | super(Vgg19, self).__init__(name='vgg19') 62 | if len(set(tap_out_layers)) != len(tap_out_layers): 63 | raise ValueError(f'There are duplicates in the provided layers: ' 64 | f'{tap_out_layers}') 65 | 66 | # Load pre-trained weights. 67 | model = tf.keras.applications.VGG19(include_top=False, weights=weights) 68 | 69 | # Replace max pooling by average pooling according to the following paper: 70 | # Zhang et al., Single Image Reflection Removal with Perceptual Losses, 71 | # CVPR 2018. 72 | model = self._replace_max_by_average_pool(model) 73 | model.trainable = trainable 74 | 75 | # Configure tap-out layers. Note that we need `layer.get_output_at(1)` below 76 | # to use the modified graph (at node 1) with average pooling. The 77 | # `layer.output` attribute will default to node 0, which is the unmodified 78 | # model. 79 | invalid_layers = set(tap_out_layers) - set(l.name for l in model.layers) 80 | if invalid_layers: 81 | raise ValueError(f'Unrecognized layers: {invalid_layers}') 82 | tap_outs = [model.get_layer(l).get_output_at(1) for l in tap_out_layers] 83 | self._model = tf.keras.Model(inputs=model.inputs, outputs=tap_outs) 84 | 85 | def call(self, images, **kwargs): 86 | """Invokes the model on batched images. 87 | 88 | Args: 89 | images: A [B, H, W, C]-tensor of type float32, in range [0, 1]. 90 | **kwargs: Other arguments in the base class are ignored. 91 | 92 | Returns: 93 | Output tensors of the tap-out layers, in the same order as 94 | `self.tap_out_layers`. 95 | """ 96 | # Scale from [0, 1] to [0, 255], convert to BGR channel order, and subtract 97 | # channel means. 98 | x = tf.keras.applications.vgg19.preprocess_input(images * 255.0) 99 | return self._model(x) 100 | 101 | @staticmethod 102 | def _replace_max_by_average_pool(model): 103 | """Replaces MaxPooling2D layers in a model with AveragePooling2D.""" 104 | input_layer, *other_layers = model.layers 105 | if not isinstance(input_layer, tf.keras.layers.InputLayer): 106 | raise ValueError('The first layer should be InputLayer, but is:', 107 | input_layer) 108 | x = input_layer.output 109 | for layer in other_layers: 110 | if isinstance(layer, tf.keras.layers.MaxPooling2D): 111 | layer = tf.keras.layers.AveragePooling2D( 112 | pool_size=layer.pool_size, 113 | strides=layer.strides, 114 | padding=layer.padding, 115 | data_format=layer.data_format, 116 | name=layer.name, 117 | ) 118 | x = layer(x) 119 | return tf.keras.models.Model(inputs=input_layer.input, outputs=x) 120 | 121 | 122 | class IdentityInitializer(tf.keras.initializers.Initializer): 123 | """Initializes a Conv2D kernel as an identity transform. 124 | 125 | Specifically, the identity kernel does the following (assuming M input 126 | channels and N output channels): 127 | - If M >= N, the first N channels of the input are copied over to the output. 128 | - If M < N, the input is copied to the first M channels of the output, and the 129 | rest of the output is zero. 130 | 131 | The kernel weight matrix is assumed to have 4 dimensions: [H, W, M, N], where 132 | (H, W) are the size of each 2-D kernel, and (M, N) are the number of 133 | input/output channels. 134 | 135 | Note that this differs from the `tf.keras.initializers.Identity` initializer, 136 | which works on 2-D weight matrices. 137 | """ 138 | 139 | def __call__(self, 140 | shape, 141 | dtype = tf.float32, 142 | **kwargs): 143 | array = np.zeros(shape, dtype=dtype.as_numpy_dtype) 144 | kernel_height, kernel_width, in_channels, out_channels = shape 145 | cy, cx = kernel_height // 2, kernel_width // 2 146 | for i in range(np.minimum(in_channels, out_channels)): 147 | array[cy, cx, i, i] = 1 148 | return tf.constant(array) 149 | 150 | 151 | class _CanBlock(tf.keras.layers.Layer): 152 | """A convolutional block in the context aggregation network.""" 153 | 154 | def __init__(self, channels, size, rate, **kwargs): 155 | """Initializes a convolutional block. 156 | 157 | Args: 158 | channels: Number of output channels. 159 | size: Side length of the square kernel. 160 | rate: Dilation rate. 161 | **kwargs: Other args passed into `Layer`. 162 | """ 163 | super(_CanBlock, self).__init__(**kwargs) 164 | self.channels = channels 165 | self.size = size 166 | self.rate = rate 167 | 168 | def build(self, input_shape): 169 | self.conv = tf.keras.layers.Conv2D( 170 | filters=self.channels, 171 | kernel_size=self.size, 172 | dilation_rate=self.rate, 173 | padding='same', 174 | use_bias=False, 175 | kernel_initializer=IdentityInitializer(), 176 | input_shape=input_shape) 177 | # Trainable weights for normalization. 178 | self.w0 = self.add_weight( 179 | 'w0', 180 | dtype=tf.float32, 181 | initializer=tf.keras.initializers.Constant(1.0), 182 | trainable=True) 183 | self.w1 = self.add_weight( 184 | 'w1', 185 | dtype=tf.float32, 186 | initializer=tf.keras.initializers.Constant(0.0), 187 | trainable=True) 188 | self.batch_norm = tf.keras.layers.BatchNormalization(scale=False) 189 | self.activation = tf.keras.layers.LeakyReLU(0.2) 190 | 191 | def call(self, inputs): 192 | convolved = self.conv(inputs) 193 | normalized = self.w0 * convolved + self.w1 * self.batch_norm(convolved) 194 | outputs = self.activation(normalized) 195 | return outputs 196 | 197 | 198 | def build_can(input_shape = (512, 512, 3), 199 | conv_channels=64, 200 | out_channels=3, 201 | name='can'): 202 | """A context aggregation network based on the pre-trained VGG-19 network. 203 | 204 | Reference: 205 | X. Zhang, R. Ng, and Q. Chen. Single image reflection removal with perceptual 206 | loss. CVPR, 2018. 207 | 208 | Args: 209 | input_shape: Shape of the input tensor, without the batch dimension. For a 210 | typical RGB image, this should be [height, width, 3]. 211 | conv_channels: Number of channels in the intermediate convolution blocks. 212 | out_channels: Number of output channels. 213 | name: Name of this model. Will also be added as a prefix to the weight 214 | variable names. 215 | 216 | Returns: 217 | A Keras Model object. 218 | """ 219 | input_layer = tf.keras.Input(shape=input_shape, name='input') 220 | 221 | vgg = Vgg19( 222 | tap_out_layers=[f'block{i}_conv2' for i in range(1, 6)], trainable=False) 223 | features = vgg(input_layer) 224 | features = [tf.image.resize(f, input_shape[:2]) / 255.0 for f in features] 225 | 226 | x = tf.concat([input_layer] + features, axis=-1) 227 | 228 | x = _CanBlock(conv_channels, size=1, rate=1, name=f'{name}_g_conv0')(x) 229 | 230 | for i, rate in enumerate([1, 2, 4, 8, 16, 32, 64, 1]): 231 | x = _CanBlock( 232 | conv_channels, size=3, rate=rate, name=f'{name}_g_conv{i + 1}')( 233 | x) 234 | 235 | output_layer = tf.keras.layers.Conv2D( 236 | out_channels, 237 | kernel_size=1, 238 | dilation_rate=1, 239 | padding='same', 240 | use_bias=False, 241 | name=f'{name}_g_conv_last')( 242 | x) 243 | 244 | return tf.keras.Model(input_layer, output_layer, name=name) 245 | -------------------------------------------------------------------------------- /flare_removal/python/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """General utility functions.""" 17 | import os.path 18 | 19 | import cv2 20 | import numpy as np 21 | import skimage 22 | import skimage.morphology 23 | import tensorflow as tf 24 | from tensorflow_addons import image as tfa_image 25 | from tensorflow_addons.utils import types as tfa_types 26 | 27 | # Small number added to near-zero quantities to avoid numerical instability. 28 | _EPS = 1e-7 29 | 30 | 31 | def _gaussian_kernel(kernel_size, sigma, n_channels, 32 | dtype): 33 | x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=dtype) 34 | g = tf.math.exp(-(tf.pow(x, 2) / (2 * tf.pow(tf.cast(sigma, dtype), 2)))) 35 | g_norm2d = tf.pow(tf.reduce_sum(g), 2) 36 | g_kernel = tf.tensordot(g, g, axes=0) / g_norm2d 37 | g_kernel = tf.expand_dims(g_kernel, axis=-1) 38 | return tf.expand_dims(tf.tile(g_kernel, (1, 1, n_channels)), axis=-1) 39 | 40 | 41 | def apply_blur(im, sigma): 42 | """Applies a Gaussian blur to an image tensor.""" 43 | blur = _gaussian_kernel(21, sigma, im.shape[-1], im.dtype) 44 | im = tf.nn.depthwise_conv2d(im, blur, [1, 1, 1, 1], 'SAME') 45 | return im 46 | 47 | 48 | def remove_flare(combined, flare, gamma = 2.2): 49 | """Subtracts flare from the image in linear space. 50 | 51 | Args: 52 | combined: gamma-encoded image of a flare-polluted scene. 53 | flare: gamma-encoded image of the flare. 54 | gamma: [value in linear domain] = [gamma-encoded value] ^ gamma. 55 | 56 | Returns: 57 | Gamma-encoded flare-free scene. 58 | """ 59 | # Avoid zero. Otherwise, the gradient of pow() below will be undefined when 60 | # gamma < 1. 61 | combined = tf.clip_by_value(combined, _EPS, 1.0) 62 | flare = tf.clip_by_value(flare, _EPS, 1.0) 63 | 64 | combined_linear = tf.pow(combined, gamma) 65 | flare_linear = tf.pow(flare, gamma) 66 | 67 | scene_linear = combined_linear - flare_linear 68 | # Avoid zero. Otherwise, the gradient of pow() below will be undefined when 69 | # gamma > 1. 70 | scene_linear = tf.clip_by_value(scene_linear, _EPS, 1.0) 71 | scene = tf.pow(scene_linear, 1.0 / gamma) 72 | return scene 73 | 74 | 75 | def quantize_8(image): 76 | """Converts and quantizes an image to 2^8 discrete levels in [0, 1].""" 77 | q8 = tf.image.convert_image_dtype(image, tf.uint8, saturate=True) 78 | return tf.cast(q8, tf.float32) * (1.0 / 255.0) 79 | 80 | 81 | def write_image(image, path, overwrite = True): 82 | """Writes an image represented by a tensor to a PNG or JPG file.""" 83 | if not os.path.basename(path): 84 | raise ValueError(f'The given path doesn\'t represent a file: {path}') 85 | if tf.io.gfile.exists(path): 86 | if tf.io.gfile.isdir(path): 87 | raise ValueError(f'The given path is an existing directory: {path}') 88 | if not overwrite: 89 | print(f'Not overwriting an existing file at {path}') 90 | return False 91 | tf.io.gfile.remove(path) 92 | else: 93 | tf.io.gfile.makedirs(os.path.dirname(path)) 94 | 95 | image_u8 = tf.image.convert_image_dtype(image, tf.uint8, saturate=True) 96 | if path.lower().endswith('.png'): 97 | encoded = tf.io.encode_png(image_u8) 98 | elif path.lower().endswith('.jpg') or path.lower().endswith('.jpeg'): 99 | encoded = tf.io.encode_jpeg(image_u8, progressive=True) 100 | else: 101 | raise ValueError(f'Unsupported image format: {os.path.basename(path)}') 102 | with tf.io.gfile.GFile(path, 'wb') as f: 103 | f.write(encoded.numpy()) 104 | return True 105 | 106 | 107 | def _center_transform(t, height, width): 108 | """Modifies a homography such that the origin is at the image center. 109 | 110 | The transform matrices are represented using 8-vectors, following the 111 | `tensorflow_addons,image` package. 112 | 113 | Args: 114 | t: A [8]- or [B, 8]-tensor representing projective transform(s) defined 115 | relative to the origin (0, 0). 116 | height: Image height, in pixels. 117 | width: Image width, in pixels. 118 | 119 | Returns: 120 | The same transform(s), but applied relative to the image center (width / 2, 121 | height / 2) instead. 122 | """ 123 | center_to_origin = tfa_image.translations_to_projective_transforms( 124 | [-width / 2, -height / 2]) 125 | origin_to_center = tfa_image.translations_to_projective_transforms( 126 | [width / 2, height / 2]) 127 | t = tfa_image.compose_transforms([center_to_origin, t, origin_to_center]) 128 | return t 129 | 130 | 131 | def scales_to_projective_transforms(scales, height, 132 | width): 133 | """Returns scaling transform matrices for a batched input. 134 | 135 | The scaling is applied relative to the image center, instead of (0, 0). 136 | 137 | Args: 138 | scales: 2-element tensor [sx, sy], or a [B, 2]-tensor reprenting a batch of 139 | such inputs. `sx` and `sy` are the scaling ratio in x and y respectively. 140 | height: Image height, in pixels. 141 | width: Image width, in pixels. 142 | 143 | Returns: 144 | A [B, 8]-tensor representing the transform that can be passed to 145 | `tensorflow_addons.image.transform`. 146 | """ 147 | scales = tf.convert_to_tensor(scales) 148 | if tf.rank(scales) == 1: 149 | scales = scales[None, :] 150 | scales_x = tf.reshape(scales[:, 0], (-1, 1)) 151 | scales_y = tf.reshape(scales[:, 1], (-1, 1)) 152 | zeros = tf.zeros_like(scales_x) 153 | transform = tf.concat( 154 | [scales_x, zeros, zeros, zeros, scales_y, zeros, zeros, zeros], axis=-1) 155 | return _center_transform(transform, height, width) 156 | 157 | 158 | def shears_to_projective_transforms(shears, height, 159 | width): 160 | """Returns shear transform matrices for a batched input. 161 | 162 | The shear is applied relative to the image center, instead of (0, 0). 163 | 164 | Args: 165 | shears: 2-element tensor [sx, sy], or a [B, 2]-tensor reprenting a batch of 166 | such inputs. `sx` and `sy` are the shear angle (in radians) in x and y 167 | respectively. 168 | height: Image height, in pixels. 169 | width: Image width, in pixels. 170 | 171 | Returns: 172 | A [B, 8]-tensor representing the transform that can be passed to 173 | `tensorflow_addons.image.transform`. 174 | """ 175 | shears = tf.convert_to_tensor(shears) 176 | if tf.rank(shears) == 1: 177 | shears = shears[None, :] 178 | shears_x = tf.reshape(tf.tan(shears[:, 0]), (-1, 1)) 179 | shears_y = tf.reshape(tf.tan(shears[:, 1]), (-1, 1)) 180 | ones = tf.ones_like(shears_x) 181 | zeros = tf.zeros_like(shears_x) 182 | transform = tf.concat( 183 | [ones, shears_x, zeros, shears_y, ones, zeros, zeros, zeros], axis=-1) 184 | return _center_transform(transform, height, width) 185 | 186 | 187 | def apply_affine_transform(image, 188 | rotation = 0., 189 | shift_x = 0., 190 | shift_y = 0., 191 | shear_x = 0., 192 | shear_y = 0., 193 | scale_x = 1., 194 | scale_y = 1., 195 | interpolation = 'bilinear'): 196 | """Applies affine transform(s) on the input images. 197 | 198 | The rotation, shear, and scaling transforms are applied relative to the image 199 | center, instead of (0, 0). The transform parameters can either be scalars 200 | (applied to all images in the batch) or [B]-tensors (applied to each image 201 | individually). 202 | 203 | Args: 204 | image: Input images in [B, H, W, C] format. 205 | rotation: Rotation angle in radians. Positive value rotates the image 206 | counter-clockwise. 207 | shift_x: Translation in x direction, in pixels. 208 | shift_y: Translation in y direction, in pixels. 209 | shear_x: Shear angle (radians) in x direction. 210 | shear_y: Shear angle (radians) in y direction. 211 | scale_x: Scaling factor in x direction. 212 | scale_y: Scaling factor in y direction. 213 | interpolation: Interpolation mode. Supported values: 'nearest', 'bilinear'. 214 | 215 | Returns: 216 | The transformed images in [B, H, W, C] format. 217 | """ 218 | height, width = image.shape[1:3] 219 | 220 | rotation = tfa_image.angles_to_projective_transforms(rotation, height, width) 221 | shear = shears_to_projective_transforms([shear_x, shear_y], height, width) 222 | scaling = scales_to_projective_transforms([scale_x, scale_y], height, width) 223 | translation = tfa_image.translations_to_projective_transforms( 224 | [shift_x, shift_y]) 225 | 226 | t = tfa_image.compose_transforms([rotation, shear, scaling, translation]) 227 | transformed = tfa_image.transform(image, t, interpolation=interpolation) 228 | 229 | return transformed 230 | 231 | 232 | def get_highlight_mask(im, 233 | threshold = 0.99, 234 | dtype = tf.float32): 235 | """Returns a binary mask indicating the saturated regions in the input image. 236 | 237 | Args: 238 | im: Image tensor with shape [H, W, C], or [B, H, W, C]. 239 | threshold: A pixel is considered saturated if its channel-averaged intensity 240 | is above this value. 241 | dtype: Expected output data type. 242 | 243 | Returns: 244 | A `dtype` tensor with shape [H, W, 1] or [B, H, W, 1]. 245 | """ 246 | binary_mask = tf.reduce_mean(im, axis=-1, keepdims=True) > threshold 247 | mask = tf.cast(binary_mask, dtype) 248 | return mask 249 | 250 | 251 | def refine_mask(mask, morph_size = 0.01): 252 | """Refines a mask by applying mophological operations. 253 | 254 | Args: 255 | mask: A float array of shape [H, W] or [B, H, W]. 256 | morph_size: Size of the morphological kernel relative to the long side of 257 | the image. 258 | 259 | Returns: 260 | Refined mask of shape [H, W] or [B, H, W]. 261 | """ 262 | mask_size = max(np.shape(mask)) 263 | kernel_radius = .5 * morph_size * mask_size 264 | kernel = skimage.morphology.disk(np.ceil(kernel_radius)) 265 | opened = skimage.morphology.binary_opening(mask, kernel) 266 | return opened 267 | 268 | 269 | def _create_disk_kernel(kernel_size): 270 | x = np.arange(kernel_size) - (kernel_size - 1) / 2 271 | xx, yy = np.meshgrid(x, x) 272 | rr = np.sqrt(xx**2 + yy**2) 273 | kernel = np.float32(rr <= np.max(x)) + _EPS 274 | kernel = kernel / np.sum(kernel) 275 | return kernel 276 | 277 | 278 | def blend_light_source(scene_input, scene_pred): 279 | """Adds suspected light source in the input to the flare-free image.""" 280 | binary_mask = get_highlight_mask(scene_input, dtype=tf.bool).numpy() 281 | binary_mask = np.squeeze(binary_mask, axis=-1) 282 | binary_mask = refine_mask(binary_mask) 283 | 284 | labeled = skimage.measure.label(binary_mask) 285 | properties = skimage.measure.regionprops(labeled) 286 | max_diameter = 0 287 | for p in properties: 288 | max_diameter = max(max_diameter, p['equivalent_diameter']) 289 | 290 | mask = np.float32(binary_mask) 291 | 292 | kernel_size = round(1.5 * max_diameter) 293 | if kernel_size > 0: 294 | kernel = _create_disk_kernel(kernel_size) 295 | mask = cv2.filter2D(mask, -1, kernel) 296 | mask = np.clip(mask * 3.0, 0.0, 1.0) 297 | mask_rgb = np.stack([mask] * 3, axis=-1) 298 | else: 299 | mask_rgb = 0 300 | 301 | blend = scene_input * mask_rgb + scene_pred * (1 - mask_rgb) 302 | 303 | return blend 304 | 305 | 306 | def normalize_white_balance(im): 307 | """Normalizes the RGB channels so the image appears neutral in color. 308 | 309 | Args: 310 | im: Image tensor with shape [H, W, C], or [B, H, W, C]. 311 | 312 | Returns: 313 | Image(s) with equal channel mean. (The channel mean may be different across 314 | images for batched input.) 315 | """ 316 | channel_mean = tf.reduce_mean(im, axis=(-3, -2), keepdims=True) 317 | max_of_mean = tf.reduce_max(channel_mean, axis=(-3, -2, -1), keepdims=True) 318 | normalized = max_of_mean * im / (channel_mean + _EPS) 319 | return normalized 320 | 321 | 322 | def remove_background(im): 323 | """Removes the DC component in the background. 324 | 325 | Args: 326 | im: Image tensor with shape [H, W, C], or [B, H, W, C]. 327 | 328 | Returns: 329 | Image(s) with DC background removed. The white level (maximum pixel value) 330 | stays the same. 331 | """ 332 | im_min = tf.reduce_min(im, axis=(-3, -2), keepdims=True) 333 | im_max = tf.reduce_max(im, axis=(-3, -2), keepdims=True) 334 | return (im - im_min) * im_max / (im_max - im_min + _EPS) 335 | --------------------------------------------------------------------------------