634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Diffusion models meet image counter-forensics
6 |
7 | [Link to code](https://www.github.com/mtailanian/diff-cf)
8 |
9 | ### Link to download paper:
10 |
11 | [](https://openaccess.thecvf.com/content/WACV2024/papers/Tailanian_Diffusion_Models_Meet_Image_Counter-Forensics_WACV_2024_paper.pdf)
12 | [](https://arxiv.org/abs/2311.13629)
13 |
14 |
15 |
16 | ### Citation:
17 | ```
18 | @InProceedings{Tailanian_2024_WACV,
19 | author = {Tailani\'an, Mat{\'\i}as and Gardella, Marina and Pardo, Alvaro and Mus\'e, Pablo},
20 | title = {Diffusion Models Meet Image Counter-Forensics},
21 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
22 | month = {January},
23 | year = {2024},
24 | pages = {3925-3935}
25 | }
26 | ```
27 | ## Abstract
28 | _From its acquisition in the camera sensors to its storage, different operations are performed to generate the final image. This pipeline imprints specific traces into the image to form a natural watermark. Tampering with an image disturbs these traces; these disruptions are clues that are used by most methods to detect and locate forgeries. In this article, we assess the capabilities of diffusion models to erase the traces left by forgers and, therefore, deceive forensics methods. Such an approach has been recently introduced for adversarial purification, achieving significant performance. We show that diffusion purification methods are well suited for counter-forensics tasks. Such approaches outperform already existing counter-forensics techniques both in deceiving forensics methods and in preserving the natural look of the purified images. The source code is publicly available at https://github.com/mtailanian/diff-cf._
29 |
30 |
31 |

32 |
33 |
34 | # Setup
35 |
36 | This code is based on [guided-diffusion](https://github.com/openai/guided-diffusion). So first we start by adding this
37 | repo as a submodule:
38 |
39 | ```bash
40 | git submodule add git@github.com:openai/guided-diffusion.git guided_diffusion
41 | ```
42 |
43 | And we also need the latest version of `pytorch_ssim`:
44 | ```bash
45 | git submodule add https://github.com/Po-Hsun-Su/pytorch-ssim.git pytorch_ssim
46 | ```
47 |
48 | ## Environment
49 |
50 | Then, we create a conda virtual environment with the required packages:
51 |
52 | ```bash
53 | conda create -n diff-cf python=3.10
54 | conda activate diff-cf
55 |
56 | cd guided_diffusion
57 | pip install -e .
58 | cd ..
59 |
60 | cd pytorch_ssim
61 | pip install -e .
62 | cd ..
63 |
64 | pip install -r requirements.txt
65 | conda install mpi4py
66 | ```
67 |
68 | ## Download the pre-trained diffusion model
69 | ```bash
70 | mkdir models
71 | cd models
72 | wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt
73 | cd ..
74 | ```
75 |
76 | # Data
77 | Save original images from all datasets (tested Korus, FAU, DSO-1, and Coverage) in the following structure:
78 |
79 | ```text
80 | diff-cf
81 | ├── images
82 | │ ├── korus
83 | │ │ ├── original
84 | │ │ │ ├── 0001.png
85 | │ │ │ ├── 0002.png
86 | │ │ │ ├── ...
87 | │ ├── fau
88 | │ │ ├── original
89 | │ │ │ ├── ...
90 | ├── ...
91 | ```
92 |
93 | # Run
94 |
95 | This script runs Diff-CF and Diff-CFG over all tested datasets:
96 | ```bash
97 | python main.py
98 | ```
99 |
100 | And saves the purified images next to the `original` folder, for each variant
101 |
102 |
103 | ## Results
104 |
105 | ### Traces removal
106 |
107 | The first point to evaluate is how well the proposed approaches remove the forgery traces. To do so, we ran several state-of-the-art forgery detection methods on the original datasets as well as in their counter-forensics versions (images purified using different techniques). To evaluate their capability of deceiving the forensics methods, we look at the difference between the detection performance before and after purification. The forensics methods that were used are: ZERO [44], Noiseprint [17], Splicebuster [15], ManTraNet [58], Choi [3, 11], Bammey [2], Shin [47], Comprint [38], CAT-Net [33, 34] and TruFor [25]. A brief description of each method can be found in the supplementary material.
108 |
109 |

110 |
111 |
112 |
113 |

114 |
115 |
116 | ### Image quality assessment
117 | Another important point to evaluate the pertinence of counter-forensic methods is their resulting image quality. We evaluate this quality in two senses. Firstly, we are interested in how natural the purified images are. To evaluate this, we use the reference-free image quality assessment techniques NIQE [42] and BRISQE [41]. Secondly, it is also important to measure the similarity between the input image and the one obtained after the counter-forensic attack. We, of course, want these two images to be perceptually similar. To evaluate this aspect, we use the full reference image quality assessment methods LPIPS [62], SSIM [53], and PSNR. For all the metrics, we use the implementations provided by the PyIQA library [8].
118 |
119 | The first point to evaluate is how well the proposed approaches remove the forgery traces. To do so, we ran several state-of-the-art forgery detection methods on the original datasets as well as in their counter-forensics versions (images purified using different techniques). To evaluate their capability of deceiving the forensics methods, we look at the difference between the detection performance before and after purification. The forensics methods that were used are: ZERO [44], Noiseprint [17], Splicebuster [15], ManTraNet [58], Choi [3, 11], Bammey [2], Shin [47], Comprint [38], CAT-Net [33, 34] and TruFor [25]. A brief description of each method can be found in the supplementary material.
120 |
121 |

122 |
123 |
124 |
125 |

126 |
127 |
128 | # Copyright and License
129 |
130 | This program is free software: you can redistribute it and/or modify
131 | it under the terms of the GNU Affero General Public License as
132 | published by the Free Software Foundation, either version 3 of the
133 | License, or (at your option) any later version.
134 |
135 | This program is distributed in the hope that it will be useful,
136 | but WITHOUT ANY WARRANTY; without even the implied warranty of
137 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
138 | GNU Affero General Public License for more details.
139 |
140 | You should have received a copy of the GNU Affero General Public License
141 | along with this program. If not, see .
--------------------------------------------------------------------------------
/assets/results_image_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtailanian/diff-cf/d8a82fe79fdda38ec8494462698c602a2f62552f/assets/results_image_01.png
--------------------------------------------------------------------------------
/assets/results_image_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtailanian/diff-cf/d8a82fe79fdda38ec8494462698c602a2f62552f/assets/results_image_02.png
--------------------------------------------------------------------------------
/assets/results_table_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtailanian/diff-cf/d8a82fe79fdda38ec8494462698c602a2f62552f/assets/results_table_01.png
--------------------------------------------------------------------------------
/assets/results_table_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtailanian/diff-cf/d8a82fe79fdda38ec8494462698c602a2f62552f/assets/results_table_02.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtailanian/diff-cf/d8a82fe79fdda38ec8494462698c602a2f62552f/assets/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import inspect
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pytorch_ssim
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision.transforms
10 | import yaml
11 | from PIL import Image
12 | from rich.progress import Progress
13 | from torch import nn
14 | from torchvision import transforms
15 |
16 | import guided_diffusion.dist_util as dist_util
17 | from guided_diffusion.script_util import create_model_and_diffusion
18 | from utils import image2patches, patches2image, imshow_tensor
19 |
20 | DEVICE = dist_util.dev()
21 | # DEVICE = 'cpu'
22 | # DEVICE = "cuda:1"
23 |
24 |
25 | class Diffusion(nn.Module):
26 | def __init__(self, model_path, **kwargs):
27 | super(Diffusion, self).__init__()
28 | self._load_pretrained_model(model_path, **kwargs)
29 | # print(f"Diffusion pretrained model is successfully loaded from {model_path}")
30 |
31 | def _load_pretrained_model(self, model_path: str, **kwargs):
32 | # Needed to pass only expected args to the function
33 | argnames = inspect.getfullargspec(create_model_and_diffusion)[0]
34 | expected_args = {name: kwargs[name] for name in argnames}
35 | self.model, self.diffusion = create_model_and_diffusion(**expected_args)
36 |
37 | self.model.load_state_dict(
38 | dist_util.load_state_dict(str(model_path), map_location="cpu")
39 | )
40 |
41 | self.model.to(DEVICE)
42 | if kwargs['use_fp16']:
43 | self.model.convert_to_fp16()
44 |
45 | self.model.eval()
46 |
47 | @torch.no_grad()
48 | def purify(self, image_batch, t, show=False):
49 | batch_size = image_batch.shape[0]
50 | x = self.diffusion.q_sample(image_batch, torch.tensor(t * batch_size, device=image_batch.device))
51 |
52 | if show:
53 | imshow_tensor(image_batch, "original")
54 | imshow_tensor(x, "noisy")
55 |
56 | for i in reversed(range(t)):
57 | ti = torch.tensor([i] * batch_size, device=x.device)
58 | diffusion_output = self.diffusion.p_sample(
59 | self.model, x, ti,
60 | clip_denoised=True,
61 | denoised_fn=None,
62 | cond_fn=None,
63 | model_kwargs=None
64 | )
65 | x = diffusion_output["sample"]
66 | predicted_x0 = diffusion_output["pred_xstart"]
67 |
68 | if show:
69 | imshow_tensor(x)
70 | imshow_tensor(predicted_x0)
71 |
72 | denoised = x
73 | if show:
74 | imshow_tensor(denoised, "denoised")
75 | imshow_tensor(image_batch - denoised, "noise")
76 |
77 | return denoised
78 |
79 | def compute_scale(self, t, m):
80 | alpha_bar = self.diffusion.alphas_cumprod[t]
81 | return np.sqrt(1 - alpha_bar) / (m * np.sqrt(alpha_bar))
82 |
83 | def purify_guided(self, image_batch, t, guide_scale=50_000, guide_mode='MSE'):
84 | batch_size = image_batch.shape[0]
85 | x_guide = image_batch.clone()
86 |
87 | def cond_fn(x_t, cond_t, **kwargs):
88 | x_guide_t = self.diffusion.q_sample(x_guide, cond_t.clone().detach().to(x_guide.device) * x_guide.shape[0])
89 |
90 | scale = self.compute_scale(cond_t, 1. / guide_scale)
91 |
92 | with torch.enable_grad():
93 | x_in = x_t.detach().requires_grad_(True)
94 | if guide_mode == 'MSE':
95 | similarity = -1 * F.mse_loss(x_in, x_guide_t)
96 | elif guide_mode == 'SSIM':
97 | similarity = pytorch_ssim.ssim(x_in, x_guide_t)
98 | elif guide_mode == 'CORR':
99 | similarity = torch.corrcoef(torch.cat((x_in.reshape((1, -1)), x_guide_t.reshape((1, -1)))))
100 | else:
101 | raise ValueError(f"Unknown guide mode: {guide_mode}")
102 | gradient = torch.autograd.grad(similarity.sum(), x_in)[0] * scale
103 | return gradient
104 |
105 | with torch.no_grad():
106 | x = self.diffusion.q_sample(image_batch, torch.tensor(t * batch_size, device=image_batch.device))
107 | for i in reversed(range(t)):
108 | ti = torch.tensor([i] * batch_size, device=x.device)
109 | x = self.diffusion.p_sample(
110 | self.model, x, ti,
111 | clip_denoised=True,
112 | denoised_fn=None,
113 | cond_fn=cond_fn,
114 | model_kwargs={}
115 | )["sample"]
116 |
117 | return x
118 |
119 |
120 | def main_diff_cf(dataset='korus', guided=False, t=10, guide_mode='SSIM', guide_scale=1_000_000):
121 |
122 | print(f"Dataset: {dataset}")
123 | print(f"t: {t}")
124 | print(f"Guided: {guided}")
125 | if guided:
126 | print(f"Guide mode: {guide_mode}")
127 | print(f"Guide scale: {guide_scale}")
128 |
129 | # Diffusion model
130 | diffusion_args = yaml.safe_load(open("parameters.yaml", "r"))
131 | diffusion_args['model_path'] = Path("models") / diffusion_args['model_path']
132 | diffusion_model = Diffusion(**diffusion_args)
133 |
134 | patch_size = diffusion_args['image_size']
135 |
136 | # Input/Output directories
137 | base_dir = Path("images")
138 |
139 | if guided:
140 | variation = Path("diff-cfg") / f"{guide_mode}_gs{guide_scale:015.0f}" / f"t{t:03d}"
141 | else:
142 | variation = Path("diff-cf") / f"t{t:03d}"
143 | images_dir = base_dir / dataset / "original"
144 | results_dir = base_dir / dataset / variation
145 | results_dir.mkdir(parents=True, exist_ok=True)
146 |
147 | img_transforms = torchvision.transforms.Compose([
148 | torchvision.transforms.ToTensor(),
149 | transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
150 | ])
151 |
152 | print(f"Diff-CF{'G' if guided else ''}")
153 |
154 | images_paths = list(images_dir.glob('*'))
155 | with Progress() as pb:
156 | ds_progress = pb.add_task(dataset.upper(), total=len(images_paths))
157 | for j, img_path in enumerate(images_paths):
158 |
159 | out_path = results_dir / f"{img_path.stem}.png"
160 |
161 | if out_path.exists():
162 | pb.update(ds_progress, completed=j + 1)
163 | continue
164 |
165 | img = Image.open(img_path).convert('RGB')
166 | img = img_transforms(img).unsqueeze(0).to(DEVICE)
167 |
168 | patches, patching_args = image2patches(img, patch_size=patch_size, complete_patches_only=False)
169 | denoised_patches = []
170 |
171 | patch_progress = pb.add_task(str(img_path.stem), total=len(patches), transient=True)
172 |
173 | for i in range(len(patches)):
174 | if guided:
175 | denoised_patches.append(diffusion_model.purify_guided(patches[i:i + 1], t, guide_scale, guide_mode))
176 | else:
177 | denoised_patches.append(diffusion_model.purify(patches[i:i + 1], t, show=False))
178 |
179 | pb.update(patch_progress, completed=i + 1)
180 | pb.remove_task(patch_progress)
181 | pb.update(ds_progress, completed=j + 1)
182 |
183 | denoised_patches = torch.cat(denoised_patches, dim=0)
184 | denoised = patches2image(denoised_patches, patching_args).cpu()
185 |
186 | # Save the denoised image
187 | out_img = transforms.ToPILImage()(denoised[0] / 2 + 0.5)
188 | out_img.save(out_path)
189 |
190 | del img, out_img, patches, denoised_patches, denoised
191 | gc.collect()
192 |
193 |
194 | if __name__ == '__main__':
195 | for dataset in [
196 | "korus",
197 | "FAU",
198 | "DSO-1",
199 | "COVERAGE",
200 | ]:
201 | main_diff_cf(dataset=dataset, guided=False, t=40)
202 | main_diff_cf(dataset=dataset, guided=True, t=40, guide_mode='SSIM', guide_scale=1_000_000)
203 |
--------------------------------------------------------------------------------
/parameters.yaml:
--------------------------------------------------------------------------------
1 | model_path: 256x256_diffusion_uncond.pt
2 | steps: [150, 250] # [50, 150, 250]
3 | blocks: [5, 6, 7, 8, 12]
4 | input_activations: false
5 | image_size: 256
6 | class_cond: false
7 | learn_sigma: true
8 | num_channels: 256
9 | num_res_blocks: 2
10 | channel_mult: ''
11 | num_heads: 4
12 | num_head_channels: 64
13 | num_heads_upsample: -1
14 | attention_resolutions: '32,16,8'
15 | dropout: 0.1
16 | diffusion_steps: 1000
17 | noise_schedule: 'linear'
18 | timestep_respacing: ''
19 | use_kl: False
20 | predict_xstart: false
21 | rescale_timesteps: false
22 | rescale_learned_sigmas: false
23 | use_checkpoint: false
24 | use_scale_shift_norm: true
25 | resblock_updown: true
26 | use_fp16: true
27 | use_new_attention_order: false
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml
2 | numpy
3 | matplotlib
4 | Pillow
5 | rich
6 | torchvision
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def image2patches(image, patch_size, complete_patches_only=True):
6 | """
7 | Fold 1 image into patches of fixed size
8 | image: [1, C, H, W]
9 | patches: [n_patches, C, patch_size, patch_size]
10 |
11 | Parameters
12 | ----------
13 | image: [1, C, H, W]
14 | patch_size: size of the patch
15 | complete_patches_only: if True, only return patches that are complete (i.e. no padding)
16 |
17 | Returns
18 | -------
19 | patches: [n_patches, C, patch_size, patch_size]
20 | patching_args: dict containing useful information for reconstructing the image from patches
21 | """
22 | patching_args = {'original_img_shape': image.shape[-2:]}
23 |
24 | if not complete_patches_only:
25 | new_width = image.shape[-1] + patch_size - image.shape[-1] % patch_size
26 | new_height = image.shape[-2] + patch_size - image.shape[-2] % patch_size
27 | image = torch.nn.functional.pad(image, (0, new_width - image.shape[-1], 0, new_height - image.shape[-2]))
28 |
29 | patching_args['padded_img_shape'] = image.shape[-2:]
30 |
31 | patches_fold_h = image.unfold(2, patch_size, patch_size)
32 | patches_fold_hw = patches_fold_h.unfold(3, patch_size, patch_size)
33 | patches = patches_fold_hw.permute(0, 2, 3, 1, 4, 5).reshape(-1, image.shape[1], patch_size, patch_size)
34 |
35 | return patches, patching_args
36 |
37 |
38 | def patches2image(patches, patching_args):
39 | """
40 |
41 | Parameters
42 | ----------
43 | patches: [n_patches, C, patch_size, patch_size]
44 | patching_args: dict containing useful information for reconstructing the image from patches
45 |
46 | Returns
47 | -------
48 | image: [1, C, H, W]
49 | """
50 |
51 | patch_size = patches.shape[-1]
52 | n_patches_h = patching_args['padded_img_shape'][-1] // patch_size
53 | unfolded = patches.unfold(0, n_patches_h, n_patches_h).permute(0, 4, 2, 3, 1)
54 | stitch_v = torch.cat(tuple(unfolded), dim=1) # [11, 704, 64]
55 | stitch_vh = torch.cat(tuple(stitch_v), dim=1) # [704, 704]
56 | image = stitch_vh.permute(2, 0, 1).unsqueeze(0)
57 | image = image[..., :patching_args['original_img_shape'][0], :patching_args['original_img_shape'][1]]
58 | return image
59 |
60 |
61 | def normalize(img):
62 | return (img - img.min()) / (img.max() - img.min())
63 |
64 |
65 | def imshow_tensor(img, title=None):
66 | plt.imshow(normalize(img[0].permute(1, 2, 0).detach().cpu().numpy()))
67 | if title is not None:
68 | plt.title(title)
69 | plt.axis('off')
70 | plt.tight_layout()
71 | plt.show()
72 |
--------------------------------------------------------------------------------