├── requirements ├── dev.in ├── prod.in ├── dev.txt └── prod.txt ├── src ├── models │ ├── __init__.py │ └── exif_sc │ │ ├── __init__.py │ │ ├── postprocess.py │ │ ├── README.md │ │ └── exif_sc.py ├── trainers │ ├── __init__.py │ ├── README.md │ └── exif_trainer.py ├── attacks │ ├── __init__.py │ ├── jpeg_compressor.py │ ├── README.md │ └── lots.py ├── evaluation │ ├── __init__.py │ ├── metrics.py │ ├── non_adv_evaluators.py │ ├── evaluators.py │ └── evaluators_test.py ├── datasets │ ├── __init__.py │ ├── utils.py │ ├── in_the_wild.py │ ├── scene_completion.py │ ├── realistic_tampering.py │ ├── dso_1.py │ ├── columbia.py │ └── mirflickr_25k.py ├── utils.py └── structures.py ├── report.pdf ├── data ├── demo.png └── raw │ ├── mirflickr_25k │ ├── metadata.toml │ └── README.md │ ├── in_the_wild │ ├── metadata.toml │ └── README.md │ ├── scene_completion │ ├── metadata.toml │ └── README.md │ ├── dso_1 │ ├── metadata.toml │ └── README.md │ ├── realistic_tampering │ ├── metadata.toml │ └── README.md │ └── columbia │ ├── README.md │ └── metadata.toml ├── assets ├── lots_examples │ ├── dso_1.png │ ├── dso_2.png │ ├── columbia_1.png │ ├── columbia_2.png │ ├── dso1_jpeg_1.png │ ├── dso1_jpeg_2.png │ ├── dso1_mean_1.png │ ├── dso1_mean_2.png │ ├── dso1_sample_1.png │ ├── dso1_sample_2.png │ ├── columbia_jpeg_1.png │ ├── columbia_jpeg_2.png │ ├── columbia_mean_1.png │ ├── columbia_mean_2.png │ ├── columbia_sample_1.png │ └── columbia_sample_2.png └── exif_sc_examples │ ├── columbia.png │ ├── in_the_wild.png │ ├── scene_completion.png │ └── realistic_tampering.png ├── environment.yml ├── configs ├── train │ └── exif_sc.yaml └── evaluate │ ├── non_adv.yaml │ └── adv.yaml ├── Makefile ├── LICENSE ├── non_adv_evaluate.py ├── evaluate.py ├── .gitignore ├── train.py ├── README.md └── notebooks ├── exif_sc_port.ipynb └── viz.ipynb /requirements/dev.in: -------------------------------------------------------------------------------- 1 | -c prod.txt 2 | pylama 3 | black -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .exif_sc import EXIF_SC, EXIF_Net -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .exif_trainer import EXIF_Trainer1, EXIF_Trainer2 2 | -------------------------------------------------------------------------------- /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/report.pdf -------------------------------------------------------------------------------- /data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/data/demo.png -------------------------------------------------------------------------------- /src/models/exif_sc/__init__.py: -------------------------------------------------------------------------------- 1 | from .exif_sc import EXIF_SC 2 | from .networks import EXIF_Net -------------------------------------------------------------------------------- /src/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lots import PatchLOTS 2 | from .jpeg_compressor import JPEG_Compressor 3 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluators import Evaluator 2 | from .non_adv_evaluators import NonAdvEvaluator 3 | -------------------------------------------------------------------------------- /assets/lots_examples/dso_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso_1.png -------------------------------------------------------------------------------- /assets/lots_examples/dso_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso_2.png -------------------------------------------------------------------------------- /assets/exif_sc_examples/columbia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/exif_sc_examples/columbia.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_1.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_2.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_jpeg_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_jpeg_1.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_jpeg_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_jpeg_2.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_mean_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_mean_1.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_mean_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_mean_2.png -------------------------------------------------------------------------------- /assets/exif_sc_examples/in_the_wild.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/exif_sc_examples/in_the_wild.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_sample_1.png -------------------------------------------------------------------------------- /assets/lots_examples/dso1_sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/dso1_sample_2.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_jpeg_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_jpeg_1.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_jpeg_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_jpeg_2.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_mean_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_mean_1.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_mean_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_mean_2.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_sample_1.png -------------------------------------------------------------------------------- /assets/lots_examples/columbia_sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/lots_examples/columbia_sample_2.png -------------------------------------------------------------------------------- /assets/exif_sc_examples/scene_completion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/exif_sc_examples/scene_completion.png -------------------------------------------------------------------------------- /assets/exif_sc_examples/realistic_tampering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhe-ang/fake-detection-lab/HEAD/assets/exif_sc_examples/realistic_tampering.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fake-detection-lab 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7 6 | - cudatoolkit=10.2 7 | - cudnn=7.6.5 8 | - pip 9 | - pip: 10 | - pip-tools -------------------------------------------------------------------------------- /data/raw/mirflickr_25k/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'mirflickr25k.zip' 2 | sha256 = 'e0d5b222ddd078a9be551c2a2b90e17cb3aa90d0a67f775545f31aad05787881' 3 | url = 'http://press.liacs.nl/mirflickr/mirflickr25k.v3/mirflickr25k.zip' -------------------------------------------------------------------------------- /data/raw/in_the_wild/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'in_wild.tar.gz' 2 | sha256 = '6a00d4c66742bd50068c739edd67d99b346a51db797778140a568763ef49c2c5' 3 | url = 'https://minyoungg.github.io/selfconsistency/in_wild/in_wild.tar.gz' 4 | -------------------------------------------------------------------------------- /data/raw/scene_completion/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'test_set.zip' 2 | sha256 = '5b219cf3262844ec6b89af021b2d86a615a94a0f0245800bd1540254e1180d76' 3 | url = 'http://graphics.cs.cmu.edu/projects/scene-completion/test_set.zip' 4 | -------------------------------------------------------------------------------- /requirements/prod.in: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | opencv-python 4 | scipy 5 | scikit-learn 6 | matplotlib 7 | tqdm 8 | jupyter 9 | pyyaml 10 | wandb 11 | toml 12 | gdown 13 | pandas 14 | pytorch-lightning 15 | torchmetrics -------------------------------------------------------------------------------- /data/raw/dso_1/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'tifs-database.zip' 2 | sha256 = 'f1b2c66c8e0a1e13444990c980d0d569e57243efe64d83c54fff94455e5b1746' 3 | url = 'http://ic.unicamp.br/~rocha/pub/downloads/2014-tiago-carvalho-thesis/tifs-database.zip' 4 | -------------------------------------------------------------------------------- /data/raw/realistic_tampering/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = 'realistic-tampering-dataset.zip' 2 | sha256 = 'c86e664027c3c67136db2a871f1cb19ab9646938675d0077d50808bd7378b8ef' 3 | url = 'https://drive.google.com/u/0/uc?id=0B73Fq3C_nT4aOThud0NYWUR2MTQ' 4 | -------------------------------------------------------------------------------- /data/raw/in_the_wild/README.md: -------------------------------------------------------------------------------- 1 | # In-the-Wild Image Splice Dataset 2 | 3 | Website: 4 | - https://minyoungg.github.io/selfconsistency/ 5 | 6 | Publication: 7 | - M. Huh, A. Liu, A. Owens, A. A. Efros. Fighting Fake News: Image Splice Detection via Learned Self-Consistency. ECCV, 2018. -------------------------------------------------------------------------------- /configs/train/exif_sc.yaml: -------------------------------------------------------------------------------- 1 | name: train_0 # Name of experiment 2 | 3 | n_epochs_1: 1 4 | n_epochs_2: 1 5 | 6 | learning_rate_1: 0.0003 7 | learning_rate_2: 0.0003 8 | 9 | datamodule: datasets.MIRFLICKR_25kDataModule 10 | datamodule_args: 11 | n_exif_attr: 80 12 | patch_size: 128 13 | batch_size: 128 14 | iters_per_epoch: 500 -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .realistic_tampering import RealisticTamperingDataset 2 | from .columbia import ColumbiaDataset 3 | from .in_the_wild import InTheWildDataset 4 | from .scene_completion import SceneCompletionDataset 5 | from .dso_1 import DSO_1_Dataset 6 | from .mirflickr_25k import MIRFLICKR_25kDataset, MIRFLICKR_25kDataModule 7 | -------------------------------------------------------------------------------- /data/raw/scene_completion/README.md: -------------------------------------------------------------------------------- 1 | # Scene Completion Using Millions of Photographs Dataset 2 | 3 | Website: 4 | - http://graphics.cs.cmu.edu/projects/scene-completion/ 5 | 6 | Publication: 7 | - James Hays, Alexei A. Efros. Scene Completion Using Millions of Photographs. ACM Transactions on Graphics (SIGGRAPH 2007). August 2007, vol. 26, No. 3. -------------------------------------------------------------------------------- /configs/evaluate/non_adv.yaml: -------------------------------------------------------------------------------- 1 | name: dso1 # Name of experiment 2 | 3 | model: models.EXIF_SC 4 | model_args: 5 | device: cuda:0 6 | 7 | dataset: datasets.DSO_1_Dataset 8 | dataset_args: 9 | spliced_only: False 10 | 11 | # resize: [768, 1152] # For Columbia 12 | # resize: [600, 700] # For Hays 13 | # resize: [1000, 1400] # For In-The-Wild 14 | resize: [1600, 1900] # For DSO-1 15 | -------------------------------------------------------------------------------- /data/raw/columbia/README.md: -------------------------------------------------------------------------------- 1 | # Columbia Dataset 2 | Columbia Uncompressed Image Splicing Detection Evaluation Dataset 3 | 4 | Website: 5 | - https://www.ee.columbia.edu/ln/dvmm/downloads/authsplcuncmp/ 6 | 7 | Publication: 8 | - Yu-Feng Hsu, Shih-Fu Chang. Detecting Image Splicing Using Geometry Invariants And Camera Characteristics Consistency. International Conference on Multimedia and Expo (ICME), 2006. -------------------------------------------------------------------------------- /data/raw/columbia/metadata.toml: -------------------------------------------------------------------------------- 1 | filename = ['4cam_auth.tar.bz2', '4cam_splc.tar.bz2'] 2 | sha256 = ['561dcc477824817e23e8f4457e9139096f05139d8440725112c0281b7f9ebac3', 'b47411de8cd73283d0525f6e03e33d98adef6e7f73b27422ac5079837692679f'] 3 | url = ['https://www.dropbox.com/sh/786qv3yhvc7s9ki/AABaQvI-lPiM3Zl64RQoDCiMa/4cam_auth.tar.bz2?dl=1', 'https://www.dropbox.com/sh/786qv3yhvc7s9ki/AAAESATxO7wncDMKkl1XjyNaa/4cam_splc.tar.bz2?dl=1'] 4 | -------------------------------------------------------------------------------- /data/raw/mirflickr_25k/README.md: -------------------------------------------------------------------------------- 1 | # MIRFLICKR-25k Dataset 2 | Offered by the LIACS Medialab at Leiden University, The Netherlands 3 | Introduced by the ACM MIR Committee in 2008 as an ACM sponsored image retrieval evaluation 4 | 5 | Website: 6 | - https://press.liacs.nl/mirflickr/ 7 | 8 | Publication: 9 | - M. J. Huiskes, M. S. Lew (2008). The MIR Flickr Retrieval Evaluation. ACM International Conference on Multimedia Information Retrieval (MIR'08), Vancouver, Canada -------------------------------------------------------------------------------- /data/raw/realistic_tampering/README.md: -------------------------------------------------------------------------------- 1 | # Realistic Tampering Dataset 2 | 3 | Website: 4 | - http://pkorus.pl/downloads/dataset-realistic-tampering 5 | 6 | Publication: 7 | - P. Korus & J. Huang, Multi-scale Analysis Strategies in PRNU-based Tampering Localization, IEEE Trans. Information Forensics & Security, 2017 8 | - P. Korus & J. Huang, Evaluation of Random Field Models in Multi-modal Unsupervised Tampering Localization, Proc. of IEEE Int. Workshop on Inf. Forensics and Security, 2016 -------------------------------------------------------------------------------- /data/raw/dso_1/README.md: -------------------------------------------------------------------------------- 1 | # DSO-1 Dataset 2 | 3 | Website: 4 | - https://recodbr.wordpress.com/code-n-data/#dso1_dsi1 5 | 6 | Publication: 7 | - T. J. d. Carvalho, C. Riess, E. Angelopoulou, H. Pedrini and A. d. R. Rocha, “Exposing Digital Image Forgeries by Illumination Color Classification,” in IEEE Transactions on Information Forensics and Security, vol. 8, no. 7, pp. 1182-1194, July 2013. doi: doi: 10.1109/TIFS.2013.2265677 8 | - T. Carvalho, F. A. Faria, H. Pedrini, R. da S. Torres and A. Rocha, “Illuminant-Based Transformed Spaces for Image Forensics,” in IEEE Transactions on Information Forensics and Security, vol. 11, no. 4, pp. 720-733, April 2016. doi: doi: 10.1109/TIFS.2015.2506548 -------------------------------------------------------------------------------- /configs/evaluate/adv.yaml: -------------------------------------------------------------------------------- 1 | name: columbia_jpeg_0 # Name of experiment 2 | 3 | model: models.EXIF_SC 4 | model_args: 5 | device: cuda:1 6 | 7 | # dataset: datasets.DSO_1_Dataset 8 | dataset: datasets.ColumbiaDataset 9 | dataset_args: 10 | spliced_only: True 11 | 12 | resize: [768, 1152] # For Columbia 13 | # resize: [600, 700] # For Hays 14 | # resize: [1000, 1400] # For In-The-Wild 15 | # resize: [1600, 1900] # For DSO-1 16 | 17 | # Adversarial attack arguments 18 | # attacker: attacks.PatchLOTS 19 | # attacker_args: 20 | # adv_step_size: 10000 21 | # adv_n_iter: 50 22 | # method: sample 23 | 24 | attacker: attacks.JPEG_Compressor 25 | attacker_args: 26 | quality: 30 27 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements/dev.in 6 | # 7 | appdirs==1.4.4 8 | # via black 9 | black==20.8b1 10 | # via -r requirements/dev.in 11 | click==7.1.2 12 | # via 13 | # -c requirements/prod.txt 14 | # black 15 | mccabe==0.6.1 16 | # via pylama 17 | mypy-extensions==0.4.3 18 | # via black 19 | pathspec==0.8.1 20 | # via black 21 | pycodestyle==2.7.0 22 | # via pylama 23 | pydocstyle==6.0.0 24 | # via pylama 25 | pyflakes==2.3.1 26 | # via pylama 27 | pylama==7.7.1 28 | # via -r requirements/dev.in 29 | regex==2021.4.4 30 | # via black 31 | snowballstemmer==2.1.0 32 | # via pydocstyle 33 | toml==0.10.2 34 | # via 35 | # -c requirements/prod.txt 36 | # black 37 | typed-ast==1.4.2 38 | # via black 39 | typing-extensions==3.7.4.3 40 | # via 41 | # -c requirements/prod.txt 42 | # black 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Thanks to http://blog.ianpreston.ca/2020/05/13/conda_envs.html for working some of this out! 2 | 3 | # Oneshell means all lines in a recipe run in the same shell 4 | .ONESHELL: 5 | 6 | # Need to specify bash in order for conda activate to work 7 | SHELL=/bin/bash 8 | 9 | # Note that the extra activate is needed to ensure that the activate floats env to the front of PATH 10 | CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate 11 | 12 | # Same name as in environment.yml 13 | CONDA_ENV=fake-detection-lab 14 | 15 | all: conda-env-update pip-compile pip-sync 16 | 17 | # Create or update conda env 18 | conda-env-update: 19 | conda env update --prune 20 | 21 | # Compile exact pip packages 22 | pip-compile: 23 | $(CONDA_ACTIVATE) $(CONDA_ENV) 24 | pip-compile -v requirements/prod.in && pip-compile -v requirements/dev.in 25 | 26 | # Install pip packages 27 | pip-sync: 28 | $(CONDA_ACTIVATE) $(CONDA_ENV) 29 | pip-sync requirements/prod.txt requirements/dev.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 lemonwaffle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/exif_sc/postprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import sklearn.cluster 4 | 5 | 6 | def mean_shift(points_, heat_map, iters=5): 7 | points = np.copy(points_) 8 | kdt = scipy.spatial.cKDTree(points) 9 | eps_5 = np.percentile( 10 | scipy.spatial.distance.cdist(points, points, metric="euclidean"), 10 11 | ) 12 | 13 | for epis in range(iters): 14 | for point_ind in range(points.shape[0]): 15 | point = points[point_ind] 16 | nearest_inds = kdt.query_ball_point(point, r=eps_5) 17 | points[point_ind] = np.mean(points[nearest_inds], axis=0) 18 | val = [] 19 | for i in range(points.shape[0]): 20 | val.append( 21 | kdt.count_neighbors(scipy.spatial.cKDTree(np.array([points[i]])), r=eps_5) 22 | ) 23 | mode_ind = np.argmax(val) 24 | ind = np.nonzero(val == np.max(val)) 25 | return np.mean(points[ind[0]], axis=0).reshape(heat_map.shape[0], heat_map.shape[1]) 26 | 27 | 28 | def normalized_cut(res): 29 | sc = sklearn.cluster.SpectralClustering( 30 | n_clusters=2, n_jobs=-1, affinity="precomputed" 31 | ) 32 | out = sc.fit_predict(res.reshape((res.shape[0] * res.shape[1], -1))) 33 | vis = out.reshape((res.shape[0], res.shape[1])) 34 | return vis 35 | -------------------------------------------------------------------------------- /non_adv_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from src.evaluation import NonAdvEvaluator 5 | from src.utils import ConfigManager, load_yaml 6 | 7 | 8 | def main(config, args): 9 | config_manager = ConfigManager(config) 10 | 11 | model = config_manager.init_object("model", weight_file=args.weights_path) 12 | dataset = config_manager.init_object("dataset") 13 | 14 | # Run evaluation 15 | evaluator = NonAdvEvaluator(model, dataset) 16 | results = evaluator.evaluate(tuple(config["resize"])) 17 | 18 | # Save results 19 | print(results) 20 | with open(args.results_path, "w") as f: 21 | json.dump(results, f) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--config", 28 | help="path to the config file", 29 | default="configs/evaluate/config.yaml", 30 | ) 31 | parser.add_argument( 32 | "--weights_path", 33 | help="path to the weights / checkpoint to load", 34 | default="artifacts/exif_sc.npy", 35 | ) 36 | parser.add_argument( 37 | "--results_path", 38 | help="path to store evaluation results as JSON file", 39 | default="results.json", 40 | ) 41 | args = parser.parse_args() 42 | 43 | # Load config file 44 | config = load_yaml(args.config) 45 | 46 | main(config, args) 47 | -------------------------------------------------------------------------------- /src/attacks/jpeg_compressor.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from typing import Any, Dict 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | class JPEG_Compressor: 10 | def __init__(self, quality: int = 30) -> None: 11 | """ 12 | Parameters 13 | ---------- 14 | quality : int, optional 15 | Quality of compressed image, from [0, 100], by default 30 16 | """ 17 | self.quality = quality 18 | 19 | def __call__(self, model, data: Dict[str, Any]) -> torch.ByteTensor: 20 | """ 21 | Parameters 22 | ---------- 23 | model : [type] 24 | data : Dict[str, Any] 25 | From dataloader 26 | 27 | Returns 28 | ------- 29 | torch.ByteTensor 30 | [C, H, W], the compressed tensor 31 | """ 32 | clean_img = data["img"] 33 | # Convert to PIL image 34 | np_img = clean_img.permute(1, 2, 0).numpy() 35 | pil_img = Image.fromarray(np_img) 36 | 37 | # Compress to JPEG 38 | with BytesIO() as f: 39 | pil_img.save(f, format="JPEG", optimize=True, quality=self.quality) 40 | f.seek(0) 41 | jpg_img = Image.open(f) 42 | jpg_img.load() 43 | 44 | # Convert back to torch tensor 45 | jpg_img_np = np.array(jpg_img) 46 | jpg_img_t = torch.tensor(jpg_img_np).permute(2, 0, 1) 47 | 48 | return jpg_img_t 49 | -------------------------------------------------------------------------------- /src/models/exif_sc/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Results 2 | There are likely to be some hidden hyperparameter choices that are not replicated exactly in these experiments. 3 | 4 | In the paper, every dataset in split in half into train / test sets. Since the exact split is unknown, all my experiments simply evaluate on the entire dataset. 5 | 6 | ## Splice Detection 7 | Average precision scores are reported. 8 | 9 | | | Results from Paper | This Implementation | 10 | | -------- | ------------------ | ------------------- | 11 | | RT | 0.55 | 0.54 | 12 | | Columbia | 0.98 | 0.95 | 13 | 14 | ## Splice Localization 15 | Class-balanced IOU (cIOU) scores are reported. 16 | 17 | I resize all the ground-truth and prediction maps into the same size in order to compute the optimal threshold and corresponding IoU score for each image in a vectorized and efficient manner. 18 | 19 | | | Results from Paper | This Implementation | 20 | | ----------- | ------------------ | ------------------- | 21 | | RT | 0.54 | 0.54 | 22 | | Columbia | 0.85 | 0.67 | 23 | | In-the-Wild | 0.58 | 0.64 | 24 | | Hays | 0.65 | 0.54 | 25 | 26 | ## Qualitative Results 27 | ### Realistic Tampering 28 | ![](/assets/exif_sc_examples/realistic_tampering.png) 29 | 30 | ### Columbia 31 | ![](/assets/exif_sc_examples/columbia.png) 32 | 33 | ### In-the-Wild 34 | ![](/assets/exif_sc_examples/in_the_wild.png) 35 | 36 | ### Scene Completion (Hays) 37 | ![](/assets/exif_sc_examples/scene_completion.png) -------------------------------------------------------------------------------- /src/trainers/README.md: -------------------------------------------------------------------------------- 1 | # EXIF-SC Training 2 | The code implements a mock of the self-consistency training algorithm, which comprises two-stages: 3 | 4 | ## First Stage 5 | The first stage trains the network to predict EXIF attribute consistency (multi-label classification) from a pair of image patches. 6 | 7 | The sampling process for a training batch is as follows: 8 | 9 | 1. We'll select a specific EXIF attribute value. To do this, we'll first randomly sample an EXIF attribute, and then randomly sample a value from it. 10 | 2. The first half of the batch will be consistent, i.e. pairs will both have that specific attribute value. We randomly sample from the set of images with that attribute value. 11 | 3. The second half of the batch will be inconsistent, i.e. the first image will have that specific value, but the second image to be compared with will have a different value. We sample from the rest of the images to form those second images. 12 | 13 | ## Second Stage 14 | The second stage attaches another MLP on top of the EXIF attribute predictions and trains it to predict whether an image pair comes from the same image or not(binary classification). 15 | 16 | The rest of the network weights are frozen, and only this MLP is trained. 17 | 18 | A training batch is constructed by ensuring that the first half of the batch is consistent (pairs are patches that come from the same image), and the second half of the batch is inconsistent (pairs are patches that come from different images). 19 | 20 | ## Implementation Notes 21 | - The EXIF attributes to predict are chosen dynamically based on the dataset, for e.g. the top 80 EXIF attributes with the least missing values. 22 | - If an attribute is missing for either image, it is immediately assigned a target of 0. 23 | - Binary cross-entropy is used as the loss function. 24 | - The incorporation of post-processing consistency attributes is yet to be implemented. -------------------------------------------------------------------------------- /src/attacks/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Results 2 | Localization metrics are reported. 3 | 4 | Only the spliced images in the datasets were used for evaluation. For the MCC metric that requires binary decision maps instead of score maps, a default threshold value of 0.5 was chosen (most papers search for the optimal threshold value and report the corresponding highest metric score instead). The F1 score was computed by finding the optimal threshold. 5 | 6 | Types of attack: 7 | - **AdvMean**: sets all target features to be the mean feature of all authentic patches. 8 | - **AdvSample**: samples uniformly from the set of features of authentic patches to be the target features for non-authentic patches. 9 | - **JPEG**: JPEG Compression. 10 | 11 | A step size of 10000 was used, with 50 iterations. 12 | 13 | | Dataset | F1 ↑ | MCC ↑ | mAP ↑ | AUC ↑ | cIoU ↑ | 14 | | ------------------ | ---------- | ---------- | ---------- | ---------- | ---------- | 15 | | Columbia | 0.8703 | 0.6971 | 0.8958 | 0.9697 | 0.8490 | 16 | | AdvMean-Columbia | 0.7014 | **0.0004** | 0.6984 | 0.8773 | 0.7194 | 17 | | JPEG-Columbia | 0.6397 | 0.2417 | 0.6084 | 0.8476 | 0.6528 | 18 | | AdvSample-Columbia | **0.5067** | 0.0081 | **0.3832** | **0.7213** | **0.5363** | 19 | | | 20 | | DSO-1 | 0.9473 | 0.3650 | 0.9652 | 0.8439 | 0.5038 | 21 | | AdvMean-DSO-1 | 0.9263 | **0.0221** | 0.9313 | 0.7303 | 0.5263 | 22 | | JPEG-DSO-1 | **0.9253** | 0.1209 | 0.9195 | **0.6774** | 0.5124 | 23 | | AdvSample-DSO-1 | 0.9281 | 0.0541 | **0.9129** | 0.6877 | **0.5041** | 24 | 25 | # Qualitative Results 26 | 27 | ## Columbia Dataset 28 | 29 | ![](/assets/lots_examples/columbia_1.png) 30 | 31 | ![](/assets/lots_examples/columbia_2.png) 32 | 33 | ## DSO-1 Dataset 34 | 35 | ![](/assets/lots_examples/dso_1.png) 36 | 37 | ![](/assets/lots_examples/dso_2.png) 38 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import wandb 5 | from src.evaluation import Evaluator 6 | from src.utils import ConfigManager, load_yaml 7 | 8 | 9 | def main(config, args): 10 | # Initialize logger 11 | if args.wandb: 12 | wandb.init(project="exif-sc-attack", config=config, name=config["name"]) 13 | logger = wandb 14 | else: 15 | logger = None 16 | 17 | config_manager = ConfigManager(config) 18 | 19 | model = config_manager.init_object("model", weight_file=args.weights_path) 20 | dataset = config_manager.init_object("dataset") 21 | attacker = config_manager.init_object("attacker") 22 | 23 | evaluator = Evaluator( 24 | model, 25 | dataset, 26 | attacker, 27 | vis_dir=args.vis_dir, 28 | logger=logger, 29 | ) 30 | 31 | # Run evaluation 32 | results = evaluator(tuple(config["resize"])) 33 | 34 | # Save results 35 | print(results) 36 | with open(args.results_path, "w") as f: 37 | json.dump(results, f) 38 | 39 | # Log results 40 | if args.wandb: 41 | # Flatten nested dict 42 | log_results = {} 43 | for type, r in results.items(): 44 | for metric, value in r.items(): 45 | log_results[f"{type}/{metric}"] = value 46 | 47 | wandb.log(log_results) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument( 53 | "--config", 54 | help="path to the config file", 55 | default="configs/evaluate/config.yaml", 56 | ) 57 | parser.add_argument( 58 | "--weights_path", 59 | help="path to the weights / checkpoint to load", 60 | default="artifacts/exif_sc.npy", 61 | ) 62 | parser.add_argument( 63 | "--results_path", 64 | help="path to store evaluation results as JSON file", 65 | default="results.json", 66 | ) 67 | parser.add_argument( 68 | "--vis_dir", 69 | help="directory to save visualization results", 70 | ) 71 | parser.add_argument( 72 | "--wandb", 73 | action="store_true", 74 | help="whether to log to Weights & Biases", 75 | ) 76 | args = parser.parse_args() 77 | 78 | # Load config file 79 | config = load_yaml(args.config) 80 | 81 | main(config, args) 82 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import numpy as np 4 | from sklearn.metrics import ( 5 | average_precision_score, 6 | precision_recall_curve, 7 | f1_score, 8 | matthews_corrcoef, 9 | roc_auc_score, 10 | ) 11 | 12 | 13 | class LocalizationMetric: 14 | def __init__( 15 | self, metric: Callable[[np.ndarray, np.ndarray], float], thresh=False 16 | ) -> None: 17 | self.metric = metric 18 | # Whether the metric takes in binary decision maps 19 | self.thresh = thresh 20 | 21 | self.scores = [] 22 | 23 | def update(self, label_map: np.ndarray, score_map: np.ndarray) -> None: 24 | # FIXME Search for the optimal threshold? 25 | if self.thresh: 26 | pred_map = (score_map > 0.5).astype(int) 27 | else: 28 | pred_map = score_map 29 | 30 | score = self.metric(label_map.flatten(), pred_map.flatten()) 31 | 32 | # Consider inverted map 33 | inverted_map = 1 - score_map 34 | if self.thresh: 35 | pred_map = (inverted_map > 0.5).astype(int) 36 | else: 37 | pred_map = inverted_map 38 | 39 | inverted_score = self.metric(label_map.flatten(), pred_map.flatten()) 40 | 41 | # Take the better score 42 | self.scores.append(max(score, inverted_score)) 43 | 44 | def compute(self) -> float: 45 | return sum(self.scores) / len(self.scores) 46 | 47 | 48 | class mAP_Metric(LocalizationMetric): 49 | def __init__(self): 50 | super().__init__(average_precision_score) 51 | 52 | 53 | class F1_Metric(LocalizationMetric): 54 | def __init__(self): 55 | # Compute optimal f1 score 56 | def optimal_f1(y_true, y_score): 57 | precision, recall, thresholds = precision_recall_curve(y_true, y_score) 58 | f1_scores = 2 * recall * precision / (recall + precision) 59 | # Account for nan values 60 | f1_scores[np.isnan(f1_scores)] = 0 61 | 62 | return f1_scores.max() 63 | 64 | super().__init__(optimal_f1) 65 | 66 | 67 | class MCC_Metric(LocalizationMetric): 68 | def __init__(self): 69 | super().__init__(matthews_corrcoef, thresh=True) 70 | 71 | 72 | class AUC_Metric(LocalizationMetric): 73 | def __init__(self): 74 | super().__init__(roc_auc_score) 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | test/ 3 | NOTES.md 4 | wandb/ 5 | .vscode/ 6 | data/downloaded/ 7 | artifacts/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from pathlib import Path 3 | from typing import Dict, Union 4 | from urllib.request import urlretrieve 5 | 6 | import gdown 7 | from tqdm import tqdm 8 | 9 | 10 | class TqdmUpTo(tqdm): 11 | """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 12 | 13 | def update_to(self, blocks=1, bsize=1, tsize=None): 14 | """ 15 | Parameters 16 | ---------- 17 | blocks: int, optional 18 | Number of blocks transferred so far [default: 1]. 19 | bsize: int, optional 20 | Size of each block (in tqdm units) [default: 1]. 21 | tsize: int, optional 22 | Total size (in tqdm units). If [default: None] remains unchanged. 23 | """ 24 | if tsize is not None: 25 | self.total = tsize # pylint: disable=attribute-defined-outside-init 26 | self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 27 | 28 | 29 | def compute_sha256(filename: Union[Path, str]) -> str: 30 | """Return SHA256 checksum of a file.""" 31 | with open(filename, "rb") as f: 32 | return hashlib.sha256(f.read()).hexdigest() 33 | 34 | 35 | def download_url(url: str, filename: Path) -> None: 36 | """Download a file from url to filename, with a progress bar.""" 37 | with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 38 | urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 39 | 40 | 41 | def check_and_download_url( 42 | dl_dirname: Path, sha256: str, filename: str, url: str, gdrive: bool = False 43 | ) -> None: 44 | # If already exists, don't have to download 45 | filename = dl_dirname / filename 46 | if filename.exists(): 47 | return filename 48 | 49 | # Download file 50 | print(f"Downloading raw dataset from {url} to {filename}...") 51 | if gdrive: 52 | gdown.download(url, str(filename), quiet=False) 53 | else: 54 | download_url(url, filename) 55 | 56 | # Compute and check SHA256 57 | print("Computing SHA-256...") 58 | 59 | sha256_check = compute_sha256(filename) 60 | if sha256_check != sha256: 61 | raise ValueError( 62 | f"Downloaded data file SHA-256 ({sha256_check}) does not match that listed in metadata document." 63 | ) 64 | 65 | 66 | def download_raw_dataset( 67 | metadata: Dict, dl_dirname: Path, gdrive: bool = False 68 | ) -> None: 69 | dl_dirname.mkdir(parents=True, exist_ok=True) 70 | 71 | # Download multiple files 72 | if isinstance(metadata["filename"], list): 73 | for sha256, filename, url in zip( 74 | metadata["sha256"], metadata["filename"], metadata["url"] 75 | ): 76 | check_and_download_url(dl_dirname, sha256, filename, url, gdrive) 77 | 78 | # Download single file 79 | else: 80 | check_and_download_url( 81 | dl_dirname, 82 | metadata["sha256"], 83 | metadata["filename"], 84 | metadata["url"], 85 | gdrive, 86 | ) 87 | -------------------------------------------------------------------------------- /src/datasets/in_the_wild.py: -------------------------------------------------------------------------------- 1 | """In-the-Wild Image Splice Dataset 2 | 3 | - https://minyoungg.github.io/selfconsistency/ 4 | - M. Huh, A. Liu, A. Owens, A. A. Efros, Fighting Fake News: Image Splice Detection via Learned Self-Consistency In ECCV, 2018 5 | """ 6 | import tarfile 7 | from pathlib import Path 8 | from typing import Any, Dict 9 | 10 | import cv2 11 | import numpy as np 12 | import toml 13 | import torch 14 | from src.datasets.utils import download_raw_dataset 15 | from torch.utils.data import Dataset 16 | 17 | METADATA_FILENAME = Path("data/raw/in_the_wild/metadata.toml") 18 | DL_DATA_DIRNAME = Path("data/downloaded/in_the_wild") 19 | PROCESSED_DATA_DIRNAME = DL_DATA_DIRNAME / "label_in_wild" 20 | 21 | 22 | class InTheWildDataset(Dataset): 23 | def __init__(self, root_dir=PROCESSED_DATA_DIRNAME) -> None: 24 | self._prepare_data() 25 | 26 | root_dir = Path(root_dir) 27 | 28 | # Get list of all image paths 29 | img_dir = root_dir / "images" 30 | self.img_paths = list(img_dir.glob("*.jpg")) 31 | 32 | assert ( 33 | len(self.img_paths) == 201 34 | ), "Incorrect expected number of images in dataset!" 35 | 36 | def _prepare_data(self) -> None: 37 | if not PROCESSED_DATA_DIRNAME.exists(): 38 | metadata = toml.load(METADATA_FILENAME) 39 | # Download dataset 40 | download_raw_dataset(metadata, DL_DATA_DIRNAME) 41 | 42 | # Process downloaded dataset 43 | print("Unzipping In The Wild...") 44 | tar = tarfile.open(DL_DATA_DIRNAME / metadata["filename"], "r:gz") 45 | tar.extractall(DL_DATA_DIRNAME) 46 | tar.close() 47 | 48 | def __getitem__(self, idx) -> Dict[str, Any]: 49 | """ 50 | Returns 51 | ------- 52 | Dict[str, Any] 53 | img : torch.ByteTensor 54 | [C, H, W], range [0, 255] 55 | label : int 56 | One of {0, 1}. No meaningful labels for this dataset (all manipulated) 57 | map : np.ndarray (uint8) 58 | [H, W], values one of {0, 1} 59 | """ 60 | img_path = self.img_paths[idx] 61 | 62 | # Get image 63 | img = cv2.imread(str(img_path))[:, :, [2, 1, 0]] # [H, W, C] 64 | assert img.dtype == np.uint8, "Image should be of type int!" 65 | assert ( 66 | img.min() >= 0 and img.max() <= 255 67 | ), "Image should be bounded between [0, 255]!" 68 | 69 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 70 | 71 | # Get localization map 72 | img_name = img_path.stem 73 | map_path = img_path.parent.parent / "masks" / f"{img_name}.png" 74 | map = cv2.imread(str(map_path), cv2.IMREAD_GRAYSCALE) 75 | assert map.dtype == np.uint8, "Ground-truth should be of type int!" 76 | assert ( 77 | map.min() >= 0 and map.max() <= 255 78 | ), "Ground-truth should be bounded between [0, 255]!" 79 | 80 | map[map > 0] = 1 81 | 82 | return {"img": img, "label": 1, "map": map} 83 | 84 | def __len__(self): 85 | return len(self.img_paths) 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytorch_lightning as pl 4 | 5 | from src.models import EXIF_Net 6 | from src.trainers import EXIF_Trainer1, EXIF_Trainer2 7 | from src.utils import ConfigManager, load_yaml 8 | 9 | pl.seed_everything(42, workers=True) 10 | 11 | 12 | def main(config, args): 13 | # Initialize logger 14 | if args.wandb: 15 | logger = pl.loggers.WandbLogger( 16 | name=config["name"], 17 | project="exif-sc-train", 18 | ) 19 | logger.log_hyperparams(config) 20 | else: 21 | logger = None 22 | 23 | config_manager = ConfigManager(config) 24 | 25 | net = EXIF_Net(n_attrs=config["datamodule_args"]["n_exif_attr"]) 26 | 27 | # Stage 1 Training ######################################################### 28 | dm = config_manager.init_object("datamodule", label="attr") 29 | dm.prepare_data() 30 | dm.setup() 31 | exif_trainer = EXIF_Trainer1(net, dm, config) 32 | 33 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint( 34 | dirpath=f"{args.checkpoints_dir}/{config['name']}/stage_1", 35 | monitor="train/loss_1", 36 | mode="min", 37 | save_weights_only=True, 38 | ) 39 | callbacks = [model_checkpoint_callback] 40 | 41 | # Fit model 42 | trainer = pl.Trainer( 43 | callbacks=callbacks, 44 | logger=logger, 45 | gpus=[args.gpu], 46 | max_epochs=config["n_epochs_1"], 47 | deterministic=True, 48 | benchmark=True, 49 | # fast_dev_run=True, 50 | ) 51 | trainer.fit(exif_trainer, datamodule=dm) 52 | 53 | # Stage 2 Training ######################################################### 54 | dm = config_manager.init_object("datamodule", label="img") 55 | dm.prepare_data() 56 | dm.setup() 57 | exif_trainer = EXIF_Trainer2(net, dm, config) 58 | 59 | model_checkpoint_callback = pl.callbacks.ModelCheckpoint( 60 | dirpath=f"{args.checkpoints_dir}/{config['name']}/stage_2", 61 | monitor="train/loss_2", 62 | mode="min", 63 | save_weights_only=True, 64 | ) 65 | callbacks = [model_checkpoint_callback] 66 | 67 | # Fit model 68 | trainer = pl.Trainer( 69 | callbacks=callbacks, 70 | logger=logger, 71 | gpus=[args.gpu], 72 | max_epochs=config["n_epochs_2"], 73 | deterministic=True, 74 | benchmark=True, 75 | # fast_dev_run=True, 76 | ) 77 | trainer.fit(exif_trainer, datamodule=dm) 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | 83 | parser.add_argument( 84 | "--config", 85 | help="path to the config file", 86 | default="configs/train/exif_sc.yaml", 87 | ) 88 | parser.add_argument( 89 | "--checkpoints_dir", 90 | help="directory to save checkpoint weights", 91 | default="checkpoints", 92 | ) 93 | parser.add_argument("--gpu", help="which gpu id to use", type=int, default=0) 94 | parser.add_argument( 95 | "--wandb", 96 | action="store_true", 97 | help="whether to log to Weights & Biases", 98 | ) 99 | args = parser.parse_args() 100 | 101 | # Load config file 102 | config = load_yaml(args.config) 103 | 104 | main(config, args) 105 | -------------------------------------------------------------------------------- /src/datasets/scene_completion.py: -------------------------------------------------------------------------------- 1 | """Scene Completion Using Millions of Photographs Dataset 2 | 3 | - http://graphics.cs.cmu.edu/projects/scene-completion/ 4 | - James Hays, Alexei A. Efros. Scene Completion Using Millions of Photographs. ACM Transactions on Graphics (SIGGRAPH 2007). August 2007, vol. 26, No. 3. 5 | """ 6 | import zipfile 7 | from pathlib import Path 8 | from typing import Any, Dict 9 | 10 | import cv2 11 | import toml 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import Dataset 15 | from src.datasets.utils import download_raw_dataset 16 | 17 | METADATA_FILENAME = Path("data/raw/scene_completion/metadata.toml") 18 | DL_DATA_DIRNAME = Path("data/downloaded/scene_completion") 19 | PROCESSED_DATA_DIRNAME = DL_DATA_DIRNAME / "processed" 20 | 21 | 22 | class SceneCompletionDataset(Dataset): 23 | def __init__(self, root_dir=PROCESSED_DATA_DIRNAME) -> None: 24 | self._prepare_data() 25 | 26 | # Get list of all image names 27 | img_dir = Path(root_dir) 28 | self.img_paths = [p for p in img_dir.iterdir() if p.stem[-4:] != "mask"] 29 | 30 | assert ( 31 | len(self.img_paths) == 51 32 | ), "Incorrect expected number of images in dataset!" 33 | 34 | def _prepare_data(self) -> None: 35 | if not PROCESSED_DATA_DIRNAME.exists(): 36 | metadata = toml.load(METADATA_FILENAME) 37 | # Download dataset 38 | download_raw_dataset(metadata, DL_DATA_DIRNAME) 39 | 40 | # Process downloaded dataset 41 | print("Unzipping Scene Completion...") 42 | zip = zipfile.ZipFile(DL_DATA_DIRNAME / metadata["filename"]) 43 | PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) 44 | zip.extractall(PROCESSED_DATA_DIRNAME) 45 | zip.close() 46 | 47 | def __getitem__(self, idx) -> Dict[str, Any]: 48 | """ 49 | Returns 50 | ------- 51 | Dict[str, Any] 52 | img : torch.ByteTensor 53 | [C, H, W], range [0, 255] 54 | label : int 55 | One of {0, 1}. No meaningful labels for this dataset (all manipulated) 56 | map : np.ndarray (uint8) 57 | [H, W], values one of {0, 1} 58 | """ 59 | img_path = self.img_paths[idx] 60 | 61 | # Get image 62 | img = cv2.imread(str(img_path))[:, :, [2, 1, 0]] # [H, W, C] 63 | assert img.dtype == np.uint8, "Image should be of type int!" 64 | assert ( 65 | img.min() >= 0 and img.max() <= 255 66 | ), "Image should be bounded between [0, 255]!" 67 | 68 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 69 | 70 | # Get localization map 71 | img_name = img_path.stem 72 | img_ext = img_path.suffix 73 | 74 | map_path = img_path.parent / f"{img_name}_mask{img_ext}" 75 | # HACK 76 | if not map_path.is_file(): 77 | # Correct extension 78 | map_path = img_path.parent / f"{img_name}_mask.jpg" 79 | 80 | map = cv2.imread(str(map_path), cv2.IMREAD_GRAYSCALE) 81 | assert map.dtype == np.uint8, "Ground-truth should be of type int!" 82 | assert ( 83 | map.min() >= 0 and map.max() <= 255 84 | ), "Ground-truth should be bounded between [0, 255]!" 85 | 86 | map[map > 0] = 1 87 | 88 | return {"img": img, "label": 1, "map": map} 89 | 90 | def __len__(self): 91 | return len(self.img_paths) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fake-detection-lab 2 | Media Forensics / Fake Image Detection experiments in PyTorch. 3 | 4 | Project report can be found [here](report.pdf). 5 | 6 | # Installation 7 | We use `conda` for managing Python and CUDA versions, and `pip-tools` for managing Python package dependencies. 8 | 1. Specify the appropriate `cudatoolkit` and `cudnn` versions to install on your machine in the `environment.yml` file. 9 | 2. To create the `conda` environment, run: `conda env create` 10 | 3. Activate the environment: `conda activate fake-detection-lab` 11 | 4. Install all necessary packages: `pip-sync requirements/prod.txt` 12 | 13 | # Model Artifacts 14 | All model artifacts can be accessed and downloaded [here](https://drive.google.com/drive/folders/1Qm1WUUithm0dE1qnJXGfoCbMG37jq3mW?usp=sharing). 15 | - `exif_sc.npy`: EXIF-SC model weights 16 | 17 | # Project Structure 18 | ``` 19 | ├── artifacts 20 | │   └── exif_sc.npy <-- Store model weights here 21 | ├── assets 22 | ├── configs <-- Configuration files for scripts 23 | ├── data 24 | │   ├── downloaded <-- To store downloaded data 25 | │   └── raw <-- Dataset metadata 26 | ├── notebooks 27 | ├── requirements 28 | ├── src 29 | │   ├── attacks <-- Implementation of adversarial attacks 30 | │   ├── datasets <-- Data loading classes 31 | │   ├── evaluation <-- Evaluation classes and utilities 32 | │   ├── models <-- Implementation of detection models 33 | │   ├── trainers <-- Classes for model training 34 | │   ├── structures.py 35 | │   └── utils.py 36 | ├── evaluate.py <-- Main entry point for evaluation 37 | ├── non_adv_evaluate.py <-- Main entry point for evaluation 38 | ├── train.py <-- Main entry point for training 39 | └── ... 40 | ``` 41 | 42 | # Usage 43 | 44 | ## Training 45 | ``` 46 | python train.py \ 47 | --config configs/train/exif_sc.yaml \ 48 | --checkpoints_dir checkpoints \ 49 | --gpu 0 50 | ``` 51 | Runs training on a dataset, based on the settings specified in the configuration file. Weights are saved as a torch `.ckpt` file in the specified directory. 52 | 53 | More [info](src/trainers/README.md). 54 | 55 | ## Evaluation 56 | More info [here](src/models/exif_sc/README.md) and [here](src/attacks/README.md) 57 | ### Without Adversarial Attack 58 | ``` 59 | python non_adv_evaluate.py \ 60 | --config configs/evaluate/non_adv.yaml \ 61 | --weights_path path/to/weights.{npy, ckpt} 62 | ``` 63 | Runs the evaluation on a dataset, based on the settings specified in the configuration file. 64 | 65 | ### With Adversarial Attack 66 | ``` 67 | python evaluate.py \ 68 | --config configs/evaluate/adv.yaml \ 69 | --weights_path path/to/weights.{npy, ckpt} 70 | ``` 71 | Runs the evaluation on a clean dataset, and also on the dataset after it has been adversarially perturbed, based on the settings specified in the configuration file. 72 | 73 | # Datasets 74 | All metadata for the datasets used can be found [here](data/raw). 75 | 76 | # Resources 77 | ### Model Conversion 78 | - Microsoft's [MMdnn](https://github.com/microsoft/MMdnn) 79 | - [ONNX](https://github.com/onnx/onnx) 80 | 81 | ### Survey Papers 82 | - Media Forensics and DeepFakes: an overview ([Luisa Verdoliva, 2020](https://arxiv.org/abs/2001.06564)) 83 | - A Survey of Machine Learning Techniques in Adversarial Image Forensics ([Nowroozia et al., 2020](https://arxiv.org/abs/2010.09680)) 84 | 85 | ### Fake Detectors 86 | - Fighting Fake News: Image Splice Detection via Learned Self-Consistency ([Huh et al., ECCV 2018](https://minyoungg.github.io/selfconsistency/)) 87 | 88 | ### Adversarial Machine Learning 89 | - Adversarial Attack on Deep Learning-Based Splice Localization ([Rozsa et al., 2020](https://arxiv.org/abs/2004.08443)) 90 | -------------------------------------------------------------------------------- /src/trainers/exif_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pytorch_lightning import LightningDataModule, LightningModule 7 | from torchmetrics import Accuracy 8 | 9 | 10 | class EXIF_Trainer1(LightningModule): 11 | def __init__( 12 | self, net: nn.Module, datamodule: LightningDataModule, config: Dict[str, Any] 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.net = net 17 | self.dm = datamodule 18 | self.config = config 19 | 20 | # Initialize metrics 21 | self.exif_attrs = self.dm.exif_attrs 22 | self.metrics = nn.ModuleList([Accuracy() for _ in range(len(self.exif_attrs))]) 23 | 24 | def configure_optimizers(self): 25 | params = [p for p in self.parameters() if p.requires_grad] 26 | optimizer = torch.optim.Adam(params, lr=self.config["learning_rate_1"]) 27 | 28 | return optimizer 29 | 30 | def training_step(self, batch, batch_idx): 31 | # [2, B, C, H, W], [B, n_exif_attr] 32 | imgs, labels = batch 33 | 34 | _, B, C, H, W = imgs.shape 35 | imgs = imgs.view(-1, C, H, W) # [2*B, C, H, W] 36 | 37 | feats = self.net(imgs) # [B*2, 4096] 38 | feats = torch.cat([feats[:B], feats[B:]], dim=-1) # [B, 8192] 39 | 40 | logits = self.net.exif_fc(feats) # [B, n_exif_attr] 41 | 42 | labels_float = labels.float() 43 | loss = F.binary_cross_entropy_with_logits(logits, labels_float) 44 | 45 | # Log metrics 46 | self.log("train/loss_1", loss, prog_bar=True) 47 | 48 | # Compute accuracy for each attr 49 | with torch.no_grad(): 50 | probs = torch.sigmoid(logits) 51 | 52 | metrics_dict = {} 53 | for i, (attr, m) in enumerate(zip(self.exif_attrs, self.metrics)): 54 | m(probs[:, i], labels[:, i]) 55 | metrics_dict[f"val/{attr}_acc"] = m 56 | 57 | self.log_dict(metrics_dict, on_step=False, on_epoch=True) 58 | 59 | return loss 60 | 61 | 62 | class EXIF_Trainer2(LightningModule): 63 | def __init__( 64 | self, net: nn.Module, datamodule: LightningDataModule, config: Dict[str, Any] 65 | ) -> None: 66 | super().__init__() 67 | 68 | # Freeze entire network except for final classification MLP 69 | for name, params in net.named_parameters(): 70 | if "classifier_fc" not in name: 71 | params.requires_grad = False 72 | self.net = net 73 | 74 | self.dm = datamodule 75 | self.config = config 76 | 77 | # Initialize metrics 78 | self.acc = Accuracy() 79 | 80 | def configure_optimizers(self): 81 | params = [p for p in self.parameters() if p.requires_grad] 82 | optimizer = torch.optim.Adam(params, lr=self.config["learning_rate_2"]) 83 | 84 | return optimizer 85 | 86 | def training_step(self, batch, batch_idx): 87 | # [2, B, C, H, W], [B,] 88 | imgs, labels = batch 89 | 90 | _, B, C, H, W = imgs.shape 91 | imgs = imgs.view(-1, C, H, W) # [2*B, C, H, W] 92 | 93 | feats = self.net(imgs) # [B*2, 4096] 94 | feats = torch.cat([feats[:B], feats[B:]], dim=-1) # [B, 8192] 95 | 96 | logits = self.net.exif_fc(feats) # [B, n_exif_attr] 97 | binary_logit = self.net.classifier_fc(logits).view(-1) # [B,] 98 | 99 | labels_float = labels.float() 100 | loss = F.binary_cross_entropy_with_logits(binary_logit, labels_float) 101 | 102 | # Log metrics 103 | self.log("train/loss_2", loss, prog_bar=True) 104 | 105 | # Compute accuracy for image pair prediction 106 | with torch.no_grad(): 107 | prob = torch.sigmoid(binary_logit) 108 | 109 | self.acc(prob, labels) 110 | self.log("val/img_pred_acc", self.acc, on_step=False, on_epoch=True) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /src/datasets/realistic_tampering.py: -------------------------------------------------------------------------------- 1 | """Realistic Tampering Dataset 2 | 3 | - http://pkorus.pl/downloads/dataset-realistic-tampering 4 | - P. Korus & J. Huang, Multi-scale Analysis Strategies in PRNU-based Tampering Localization, IEEE Trans. Information Forensics & Security, 2017 5 | - P. Korus & J. Huang, Evaluation of Random Field Models in Multi-modal Unsupervised Tampering Localization, Proc. of IEEE Int. Workshop on Inf. Forensics and Security, 2016 6 | """ 7 | import zipfile 8 | from pathlib import Path 9 | from typing import Any, Dict 10 | 11 | import cv2 12 | import numpy as np 13 | import toml 14 | import torch 15 | from src.datasets.utils import download_raw_dataset 16 | from torch.utils.data import Dataset 17 | 18 | METADATA_FILENAME = Path("data/raw/realistic_tampering/metadata.toml") 19 | DL_DATA_DIRNAME = Path("data/downloaded/realistic_tampering") 20 | PROCESSED_DATA_DIRNAME = DL_DATA_DIRNAME / "data-images" 21 | 22 | 23 | class RealisticTamperingDataset(Dataset): 24 | def __init__(self, root_dir=PROCESSED_DATA_DIRNAME) -> None: 25 | self._prepare_data() 26 | 27 | self.to_label = {"pristine": 0, "tampered-realistic": 1} 28 | root_dir = Path(root_dir) 29 | 30 | # Get list of all image paths 31 | self.img_paths = [] 32 | 33 | folders = ["Canon_60D", "Nikon_D90", "Nikon_D7000", "Sony_A57"] 34 | sub_folders = ["pristine", "tampered-realistic"] 35 | 36 | for folder in folders: 37 | for sub_folder in sub_folders: 38 | img_dir = root_dir / folder / sub_folder 39 | # Grab all .TIF images 40 | self.img_paths.extend(img_dir.glob("*.TIF")) 41 | 42 | assert ( 43 | len(self.img_paths) == 440 44 | ), "Incorrect expected number of images in dataset!" 45 | 46 | def _prepare_data(self) -> None: 47 | if not PROCESSED_DATA_DIRNAME.exists(): 48 | metadata = toml.load(METADATA_FILENAME) 49 | # Download dataset 50 | download_raw_dataset(metadata, DL_DATA_DIRNAME, gdrive=True) 51 | 52 | # Process downloaded dataset 53 | print("Unzipping Realistic Tampering...") 54 | zip = zipfile.ZipFile(DL_DATA_DIRNAME / metadata["filename"]) 55 | zip.extractall(DL_DATA_DIRNAME) 56 | zip.close() 57 | 58 | def __getitem__(self, idx) -> Dict[str, Any]: 59 | """ 60 | Returns 61 | ------- 62 | Dict[str, Any] 63 | img : torch.ByteTensor 64 | [C, H, W], range [0, 255] 65 | label : int 66 | One of {0, 1} 67 | map : np.ndarray (uint8) 68 | [H, W], values one of {0, 1} 69 | """ 70 | img_path = self.img_paths[idx] 71 | 72 | # Get image 73 | img = cv2.imread(str(img_path))[:, :, [2, 1, 0]] # [H, W, C] 74 | assert img.dtype == np.uint8, "Image should be of type int!" 75 | assert ( 76 | img.min() >= 0 and img.max() <= 255 77 | ), "Image should be bounded between [0, 255]!" 78 | 79 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 80 | 81 | # Get label 82 | label = self.to_label[img_path.parent.name] 83 | 84 | # Get localization map 85 | if label: 86 | img_name = img_path.stem 87 | map_path = img_path.parent.parent / "ground-truth" / f"{img_name}.PNG" 88 | map = cv2.imread(str(map_path), cv2.IMREAD_GRAYSCALE) # [H, W] 89 | assert map.dtype == np.uint8, "Ground-truth should be of type int!" 90 | assert ( 91 | map.min() >= 0 and map.max() <= 255 92 | ), "Ground-truth should be bounded between [0, 255]!" 93 | 94 | # Turn all greys into black 95 | map[map > 0] = 1 96 | 97 | # If clean image 98 | else: 99 | _, height, width = img.shape 100 | map = np.zeros((height, width), dtype=np.uint8) 101 | 102 | return {"img": img, "label": label, "map": map} 103 | 104 | def __len__(self): 105 | return len(self.img_paths) 106 | -------------------------------------------------------------------------------- /src/datasets/dso_1.py: -------------------------------------------------------------------------------- 1 | """DSO-1 Dataset 2 | 3 | - https://recodbr.wordpress.com/code-n-data/#dso1_dsi1 4 | - T. J. d. Carvalho, C. Riess, E. Angelopoulou, H. Pedrini and A. d. R. Rocha, “Exposing Digital Image Forgeries by Illumination Color Classification,” in IEEE Transactions on Information Forensics and Security, vol. 8, no. 7, pp. 1182-1194, July 2013. doi: doi: 10.1109/TIFS.2013.2265677 5 | """ 6 | import zipfile 7 | from pathlib import Path 8 | from typing import Any, Dict 9 | 10 | import cv2 11 | import numpy as np 12 | import toml 13 | import torch 14 | from src.datasets.utils import download_raw_dataset 15 | from torch.utils.data import Dataset 16 | 17 | METADATA_FILENAME = Path("data/raw/dso_1/metadata.toml") 18 | DL_DATA_DIRNAME = Path("data/downloaded/dso_1") 19 | PROCESSED_DATA_DIRNAME = DL_DATA_DIRNAME / "tifs-database" 20 | 21 | 22 | class DSO_1_Dataset(Dataset): 23 | def __init__(self, root_dir=PROCESSED_DATA_DIRNAME, spliced_only=False) -> None: 24 | self._prepare_data() 25 | 26 | self.to_label = {"normal": 0, "splicing": 1} 27 | self.root_dir = Path(root_dir) 28 | 29 | # Get list of all image paths 30 | img_dir = self.root_dir / "DSO-1" 31 | self.img_paths = list(img_dir.glob("*.png")) 32 | 33 | # Filter out authentic images 34 | if spliced_only: 35 | self.img_paths = [ 36 | p for p in self.img_paths if p.stem.split("-")[0] == "splicing" 37 | ] 38 | 39 | dataset_len = 100 if spliced_only else 200 40 | assert ( 41 | len(self.img_paths) == dataset_len 42 | ), "Incorrect expected number of images in dataset!" 43 | 44 | def _prepare_data(self) -> None: 45 | if not PROCESSED_DATA_DIRNAME.exists(): 46 | metadata = toml.load(METADATA_FILENAME) 47 | # Download dataset 48 | download_raw_dataset(metadata, DL_DATA_DIRNAME) 49 | 50 | # Process downloaded dataset 51 | print("Unzipping DSO-1...") 52 | zip = zipfile.ZipFile(DL_DATA_DIRNAME / metadata["filename"]) 53 | zip.extractall(DL_DATA_DIRNAME) 54 | zip.close() 55 | 56 | def __getitem__(self, idx) -> Dict[str, Any]: 57 | """ 58 | Returns 59 | ------- 60 | Dict[str, Any] 61 | img : torch.ByteTensor 62 | [C, H, W], range [0, 255] 63 | label : int 64 | One of {0, 1} 65 | map : np.ndarray (uint8) 66 | [H, W], values one of {0, 1} 67 | """ 68 | img_path = self.img_paths[idx] 69 | 70 | # Get image 71 | img = cv2.imread(str(img_path))[:, :, [2, 1, 0]] # [H, W, C] 72 | assert img.dtype == np.uint8, "Image should be of type int!" 73 | assert ( 74 | img.min() >= 0 and img.max() <= 255 75 | ), "Image should be bounded between [0, 255]!" 76 | 77 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 78 | 79 | # Get label 80 | cat = img_path.stem.split("-")[0] 81 | label = self.to_label[cat] 82 | 83 | # Get spliced map 84 | map_dir = self.root_dir / "DSO-1-Fake-Images-Masks" 85 | _, height, width = img.shape 86 | 87 | if label: 88 | img_name = img_path.name 89 | map_path = map_dir / img_name 90 | map = cv2.imread(str(map_path), cv2.IMREAD_GRAYSCALE) 91 | assert map.dtype == np.uint8, "Ground-truth should be of type int!" 92 | assert ( 93 | map.min() >= 0 and map.max() <= 255 94 | ), "Ground-truth should be bounded between [0, 255]!" 95 | 96 | # Resize map if doesn't match image 97 | if (height, width) != map.shape: 98 | map = cv2.resize(map, (width, height), interpolation=cv2.INTER_LINEAR) 99 | 100 | map[map > 0] = 1 101 | 102 | # If authentic image 103 | else: 104 | map = np.zeros((height, width), dtype=np.uint8) 105 | 106 | return {"img": img, "label": label, "map": map} 107 | 108 | def __len__(self): 109 | return len(self.img_paths) 110 | -------------------------------------------------------------------------------- /notebooks/exif_sc_port.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": 3 14 | }, 15 | "orig_nbformat": 2 16 | }, 17 | "nbformat": 4, 18 | "nbformat_minor": 2, 19 | "cells": [ 20 | { 21 | "source": [ 22 | "# Porting EXIF-SC implementation to PyTorch\n", 23 | "\n", 24 | "Official code repository: https://github.com/minyoungg/selfconsistency\n", 25 | "\n", 26 | "1. Get PyTorch model building code for a TensorFlow-slim ResNet50 model using [MMdnn](https://github.com/Microsoft/MMdnn/blob/master/docs/tf2pytorch.md)\n", 27 | "\n", 28 | "```\n", 29 | "pip install mmdnn\n", 30 | "mmdownload -f tensorflow -n resnet_v2_50\n", 31 | "mmtoir -f tensorflow -n imagenet_resnet_v2_50.ckpt.meta -w imagenet_resnet_v2_50.ckpt --dstNode MMdnn_Output -o converted\n", 32 | "mmtocode -f pytorch -n converted.pb -w converted.npy -d converted_pytorch.py -dw converted_pytorch.npy\n", 33 | "```\n", 34 | "\n", 35 | "2. Download the EXIF-SC model checkpoint from the [official repo](https://github.com/minyoungg/selfconsistency)\n", 36 | "\n", 37 | "3. Examine the variables in the TensorFlow checkpoint `exif_final.ckpt`.\n", 38 | " - Extract all relevant weights, and make any necessary modifications in order to load them into PyTorch layers. \n", 39 | " - Modify the model building code `converted_pytorch.py` in order to load those weights into the PyTorch model." 40 | ], 41 | "cell_type": "markdown", 42 | "metadata": {} 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import tensorflow as tf\n", 51 | "from tqdm import tqdm" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "ckpt_path = 'ckpt/exif_final/exif_final.ckpt'" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "tf_vars = tf.train.list_variables(ckpt_path)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "for name, shape in tf_vars:\n", 79 | " print(name, shape)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Remove unncessary variables\n", 89 | "# Modify the weights into a format suitable for PyTorch\n", 90 | "weights_dict = {}\n", 91 | "\n", 92 | "for name, _ in tqdm(tf_vars):\n", 93 | "\n", 94 | " name_split = name.split('/')\n", 95 | " weight_type = name_split[-1]\n", 96 | " \n", 97 | " # Exclude unnecessary variables\n", 98 | " if weight_type in ['beta1_power', 'beta2_power', 'Adam', 'Adam_1']:\n", 99 | " continue\n", 100 | "\n", 101 | " weight_name = '/'.join(name_split[:-1])\n", 102 | "\n", 103 | " weights = tf.train.load_variable(ckpt_path, name)\n", 104 | " \n", 105 | " # Transpose CNN weights\n", 106 | " # [H, W, C, F] -> [F, C, H, W]\n", 107 | " if len(weights.shape) == 4:\n", 108 | " weights = np.transpose(weights, (3, 2, 0, 1))\n", 109 | " # Tranpose linear matrices\n", 110 | " if len(weights.shape) == 2:\n", 111 | " weights = np.transpose(weights, (1, 0))\n", 112 | "\n", 113 | " if weight_name not in weights_dict:\n", 114 | " weights_dict[weight_name] = {}\n", 115 | " weights_dict[weight_name][weight_type] = weights" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Save the weights into a separate file\n", 125 | "np.save('ckpt/resnet_50_pt/exif_final.npy', weights_dict)" 126 | ] 127 | } 128 | ] 129 | } -------------------------------------------------------------------------------- /src/datasets/columbia.py: -------------------------------------------------------------------------------- 1 | """Columbia Uncompressed Image Splicing Detection Evaluation Dataset 2 | 3 | - https://www.ee.columbia.edu/ln/dvmm/downloads/authsplcuncmp/ 4 | - Detecting Image Splicing Using Geometry Invariants And Camera Characteristics Consistency, Yu-Feng Hsu, Shih-Fu Chang 5 | """ 6 | import tarfile 7 | from pathlib import Path 8 | from typing import Any, Dict 9 | 10 | import cv2 11 | import numpy as np 12 | import toml 13 | import torch 14 | from src.datasets.utils import download_raw_dataset 15 | from torch.utils.data import Dataset 16 | 17 | METADATA_FILENAME = Path("data/raw/columbia/metadata.toml") 18 | DL_DATA_DIRNAME = Path("data/downloaded/columbia") 19 | PROCESSED_DATA_DIRNAMES = [DL_DATA_DIRNAME / "4cam_auth", DL_DATA_DIRNAME / "4cam_splc"] 20 | 21 | 22 | class ColumbiaDataset(Dataset): 23 | def __init__(self, root_dir=DL_DATA_DIRNAME, spliced_only=False) -> None: 24 | self._prepare_data() 25 | 26 | self.to_label = {"4cam_auth": 0, "4cam_splc": 1} 27 | root_dir = Path(root_dir) 28 | 29 | # Get list of all image paths 30 | self.img_paths = [] 31 | 32 | # Grab authentic images 33 | if not spliced_only: 34 | auth_dir = root_dir / "4cam_auth" 35 | auth_paths = list(auth_dir.glob("*.tif")) 36 | assert ( 37 | len(auth_paths) == 183 38 | ), "Incorrect expected number of authentic images in dataset!" 39 | 40 | self.img_paths.extend(auth_paths) 41 | 42 | # Grab spliced images 43 | splc_dir = root_dir / "4cam_splc" 44 | splc_paths = list(splc_dir.glob("*.tif")) 45 | assert ( 46 | len(splc_paths) == 180 47 | ), "Incorrect expected number of spliced images in dataset!" 48 | 49 | self.img_paths.extend(splc_paths) 50 | 51 | def _prepare_data(self) -> None: 52 | if not all(p.exists() for p in PROCESSED_DATA_DIRNAMES): 53 | metadata = toml.load(METADATA_FILENAME) 54 | # Download dataset 55 | download_raw_dataset(metadata, DL_DATA_DIRNAME) 56 | 57 | # Process downloaded dataset 58 | print("Unzipping Columbia...") 59 | for filename in metadata["filename"]: 60 | tar = tarfile.open(DL_DATA_DIRNAME / filename, "r:bz2") 61 | tar.extractall(DL_DATA_DIRNAME) 62 | tar.close() 63 | 64 | def __getitem__(self, idx) -> Dict[str, Any]: 65 | """ 66 | Returns 67 | ------- 68 | Dict[str, Any] 69 | img : torch.ByteTensor 70 | [C, H, W], range [0, 255] 71 | label : int 72 | One of {0, 1} 73 | map : np.ndarray (uint8) 74 | [H, W], values one of {0, 1} 75 | """ 76 | img_path = self.img_paths[idx] 77 | 78 | # Get image 79 | img = cv2.imread(str(img_path))[:, :, [2, 1, 0]] # [H, W, C] 80 | assert img.dtype == np.uint8, "Image should be of type int!" 81 | assert ( 82 | img.min() >= 0 and img.max() <= 255 83 | ), "Image should be bounded between [0, 255]!" 84 | 85 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 86 | 87 | # Get label 88 | label = self.to_label[img_path.parent.name] 89 | 90 | # Get localization map 91 | BRIGHT_GREEN = np.array([0, 255, 0]) 92 | REGULAR_GREEN = np.array([0, 200, 0]) 93 | 94 | _, height, width = img.shape 95 | 96 | if label: 97 | img_name = img_path.stem 98 | map_path = img_path.parent / "edgemask" / f"{img_name}_edgemask.jpg" 99 | map = cv2.imread(str(map_path))[:, :, [2, 1, 0]] # [H, W, C] 100 | 101 | # FIXME Should I include bright red too? 102 | # Find spliced region, i.e. green regions 103 | binary_map = np.zeros((height, width), dtype=np.uint8) 104 | bright_green_mask = (map == BRIGHT_GREEN).all(axis=-1) 105 | regular_green_mask = (map == REGULAR_GREEN).all(axis=-1) 106 | binary_map[bright_green_mask | regular_green_mask] = 1 107 | 108 | # If authentic image 109 | else: 110 | binary_map = np.zeros((height, width), dtype=np.uint8) 111 | 112 | return {"img": img, "label": label, "map": binary_map} 113 | 114 | def __len__(self): 115 | return len(self.img_paths) 116 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from typing import Dict, List 3 | 4 | import yaml 5 | 6 | 7 | # def read_image(image_uri: Union[Path, str], grayscale=False) -> np.array: 8 | # """Read image_uri.""" 9 | 10 | # def read_image_from_filename(image_filename, imread_flag): 11 | # # FIXME Change order of channels 12 | # return cv2.imread(str(image_filename), imread_flag) 13 | 14 | # def read_image_from_url(image_url, imread_flag): 15 | # url_response = urlopen(str(image_url)) # nosec 16 | # img_array = np.array(bytearray(url_response.read()), dtype=np.uint8) 17 | # return cv2.imdecode(img_array, imread_flag) 18 | 19 | # imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR 20 | # local_file = os.path.exists(image_uri) 21 | # try: 22 | # img = None 23 | # if local_file: 24 | # img = read_image_from_filename(image_uri, imread_flag) 25 | # else: 26 | # img = read_image_from_url(image_uri, imread_flag) 27 | # assert img is not None 28 | # except Exception as e: 29 | # raise ValueError("Could not load image at {}: {}".format(image_uri, e)) 30 | # return img 31 | 32 | 33 | # def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: 34 | # """Write image to file.""" 35 | # cv2.imwrite(str(filename), image) 36 | 37 | 38 | # class TqdmUpTo(tqdm): 39 | # """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" 40 | 41 | # def update_to(self, blocks=1, bsize=1, tsize=None): 42 | # """ 43 | # Parameters 44 | # ---------- 45 | # blocks : int, optional 46 | # Number of blocks transferred so far [default: 1]. 47 | # bsize : int, optional 48 | # Size of each block (in tqdm units) [default: 1]. 49 | # tsize : int, optional 50 | # Total size (in tqdm units). If [default: None] remains unchanged. 51 | # """ 52 | # if tsize is not None: 53 | # self.total = tsize # pylint: disable=attribute-defined-outside-init 54 | # self.update(blocks * bsize - self.n) # will also set self.n = b * bsize 55 | 56 | 57 | # def download_url(url, filename): 58 | # """Download a file from url to filename, with a progress bar.""" 59 | # with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: 60 | # urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec 61 | 62 | 63 | class ConfigManager: 64 | def __init__(self, config: Dict): 65 | self.config = config 66 | 67 | def init_object(self, name: str, has_args=True, *args, **kwargs) -> object: 68 | # Root module 69 | root = "src" 70 | 71 | object_path = self.config[name] 72 | if object_path is None: 73 | return None 74 | 75 | module_name, object_name = object_path.rsplit(".", 1) 76 | 77 | module = import_module(f"{root}.{module_name}") 78 | 79 | if has_args: 80 | object_args = self.config[f"{name}_args"] or {} 81 | else: 82 | object_args = {} 83 | kwargs = {**kwargs, **object_args} 84 | 85 | return getattr(module, object_name)(*args, **kwargs) 86 | 87 | def init_objects(self, name: str, *args, **kwargs) -> List[object]: 88 | # Root module 89 | root = "src" 90 | 91 | objects = [] 92 | 93 | object_paths = self.config[name] 94 | n_objects = len(object_paths) 95 | object_args = self.config[f"{name}_args"] or [{}] * n_objects 96 | 97 | # Repeat single args across objects 98 | args = [arg if isinstance(arg, list) else [arg] * n_objects for arg in args] 99 | # print(args) 100 | args = list(zip(*args)) 101 | # FIXME Figure out something for kwargs 102 | 103 | for object_path, object_arg, arg in zip(object_paths, object_args, args): 104 | module_name, object_name = object_path.rsplit(".", 1) 105 | module = import_module(f"{root}.{module_name}") 106 | 107 | objects.append(getattr(module, object_name)(*arg, **object_arg)) 108 | 109 | return objects 110 | 111 | 112 | def load_yaml(path): 113 | with open(path, "r") as file: 114 | try: 115 | yaml_file = yaml.safe_load(file) 116 | except yaml.YAMLError as exc: 117 | print(exc) 118 | 119 | return yaml_file 120 | 121 | 122 | # def _import_class(module_and_class_name: str) -> type: 123 | # """Import class from a module, e.g. 'text_recognizer.models.MLP'""" 124 | # module_name, class_name = module_and_class_name.rsplit(".", 1) 125 | # module = importlib.import_module(module_name) 126 | # class_ = getattr(module, class_name) 127 | # return class_ 128 | -------------------------------------------------------------------------------- /src/structures.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class PatchedImage: 8 | def __init__( 9 | self, 10 | data: torch.Tensor, 11 | patch_size=128, 12 | num_per_dim=30, 13 | ) -> None: 14 | """Representation of an image that is sliced into patches""" 15 | self.data = data.float() 16 | 17 | # Initialize image attributes 18 | self.patch_size = patch_size 19 | self.num_per_dim = num_per_dim 20 | 21 | self.shape = data.shape 22 | _, height, width = data.shape 23 | 24 | # Compute patch stride on image 25 | self.stride = (max(height, width) - self.patch_size) // self.num_per_dim 26 | 27 | # Compute total number of patches along height and width dimension 28 | self.max_h_idx = 1 + floor((height - self.patch_size) / self.stride) 29 | self.max_w_idx = 1 + floor((width - self.patch_size) / self.stride) 30 | 31 | def get_patch(self, h_idx: int, w_idx: int) -> torch.Tensor: 32 | """Get a patch from the image 33 | 34 | Parameters 35 | ---------- 36 | h_idx : int 37 | w_idx : int 38 | 39 | Returns 40 | ------- 41 | torch.Tensor 42 | [3, patch_size, patch_size] 43 | """ 44 | h_coord = h_idx * self.stride 45 | w_coord = w_idx * self.stride 46 | 47 | return self.data[ 48 | :, h_coord : h_coord + self.patch_size, w_coord : w_coord + self.patch_size 49 | ] 50 | 51 | def get_patch_map(self, h_idx: int, w_idx: int) -> torch.ByteTensor: 52 | """ 53 | Parameters 54 | ---------- 55 | h_idx : int 56 | w_idx : int 57 | 58 | Returns 59 | ------- 60 | torch.ByteTensor 61 | [H, W], values of {0, 1} 62 | """ 63 | h_coord = h_idx * self.stride 64 | w_coord = w_idx * self.stride 65 | 66 | _, height, width = self.shape 67 | 68 | binary_map = torch.zeros(height, width, dtype=torch.bool) 69 | binary_map[ 70 | h_coord : h_coord + self.patch_size, w_coord : w_coord + self.patch_size 71 | ] = True 72 | 73 | return binary_map 74 | 75 | def get_patches(self, idxs: torch.Tensor) -> torch.Tensor: 76 | """Get patches from image given its indices 77 | 78 | Parameters 79 | ---------- 80 | idxs : torch.Tensor 81 | [n_patches, 2], [n_patches, (h_idx, w_idx)] 82 | 83 | Returns 84 | ------- 85 | torch.Tensor 86 | [n_patches, 3, patch_size, patch_size] 87 | """ 88 | n_patches = idxs.shape[0] 89 | patches = torch.zeros( 90 | n_patches, 3, self.patch_size, self.patch_size, device=self.data.device 91 | ) 92 | 93 | # FIXME Any way to vectorize this? 94 | # https://discuss.pytorch.org/t/advanced-fancy-indexing-across-batches/103445 95 | for i, idx in enumerate(idxs): 96 | h_idx, w_idx = idx 97 | 98 | patches[i] = self.get_patch(h_idx, w_idx) 99 | 100 | return patches 101 | 102 | def get_patch_maps(self, idxs: torch.Tensor) -> torch.Tensor: 103 | n_patches = idxs.shape[0] 104 | _, height, width = self.shape 105 | 106 | maps = torch.zeros(n_patches, height, width, dtype=torch.bool) 107 | 108 | for i, idx in enumerate(idxs): 109 | h_idx, w_idx = idx 110 | 111 | maps[i] = self.get_patch_map(h_idx, w_idx) 112 | 113 | return maps 114 | 115 | def patches_gen(self, batch_size=32) -> torch.Tensor: 116 | """Generator for all patches in an image, in raster scan order 117 | 118 | Parameters 119 | ---------- 120 | batch_size : int, optional 121 | Number of patches in each iteration, by default 32 122 | 123 | Returns 124 | ------- 125 | torch.Tensor 126 | [batch_size, 3, patch_size, patch_size] 127 | """ 128 | count = 0 129 | 130 | # Initialize indices / coords of all patches, [n_patches, 2] 131 | h_idxs = torch.arange(self.max_h_idx) 132 | w_idxs = torch.arange(self.max_w_idx) 133 | idxs = torch.stack(torch.meshgrid([h_idxs, w_idxs])).view(2, -1).T 134 | 135 | n_patches = len(idxs) 136 | 137 | while True: 138 | # Break when run out of patches 139 | if count * batch_size >= n_patches: 140 | break 141 | 142 | # Yield a batch of patches 143 | patches = self.get_patches( 144 | idxs[count * batch_size : (count + 1) * batch_size] 145 | ) 146 | yield patches 147 | count += 1 148 | 149 | def patch_maps_gen(self, batch_size=32) -> torch.Tensor: 150 | count = 0 151 | 152 | # Initialize indices / coords of all patches, [n_patches, 2] 153 | h_idxs = torch.arange(self.max_h_idx) 154 | w_idxs = torch.arange(self.max_w_idx) 155 | idxs = torch.stack(torch.meshgrid([h_idxs, w_idxs])).view(2, -1).T 156 | 157 | n_patches = len(idxs) 158 | 159 | while True: 160 | # Break when run out of patches 161 | if count * batch_size >= n_patches: 162 | break 163 | 164 | # Yield a batch of patches 165 | patches = self.get_patch_maps( 166 | idxs[count * batch_size : (count + 1) * batch_size] 167 | ) 168 | yield patches 169 | count += 1 170 | 171 | def pred_idxs_gen(self, batch_size=32) -> torch.Tensor: 172 | """Generator for all prediction map indices 173 | 174 | Parameters 175 | ---------- 176 | batch_size : int, optional 177 | Number of indices in each iteration, by default 32 178 | 179 | Returns 180 | ------- 181 | torch.Tensor 182 | [batch_size, 4] 183 | """ 184 | # h_idxs = torch.arange(self.max_h_idx) 185 | # w_idxs = torch.arange(self.max_w_idx) 186 | 187 | # # All possible pairs of patches, [n_idxs, 4] 188 | # # n_idxs: max_h_ind ^ 2 * max_w_ind ^ 2 189 | # idxs = ( 190 | # torch.stack(torch.meshgrid([h_idxs, w_idxs, h_idxs, w_idxs])).view(4, -1).T 191 | # ) 192 | 193 | # # All possible pairs of patches, [n_idxs, 4] 194 | # # n_idxs: max_h_ind ^ 2 * max_w_ind ^ 2 195 | idxs = ( 196 | np.mgrid[ 197 | 0 : self.max_h_idx, 198 | 0 : self.max_w_idx, 199 | 0 : self.max_h_idx, 200 | 0 : self.max_w_idx, 201 | ] 202 | .reshape((4, -1)) 203 | .T 204 | ) 205 | 206 | count = 0 207 | while True: 208 | if count * batch_size >= len(idxs): 209 | break 210 | 211 | yield idxs[count * batch_size : (count + 1) * batch_size] 212 | count += 1 213 | -------------------------------------------------------------------------------- /src/attacks/lots.py: -------------------------------------------------------------------------------- 1 | """Adaptation of the LOTS algorithm for patch-based features 2 | 3 | - Rozsa, A., Zhong, Z., & Boult, T. (2020). Adversarial Attack on Deep Learning-Based Splice Localization. 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2757-2765. 4 | - https://arxiv.org/abs/2004.08443 5 | """ 6 | from typing import Any, Dict 7 | 8 | import numpy as np 9 | import torch 10 | from src.structures import PatchedImage 11 | from tqdm import tqdm 12 | 13 | 14 | class PatchLOTS: 15 | def __init__( 16 | self, 17 | step_size=5, 18 | n_iter=50, 19 | feat_batch_size=32, 20 | # pred_batch_size=1024, 21 | method="mean", # One of {'mean', 'sample'} 22 | ) -> None: 23 | """ 24 | Parameters 25 | ---------- 26 | step_size : int, optional 27 | Step size of gradient update, by default 5 28 | n_iter : int, optional 29 | Number of gradient updates, by default 50 30 | feat_batch_size : int, optional 31 | , by default 32 32 | method : str, optional 33 | One of {'mean', 'sample'}, by default "mean" 34 | """ 35 | self.step_size = step_size 36 | self.n_iter = n_iter 37 | self.feat_batch_size = feat_batch_size 38 | self.method = method 39 | 40 | def __call__( 41 | self, 42 | model, 43 | data: Dict[str, Any], 44 | ) -> torch.ByteTensor: 45 | """ 46 | Parameters 47 | ---------- 48 | model : [type] 49 | data : Dict[str, Any] 50 | From dataloader 51 | 52 | Returns 53 | ------- 54 | torch.ByteTensor 55 | [C, H, W], the adversarially perturbed image 56 | """ 57 | # Make a copy cos will be modifying it 58 | img = data["img"].detach().clone() 59 | gt_map = data["map"] 60 | 61 | # Perform prediction on clean image 62 | # print("Performing prediction on clean image...") 63 | # clean_preds = model.predict(img) 64 | 65 | print("Performing adversarial attack...") 66 | 67 | img = model.init_img(img) 68 | n_patches = img.max_h_idx * img.max_w_idx 69 | 70 | # Get patch features, [N, D] 71 | patch_feats = model.get_patch_feats(img, batch_size=self.feat_batch_size) 72 | 73 | # Identify all authentic patches 74 | # Find patches with no overlap with spliced regions 75 | auth_feats, is_auth = self._get_auth_feats( 76 | img, gt_map, patch_feats, self.feat_batch_size 77 | ) 78 | if len(auth_feats) == 0: 79 | return img.data.detach().clone().round().byte().cpu() 80 | 81 | # Determine target features 82 | if self.method == "mean": 83 | # Get mean feature representation, t, of all authentic patches 84 | # [N, D] 85 | target_feats = auth_feats.mean(dim=0, keepdim=True).expand(n_patches, -1) 86 | elif self.method == "sample": 87 | # Instead of fixing a target feature 88 | # Model feature distribution of authentic regions 89 | # Make features of spliced region close to that distribution 90 | # Simply perform sampling? 91 | 92 | target_feats = patch_feats 93 | n_auths = len(auth_feats) 94 | n_not_auths = len(target_feats) - n_auths 95 | 96 | # Sample auth features 97 | sample_idx = torch.randint(n_auths, size=(n_not_auths,)) 98 | 99 | # Replace not-auth regions with samples from auth regions 100 | target_feats[~is_auth] = auth_feats[sample_idx] 101 | 102 | # Make all the patches close to target features 103 | # Compute perturbation for each patch 104 | 105 | # Cache the best perturbed image thus far 106 | best_img = img.data.detach().clone() 107 | best_loss = float("inf") 108 | 109 | for _ in tqdm(range(self.n_iter)): 110 | # FIXME Normalize image instead? Then step size will be smaller 111 | img.data.requires_grad = True 112 | 113 | total_loss = 0 114 | 115 | # Have to split patches into batches (to fit in GPU memory) 116 | for i, patches in enumerate(img.patches_gen(self.feat_batch_size)): 117 | patch_feats = model.net(patches) 118 | # If missing batch dimension 119 | if len(patch_feats.shape) == 1: 120 | patch_feats = patch_feats.view(1, -1) 121 | 122 | # Compute distance from target feature 123 | curr_target_feats = target_feats[ 124 | i * self.feat_batch_size : (i + 1) * self.feat_batch_size 125 | ] 126 | adv_loss_per_patch = ((curr_target_feats - patch_feats) ** 2).sum( 127 | -1 128 | ) / 2 129 | adv_loss = adv_loss_per_patch.sum() 130 | 131 | # FIXME How to combine gradients from overlapping patches? 132 | # Just accumulate? Have to normalize? 133 | adv_loss.backward() 134 | 135 | total_loss += adv_loss.detach() 136 | 137 | # Perform update 138 | img_grad = img.data.grad.detach() 139 | with torch.no_grad(): 140 | grad_norm = torch.linalg.norm(img.data.flatten(), ord=float("inf")) 141 | img.data = img.data - self.step_size * (img_grad / grad_norm) 142 | 143 | # Clip pixels 144 | img.data = img.data.clamp(0, 255) 145 | 146 | # Reset gradients 147 | img.data.grad = None 148 | 149 | # Choose the perturbed image that has features closest to target 150 | if total_loss < best_loss: 151 | # print(f"Iter {i}: Found better adversarial example") 152 | best_loss = total_loss 153 | best_img = img.data.detach().clone() 154 | 155 | # Round pixel values to be discrete 156 | adv_img = best_img.round().byte().cpu() 157 | 158 | # Perform prediction on adversarial image 159 | # print("Performing prediction on adversarial image...") 160 | # adv_preds = model.predict(best_img) 161 | 162 | # Convert into numpy image 163 | # adv_img = best_img.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 164 | 165 | # return clean_preds, adv_preds, adv_img 166 | return adv_img 167 | 168 | def _get_auth_feats( 169 | self, 170 | img: PatchedImage, 171 | gt_map: np.ndarray, 172 | patch_feats: torch.Tensor, 173 | batch_size=32, 174 | ) -> torch.Tensor: 175 | 176 | gt_map = torch.BoolTensor(gt_map) 177 | # Keep track of which patches are authentic 178 | is_auth = torch.zeros(img.max_h_idx * img.max_w_idx, dtype=torch.bool) 179 | 180 | # Find all authentic patches 181 | # FIXME Vectorize this 182 | # FIXME Put onto GPU? 183 | for i, patch_maps in enumerate(img.patch_maps_gen(batch_size)): 184 | # Check whether each patch overlaps with the spliced ground-truth 185 | is_auth[i * batch_size : (i + 1) * batch_size] = ~( 186 | (patch_maps & gt_map).flatten(1, 2).any(dim=-1) 187 | ) 188 | 189 | auth_feats = patch_feats[is_auth] 190 | 191 | # [N_auth, D], [N,] 192 | return auth_feats, is_auth 193 | -------------------------------------------------------------------------------- /notebooks/viz.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.7.10" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3710jvsc74a57bd0f8cb47d26f652eae6d609ded3532fd9bf573c023c853475c39cb6ed10b9f3c5e", 18 | "display_name": "Python 3.7.10 64-bit ('fake-detection-lab': conda)" 19 | } 20 | }, 21 | "nbformat": 4, 22 | "nbformat_minor": 2, 23 | "cells": [ 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import os\n", 31 | "os.chdir('..')\n", 32 | "\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from src.datasets import (\n", 44 | " ColumbiaDataset,\n", 45 | " DSO_1_Dataset,\n", 46 | " InTheWildDataset,\n", 47 | " RealisticTamperingDataset,\n", 48 | " SceneCompletionDataset\n", 49 | ")\n", 50 | "from src.models.exif_sc import EXIF_SC\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "from tqdm import tqdm" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def plot_pred(data, pred, save_path=None):\n", 62 | " plt.subplots(figsize=(16, 8))\n", 63 | " plt.subplot(2, 2, 1)\n", 64 | " plt.title('Input Image')\n", 65 | " plt.imshow(data['img'].permute(1, 2, 0))\n", 66 | " plt.axis('off')\n", 67 | "\n", 68 | " plt.subplot(2, 2, 2)\n", 69 | " plt.title('Cluster w/ MeanShift')\n", 70 | " plt.axis('off')\n", 71 | " plt.imshow(pred['ms'], cmap='jet', vmin=0.0, vmax=1.0)\n", 72 | "\n", 73 | " plt.subplot(2, 2, 3)\n", 74 | " plt.title('Ground-truth Segment')\n", 75 | " plt.axis('off')\n", 76 | " plt.imshow(data['map'], vmin=0.0, vmax=1.0, cmap=\"gray\")\n", 77 | "\n", 78 | " plt.subplot(2, 2, 4)\n", 79 | " plt.title('Segment with NCuts')\n", 80 | " plt.axis('off')\n", 81 | " plt.imshow(pred['ncuts'], vmin=0.0, vmax=1.0, cmap=\"gray\")\n", 82 | "\n", 83 | " if save_path:\n", 84 | " plt.savefig(save_path + f'/{i}.png')" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "output_type": "stream", 94 | "name": "stderr", 95 | "text": [ 96 | " 0%| | 0/363 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'img'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mplot_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'artifacts/columbia_egs'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 111 | "\u001b[0;32m~/code/fake-detection-lab/src/models/exif_sc/exif_sc.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, img, feat_batch_size, pred_batch_size, blue_high)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;31m# Predict consistency maps\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m pred_maps = self._predict_consistency_maps(\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpred_batch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m )\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 112 | "\u001b[0;32m~/code/fake-detection-lab/src/models/exif_sc/exif_sc.py\u001b[0m in \u001b[0;36m_predict_consistency_maps\u001b[0;34m(self, img, patch_features, batch_size)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# [B, 1]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;31m# FIXME Is it possible to vectorize this?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 113 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "ds = ColumbiaDataset()\n", 119 | "\n", 120 | "model = EXIF_SC(\"artifacts/exif_sc.npy\", device=\"cuda:1\")\n", 121 | "for param in model.net.parameters():\n", 122 | " param.requires_grad = False\n", 123 | "\n", 124 | "for i in tqdm(range(len(ds))):\n", 125 | " data = ds[i]\n", 126 | "\n", 127 | " pred = model.predict(data['img'])\n", 128 | " plot_pred(data, pred, save_path='artifacts/columbia_egs')" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | } 138 | ] 139 | } -------------------------------------------------------------------------------- /src/evaluation/non_adv_evaluators.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import average_precision_score 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | 11 | 12 | class NonAdvEvaluator: 13 | def __init__(self, model, dataset: Dataset) -> None: 14 | self.model = model 15 | self.dataset = dataset 16 | 17 | self.metrics = {} 18 | 19 | def evaluate(self, resize: Tuple[int, int] = None) -> Dict[str, Any]: 20 | """ 21 | Parameters 22 | ---------- 23 | save : bool, optional 24 | Whether to save prediction arrays, by default False 25 | resize : Tuple[int, int], optional 26 | [H, W], whether to resize images / maps to a consistent shape 27 | 28 | Returns 29 | ------- 30 | Dict[str, Any] 31 | AP : float 32 | Average precision score, for detection 33 | IoU : float 34 | Class-balanced IoU, for localization 35 | """ 36 | y_true = [] 37 | label_map = [] 38 | 39 | y_score = [] 40 | score_map = [] 41 | ncut = [] 42 | 43 | for i in tqdm(range(len(self.dataset))): 44 | data = self.dataset[i] 45 | # Perform prediction 46 | pred = self.model.predict(data["img"]) 47 | 48 | # If image sizes different, resize to a consistent shape 49 | if resize: 50 | data["map"] = cv2.resize( 51 | data["map"], resize, interpolation=cv2.INTER_LINEAR 52 | ) 53 | pred["ms"] = cv2.resize( 54 | pred["ms"], resize, interpolation=cv2.INTER_LINEAR 55 | ) 56 | 57 | # Store ground-truths 58 | y_true.append(data["label"]) 59 | label_map.append(data["map"]) 60 | 61 | # Store predictions 62 | y_score.append(pred["score"]) 63 | score_map.append(pred["ms"]) 64 | # ncut.append(pred["ncuts"]) 65 | 66 | y_true = np.array(y_true) 67 | label_map = np.stack(label_map, axis=0) 68 | 69 | y_score = np.array(y_score) 70 | score_map = np.stack(score_map, axis=0) 71 | # ncut = np.stack(ncut, axis=0) 72 | 73 | # Save predictions 74 | # if save: 75 | # save_path = Path("artifacts/predictions") 76 | # np.save(save_path / "scores.npy", y_score) 77 | # np.save(save_path / "score_maps.npy", score_map) 78 | # # np.save(save_path / "rt_ncuts.npy", ncut) 79 | 80 | # Compute localization metrics 81 | self._compute_localization_metrics(label_map, score_map) 82 | 83 | # Compute detection metrics 84 | self._compute_detection_metrics(y_true, y_score) 85 | 86 | return self.metrics 87 | 88 | # @staticmethod 89 | # def compute_optimal_iou(y_true, y_pred, batch_size=256): 90 | # # Check whether NaN values 91 | # if np.isnan(y_pred).any(): 92 | # print("WARNING: NaN values in localization prediction scores!") 93 | # y_pred[np.isnan(y_pred)] = 0 94 | 95 | # # Store all possible iou scores 96 | # thresholds = y_pred.flatten() 97 | # scores = np.zeros_like(thresholds) 98 | 99 | # for i in range(0, len(scores), batch_size): 100 | # threshs = thresholds[i : i + batch_size] # [B] 101 | 102 | # y_preds = y_pred.copy() 103 | # # [H, W, B] 104 | # y_preds = np.repeat(y_preds[..., None], batch_size, axis=-1) 105 | 106 | # # Perform thresholding 107 | # y_preds[y_preds < threshs] = 0 108 | # y_preds[y_preds >= threshs] = 1 109 | 110 | # # Compute scores 111 | # return iou( 112 | # torch.from_numpy(y_preds.transpose(2, 0, 1)), 113 | # torch.from_numpy(np.repeat(y_true[None, ...], batch_size, axis=0)) 114 | # ) 115 | 116 | # Compute iou score for each threshold 117 | # for i, thresh in tqdm(enumerate(thresholds)): 118 | # y_pred_thresh = np.zeros(y_pred.shape, dtype=np.uint8) 119 | # y_pred_thresh[y_pred >= thresh] = 1 120 | 121 | # scores[i] = jaccard_score(y_true.flatten(), y_pred_thresh.flatten()) 122 | 123 | return scores 124 | 125 | def _compute_localization_metrics(self, label_map, score_map) -> None: 126 | # Check whether NaN values 127 | if np.isnan(score_map).any(): 128 | print("WARNING: NaN values in localization prediction scores!") 129 | score_map[np.isnan(score_map)] = 0 130 | 131 | # Find optimal threshold, and the corresponding score for each image 132 | 133 | # Compute for spliced regions 134 | _, iou_spliced = self.find_optimal_threshold(score_map, label_map) 135 | iou_spliced = iou_spliced.mean().item() 136 | 137 | # Compute for non-spliced regions 138 | invert_label_map = 1 - label_map 139 | invert_score_map = 1 - score_map 140 | 141 | _, iou_non_spliced = self.find_optimal_threshold( 142 | invert_score_map, invert_label_map 143 | ) 144 | iou_non_spliced = iou_non_spliced.mean().item() 145 | 146 | self.metrics["IoU-spliced"] = iou_spliced 147 | self.metrics["IoU-non-spliced"] = iou_non_spliced 148 | # Compute mean IoU 149 | self.metrics["IoU"] = (iou_spliced + iou_non_spliced) / 2 150 | 151 | # FIXME Report per-class scores 152 | 153 | def _compute_detection_metrics(self, y_true, y_score) -> None: 154 | # Check whether NaN values 155 | if np.isnan(y_score).any(): 156 | print("WARNING: NaN values in detection prediction scores!") 157 | y_score[np.isnan(y_score)] = 0 158 | 159 | self.metrics["AP"] = average_precision_score(y_true, y_score) 160 | 161 | @staticmethod 162 | def find_optimal_threshold( 163 | pred_mask: np.ndarray, groundtruth_masks: np.ndarray 164 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 165 | """https://codereview.stackexchange.com/questions/229341/pytorch-vectorized-implementation-for-thresholding-and-computing-jaccard-index 166 | 167 | Parameters 168 | ---------- 169 | pred_mask : np.ndarray (float32) 170 | [B, H, W], range [0, 1], probability prediction map 171 | groundtruth_masks : np.ndarray (uint8) 172 | [B, H, W], values one of {0, 1}, binary label map 173 | 174 | Returns 175 | ------- 176 | Tuple[torch.FloatTensor, torch.FloatTensor] 177 | [B], optimal thresholds for each image 178 | [B], corresponding jaccard scores for each image 179 | """ 180 | n_patch = groundtruth_masks.shape[0] 181 | 182 | groundtruth_masks_tensor = torch.from_numpy(groundtruth_masks) 183 | pred_mask_tensor = torch.from_numpy(pred_mask) 184 | 185 | # if USE_CUDA: 186 | # groundtruth_masks_tensor = groundtruth_masks_tensor.cuda() 187 | # pred_mask_tensor = pred_mask_tensor.cuda() 188 | 189 | vector_pred = pred_mask_tensor.view(n_patch, -1) 190 | vector_gt = groundtruth_masks_tensor.view(n_patch, -1) 191 | vector_pred, sort_pred_idx = torch.sort(vector_pred, descending=True) 192 | vector_gt = vector_gt[torch.arange(vector_gt.shape[0])[:, None], sort_pred_idx] 193 | gt_cumsum = torch.cumsum(vector_gt, dim=1) 194 | gt_total = gt_cumsum[:, -1].reshape(n_patch, 1) 195 | predicted = torch.arange(start=1, end=vector_pred.shape[1] + 1) 196 | # if USE_CUDA: 197 | # predicted = predicted.cuda() 198 | gt_cumsum = gt_cumsum.type(torch.float) 199 | gt_total = gt_total.type(torch.float) 200 | predicted = predicted.type(torch.float) 201 | jaccard_idx = gt_cumsum / (gt_total + predicted - gt_cumsum) 202 | max_jaccard_idx, max_indices = torch.max(jaccard_idx, dim=1) 203 | max_indices = max_indices.reshape(-1, 1) 204 | best_threshold = vector_pred[ 205 | torch.arange(vector_pred.shape[0])[:, None], max_indices 206 | ] 207 | best_threshold = best_threshold.reshape(-1) 208 | 209 | return best_threshold, max_jaccard_idx 210 | -------------------------------------------------------------------------------- /requirements/prod.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements/prod.in 6 | # 7 | absl-py==0.12.0 8 | # via tensorboard 9 | aiohttp==3.7.4.post0 10 | # via fsspec 11 | argon2-cffi==20.1.0 12 | # via notebook 13 | async-generator==1.10 14 | # via nbclient 15 | async-timeout==3.0.1 16 | # via aiohttp 17 | attrs==20.3.0 18 | # via 19 | # aiohttp 20 | # jsonschema 21 | backcall==0.2.0 22 | # via ipython 23 | bleach==3.3.0 24 | # via nbconvert 25 | cachetools==4.2.2 26 | # via google-auth 27 | certifi==2020.12.5 28 | # via 29 | # requests 30 | # sentry-sdk 31 | cffi==1.14.5 32 | # via argon2-cffi 33 | chardet==4.0.0 34 | # via 35 | # aiohttp 36 | # requests 37 | click==7.1.2 38 | # via wandb 39 | configparser==5.0.2 40 | # via wandb 41 | cycler==0.10.0 42 | # via matplotlib 43 | decorator==5.0.5 44 | # via ipython 45 | defusedxml==0.7.1 46 | # via nbconvert 47 | docker-pycreds==0.4.0 48 | # via wandb 49 | entrypoints==0.3 50 | # via nbconvert 51 | filelock==3.0.12 52 | # via gdown 53 | fsspec[http]==2021.5.0 54 | # via pytorch-lightning 55 | future==0.18.2 56 | # via pytorch-lightning 57 | gdown==3.12.2 58 | # via -r requirements/prod.in 59 | gitdb==4.0.7 60 | # via gitpython 61 | gitpython==3.1.14 62 | # via wandb 63 | google-auth-oauthlib==0.4.4 64 | # via tensorboard 65 | google-auth==1.30.0 66 | # via 67 | # google-auth-oauthlib 68 | # tensorboard 69 | grpcio==1.37.1 70 | # via tensorboard 71 | idna==2.10 72 | # via 73 | # requests 74 | # yarl 75 | importlib-metadata==3.10.0 76 | # via 77 | # jsonschema 78 | # markdown 79 | ipykernel==5.5.3 80 | # via 81 | # ipywidgets 82 | # jupyter 83 | # jupyter-console 84 | # notebook 85 | # qtconsole 86 | ipython-genutils==0.2.0 87 | # via 88 | # nbformat 89 | # notebook 90 | # qtconsole 91 | # traitlets 92 | ipython==7.22.0 93 | # via 94 | # ipykernel 95 | # ipywidgets 96 | # jupyter-console 97 | ipywidgets==7.6.3 98 | # via jupyter 99 | jedi==0.18.0 100 | # via ipython 101 | jinja2==2.11.3 102 | # via 103 | # nbconvert 104 | # notebook 105 | joblib==1.0.1 106 | # via scikit-learn 107 | jsonschema==3.2.0 108 | # via nbformat 109 | jupyter-client==6.1.12 110 | # via 111 | # ipykernel 112 | # jupyter-console 113 | # nbclient 114 | # notebook 115 | # qtconsole 116 | jupyter-console==6.4.0 117 | # via jupyter 118 | jupyter-core==4.7.1 119 | # via 120 | # jupyter-client 121 | # nbconvert 122 | # nbformat 123 | # notebook 124 | # qtconsole 125 | jupyter==1.0.0 126 | # via -r requirements/prod.in 127 | jupyterlab-pygments==0.1.2 128 | # via nbconvert 129 | jupyterlab-widgets==1.0.0 130 | # via ipywidgets 131 | kiwisolver==1.3.1 132 | # via matplotlib 133 | markdown==3.3.4 134 | # via tensorboard 135 | markupsafe==1.1.1 136 | # via jinja2 137 | matplotlib==3.4.1 138 | # via -r requirements/prod.in 139 | mistune==0.8.4 140 | # via nbconvert 141 | multidict==5.1.0 142 | # via 143 | # aiohttp 144 | # yarl 145 | nbclient==0.5.3 146 | # via nbconvert 147 | nbconvert==6.0.7 148 | # via 149 | # jupyter 150 | # notebook 151 | nbformat==5.1.3 152 | # via 153 | # ipywidgets 154 | # nbclient 155 | # nbconvert 156 | # notebook 157 | nest-asyncio==1.5.1 158 | # via nbclient 159 | notebook==6.3.0 160 | # via 161 | # jupyter 162 | # widgetsnbextension 163 | numpy==1.20.2 164 | # via 165 | # matplotlib 166 | # opencv-python 167 | # pandas 168 | # pytorch-lightning 169 | # scikit-learn 170 | # scipy 171 | # tensorboard 172 | # torch 173 | # torchvision 174 | oauthlib==3.1.0 175 | # via requests-oauthlib 176 | opencv-python==4.5.1.48 177 | # via -r requirements/prod.in 178 | packaging==20.9 179 | # via 180 | # bleach 181 | # pytorch-lightning 182 | # torchmetrics 183 | pandas==1.2.4 184 | # via -r requirements/prod.in 185 | pandocfilters==1.4.3 186 | # via nbconvert 187 | parso==0.8.2 188 | # via jedi 189 | pathtools==0.1.2 190 | # via wandb 191 | pexpect==4.8.0 192 | # via ipython 193 | pickleshare==0.7.5 194 | # via ipython 195 | pillow==8.2.0 196 | # via 197 | # matplotlib 198 | # torchvision 199 | prometheus-client==0.10.0 200 | # via notebook 201 | promise==2.3 202 | # via wandb 203 | prompt-toolkit==3.0.18 204 | # via 205 | # ipython 206 | # jupyter-console 207 | protobuf==3.15.7 208 | # via 209 | # tensorboard 210 | # wandb 211 | psutil==5.8.0 212 | # via wandb 213 | ptyprocess==0.7.0 214 | # via 215 | # pexpect 216 | # terminado 217 | pyasn1-modules==0.2.8 218 | # via google-auth 219 | pyasn1==0.4.8 220 | # via 221 | # pyasn1-modules 222 | # rsa 223 | pycparser==2.20 224 | # via cffi 225 | pydeprecate==0.3.0 226 | # via pytorch-lightning 227 | pygments==2.8.1 228 | # via 229 | # ipython 230 | # jupyter-console 231 | # jupyterlab-pygments 232 | # nbconvert 233 | # qtconsole 234 | pyparsing==2.4.7 235 | # via 236 | # matplotlib 237 | # packaging 238 | pyrsistent==0.17.3 239 | # via jsonschema 240 | pysocks==1.7.1 241 | # via requests 242 | python-dateutil==2.8.1 243 | # via 244 | # jupyter-client 245 | # matplotlib 246 | # pandas 247 | # wandb 248 | pytorch-lightning==1.3.1 249 | # via -r requirements/prod.in 250 | pytz==2021.1 251 | # via pandas 252 | pyyaml==5.4.1 253 | # via 254 | # -r requirements/prod.in 255 | # pytorch-lightning 256 | # wandb 257 | pyzmq==22.0.3 258 | # via 259 | # jupyter-client 260 | # notebook 261 | # qtconsole 262 | qtconsole==5.0.3 263 | # via jupyter 264 | qtpy==1.9.0 265 | # via qtconsole 266 | requests-oauthlib==1.3.0 267 | # via google-auth-oauthlib 268 | requests[socks]==2.25.1 269 | # via 270 | # fsspec 271 | # gdown 272 | # requests-oauthlib 273 | # tensorboard 274 | # wandb 275 | rsa==4.7.2 276 | # via google-auth 277 | scikit-learn==0.24.1 278 | # via -r requirements/prod.in 279 | scipy==1.6.2 280 | # via 281 | # -r requirements/prod.in 282 | # scikit-learn 283 | send2trash==1.5.0 284 | # via notebook 285 | sentry-sdk==1.0.0 286 | # via wandb 287 | shortuuid==1.0.1 288 | # via wandb 289 | six==1.15.0 290 | # via 291 | # absl-py 292 | # argon2-cffi 293 | # bleach 294 | # cycler 295 | # docker-pycreds 296 | # gdown 297 | # google-auth 298 | # grpcio 299 | # jsonschema 300 | # promise 301 | # protobuf 302 | # python-dateutil 303 | # tensorboard 304 | # wandb 305 | smmap==4.0.0 306 | # via gitdb 307 | subprocess32==3.5.4 308 | # via wandb 309 | tensorboard-plugin-wit==1.8.0 310 | # via tensorboard 311 | tensorboard==2.4.1 312 | # via pytorch-lightning 313 | terminado==0.9.4 314 | # via notebook 315 | testpath==0.4.4 316 | # via nbconvert 317 | threadpoolctl==2.1.0 318 | # via scikit-learn 319 | toml==0.10.2 320 | # via -r requirements/prod.in 321 | torch==1.8.1 322 | # via 323 | # -r requirements/prod.in 324 | # pytorch-lightning 325 | # torchmetrics 326 | # torchvision 327 | torchmetrics==0.3.2 328 | # via 329 | # -r requirements/prod.in 330 | # pytorch-lightning 331 | torchvision==0.9.1 332 | # via -r requirements/prod.in 333 | tornado==6.1 334 | # via 335 | # ipykernel 336 | # jupyter-client 337 | # notebook 338 | # terminado 339 | tqdm==4.59.0 340 | # via 341 | # -r requirements/prod.in 342 | # gdown 343 | # pytorch-lightning 344 | traitlets==5.0.5 345 | # via 346 | # ipykernel 347 | # ipython 348 | # ipywidgets 349 | # jupyter-client 350 | # jupyter-core 351 | # nbclient 352 | # nbconvert 353 | # nbformat 354 | # notebook 355 | # qtconsole 356 | typing-extensions==3.7.4.3 357 | # via 358 | # aiohttp 359 | # importlib-metadata 360 | # torch 361 | # yarl 362 | urllib3==1.26.4 363 | # via 364 | # requests 365 | # sentry-sdk 366 | wandb==0.10.24 367 | # via -r requirements/prod.in 368 | wcwidth==0.2.5 369 | # via prompt-toolkit 370 | webencodings==0.5.1 371 | # via bleach 372 | werkzeug==2.0.1 373 | # via tensorboard 374 | wheel==0.36.2 375 | # via tensorboard 376 | widgetsnbextension==3.5.1 377 | # via ipywidgets 378 | yarl==1.6.3 379 | # via aiohttp 380 | zipp==3.4.1 381 | # via importlib-metadata 382 | 383 | # The following packages are considered to be unsafe in a requirements file: 384 | # setuptools 385 | -------------------------------------------------------------------------------- /src/models/exif_sc/exif_sc.py: -------------------------------------------------------------------------------- 1 | """EXIF-SC overall inference model 2 | 3 | From: 4 | - Fighting Fake News: Image Splice Detection via Learned Self-Consistency (Huh et al., ECCV 2018) 5 | - https://minyoungg.github.io/selfconsistency/ 6 | - https://github.com/minyoungg/selfconsistency 7 | 8 | Network building file adapted from: 9 | - https://github.com/Microsoft/MMdnn/blob/master/docs/tf2pytorch.md 10 | """ 11 | from typing import Any, Dict 12 | 13 | import cv2 14 | import numpy as np 15 | import torch 16 | from src.structures import PatchedImage 17 | 18 | # FIXME Something wrong with network!! 19 | # FIXME Something wrong with image preprocessing?!! 20 | from .networks import EXIF_Net 21 | from .postprocess import mean_shift, normalized_cut 22 | 23 | # TODO Check out PyTorch multiprocessing 24 | # FIXME Careful of image shape, i.e. [C, H, W] vs [H, W, C] 25 | # FIXME `no_grad` when running network 26 | # FIXME Normalize image! 27 | 28 | 29 | class EXIF_SC: 30 | def __init__( 31 | self, weight_file: str, patch_size=128, num_per_dim=30, device="cuda:0" 32 | ) -> None: 33 | """ 34 | Parameters 35 | ---------- 36 | weight_file : str 37 | Path to network weights file 38 | patch_size : int, optional 39 | Size of patches, by default 128 40 | num_per_dim : int, optional 41 | Number of patches to use along the largest dimension, by default 30 42 | device : str, optional 43 | , by default "cuda:0" 44 | """ 45 | self.patch_size = patch_size 46 | self.num_per_dim = num_per_dim 47 | self.device = torch.device(device) 48 | 49 | self.net = EXIF_Net(weight_file) 50 | self.net.eval() 51 | self.net.to(device) 52 | 53 | def predict( 54 | self, 55 | img: torch.Tensor, 56 | feat_batch_size=32, # Does not affect compute time much? 57 | pred_batch_size=1024, # Affects up to a certain extent 58 | blue_high=True, 59 | ) -> Dict[str, Any]: 60 | """ 61 | Parameters 62 | ---------- 63 | img : torch.Tensor 64 | [C, H, W], range: [0, 255] 65 | feat_batch_size : int, optional 66 | , by default 32 67 | pred_batch_size : int, optional 68 | , by default 1024 69 | blue_high : bool 70 | , by default True 71 | 72 | Returns 73 | ------- 74 | Dict[str, Any] 75 | ms : np.ndarray (float32) 76 | Consistency map, [H, W], range [0, 1] 77 | ncuts : np.ndarray (float32) 78 | Localization map, [H, W], range [0, 1] 79 | score : float 80 | Prediction score, higher indicates existence of manipulation 81 | """ 82 | _, height, width = img.shape 83 | assert ( 84 | min(height, width) > self.patch_size 85 | ), "Image must be bigger than patch size!" 86 | 87 | # Initialize image and attributes 88 | img = self.init_img(img) 89 | 90 | # Precompute features for each patch 91 | with torch.no_grad(): 92 | patch_features = self.get_patch_feats(img, batch_size=feat_batch_size) 93 | 94 | # Predict consistency maps 95 | pred_maps = self._predict_consistency_maps( 96 | img, patch_features, batch_size=pred_batch_size 97 | ) 98 | 99 | # Produce a single response map 100 | ms = mean_shift( 101 | pred_maps.reshape((-1, pred_maps.shape[0] * pred_maps.shape[1])), pred_maps 102 | ) 103 | 104 | # As a heuristic, the anomalous areas are smaller than the normal areas 105 | if np.mean(ms > 0.5) > 0.5: 106 | # majority of the image is above .5 107 | if blue_high: 108 | # Reverse heat map 109 | ms = 1 - ms 110 | 111 | # Run clustering to get localization map 112 | ncuts = normalized_cut(pred_maps) 113 | if np.mean(ncuts > 0.5) > 0.5: 114 | # majority of the image is white 115 | # flip so spliced is white 116 | ncuts = 1 - ncuts 117 | out_ncuts = cv2.resize( 118 | ncuts.astype(np.float32), 119 | (width, height), 120 | interpolation=cv2.INTER_LINEAR, 121 | ) 122 | 123 | out_ms = cv2.resize(ms, (width, height), interpolation=cv2.INTER_LINEAR) 124 | 125 | return {"ms": out_ms, "ncuts": out_ncuts, "score": out_ms.mean()} 126 | 127 | def init_img(self, img: torch.Tensor) -> PatchedImage: 128 | # Initialize image and attributes 129 | img = img.to(self.device) 130 | img = PatchedImage(img, self.patch_size, self.num_per_dim) 131 | 132 | return img 133 | 134 | def _predict_consistency_maps( 135 | self, img: PatchedImage, patch_features: torch.Tensor, batch_size=64 136 | ): 137 | # For each patch, how many overlapping patches? 138 | spread = max(1, img.patch_size // img.stride) 139 | 140 | # Aggregate prediction maps; for each patch, compared to each other patch 141 | responses = torch.zeros( 142 | ( 143 | img.max_h_idx + spread - 1, 144 | img.max_w_idx + spread - 1, 145 | img.max_h_idx + spread - 1, 146 | img.max_w_idx + spread - 1, 147 | ) 148 | ) 149 | # Number of predictions for each patch 150 | vote_counts = ( 151 | torch.zeros( 152 | ( 153 | img.max_h_idx + spread - 1, 154 | img.max_w_idx + spread - 1, 155 | img.max_h_idx + spread - 1, 156 | img.max_w_idx + spread - 1, 157 | ) 158 | ) 159 | + 1e-4 160 | ) 161 | 162 | # Perform prediction 163 | for idxs in img.pred_idxs_gen(batch_size=batch_size): 164 | # a to be compared to b 165 | patch_a_idxs = idxs[:, :2] # [B, 2] 166 | patch_b_idxs = idxs[:, 2:] # [B, 2] 167 | 168 | # Convert 2D index into its 1D version 169 | a_idxs = torch.from_numpy( 170 | np.ravel_multi_index(patch_a_idxs.T, [img.max_h_idx, img.max_w_idx]) 171 | ) # [B] 172 | b_idxs = torch.from_numpy( 173 | np.ravel_multi_index(patch_b_idxs.T, [img.max_h_idx, img.max_w_idx]) 174 | ) 175 | 176 | # Grab corresponding features 177 | a_feats = patch_features[a_idxs] # [B, 4096] 178 | b_feats = patch_features[b_idxs] 179 | 180 | feats = torch.cat([a_feats, b_feats], dim=-1) # [B, 8192] 181 | 182 | # Get predictions 183 | with torch.no_grad(): 184 | exif_logits = self.net.exif_fc(feats) # [B, 83] 185 | # FIXME Sigmoid or nay? 186 | # exif_preds = torch.sigmoid(exif_logits) 187 | exif_preds = exif_logits 188 | 189 | logits = self.net.classifier_fc(exif_preds) 190 | preds = torch.sigmoid(logits) # [B, 1] 191 | 192 | preds = preds.cpu() 193 | 194 | # FIXME Is it possible to vectorize this? 195 | # Accumulate predictions for overlapping patches 196 | for i in range(len(preds)): 197 | responses[ 198 | idxs[i][0] : (idxs[i][0] + spread), 199 | idxs[i][1] : (idxs[i][1] + spread), 200 | idxs[i][2] : (idxs[i][2] + spread), 201 | idxs[i][3] : (idxs[i][3] + spread), 202 | ] += preds[i] 203 | vote_counts[ 204 | idxs[i][0] : (idxs[i][0] + spread), 205 | idxs[i][1] : (idxs[i][1] + spread), 206 | idxs[i][2] : (idxs[i][2] + spread), 207 | idxs[i][3] : (idxs[i][3] + spread), 208 | ] += 1 209 | 210 | # Normalize predictions 211 | return responses / vote_counts 212 | 213 | def get_patch_feats( 214 | self, img: PatchedImage, batch_size=32 215 | ) -> torch.Tensor: 216 | """Get features for every patch in the image. 217 | Features used to compute if two patches share the same EXIF attributes. 218 | 219 | Parameters 220 | ---------- 221 | batch_size : int, optional 222 | Batch size to be fed into the network, by default 32 223 | 224 | Returns 225 | ------- 226 | torch.Tensor 227 | [n_patches, 4096] 228 | """ 229 | # Compute feature vector for each image patch 230 | patch_features = [] 231 | 232 | # Generator for patches; raster scan order 233 | for patches in img.patches_gen(batch_size): 234 | feat = self.net(patches) 235 | # If missing batch dimension 236 | if len(feat.shape) == 1: 237 | feat = feat.view(1, -1) 238 | patch_features.append(feat) 239 | 240 | # [n_patches, n_features] 241 | patch_features = torch.cat(patch_features, dim=0) 242 | 243 | # Try preallocate tensor instead 244 | # patch_features = torch.zeros(self.max_h_idx * self.max_w_idx, n_features) 245 | 246 | # with torch.no_grad(): 247 | # for i, patches in enumerate(self._patches_gen(batch_size)): 248 | # patch_features[i * batch_size : (i + 1) * batch_size] = self.net( 249 | # patches 250 | # ) 251 | 252 | # patch_features = (patch_features.T).view( 253 | # n_features, self.max_h_idx, self.max_w_idx 254 | # ) 255 | 256 | return patch_features 257 | 258 | 259 | if __name__ == "__main__": 260 | import argparse 261 | 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument( 264 | "--weights_path", 265 | help="path to the weights file", 266 | default="artifacts/exif_sc.npy", 267 | ) 268 | parser.add_argument( 269 | "--img_path", 270 | help="path to the input image file", 271 | default="data/demo.png", 272 | ) 273 | args = parser.parse_args() 274 | 275 | model = EXIF_SC(args.weights_path) 276 | 277 | img = cv2.imread(args.img_path)[:, :, [2, 1, 0]] # [H, W, C] 278 | img = torch.from_numpy(img).permute(2, 0, 1) # [C, H, W] 279 | -------------------------------------------------------------------------------- /src/datasets/mirflickr_25k.py: -------------------------------------------------------------------------------- 1 | """MIRFLICKR-25k Dataset 2 | 3 | - https://press.liacs.nl/mirflickr/ 4 | - M. J. Huiskes, M. S. Lew (2008). The MIR Flickr Retrieval Evaluation. ACM International Conference on Multimedia Information Retrieval (MIR'08), Vancouver, Canada 5 | """ 6 | import zipfile 7 | from pathlib import Path 8 | from typing import Tuple, Optional 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import toml 13 | import torch 14 | from src.datasets.utils import download_raw_dataset 15 | from torch.utils.data import Dataset, DataLoader 16 | from torchvision.io import read_image 17 | from torchvision.transforms.functional import resize 18 | from pytorch_lightning import LightningDataModule 19 | 20 | METADATA_FILENAME = Path("data/raw/mirflickr_25k/metadata.toml") 21 | DL_DATA_DIRNAME = Path("data/downloaded/mirflickr_25k") 22 | PROCESSED_DATA_DIRNAME = DL_DATA_DIRNAME / "mirflickr" 23 | 24 | 25 | class MIRFLICKR_25kDataset(Dataset): 26 | def __init__( 27 | self, 28 | root_dir: Path = DL_DATA_DIRNAME, 29 | n_exif_attr: int = 80, 30 | patch_size: int = 128, 31 | batch_size: int = 32, 32 | iters_per_epoch: int = 5_000, 33 | label: str = "attr", # One of {"attr", "img"} 34 | ) -> None: 35 | self.n_exif_attr = n_exif_attr 36 | self.patch_size = patch_size 37 | self.iters_per_epoch = iters_per_epoch 38 | self.label = label 39 | 40 | assert batch_size % 2 == 0, "Make sure `batch_size` is divisible by 2!" 41 | self.batch_size = batch_size 42 | 43 | self._prepare_data() 44 | 45 | def _prepare_data(self) -> None: 46 | """Download dataset files, 47 | and processes them into suitable data structures 48 | """ 49 | if not PROCESSED_DATA_DIRNAME.exists(): 50 | metadata = toml.load(METADATA_FILENAME) 51 | # Download dataset 52 | download_raw_dataset(metadata, DL_DATA_DIRNAME) 53 | 54 | # Process downloaded dataset 55 | print("Unzipping MIRFLICKR-25k...") 56 | zip = zipfile.ZipFile(DL_DATA_DIRNAME / metadata["filename"]) 57 | zip.extractall(DL_DATA_DIRNAME) 58 | zip.close() 59 | 60 | self._init_exif_data() 61 | 62 | def _init_exif_data(self) -> None: 63 | # Compile EXIF information from dataset 64 | exif_dir = PROCESSED_DATA_DIRNAME / "meta" / "exif_raw" 65 | exif_paths = list(exif_dir.glob("*.txt")) 66 | 67 | data_dicts = [] 68 | 69 | for p in exif_paths: 70 | d = {} 71 | 72 | idx = int(p.stem[4:]) 73 | d["img_path"] = str(PROCESSED_DATA_DIRNAME / f"im{idx}.jpg") 74 | 75 | with p.open("r", errors="replace") as f: 76 | lines = f.readlines() 77 | 78 | for i in range(int(len(lines) / 2)): 79 | attr = lines[i * 2][1:].strip() 80 | value = lines[(i * 2) + 1].strip() 81 | 82 | d[attr] = value 83 | 84 | data_dicts.append(d) 85 | 86 | df = pd.DataFrame(data_dicts) 87 | 88 | self.img_paths = df["img_path"] 89 | 90 | # Determine EXIF attributes to predict 91 | # Select the attributes with the least missing values 92 | exif_attrs = list( 93 | df.drop("img_path", axis=1) 94 | .isnull() 95 | .mean(0) 96 | .sort_values()[: self.n_exif_attr] 97 | .index 98 | ) 99 | self.exif_data = df[exif_attrs] 100 | self.exif_attrs = list(self.exif_data.columns) 101 | 102 | # TODO For a given EXIF attribute, 103 | # discard values that occur less than N times? 104 | 105 | # TODO Train / Val split? 106 | 107 | def _resize_img(self, img: torch.ByteTensor) -> torch.ByteTensor: 108 | """Resizes img if smaller than required patch size""" 109 | _, H, W = img.shape 110 | 111 | if H < self.patch_size or W < self.patch_size: 112 | return resize(img, size=self.patch_size) 113 | 114 | else: 115 | return img 116 | 117 | def _get_random_patch(self, img: torch.ByteTensor) -> torch.ByteTensor: 118 | _, H, W = img.shape 119 | rand_H = np.random.randint(H - self.patch_size + 1) 120 | rand_W = np.random.randint(W - self.patch_size + 1) 121 | 122 | return img[ 123 | :, rand_H : rand_H + self.patch_size, rand_W : rand_W + self.patch_size 124 | ] 125 | 126 | def __getitem__(self, idx: int) -> Tuple[torch.ByteTensor, torch.LongTensor]: 127 | if self.label == "attr": 128 | return self._get_attr_batch() 129 | elif self.label == "img": 130 | return self._get_img_batch() 131 | 132 | def _get_attr_batch(self) -> Tuple[torch.ByteTensor, torch.LongTensor]: 133 | """Get pairs of image patches, and the EXIF values predictions 134 | 135 | Returns 136 | ------- 137 | Tuple[torch.ByteTensor, torch.LongTensor] 138 | [2, batch_size, C, H, W], [batch_size, n_exif_attr] 139 | Range [0, 255], One of {0, 1} 140 | """ 141 | # FIXME Disable automatic batching? Or define a sampler? 142 | 143 | # TODO Include post-processing consistency pipeline 144 | 145 | n_rows, n_cols = self.exif_data.shape 146 | 147 | # Randomly choose an EXIF value 148 | exif_idx = np.random.randint(n_cols) 149 | 150 | # FIXME Cache EXIF values? 151 | exif_col = self.exif_data.iloc[:, exif_idx] 152 | while True: 153 | exif_value = np.random.choice(exif_col.unique()) 154 | if exif_value is not np.nan: 155 | break 156 | 157 | # Get all images with / w/o that `exif_value` 158 | is_exif_value = exif_col == exif_value 159 | imgs_with_value = self.img_paths[is_exif_value] 160 | imgs_wo_value = self.img_paths[~is_exif_value] 161 | 162 | # [2, B, C, H, W] 163 | img_batch = torch.zeros( 164 | 2, self.batch_size, 3, self.patch_size, self.patch_size, dtype=torch.uint8 165 | ) 166 | # [2, B, n_exif_attr] 167 | attrs_batch = np.empty((2, self.batch_size, self.n_exif_attr), dtype=object) 168 | 169 | # FIXME Possible to vectorize this? 170 | # Create batch 171 | for batch_idx in range(self.batch_size): 172 | for pair_idx in (0, 1): 173 | # Create negative pairs for second half of the batch 174 | if batch_idx >= int(self.batch_size / 2): 175 | imgs_to_sample = imgs_wo_value if pair_idx else imgs_with_value 176 | else: 177 | imgs_to_sample = imgs_with_value 178 | 179 | img_sample = imgs_to_sample.sample() 180 | img_idx = img_sample.index.values[0] 181 | 182 | # Get attributes 183 | attrs = self.exif_data.loc[img_idx].values 184 | attrs_batch[pair_idx, batch_idx] = attrs 185 | 186 | # Get image 187 | img_path = img_sample.values[0] 188 | img = read_image(img_path) 189 | # Resize image if smaller than patch size 190 | img = self._resize_img(img) 191 | img_patch = self._get_random_patch(img) 192 | 193 | img_batch[pair_idx, batch_idx] = img_patch 194 | 195 | # Compute labels; by comparing the attrs of each pair 196 | labels_batch = attrs_batch[0] == attrs_batch[1] 197 | labels_batch = torch.tensor(labels_batch, dtype=torch.int64) 198 | 199 | return img_batch, labels_batch 200 | 201 | def _get_img_batch(self) -> Tuple[torch.ByteTensor, torch.FloatTensor]: 202 | """Get pairs of image patches, 203 | and prediction for whether each pair came from the same image 204 | 205 | Returns 206 | ------- 207 | Tuple[torch.ByteTensor, torch.LongTensor] 208 | [2, batch_size, C, H, W], [batch_size] 209 | Range [0, 255], One of {0, 1} 210 | """ 211 | n_rows, n_cols = self.exif_data.shape 212 | 213 | # Batch contains half positive pairs, and half negative pairs 214 | labels_batch = torch.zeros(self.batch_size, dtype=torch.int64) 215 | labels_batch[:int(self.batch_size / 2)] = 1 216 | 217 | # [2, B, C, H, W] 218 | img_batch = torch.zeros( 219 | 2, self.batch_size, 3, self.patch_size, self.patch_size, dtype=torch.uint8 220 | ) 221 | 222 | # Create positive pairs 223 | for batch_idx in range(int(self.batch_size / 2)): 224 | # Choose a random image to be the current pair 225 | img_idx = np.random.randint(n_rows) 226 | img_path = self.img_paths[img_idx] 227 | 228 | img = read_image(img_path) 229 | # Resize image if smaller than patch size 230 | img = self._resize_img(img) 231 | 232 | for pair_idx in (0, 1): 233 | img_patch = self._get_random_patch(img) 234 | img_batch[pair_idx, batch_idx] = img_patch 235 | 236 | # Create negative pairs 237 | for batch_idx in range(int(self.batch_size / 2), self.batch_size): 238 | # Choose a random pair of images 239 | img_idxs = np.random.choice(np.arange(n_rows), size=(2,), replace=False) 240 | 241 | for pair_idx in (0, 1): 242 | img_path = self.img_paths[img_idxs[pair_idx]] 243 | 244 | img = read_image(img_path) 245 | # Resize image if smaller than patch size 246 | img = self._resize_img(img) 247 | img_patch = self._get_random_patch(img) 248 | 249 | img_batch[pair_idx, batch_idx] = img_patch 250 | 251 | return img_batch, labels_batch 252 | 253 | def __len__(self): 254 | # Determines how many iterations per epoch 255 | return self.iters_per_epoch 256 | 257 | 258 | class MIRFLICKR_25kDataModule(LightningDataModule): 259 | def __init__( 260 | self, 261 | root_dir: Path = DL_DATA_DIRNAME, 262 | n_exif_attr: int = 80, 263 | patch_size: int = 128, 264 | batch_size: int = 32, 265 | iters_per_epoch: int = 5_000, 266 | label: str = "attr", # One of {"attr", "img"} 267 | n_workers: int = 18, 268 | pin_memory: bool = True, 269 | ) -> None: 270 | super().__init__() 271 | 272 | self.root_dir = root_dir 273 | self.n_exif_attr = n_exif_attr 274 | self.patch_size = patch_size 275 | self.batch_size = batch_size 276 | self.iters_per_epoch = iters_per_epoch 277 | self.label = label 278 | self.n_workers = n_workers 279 | self.pin_memory = pin_memory 280 | 281 | def prepare_data(self, *args, **kwargs) -> None: 282 | self.dataset = MIRFLICKR_25kDataset( 283 | root_dir=self.root_dir, 284 | n_exif_attr=self.n_exif_attr, 285 | patch_size=self.patch_size, 286 | batch_size=self.batch_size, 287 | iters_per_epoch=self.iters_per_epoch, 288 | label=self.label, 289 | ) 290 | 291 | def setup(self, stage: Optional[str] = None) -> None: 292 | self.exif_attrs = self.dataset.exif_attrs 293 | 294 | def train_dataloader(self): 295 | return DataLoader( 296 | self.dataset, 297 | batch_size=None, # Disable automatic batching 298 | num_workers=self.n_workers, 299 | pin_memory=self.pin_memory, 300 | ) 301 | -------------------------------------------------------------------------------- /src/evaluation/evaluators.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Any, Dict, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from sklearn.metrics import average_precision_score 9 | from torch.utils.data import Dataset 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | 13 | from .metrics import AUC_Metric, F1_Metric, MCC_Metric, mAP_Metric 14 | 15 | 16 | class Evaluator: 17 | def __init__( 18 | self, 19 | model, 20 | dataset: Dataset, 21 | attacker, 22 | vis_dir: str = None, 23 | vis_every=1, 24 | logger=None, 25 | method="mean", 26 | ) -> None: 27 | # Freeze all network weights 28 | for parameter in model.net.parameters(): 29 | parameter.requires_grad = False 30 | model.net.eval() 31 | self.model = model 32 | 33 | self.dataset = dataset 34 | self.method = method 35 | 36 | self.vis_dir = vis_dir 37 | self.vis_every = vis_every 38 | self.logger = logger 39 | 40 | self.results = {"clean": {}, "adv": {}} 41 | 42 | self.attacker = attacker 43 | 44 | def __call__(self, resize: Tuple[int, int] = None) -> Dict[str, Any]: 45 | """ 46 | Parameters 47 | ---------- 48 | save : bool, optional 49 | Whether to save prediction arrays, by default False 50 | resize : Tuple[int, int], optional 51 | [H, W], whether to resize images / maps to a consistent shape 52 | 53 | Returns 54 | ------- 55 | Dict[str, Any] 56 | AP : float 57 | Average precision score, for detection 58 | IoU : float 59 | Class-balanced IoU, for localization 60 | f1_score : float 61 | for localization 62 | mcc : float 63 | Matthews Correlation Coefficient, for localization 64 | mAP : float 65 | Mean Average Precision, for localization 66 | auc : float 67 | Area under the Receiving Operating Characteristic Curve, for localization 68 | """ 69 | # Initialize per-image metrics 70 | metrics = defaultdict(dict) 71 | metric_classes = { 72 | "f1_score": F1_Metric, 73 | "mcc": MCC_Metric, 74 | "mAP": mAP_Metric, 75 | "auc": AUC_Metric, 76 | } 77 | for type in ["clean", "adv"]: 78 | for name, cls in metric_classes.items(): 79 | metrics[type][name] = cls() 80 | 81 | # Cache all predictions 82 | all_preds = {"clean": defaultdict(list), "adv": defaultdict(list)} 83 | 84 | # Loop through dataset 85 | for i in tqdm(range(len(self.dataset))): 86 | # Store per-image predictions 87 | img_pred = {} 88 | 89 | data = self.dataset[i] 90 | clean_img = data["img"] 91 | 92 | # Perform prediction on clean image 93 | img_pred["clean"] = self.model.predict(clean_img) 94 | 95 | # Generate adversarial image 96 | adv_img = self.attacker(self.model, data) 97 | 98 | # Perform prediction on adversarial image 99 | img_pred["adv"] = self.model.predict(adv_img) 100 | 101 | # Account for NaN values 102 | for pred in img_pred.values(): 103 | if np.isnan(pred["ms"]).any(): 104 | print("WARNING: NaN values in localization prediction scores!") 105 | pred["ms"][np.isnan(pred["ms"])] = 0 106 | 107 | if np.isnan(pred["score"]): 108 | print("WARNING: NaN values in detection prediction scores!") 109 | pred["score"] = 0 110 | 111 | # Perform per-image evaluations 112 | for type, ms in metrics.items(): 113 | for _, m in ms.items(): 114 | m.update(data["map"], img_pred[type]["ms"]) 115 | 116 | # Visualize some examples 117 | if self.vis_dir and i % self.vis_every == 0: 118 | self._vis_preds(i, data, img_pred, clean_img, adv_img) 119 | 120 | # If image sizes different, resize to a consistent shape 121 | if resize: 122 | data["map"] = cv2.resize( 123 | data["map"], resize[::-1], interpolation=cv2.INTER_LINEAR 124 | ) 125 | img_pred["clean"]["ms"] = cv2.resize( 126 | img_pred["clean"]["ms"], 127 | resize[::-1], 128 | interpolation=cv2.INTER_LINEAR, 129 | ) 130 | img_pred["adv"]["ms"] = cv2.resize( 131 | img_pred["adv"]["ms"], resize[::-1], interpolation=cv2.INTER_LINEAR 132 | ) 133 | 134 | # Cache predictions 135 | for type, preds in all_preds.items(): 136 | # Store ground-truths 137 | preds["y_true"].append(data["label"]) 138 | preds["label_map"].append(data["map"]) 139 | 140 | # Store predictions 141 | preds["y_score"].append(img_pred[type]["score"]) 142 | preds["score_map"].append(img_pred[type]["ms"]) 143 | 144 | # Compute per-image evaluation metrics 145 | for type, ms in metrics.items(): 146 | for metric_name, m in ms.items(): 147 | self.results[type][metric_name] = m.compute() 148 | 149 | # Consolidate cached predictions 150 | for type, preds in all_preds.items(): 151 | preds["y_true"] = np.array(preds["y_true"]) 152 | preds["label_map"] = np.stack(preds["label_map"], axis=0) 153 | 154 | preds["y_score"] = np.array(preds["y_score"]) 155 | preds["score_map"] = np.stack(preds["score_map"], axis=0) 156 | 157 | # Save predictions 158 | # if save: 159 | # save_path = Path("artifacts/predictions") 160 | # np.save(save_path / "scores.npy", y_score) 161 | # np.save(save_path / "score_maps.npy", score_map) 162 | # np.save(save_path / "rt_ncuts.npy", ncut) 163 | 164 | # Compute rest of the metrics on cached predictions 165 | for type, r in self.results.items(): 166 | # Compute per-class IoU 167 | iou_spliced, iou_non_spliced, iou = self._compute_class_iou( 168 | all_preds[type]["label_map"], all_preds[type]["score_map"] 169 | ) 170 | r["iou_spliced"] = iou_spliced 171 | r["iou_non_spliced"] = iou_non_spliced 172 | r["iou"] = iou 173 | 174 | # Compute detection metrics 175 | r["AP"] = average_precision_score( 176 | all_preds[type]["y_true"], all_preds[type]["y_score"] 177 | ) 178 | 179 | return self.results 180 | 181 | # @staticmethod 182 | # def compute_optimal_iou(y_true, y_pred, batch_size=256): 183 | # # Check whether NaN values 184 | # if np.isnan(y_pred).any(): 185 | # print("WARNING: NaN values in localization prediction scores!") 186 | # y_pred[np.isnan(y_pred)] = 0 187 | 188 | # # Store all possible iou scores 189 | # thresholds = y_pred.flatten() 190 | # scores = np.zeros_like(thresholds) 191 | 192 | # for i in range(0, len(scores), batch_size): 193 | # threshs = thresholds[i : i + batch_size] # [B] 194 | 195 | # y_preds = y_pred.copy() 196 | # # [H, W, B] 197 | # y_preds = np.repeat(y_preds[..., None], batch_size, axis=-1) 198 | 199 | # # Perform thresholding 200 | # y_preds[y_preds < threshs] = 0 201 | # y_preds[y_preds >= threshs] = 1 202 | 203 | # # Compute scores 204 | # return iou( 205 | # torch.from_numpy(y_preds.transpose(2, 0, 1)), 206 | # torch.from_numpy(np.repeat(y_true[None, ...], batch_size, axis=0)) 207 | # ) 208 | 209 | # Compute iou score for each threshold 210 | # for i, thresh in tqdm(enumerate(thresholds)): 211 | # y_pred_thresh = np.zeros(y_pred.shape, dtype=np.uint8) 212 | # y_pred_thresh[y_pred >= thresh] = 1 213 | 214 | # scores[i] = jaccard_score(y_true.flatten(), y_pred_thresh.flatten()) 215 | 216 | return scores 217 | 218 | def _compute_class_iou(self, label_map, score_map) -> None: 219 | # FIXME Consider inverted score maps? 220 | 221 | # Find optimal threshold, and the corresponding score for each image 222 | 223 | # Compute for spliced regions 224 | _, iou_spliced = self.find_optimal_threshold(score_map, label_map) 225 | iou_spliced = iou_spliced.mean().item() 226 | 227 | # Compute for non-spliced regions 228 | invert_label_map = 1 - label_map 229 | invert_score_map = 1 - score_map 230 | 231 | _, iou_non_spliced = self.find_optimal_threshold( 232 | invert_score_map, invert_label_map 233 | ) 234 | iou_non_spliced = iou_non_spliced.mean().item() 235 | 236 | # Compute mean IoU 237 | iou = (iou_spliced + iou_non_spliced) / 2 238 | 239 | # FIXME Report per-class scores 240 | 241 | return iou_spliced, iou_non_spliced, iou 242 | 243 | @staticmethod 244 | def find_optimal_threshold( 245 | pred_mask: np.ndarray, groundtruth_masks: np.ndarray 246 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 247 | """https://codereview.stackexchange.com/questions/229341/pytorch-vectorized-implementation-for-thresholding-and-computing-jaccard-index 248 | 249 | Parameters 250 | ---------- 251 | pred_mask : np.ndarray (float32) 252 | [B, H, W], range [0, 1], probability prediction map 253 | groundtruth_masks : np.ndarray (uint8) 254 | [B, H, W], values one of {0, 1}, binary label map 255 | 256 | Returns 257 | ------- 258 | Tuple[torch.FloatTensor, torch.FloatTensor] 259 | [B], optimal thresholds for each image 260 | [B], corresponding jaccard scores for each image 261 | """ 262 | n_patch = groundtruth_masks.shape[0] 263 | 264 | groundtruth_masks_tensor = torch.from_numpy(groundtruth_masks) 265 | pred_mask_tensor = torch.from_numpy(pred_mask) 266 | 267 | # if USE_CUDA: 268 | # groundtruth_masks_tensor = groundtruth_masks_tensor.cuda() 269 | # pred_mask_tensor = pred_mask_tensor.cuda() 270 | 271 | vector_pred = pred_mask_tensor.view(n_patch, -1) 272 | vector_gt = groundtruth_masks_tensor.view(n_patch, -1) 273 | vector_pred, sort_pred_idx = torch.sort(vector_pred, descending=True) 274 | vector_gt = vector_gt[torch.arange(vector_gt.shape[0])[:, None], sort_pred_idx] 275 | gt_cumsum = torch.cumsum(vector_gt, dim=1) 276 | gt_total = gt_cumsum[:, -1].reshape(n_patch, 1) 277 | predicted = torch.arange(start=1, end=vector_pred.shape[1] + 1) 278 | # if USE_CUDA: 279 | # predicted = predicted.cuda() 280 | gt_cumsum = gt_cumsum.type(torch.float) 281 | gt_total = gt_total.type(torch.float) 282 | predicted = predicted.type(torch.float) 283 | jaccard_idx = gt_cumsum / (gt_total + predicted - gt_cumsum) 284 | max_jaccard_idx, max_indices = torch.max(jaccard_idx, dim=1) 285 | max_indices = max_indices.reshape(-1, 1) 286 | best_threshold = vector_pred[ 287 | torch.arange(vector_pred.shape[0])[:, None], max_indices 288 | ] 289 | best_threshold = best_threshold.reshape(-1) 290 | 291 | return best_threshold, max_jaccard_idx 292 | 293 | def _vis_preds(self, i, data, img_pred, clean_img, adv_img): 294 | plt.subplots(figsize=(32, 8)) 295 | plt.subplot(2, 4, 1) 296 | plt.title("Input Image") 297 | plt.imshow(clean_img.permute(1, 2, 0)) 298 | plt.axis("off") 299 | 300 | plt.subplot(2, 4, 2) 301 | plt.title("Adv Image") 302 | plt.imshow(adv_img.permute(1, 2, 0)) 303 | plt.axis("off") 304 | 305 | plt.subplot(2, 4, 3) 306 | plt.title("Cluster w/ MeanShift") 307 | plt.axis("off") 308 | plt.imshow(img_pred["clean"]["ms"], cmap="jet", vmin=0.0, vmax=1.0) 309 | 310 | plt.subplot(2, 4, 4) 311 | plt.title("Adv Cluster w/ MeanShift") 312 | plt.axis("off") 313 | plt.imshow(img_pred["adv"]["ms"], cmap="jet", vmin=0.0, vmax=1.0) 314 | 315 | plt.subplot(2, 4, 5) 316 | plt.title("Ground-truth Segment") 317 | plt.axis("off") 318 | plt.imshow(data["map"], vmin=0.0, vmax=1.0, cmap="gray") 319 | 320 | plt.subplot(2, 4, 6) 321 | plt.title("Ground-truth Segment") 322 | plt.axis("off") 323 | plt.imshow(data["map"], vmin=0.0, vmax=1.0, cmap="gray") 324 | 325 | plt.subplot(2, 4, 7) 326 | plt.title("Segment with NCuts") 327 | plt.axis("off") 328 | plt.imshow(img_pred["clean"]["ncuts"], vmin=0.0, vmax=1.0, cmap="gray") 329 | 330 | plt.subplot(2, 4, 8) 331 | plt.title("Adv Segment with NCuts") 332 | plt.axis("off") 333 | plt.imshow(img_pred["adv"]["ncuts"], vmin=0.0, vmax=1.0, cmap="gray") 334 | 335 | plt.tight_layout() 336 | plt.show() 337 | 338 | vis_dir = Path(self.vis_dir) 339 | plt.savefig(vis_dir / f"{self.dataset.__class__.__name__}_{i}.png") 340 | 341 | if self.logger: 342 | self.logger.log({"adv_example": self.logger.Image(plt)}) 343 | -------------------------------------------------------------------------------- /src/evaluation/evaluators_test.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Any, Dict, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from sklearn.metrics import average_precision_score 9 | from src.attacks import PatchLOTS 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | import matplotlib.pyplot as plt 13 | 14 | from .metrics import AUC_Metric, F1_Metric, MCC_Metric, mAP_Metric 15 | 16 | 17 | class Evaluator: 18 | def __init__( 19 | self, 20 | model, 21 | dataset: Dataset, 22 | perform_adv: bool = False, 23 | adv_step_size: int = None, 24 | adv_n_iter: int = None, 25 | vis_dir: str = None, 26 | vis_every: int = 1, 27 | logger=None, 28 | ) -> None: 29 | """Performs evaluation given the model and datasets. 30 | Performs adversarial perturbation on the dataset if specified. 31 | 32 | Parameters 33 | ---------- 34 | model : 35 | dataset : Dataset 36 | perform_adv : bool, optional 37 | Whether to perform adversarial perturbation on dataset, by default False 38 | adv_step_size : int, optional 39 | "Learning rate" of adversarial attack, by default None 40 | adv_n_iter : int, optional 41 | Number of iterations of adversarial attack, by default None 42 | vis_dir : str, optional 43 | Directory to save visualizations, not saved if not specified, by default None 44 | vis_every : int, optional 45 | To save visualizations every `vis_every` images, by default 1 46 | logger : , optional 47 | Wandb logger to log results, by default None 48 | """ 49 | # Freeze all network weights 50 | for parameter in model.net.parameters(): 51 | parameter.requires_grad = False 52 | model.net.eval() 53 | self.model = model 54 | 55 | self.dataset = dataset 56 | 57 | self.perform_adv = perform_adv 58 | self.adv_step_size = adv_step_size 59 | self.adv_n_iter = adv_n_iter 60 | 61 | self.vis_dir = vis_dir 62 | self.vis_every = vis_every 63 | self.logger = logger 64 | 65 | self.results = {"clean": {}} 66 | if self.perform_adv: 67 | self.results["adv"] = {} 68 | self.attacker = PatchLOTS() 69 | 70 | def __call__(self, resize: Tuple[int, int] = None) -> Dict[str, Any]: 71 | """ 72 | Parameters 73 | ---------- 74 | save : bool, optional 75 | Whether to save prediction arrays, by default False 76 | resize : Tuple[int, int], optional 77 | [H, W], whether to resize images / maps to a consistent shape 78 | 79 | Returns 80 | ------- 81 | Dict[str, Any] 82 | AP : float 83 | Average precision score, for detection 84 | IoU : float 85 | Class-balanced IoU, for localization 86 | f1_score : float 87 | for localization 88 | mcc : float 89 | Matthews Correlation Coefficient, for localization 90 | mAP : float 91 | Mean Average Precision, for localization 92 | auc : float 93 | Area under the Receiving Operating Characteristic Curve, for localization 94 | """ 95 | # E.g. ["clean", "adv"] 96 | metric_types = list(self.results.keys()) 97 | 98 | # Initialize per-image metrics 99 | metrics = defaultdict(dict) 100 | metric_classes = { 101 | "f1_score": F1_Metric, 102 | "mcc": MCC_Metric, 103 | "mAP": mAP_Metric, 104 | "auc": AUC_Metric, 105 | } 106 | for type in metric_types: 107 | for name, cls in metric_classes.items(): 108 | metrics[type][name] = cls() 109 | 110 | # Cache all predictions 111 | all_preds = {t: defaultdict(list) for t in metric_types} 112 | 113 | # Loop through dataset 114 | for i in tqdm(range(len(self.dataset))): 115 | # Store per-image predictions 116 | img_pred = {} 117 | 118 | data = self.dataset[i] 119 | clean_img = data["img"] 120 | 121 | # Perform prediction on clean image 122 | img_pred["clean"] = self.model.predict(clean_img) 123 | 124 | if self.perform_adv: 125 | # Generate adversarial image 126 | adv_img = self.attacker( 127 | self.model, data, self.adv_step_size, self.adv_n_iter 128 | ) 129 | 130 | # Perform prediction on adversarial image 131 | img_pred["adv"] = self.model.predict(adv_img) 132 | 133 | # Account for NaN values 134 | for pred in img_pred.values(): 135 | if np.isnan(pred["ms"]).any(): 136 | print("WARNING: NaN values in localization prediction scores!") 137 | pred["ms"][np.isnan(pred["ms"])] = 0 138 | 139 | if np.isnan(pred["score"]): 140 | print("WARNING: NaN values in detection prediction scores!") 141 | pred["score"] = 0 142 | 143 | # Perform per-image evaluations 144 | for type, ms in metrics.items(): 145 | for _, m in ms.items(): 146 | m.update(data["map"], img_pred[type]["ms"]) 147 | 148 | # Visualize some examples 149 | if self.perform_adv and self.vis_dir and i % self.vis_every == 0: 150 | self._vis_preds(i, data, img_pred, clean_img, adv_img) 151 | 152 | # If image sizes different, resize to a consistent shape 153 | if resize: 154 | data["map"] = cv2.resize( 155 | data["map"], resize[::-1], interpolation=cv2.INTER_LINEAR 156 | ) 157 | img_pred["clean"]["ms"] = cv2.resize( 158 | img_pred["clean"]["ms"], 159 | resize[::-1], 160 | interpolation=cv2.INTER_LINEAR, 161 | ) 162 | if self.perform_adv: 163 | img_pred["adv"]["ms"] = cv2.resize( 164 | img_pred["adv"]["ms"], 165 | resize[::-1], 166 | interpolation=cv2.INTER_LINEAR, 167 | ) 168 | 169 | # Cache predictions 170 | for type, preds in all_preds.items(): 171 | # Store ground-truths 172 | preds["y_true"].append(data["label"]) 173 | preds["label_map"].append(data["map"]) 174 | 175 | # Store predictions 176 | preds["y_score"].append(img_pred[type]["score"]) 177 | preds["score_map"].append(img_pred[type]["ms"]) 178 | 179 | # Compute per-image evaluation metrics 180 | for type, ms in metrics.items(): 181 | for metric_name, m in ms.items(): 182 | self.results[type][metric_name] = m.compute() 183 | 184 | # Consolidate cached predictions 185 | for type, preds in all_preds.items(): 186 | preds["y_true"] = np.array(preds["y_true"]) 187 | preds["label_map"] = np.stack(preds["label_map"], axis=0) 188 | 189 | preds["y_score"] = np.array(preds["y_score"]) 190 | preds["score_map"] = np.stack(preds["score_map"], axis=0) 191 | 192 | # Save predictions 193 | # if save: 194 | # save_path = Path("artifacts/predictions") 195 | # np.save(save_path / "scores.npy", y_score) 196 | # np.save(save_path / "score_maps.npy", score_map) 197 | # np.save(save_path / "rt_ncuts.npy", ncut) 198 | 199 | # Compute rest of the metrics on cached predictions 200 | for type, r in self.results.items(): 201 | # Compute per-class IoU 202 | iou_spliced, iou_non_spliced, iou = self._compute_class_iou( 203 | all_preds[type]["label_map"], all_preds[type]["score_map"] 204 | ) 205 | r["iou_spliced"] = iou_spliced 206 | r["iou_non_spliced"] = iou_non_spliced 207 | r["iou"] = iou 208 | 209 | # Compute detection metrics 210 | r["AP"] = average_precision_score( 211 | all_preds[type]["y_true"], all_preds[type]["y_score"] 212 | ) 213 | 214 | return self.results 215 | 216 | # @staticmethod 217 | # def compute_optimal_iou(y_true, y_pred, batch_size=256): 218 | # # Check whether NaN values 219 | # if np.isnan(y_pred).any(): 220 | # print("WARNING: NaN values in localization prediction scores!") 221 | # y_pred[np.isnan(y_pred)] = 0 222 | 223 | # # Store all possible iou scores 224 | # thresholds = y_pred.flatten() 225 | # scores = np.zeros_like(thresholds) 226 | 227 | # for i in range(0, len(scores), batch_size): 228 | # threshs = thresholds[i : i + batch_size] # [B] 229 | 230 | # y_preds = y_pred.copy() 231 | # # [H, W, B] 232 | # y_preds = np.repeat(y_preds[..., None], batch_size, axis=-1) 233 | 234 | # # Perform thresholding 235 | # y_preds[y_preds < threshs] = 0 236 | # y_preds[y_preds >= threshs] = 1 237 | 238 | # # Compute scores 239 | # return iou( 240 | # torch.from_numpy(y_preds.transpose(2, 0, 1)), 241 | # torch.from_numpy(np.repeat(y_true[None, ...], batch_size, axis=0)) 242 | # ) 243 | 244 | # Compute iou score for each threshold 245 | # for i, thresh in tqdm(enumerate(thresholds)): 246 | # y_pred_thresh = np.zeros(y_pred.shape, dtype=np.uint8) 247 | # y_pred_thresh[y_pred >= thresh] = 1 248 | 249 | # scores[i] = jaccard_score(y_true.flatten(), y_pred_thresh.flatten()) 250 | 251 | return scores 252 | 253 | def _compute_class_iou(self, label_map, score_map) -> None: 254 | # FIXME Consider inverted score maps? 255 | 256 | # Find optimal threshold, and the corresponding score for each image 257 | 258 | # Compute for spliced regions 259 | _, iou_spliced = self.find_optimal_threshold(score_map, label_map) 260 | iou_spliced = iou_spliced.mean().item() 261 | 262 | # Compute for non-spliced regions 263 | invert_label_map = 1 - label_map 264 | invert_score_map = 1 - score_map 265 | 266 | _, iou_non_spliced = self.find_optimal_threshold( 267 | invert_score_map, invert_label_map 268 | ) 269 | iou_non_spliced = iou_non_spliced.mean().item() 270 | 271 | # Compute mean IoU 272 | iou = (iou_spliced + iou_non_spliced) / 2 273 | 274 | # FIXME Report per-class scores 275 | 276 | return iou_spliced, iou_non_spliced, iou 277 | 278 | @staticmethod 279 | def find_optimal_threshold( 280 | pred_mask: np.ndarray, groundtruth_masks: np.ndarray 281 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 282 | """https://codereview.stackexchange.com/questions/229341/pytorch-vectorized-implementation-for-thresholding-and-computing-jaccard-index 283 | 284 | Parameters 285 | ---------- 286 | pred_mask : np.ndarray (float32) 287 | [B, H, W], range [0, 1], probability prediction map 288 | groundtruth_masks : np.ndarray (uint8) 289 | [B, H, W], values one of {0, 1}, binary label map 290 | 291 | Returns 292 | ------- 293 | Tuple[torch.FloatTensor, torch.FloatTensor] 294 | [B], optimal thresholds for each image 295 | [B], corresponding jaccard scores for each image 296 | """ 297 | n_patch = groundtruth_masks.shape[0] 298 | 299 | groundtruth_masks_tensor = torch.from_numpy(groundtruth_masks) 300 | pred_mask_tensor = torch.from_numpy(pred_mask) 301 | 302 | # if USE_CUDA: 303 | # groundtruth_masks_tensor = groundtruth_masks_tensor.cuda() 304 | # pred_mask_tensor = pred_mask_tensor.cuda() 305 | 306 | vector_pred = pred_mask_tensor.view(n_patch, -1) 307 | vector_gt = groundtruth_masks_tensor.view(n_patch, -1) 308 | vector_pred, sort_pred_idx = torch.sort(vector_pred, descending=True) 309 | vector_gt = vector_gt[torch.arange(vector_gt.shape[0])[:, None], sort_pred_idx] 310 | gt_cumsum = torch.cumsum(vector_gt, dim=1) 311 | gt_total = gt_cumsum[:, -1].reshape(n_patch, 1) 312 | predicted = torch.arange(start=1, end=vector_pred.shape[1] + 1) 313 | # if USE_CUDA: 314 | # predicted = predicted.cuda() 315 | gt_cumsum = gt_cumsum.type(torch.float) 316 | gt_total = gt_total.type(torch.float) 317 | predicted = predicted.type(torch.float) 318 | jaccard_idx = gt_cumsum / (gt_total + predicted - gt_cumsum) 319 | max_jaccard_idx, max_indices = torch.max(jaccard_idx, dim=1) 320 | max_indices = max_indices.reshape(-1, 1) 321 | best_threshold = vector_pred[ 322 | torch.arange(vector_pred.shape[0])[:, None], max_indices 323 | ] 324 | best_threshold = best_threshold.reshape(-1) 325 | 326 | return best_threshold, max_jaccard_idx 327 | 328 | def _vis_preds(self, i, data, img_pred, clean_img, adv_img): 329 | plt.subplots(figsize=(32, 8)) 330 | plt.subplot(2, 4, 1) 331 | plt.title("Input Image") 332 | plt.imshow(clean_img.permute(1, 2, 0)) 333 | plt.axis("off") 334 | 335 | plt.subplot(2, 4, 2) 336 | plt.title("Adv Image") 337 | plt.imshow(adv_img.permute(1, 2, 0)) 338 | plt.axis("off") 339 | 340 | plt.subplot(2, 4, 3) 341 | plt.title("Cluster w/ MeanShift") 342 | plt.axis("off") 343 | plt.imshow(img_pred["clean"]["ms"], cmap="jet", vmin=0.0, vmax=1.0) 344 | 345 | plt.subplot(2, 4, 4) 346 | plt.title("Adv Cluster w/ MeanShift") 347 | plt.axis("off") 348 | plt.imshow(img_pred["adv"]["ms"], cmap="jet", vmin=0.0, vmax=1.0) 349 | 350 | plt.subplot(2, 4, 5) 351 | plt.title("Ground-truth Segment") 352 | plt.axis("off") 353 | plt.imshow(data["map"], vmin=0.0, vmax=1.0, cmap="gray") 354 | 355 | plt.subplot(2, 4, 6) 356 | plt.title("Ground-truth Segment") 357 | plt.axis("off") 358 | plt.imshow(data["map"], vmin=0.0, vmax=1.0, cmap="gray") 359 | 360 | plt.subplot(2, 4, 7) 361 | plt.title("Segment with NCuts") 362 | plt.axis("off") 363 | plt.imshow(img_pred["clean"]["ncuts"], vmin=0.0, vmax=1.0, cmap="gray") 364 | 365 | plt.subplot(2, 4, 8) 366 | plt.title("Adv Segment with NCuts") 367 | plt.axis("off") 368 | plt.imshow(img_pred["adv"]["ncuts"], vmin=0.0, vmax=1.0, cmap="gray") 369 | 370 | plt.tight_layout() 371 | plt.show() 372 | 373 | vis_dir = Path(self.vis_dir) 374 | plt.savefig(vis_dir / f"{self.dataset.__class__.__name__}_{i}.png") 375 | 376 | if self.logger: 377 | self.logger.log({"adv_example": self.logger.Image(plt)}) 378 | --------------------------------------------------------------------------------