├── .gitignore ├── .gitattributes ├── data ├── kodim13.png └── illustration.gif ├── example.py ├── LICENSE ├── README.md └── cas.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /data/kodim13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/HEAD/data/kodim13.png -------------------------------------------------------------------------------- /data/illustration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/HEAD/data/illustration.gif -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 28 11:57:34 2023 4 | 5 | @author: jamyl 6 | """ 7 | 8 | import torch as th 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | from cas import contrast_adaptive_sharpening 12 | 13 | BINOMIAL_KERNEL = th.tensor([1, 2, 1]).float().view(1, 1, 1, 3).repeat(3, 1, 1, 1)/4 14 | 15 | 16 | im = th.from_numpy(plt.imread("data/kodim13.png").copy()) 17 | im = im.unsqueeze(0).permute(0, -1, 1, 2) 18 | 19 | blurry = F.conv2d(im, BINOMIAL_KERNEL, padding=(0, 1), groups=3) 20 | blurry = F.conv2d(blurry, BINOMIAL_KERNEL.permute(0, 1, 3, 2), padding=(1, 0), groups=3)[0] 21 | 22 | amount = 0.8 23 | 24 | plt.imsave('data/amount={:3.2f}.png'.format(amount), contrast_adaptive_sharpening(blurry, amount).permute(1, 2, 0).numpy()) 25 | plt.imsave('data/blurry.png', blurry.permute(1, 2, 0).numpy()) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jamy L 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Contrast Adaptive Sharpening 2 | 3 | This repository is an unofficial PyTorch implementation of the [Contrast Adaptative Sharpening](https://github.com/GPUOpen-Effects/FidelityFX-CAS/tree/master) (CAS) featured in AMD's FidelityFX. 4 | It is designed to be the lightweight final step of bigger image processing pipelines, as it attenuates the blur typically introduced by image upscaling while minimizing artifacts. 5 | Below, you will find information about the code, installation requirements, and how to run the provided example. 6 | 7 | ## Illustration 8 | 9 | Here is a quick overview of the sharpening on a famous image of the [Kodak dataset](https://www.r0k.us/graphics/kodak/): 10 | 11 | ![Feature Illustration](https://github.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/blob/main/data/illustration.gif) 12 | 13 | _Description: A blurry image is sharpen with 2 different strengh. Notice how the clouds and distant mountain stay untouched_ 14 | 15 | ## Requirements 16 | 17 | PyTorch is the only requirement. You can install it using the following command: 18 | 19 | ``` 20 | pip install torch 21 | ``` 22 | 23 | 24 | Make sure to install a compatible version of PyTorch based on your system and preferences. 25 | 26 | ## Example 27 | 28 | To run the provided example, follow these steps: 29 | 30 | 1. Make sure you have fulfilled the requirements mentioned above by installing PyTorch. 31 | 32 | 2. Additionally, the example script `example.py` utilizes Matplotlib for visualizations. To install Matplotlib, you can use the following command: 33 | 34 | ``` 35 | pip install matplotlib 36 | ``` 37 | 38 | 3. Once you have PyTorch and Matplotlib installed, you can run the example script using the following command: 39 | 40 | ``` 41 | python example.py 42 | ``` 43 | 44 | ## Issues and Contributions 45 | 46 | If you encounter any issues while using this code or have suggestions for improvements, please feel free to open an issue on this repository. Contributions are also welcome! Please follow the standard GitHub workflow for making pull requests. 47 | 48 | ## Contact 49 | 50 | If you have any questions or encounter an issue, you can contact me at [jamy.lafenetre@ens-paris-saclay.fr]. 51 | 52 | -------------------------------------------------------------------------------- /cas.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 27 19:30:28 2023 4 | 5 | @author: jamyl 6 | """ 7 | 8 | import torch as th 9 | import torch.nn.functional as F 10 | 11 | EPSILON = 1e-6 12 | 13 | def min_(tensor_list): 14 | # return the element-wise min of the tensor list. 15 | x = th.stack(tensor_list) 16 | mn = x.min(axis=0)[0] 17 | return mn 18 | 19 | def max_(tensor_list): 20 | # return the element-wise max of the tensor list. 21 | x = th.stack(tensor_list) 22 | mx = x.max(axis=0)[0] 23 | return mx 24 | 25 | 26 | def contrast_adaptive_sharpening(x, amount=0.8, better_diagonals=True): 27 | """ 28 | Performs a contrast adaptive sharpening on the batch of images x. 29 | The algorithm is directly implemented from FidelityFX's source code, 30 | that can be found here 31 | https://github.com/GPUOpen-Effects/FidelityFX-CAS/blob/master/ffx-cas/ffx_cas.h 32 | 33 | Parameters 34 | ---------- 35 | x : Tensor 36 | Image or stack of images, of shape [burst, channels, ny, nx]. 37 | Burst and channel dimensions can be ommited. 38 | amount : int [0, 1] 39 | Amount of sharpening to do, 0 being minimum and 1 maximum 40 | better_diagonals : bool, optional 41 | If False, the algorithm runs slightly faster, but 42 | won't consider diagonals. The default is True. 43 | 44 | Returns 45 | ------- 46 | Tensor 47 | Processed stack of images. 48 | 49 | """ 50 | assert x.dim() >= 2 51 | assert 0 <= amount <= 1 52 | assert x.max() <= 1 53 | assert x.min() >= 0 54 | 55 | x_padded = F.pad(x, pad=(1, 1, 1, 1)) 56 | # each side gets padded with 1 pixel 57 | # padding = same by default 58 | 59 | # Extracting the 3x3 neighborhood around each pixel 60 | # a b c 61 | # d e f 62 | # g h i 63 | 64 | b = x_padded[..., :-2, 1:-1] 65 | d = x_padded[..., 1:-1, :-2] 66 | e = x_padded[..., 1:-1, 1:-1] 67 | f = x_padded[..., 1:-1, 2:] 68 | h = x_padded[..., 2:, 1:-1] 69 | 70 | if better_diagonals: 71 | a = x_padded[..., :-2, :-2] 72 | c = x_padded[..., :-2, 2:] 73 | g = x_padded[..., 2:, :-2] 74 | i = x_padded[..., 2:, 2:] 75 | 76 | # Computing contrast 77 | cross = (b, d, e, f, h) 78 | mn = min_(cross) 79 | mx = max_(cross) 80 | 81 | if better_diagonals: 82 | diag = (a, c, g, i) 83 | mn2 = min_(diag) 84 | mx2 = max_(diag) 85 | 86 | mx = mx + mx2 87 | mn = mn + mn2 88 | 89 | # Computing local weight 90 | inv_mx = th.reciprocal(mx + EPSILON) # 1/mx 91 | 92 | if better_diagonals: 93 | amp = inv_mx * th.minimum(mn, (2 - mx)) 94 | else: 95 | amp = inv_mx * th.minimum(mn, (1 - mx)) 96 | 97 | # scaling 98 | amp = th.sqrt(amp) 99 | 100 | w = - amp * (amount * (1/5 - 1/8) + 1/8) 101 | # w scales from 0 when amp=0 to K for amp=1 102 | # K scales from -1/5 when amount=1 to -1/8 for amount=0 103 | 104 | # The local conv filter is 105 | # 0 w 0 106 | # w 1 w 107 | # 0 w 0 108 | div = th.reciprocal(1 + 4*w) 109 | output = ((b + d + f + h)*w + e) * div 110 | 111 | # Clipping between 0 and 1. It fixes previous divisions by 0 too 112 | return output.clamp(0, 1) 113 | --------------------------------------------------------------------------------