├── scripts ├── run_stylegan.sh ├── plot_ffhq.sh ├── preprocess_ffhq_1024.sh ├── multi_train_ffhq_run.sh ├── train_ffhq_stylegan.sh ├── multi_train_lsun_run.sh ├── setup-conda-env.sh ├── prepare_ffhq_1024.sh ├── prepare_dataset_celebA.sh ├── prepare_dataset_celebA_log_haar.sh ├── preprocess_ffhq_style_style2.sh ├── unzinp_ffhq.py ├── prepare_dataset_lsun_log_db2_boundary.sh ├── prepare_dataset_lsun_log_db4_boundary.sh ├── prepare_dataset_lsun_log_sym2_boundary.sh ├── prepare_dataset_lsun_log_sym4_boundary.sh ├── prepare_dataset_celebA_log_db2_boundary.sh ├── prepare_dataset_celebA_log_db4_boundary.sh ├── prepare_dataset_celebA_log_sym2_boundary.sh ├── prepare_dataset_celebA_log_sym4_boundary.sh ├── train_celebA_cnn.sh ├── calc_acc_from_conf_mat.py ├── multi_train_celabA.py ├── multi_train_lsun.py ├── run_wavelet_plots.sh ├── prepare_dataset.sh ├── confusion_matrix.sh ├── train_celebA.sh ├── multi_train_ffhq.py ├── train_lsun.sh └── baseline.sh ├── src └── freqdect │ ├── baselines │ ├── __init__.py │ ├── README.md │ ├── utils.py │ ├── knn.py │ ├── classifier.py │ ├── prnu.py │ ├── eigenface.py │ └── baselines.py │ ├── __init__.py │ ├── __main__.py │ ├── version.py │ ├── fourier_math.py │ ├── plot_linear_classifier.py │ ├── resize_images_folder.py │ ├── crop_celeba.py │ ├── corruption.py │ ├── plot_accuracy_simple.py │ ├── crop_lsun.py │ ├── data_loader.py │ ├── wavelet_math.py │ ├── saliency_process.py │ ├── models.py │ ├── wavelet_plot.py │ ├── saliency.py │ ├── plot_mean_packets.py │ ├── confusion_matrix.py │ ├── train_classifier.py │ └── plot_accuracy_results.py ├── tests ├── __init__.py ├── test_version.py ├── test_welford.py ├── test_corruption.py └── test_packets.py ├── img └── packet_visualization2.png ├── setup.py ├── .readthedocs.yml ├── MANIFEST.in ├── CITATION.bib ├── .bumpversion.cfg ├── .flake8 ├── .github └── workflows │ └── tests.yml ├── setup.cfg ├── CONTRIBUTING.rst ├── tox.ini ├── .gitignore └── README.md /scripts/run_stylegan.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/freqdect/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests for :mod:`freqdect`.""" 4 | -------------------------------------------------------------------------------- /scripts/plot_ffhq.sh: -------------------------------------------------------------------------------- 1 | python -m freqdect.plot_accuracy_results.py ./log/source_data_raw_regression -------------------------------------------------------------------------------- /src/freqdect/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Detect GANs using frequency domain methods.""" 4 | -------------------------------------------------------------------------------- /img/packet_visualization2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gan-police/frequency-forensics/HEAD/img/packet_visualization2.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Setup module for :mod:`freqdect`.""" 4 | 5 | import setuptools 6 | 7 | if __name__ == "__main__": 8 | setuptools.setup() 9 | -------------------------------------------------------------------------------- /src/freqdect/baselines/README.md: -------------------------------------------------------------------------------- 1 | Baseline code licensed under: 2 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/LICENSE 3 | 4 | We thank @Joool (Joel Frank) for making the code available. -------------------------------------------------------------------------------- /scripts/preprocess_ffhq_1024.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python src/freqdect/prepare_dataset_batched.py ./data/ffhq_stylegan_large --batch-size 100 --train-size 60000 --val-size 2000 --test-size 8000 --packets 2 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # See: https://docs.readthedocs.io/en/latest/config-file/v2.html 2 | 3 | version: 2 4 | 5 | build: 6 | image: latest 7 | 8 | python: 9 | version: 3.8 10 | install: 11 | - method: pip 12 | path: . 13 | extra_requirements: 14 | - docs 15 | -------------------------------------------------------------------------------- /src/freqdect/__main__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Entrypoint module, in case you use `python -m freqdect`. 4 | 5 | Why does this file exist, and why ``__main__``? For more info, read: 6 | 7 | - https://www.python.org/dev/peps/pep-0338/ 8 | - https://docs.python.org/3/using/cmdline.html#cmdoption-m 9 | """ 10 | 11 | from .cli import main 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /scripts/multi_train_ffhq_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=multi_train_ffhq_log 6 | #SBATCH --output=train_multi_ffhq-%j.out 7 | #SBATCH --error=train_multi_ffhq-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=5 11 | #SBATCH --time=12:00:00 12 | #SBATCH --mem=10gb 13 | 14 | PYTHONPATH=. python ./scripts/multi_train_ffhq.py 15 | -------------------------------------------------------------------------------- /scripts/train_ffhq_stylegan.sh: -------------------------------------------------------------------------------- 1 | for i in 0 1 2 3 4 2 | do 3 | echo "packet experiment no: $i " 4 | python -m freqdect.train_classifier --features packets --seed $i --data-prefix /nvme/mwolter/source_data_log_packets 5 | done 6 | 7 | for i in 0 1 2 3 4 8 | do 9 | echo "packet experiment no: $i " 10 | python -m freqdect.train_classifier --features raw --seed $i --data-prefix /nvme/mwolter/source_data_raw 11 | done 12 | 13 | -------------------------------------------------------------------------------- /scripts/multi_train_lsun_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=multi_train_ffhq_log 6 | #SBATCH --output=train_multi_ffhq-%j.out 7 | #SBATCH --error=train_multi_ffhq-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=4 11 | #SBATCH --time=10:00:00 12 | #SBATCH --mem=10gb 13 | 14 | PYTHONPATH=. python ./scripts/multi_train_lsun.py 15 | 16 | -------------------------------------------------------------------------------- /scripts/setup-conda-env.sh: -------------------------------------------------------------------------------- 1 | conda create \ 2 | python=3.8 \ # use python 3.8 3 | scikit-learn-intelex \ #intel sklearn extension 4 | tqdm \ 5 | matplotlib \ 6 | seaborn \ 7 | PyWavelets \ 8 | scipy \ 9 | pillow \ 10 | opencv \ 11 | scikit-learn \ 12 | -c intel \ # use intel channel 13 | -c conda-forge \ # use conda-forge channel 14 | --prefix=~/env/intel38 # install environment in ~/env with the name intel38 15 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | graft tests 3 | prune scripts 4 | prune notebooks 5 | prune log 6 | prune img 7 | 8 | recursive-include docs/source *.py 9 | recursive-include docs/source *.rst 10 | recursive-include docs/source *.png 11 | 12 | global-exclude *.py[cod] __pycache__ *.so *.dylib .DS_Store *.gpickle 13 | 14 | include README.md LICENSE docs/Makefile 15 | exclude tox.ini .flake8 .bumpversion.cfg .readthedocs.yml CONTRIBUTING.rst CITATION.bib 16 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Trivial version test.""" 4 | 5 | import unittest 6 | 7 | from src.freqdect.version import get_version 8 | 9 | 10 | class TestVersion(unittest.TestCase): 11 | """Trivially test a version.""" 12 | 13 | def test_version_type(self): 14 | """Test the version is a string. 15 | 16 | This is only meant to be an example test. 17 | """ 18 | version = get_version() 19 | self.assertIsInstance(version, str) 20 | -------------------------------------------------------------------------------- /scripts/prepare_ffhq_1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=prepare_dataset 6 | #SBATCH --output=prepare_dataset-%j.out 7 | #SBATCH --error=prepare_dataset-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=25 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=200gb 13 | 14 | python -m freqdect.prepare_dataset /nvme/mwolter/ffhq1024x1024 --train-size 5000 --val-size 2000 --test-size 500 --batch-size 100 --log-packets 15 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celeba_prepare_dataset 6 | #SBATCH --output=prepare_dataset-%j.out 7 | #SBATCH --error=prepare_dataset-%j.err 8 | #SBATCH --ntasks=1 9 | #SBATCH -p gpu 10 | #SBATCH --gres gpu:v100:1 11 | #SBATCH --cpus-per-task=16 12 | 13 | python src/freqdect/prepare_dataset_batched.py \ 14 | /home/ndv/projects/wavelets/datasets_moritz/celeba_align_png_cropped \ 15 | --train-size 100000 --test-size 30000 --val-size 20000 -p -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @article{wolter2022waveletpacket, 2 | title = {Wavelet-Packets for Deepfake Image Analysis and Detection}, 3 | author = {Moritz Wolter and Felix Blanke and Raoul Heese and Jochen Garcke}, 4 | journal = {Machine Learning}, 5 | year = {2022}, 6 | volume = {Special Issue of the ECML PKDD 2022 Journal Track}, 7 | pages = {1-33}, 8 | month = {August}, 9 | url = {https://rdcu.be/cUIRt}, 10 | issn = {0885-6125}, 11 | doi = {https://doi.org/10.1007/s10994-022-06225-5} 12 | } 13 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA_log_haar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celebA_log_haar 6 | #SBATCH --output=prepare_dataset_celebA_log_haar-%j.out 7 | #SBATCH --error=prepare_dataset_celebA_log_haar-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/celeba/celeba_align_png_cropped \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 --log-packets -------------------------------------------------------------------------------- /scripts/preprocess_ffhq_style_style2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=prepare_ffhq 6 | #SBATCH --output=prepare_ffhq-%j.out 7 | #SBATCH --error=prepare_ffhq-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=10 11 | #SBATCH --time=12:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset /nvme/mwolter/ffhq128/source_data --log-packets --wavelet haar --boundary boundary --train-size 62000 15 | python -m freqdect.prepare_dataset /nvme/mwolter/ffhq128/source_data --raw --train-size 62000 -------------------------------------------------------------------------------- /scripts/unzinp_ffhq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def main(): 6 | os.system("pwd") 7 | ffhq_folder = "./data/ffhq_large/ffhq_1024/" 8 | ffhq_zip_files = ffhq_folder + "*.zip" 9 | 10 | files = glob.glob(ffhq_zip_files) 11 | files.sort() 12 | 13 | for file in files: 14 | folder = file.split("/")[-1].split("-")[0] 15 | os.system("unzip " + file) 16 | os.system("mv " + folder + "/*.png " + ffhq_folder) 17 | os.system("rmdir " + folder) 18 | os.system("rm " + file) 19 | print("file done") 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_lsun_log_db2_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=lsun_log_db2_boundary 6 | #SBATCH --output=prepare_dataset_lsun_log_db2_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_lsun_log_db2_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/lsun/lsun_bedroom_200k_png \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet db2 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_lsun_log_db4_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=lsun_log_db4_boundary 6 | #SBATCH --output=prepare_dataset_lsun_log_db4_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_lsun_log_db4_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/lsun/lsun_bedroom_200k_png \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet db4 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_lsun_log_sym2_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=lsun_log_sym2_boundary 6 | #SBATCH --output=prepare_dataset_lsun_log_sym2_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_lsun_log_sym2_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/lsun/lsun_bedroom_200k_png \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet sym2 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_lsun_log_sym4_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=lsun_log_sym4_boundary 6 | #SBATCH --output=prepare_dataset_lsun_log_sym4_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_lsun_log_sym4_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/lsun/lsun_bedroom_200k_png \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet sym4 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA_log_db2_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celebA_log_db2_boundary 6 | #SBATCH --output=prepare_dataset_celebA_log_db2_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_celebA_log_db2_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/celeba/celeba_align_png_cropped \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet db2 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA_log_db4_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celebA_log_db4_boundary 6 | #SBATCH --output=prepare_dataset_celebA_log_db4_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_celebA_log_db4_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/celeba/celeba_align_png_cropped \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet db4 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA_log_sym2_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celebA_log_db2_boundary 6 | #SBATCH --output=prepare_dataset_celebA_log_db2_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_celebA_log_db2_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/celeba/celeba_align_png_cropped \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet sym2 --boundary boundary 18 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_celebA_log_sym4_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celebA_log_sym4_boundary 6 | #SBATCH --output=prepare_dataset_celebA_log_sym4_boundary-%j.out 7 | #SBATCH --error=prepare_dataset_celebA_log_sym4_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | python -m freqdect.prepare_dataset \ 15 | /nvme/mwolter/celeba/celeba_align_png_cropped \ 16 | --train-size 100000 --test-size 30000 --val-size 20000 \ 17 | --log-packets --wavelet sym4 --boundary boundary 18 | -------------------------------------------------------------------------------- /tests/test_welford.py: -------------------------------------------------------------------------------- 1 | """Code testing the welford estimator against numpy functions.""" 2 | 3 | import numpy as np 4 | import torch 5 | from src.freqdect.prepare_dataset import WelfordEstimator 6 | 7 | 8 | def test_welford() -> None: 9 | """Test the welford estimator.""" 10 | test_data = np.random.randn(2000, 128, 128, 3) 11 | np_mean = np.mean(test_data, axis=(0, 1, 2)) 12 | np_std = np.std(test_data, axis=(0, 1, 2)) 13 | 14 | welford = WelfordEstimator() 15 | for test_el in test_data: 16 | welford.update(torch.from_numpy(test_el)) 17 | 18 | welford_mean, welford_std = welford.finalize() 19 | assert np.allclose(np_mean, welford_mean) # noqa: S101 20 | assert np.allclose(np_std, welford_std) # noqa: S101 21 | -------------------------------------------------------------------------------- /scripts/train_celebA_cnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celeba_train_cnn 6 | #SBATCH --output=train_celeba_packets_log_packets_cnn-%j.out 7 | #SBATCH --error=train_celeba_packets_log_packets_cnn-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=16 11 | #SBATCH --time=12:00:00 12 | #SBATCH --mem=200gb 13 | 14 | source activate ~/env/idp38 15 | 16 | for i in 0 1 2 3 4 17 | do 18 | echo "packet experiment no: $i " 19 | python -m freqdect.train_classifier \ 20 | --features packets \ 21 | --seed $i \ 22 | --epochs 20 \ 23 | --data-prefix /nvme/fblanke/celeba_align_png_cropped_log_packets_db4_boundary \ 24 | --nclasses 5 \ 25 | --calc-normalization \ 26 | --model cnn 27 | done 28 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.1-dev 3 | commit = True 4 | tag = False 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))? 6 | serialize = 7 | {major}.{minor}.{patch}-{release}+{build} 8 | {major}.{minor}.{patch}+{build} 9 | {major}.{minor}.{patch}-{release} 10 | {major}.{minor}.{patch} 11 | 12 | [bumpversion:part:release] 13 | optional_value = production 14 | first_value = dev 15 | values = 16 | dev 17 | production 18 | 19 | [bumpverion:part:build] 20 | values = [0-9A-Za-z-]+ 21 | 22 | [bumpversion:file:setup.cfg] 23 | search = version = {current_version} 24 | replace = version = {new_version} 25 | 26 | [bumpversion:file:docs/source/conf.py] 27 | search = release = '{current_version}' 28 | replace = release = '{new_version}' 29 | 30 | [bumpversion:file:src/freqdect/version.py] 31 | search = VERSION = '{current_version}' 32 | replace = VERSION = '{new_version}' 33 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | ######################### 2 | # Flake8 Configuration # 3 | # (.flake8) # 4 | ######################### 5 | [flake8] 6 | ignore = 7 | S301 # pickle 8 | S403 # pickle 9 | S404 10 | S603 11 | W503 # Line break before binary operator (flake8 is wrong) 12 | E203 # Ignore the spaces black puts before columns. 13 | E402 # allow path extensions for testing. 14 | DAR101 15 | DAR201 16 | N400 # flake and pylance disagree on linebreaks in strings. 17 | exclude = 18 | .tox, 19 | .git, 20 | __pycache__, 21 | docs/source/conf.py, 22 | build, 23 | dist, 24 | tests/fixtures/*, 25 | *.pyc, 26 | *.bib, 27 | *.egg-info, 28 | .cache, 29 | .eggs, 30 | data. 31 | src/freqdect/baselines/* 32 | max-line-length = 120 33 | max-complexity = 20 34 | import-order-style = pycharm 35 | application-import-names = 36 | freqdect 37 | tests 38 | format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s 39 | -------------------------------------------------------------------------------- /src/freqdect/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Version information for :mod:`freqdect`. 4 | 5 | Run with ``python -m freqdect.version`` 6 | """ 7 | 8 | import os 9 | from subprocess import CalledProcessError, check_output # noqa: S404 10 | 11 | __all__ = [ 12 | "VERSION", 13 | "get_version", 14 | "get_git_hash", 15 | ] 16 | 17 | VERSION = "0.0.1-dev" 18 | 19 | 20 | def get_git_hash() -> str: 21 | """Get the :mod:`freqdect` git hash.""" 22 | with open(os.devnull, "w") as devnull: 23 | try: 24 | ret = check_output( # noqa: S603,S607 25 | ["git", "rev-parse", "HEAD"], 26 | cwd=os.path.dirname(__file__), 27 | stderr=devnull, 28 | ) 29 | except CalledProcessError: 30 | return "UNHASHED" 31 | else: 32 | return ret.strip().decode("utf-8")[:8] 33 | 34 | 35 | def get_version(with_git_hash: bool = False): 36 | """Get the :mod:`freqdect` version string, including a git hash.""" 37 | return f"{VERSION}-{get_git_hash()}" if with_git_hash else VERSION 38 | 39 | 40 | if __name__ == "__main__": 41 | print(get_version(with_git_hash=True)) 42 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | lint: 7 | name: Lint 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [ 3.8, 3.9 ] 12 | toxenv: [ manifest, flake8, pyroma, mypy] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: pip install tox 21 | - name: Run tox 22 | run: tox -e ${{ matrix.toxenv }} 23 | tests: 24 | name: Tests 25 | runs-on: ${{ matrix.os }} 26 | strategy: 27 | matrix: 28 | os: [ ubuntu-latest ] 29 | python-version: [ 3.8, 3.9 ] 30 | steps: 31 | - uses: actions/checkout@v2 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | - name: Install dependencies 37 | run: pip install tox 38 | - name: Test with pytest 39 | run: 40 | tox -e py 41 | -------------------------------------------------------------------------------- /scripts/calc_acc_from_conf_mat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def calc_mean_std_accs(wavelet, mode): 4 | acc_lst = [] 5 | known_acc_lst = [] 6 | unknown_acc_lst = [] 7 | 8 | for seed in range(4): 9 | matrix = np.load(open(f"confusion-matrix-lsun_bedroom_200k_png_log_packets_{wavelet}_boundary_missing_4_{mode}_{seed}-generalized.npy", "rb")) 10 | 11 | acc_lst.append((matrix[0, 0] + matrix[1:, 1].sum()) / matrix.sum()) 12 | known_acc_lst.append((matrix[0, 0] + matrix[1:-1, 1].sum()) / matrix[:-1, :].sum()) 13 | unknown_acc_lst.append(matrix[-1, 1] / matrix[-1, :].sum()) 14 | 15 | max = (np.max(acc_lst), np.max(known_acc_lst), np.max(unknown_acc_lst)) 16 | mean = (np.mean(acc_lst), np.mean(known_acc_lst), np.mean(unknown_acc_lst)) 17 | std = (np.std(acc_lst), np.std(known_acc_lst), np.std(unknown_acc_lst)) 18 | 19 | return max, mean, std 20 | 21 | for mode in ["regression"]: 22 | print(f"{mode}:") 23 | for wavelet in ["db4"]: 24 | max, mean, std = calc_mean_std_accs(wavelet, mode) 25 | print(f"{wavelet} & {100*max[0]:.2f}\,\% & {100*mean[0]:.2f} $\pm$ {100*std[0]:.2f}\,\% & \ 26 | {100*max[1]:.2f}\,\% & {100*mean[1]:.2f} $\pm$ {100*std[1]:.2f}\,\% & \ 27 | {100*max[2]:.2f}\,\% & {100*mean[2]:.2f} $\pm$ {100*std[2]:.2f}\,\%") 28 | -------------------------------------------------------------------------------- /scripts/multi_train_celabA.py: -------------------------------------------------------------------------------- 1 | # Running this script in sbatch will train multiple neural networks on the same gpu. 2 | import time 3 | import datetime 4 | import subprocess 5 | subprocess.call('pwd') 6 | 7 | print('running jobs in parallel') 8 | 9 | experiment_lst = [] 10 | for model in ['cnn']: 11 | for wavelet_str in ["sym3", "sym4", "sym5", "db3", "db4", "db5"]: 12 | for seed in [0, 1, 2, 3, 4]: 13 | experiment_lst.append( 14 | (["python", "-m", "freqdect.train_classifier", "--features", "packets", 15 | "--model", str(model), 16 | "--seed", str(seed), 17 | "--data-prefix", 18 | "/nvme/mwolter/celeba/celeba_align_png_cropped_log_packets_" + wavelet_str + "_boundary", 19 | "--nclasses", "5"], (str(model), str(seed), wavelet_str))) 20 | jobs = [] 21 | for exp_no, experiment in enumerate(experiment_lst): 22 | time.sleep(10) 23 | time_str = str(datetime.datetime.today()) 24 | print(experiment, ' at time:', time_str) 25 | file_name = f"{time_str}_celebA_{experiment[1][0]}_{experiment[1][1]}_{experiment[1][2]}.txt" 26 | with open(f"./log/out/{file_name}", "w") as file: 27 | jobs.append(subprocess.Popen(experiment[0], stdout=file)) 28 | if exp_no % 5 == 0 and exp_no > 0: 29 | for job in jobs: 30 | job.wait() 31 | jobs = [] 32 | -------------------------------------------------------------------------------- /scripts/multi_train_lsun.py: -------------------------------------------------------------------------------- 1 | # Running this script in sbatch will train multiple neural networks on the same gpu. 2 | import time 3 | import datetime 4 | import subprocess 5 | subprocess.call('pwd') 6 | 7 | print('running jobs in parallel') 8 | 9 | experiment_lst = [] 10 | for model in ['cnn']: 11 | for wavelet_str in ["db4"]: 12 | for seed in range(5): 13 | experiment_lst.append( 14 | (["python", "-m", "freqdect.train_classifier", "--features", "packets", 15 | "--model", str(model), 16 | "--seed", str(seed), 17 | "--data-prefix", 18 | "/nvme/fblanke/celeba_align_png_cropped_log_packets_" + wavelet_str + "_boundary", 19 | "--nclasses", "5", 20 | "--epochs", "20", "--calc-normalization"], (str(model), str(seed), wavelet_str))) 21 | jobs = [] 22 | for exp_no, experiment in enumerate(experiment_lst): 23 | time.sleep(100) 24 | time_str = str(datetime.datetime.today()) 25 | print(experiment, ' at time:', time_str) 26 | file_name = f"{time_str}_celeba_{experiment[1][0]}_{experiment[1][1]}_{experiment[1][2]}.txt" 27 | with open(f"./log/out/{file_name}", "w") as file: 28 | jobs.append(subprocess.Popen(experiment[0], stdout=file)) 29 | if exp_no % 4 == 0 and exp_no > 0: 30 | for job in jobs: 31 | job.wait() 32 | jobs = [] 33 | -------------------------------------------------------------------------------- /src/freqdect/fourier_math.py: -------------------------------------------------------------------------------- 1 | """Module implementing Fourier related math functions. 2 | 3 | The goal here is to make the Fourier-transform useful for 4 | image analysis and gan-content recognition. 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def batch_fourier_preprocessing(image_batch, eps=1e-12, log_scale=False): 12 | """Preprosess image batches by computing the Fourier-representation. 13 | 14 | The raw as well as an absolute log scaled version can be computed. 15 | 16 | Args: 17 | image_batch (np.array): An image of shape (B, H, W, C) 18 | eps: A small number to stabilize the logarithm. 19 | log_scale: Use log-scaling if True. 20 | Log-scaled coefficients aren't invertible. 21 | Default: False. 22 | 23 | Returns: 24 | [np.array]: The wavelet packets [B, H, W, C]. 25 | """ 26 | image_batch = torch.from_numpy(image_batch.astype(np.float32)).cuda() 27 | # transform to from H, W, C to C, H, W 28 | channels = [] 29 | for channel in range(image_batch.shape[-1]): 30 | channels.append(torch.fft.fft2((image_batch[..., channel]))) 31 | freq = torch.stack(channels, -1) 32 | del channels 33 | if log_scale: 34 | freq = torch.abs(freq) 35 | freq = torch.log(freq + eps) 36 | else: 37 | freq = torch.cat([torch.real(freq), torch.imag(freq)], -1) 38 | return freq.cpu().numpy() 39 | -------------------------------------------------------------------------------- /scripts/run_wavelet_plots.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=train_lsun 6 | #SBATCH --output=train_lsun-%j.out 7 | #SBATCH --error=train_lsun-%j.err 8 | #SBATCH --ntasks=1 9 | #SBATCH -p gpu 10 | #SBATCH --gres gpu:v100:1 11 | #SBATCH --cpus-per-task=16 12 | 13 | DATASETS="/home/ndv/projects/wavelets/datasets_moritz" 14 | DATASET_RAW="lsun_bedroom_200k_png_raw" 15 | DATASET_PACKETS="lsun_bedroom_200k_png_packets" 16 | 17 | ANACONDA_ENV="$HOME/myconda-env" 18 | 19 | # save working directory 20 | ORIG_PWD=${PWD} 21 | 22 | RAW_PREFIX=$DATASETS 23 | PACKETS_PREFIX=$DATASETS 24 | 25 | module load CUDA 26 | module load Anaconda3 27 | module load PyTorch 28 | source activate "$ANACONDA_ENV" 29 | 30 | pip install -q -e . 31 | 32 | 33 | if [ -f ${DATASETS}/${DATASET_RAW}.tar ]; then 34 | echo "Tarred raw input folder exists, copying to $TMPDIR" 35 | cp "${DATASETS}/${DATASET_RAW}.tar" "${TMPDIR}" 36 | cd "$TMPDIR" 37 | tar -xf "${DATASET_RAW}.tar" 38 | RAW_PREFIX="${TMPDIR}/${DATASET_RAW}" 39 | rm "${DATASET_RAW}.tar" 40 | fi 41 | 42 | cd "$ORIG_PWD" 43 | 44 | for i in 0 1 2 3 4 45 | do 46 | echo "raw experiment no: $i " 47 | python -m freqdect.train_classifier \ 48 | --features raw \ 49 | --seed $i \ 50 | --data-prefix "$RAW_PREFIX" \ 51 | --nclasses 4 \ 52 | --calc-normalization 53 | done 54 | 55 | if [ -f ${DATASETS}/${DATASET_RAW}.tar ]; then 56 | rm -r "${TMPDIR}/${DATASET_RAW}_*" 57 | fi -------------------------------------------------------------------------------- /scripts/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=prepare_dataset 6 | #SBATCH --output=prepare_dataset-%j.out 7 | #SBATCH --error=prepare_dataset-%j.err 8 | #SBATCH --ntasks=1 9 | #SBATCH -p gpu 10 | #SBATCH --gres gpu:v100:1 11 | #SBATCH --cpus-per-task=16 12 | 13 | echo prepare_dataset.sh started at `date +"%T"` 14 | 15 | ANACONDA_ENV="$HOME/myconda-env" 16 | 17 | DATASETS="/home/ndv/projects/wavelets/datasets" 18 | DATASET="lsun_bedroom_200k_png" 19 | DATA_PREFIX="${DATASETS}/${DATASET}" 20 | 21 | # save working directory 22 | ORIG_PWD=${PWD} 23 | 24 | module load PyTorch 25 | module load Pillow 26 | module load Anaconda3 27 | source activate "$ANACONDA_ENV" 28 | 29 | pip install -q -e . 30 | 31 | if [ -f ${DATASETS}/${DATASET}.tar ]; then 32 | echo "Tarred input folder exists, copying to $TMPDIR" 33 | cp "${DATASETS}/${DATASET}.tar" "$TMPDIR"/ 34 | cd "$TMPDIR" 35 | echo "Unpacking tarred input folder" 36 | tar xf "${DATASET}.tar" 37 | DATA_PREFIX="${TMPDIR}/${DATASET}" 38 | fi 39 | 40 | cd $ORIG_PWD 41 | 42 | echo "Preparing data" 43 | python -m freqdect.prepare_dataset_batched "$DATA_PREFIX" \ 44 | --train-size 100000 \ 45 | --test-size 30000 \ 46 | --val-size 20000 \ 47 | --packets 48 | 49 | if ls ${TMPDIR}/${DATASET}_* > /dev/null 2>&1; then 50 | echo "Copying results back to ${DATASETS}" 51 | cp -r ${TMPDIR}/${DATASET}_* ${DATASETS} 52 | fi 53 | 54 | echo prepare_dataset.sh finished at `date +"%T"` 55 | -------------------------------------------------------------------------------- /scripts/confusion_matrix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=confusion-matrix-celeba-CNN-packet 6 | #SBATCH --output=conf-db2-cnn-generalized.out 7 | #SBATCH --error=conf-db2-cnn-generalized.err 8 | #SBATCH --ntasks=1 9 | #SBATCH -p gpu 10 | #SBATCH --gres gpu:v100:1 11 | #SBATCH --cpus-per-task=4 12 | # Set time limit to override default limit 13 | #SBATCH --time=2:00:00 14 | 15 | ANACONDA_ENV="$HOME/env/idp38" 16 | 17 | DATASETS_DIR="/nvme/fblanke" 18 | 19 | LSUN_DATASET_LOGPACKETS="lsun_bedroom_200k_png_log_packets" 20 | LSUN_DATASET_PACKETS="lsun_bedroom_200k_png_packets" 21 | LSUN_DATASET_RAW="lsun_bedroom_200k_png_raw" 22 | 23 | WAVELET="db2" 24 | CHOSEN_DATASET=${LSUN_DATASET_LOGPACKETS}_${WAVELET}_boundary_missing_4 25 | BATCH_SIZE="2048" 26 | MODEL="cnn" 27 | SUFFIX="" 28 | NCLASSES="2" 29 | 30 | module load Anaconda3 31 | source activate "$ANACONDA_ENV" 32 | 33 | echo "confusion matrices for ${CHOSEN_DATASET}:" 34 | for i in 0 1 2 3 4 35 | do 36 | CLASSIFIER_FILE="${CHOSEN_DATASET}_${MODEL}_${i}${SUFFIX}" 37 | echo "confusion matrix for: $i " 38 | python -m freqdect.confusion_matrix \ 39 | --features packets \ 40 | --classifier-path log/${CLASSIFIER_FILE}.pt \ 41 | --data-prefix ${DATASETS_DIR}/${CHOSEN_DATASET} \ 42 | --calc-normalization \ 43 | --batch-size $BATCH_SIZE \ 44 | --nclasses ${NCLASSES} \ 45 | --model $MODEL \ 46 | --store-path confusion-matrix-${CHOSEN_DATASET}_${MODEL}_${i}-generalized.npy \ 47 | --generalized 48 | done 49 | echo "ended at `date +"%T"`" 50 | -------------------------------------------------------------------------------- /tests/test_corruption.py: -------------------------------------------------------------------------------- 1 | """Test the corruption code used for robustness testing.""" 2 | import sys 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from scipy import misc 7 | 8 | sys.path.append("./src") 9 | from src.freqdect.corruption import ( # noqa: I202 10 | jpeg_compression, 11 | random_resized_crop, 12 | random_rotation, 13 | ) 14 | 15 | 16 | def test_jpeg_compression(): 17 | """Test the jepeg compression function.""" 18 | face = Image.fromarray(misc.face()) 19 | compressed = np.array(jpeg_compression(face)) 20 | assert np.array(face).shape == compressed.shape # noqa: S101 21 | 22 | 23 | def test_rotation(): 24 | """Test the random rotation function from freqdect.corruption.""" 25 | face = Image.fromarray(misc.face()) 26 | rotated = random_rotation(face) 27 | assert face.size == rotated.size # noqa: S101 28 | 29 | 30 | def test_crop(): 31 | """Test the random cropping function from freqdect.corruption.""" 32 | face = Image.fromarray(misc.face()) 33 | crop = random_resized_crop(face) 34 | assert crop.size == face.size # noqa: S101 35 | 36 | 37 | if __name__ == "__main__": 38 | test_jpeg_compression() 39 | test_rotation() 40 | test_crop() 41 | 42 | import matplotlib.pyplot as plt 43 | 44 | face = Image.fromarray(misc.face()) 45 | plt.imshow(np.array(face)) 46 | plt.show() 47 | 48 | second_face = np.array(jpeg_compression(face)) 49 | 50 | plt.imshow(second_face) 51 | plt.show() 52 | 53 | third_face = random_resized_crop(face) 54 | plt.imshow(third_face) 55 | plt.show() 56 | -------------------------------------------------------------------------------- /scripts/train_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=celeba_train_log 6 | #SBATCH --output=train_celeba-%j.out 7 | #SBATCH --error=train_celeba-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=90gb 13 | 14 | for i in 0 1 2 3 4 15 | do 16 | echo "packet experiment no: $i " 17 | python -m freqdect.train_classifier \ 18 | --features packets \ 19 | --seed $i \ 20 | --data-prefix /nvme/mwolter/celeba/celeba_align_png_cropped_log_packets_sym2_boundary \ 21 | --nclasses 5 \ 22 | --calc-normalization 23 | done 24 | 25 | for i in 0 1 2 3 4 26 | do 27 | echo "packet experiment no: $i " 28 | python -m freqdect.train_classifier \ 29 | --features packets \ 30 | --seed $i \ 31 | --data-prefix /nvme/mwolter/celeba/celeba_align_png_cropped_log_packets_db2_boundary \ 32 | --nclasses 5 \ 33 | --calc-normalization 34 | done 35 | 36 | 37 | for i in 0 1 2 3 4 38 | do 39 | echo "packet experiment no: $i " 40 | python -m freqdect.train_classifier \ 41 | --features packets \ 42 | --seed $i \ 43 | --data-prefix /nvme/mwolter/celeba/celeba_align_png_cropped_log_packets_db2_boundary \ 44 | --nclasses 5 \ 45 | --calc-normalization \ 46 | --model cnn 47 | done 48 | 49 | for i in 0 1 2 3 4 50 | do 51 | echo "packet experiment no: $i " 52 | python -m freqdect.train_classifier \ 53 | --features packets \ 54 | --seed $i \ 55 | --data-prefix /nvme/mwolter/celeba/celeba_align_png_cropped_log_packets_sym2_boundary \ 56 | --nclasses 5 \ 57 | --calc-normalization \ 58 | --model cnn 59 | done 60 | -------------------------------------------------------------------------------- /tests/test_packets.py: -------------------------------------------------------------------------------- 1 | """Ensure ptwt and pywt packets are equivalent.""" 2 | 3 | import sys 4 | from itertools import product 5 | 6 | import numpy as np 7 | import pytest 8 | import pywt 9 | import torch 10 | from scipy import misc 11 | 12 | sys.path.append("./src") 13 | from src.freqdect.wavelet_math import ( # noqa: I202 14 | compute_pytorch_packet_representation_2d_tensor, 15 | ) 16 | 17 | 18 | def compute_pywt_packet_representation_2d_tensor( 19 | data, wavelet_str: str = "db5", max_lev: int = 5 20 | ): 21 | """To Ensure pywt and ptwt equivalence compute pywt packets.""" 22 | wavelet = pywt.Wavelet(wavelet_str) 23 | pywt_wp_tree = pywt.WaveletPacket2D(data=data, wavelet=wavelet, mode="reflect") 24 | 25 | # get the pytorch decomposition 26 | # batch_size = pt_data.shape[0] 27 | wp_keys = list(product(["a", "h", "v", "d"], repeat=max_lev)) 28 | packet_list = [] 29 | for node in wp_keys: 30 | packet = pywt_wp_tree["".join(node)].data 31 | packet_list.append(packet) 32 | 33 | wp_py = np.stack(packet_list, axis=0) 34 | return wp_py 35 | 36 | 37 | @pytest.mark.slow 38 | def test_packets(): 39 | """Runs the pywt ptwt comparison test.""" 40 | face = misc.face()[256:512, 256:512] 41 | grey_face = np.mean(face, axis=-1).astype(np.float64) 42 | # add batch dimension. 43 | pt_face = torch.unsqueeze(torch.from_numpy(grey_face), 0) 44 | py_packets = compute_pywt_packet_representation_2d_tensor( 45 | pt_face.squeeze(0).numpy(), "haar" 46 | ) 47 | pt_packets = ( 48 | compute_pytorch_packet_representation_2d_tensor(pt_face, "haar") 49 | .squeeze(0) 50 | .numpy() 51 | ) 52 | assert np.allclose(py_packets, pt_packets) # noqa: S101 53 | 54 | 55 | if __name__ == "__main__": 56 | """Run the test.""" 57 | test_packets() 58 | -------------------------------------------------------------------------------- /src/freqdect/plot_linear_classifier.py: -------------------------------------------------------------------------------- 1 | """Code to visualize a linear classifier trained on wavelet packets.""" 2 | 3 | import argparse 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tikzplotlib as tikz 8 | import torch 9 | 10 | from .models import Regression 11 | from .plot_mean_packets import generate_frequency_packet_image 12 | 13 | 14 | def _parse_args(): 15 | """Parse the command line.""" 16 | parser = argparse.ArgumentParser( 17 | description="Plot the weights of our linear regressor." 18 | ) 19 | parser.add_argument( 20 | "model_path", type=str, help="path to the linear model to plot." 21 | ) 22 | parser.add_argument( 23 | "--classes", type=int, help="number of classes in the classifier.", default=2 24 | ) 25 | return parser.parse_args() 26 | 27 | 28 | def main(args): 29 | """Plot the weights from the linear classifer defined in the models-module.""" 30 | model = Regression(args.classes) 31 | model.load_state_dict(torch.load(args.model_path)) 32 | mat = torch.reshape( 33 | model.linear.weight.cpu().detach(), [args.classes, 64, 16, 16, 3] 34 | ) 35 | mat = torch.mean(mat, -1) 36 | real_weights = generate_frequency_packet_image(mat[0].numpy(), 3) 37 | fake_weights = [] 38 | for c in range(1, args.classes): 39 | fake_weights.append(generate_frequency_packet_image(mat[c].numpy(), 3)) 40 | 41 | cat = real_weights 42 | for fake in fake_weights: 43 | cat = np.concatenate([cat, fake], axis=1) 44 | 45 | plt.imshow(cat) 46 | plt.title("Real and fake class weights side by side.") 47 | plt.colorbar() 48 | plt.axis("off") 49 | plt.show() 50 | 51 | plt.imshow(real_weights, cmap=plt.cm.viridis, vmax=0.2, vmin=-0.2) 52 | plt.title("Real classifier weights") 53 | plt.colorbar() 54 | plt.axis("off") 55 | tikz.save("real_classifier_weights.tex") 56 | plt.show() 57 | 58 | for c, fake in enumerate(fake_weights): 59 | plt.imshow(fake, cmap=plt.cm.viridis, vmax=0.2, vmin=-0.2) 60 | plt.title(f"fake classifier {c+1} weights") 61 | plt.colorbar() 62 | tikz.save(f"fake_classifier_weights_{c+1}.tex") 63 | plt.axis("off") 64 | plt.show() 65 | 66 | print("stop") 67 | 68 | 69 | if __name__ == "__main__": 70 | main(_parse_args()) 71 | -------------------------------------------------------------------------------- /src/freqdect/resize_images_folder.py: -------------------------------------------------------------------------------- 1 | """Preprocessing module with code to resize all images in a folder.""" 2 | 3 | import argparse 4 | import os 5 | from concurrent.futures import ProcessPoolExecutor 6 | from functools import partial 7 | from typing import Tuple 8 | 9 | from PIL import Image 10 | 11 | 12 | def resize_image(packed: Tuple[int, str, str, str], shape: Tuple[int, int]): 13 | """Resize an image. 14 | 15 | Args: 16 | packed (Tuple[int, str, str, str]): Packed args as tuple. 17 | The first entry is the image index. 18 | The second entry is the path of the directory containing all original CelebA images. 19 | The third entry is the file path of the original image file, which is cropped. 20 | The fourth entry is the path of the directory where the cropped image is stored. 21 | shape (Tuple(int, int)): The shape for the resized images. 22 | """ 23 | i, directory, file_path, output = packed 24 | if ( 25 | file_path.endswith("png") 26 | or file_path.endswith("jpeg") 27 | or file_path.endswith("jpg") 28 | ): 29 | image = Image.open(f"{directory}/{file_path}") 30 | image = image.resize(shape) 31 | image.save(f"{output}/resize_{file_path}") 32 | return i 33 | 34 | 35 | def main(args): 36 | """Resizes a number of images in a directory and stores the resized images.""" 37 | os.makedirs(args.OUTPUT, exist_ok=True) 38 | paths = os.listdir(args.DIRECTORY)[: args.SIZE] 39 | packed = map(lambda x: (x[0], args.DIRECTORY, x[1], args.OUTPUT), enumerate(paths)) 40 | packed_list = list(packed) 41 | print("image total", len(packed_list)) 42 | resize_shape = partial(resize_image, shape=(args.SHAPE, args.SHAPE)) 43 | 44 | with ProcessPoolExecutor() as pool: 45 | _ = pool.map(resize_shape, packed_list) 46 | 47 | 48 | def _parse_args(): 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument("DIRECTORY", help="Source directory.", type=str) 52 | parser.add_argument("OUTPUT", help="Output directory.", type=str) 53 | parser.add_argument("SHAPE", help="Shape for the new images.", type=int) 54 | parser.add_argument("SIZE", help="Amount of data to convert.", type=int) 55 | 56 | return parser.parse_args() 57 | 58 | 59 | if __name__ == "__main__": 60 | main(_parse_args()) 61 | -------------------------------------------------------------------------------- /src/freqdect/baselines/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | As found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/utils.py 4 | """ 5 | import json 6 | from collections import defaultdict 7 | from pathlib import Path 8 | 9 | 10 | class PersistentDefaultDict: 11 | """ 12 | Nested defaultdict that gets synced transparently to disk. 13 | 14 | Init: 15 | results = PersistentDefaultDict() 16 | 17 | Add result: 18 | results['key1', 'key2'] = 19 | """ 20 | 21 | def __init__(self, path_to_dict): 22 | """Creates the dictionary at the given path.""" 23 | self.path = Path(path_to_dict) 24 | if self.path.is_file(): 25 | stored_data = json.loads(self.path.read_text()) 26 | self.data = PersistentDefaultDict.redefault_dict(stored_data) 27 | else: 28 | self.data = defaultdict(PersistentDefaultDict.rec_default_dict) 29 | 30 | def __str__(self): 31 | return str(json.dumps(self.data, indent=4)) 32 | 33 | def __setitem__(self, keys, item): 34 | d = self.data 35 | if isinstance(keys, str): 36 | d[keys] = item 37 | elif isinstance(keys, tuple): 38 | for key in keys[:-1]: 39 | d = d[key] 40 | d[keys[-1]] = item 41 | else: 42 | raise NotImplementedError() 43 | self.path.write_text(json.dumps(self.data, indent=4)) 44 | 45 | def __getitem__(self, key): 46 | return self.data[key] 47 | 48 | def __iter__(self): 49 | return self.data.__iter__() 50 | 51 | def as_dict(self): 52 | """Return the stored json as dict.""" 53 | if self.path.is_file(): 54 | return json.loads(self.path.read_text()) 55 | else: 56 | return {} 57 | 58 | @staticmethod 59 | def rec_default_dict(): 60 | """Wrap the rec_default_dict property.""" 61 | return defaultdict(PersistentDefaultDict.rec_default_dict) 62 | 63 | @staticmethod 64 | def redefault_dict(data): 65 | """Loads data into the dictionary.""" 66 | if isinstance(data, dict): 67 | return defaultdict( 68 | PersistentDefaultDict.rec_default_dict, 69 | {k: PersistentDefaultDict.redefault_dict(v) for k, v in data.items()}, 70 | ) 71 | else: 72 | return data 73 | -------------------------------------------------------------------------------- /src/freqdect/crop_celeba.py: -------------------------------------------------------------------------------- 1 | """Crop celebA for further processing. 2 | 3 | Adopted from: https://github.com/ningyu1991/GANFingerprints/ 4 | This version is taken as is from: https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/crop_celeba.py 5 | """ 6 | import argparse 7 | import os 8 | from concurrent.futures import ProcessPoolExecutor 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | from PIL import Image 13 | 14 | 15 | def crop_image(packed: Tuple[int, str, str, str]): 16 | """Center-crops an CelebA image to 128x128 pixels. 17 | 18 | Args: 19 | packed (Tuple[int, str, str, str]): Packed args as tuple. 20 | The first entry is the image index. 21 | The second entry is the path of the directory containing all original CelebA images. 22 | The third entry is the file path of the original image file, which is cropped. 23 | The fourth entry is the path of the directory where the cropped image is stored. 24 | """ 25 | i, directory, file_path, output = packed 26 | if ( 27 | file_path.endswith("png") 28 | or file_path.endswith("jpeg") 29 | or file_path.endswith("jpg") 30 | ): 31 | image = np.asarray(Image.open(f"{directory}/{file_path}")) 32 | 33 | if image.shape[0] != 128 or image.shape[1] != 128: 34 | x, y, _ = image.shape 35 | image = np.copy(image) 36 | x_upper = min(121 + 64, x) 37 | y_upper = min(89 + 64, y) 38 | image = image[x_upper - 128 : x_upper, y_upper - 128 : y_upper] 39 | image = np.clip(image, 0, 255.0).astype(np.uint8) # type: ignore 40 | 41 | if not (image.shape[0] == 128 and image.shape[1] == 128): 42 | print("Aborting") 43 | return i 44 | 45 | Image.fromarray(image).save(f"{output}/celeba_{file_path}") 46 | return i 47 | 48 | 49 | def main(args): 50 | """Center-crops a number of CelebA images in a directory to 128x128 pixels and stores the cropped images.""" 51 | os.makedirs(args.OUTPUT, exist_ok=True) 52 | paths = os.listdir(args.DIRECTORY)[: args.SIZE] 53 | packed = map(lambda x: (x[0], args.DIRECTORY, x[1], args.OUTPUT), enumerate(paths)) 54 | 55 | with ProcessPoolExecutor() as pool: 56 | _ = pool.map(crop_image, packed) 57 | 58 | 59 | def _parse_args(): 60 | parser = argparse.ArgumentParser() 61 | 62 | parser.add_argument("DIRECTORY", help="Source directory.", type=str) 63 | parser.add_argument("OUTPUT", help="Output directory.", type=str) 64 | parser.add_argument("SIZE", help="Amount of data to convert.", type=int) 65 | 66 | return parser.parse_args() 67 | 68 | 69 | if __name__ == "__main__": 70 | main(_parse_args()) 71 | -------------------------------------------------------------------------------- /src/freqdect/corruption.py: -------------------------------------------------------------------------------- 1 | """Corrupt images for robustness testing.""" 2 | from io import BytesIO 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from torchvision.transforms import RandomResizedCrop, RandomRotation 8 | 9 | 10 | def jpeg_compression(image: Image) -> Image: 11 | """Compute a compressed version of the input image. 12 | 13 | Args: 14 | image (Image): The input image. 15 | jpeg_compression (int): Compression factor 16 | on a scale from 0 (worst) to 95 (best). 17 | 18 | Returns: 19 | Image: The compressed image. 20 | """ 21 | out = BytesIO() 22 | factor = np.random.randint(low=70, high=90) 23 | image.save(out, format="JPEG", quality=factor, subsampling=0) 24 | return Image.open(out) 25 | 26 | 27 | def random_rotation(image: Image, angle=15) -> Image: 28 | """Randomly rotates an Image. 29 | 30 | Args: 31 | image (Image): The input image. 32 | angle (int, optional): The max rotation angle. Defaults to 15. 33 | 34 | Returns: 35 | Image: The rotated image. 36 | """ 37 | return RandomRotation(angle)(image) 38 | 39 | 40 | def random_resized_crop(image: Image) -> Image: 41 | """Randomly resize and crop the input Image. 42 | 43 | Args: 44 | image (Image): The input image. 45 | 46 | Returns: 47 | Image: The processed output image. 48 | """ 49 | return RandomResizedCrop((image.size[1], image.size[0]), scale=(0.8, 1.0))(image) 50 | 51 | 52 | def noise(image: Image) -> Image: 53 | """Add random variance noise with to test classifier resilience. 54 | 55 | Adapted from: 56 | https://github.com/RUB-SysSec/GANDCTAnalysis/ 57 | -> create_perturbed_imagedata.py 58 | 59 | Args: 60 | image (Image): The input PIL.Image . 61 | 62 | Returns: 63 | Image: Output image with added noise. 64 | """ 65 | image = np.array(image) 66 | # variance from U[5.0,20.0] 67 | variance = np.random.uniform(low=5.0, high=20.0) 68 | image = np.copy(image).astype(np.float64) 69 | noise = variance * np.random.randn(*image.shape) 70 | image += noise 71 | return Image.fromarray(np.clip(image, 0.0, 255.0).astype(np.uint8)) 72 | 73 | 74 | def blur(image: Image) -> Image: 75 | """Apply a gaussian blur for resilience testing. 76 | 77 | Adapted from: 78 | https://github.com/RUB-SysSec/GANDCTAnalysis/ 79 | -> create_perturbed_imagedata.py 80 | 81 | Args: 82 | image (Image): The PIL.Image input. 83 | 84 | Returns: 85 | Image: Blurred output. 86 | """ 87 | # kernel size from [1, 3, 5, 7, 9] 88 | image = np.array(image) 89 | kernel_size = np.random.choice([3, 5, 7, 9]) 90 | blurred = cv2.GaussianBlur( 91 | image, (kernel_size, kernel_size), sigmaX=cv2.BORDER_DEFAULT 92 | ) 93 | return Image.fromarray(np.clip(blurred, 0.0, 255.0).astype(np.uint8)) 94 | -------------------------------------------------------------------------------- /scripts/multi_train_ffhq.py: -------------------------------------------------------------------------------- 1 | # Running this script in sbatch will train multiple neural networks on the same gpu. 2 | import time 3 | import datetime 4 | 5 | import subprocess 6 | subprocess.call('pwd') 7 | 8 | print('running jobs in parallel') 9 | 10 | # experiment_lst = \ 11 | # [["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 12 | # "0", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_log_packets_haar_boundary", "--nclasses", "3"], 13 | # ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 14 | # "1", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_log_packets_haar_boundary", "--nclasses", "3"], 15 | # ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 16 | # "2", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_log_packets_haar_boundary", "--nclasses", "3"], 17 | # ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 18 | # "3", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_log_packets_haar_boundary", "--nclasses", "3"], 19 | # ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 20 | # "4", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_log_packets_haar_boundary", "--nclasses", "3"]] 21 | # jobs = [] 22 | # for exp_no, experiment in enumerate(experiment_lst): 23 | # time.sleep(10) 24 | # time_str = str(datetime.datetime.today()) 25 | # print(experiment, ' at time:', time_str) 26 | # with open("./log/out/" + time_str + ".txt", "w") as f: 27 | # jobs.append(subprocess.Popen(experiment, stdout=f)) 28 | # for job in jobs: 29 | # job.wait() 30 | 31 | 32 | experiment_lst = \ 33 | [["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 34 | "0", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_raw", "--nclasses", "3"], 35 | ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 36 | "1", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_raw", "--nclasses", "3"], 37 | ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 38 | "2", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_raw", "--nclasses", "3"], 39 | ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 40 | "3", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_raw", "--nclasses", "3"], 41 | ["python", "-m", "freqdect.train_classifier", "--features", "packets", "--seed", 42 | "4", "--data-prefix", "/nvme/mwolter/ffhq128/source_data_raw", "--nclasses", "3"]] 43 | jobs = [] 44 | for exp_no, experiment in enumerate(experiment_lst): 45 | time.sleep(10) 46 | time_str = str(datetime.datetime.today()) 47 | print(experiment, ' at time:', time_str) 48 | with open("./log/out/" + time_str + ".txt", "w") as f: 49 | jobs.append(subprocess.Popen(experiment, stdout=f)) 50 | for job in jobs: 51 | job.wait() 52 | 53 | print('done') -------------------------------------------------------------------------------- /src/freqdect/plot_accuracy_simple.py: -------------------------------------------------------------------------------- 1 | """Code to plot training mean accuracy as well as the standard deviation.""" 2 | import argparse 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | from .plot_accuracy_results import ( 8 | get_plot_tuple, 9 | get_test_acc_mean_std_max, 10 | ) 11 | 12 | 13 | def plot_mean_std(steps, mean, std, color, label="", marker="."): 14 | """Plot means and standard deviations with shaded areas.""" 15 | plt.plot(steps, mean, label=label, color=color, marker=marker) 16 | plt.fill_between(steps, mean - std, mean + std, color=color, alpha=0.2) 17 | 18 | 19 | def _parse_args(): 20 | """Parse the command line.""" 21 | parser = argparse.ArgumentParser(description="Simply plot validation accuracy") 22 | parser.add_argument("prefix_one", type=str, help="prefix to a first logfile group") 23 | parser.add_argument("prefix_two", type=str, help="prefix to a second logfile group") 24 | parser.add_argument( 25 | "--seeds", 26 | type=int, 27 | nargs="+", 28 | help="The seeds, defaults to 0 1 2 3 4", 29 | default=[0, 1, 2, 3, 4], 30 | ) 31 | return parser.parse_args() 32 | 33 | 34 | def main(args): 35 | """Plot two experiments.""" 36 | print(args.prefix_one) 37 | print(args.prefix_two) 38 | 39 | first_logs = [] 40 | for seed in args.seeds: 41 | with open(f"./log/{args.prefix_one}_{seed}.pkl", "rb") as f: 42 | first_logs.append(pickle.load(f)[0]) 43 | second_logs = [] 44 | for seed in args.seeds: 45 | with open(f"./log/{args.prefix_two}_{seed}.pkl", "rb") as f: 46 | second_logs.append(pickle.load(f)[0]) 47 | 48 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] 49 | 50 | steps, mean, std = get_plot_tuple(second_logs, "train_acc") 51 | steps, mean, std = get_plot_tuple(second_logs, "val_acc") 52 | plot_mean_std(steps, mean, std, color=colors[0], label="second validation acc") 53 | 54 | steps, mean, std = get_plot_tuple(first_logs, "val_acc") 55 | plot_mean_std(steps, mean, std, color=colors[1], label="first validation acc") 56 | 57 | pt_mean, pt_std, pt_max = get_test_acc_mean_std_max(first_logs, "test_acc") 58 | rt_mean, rt_std, rt_max = get_test_acc_mean_std_max(second_logs, "test_acc") 59 | print("first_mean", pt_mean, "first_std", pt_std, "first_max", pt_max) 60 | print("second_mean", rt_mean, "second_std", rt_std, "second_max", rt_max) 61 | plt.errorbar( 62 | steps[-1], pt_mean, pt_std, color=colors[2], label="first test acc", marker="x" 63 | ) 64 | plt.errorbar( 65 | steps[-1], rt_mean, rt_std, color=colors[3], label="second test acc", marker="x" 66 | ) 67 | 68 | plt.ylabel("mean accuracy") 69 | plt.xlabel("training steps") 70 | plt.title("Accuracy source identification") 71 | plt.legend() 72 | if 0: 73 | import tikzplotlib as tikz 74 | 75 | tikz.save("ffhq_style.tex", standalone=True) 76 | plt.show() 77 | print("done") 78 | 79 | 80 | if __name__ == "__main__": 81 | main(_parse_args()) 82 | -------------------------------------------------------------------------------- /scripts/train_lsun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | #SBATCH --tasks-per-node=1 5 | #SBATCH --job-name=train_lsun_log_packets_db2_boundary 6 | #SBATCH --output=train_lsun_log_packets_db2_boundary-%j.out 7 | #SBATCH --error=train_lsun_log_packets_db2_boundary-%j.err 8 | #SBATCH -p gpu 9 | #SBATCH --gres gpu:v100:1 10 | #SBATCH --cpus-per-task=9 11 | #SBATCH --time=48:00:00 12 | #SBATCH --mem=200gb 13 | 14 | for i in 0 1 2 3 4 15 | do 16 | echo "packet experiment no: $i " 17 | python -m freqdect.train_classifier \ 18 | --features packets \ 19 | --seed $i \ 20 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_db2_boundary \ 21 | --nclasses 5 \ 22 | --calc-normalization 23 | done 24 | 25 | for i in 0 1 2 3 4 26 | do 27 | echo "packet experiment no: $i " 28 | python -m freqdect.train_classifier \ 29 | --features packets \ 30 | --seed $i \ 31 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_sym2_boundary \ 32 | --nclasses 5 \ 33 | --calc-normalization 34 | done 35 | 36 | for i in 0 1 2 3 4 37 | do 38 | echo "packet experiment no: $i " 39 | python -m freqdect.train_classifier \ 40 | --features packets \ 41 | --seed $i \ 42 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_db4_boundary \ 43 | --nclasses 5 \ 44 | --calc-normalization 45 | done 46 | 47 | for i in 0 1 2 3 4 48 | do 49 | echo "packet experiment no: $i " 50 | python -m freqdect.train_classifier \ 51 | --features packets \ 52 | --seed $i \ 53 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_sym4_boundary \ 54 | --nclasses 5 \ 55 | --calc-normalization 56 | done 57 | 58 | 59 | for i in 0 1 2 3 4 60 | do 61 | echo "packet experiment no: $i " 62 | python -m freqdect.train_classifier \ 63 | --features packets \ 64 | --seed $i \ 65 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_db2_boundary \ 66 | --nclasses 5 \ 67 | --calc-normalization \ 68 | --model cnn 69 | done 70 | 71 | for i in 0 1 2 3 4 72 | do 73 | echo "packet experiment no: $i " 74 | python -m freqdect.train_classifier \ 75 | --features packets \ 76 | --seed $i \ 77 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_sym2_boundary \ 78 | --nclasses 5 \ 79 | --calc-normalization \ 80 | --model cnn 81 | done 82 | 83 | for i in 0 1 2 3 4 84 | do 85 | echo "packet experiment no: $i " 86 | python -m freqdect.train_classifier \ 87 | --features packets \ 88 | --seed $i \ 89 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_db4_boundary \ 90 | --nclasses 5 \ 91 | --calc-normalization \ 92 | --model cnn 93 | done 94 | 95 | for i in 0 1 2 3 4 96 | do 97 | echo "packet experiment no: $i " 98 | python -m freqdect.train_classifier \ 99 | --features packets \ 100 | --seed $i \ 101 | --data-prefix /nvme/mwolter/lsun/lsun_bedroom_200k_png_log_packets_sym4_boundary \ 102 | --nclasses 5 \ 103 | --calc-normalization \ 104 | --model cnn 105 | done -------------------------------------------------------------------------------- /src/freqdect/baselines/knn.py: -------------------------------------------------------------------------------- 1 | """ 2 | KNN baseline code as found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/knn.py 4 | """ 5 | from .classifier import Classifier, read_dataset 6 | from sklearn.neighbors import KNeighborsClassifier 7 | from .utils import PersistentDefaultDict 8 | 9 | 10 | class KNNClassifier(Classifier): 11 | """K-nearest neighbors classification""" 12 | 13 | def __init__(self, n_neighbors, n_jobs, **kwargs): 14 | """Create the classifier.""" 15 | super().__init__(**kwargs) 16 | self.knn = KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=n_jobs) 17 | 18 | def _fit(self, train_data, train_labels): 19 | self.knn.fit(train_data, train_labels) 20 | 21 | def _score(self, test_data, test_labels): 22 | return self.knn.score(test_data, test_labels) 23 | 24 | @staticmethod 25 | def grid_search( 26 | dataset_name, datasets_dir, output_dir, n_jobs, mean=None, std=None 27 | ): 28 | """Determine reasonable hyperparameters.""" 29 | # hyperparameter grid 30 | knn_grid = [1] + [(2**x) + 1 for x in range(1, 11)] 31 | 32 | # init results 33 | results = PersistentDefaultDict(output_dir.joinpath(f"knn_grid_search.json")) 34 | 35 | # load data 36 | train_data, train_labels = read_dataset( 37 | datasets_dir, f"{dataset_name}_train", mean=mean, std=std 38 | ) 39 | val_data, val_labels = read_dataset( 40 | datasets_dir, f"{dataset_name}_val", mean=mean, std=std 41 | ) 42 | 43 | for n_neighbors in knn_grid: 44 | knn_params_str = f"n_neighbors.{n_neighbors}" 45 | print(f"[+] {knn_params_str}") 46 | 47 | # skip if result already exists 48 | if ( 49 | dataset_name in results.as_dict() 50 | and knn_params_str in results.as_dict()[dataset_name] 51 | ): 52 | continue 53 | 54 | # train and test classifier 55 | knn = KNNClassifier(n_neighbors, n_jobs) 56 | knn.fit(train_data, train_labels) 57 | score = knn.score(val_data, val_labels) 58 | 59 | # store result 60 | results[dataset_name, knn_params_str] = score 61 | 62 | return results 63 | 64 | @staticmethod 65 | def train_classifier( 66 | dataset_name, datasets_dir, output_dir, n_jobs, n_neighbors, mean=None, std=None 67 | ): 68 | """Run the training code.""" 69 | results = PersistentDefaultDict(output_dir.joinpath(f"knn_test.json")) 70 | # classifier name 71 | classifier_name = f"classifier_{dataset_name}_knn_n_neighbors.{n_neighbors}" 72 | # load data 73 | train_data, train_labels = read_dataset( 74 | datasets_dir, f"{dataset_name}_train", mean=mean, std=std 75 | ) 76 | test_data, test_labels = read_dataset( 77 | datasets_dir, f"{dataset_name}_test", mean=mean, std=std 78 | ) 79 | # train classifier 80 | knn = KNNClassifier(n_neighbors, n_jobs) 81 | knn.fit(train_data, train_labels) 82 | # test classifier 83 | score = knn.score(test_data, test_labels) 84 | results[classifier_name] = score 85 | -------------------------------------------------------------------------------- /src/freqdect/crop_lsun.py: -------------------------------------------------------------------------------- 1 | """Script for cropping LSUN. 2 | 3 | Adopted from: https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/crop_lsun.py 4 | which is based on: https://github.com/ningyu1991/GANFingerprints/ 5 | """ 6 | 7 | import argparse 8 | import os 9 | from concurrent.futures import ProcessPoolExecutor 10 | from typing import Tuple 11 | 12 | from PIL import Image 13 | 14 | 15 | def transform_image(packed: Tuple[str, str, str]): 16 | """Center-crops and resizes an LSUN image to 128x128 pixels. 17 | 18 | Args: 19 | packed (Tuple[str, str, str]): Packed args as tuple. 20 | The first entry is the file path of the original image file, which is cropped and resized. 21 | The second entry is the path of the directory containing all original LSUN images. 22 | The third entry is the path of the directory where the cropped image is stored. 23 | """ 24 | file_path, directory, output = packed 25 | # catch errors and continue with different files 26 | try: 27 | if ( 28 | file_path.endswith("png") 29 | or file_path.endswith("jpeg") 30 | or file_path.endswith("jpg") 31 | or file_path.endswith("webp") 32 | ): 33 | image = Image.open(f"{directory}/{file_path}") 34 | x, y = image.size 35 | if y < x: 36 | crop_height = y 37 | crop_width = y 38 | 39 | crop_left = (x - y) // 2 40 | crop_top = 0 41 | image = image.crop( 42 | ( 43 | crop_left, 44 | crop_top, 45 | crop_left + crop_width, 46 | crop_top + crop_height, 47 | ) 48 | ) 49 | elif x < y: 50 | crop_height = x 51 | crop_width = x 52 | 53 | crop_left = 0 54 | crop_top = (y - x) // 2 55 | image = image.crop( 56 | ( 57 | crop_left, 58 | crop_top, 59 | crop_left + crop_width, 60 | crop_top + crop_height, 61 | ) 62 | ) 63 | 64 | image = image.resize((128, 128)) 65 | 66 | # store .webp images as .png files 67 | if file_path.endswith("webp"): 68 | file_path = file_path.replace(".webp", ".png") 69 | image.save(f"{output}/{file_path}") 70 | else: 71 | print(f"Skipped {file_path}") 72 | except ValueError as exc: 73 | print(file_path, exc, x, y) 74 | 75 | 76 | def main(args): 77 | """Center-crops and resizes a number of LSUN images. 78 | 79 | Images are resized to 128x128 pixels and stored. 80 | """ 81 | os.makedirs(args.OUTPUT, exist_ok=True) 82 | 83 | # only consider the specified number of files 84 | paths = os.listdir(args.DIRECTORY)[: args.SIZE] 85 | packed = map(lambda p: (p, args.DIRECTORY, args.OUTPUT), paths) 86 | with ProcessPoolExecutor() as pool: 87 | list(pool.map(transform_image, packed)) 88 | 89 | 90 | def _parse_args(): 91 | parser = argparse.ArgumentParser() 92 | 93 | parser.add_argument("DIRECTORY", help="Source directory.", type=str) 94 | parser.add_argument("OUTPUT", help="Output directory.", type=str) 95 | 96 | # added argument "SIZE" 97 | parser.add_argument("SIZE", help="Amount of data to convert.", type=int) 98 | 99 | return parser.parse_args() 100 | 101 | 102 | if __name__ == "__main__": 103 | main(_parse_args()) 104 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | ########################## 2 | # Setup.py Configuration # 3 | ########################## 4 | [metadata] 5 | name = freqdect 6 | version = 0.0.1-dev 7 | description = Detect GANs using frequency domain methods. 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | 11 | # URLs associated with the project 12 | url = https://github.com/gan-police/frequency-detection 13 | download_url = https://github.com/gan-police/frequency-detection/releases 14 | project_urls = 15 | Bug Tracker = https://github.com/gan-police/frequency-detection/issues 16 | Source Code = https://github.com/gan-police/frequency-detection 17 | 18 | # Author information 19 | author = Moritz Wolter 20 | author_email = moritz@wolter.tech 21 | maintainer = Moritz Wolter 22 | maintainer_email = moritz@wolter.tech 23 | 24 | # License Information 25 | license = GPLv3 26 | license_file = LICENSE 27 | 28 | # Search tags 29 | classifiers = 30 | Development Status :: 1 - Planning 31 | Environment :: Console 32 | Intended Audience :: Developers 33 | License :: OSI Approved :: GNU General Public License v3 (GPLv3) 34 | Operating System :: OS Independent 35 | Framework :: Pytest 36 | Framework :: tox 37 | Framework :: Sphinx 38 | Programming Language :: Python 39 | Programming Language :: Python :: 3.8 40 | Programming Language :: Python :: 3.9 41 | Programming Language :: Python :: 3 :: Only 42 | # TODO add your topics from the Trove controlled vocabulary (see https://pypi.org/classifiers) 43 | keywords = 44 | cookiecutter 45 | snekpack 46 | deepfakes 47 | GANs 48 | wavelets 49 | fft 50 | fwt 51 | CNNs 52 | classification 53 | deep learning 54 | 55 | [options] 56 | install_requires = 57 | # Missing itertools from the standard library you didn't know you needed 58 | more_itertools 59 | # Use progress bars excessively 60 | tqdm 61 | # Command line tools 62 | click 63 | more_click 64 | # TODO your requirements go here 65 | torch 66 | matplotlib 67 | seaborn 68 | PyWavelets 69 | scipy 70 | pillow 71 | opencv-python 72 | ptwt==0.1.2 73 | torchvision 74 | 75 | # Random options 76 | zip_safe = false 77 | include_package_data = True 78 | python_requires = >=3.6 79 | 80 | # Where is my code 81 | packages = find: 82 | package_dir = 83 | = src 84 | 85 | [options.packages.find] 86 | where = src 87 | 88 | [options.extras_require] 89 | docs = 90 | sphinx 91 | sphinx-rtd-theme 92 | sphinx-click 93 | sphinx-autodoc-typehints 94 | sphinx_automodapi 95 | # To include LaTeX comments easily in your docs 96 | texext 97 | 98 | [options.entry_points] 99 | console_scripts = 100 | freqdect = freqdect.cli:main 101 | 102 | ###################### 103 | # Doc8 Configuration # 104 | # (doc8.ini) # 105 | ###################### 106 | [doc8] 107 | max-line-length = 120 108 | 109 | ########################## 110 | # Coverage Configuration # 111 | # (.coveragerc) # 112 | ########################## 113 | [coverage:run] 114 | branch = True 115 | source = freqdect 116 | omit = 117 | tests/* 118 | docs/* 119 | img/* 120 | 121 | [coverage:paths] 122 | source = 123 | src/freqdect 124 | .tox/*/lib/python*/site-packages/freqdect 125 | 126 | [coverage:report] 127 | show_missing = True 128 | exclude_lines = 129 | pragma: no cover 130 | raise NotImplementedError 131 | if __name__ == .__main__.: 132 | def __str__ 133 | def __repr__ 134 | 135 | ########################## 136 | # Darglint Configuration # 137 | ########################## 138 | [darglint] 139 | docstring_style = sphinx 140 | strictness = short 141 | -------------------------------------------------------------------------------- /scripts/baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --nodes=1 4 | # Use all CPUs on the node 5 | #SBATCH --cpus-per-task=32 6 | #SBATCH --job-name=l-eigenfaces-lsun 7 | #SBATCH --output=l-eigenfaces-lsun-%j.out 8 | #SBATCH --error=l-eigenfaces-lsun-%j.err 9 | # Send the USR1 signal 120 seconds before end of time limit 10 | #SBATCH --signal=B:USR1@120 11 | # Set time limit to override default limit 12 | #SBATCH --time=48:00:00 13 | 14 | 15 | echo baseline.sh started at `date +"%T"` 16 | 17 | ANACONDA_ENV="$HOME/env/intel38" 18 | 19 | OUTPUT_DIR="baselines/results" 20 | DATASETS_DIR="/home/ndv/projects/wavelets/frequency-forensics_felix/data" 21 | 22 | LSUN_DATASET_LOGPACKETS="lsun_bedroom_200k_png_baseline_logpackets" 23 | LSUN_DATASET_PACKETS="lsun_bedroom_200k_png_baseline_packets" 24 | LSUN_DATASET_RAW="lsun_bedroom_200k_png_baseline_raw" 25 | 26 | CELEBA_DATASET_LOGPACKETS="celeba_align_png_cropped_baselines_logpackets" 27 | CELEBA_DATASET_PACKETS="celeba_align_png_cropped_baselines_packets" 28 | CELEBA_DATASET_RAW="celeba_align_png_cropped_baselines_raw" 29 | 30 | 31 | # select baseline to compute from {"knn", "prnu", "eigenfaces"} 32 | BASELINE="eigenfaces" 33 | 34 | # first three are channelwise mean, last three channelwise std 35 | MEAN_STD_RAW_CHANNELWISE="175.4984 163.5837 152.6461 56.3215 60.2467 64.2528" 36 | MEAN_STD_PACKETS_CHANNELWISE="0.1826 0.2155 0.2154 4.4256 4.3896 4.3595" 37 | 38 | # first is overall mean, last is overall std 39 | LSUN_MEAN_STD_LOGPACKETS="0.3281 4.2175" 40 | LSUN_MEAN_STD_PACKETS="19.8486 168.9453" 41 | LSUN_MEAN_STD_RAW="157.9363 63.1872" 42 | 43 | CELEBA_MEAN_STD_LOGPACKETS="0.7375 3.4890" 44 | CELEBA_MEAN_STD_PACKETS="17.5999 155.5985" 45 | CELEBA_MEAN_STD_RAW="140.8967 68.3285" 46 | 47 | CHOSEN_DATASET=$LSUN_DATASET_LOGPACKETS 48 | CHOSEN_NORMALIZATION=$LSUN_MEAN_STD_LOGPACKETS 49 | 50 | 51 | cp_results_from_tmp() 52 | { 53 | if [ -d ${TMPDIR}/${OUTPUT_DIR} ]; then 54 | echo "Copying results back to ${SLURM_SUBMIT_DIR}" 55 | 56 | # make sure that the output dir exists 57 | mkdir -p ${SLURM_SUBMIT_DIR}/${OUTPUT_DIR} 58 | cp -r ${TMPDIR}/${OUTPUT_DIR}/. ${SLURM_SUBMIT_DIR}/${OUTPUT_DIR} 59 | fi 60 | } 61 | 62 | # Define the signal handler function 63 | finalize_job() 64 | { 65 | echo Signal USR1 trapped at `date +"%T"` 66 | cp_results_from_tmp 67 | exit 68 | } 69 | 70 | # Call finalize_job function as soon as we receive USR1 signal (2 min before timeout) 71 | trap 'finalize_job' USR1 72 | 73 | module load Anaconda3 74 | source activate "$ANACONDA_ENV" 75 | 76 | if [ -f ${DATASETS_DIR}/${CHOSEN_DATASET}.tar ]; then 77 | echo "Tarred raw input folder exists, copying to $TMPDIR" 78 | cp "${DATASETS_DIR}/${CHOSEN_DATASET}.tar" "$TMPDIR"/ 79 | cd "$TMPDIR" 80 | echo "Unpacking tarred input folder" 81 | tar xf ${CHOSEN_DATASET}.tar 82 | DATASETS_DIR=${TMPDIR} 83 | 84 | # delete .tar file, which is not needed anymore 85 | rm ${CHOSEN_DATASET}.tar 86 | fi 87 | 88 | # work on scratch dir 89 | cd $TMPDIR 90 | 91 | # copy existing results to avoid repetition 92 | if [ -d ${SLURM_SUBMIT_DIR}/${OUTPUT_DIR} ]; then 93 | echo "Copying existing results to ${TMPDIR}" 94 | mkdir -p ${TMPDIR}/${OUTPUT_DIR} 95 | cp -r ${SLURM_SUBMIT_DIR}/${OUTPUT_DIR}/. ${TMPDIR}/${OUTPUT_DIR} 96 | fi 97 | 98 | echo "Calculating baseline data" 99 | 100 | # -u: unbuffered stdout for "live" updates in the output file 101 | python -u -m freqdect.baselines.baselines \ 102 | --command grid_search \ 103 | --output_dir $OUTPUT_DIR \ 104 | --datasets_dir $DATASETS_DIR \ 105 | --datasets $CHOSEN_DATASET \ 106 | --normalize $CHOSEN_NORMALIZATION \ 107 | --n_jobs 32 \ 108 | $BASELINE 109 | 110 | # release signal 111 | trap - USR1 112 | 113 | # save the results 114 | cp_results_from_tmp 115 | 116 | echo baseline.sh finished at `date +"%T"` 117 | 118 | exit 119 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | Contributions, whether big or small, are appreciated! You can get involved by submitting an issue, making a suggestion, 4 | or adding code to the project. 5 | 6 | Having a Problem? Submit an Issue. 7 | ---------------------------------- 8 | 1. Check that you have the latest version of :code:`freqdect` 9 | 2. Check that StackOverflow hasn't already solved your problem 10 | 3. Go here: https://github.com/gan-police/frequency-detection/issues 11 | 4. Check that this issue hasn't been solved 12 | 5. Click "new issue" 13 | 6. Add a short, but descriptive title 14 | 7. Add a full description of the problem, including the code that caused it and any support files related to this code 15 | so others can reproduce your problem 16 | 8. Copy the output and error message you're getting 17 | 18 | Have a Question or Suggestion? 19 | ------------------------------ 20 | Same drill! Submit an issue and we'll have a nice conversation in the thread. 21 | 22 | Want to Contribute? 23 | ------------------- 24 | 1. Get the code. Fork the repository from GitHub using the big green button in the top-right corner of 25 | https://github.com/gan-police/frequency-detection 26 | 2. Clone your directory with 27 | 28 | .. code-block:: sh 29 | 30 | $ git clone https://github.com//frequency-detection 31 | 32 | 3. Install with :code:`pip`. The flag, :code:`-e`, makes your installation editable, so your changes will be reflected 33 | automatically in your installation. 34 | 35 | .. code-block:: sh 36 | 37 | $ cd frequency-detection 38 | $ python3 -m pip install -e . 39 | 40 | 4. Make a branch off of develop, then make contributions! This line makes a new branch and checks it out 41 | 42 | .. code-block:: sh 43 | 44 | $ git checkout -b feature/ 45 | 46 | 5. This project should be well tested, so write unit tests in the :code:`tests/` directory 47 | 6. Check that all tests are passing and code coverage is good with :code:`tox` before committing. 48 | 49 | .. code-block:: sh 50 | 51 | $ tox 52 | 53 | Pull Requests 54 | ~~~~~~~~~~~~~ 55 | Once you've got your feature or bugfix finished (or if its in a partially complete state but you want to publish it 56 | for comment), push it to your fork of the repository and open a pull request against the develop branch on GitHub. 57 | 58 | Make a descriptive comment about your pull request, perhaps referencing the issue it is meant to fix (something along 59 | the lines of "fixes issue #10" will cause GitHub to automatically link to that issue). The maintainers will review your 60 | pull request and perhaps make comments about it, request changes, or may pull it in to the develop branch! If you need 61 | to make changes to your pull request, simply push more commits to the feature branch in your fork to GitHub and they 62 | will automatically be added to the pull. You do not need to close and reissue your pull request to make changes! 63 | 64 | If you spend a while working on your changes, further commits may be made to the main :code:`freqdect` 65 | repository (called "upstream") before you can make your pull request. In keep your fork up to date with upstream by 66 | pulling the changes--if your fork has diverged too much, it becomes difficult to properly merge pull requests without 67 | conflicts. 68 | 69 | To pull in upstream changes: 70 | 71 | .. code-block:: sh 72 | 73 | $ git remote add upstream https://github.com/gan-police/frequency-detection 74 | $ git fetch upstream develop 75 | 76 | Check the log to make sure the upstream changes don't affect your work too much: 77 | 78 | .. code-block:: sh 79 | 80 | $ git log upstream/develop 81 | 82 | Then merge in the new changes: 83 | 84 | .. code-block:: sh 85 | 86 | $ git merge upstream/develop 87 | 88 | More information about this whole fork-pull-merge process can be found 89 | `here on Github's website `_. 90 | -------------------------------------------------------------------------------- /src/freqdect/baselines/classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classifier interface code as found at:As found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/classifier.py 4 | """ 5 | import pickle 6 | import time 7 | from collections import defaultdict 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | 14 | class Classifier(object): 15 | """Classifier interface for the eigenfaces and k-nearest neighbour.""" 16 | 17 | def __init__(self): 18 | """Instantiates a classifier""" 19 | super().__init__() 20 | 21 | def _fit(self, train_data, train_labels): 22 | raise NotImplementedError() 23 | 24 | def _score(self, test_data, test_labels): 25 | raise NotImplementedError() 26 | 27 | def fit(self, train_data, train_labels): 28 | """Fit the classifier to training data.""" 29 | print(f" fit") 30 | start = time.time() 31 | self._fit(train_data, train_labels) 32 | end = time.time() 33 | runtime = int(end - start) 34 | print( 35 | f" completed in {runtime // 3600}h {(runtime % 3600) // 60}m {(runtime % 60)}s" 36 | ) 37 | return self 38 | 39 | def score(self, test_data, test_labels): 40 | """Measure classifier performance.""" 41 | print(f" score") 42 | start = time.time() 43 | score = self._score(test_data, test_labels) 44 | end = time.time() 45 | runtime = int(end - start) 46 | print(f" -> {score}") 47 | print( 48 | f" completed in {runtime // 3600}h {(runtime % 3600) // 60}m {(runtime % 60)}s" 49 | ) 50 | return score 51 | 52 | def save(self, output_path): 53 | """Save the classifier to disk.""" 54 | Path(output_path).parent.mkdir(exist_ok=True, parents=True) 55 | output_path.write_bytes(pickle.dumps(self)) 56 | 57 | @staticmethod 58 | def load(in_path): 59 | """Load classifier from disk.""" 60 | instance = pickle.loads(Path(in_path).read_bytes()) 61 | return instance 62 | 63 | 64 | def read_dataset( 65 | datasets_dir, dataset_name, subset_to_size=None, flatten=True, mean=None, std=None 66 | ): 67 | """Load data from disk.""" 68 | print(f"[+] Read from {dataset_name}") 69 | dataset_dir = datasets_dir / dataset_name 70 | 71 | labels = np.load(dataset_dir.joinpath("labels.npy")) 72 | if not subset_to_size: 73 | # read full dataset 74 | imgs = [] 75 | for idx in tqdm(range(labels.size), bar_format=" {l_bar}{bar:30}{r_bar}"): 76 | img_path = dataset_dir.joinpath(f"{idx:06}.npy") 77 | img = np.load(img_path) 78 | if mean is not None: 79 | img = (img - mean) / std 80 | imgs.append(img) 81 | imgs = np.stack(imgs, 0) 82 | if flatten: 83 | imgs = imgs.reshape(labels.size, -1) 84 | return imgs, labels 85 | 86 | else: 87 | # subset dataset 88 | size_per_label = subset_to_size // np.unique(labels).size 89 | 90 | subset_data = [] 91 | subset_labels = [] 92 | 93 | sizes_per_label = defaultdict(int) 94 | p_bar = tqdm(total=subset_to_size, bar_format=" {l_bar}{bar:30}{r_bar}") 95 | for idx, label in enumerate(labels): 96 | 97 | if sizes_per_label[label] < size_per_label: 98 | img_path = dataset_dir.joinpath(f"{idx:06}.npy") 99 | img = np.load(img_path) 100 | if mean is not None: 101 | img = (img - mean) / std 102 | subset_data.append(img) 103 | subset_labels.append(label) 104 | p_bar.update(1) 105 | sizes_per_label[label] += 1 106 | 107 | if len(subset_data) == subset_to_size: 108 | p_bar.close() 109 | break 110 | 111 | else: 112 | raise Exception("[!] ran out of images") 113 | 114 | subset_data = np.stack(subset_data, 0) 115 | if flatten: 116 | subset_data = subset_data.reshape(subset_to_size, -1) 117 | return subset_data, subset_labels 118 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = 8 | # always keep coverage-clean first 9 | # coverage-clean 10 | # code linters/stylers 11 | manifest 12 | pyroma 13 | flake8 14 | mypy 15 | # documentation linters/checkers 16 | doc8 17 | docstr-coverage 18 | docs 19 | # the actual tests 20 | py 21 | # always keep coverage-report last 22 | # coverage-report 23 | 24 | [testenv] 25 | commands = coverage run -p -m pytest --durations=20 {posargs:tests} 26 | deps = 27 | coverage 28 | pytest 29 | 30 | [pytest] 31 | markers = 32 | slow: marks tests as slow (deselect with '-m "not slow"') 33 | 34 | [testenv:coverage-clean] 35 | deps = coverage 36 | skip_install = true 37 | commands = coverage erase 38 | 39 | [testenv:manifest] 40 | deps = check-manifest 41 | skip_install = true 42 | commands = check-manifest 43 | 44 | [testenv:black] 45 | skip_install = true 46 | deps = 47 | black 48 | commands = 49 | black src/freqdect/ scripts/ tests/ setup.py 50 | description = Apply Black to python source code. 51 | 52 | [testenv:flake8] 53 | skip_install = true 54 | deps = 55 | flake8==4.0.1 56 | flake8-bandit==2.1.2 57 | bandit==1.7.2 58 | flake8-colors 59 | flake8-black 60 | flake8-docstrings 61 | flake8-import-order 62 | flake8-bugbear 63 | flake8-broken-line 64 | pep8-naming 65 | pydocstyle 66 | darglint 67 | commands = 68 | flake8 src/freqdect/ tests/ setup.py 69 | description = Run the flake8 tool with several plugins (bandit, docstrings, import order, pep8 naming). 70 | 71 | [testenv:pyroma] 72 | deps = 73 | pygments 74 | pyroma 75 | skip_install = true 76 | commands = pyroma --min=10 . 77 | description = Run the pyroma tool to check the package friendliness of the project. 78 | 79 | [testenv:mypy] 80 | deps = mypy 81 | commands = mypy --ignore-missing-imports src/freqdect/ 82 | description = Run the mypy tool to check static typing on the project. 83 | 84 | [testenv:doc8] 85 | skip_install = true 86 | deps = 87 | sphinx 88 | doc8 89 | commands = 90 | doc8 docs/source/ 91 | description = Run the doc8 tool to check the style of the RST files in the project docs. 92 | 93 | [testenv:docstr-coverage] 94 | skip_install = true 95 | deps = 96 | docstr-coverage 97 | commands = 98 | docstr-coverage src/freqdect/ tests/ setup.py --skip-private --skip-magic 99 | description = Run the docstr-coverage tool to check documentation coverage 100 | 101 | [testenv:docs] 102 | changedir = docs 103 | extras = 104 | docs 105 | commands = 106 | mkdir -p {envtmpdir} 107 | cp -r source {envtmpdir}/source 108 | sphinx-build -W -b html -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/html 109 | sphinx-build -W -b coverage -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/coverage 110 | cat {envtmpdir}/build/coverage/c.txt 111 | cat {envtmpdir}/build/coverage/python.txt 112 | whitelist_externals = 113 | /bin/cp 114 | /bin/cat 115 | /bin/mkdir 116 | 117 | [testenv:coverage-report] 118 | deps = coverage 119 | skip_install = true 120 | commands = 121 | coverage combine 122 | coverage report 123 | 124 | #################### 125 | # Deployment tools # 126 | #################### 127 | 128 | [testenv:bumpversion] 129 | commands = bumpversion {posargs} 130 | skip_install = true 131 | passenv = HOME 132 | deps = 133 | bumpversion 134 | 135 | [testenv:build] 136 | skip_install = true 137 | deps = 138 | wheel 139 | setuptools 140 | commands = 141 | python setup.py -q sdist bdist_wheel 142 | 143 | [testenv:release] 144 | skip_install = true 145 | deps = 146 | {[testenv:build]deps} 147 | twine >= 1.5.0 148 | commands = 149 | {[testenv:build]commands} 150 | twine upload --skip-existing dist/* 151 | 152 | [testenv:finish] 153 | skip_install = true 154 | passenv = HOME 155 | deps = 156 | {[testenv:build]deps} 157 | {[testenv:release]deps} 158 | bumpversion 159 | commands = 160 | bumpversion release 161 | {[testenv:release]commands} 162 | git push 163 | bumpversion patch 164 | git push 165 | whitelist_externals = 166 | /usr/bin/git 167 | -------------------------------------------------------------------------------- /src/freqdect/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to load numpy files into memory for further processing with PyTorch. 3 | 4 | Written with the numpy based data format 5 | of https://github.com/RUB-SysSec/GANDCTAnalysis in mind. 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | __all__ = ["NumpyDataset", "CombinedDataset"] 16 | 17 | 18 | class NumpyDataset(Dataset): 19 | """Create a data loader to load pre-processed numpy arrays into memory.""" 20 | 21 | def __init__( 22 | self, 23 | data_dir: str, 24 | mean: Optional[float] = None, 25 | std: Optional[float] = None, 26 | key: Optional[str] = "image", 27 | ): 28 | """Create a Numpy-dataset object. 29 | 30 | Args: 31 | data_dir: A path to a pre-processed folder with numpy files. 32 | mean: Pre-computed mean to normalize with. Defaults to None. 33 | std: Pre-computed standard deviation to normalize with. Defaults to None. 34 | key: The key for the input or 'x' component of the dataset. 35 | Defaults to "image". 36 | 37 | Raises: 38 | ValueError: If an unexpected file name is given or directory is empty. 39 | 40 | # noqa: DAR401 41 | """ 42 | self.data_dir = data_dir 43 | self.file_lst = sorted(Path(data_dir).glob("./*.npy")) 44 | print("Loading ", data_dir) 45 | if len(self.file_lst) == 0: 46 | raise ValueError("empty directory") 47 | if self.file_lst[-1].name != "labels.npy": 48 | raise ValueError("unexpected file name") 49 | self.labels = np.load(self.file_lst[-1]) 50 | self.images = self.file_lst[:-1] 51 | self.mean = mean 52 | self.std = std 53 | self.key = key 54 | 55 | def __len__(self) -> int: 56 | """Return the data set length.""" 57 | return len(self.labels) 58 | 59 | def __getitem__(self, idx: int) -> dict: 60 | """Get a dataset element. 61 | 62 | Args: 63 | idx (int): The element index of the data pair to return. 64 | 65 | Returns: 66 | [dict]: Returns a dictionary with the self.key 67 | default ("image") and "label" keys. 68 | """ 69 | img_path = self.images[idx] 70 | image = np.load(img_path) 71 | image = torch.from_numpy(image.astype(np.float32)) 72 | # normalize the data. 73 | if self.mean is not None: 74 | image = (image - self.mean) / self.std 75 | label = self.labels[idx] 76 | label = torch.tensor(int(label)) 77 | sample = {self.key: image, "label": label} 78 | return sample 79 | 80 | 81 | class CombinedDataset(Dataset): 82 | """Load data from multiple Numpy-Data sets using a singe object.""" 83 | 84 | def __init__(self, sets: list): 85 | """Create an merged dataset, combining many numpy datasets. 86 | 87 | Args: 88 | sets (list): A list of NumpyDataset objects. 89 | """ 90 | self.sets = sets 91 | self.len = len(sets[0]) 92 | # assert not any(self.len != len(s) for s in sets) 93 | 94 | @property 95 | def key(self) -> list: 96 | """Return the keys for all features in this dataset.""" 97 | return [d.key for d in self.sets] 98 | 99 | def __len__(self) -> int: 100 | """Return the data set length.""" 101 | return self.len 102 | 103 | def __getitem__(self, idx: int) -> dict: 104 | """Get a dataset element. 105 | 106 | Args: 107 | idx (int): The element index of the data pair to return. 108 | 109 | Returns: 110 | [dict]: Returns a dictionary with the self.key 111 | default ("image") and "label" keys. 112 | The key property will return a keylist. 113 | """ 114 | label_list = [s.__getitem__(idx)["label"] for s in self.sets] 115 | # the labels should all be the same 116 | # assert not any([label_list[0] != l for l in label_list]) 117 | label = label_list[0] 118 | dict = {set.key: set.__getitem__(idx)[set.key] for set in self.sets} 119 | dict["label"] = label 120 | return dict 121 | 122 | 123 | def main(): 124 | """Compute dataset mean and standard deviation and store it.""" 125 | import argparse 126 | import pickle 127 | 128 | parser = argparse.ArgumentParser(description="Calculate mean and std") 129 | parser.add_argument( 130 | "dir", 131 | type=str, 132 | help="path of training data for which mean and std are computed", 133 | ) 134 | args = parser.parse_args() 135 | 136 | print(args) 137 | 138 | data = NumpyDataset(args.dir) 139 | 140 | def compute_mean_std(data_set: Dataset) -> tuple: 141 | """Compute mean and stad values by looping over a dataset. 142 | 143 | Args: 144 | data_set (Dataset): A torch style dataset. 145 | 146 | Returns: 147 | tuple: the raw_data, as well as mean and std values. 148 | """ 149 | # compute mean and std 150 | img_lst = [] 151 | for img_no in range(data_set.__len__()): # type: ignore[attr-defined] 152 | img_lst.append(data_set.__getitem__(img_no)["image"]) 153 | img_data = torch.stack(img_lst, 0) 154 | 155 | # average all axis except the color channel 156 | axis = tuple(np.arange(len(img_data.shape[:-1]))) 157 | # calculate mean and std in double to avoid precision problems 158 | mean = torch.mean(img_data.double(), axis).float() 159 | std = torch.std(img_data.double(), axis).float() 160 | return img_data, mean, std 161 | 162 | data, mean, std = compute_mean_std(data) 163 | 164 | print("mean", mean) 165 | print("std", std) 166 | file_name = f"{args.dir}/mean_std.pkl" 167 | with open(file_name, "wb") as f: 168 | pickle.dump([mean.numpy(), std.numpy()], f) 169 | print("stored in", file_name) 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /src/freqdect/baselines/prnu.py: -------------------------------------------------------------------------------- 1 | """ 2 | As found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/prnu.py 4 | """ 5 | from collections import defaultdict 6 | from itertools import product 7 | from multiprocessing import cpu_count 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from .classifier import Classifier, read_dataset 13 | from .prnu_functions import aligned_cc, extract_multiple_aligned, extract_single 14 | from .utils import PersistentDefaultDict 15 | 16 | 17 | class PRNUClassifier(Classifier): 18 | """Photoresponse non-uniformity classification as described in: 19 | Francesco Marra, Diego Gragnaniello, Luisa Verdoliva, and Giovanni Poggi. Do gans leave 20 | artificial fingerprints? In 2019 IEEE Conference on Multimedia Information Processing and 21 | Retrieval (MIPR), pages 506–511. IEEE, 2019. 22 | """ 23 | 24 | def __init__(self, levels, sigma, **kwargs): 25 | """Create the classifier.""" 26 | super().__init__(**kwargs) 27 | self.levels = levels 28 | self.sigma = sigma 29 | self.gan_fingerprints = None 30 | 31 | def _fit(self, train_data, train_labels): 32 | # sort training data by labels 33 | unique_labels = np.unique(train_labels) 34 | train_data_by_label = defaultdict(list) 35 | for img, label in zip(train_data, train_labels): 36 | train_data_by_label[label].append(img) 37 | # extract fingerprints 38 | self.gan_fingerprints = [] 39 | for label in unique_labels: 40 | imgs = train_data_by_label[label] 41 | gan_fingerprint = extract_multiple_aligned( 42 | imgs, self.levels, self.sigma, processes=cpu_count() 43 | ) 44 | self.gan_fingerprints.append(gan_fingerprint) 45 | 46 | def _score(self, test_data, test_labels): 47 | # extract fingerprints 48 | img_fingerprints = [] 49 | for img, label in tqdm( 50 | zip(test_data, test_labels), 51 | bar_format=" {l_bar}{bar:30}{r_bar}", 52 | total=len(test_labels), 53 | ): 54 | img_fingerprint = extract_single(img, self.levels, self.sigma) 55 | img_fingerprints.append(img_fingerprint) 56 | # correlate images with GAN fingerprints 57 | cc = aligned_cc( 58 | np.stack(self.gan_fingerprints, 0), np.stack(img_fingerprints, 0) 59 | )["ncc"] 60 | # calculate score 61 | predictions = np.argmax(cc, axis=0) 62 | correct = 0 63 | incorrect = 0 64 | for prediction, label in zip(predictions, test_labels): 65 | if str(prediction) == str(label): 66 | correct += 1 67 | else: 68 | incorrect += 1 69 | score = correct / (correct + incorrect) 70 | return score 71 | 72 | @staticmethod 73 | def grid_search( 74 | dataset_name, datasets_dir, output_dir, n_jobs, mean=None, std=None 75 | ): 76 | """Determine hyperparameters.""" 77 | # init results 78 | results = PersistentDefaultDict(output_dir.joinpath(f"prnu_grid_search.json")) 79 | 80 | # load data 81 | train_data, train_labels = read_dataset( 82 | datasets_dir, f"{dataset_name}_train", flatten=False 83 | ) 84 | val_data, val_labels = read_dataset( 85 | datasets_dir, f"{dataset_name}_val", flatten=False 86 | ) 87 | train_data = train_data.astype(np.dtype("uint8")) 88 | val_data = val_data.astype(np.dtype("uint8")) 89 | 90 | # hyperparameter grid 91 | levels_range = range(1, 5, 1) 92 | sigma_range = np.arange(0.05, 1, 0.05) 93 | 94 | for levels, sigma in product(levels_range, sigma_range): 95 | # classifier name 96 | prnu_params_str = f"levels.{levels}_sigma.{sigma}" 97 | print(f"[+] {prnu_params_str}") 98 | 99 | # skip if result already exists 100 | if ( 101 | dataset_name in results.as_dict() 102 | and prnu_params_str in results.as_dict()[dataset_name] 103 | ): 104 | continue 105 | 106 | # train and test classifier 107 | prnu = PRNUClassifier(levels, sigma) 108 | prnu.fit(train_data, train_labels) 109 | score = prnu.score(val_data, val_labels) 110 | 111 | # store result 112 | results[dataset_name, prnu_params_str] = score 113 | 114 | return results 115 | 116 | @staticmethod 117 | def train_classifier(dataset_name, datasets_dir, output_dir, n_jobs, levels, sigma): 118 | """Run the training code.""" 119 | # classifier name 120 | classifier_name = ( 121 | f"classifier_{dataset_name}_prnu_levels.{levels}_sigma.{sigma}" 122 | ) 123 | print(f"\n{classifier_name.upper()}") 124 | # load data 125 | train_data, train_labels = read_dataset( 126 | datasets_dir, f"{dataset_name}_train", flatten=False 127 | ) 128 | train_data = train_data.astype(np.dtype("uint8")) 129 | # train 130 | prnu = PRNUClassifier(levels, sigma) 131 | prnu.fit(train_data, train_labels) 132 | prnu.save(output_dir.joinpath(f"{classifier_name}.pickle")) 133 | # test 134 | PRNUClassifier.test_classifier( 135 | classifier_name, dataset_name, datasets_dir, output_dir, n_jobs 136 | ) 137 | 138 | @staticmethod 139 | def test_classifier( 140 | classifier_name, dataset_name, datasets_dir, output_dir, n_jobs 141 | ): 142 | """Test the classifier.""" 143 | print(f"\n{classifier_name.upper()}") 144 | results = PersistentDefaultDict(output_dir.joinpath(f"prnu_test.json")) 145 | # load data 146 | test_data, test_labels = read_dataset( 147 | datasets_dir, f"{dataset_name}_test", flatten=False 148 | ) 149 | test_data = test_data.astype(np.dtype("uint8")) 150 | # load classifier 151 | prnu = PRNUClassifier.load(output_dir.joinpath(classifier_name + ".pickle")) 152 | # score 153 | score = prnu.score(test_data, test_labels) 154 | results[classifier_name] = score 155 | -------------------------------------------------------------------------------- /src/freqdect/wavelet_math.py: -------------------------------------------------------------------------------- 1 | """Module implementing wavelet related math functions. 2 | 3 | The idea is to provide functionality to make the packet transform useful 4 | for image analysis and gan-content recognition. 5 | """ 6 | 7 | from itertools import product 8 | 9 | import numpy as np 10 | import ptwt 11 | import pywt 12 | import torch 13 | 14 | 15 | def compute_packet_rep_2d( 16 | image: np.ndarray, wavelet_str: str = "haar", max_lev: int = 3 17 | ) -> np.ndarray: 18 | """Numpy based computation of a 2d full-packet representation. 19 | 20 | Args: 21 | image (np.ndarray): Image of shape [height, width]. 22 | wavelet_str (str, optional): The wavelet to use. Defaults to "haar". 23 | max_lev (int, optional): The number of levels in the representation. 24 | Defaults to 3. 25 | 26 | Returns: 27 | np.ndarray: A ready to plot wavelet packet image. 28 | """ 29 | wavelet = pywt.Wavelet(wavelet_str) 30 | wp_tree = pywt.WaveletPacket2D(image, wavelet=wavelet, mode="reflect") 31 | # Get the full decomposition 32 | wp_keys = list(product(["a", "h", "v", "d"], repeat=max_lev)) 33 | count = 0 34 | img_rows = None 35 | img = [] 36 | for node in wp_keys: 37 | packet = np.squeeze(wp_tree["".join(node)].data) 38 | if img_rows is not None: 39 | img_rows = np.concatenate([img_rows, packet], axis=1) 40 | else: 41 | img_rows = packet 42 | count += 1 43 | if count >= np.sqrt(len(wp_keys)): 44 | count = 0 45 | img.append(img_rows) 46 | img_rows = None 47 | 48 | img_pywt = np.concatenate(img, axis=0) 49 | return img_pywt 50 | 51 | 52 | def compute_pytorch_packet_representation_2d_image( 53 | pt_data: torch.Tensor, wavelet_str: str = "db5", max_lev: int = 5 54 | ): 55 | """Create a packet image to plot.""" 56 | wavelet = pywt.Wavelet(wavelet_str) 57 | ptwt_wp_tree = ptwt.WaveletPacket2D(data=pt_data, wavelet=wavelet, mode="reflect") 58 | 59 | # get the pytorch decomposition 60 | wp_keys = list(product(["a", "h", "v", "d"], repeat=max_lev)) 61 | count = 0 62 | img_pt = [] 63 | img_rows_pt = None 64 | for node in wp_keys: 65 | packet = torch.squeeze(ptwt_wp_tree["".join(node)], dim=1) 66 | if img_rows_pt is not None: 67 | img_rows_pt = torch.cat([img_rows_pt, packet], dim=2) 68 | else: 69 | img_rows_pt = packet 70 | count += 1 71 | if count >= np.sqrt(len(wp_keys)): 72 | count = 0 73 | img_pt.append(img_rows_pt) 74 | img_rows_pt = None 75 | 76 | wp_pt = torch.cat(img_pt, dim=1) 77 | return wp_pt 78 | 79 | 80 | def compute_pytorch_packet_representation_2d_tensor( 81 | pt_data: torch.Tensor, 82 | wavelet_str: str = "db5", 83 | max_lev: int = 5, 84 | mode: str = "reflect", 85 | ) -> torch.Tensor: 86 | """Compute the wavelet packet representation tensor for a batch of input images. 87 | 88 | Args: 89 | pt_data: Image tensor of shape [batch, height, width] 90 | wavelet_str: Wavelet description string. Must be Pywt compatible. Defaults to "db5". 91 | max_lev: The maximum decomposition level to compute. Defaults to 5. 92 | mode: The desired boundary treatment approach. Choose zero, reflect or 93 | boundary. Defaults to reflect. 94 | 95 | Returns: 96 | : The packet tensor of shape [batch_size, packet_no, packet_height, packet_width] 97 | """ 98 | wavelet = pywt.Wavelet(wavelet_str) 99 | # print('wavelet', wavelet_str) 100 | ptwt_wp_tree = ptwt.WaveletPacket2D(data=pt_data, wavelet=wavelet, mode=mode) 101 | 102 | # get the pytorch decomposition 103 | # batch_size = pt_data.shape[0] 104 | wp_keys = list( 105 | product( 106 | ["a", "h", "v", "d"], 107 | repeat=max_lev, 108 | ) 109 | ) 110 | packet_list = [] 111 | for node in wp_keys: 112 | packet = torch.squeeze(ptwt_wp_tree["".join(node)], dim=1) 113 | packet_list.append(packet) 114 | 115 | wp_pt = torch.stack(packet_list, dim=1) 116 | return wp_pt 117 | 118 | 119 | def batch_packet_preprocessing( 120 | image_batch: np.ndarray, 121 | wavelet: str = "db1", 122 | max_lev: int = 3, 123 | eps: float = 1e-12, 124 | log_scale: bool = False, 125 | mode: str = "reflect", 126 | cuda: bool = True, 127 | ) -> np.ndarray: 128 | """Preprocess image batches by computing the wavelet packet representation. 129 | 130 | The raw as well as an absolute log scaled version can be computed. 131 | 132 | Args: 133 | image_batch (np.ndarray): An image of shape (B, H, W, C) 134 | wavelet (str, optional): A pywt-compatible wavelet string. 135 | Defaults to 'db1'. 136 | max_lev (int, optional): The number of decomposition scales 137 | to use. Defaults to 3. 138 | eps (float, optional): A small number to stabilize the logarithm. 139 | Defaults to 1e-12. 140 | log_scale (bool, optional): Use log-scaling if True. 141 | Log-scaled coefficients aren't invertible. Defaults to False. 142 | mode (str, optional): The boundary treatment method. Defaults to reflect. 143 | cuda (bool, optional): If False computations take place on the cpu. 144 | 145 | Returns: 146 | [np.ndarray]: The wavelet packets [B, N, H, W, C]. 147 | """ 148 | image_batch_tensor = torch.from_numpy(image_batch.astype(np.float32)) 149 | if cuda: 150 | image_batch_tensor = image_batch_tensor.cuda() 151 | # transform to from H, W, C to C, H, W 152 | channels = [] 153 | for channel in range(image_batch_tensor.shape[-1]): 154 | with torch.no_grad(): 155 | channel_packets = compute_pytorch_packet_representation_2d_tensor( 156 | image_batch_tensor[:, :, :, channel], 157 | wavelet_str=wavelet, 158 | max_lev=max_lev, 159 | mode=mode, 160 | ) 161 | channels.append(channel_packets) 162 | packets = torch.stack(channels, -1) 163 | del channels 164 | if log_scale: 165 | packets = torch.abs(packets) 166 | packets = torch.log(packets + eps) 167 | return packets.cpu().numpy() 168 | 169 | 170 | def identity_processing(image_batch): 171 | """Return the input unchanged.""" 172 | return image_batch 173 | -------------------------------------------------------------------------------- /src/freqdect/baselines/eigenface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Eigenface baseline code as found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/eigenface.py 4 | """ 5 | import time 6 | from itertools import product 7 | from .classifier import Classifier, read_dataset 8 | from sklearn.decomposition import PCA 9 | from sklearn.svm import LinearSVC 10 | from .utils import PersistentDefaultDict 11 | 12 | 13 | class PCAClassifier(Classifier): 14 | """Classification based on principal components or eigenfaces as 15 | described in: 16 | Lawrence Sirovich and Michael Kirby. Low-dimensional procedure for the characterization of 17 | human faces. Josa a, 4(3):519–524, 1987. 18 | """ 19 | 20 | def __init__(self, pca_target_variance, svm_params, **kwargs): 21 | """Create the PCA classifier.""" 22 | super().__init__(**kwargs) 23 | self.pca = PCA(n_components=pca_target_variance, svd_solver="full") 24 | self.svm = LinearSVC(**svm_params, max_iter=10000) 25 | 26 | def fit_pca(self, train_data): 27 | """Fit to target data.""" 28 | print(f" -> pca") 29 | start = time.time() 30 | self.pca.fit(train_data) 31 | end = time.time() 32 | runtime = int(end - start) 33 | print(f" pca_components = {self.pca.components_.shape[0]}") 34 | print( 35 | f" completed in {runtime // 3600}h {(runtime % 3600) // 60}m {(runtime % 60)}s" 36 | ) 37 | 38 | def _fit(self, train_data, train_labels): 39 | if not hasattr(self.pca, "components_"): 40 | self.fit_pca(train_data) 41 | train_data_pca_transformed = self.pca.transform(train_data) 42 | self.svm.fit(train_data_pca_transformed, train_labels) 43 | 44 | def _score(self, test_data, test_labels): 45 | test_data_pca_transformed = self.pca.transform(test_data) 46 | score = self.svm.score(test_data_pca_transformed, test_labels) 47 | return score 48 | 49 | @staticmethod 50 | def generate_params(svm_grid): 51 | """Loop over parameter grids.""" 52 | for grid in svm_grid: 53 | for param_values in product(*tuple(grid.values())): 54 | params = {} 55 | for param_name, param_value in zip(grid.keys(), param_values): 56 | params[param_name] = param_value 57 | yield params 58 | 59 | @staticmethod 60 | def grid_search( 61 | dataset_name, datasets_dir, output_dir, n_jobs, mean=None, std=None 62 | ): 63 | """Determine reasonable input parameters.""" 64 | # hyperparameter grid 65 | pca_target_variances = [0.25, 0.5, 0.75, 0.95] 66 | svm_grid = [{"C": [0.0001, 0.001, 0.01, 0.1]}] 67 | 68 | # init results 69 | results = PersistentDefaultDict( 70 | output_dir.joinpath(f"eigenfaces_grid_search.json") 71 | ) 72 | 73 | # load data 74 | train_data, train_labels = read_dataset( 75 | datasets_dir, f"{dataset_name}_train", mean=mean, std=std 76 | ) 77 | val_data, val_labels = read_dataset( 78 | datasets_dir, f"{dataset_name}_val", mean=mean, std=std 79 | ) 80 | 81 | for pca_target_variance in pca_target_variances: 82 | # enumerate svm params 83 | for svm_params in PCAClassifier.generate_params(svm_grid): 84 | svm_params_str = "_".join([f"{k}.{v}" for k, v in svm_params.items()]) 85 | params_str = ( 86 | f"pca_target_variance.{pca_target_variance}_{svm_params_str}" 87 | ) 88 | print(f"[+] {params_str}") 89 | 90 | # skip if result already exists 91 | if ( 92 | dataset_name in results.as_dict() 93 | and params_str in results.as_dict()[dataset_name] 94 | ): 95 | continue 96 | 97 | # load training data for PCA 98 | train_data_pca, _ = read_dataset( 99 | datasets_dir, 100 | f"{dataset_name}_train", 101 | subset_to_size=10000, 102 | mean=mean, 103 | std=std, 104 | ) 105 | 106 | # train and test classifier 107 | pca = PCAClassifier(pca_target_variance, svm_params) 108 | pca.fit_pca(train_data_pca) 109 | pca.fit(train_data, train_labels) 110 | score = pca.score(val_data, val_labels) 111 | 112 | # store result 113 | results[dataset_name, params_str] = score 114 | 115 | return results 116 | 117 | @staticmethod 118 | def train_classifier( 119 | dataset_name, 120 | datasets_dir, 121 | output_dir, 122 | n_jobs, 123 | pca_target_variance, 124 | C, 125 | mean=None, 126 | std=None, 127 | ): 128 | """Run the training code.""" 129 | # classifier name 130 | classifier_name = ( 131 | f"classifier_{dataset_name}_eigenfaces_v.{pca_target_variance}_c.{C}" 132 | ) 133 | # load data 134 | train_data, train_labels = read_dataset( 135 | datasets_dir, f"{dataset_name}_train", mean=mean, std=std 136 | ) 137 | train_data_pca, _ = read_dataset( 138 | datasets_dir, 139 | f"{dataset_name}_train", 140 | subset_to_size=10000, 141 | mean=mean, 142 | std=std, 143 | ) 144 | # train 145 | pca = PCAClassifier( 146 | pca_target_variance=pca_target_variance, svm_params={"C": C} 147 | ) 148 | pca.fit_pca(train_data_pca) 149 | pca.fit(train_data, train_labels) 150 | pca.save(output_dir.joinpath(f"{classifier_name}.pickle")) 151 | # test 152 | PCAClassifier.test_classifier( 153 | classifier_name, 154 | dataset_name, 155 | datasets_dir, 156 | output_dir, 157 | n_jobs, 158 | mean=mean, 159 | std=std, 160 | ) 161 | 162 | @staticmethod 163 | def test_classifier( 164 | classifier_name, 165 | dataset_name, 166 | datasets_dir, 167 | output_dir, 168 | n_jobs, 169 | mean=None, 170 | std=None, 171 | ): 172 | """Run the test code.""" 173 | results = PersistentDefaultDict(output_dir.joinpath(f"eigenfaces_test.json")) 174 | # load data 175 | test_data, test_labels = read_dataset( 176 | datasets_dir, f"{dataset_name}_test", mean=mean, std=std 177 | ) 178 | # load classifier 179 | pca = PCAClassifier.load(output_dir.joinpath(f"{classifier_name}.pickle")) 180 | # score 181 | score = pca.score(test_data, test_labels) 182 | results[classifier_name] = score 183 | -------------------------------------------------------------------------------- /src/freqdect/saliency_process.py: -------------------------------------------------------------------------------- 1 | """Sensitivity analysis results processing module. 2 | 3 | Written by https://github.com/RaoulHeese . 4 | """ 5 | 6 | import argparse 7 | import os 8 | from pathlib import Path 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from .plot_mean_packets import generate_frequency_packet_image 15 | 16 | 17 | def process_results_dir_avg( 18 | data_dir: str, 19 | img_shape: Tuple[int, int] = (128, 128), 20 | degree: int = 3, 21 | ) -> np.ndarray: 22 | """Compute average gradients of most probable class. 23 | 24 | Args: 25 | data_dir (str): Directory in which the results from saliency.py are stored. 26 | img_shape (Tuple[int, int]): Input (image) shapes. Defaults to (128, 128). 27 | degree (int): wavelet degree (required to compose single image from wavelets). Defaults to 3. 28 | 29 | Returns: 30 | np.ndarray : The resulting gradient image [img_shape[0], img_shape[1]] 31 | """ 32 | 33 | def _process_raw_image(s, o): 34 | # s.shape = (classes, H, W, C) 35 | p = np.exp(o) # p(class) 36 | i = np.argmax(p) 37 | x = s[i] # take more probable class 38 | x = np.abs(np.mean(x, axis=-1)) # abs mean over channel 39 | x = (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else x # [0,1] 40 | return x.astype(float) 41 | 42 | def _process_wavelet_image(s, o): 43 | # s.shape = (classes, wavelets, H, W, C) 44 | p = np.exp(o) # p(class) 45 | i = np.argmax(p) 46 | x = s[i] # take more probable class 47 | x = np.abs(np.mean(x, axis=-1)) # abs mean over channel 48 | x = generate_frequency_packet_image(x, degree) # compose 49 | x = (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else x # [0,1] 50 | return x.astype(float) 51 | 52 | def _process_image(s, o): 53 | if len(s.shape) == 4: 54 | return _process_raw_image(s, o) 55 | elif len(s.shape) == 5: 56 | return _process_wavelet_image(s, o) 57 | else: 58 | raise NotImplementedError(f"Data shape {s.shape} cannot be processed") 59 | 60 | def _process_batch(s_in, o_in): 61 | batch_picture = np.zeros(img_shape, dtype=float) 62 | for s, o in zip(s_in, o_in): 63 | batch_picture += _process_image(s, o) 64 | return batch_picture, s_in.shape[0] 65 | 66 | print("avg") 67 | avg_picture = np.zeros(img_shape, dtype=float) 68 | counter = 0 69 | file_lst = sorted(Path(data_dir).glob("./*.npy")) 70 | for file in tqdm(file_lst, desc="process files"): 71 | data = np.load(file) 72 | batch_picture, batch_counter = _process_batch(data["S"], data["O"]) 73 | avg_picture += batch_picture 74 | counter += batch_counter 75 | 76 | print(f"\tprocessed {counter} images") 77 | return avg_picture / counter 78 | 79 | 80 | def process_results_dir_std( 81 | data_dir: str, 82 | avg_picture: np.ndarray, 83 | img_shape: Tuple[int, int] = (128, 128), 84 | degree: int = 3, 85 | ) -> np.ndarray: 86 | """Compute standard deviation of gradients of most probable class. 87 | 88 | Args: 89 | data_dir (str): Directory in which the results from saliency.py are stored. 90 | avg_picture (np.ndarray): Average gradient from process_results_dir_avg. 91 | img_shape (Tuple[int, int]): Input (image) shapes. Defaults to (128, 128). 92 | degree (int): wavelet degree (required to compose single image from wavelets). Defaults to 3. 93 | 94 | Returns: 95 | np.ndarray : The resulting gradient image [img_shape[0], img_shape[1]] 96 | """ 97 | 98 | def _process_raw_image(s, o): 99 | # s.shape = (classes, H, W, C) 100 | p = np.exp(o) # p(class) 101 | i = np.argmax(p) 102 | x = s[i] # take more probable class 103 | x = np.abs(np.mean(x, axis=-1)) # abs mean over channel 104 | x = (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else x # [0,1] 105 | return x.astype(float) 106 | 107 | def _process_wavelet_image(s, o): 108 | # s.shape = (classes, wavelets, H, W, C) 109 | p = np.exp(o) # p(class) 110 | i = np.argmax(p) 111 | x = s[i] # take more probable class 112 | x = np.abs(np.mean(x, axis=-1)) # abs mean over channel 113 | x = generate_frequency_packet_image(x, degree) # compose 114 | x = (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else x # [0,1] 115 | return x.astype(float) 116 | 117 | def _process_image(s, o): 118 | if len(s.shape) == 4: 119 | return _process_raw_image(s, o) 120 | elif len(s.shape) == 5: 121 | return _process_wavelet_image(s, o) 122 | else: 123 | raise NotImplementedError(f"Data shape {s.shape} cannot be processed") 124 | 125 | def _process_batch(s_in, o_in): 126 | batch_picture = np.zeros(img_shape, dtype=float) 127 | for s, o in zip(s_in, o_in): 128 | batch_picture += (_process_image(s, o) - avg_picture) ** 2 129 | return batch_picture, s_in.shape[0] 130 | 131 | print("std") 132 | std_picture = np.zeros(img_shape, dtype=float) 133 | counter = 0 134 | file_lst = sorted(Path(data_dir).glob("./*.npy")) 135 | for file in tqdm(file_lst, desc="process files"): 136 | data = np.load(file) 137 | batch_picture, batch_counter = _process_batch(data["S"], data["O"]) 138 | std_picture += batch_picture 139 | counter += batch_counter 140 | 141 | print(f"\tprocessed {counter} images") 142 | return np.sqrt(std_picture / counter) 143 | 144 | 145 | def main(args): 146 | """Process results from saliency.py .""" 147 | print(f"Process '{args.sal_dir}'") 148 | 149 | avg_picture = process_results_dir_avg(args.sal_dir) 150 | std_picture = process_results_dir_std(args.sal_dir, avg_picture) 151 | 152 | array_dict = {"avg": avg_picture, "std": std_picture} 153 | if not os.path.exists(args.result_dir): 154 | print("creating", args.result_dir) 155 | os.mkdir(args.result_dir) 156 | prefix = os.path.split(args.sal_dir)[-1] 157 | filename = f"{prefix}-result.npy" 158 | with open(os.path.join(args.result_dir, filename), "wb") as numpy_file: 159 | np.savez(numpy_file, **array_dict) 160 | 161 | print(f"Finished: {filename}.") 162 | 163 | 164 | def _parse_args(): 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument( 167 | "--sal-dir", 168 | type=str, 169 | required=True, 170 | help="Path to saliency result directory.", 171 | ) 172 | parser.add_argument( 173 | "--result-dir", 174 | type=str, 175 | required=True, 176 | help="Path to processed result directory.", 177 | ) 178 | return parser.parse_args() 179 | 180 | 181 | if __name__ == "__main__": 182 | args = _parse_args() 183 | print(args) 184 | main(args) 185 | -------------------------------------------------------------------------------- /src/freqdect/baselines/baselines.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basline code as found at: 3 | https://github.com/RUB-SysSec/GANDCTAnalysis/blob/master/baselines/baselines.py 4 | """ 5 | 6 | import argparse 7 | from multiprocessing import cpu_count 8 | from pathlib import Path 9 | from typing import Type, Dict 10 | 11 | import numpy as np 12 | 13 | from .classifier import Classifier, read_dataset 14 | from .eigenface import PCAClassifier 15 | from .knn import KNNClassifier 16 | from .prnu import PRNUClassifier 17 | 18 | CLASSIFIER_CLS: Dict[str, Type[Classifier]] = { 19 | "prnu": PRNUClassifier, 20 | "eigenfaces": PCAClassifier, 21 | "knn": KNNClassifier, 22 | } 23 | 24 | 25 | def main( 26 | command, 27 | baseline, 28 | datasets, 29 | datasets_dir, 30 | output_dir, 31 | n_jobs, 32 | classifier_name, 33 | normalize, 34 | calc_normalization, 35 | **classifier_args, 36 | ): 37 | """Run baselines.""" 38 | print("[+] ARGUMENTS") 39 | print(f" -> command @ {command}") 40 | print(f" -> baseline @ {baseline}") 41 | if command == "train": 42 | print( 43 | " * ", 44 | "\n * ".join([f"{k} = {v}" for k, v in classifier_args.items()]), 45 | sep="", 46 | ) 47 | if command == "test": 48 | print(" * ", classifier_name) 49 | print(f" -> n_jobs @ {n_jobs}") 50 | print(f" -> datasets @ {len(datasets)}") 51 | print(" * ", "\n * ".join(list(datasets)), sep="") 52 | print(f" -> datasets_dir @ {datasets_dir}") 53 | assert datasets_dir.is_dir() 54 | print(f" -> output_dir @ {output_dir}") 55 | output_dir.mkdir(exist_ok=True, parents=True) 56 | 57 | # select classifier class of given baseline 58 | classifier_cls = CLASSIFIER_CLS[baseline] 59 | 60 | # grid search 61 | if command == "grid_search": 62 | print("\n[+] GRID SEARCH") 63 | best_results = {} 64 | for dataset_name in datasets: 65 | print(f"\n{dataset_name.upper()}") 66 | 67 | if normalize: 68 | num_of_norm_vals = len(normalize) 69 | assert num_of_norm_vals == 2 or num_of_norm_vals == 6 70 | mean = np.array(normalize[: num_of_norm_vals // 2]) 71 | std = np.array(normalize[(num_of_norm_vals // 2) :]) 72 | elif calc_normalization: 73 | # load train data and compute mean and std 74 | train_data_set = read_dataset( 75 | datasets_dir, f"{dataset_name}_train", flatten=False 76 | ) 77 | 78 | # average all axis except the color channel 79 | axis = tuple(np.arange(len(train_data_set.shape[:-1]))) 80 | 81 | # calculate mean and std in double to avoid precision problems 82 | mean = np.mean(train_data_set.double(), axis).float() 83 | std = np.std(train_data_set.double(), axis).float() 84 | else: 85 | mean = None 86 | std = None 87 | 88 | print(f"\t\tmean: {mean}") 89 | print(f"\t\tstd: {std}") 90 | 91 | results = classifier_cls.grid_search( 92 | dataset_name, datasets_dir, output_dir, n_jobs, mean=mean, std=std 93 | ) 94 | # get best result 95 | best_results[dataset_name] = sorted( 96 | results.as_dict()[dataset_name].items(), key=lambda e: e[1] 97 | ).pop() 98 | 99 | print(f"\n[+] Best Results") 100 | for dataset_name, (params, acc) in best_results.items(): 101 | print(f" -> {dataset_name}") 102 | print(f" {params} @ {acc}") 103 | 104 | # train 105 | if command == "train": 106 | for dataset_name in datasets: 107 | classifier_cls.train_classifier( 108 | dataset_name, datasets_dir, output_dir, n_jobs, **classifier_args 109 | ) 110 | 111 | # test 112 | if command == "test": 113 | assert len(datasets) == 1 114 | classifier_cls.test_classifier( 115 | classifier_name, datasets[0], datasets_dir, output_dir, n_jobs 116 | ) 117 | 118 | 119 | def parse_args(): 120 | """Read command line arguments.""" 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "--command", 124 | help="Command to execute. If choice is `train`, the hyperparameters need to be set accrodingly.", 125 | choices=["train", "test", "grid_search"], 126 | type=str, 127 | ) 128 | parser.add_argument( 129 | "--n_jobs", 130 | help="Limits the number of cores used (if available).", 131 | type=int, 132 | default=cpu_count(), 133 | ) 134 | parser.add_argument( 135 | "--datasets", 136 | help="Name of dataset(s).", 137 | action="append", 138 | type=str, 139 | required=True, 140 | ) 141 | parser.add_argument( 142 | "--datasets_dir", 143 | help="Directory containing the dataset(s).", 144 | type=Path, 145 | required=True, 146 | ) 147 | parser.add_argument( 148 | "--output_dir", 149 | help="Working directory containing results and classifiers.", 150 | type=Path, 151 | required=True, 152 | ) 153 | group = parser.add_mutually_exclusive_group() 154 | group.add_argument( 155 | "--normalize", 156 | nargs="+", 157 | type=float, 158 | metavar=("MEAN", "STD"), 159 | help="normalize with specified values for mean and standard deviation (either 2 or 6 values " 160 | "are accepted)", 161 | ) 162 | group.add_argument( 163 | "--calc-normalization", 164 | action="store_true", 165 | help="calculates mean and standard deviation used in normalization from the training data", 166 | ) 167 | 168 | parser.add_argument( 169 | "--classifier_name", 170 | help="Name of classifier (located within output directory). Only used when command set to " 171 | "'test'.", 172 | type=str, 173 | ) 174 | 175 | subparsers = parser.add_subparsers(dest="baseline", help="Name of classifier.") 176 | subparsers.required = True 177 | 178 | knn_parser = subparsers.add_parser("knn", help="kNN-based classifier.") 179 | knn_parser.add_argument("--n_neighbors", type=int) 180 | 181 | prnu_parser = subparsers.add_parser("prnu", help="PRNU-based classifier.") 182 | prnu_parser.add_argument("--levels", type=int) 183 | prnu_parser.add_argument("--sigma", type=float) 184 | 185 | eigenfaces_parser = subparsers.add_parser( 186 | "eigenfaces", help="Eigenfaces-based classifier." 187 | ) 188 | eigenfaces_parser.add_argument("--pca_target_variance", type=float) 189 | eigenfaces_parser.add_argument("--C", type=float) 190 | 191 | return parser.parse_args() 192 | 193 | 194 | if __name__ == "__main__": 195 | import sklearnex 196 | 197 | sklearnex.patch_sklearn() 198 | main(**vars(parse_args())) 199 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/data_download/ffhq-dataset/thumbnails128x128 2 | src/data_download/lsun 3 | data 4 | runs 5 | log 6 | log2 7 | scripts 8 | 9 | # Created by https://www.toptal.com/developers/gitignore/api/macos,linux,pycharm,python,windows 10 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,linux,pycharm,python,windows 11 | 12 | ### Linux ### 13 | *~ 14 | 15 | # temporary files which can be created if a process still has a handle open of a deleted file 16 | .fuse_hidden* 17 | 18 | # KDE directory preferences 19 | .directory 20 | 21 | # Linux trash folder which might appear on any partition or disk 22 | .Trash-* 23 | 24 | # .nfs files are created when an open file is removed but is still being accessed 25 | .nfs* 26 | 27 | ### macOS ### 28 | # General 29 | .DS_Store 30 | .AppleDouble 31 | .LSOverride 32 | 33 | # Icon must end with two \r 34 | Icon 35 | 36 | 37 | # Thumbnails 38 | ._* 39 | 40 | # Files that might appear in the root of a volume 41 | .DocumentRevisions-V100 42 | .fseventsd 43 | .Spotlight-V100 44 | .TemporaryItems 45 | .Trashes 46 | .VolumeIcon.icns 47 | .com.apple.timemachine.donotpresent 48 | 49 | # Directories potentially created on remote AFP share 50 | .AppleDB 51 | .AppleDesktop 52 | Network Trash Folder 53 | Temporary Items 54 | .apdisk 55 | 56 | ### PyCharm ### 57 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 58 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 59 | 60 | # User-specific stuff 61 | .idea/**/workspace.xml 62 | .idea/**/tasks.xml 63 | .idea/**/usage.statistics.xml 64 | .idea/**/dictionaries 65 | .idea/**/shelf 66 | 67 | # Generated files 68 | .idea/**/contentModel.xml 69 | 70 | # Sensitive or high-churn files 71 | .idea/**/dataSources/ 72 | .idea/**/dataSources.ids 73 | .idea/**/dataSources.local.xml 74 | .idea/**/sqlDataSources.xml 75 | .idea/**/dynamic.xml 76 | .idea/**/uiDesigner.xml 77 | .idea/**/dbnavigator.xml 78 | 79 | # Gradle 80 | .idea/**/gradle.xml 81 | .idea/**/libraries 82 | 83 | # Gradle and Maven with auto-import 84 | # When using Gradle or Maven with auto-import, you should exclude module files, 85 | # since they will be recreated, and may cause churn. Uncomment if using 86 | # auto-import. 87 | # .idea/artifacts 88 | # .idea/compiler.xml 89 | # .idea/jarRepositories.xml 90 | # .idea/modules.xml 91 | # .idea/*.iml 92 | # .idea/modules 93 | # *.iml 94 | # *.ipr 95 | 96 | # CMake 97 | cmake-build-*/ 98 | 99 | # Mongo Explorer plugin 100 | .idea/**/mongoSettings.xml 101 | 102 | # File-based project format 103 | *.iws 104 | 105 | # IntelliJ 106 | out/ 107 | 108 | # mpeltonen/sbt-idea plugin 109 | .idea_modules/ 110 | 111 | # JIRA plugin 112 | atlassian-ide-plugin.xml 113 | 114 | # Cursive Clojure plugin 115 | .idea/replstate.xml 116 | 117 | # Crashlytics plugin (for Android Studio and IntelliJ) 118 | com_crashlytics_export_strings.xml 119 | crashlytics.properties 120 | crashlytics-build.properties 121 | fabric.properties 122 | 123 | # Editor-based Rest Client 124 | .idea/httpRequests 125 | 126 | # Android studio 3.1+ serialized cache file 127 | .idea/caches/build_file_checksums.ser 128 | 129 | ### PyCharm Patch ### 130 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 131 | 132 | # *.iml 133 | # modules.xml 134 | # .idea/misc.xml 135 | # *.ipr 136 | 137 | # Sonarlint plugin 138 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 139 | .idea/**/sonarlint/ 140 | 141 | # SonarQube Plugin 142 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 143 | .idea/**/sonarIssues.xml 144 | 145 | # Markdown Navigator plugin 146 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 147 | .idea/**/markdown-navigator.xml 148 | .idea/**/markdown-navigator-enh.xml 149 | .idea/**/markdown-navigator/ 150 | 151 | # Cache file creation bug 152 | # See https://youtrack.jetbrains.com/issue/JBR-2257 153 | .idea/$CACHE_FILE$ 154 | 155 | # CodeStream plugin 156 | # https://plugins.jetbrains.com/plugin/12206-codestream 157 | .idea/codestream.xml 158 | 159 | ### Python ### 160 | # Byte-compiled / optimized / DLL files 161 | __pycache__/ 162 | *.py[cod] 163 | *$py.class 164 | 165 | # C extensions 166 | *.so 167 | 168 | # Distribution / packaging 169 | .Python 170 | build/ 171 | develop-eggs/ 172 | dist/ 173 | downloads/ 174 | eggs/ 175 | .eggs/ 176 | lib/ 177 | lib64/ 178 | parts/ 179 | sdist/ 180 | var/ 181 | wheels/ 182 | pip-wheel-metadata/ 183 | share/python-wheels/ 184 | *.egg-info/ 185 | .installed.cfg 186 | *.egg 187 | MANIFEST 188 | 189 | # PyInstaller 190 | # Usually these files are written by a python script from a template 191 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 192 | *.manifest 193 | *.spec 194 | 195 | # Installer logs 196 | pip-log.txt 197 | pip-delete-this-directory.txt 198 | 199 | # Unit test / coverage reports 200 | htmlcov/ 201 | .tox/ 202 | .nox/ 203 | .coverage 204 | .coverage.* 205 | .cache 206 | nosetests.xml 207 | coverage.xml 208 | *.cover 209 | *.py,cover 210 | .hypothesis/ 211 | .pytest_cache/ 212 | pytestdebug.log 213 | 214 | # Translations 215 | *.mo 216 | *.pot 217 | 218 | # Django stuff: 219 | *.log 220 | local_settings.py 221 | db.sqlite3 222 | db.sqlite3-journal 223 | 224 | # Flask stuff: 225 | instance/ 226 | .webassets-cache 227 | 228 | # Scrapy stuff: 229 | .scrapy 230 | 231 | # Sphinx documentation 232 | docs/_build/ 233 | doc/_build/ 234 | 235 | # PyBuilder 236 | target/ 237 | 238 | # Jupyter Notebook 239 | .ipynb_checkpoints 240 | 241 | # IPython 242 | profile_default/ 243 | ipython_config.py 244 | 245 | # pyenv 246 | .python-version 247 | 248 | # pipenv 249 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 250 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 251 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 252 | # install all needed dependencies. 253 | #Pipfile.lock 254 | 255 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 256 | __pypackages__/ 257 | 258 | # Celery stuff 259 | celerybeat-schedule 260 | celerybeat.pid 261 | 262 | # SageMath parsed files 263 | *.sage.py 264 | 265 | # Environments 266 | .env 267 | .venv 268 | env/ 269 | venv/ 270 | ENV/ 271 | env.bak/ 272 | venv.bak/ 273 | pythonenv* 274 | 275 | # Spyder project settings 276 | .spyderproject 277 | .spyproject 278 | 279 | # Rope project settings 280 | .ropeproject 281 | 282 | # mkdocs documentation 283 | /site 284 | 285 | # mypy 286 | .mypy_cache/ 287 | .dmypy.json 288 | dmypy.json 289 | 290 | # Pyre type checker 291 | .pyre/ 292 | 293 | # pytype static type analyzer 294 | .pytype/ 295 | 296 | # profiling data 297 | .prof 298 | 299 | ### Windows ### 300 | # Windows thumbnail cache files 301 | Thumbs.db 302 | Thumbs.db:encryptable 303 | ehthumbs.db 304 | ehthumbs_vista.db 305 | 306 | # Dump file 307 | *.stackdump 308 | 309 | # Folder config file 310 | [Dd]esktop.ini 311 | 312 | # Recycle Bin used on file shares 313 | $RECYCLE.BIN/ 314 | 315 | # Windows Installer files 316 | *.cab 317 | *.msi 318 | *.msix 319 | *.msm 320 | *.msp 321 | 322 | # Windows shortcuts 323 | *.lnk 324 | 325 | # End of https://www.toptal.com/developers/gitignore/api/macos,linux,pycharm,python,windows 326 | 327 | scratch/ 328 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | ## Wavelet-Packets for Deepfake Image Analysis and Detection 11 | 12 |

13 | 14 | GitHub Actions 16 | 17 | PyPI - Project 19 | 20 |

21 | 22 | 23 | This is the supplementary source code for our paper [Wavelet-Packets for Deepfake Image Analysis and Detection, Machine Learning, Special Issue of the ECML PKDD 2022 Journal Track](https://rdcu.be/cUIRt). 24 | 25 | ![packet plot](./img/packet_visualization2.png) 26 | 27 | The plot above illustrates the fundamental principle. 28 | It shows an FFHQ and a style-gan-generated image on the very left. 29 | In the center and on the right, packet coefficients and their standard deviation are depicted. 30 | We computed mean and standard deviation values using 5k images from each source. 31 | 32 | ## Installation 33 | 34 | The latest code can be installed in development mode with: 35 | 36 | ```shell 37 | $ git clone https://github.com/gan-police/frequency-forensics 38 | $ cd frequency-forensics 39 | $ pip install -e . 40 | ``` 41 | 42 | ## Assets 43 | 44 | ### GAN Architectures 45 | 46 | We utilize pre-trained models from the following repositories: 47 | 48 | - [StyleGAN](https://github.com/NVlabs/stylegan) 49 | - [GANFingerprints](https://github.com/ningyu1991/GANFingerprints) 50 | 51 | For our wavelet-packet computations, we use the : 52 | - [PyTorch-Wavelet-Toolbox: ptwt](https://github.com/v0lta/PyTorch-Wavelet-Toolbox) 53 | 54 | In the paper, we compare our approach to the DCT-method from: 55 | - [GANDCTAnalysis](https://github.com/RUB-SysSec/GANDCTAnalysis) 56 | 57 | ### Datasets 58 | 59 | We utilize three datasets that commonly appeared in previous work: 60 | 61 | - [FFHQ](https://github.com/NVlabs/ffhq-dataset) 62 | - [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 63 | - [LSUN bedroom](https://github.com/fyu/lsun) 64 | 65 | ## Reproduction 66 | 67 | The following section of the README serves as a guide to reproducing our paper. Data for the 128 pixel FFHQ-Stylegan pair is available via 68 | [google-drive](https://drive.google.com/file/d/1MOHKuEVqURfCKAN9dwp1o2tuR19OTQCF/view?usp=sharing) [4.2GB] . 69 | 70 | ### Preparation 71 | 72 | We work with images of the size 128x128 pixels. Hence, the raw images gave to be cropped 73 | and/or resized to this size. To do this, run `freqdect.crop_celeba` or `freqdect.crop_lsun`, depending on the dataset. 74 | This will create a new folder with the transformed images. The FFHQ dataset is already distributed in the required image 75 | size. 76 | 77 | Use the pretrained GAN-models to generate images. In case of StyleGAN, there is only a pre-trained model generating 78 | images of size 1024x1024, so one has to resize the GAN-generated images to size 128x128 pixels, e.g. by inserting 79 | 80 | ```python 81 | PIL.Image.fromarray(images[0], 'RGB').resize((128, 128)).save(png_filename) 82 | ``` 83 | 84 | into 85 | the [ffhq-stylegan](https://github.com/NVlabs/stylegan/blob/03563d18a0cf8d67d897cc61e44479267968716b/pretrained_example.py) 86 | . 87 | 88 | Store all images (cropped original and GAN-generated) in a separate subdirectories of a directory, i.e. the directory 89 | structure should look like this 90 | 91 | ``` 92 | source_data 93 | ├── A_original 94 | ├── B_CramerGAN 95 | ├── C_MMDGAN 96 | ├── D_ProGAN 97 | └── E_SNGAN 98 | ``` 99 | 100 | For the FFHQ case, we have only two subdirectories: `ffhq_stylegan/A_ffhq` and `ffhq_stylegan/B_stylegan`. The prefixes 101 | of the folders are important, since the directories get the labels in lexicographic order of their prefix, i.e. 102 | directory `A_...` gets label 0, `B_...` label 1, etc. 103 | 104 | Now, to prepare the data sets run `freqdect.prepare_dataset` . It reads in the data set, splits them into a training, 105 | validation and test set, applies the specified transformation (to wavelet packets, log-scaled wavelet packets or just 106 | the raw image data) and stores the result as numpy arrays. 107 | 108 | Afterwards run e.g.: 109 | 110 | ```shell 111 | $ python -m freqdect.prepare_dataset ./data/source_data/ --log-packets 112 | $ python -m freqdect.prepare_dataset ./data/source_data/ 113 | ``` 114 | 115 | The dataset preparation script accepts additional arguments. For example, it is possible to change the sizes of the 116 | train, test or validation sets. For a list of all optional arguments, open the help page via the `-h` argument. 117 | 118 | ### Training the Classifier 119 | 120 | Now you should be able to train a classifier using for example: 121 | 122 | ```shell 123 | $ python -m freqdect.train_classifier \ 124 | --data-prefix ./data/source_data_log_packets_haar_reflect_3 \ 125 | --calc-normalization \ 126 | --features packets 127 | ``` 128 | 129 | This trains a regression classifier using default hyperparameters. The training, validation and test accuracy and loss 130 | values are stored in a file placed in a `log` folder. The state dict of the trained model is stored there as well. For a 131 | list of all optional arguments, open the help page via the `-h` argument. 132 | 133 | ### Evaluating the Classifier 134 | 135 | #### Plotting the Metrics 136 | 137 | To plot the accuracy results, run: 138 | 139 | ```shell 140 | $ python -m freqdect.plot_accuracy_results {shared, lsun, celeba} {regression, CNN} ... 141 | ``` 142 | 143 | For a list of all optional arguments, open the help page via the `-h` argument. 144 | 145 | #### Calculating the confusion matrix 146 | 147 | To calculate the confusion matrix, run `freqdect.confusion_matrix`. For a list of all arguments, open the help page via 148 | the `-h` argument. 149 | 150 | ## ⚖️ Licensing 151 | 152 | This project is licensed under the [GNU GPLv3 license](LICENSE) 153 | 154 | ## Acknowledgements 155 | 156 | ### 📖 Citation 157 | If you find this work useful please consider citing: 158 | ``` 159 | @article{wolter2022waveletpacket, 160 | title = {Wavelet-Packets for Deepfake Image Analysis and Detection}, 161 | author = {Moritz Wolter and Felix Blanke and Raoul Heese and Jochen Garcke}, 162 | journal = {Machine Learning}, 163 | year = {2022}, 164 | volume = {Special Issue of the ECML PKDD 2022 Journal Track}, 165 | pages = {1-33}, 166 | month = {August}, 167 | url = {https://rdcu.be/cUIRt}, 168 | issn = {0885-6125}, 169 | doi = {https://doi.org/10.1007/s10994-022-06225-5} 170 | } 171 | ``` 172 | 173 | ### 🙏 Support 174 | 175 | This project has been supported by the following organizations (in alphabetical order): 176 | 177 | - [Fraunhofer Institute for Algorithms and Scientific Computing (SCAI)](https://www.scai.fraunhofer.de) 178 | - [Fraunhofer Cluster of Excellence Cognitive Internet Technologies (CCIT)](https://www.cit.fraunhofer.de/en.html) 179 | 180 | ### 🍪 Cookiecutter 181 | 182 | This package was created with [@audreyfeldroy](https://github.com/audreyfeldroy)'s 183 | [cookiecutter](https://github.com/cookiecutter/cookiecutter) package using [@cthoyt](https://github.com/cthoyt)'s 184 | [cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) template. 185 | -------------------------------------------------------------------------------- /src/freqdect/models.py: -------------------------------------------------------------------------------- 1 | """This module contains code for deepfake detection models.""" 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def compute_parameter_total(net: torch.nn.Module) -> int: 7 | """Compute the parameter total of the input net. 8 | 9 | Args: 10 | net (torch.nn.Module): The model containing the 11 | parameters to count. 12 | 13 | Returns: 14 | int: The parameter total. 15 | """ 16 | total = 0 17 | for p in net.parameters(): 18 | if p.requires_grad: 19 | print(p.shape) 20 | total += np.prod(p.shape) # type: ignore 21 | return total 22 | 23 | 24 | class CNN(torch.nn.Module): 25 | """CNN models used for packet or pixel classification.""" 26 | 27 | def __init__(self, classes: int, feature: str = "image"): 28 | """Create a convolutional neural network (CNN) model. 29 | 30 | Args: 31 | classes (int): The number of classes or sources to classify. 32 | feature (str)): A string which tells us the input feature 33 | we are using. 34 | """ 35 | super().__init__() 36 | self.feature = feature 37 | 38 | if feature == "packets": 39 | self.layers = torch.nn.Sequential( 40 | torch.nn.Conv2d(192, 24, 3), 41 | torch.nn.ReLU(), 42 | torch.nn.Conv2d(24, 24, 6), 43 | torch.nn.ReLU(), 44 | torch.nn.Conv2d(24, 24, 9), 45 | torch.nn.ReLU(), 46 | ) 47 | self.linear = torch.nn.Linear(24, classes) 48 | elif feature == "all-packets" or feature == "all-packets-fourier": 49 | if feature == "all-packets-fourier": 50 | self.scale1 = torch.nn.Sequential( 51 | torch.nn.Conv2d(6, 8, 3, padding=1), 52 | torch.nn.ReLU(), 53 | torch.nn.AvgPool2d(2, 2), 54 | ) 55 | else: 56 | self.scale1 = torch.nn.Sequential( 57 | torch.nn.Conv2d(3, 8, 3, padding=1), 58 | torch.nn.ReLU(), 59 | torch.nn.AvgPool2d(2, 2), 60 | ) 61 | self.scale2 = torch.nn.Sequential( 62 | torch.nn.Conv2d(20, 16, 3, padding=1), 63 | torch.nn.ReLU(), 64 | torch.nn.AvgPool2d(2, 2), 65 | ) 66 | self.scale3 = torch.nn.Sequential( 67 | torch.nn.Conv2d(64, 32, 3, padding=1), 68 | torch.nn.ReLU(), 69 | torch.nn.AvgPool2d(2, 2), 70 | ) 71 | self.scale4 = torch.nn.Sequential( 72 | torch.nn.Conv2d(224, 32, 3, 1, padding=1), torch.nn.ReLU() 73 | ) 74 | self.linear = torch.nn.Linear(32 * 16 * 16, classes) 75 | else: 76 | # assume an 128x128x3 image input. 77 | self.layers = torch.nn.Sequential( 78 | torch.nn.Conv2d(3, 8, 3), 79 | torch.nn.ReLU(), 80 | torch.nn.Conv2d(8, 8, 3), 81 | torch.nn.ReLU(), 82 | torch.nn.AvgPool2d(2, 2), 83 | torch.nn.Conv2d(8, 16, 3), 84 | torch.nn.ReLU(), 85 | torch.nn.AvgPool2d(2, 2), 86 | torch.nn.Conv2d(16, 32, 3), 87 | torch.nn.ReLU(), 88 | ) 89 | self.linear = torch.nn.Linear(32 * 28 * 28, classes) 90 | self.logsoftmax = torch.nn.LogSoftmax(dim=-1) 91 | 92 | def forward(self, x) -> torch.Tensor: 93 | """Compute the CNN forward pass. 94 | 95 | Args: 96 | x (torch.Tensor or dict): An input image of shape 97 | [batch_size, packets, height, width, channels] 98 | for packet inputs and 99 | [batch_size, height, width, channels] 100 | else. 101 | 102 | Returns: 103 | torch.Tensor: A logsoftmax scaled output of shape 104 | [batch_size, classes]. 105 | 106 | """ 107 | # x = generate_packet_image_tensor(x) 108 | if self.feature == "packets": 109 | # batch_size, packets, height, width, channels 110 | shape = x.shape 111 | # batch_size, height, width, packets, channels 112 | x = x.permute([0, 2, 3, 1, 4]) 113 | # batch_size, height, width, packets*channels 114 | to_net = x.reshape([shape[0], shape[2], shape[3], shape[1] * shape[4]]) 115 | # batch_size, packets*channels, height, width 116 | elif self.feature == "all-packets": 117 | to_net = x["raw"] 118 | elif self.feature == "all-packets-fourier": 119 | to_net = torch.cat([x["raw"], x["fourier"]], dim=-1) 120 | else: 121 | to_net = x 122 | 123 | to_net = to_net.permute([0, 3, 1, 2]) 124 | 125 | if self.feature == "all-packets" or self.feature == "all-packets-fourier": 126 | res = self.scale1(to_net) 127 | packets = [ 128 | torch.reshape( 129 | x[key].permute([0, 2, 3, 1, 4]), 130 | [x[key].shape[0], x[key].shape[2], x[key].shape[3], -1], 131 | ).permute(0, 3, 1, 2) 132 | for key in ["packets1", "packets2", "packets3"] 133 | ] 134 | # shape: batch_size, packet_channels, height, widht, color_channels 135 | # cat along channel dim1. 136 | to_net = torch.cat([packets[0], res], dim=1) 137 | res = self.scale2(to_net) 138 | to_net = torch.cat([packets[1], res], dim=1) 139 | res = self.scale3(to_net) 140 | to_net = torch.cat([packets[2], res], dim=1) 141 | out = self.scale4(to_net) 142 | out = torch.reshape(out, [out.shape[0], -1]) 143 | out = self.linear(out) 144 | else: 145 | out = self.layers(to_net) 146 | out = torch.reshape(out, [out.shape[0], -1]) 147 | out = self.linear(out) 148 | return self.logsoftmax(out) 149 | 150 | 151 | class Regression(torch.nn.Module): 152 | """A shallow linear-regression model.""" 153 | 154 | def __init__(self, classes: int): 155 | """Create the regression model. 156 | 157 | Args: 158 | classes (int): The number of classes or sources to classify. 159 | """ 160 | super().__init__() 161 | self.linear = torch.nn.Linear(49152, classes) 162 | 163 | # self.activation = torch.nn.Sigmoid() 164 | self.logsoftmax = torch.nn.LogSoftmax(dim=-1) 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | """Compute the regression forward pass. 168 | 169 | Args: 170 | x (torch.Tensor): An input tensor of shape 171 | [batch_size, ...] 172 | 173 | Returns: 174 | torch.Tensor: A logsoftmax scaled output of shape 175 | [batch_size, classes]. 176 | """ 177 | x_flat = torch.reshape(x, [x.shape[0], -1]) 178 | return self.logsoftmax(self.linear(x_flat)) 179 | 180 | 181 | def save_model(model: torch.nn.Module, path): 182 | """Save the state dict of the model to the specified path. 183 | 184 | Args: 185 | model (torch.nn.Module): model to store 186 | path: file path of the storage file 187 | """ 188 | torch.save(model.state_dict(), path) 189 | 190 | 191 | def initialize_model(model: torch.nn.Module, path): 192 | """Initialize the given model from a stored state dict file. 193 | 194 | Args: 195 | model (torch.nn.Module): model to initialize 196 | path: file path of the storage file 197 | """ 198 | model.load_state_dict(torch.load(path)) 199 | -------------------------------------------------------------------------------- /src/freqdect/wavelet_plot.py: -------------------------------------------------------------------------------- 1 | """Code to create wavelet packet plots.""" 2 | 3 | import argparse 4 | 5 | import cv2 6 | import matplotlib.colors as colors 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | from pywt._doc_utils import _2d_wp_basis_coords 11 | 12 | from .wavelet_math import ( 13 | compute_packet_rep_2d, 14 | compute_pytorch_packet_representation_2d_tensor, 15 | ) 16 | 17 | 18 | def draw_2d_wp_basis(shape, keys, fmt="k", plot_kwargs=None, ax=None, label_levels=0): 19 | """Plot a 2D representation of a WaveletPacket2D basis. 20 | 21 | Based on: pywt._doc_utils.draw_2d_wp_basis 22 | """ 23 | coords, centers = _2d_wp_basis_coords(shape, keys) 24 | if ax is None: 25 | fig, ax = plt.subplots(1, 1) 26 | else: 27 | fig = ax.get_figure() 28 | for coord in coords: 29 | ax.plot(coord[0], coord[1], fmt) 30 | ax.set_axis_off() 31 | ax.axis("square") 32 | if label_levels > 0: 33 | for key, c in centers.items(): 34 | if len(key) <= label_levels: 35 | ax.text( 36 | c[0], 37 | c[1], 38 | "".join(key), 39 | horizontalalignment="center", 40 | verticalalignment="center", 41 | fontsize=6, 42 | ) 43 | return fig, ax 44 | 45 | 46 | def read_pair(path_real, path_fake): 47 | """Load an image pair into numpy arrays. 48 | 49 | Args: 50 | path_real (str): A path to a real image. 51 | path_fake (str): Another path to a generated image. 52 | 53 | Returns: 54 | tuple: Two numpy arrays for each image. 55 | """ 56 | face = cv2.cvtColor(cv2.imread(path_real), cv2.COLOR_BGR2RGB) / 255.0 57 | fake_face = cv2.cvtColor(cv2.imread(path_fake), cv2.COLOR_BGR2RGB) / 255.0 58 | return face, fake_face 59 | 60 | 61 | def compute_packet_rep_img(image, wavelet_str, max_lev): 62 | """Compute the packet representation of an input image. 63 | 64 | Args: 65 | image (np.array): An image of shape [height, widht, channel] 66 | wavelet_str (str): A string indicating the desired wavelet according 67 | to the pywt convention. I.e. 'haar.' 68 | max_lev (int): The level up to which the packet representation should be 69 | computed. I.e. 3. 70 | 71 | Raises: 72 | ValueError: If the image shape does not have 2 or 3 dimensions. 73 | 74 | Returns: 75 | np.array: A stacked version of the wavelet packet representation. 76 | 77 | # noqa: DAR401 78 | """ 79 | if len(image.shape) == 3: 80 | channels_lst = [] 81 | for channel in range(3): 82 | channels_lst.append( 83 | compute_packet_rep_2d(image[:, :, channel], wavelet_str, max_lev) 84 | ) 85 | return np.stack(channels_lst, axis=-1) 86 | elif len(image.shape) != 2: 87 | raise ValueError(f"invalid image shape: {image.shape}") 88 | else: 89 | return compute_packet_rep_2d(image, wavelet_str, max_lev) 90 | 91 | 92 | def main(): 93 | """Compute some wavelet packets of real and generated images for visual comparison.""" 94 | parser = argparse.ArgumentParser( 95 | description="Plot wavelet decomposition of real and fake imgs" 96 | ) 97 | parser.add_argument( 98 | "--data-dir", 99 | type=str, 100 | default="./data/", 101 | help="path of folder containing the data (default: ./data/)", 102 | ) 103 | parser.add_argument( 104 | "--real-data", 105 | type=str, 106 | default="A_ffhq", 107 | help="name of folder with real data (default: A_ffhq)", 108 | ) 109 | parser.add_argument( 110 | "--fake-data", 111 | type=str, 112 | default="B_stylegan", 113 | help="name of folder with fake data (default: B_stylegan)", 114 | ) 115 | args = parser.parse_args() 116 | 117 | print(args) 118 | 119 | pairs = [] 120 | pairs.append( 121 | read_pair( 122 | args.data_dir + args.real_data + "/00000.png", 123 | args.data_dir + args.fake_data + "/style_gan_ffhq_example0.png", 124 | ) 125 | ) 126 | pairs.append( 127 | read_pair( 128 | args.data_dir + args.real_data + "/00001.png", 129 | args.data_dir + args.fake_data + "/style_gan_ffhq_example1.png", 130 | ) 131 | ) 132 | pairs.append( 133 | read_pair( 134 | args.data_dir + args.real_data + "/00002.png", 135 | args.data_dir + args.fake_data + "/style_gan_ffhq_example2.png", 136 | ) 137 | ) 138 | pairs.append( 139 | read_pair( 140 | args.data_dir + args.real_data + "/00003.png", 141 | args.data_dir + args.fake_data + "/style_gan_ffhq_example3.png", 142 | ) 143 | ) 144 | pairs.append( 145 | read_pair( 146 | args.data_dir + args.real_data + "/00004.png", 147 | args.data_dir + args.fake_data + "/style_gan_ffhq_example4.png", 148 | ) 149 | ) 150 | pairs.append( 151 | read_pair( 152 | args.data_dir + args.real_data + "/00005.png", 153 | args.data_dir + args.fake_data + "/style_gan_ffhq_example5.png", 154 | ) 155 | ) 156 | 157 | wavelet = "db1" 158 | max_lev = 3 159 | for real, fake in pairs: 160 | real = ( 161 | torch.from_numpy(np.mean(real, -1).astype(np.float32)).unsqueeze(0).cuda() 162 | ) 163 | fake = ( 164 | torch.from_numpy(np.mean(fake, -1).astype(np.float32)).unsqueeze(0).cuda() 165 | ) 166 | # plt.imshow(np.concatenate([real, fake], axis=1)) 167 | # plt.show() 168 | real_packets = compute_pytorch_packet_representation_2d_tensor( 169 | real, wavelet_str=wavelet, max_lev=max_lev 170 | ) 171 | fake_packets = compute_pytorch_packet_representation_2d_tensor( 172 | fake, wavelet_str=wavelet, max_lev=max_lev 173 | ) 174 | 175 | real_packets = torch.squeeze(real_packets) 176 | fake_packets = torch.squeeze(fake_packets) 177 | 178 | # merge_packets = np.concatenate([real_packets, fake_packets], axis=1) 179 | abs_real_packets = np.abs(real_packets.cpu().numpy()) 180 | abs_fake_packets = np.abs(fake_packets.cpu().numpy()) 181 | # scaled_packets = abs_packets/np.max(abs_packets) 182 | # log_scaled_packets = np.log(abs_packets) 183 | # scaled_packets = np. 184 | 185 | scale_min = np.min([abs_real_packets.min(), abs_fake_packets.min()]) + 2e-4 186 | scale_max = np.max([abs_real_packets.max(), abs_fake_packets.max()]) 187 | 188 | cmap = "cividis" # 'cividis' # 'magma' #'inferno' # 'viridis 189 | fig = plt.figure(figsize=(20, 6)) 190 | ax1 = fig.add_subplot(121) 191 | ax2 = fig.add_subplot(122) 192 | # ax3 = fig.add_subplot(133) 193 | ax1.set_title("real img " + wavelet + " packet decomposition") 194 | ax1.imshow( 195 | abs_real_packets, 196 | norm=colors.LogNorm(vmin=scale_min, vmax=scale_max), 197 | cmap=cmap, 198 | ) 199 | ax2.set_title("fake img " + wavelet + " packet decomposition") 200 | _ = ax2.imshow( 201 | abs_fake_packets, 202 | norm=colors.LogNorm(vmin=scale_min, vmax=scale_max), 203 | cmap=cmap, 204 | ) 205 | # fig.colorbar(im) 206 | # shape = real.shape 207 | # keys = list(product(['a', 'h', 'v', 'd'], repeat=max_lev)) 208 | # draw_2d_wp_basis(shape, keys, ax=ax3, label_levels=max_lev) 209 | # ax3.set_title('packet labels') 210 | plt.show() 211 | 212 | plt.semilogy(np.mean(abs_real_packets, 0), label="real") 213 | plt.semilogy(np.mean(abs_fake_packets, 0), label="fake") 214 | plt.legend() 215 | plt.show() 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /src/freqdect/saliency.py: -------------------------------------------------------------------------------- 1 | """Sensitivity analysis/explainability module for trained models. 2 | 3 | Written by https://github.com/RaoulHeese . 4 | """ 5 | 6 | import argparse 7 | import os 8 | import pickle 9 | 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from .data_loader import NumpyDataset 17 | from .models import CNN, Regression 18 | 19 | 20 | def save_to_disk( 21 | array_dict: dict, 22 | directory: str, 23 | sub_dir: str = "", 24 | counter: int = 0, 25 | ) -> None: 26 | """Save data batches (dicts of numpy arrays) to disk. 27 | 28 | Args: 29 | array_dict (dict): The data dict to store of the form {str: np.ndarray}. 30 | directory (str): The place to store the data. 31 | sub_dir (str): Subdirectory to store the images at. 32 | counter (int): Index of the file (sets filename). 33 | 34 | Returns: 35 | None 36 | """ 37 | # loop over the batch dimension 38 | path = os.path.join(directory, sub_dir) 39 | if not os.path.exists(path): 40 | # print("creating", path) 41 | os.mkdir(path) 42 | with open(os.path.join(path, f"{counter:06}.npy"), "wb") as numpy_file: 43 | np.savez(numpy_file, **array_dict) 44 | 45 | 46 | def saliency( 47 | model, 48 | data_loader: DataLoader, 49 | directory: str, 50 | sub_dir: str = "", 51 | ) -> int: 52 | """Calculate raw gradients (d out/d in) for each data point of a data set given one model. 53 | 54 | Args: 55 | model: PyTorch model to use for predictions. 56 | data_loader: DataLoader instance in which the data is stored. 57 | directory (str): The place to store the images at. 58 | sub_dir (str): Subdirectory to store the images at. 59 | 60 | Returns: 61 | int: The total number of processed gradients. 62 | """ 63 | print(f"Evaluate gradients -> {sub_dir}...") 64 | 65 | if torch.cuda.is_available(): 66 | model = model.cuda() 67 | for param in model.parameters(): 68 | param.requires_grad = False 69 | 70 | counter_total = 0 71 | counter_image = 0 72 | with tqdm(desc="process") as prog: 73 | for it, batch in enumerate(iter(data_loader)): 74 | 75 | model.eval() 76 | batch_images = batch["image"] 77 | # batch_labels = batch["label"] 78 | if torch.cuda.is_available(): 79 | batch_images = batch_images.cuda(non_blocking=True) 80 | # batch_labels = batch_labels.cuda(non_blocking=True) 81 | batch_images.requires_grad = True 82 | out_batch = model(batch_images) 83 | n_inputs = out_batch.shape[0] 84 | n_classes = out_batch.shape[1] 85 | slc_batch = np.empty((n_inputs, n_classes) + batch_images[0, :].shape) 86 | 87 | image_index = [] 88 | for i in range(n_inputs): 89 | for c in range(n_classes): 90 | if batch_images.grad is not None: 91 | batch_images.grad.zero_() 92 | out_batch[i, c].backward(retain_graph=True) 93 | slc = batch_images.grad[i].cpu().detach().numpy() 94 | slc_batch[i, c, :] = slc 95 | counter_total += 1 96 | image_index.append(counter_image) 97 | counter_image += 1 98 | 99 | slc_batch = np.asarray(slc_batch) # [idx_in_batch, class_idx, data...] 100 | out_batch = out_batch.cpu().detach().numpy() # [idx_in_batch, label] 101 | image_array = np.array(image_index) # (index1, ..., indexN) 102 | save_to_disk( 103 | {"S": slc_batch, "O": out_batch, "I": image_array}, 104 | directory, 105 | sub_dir, 106 | it, 107 | ) 108 | 109 | prog.set_postfix( 110 | {"img": counter_image, "all": counter_total}, refresh=False 111 | ) 112 | prog.update(1) 113 | 114 | return counter_image 115 | 116 | 117 | def main(args): 118 | """Compute gradients (d out/d in) for trained models.""" 119 | # torch seed for reproducible results 120 | torch.manual_seed(args.seed) 121 | 122 | # normalization 123 | if args.normalize: 124 | num_of_norm_vals = len(args.normalize) 125 | if (not num_of_norm_vals == 2) or (not num_of_norm_vals == 6): 126 | raise ValueError("Either two or six normalization values are required.") 127 | mean = torch.tensor(args.normalize[: (num_of_norm_vals // 2)]) 128 | std = torch.tensor(args.normalize[(num_of_norm_vals // 2) :]) 129 | elif args.calc_normalization: 130 | # load train data and compute mean and std 131 | try: 132 | with open(f"{args.data_prefix}_train/mean_std.pkl", "rb") as file: 133 | mean, std = pickle.load(file) 134 | mean = torch.from_numpy(mean.astype(np.float32)) 135 | std = torch.from_numpy(std.astype(np.float32)) 136 | except BaseException: 137 | print("loading mean and std from file failed. Re-computing.") 138 | train_data_set = NumpyDataset(args.data_prefix + "_train") 139 | 140 | img_lst = [] 141 | for img_no in range(train_data_set.__len__()): 142 | img_lst.append(train_data_set.__getitem__(img_no)["image"]) 143 | img_data = torch.stack(img_lst, 0) 144 | 145 | # average all axis except the color channel 146 | axis = tuple(np.arange(len(img_data.shape[:-1]))) 147 | 148 | # calculate mean and std in double to avoid precision problems 149 | mean = torch.mean(img_data.double(), axis).float() 150 | std = torch.std(img_data.double(), axis).float() 151 | del img_data 152 | else: 153 | mean = None 154 | std = None 155 | 156 | print("mean", mean, "std", std) 157 | 158 | # Load data 159 | train_data_set = NumpyDataset(args.data_prefix + "_train", mean=mean, std=std) 160 | val_data_set = NumpyDataset(args.data_prefix + "_val", mean=mean, std=std) 161 | test_data_set = NumpyDataset(args.data_prefix + "_test", mean=mean, std=std) 162 | train_data_loader = DataLoader( 163 | train_data_set, 164 | batch_size=args.batch_size, 165 | shuffle=False, 166 | num_workers=2, 167 | drop_last=False, 168 | ) 169 | val_data_loader = DataLoader( 170 | val_data_set, 171 | batch_size=args.batch_size, 172 | shuffle=False, 173 | num_workers=2, 174 | drop_last=False, 175 | ) 176 | test_data_loader = DataLoader( 177 | test_data_set, 178 | batch_size=args.batch_size, 179 | shuffle=False, 180 | num_workers=2, 181 | drop_last=False, 182 | ) 183 | 184 | # Build model 185 | if args.model == "cnn": 186 | model = CNN(args.nclasses, args.features).cuda() 187 | else: 188 | model = Regression(args.nclasses).cuda() 189 | 190 | # Load model parameters 191 | if torch.cuda.is_available(): 192 | # saved on GPU, load on GPU 193 | model.load_state_dict(torch.load(args.model_pt_path)) 194 | else: 195 | # saved on GPU, load on CPU 196 | map_location = torch.device("cpu") 197 | model.load_state_dict(torch.load(args.model_pt_path, map_location=map_location)) 198 | print(f"Model loaded: {args.model_pt_path}") 199 | 200 | # Saliency 201 | directory = args.result_dir 202 | count = saliency( 203 | model, train_data_loader, directory, f"{args.model}_{args.features}_train" 204 | ) 205 | print(f"Processed {count} train images.") 206 | count = saliency( 207 | model, test_data_loader, directory, f"{args.model}_{args.features}_test" 208 | ) 209 | print(f"Processed {count} test images.") 210 | count = saliency( 211 | model, val_data_loader, directory, f"{args.model}_{args.features}_val" 212 | ) 213 | print(f"Processed {count} val images.") 214 | 215 | print("Finished.") 216 | 217 | 218 | def _parse_args(): 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument( 221 | "--model-pt-path", 222 | type=str, 223 | required=True, 224 | help="Path to model pt file (required).", 225 | ) 226 | parser.add_argument( 227 | "--result-dir", 228 | type=str, 229 | required=True, 230 | help="Shared result dir (required).", 231 | ) 232 | parser.add_argument( 233 | "--features", 234 | choices=["raw", "packets"], 235 | default="packets", 236 | help="the representation type", 237 | ) 238 | parser.add_argument( 239 | "--batch-size", 240 | type=int, 241 | default=512, 242 | help="input batch size (default: 512)", 243 | ) 244 | parser.add_argument( 245 | "--model", 246 | choices=["regression", "cnn"], 247 | default="regression", 248 | help="The model type: regression, cnn. (default: regression).", 249 | ) 250 | parser.add_argument( 251 | "--data-prefix", 252 | type=str, 253 | default="./data/source_data_packets", 254 | help="shared prefix of the data paths (default: ./data/source_data_packets)", 255 | ) 256 | parser.add_argument( 257 | "--nclasses", type=int, default=2, help="number of classes (default: 2)" 258 | ) 259 | parser.add_argument( 260 | "--seed", type=int, default=42, help="the random seed pytorch works with." 261 | ) 262 | group = parser.add_mutually_exclusive_group() 263 | group.add_argument( 264 | "--normalize", 265 | nargs="+", 266 | type=float, 267 | metavar=("MEAN", "STD"), 268 | help="normalize with specified values for mean and standard deviation (either 2 or 6 values " 269 | "are accepted)", 270 | ) 271 | group.add_argument( 272 | "--calc-normalization", 273 | action="store_true", 274 | help="calculates mean and standard deviation used in normalization" 275 | "from the training data", 276 | ) 277 | return parser.parse_args() 278 | 279 | 280 | if __name__ == "__main__": 281 | args = _parse_args() 282 | print(args) 283 | main(args) 284 | -------------------------------------------------------------------------------- /src/freqdect/plot_mean_packets.py: -------------------------------------------------------------------------------- 1 | """Source code to visualize mean wavelet packets and their standard deviation for visual inspection.""" 2 | 3 | from itertools import product 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | 9 | from .data_loader import NumpyDataset 10 | 11 | 12 | def _plot_mean_std(x, mean, std, color, label="", marker="."): 13 | plt.plot(x, mean, label=label, color=color, marker=marker) 14 | plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.2) 15 | 16 | 17 | def generate_packet_image_tensor(packet_array: torch.Tensor) -> torch.Tensor: 18 | """Arrange a packet tensor as an image for imshow. 19 | 20 | Args: 21 | packet_array ([torch.Tensor): The [bach_size, packet_no, height, width, channels] packets 22 | Returns: 23 | [torch.Tensor]: The image of shape [batch_size, height, width, channels] 24 | """ 25 | packet_count = packet_array.shape[1] 26 | count = 0 27 | img_rows = None 28 | img = [] 29 | for node_no in range(packet_count): 30 | packet = packet_array[:, node_no] 31 | if img_rows is not None: 32 | img_rows = torch.cat([img_rows, packet], axis=2) 33 | else: 34 | img_rows = packet 35 | count += 1 36 | if count >= np.sqrt(packet_count): 37 | count = 0 38 | img.append(img_rows) 39 | img_rows = None 40 | return torch.cat(img, dim=1) 41 | 42 | 43 | def generate_natural_packet_image(packet_array: np.ndarray, degree: int): 44 | """Arrange a packet array as an image for imshow. 45 | 46 | Args: 47 | packet_array ([np.ndarray): The [packet_no, packet_height, packet_width] packets 48 | degree (int): The degree of the transformation. 49 | 50 | Returns: 51 | [np.ndarray]: The image of shape [original_height, original_width] 52 | """ 53 | 54 | def _cat_sector(elements: np.ndarray, level: int, max_level: int): 55 | element_lst = np.split(elements, 4) 56 | if level < max_level - 1: 57 | img0 = _cat_sector(element_lst[0], level + 1, max_level) 58 | img1 = _cat_sector(element_lst[1], level + 1, max_level) 59 | img2 = _cat_sector(element_lst[2], level + 1, max_level) 60 | img3 = _cat_sector(element_lst[3], level + 1, max_level) 61 | return np.concatenate( 62 | [ 63 | np.concatenate([img0, img1], axis=2), 64 | np.concatenate([img2, img3], axis=2), 65 | ], 66 | 1, 67 | ) 68 | else: 69 | img = np.concatenate( 70 | [ 71 | np.concatenate([element_lst[0], element_lst[1]], axis=2), 72 | np.concatenate([element_lst[2], element_lst[3]], axis=2), 73 | ], 74 | 1, 75 | ) 76 | return img 77 | 78 | return _cat_sector(packet_array, 0, degree).squeeze() 79 | 80 | 81 | def generate_frequency_packet_image(packet_array: np.ndarray, degree: int): 82 | """Create a ready-to-polt image with frequency-order packages. 83 | 84 | Given a packet array in natural order, creat an image which is 85 | ready to plot in frequency order. 86 | 87 | Args: 88 | packet_array (np.ndarray): [packet_no, packet_height, packet_width] 89 | in natural order. 90 | degree (int): The degree of the packet decomposition. 91 | 92 | Returns: 93 | [np.ndarray]: The image of shape [original_height, original_width] 94 | """ 95 | wp_freq_path, wp_natural_path = get_freq_order(degree) 96 | 97 | image = [] 98 | # go through the rows. 99 | for row_paths in wp_freq_path: 100 | row = [] 101 | for row_path in row_paths: 102 | index = wp_natural_path.index(row_path) 103 | packet = packet_array[index] 104 | row.append(packet) 105 | image.append(np.concatenate(row, -1)) 106 | return np.concatenate(image, 0) 107 | 108 | 109 | def get_freq_order(level: int): 110 | """Get the frequency order for a given packet decomposition level. 111 | 112 | Adapted from: 113 | https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py 114 | 115 | The code elements denote the filter application order. The filters 116 | are named following the pywt convention as: 117 | a - LL, low-low coefficients 118 | h - LH, low-high coefficients 119 | v - HL, high-low coefficients 120 | d - HH, high-high coefficients 121 | """ 122 | wp_natural_path = list(product(["a", "h", "v", "d"], repeat=level)) 123 | 124 | def _get_graycode_order(level, x="a", y="d"): 125 | graycode_order = [x, y] 126 | for _ in range(level - 1): 127 | graycode_order = [x + path for path in graycode_order] + [ 128 | y + path for path in graycode_order[::-1] 129 | ] 130 | return graycode_order 131 | 132 | def _expand_2d_path(path): 133 | expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"} 134 | return ( 135 | "".join([expanded_paths[p][0] for p in path]), 136 | "".join([expanded_paths[p][1] for p in path]), 137 | ) 138 | 139 | nodes: dict = {} 140 | for (row_path, col_path), node in [ 141 | (_expand_2d_path(node), node) for node in wp_natural_path 142 | ]: 143 | nodes.setdefault(row_path, {})[col_path] = node 144 | graycode_order = _get_graycode_order(level, x="l", y="h") 145 | nodes_list: list = [nodes[path] for path in graycode_order if path in nodes] 146 | wp_frequency_path = [] 147 | for row in nodes_list: 148 | wp_frequency_path.append([row[path] for path in graycode_order if path in row]) 149 | return wp_frequency_path, wp_natural_path 150 | 151 | 152 | def main(): 153 | """Compute mean wavelet packets and the standard deviation for a NumPy dataset.""" 154 | import matplotlib.pyplot as plt 155 | 156 | # raw images - use only the training set. 157 | # train_packet_set = NumpyDataset( 158 | # "/nvme/mwolter/ffhq1024x1024_log_packets_haar_reflect_train" 159 | # ) 160 | train_packet_set = NumpyDataset( 161 | "/nvme/mwolter/ffhq128_hard/source_data_log_packets_db4_boundary_3_train" 162 | ) 163 | 164 | fake_labels = [2] # [1, 2, 3, 4] 165 | 166 | fake_list = [] 167 | real_list = [] 168 | for img_no in range(train_packet_set.__len__()): 169 | train_element = train_packet_set.__getitem__(img_no) 170 | packets = train_element["image"].numpy() 171 | label = train_element["label"].numpy() 172 | if label in fake_labels: 173 | fake_list.append(packets) 174 | elif label == 0: 175 | real_list.append(packets) 176 | else: 177 | print("skipping label", label) 178 | 179 | if img_no % 500 == 0 and img_no > 0: 180 | print(img_no, "of", train_packet_set.__len__(), "loaded") 181 | # break 182 | 183 | fake_array = np.array(fake_list) 184 | del fake_list 185 | real_array = np.array(real_list) 186 | del real_list 187 | print("train set loaded.", fake_array.shape, real_array.shape) 188 | 189 | # mean image plots 190 | fake_mean_packet_image = generate_frequency_packet_image( 191 | np.mean(fake_array, axis=(0, -1)), degree=3 192 | ) 193 | real_mean_packet_image = generate_frequency_packet_image( 194 | np.mean(real_array, axis=(0, -1)), degree=3 195 | ) 196 | # std image plots 197 | fake_std_packet_image = generate_frequency_packet_image( 198 | np.std(fake_array, axis=(0, -1)), degree=3 199 | ) 200 | real_std_packet_image = generate_frequency_packet_image( 201 | np.std(real_array, axis=(0, -1)), degree=3 202 | ) 203 | 204 | fig = plt.figure(figsize=(8, 6)) 205 | columns = 3 206 | rows = 2 207 | plot_count = 1 208 | cmap = "cividis" # 'magma' #'inferno' # 'viridis 209 | 210 | mean_vmin = np.min((np.min(fake_mean_packet_image), np.min(real_mean_packet_image))) 211 | mean_vmax = np.max((np.max(fake_mean_packet_image), np.max(real_mean_packet_image))) 212 | std_vmin = np.min((np.min(fake_std_packet_image), np.min(real_std_packet_image))) 213 | std_vmax = np.max((np.max(fake_std_packet_image), np.max(real_std_packet_image))) 214 | 215 | def _plot_image(image, title, vmax=None, vmin=None): 216 | fig.add_subplot(rows, columns, plot_count) 217 | plt.imshow(image, cmap=cmap, vmax=vmax, vmin=vmin) 218 | plt.xticks([], []) 219 | plt.yticks([], []) 220 | plt.title(title) 221 | plt.colorbar() 222 | 223 | _plot_image(fake_mean_packet_image, "gan mean packets", mean_vmax, mean_vmin) 224 | plot_count += 1 225 | _plot_image(real_mean_packet_image, "data-set mean packets", mean_vmax, mean_vmin) 226 | plot_count += 1 227 | _plot_image( 228 | np.abs(fake_mean_packet_image - real_mean_packet_image), 229 | "absolute mean difference", 230 | ) 231 | plot_count += 1 232 | _plot_image(fake_std_packet_image, "gan std packets", std_vmax, std_vmin) 233 | plot_count += 1 234 | _plot_image(real_std_packet_image, "data-set std packets", std_vmax, std_vmin) 235 | plot_count += 1 236 | _plot_image( 237 | np.abs(fake_std_packet_image - real_std_packet_image), "absolute std difference" 238 | ) 239 | plot_count += 1 240 | 241 | plt.savefig("plot.png") 242 | if 0: 243 | import tikzplotlib 244 | 245 | tikzplotlib.save("ffhq_style_packet_mean_std_plot.tex", standalone=True) 246 | plt.show() 247 | print("first plot done") 248 | 249 | # mean packet plots 250 | style_gan_mean = np.mean(fake_array, axis=(0, 2, 3, 4)) 251 | style_gan_std = np.std(fake_array, axis=(0, 2, 3, 4)) 252 | ffhq_mean = np.mean(real_array, axis=(0, 2, 3, 4)) 253 | ffhq_std = np.std(real_array, axis=(0, 2, 3, 4)) 254 | 255 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] 256 | x = np.array(range(len(style_gan_mean))) 257 | wp_keys = list(product(["a", "h", "v", "d"], repeat=3)) 258 | wp_labels = ["".join(key) for key in wp_keys] 259 | _plot_mean_std(x, ffhq_mean, ffhq_std, colors[0], "real data") 260 | _plot_mean_std(x, style_gan_mean, style_gan_std, colors[1], "gan") 261 | plt.legend() 262 | plt.xlabel("filter") 263 | plt.xticks(x, labels=wp_labels) 264 | plt.xticks(rotation=80) 265 | plt.ylabel("mean absolute coefficient magnitude") 266 | plt.title("Mean absolute coefficient comparison real data-GAN") 267 | 268 | # plt.savefig("plot2.png") 269 | if 1: 270 | import tikzplotlib 271 | 272 | tikzplotlib.save( 273 | "absolute_coeff_comparison_stylegan" + str(fake_labels) + ".tex", 274 | standalone=True, 275 | ) 276 | plt.show() 277 | print("done") 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /src/freqdect/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | """Calculating confusion matrices from trained models that classify deepfake image data.""" 2 | import argparse 3 | import pickle 4 | from collections import defaultdict 5 | from typing import List 6 | 7 | import numpy as np 8 | import torch 9 | from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix 10 | from torch.utils.data import DataLoader 11 | 12 | from .data_loader import NumpyDataset 13 | from .models import CNN, Regression, initialize_model 14 | 15 | 16 | def calculate_confusion_matrix(args): 17 | """Calculate the confusion matrix. 18 | 19 | A test data set specified in the cmd line args is loaded (and normalized if specified). 20 | A model is loaded from a state dict file and used to classify the loaded test data. 21 | Then, a confusion matrix is computed from the predicted labels and the correct labels. 22 | 23 | Args: 24 | args: Command line args, in which settings such as the test data set path, the model file path, 25 | the normalization, etc. are specified. 26 | 27 | Raises: 28 | ValueError: If mean or std arguments are missing values. 29 | 30 | Returns: 31 | a confusion matrix, comparing the predicted and the actual labels for each class 32 | 33 | # noqa: DAR401 34 | """ 35 | if args.calc_normalization: 36 | # load train data and compute mean and std 37 | try: 38 | with open(f"{args.data_prefix}_train/mean_std.pkl", "rb") as file: 39 | mean, std = pickle.load(file) 40 | mean = torch.from_numpy(mean.astype(np.float32)) 41 | std = torch.from_numpy(std.astype(np.float32)) 42 | except BaseException: 43 | print("loading mean and std from file failed. Re-computing.") 44 | train_data_set = NumpyDataset(args.data_prefix + "_train") 45 | 46 | img_lst = [] 47 | for img_no in range(train_data_set.__len__()): 48 | img_lst.append(train_data_set.__getitem__(img_no)["image"]) 49 | img_data = torch.stack(img_lst, 0) 50 | 51 | # average all axis except the color channel 52 | axis = tuple(np.arange(len(img_data.shape[:-1]))) 53 | 54 | # calculate mean and std in double to avoid precision problems 55 | mean = torch.mean(img_data.double(), axis).float() 56 | std = torch.std(img_data.double(), axis).float() 57 | del img_data 58 | elif args.normalize: 59 | num_of_norm_vals = len(args.normalize) 60 | if not (num_of_norm_vals == 2 or num_of_norm_vals == 6): 61 | raise ValueError("incorrect mean and standard deviation input values.") 62 | mean = torch.tensor(args.normalize[: num_of_norm_vals // 2]) 63 | std = torch.tensor(args.normalize[(num_of_norm_vals // 2) :]) 64 | else: 65 | mean, std = [None, None] 66 | 67 | print("Mean: {}, std: {}".format(mean, std)) 68 | test_data_set = NumpyDataset(args.data_prefix + "_test", mean=mean, std=std) 69 | test_data_loader = DataLoader( 70 | test_data_set, batch_size=args.batch_size, shuffle=False, num_workers=2 71 | ) 72 | 73 | if args.model == "regression": 74 | model = Regression(args.nclasses).cuda() 75 | else: 76 | model = CNN(args.nclasses, args.features).cuda() 77 | 78 | initialize_model(model, args.classifier_path) 79 | model.eval() 80 | 81 | correct_labels = [] 82 | predicted_labels = [] 83 | 84 | with torch.no_grad(): 85 | for test_batch in iter(test_data_loader): 86 | batch_images = test_batch["image"].cuda(non_blocking=True) 87 | batch_labels = test_batch["label"] 88 | 89 | out = model(batch_images) 90 | out_labels = torch.max(out, dim=-1)[1] 91 | 92 | if args.nclasses == 2: 93 | batch_labels[batch_labels > 0] = 1 94 | 95 | correct_labels.extend(batch_labels.cpu()) 96 | predicted_labels.extend(out_labels.cpu()) 97 | 98 | return confusion_matrix(correct_labels, predicted_labels) 99 | 100 | 101 | def calculate_generalized_confusion_matrix(args): 102 | """Calculate a generalized confusion matrix for binary classification of fake/real images. 103 | 104 | A test data set specified in the cmd line args is loaded (and normalized if specified). 105 | A model is loaded from a state dict file and used to classify the loaded test data into 106 | the classes 'fake' and 'real'. 107 | Then, a generalized confusion matrix is computed from the predicted labels and the correct labels. 108 | The confusion matrix is insofar 'generalized' as the actual labels for the 'fake' class are split into 109 | subgroups according to the GAN that was used to generate the fake images. 110 | 111 | Args: 112 | args: Command line args, in which settings such as the test data set path, the model file path, 113 | the normalization, etc. are specified. 114 | 115 | Raises: 116 | ValueError: If mean or std arguments are missing values. 117 | 118 | Returns: 119 | a 'generalized' confusion matrix, containing for each image source \ 120 | (i.e. real and different GANs) the number of images that 121 | were classified as 'real' or 'fake'. 122 | 123 | # noqa: DAR401 124 | """ 125 | if args.calc_normalization: 126 | # load train data and compute mean and std 127 | try: 128 | with open(f"{args.data_prefix}_train/mean_std.pkl", "rb") as file: 129 | mean, std = pickle.load(file) 130 | mean = torch.from_numpy(mean.astype(np.float32)) 131 | std = torch.from_numpy(std.astype(np.float32)) 132 | except BaseException: 133 | print("loading mean and std from file failed. Re-computing.") 134 | train_data_set = NumpyDataset(args.data_prefix + "_train") 135 | 136 | img_lst = [] 137 | for img_no in range(train_data_set.__len__()): 138 | img_lst.append(train_data_set.__getitem__(img_no)["image"]) 139 | img_data = torch.stack(img_lst, 0) 140 | 141 | # average all axis except the color channel 142 | axis = tuple(np.arange(len(img_data.shape[:-1]))) 143 | 144 | # calculate mean and std in double to avoid precision problems 145 | mean = torch.mean(img_data.double(), axis).float() 146 | std = torch.std(img_data.double(), axis).float() 147 | del img_data 148 | elif args.normalize: 149 | num_of_norm_vals = len(args.normalize) 150 | if not (num_of_norm_vals == 2 or num_of_norm_vals == 6): 151 | raise ValueError("incorrect mean and standard deviation arguments.") 152 | mean = torch.tensor(args.normalize[: num_of_norm_vals // 2]) 153 | std = torch.tensor(args.normalize[(num_of_norm_vals // 2) :]) 154 | else: 155 | mean, std = [None, None] 156 | 157 | test_data_set = NumpyDataset(args.data_prefix + "_test", mean=mean, std=std) 158 | test_data_loader = DataLoader( 159 | test_data_set, batch_size=args.batch_size, shuffle=False, num_workers=2 160 | ) 161 | 162 | if args.model == "regression": 163 | model = Regression(args.nclasses).cuda() 164 | else: 165 | model = CNN(args.nclasses, args.features).cuda() 166 | 167 | initialize_model(model, args.classifier_path) 168 | model.eval() 169 | 170 | predicted_dict = defaultdict(list) 171 | 172 | label_names = np.array(["Original", "CramerGAN", "MMDGAN", "ProGAN", "SNGAN"]) 173 | 174 | with torch.no_grad(): 175 | for test_batch in iter(test_data_loader): 176 | batch_images = test_batch["image"].cuda(non_blocking=True) 177 | batch_labels = test_batch["label"].cpu() 178 | 179 | batch_names = label_names[batch_labels] 180 | 181 | out = model(batch_images) 182 | out_labels = torch.max(out, dim=-1)[1].cpu() 183 | 184 | for k, v in zip(batch_names, out_labels): 185 | predicted_dict[k].append(v) 186 | 187 | matrix = np.zeros((len(label_names), args.nclasses), dtype=int) 188 | 189 | for label_idx, label in enumerate(label_names): 190 | predicted_labels = np.array(predicted_dict[label]) 191 | 192 | for class_idx in range(args.nclasses): 193 | matrix[label_idx, class_idx] = len( 194 | predicted_labels[predicted_labels == class_idx] 195 | ) 196 | 197 | return matrix 198 | 199 | 200 | def output_confusion_matrix_stats(matrix, label_names: List[str], plot: bool = False): 201 | """Output stats about the confusion matrix. 202 | 203 | Args: 204 | matrix: The confusion matrix from which the stats are calculated. 205 | label_names (List[str]): String representations of the labels. 206 | plot (bool): If this flag is set, the confusion matrix is plotted. 207 | The plot is shown and stored in the current working directory. 208 | """ 209 | print("accuracy: ", np.trace(matrix) / matrix.sum()) 210 | 211 | diag = np.diag(matrix) 212 | 213 | worst_index = np.argmin(diag) 214 | best_index = np.argmax(diag) 215 | print( 216 | f"worst index: {worst_index} ({label_names[worst_index]}) \ 217 | with an accuracy of {diag[worst_index] / matrix[worst_index].sum() * 100:.2f}%" 218 | ) 219 | print( 220 | f"best index: {best_index} ({label_names[best_index]}) \ 221 | with an accuracy of {diag[best_index] / matrix[best_index].sum() * 100:.2f}%" 222 | ) 223 | 224 | if plot: 225 | import matplotlib.pyplot as plt 226 | 227 | disp = ConfusionMatrixDisplay( 228 | confusion_matrix=matrix, display_labels=label_names 229 | ) 230 | disp.plot() 231 | plt.savefig("confusion_matrix.png") 232 | plt.show() 233 | 234 | 235 | def output_generalized_stats(matrix): 236 | """Compute generalized statistics.""" 237 | accuracy = (matrix[0, 0] + matrix[1:, 1].sum()) / matrix.sum() 238 | known_acc = (matrix[0, 0] + matrix[1:-1, 1].sum()) / matrix[:-1, :].sum() 239 | unknown_acc = matrix[-1, 1] / matrix[-1, :].sum() 240 | 241 | print(f"{accuracy:.2f}% {known_acc:.2f}% {unknown_acc:.2f}% (corrected)") 242 | 243 | 244 | def _parse_args(): 245 | parser = argparse.ArgumentParser(description="Calculate the confusion matrix") 246 | parser.add_argument( 247 | "--classifier-path", type=str, help="path to classifier model file" 248 | ) 249 | parser.add_argument( 250 | "--data-prefix", 251 | type=str, 252 | help="shared prefix of the path of folders containing the train/test data", 253 | ) 254 | parser.add_argument( 255 | "--model", 256 | choices=["regression", "cnn"], 257 | help="The model type. Choose regression or cnn.", 258 | ) 259 | parser.add_argument( 260 | "--features", 261 | choices=["raw", "packets"], 262 | default="packets", 263 | help="the representation type", 264 | ) 265 | parser.add_argument( 266 | "--batch-size", 267 | type=int, 268 | default=512, 269 | help="input batch size for testing (default: 512)", 270 | ) 271 | parser.add_argument( 272 | "--label-names", 273 | nargs="+", 274 | type=str, 275 | default=["Original", "CramerGAN", "MMDGAN", "ProGAN", "SNGAN"], 276 | help="string representation of the class labels. Only used when '--generalized' is not selected.", 277 | ) 278 | parser.add_argument( 279 | "--plot", 280 | action="store_true", 281 | help="plot the confusion matrix and store the plot as png. Does only have an effect when \ 282 | '--generalized' is not selected.", 283 | ) 284 | parser.add_argument( 285 | "--nclasses", type=int, default=2, help="number of classes (default: 2)" 286 | ) 287 | parser.add_argument( 288 | "--generalized", 289 | action="store_true", 290 | help="Calculates a generalized confusion matrix for the binary classification \ 291 | task differentiating fake from real images.", 292 | ) 293 | parser.add_argument("--store-path", type=str, default=None) 294 | # one should not specify normalization parameters and request their calculation at the same time 295 | group = parser.add_mutually_exclusive_group() 296 | group.add_argument( 297 | "--normalize", 298 | nargs="+", 299 | type=float, 300 | metavar=("MEAN", "STD"), 301 | help="normalize with specified values for mean and standard deviation (either 2 or 6 values " 302 | "are accepted)", 303 | ) 304 | group.add_argument( 305 | "--calc-normalization", 306 | action="store_true", 307 | help="calculates mean and standard deviation used in normalization" 308 | "from the training data", 309 | ) 310 | return parser.parse_args() 311 | 312 | 313 | def _main(): 314 | args = _parse_args() 315 | print(args) 316 | 317 | if args.generalized: 318 | matrix = calculate_generalized_confusion_matrix(args) 319 | print(matrix) 320 | print(output_generalized_stats(matrix)) 321 | 322 | if args.store_path is not None: 323 | np.save(open(args.store_path, "wb"), matrix) 324 | 325 | else: 326 | matrix = calculate_confusion_matrix(args) 327 | print(matrix) 328 | 329 | output_confusion_matrix_stats(matrix, args.label_names, args.plot) 330 | 331 | if args.store_path is not None: 332 | np.save(open(args.store_path, "wb"), matrix) 333 | 334 | 335 | if __name__ == "__main__": 336 | _main() 337 | -------------------------------------------------------------------------------- /src/freqdect/train_classifier.py: -------------------------------------------------------------------------------- 1 | """Source code to train deepfake detectors in wavelet and pixel space.""" 2 | 3 | import argparse 4 | import os 5 | import pickle 6 | from typing import Any, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard.writer import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from .data_loader import CombinedDataset, NumpyDataset 15 | from .models import CNN, Regression, compute_parameter_total, save_model 16 | 17 | 18 | def val_test_loop( 19 | data_loader, 20 | model: torch.nn.Module, 21 | loss_fun, 22 | make_binary_labels: bool = False, 23 | _description: str = "Validation", 24 | pbar: bool = False, 25 | ) -> Tuple[float, Any]: 26 | """Test the performance of a model on a data set by calculating the prediction accuracy and loss of the model. 27 | 28 | Args: 29 | data_loader (DataLoader): A DataLoader loading the data set on which the performance should be measured, 30 | e.g. a test or validation set in a data split. 31 | model (torch.nn.Module): The model to evaluate. 32 | loss_fun: The loss function, which is used to measure the loss of the model on the data set 33 | make_binary_labels (bool): If flag is set, we only classify binarily, i.e. whether an image is real or fake. 34 | In this case, the label 0 encodes 'real'. All other labels are cosidered fake data, and are set to 1. 35 | 36 | Returns: 37 | Tuple[float, Any]: The measured accuracy and loss of the model on the data set. 38 | """ 39 | with torch.no_grad(): 40 | model.eval() 41 | val_total = 0 42 | 43 | val_ok = 0.0 44 | for val_batch in iter(data_loader): 45 | if type(data_loader.dataset) is CombinedDataset: 46 | batch_images = { 47 | key: val_batch[key].cuda(non_blocking=True) 48 | for key in data_loader.dataset.key 49 | } 50 | else: 51 | batch_images = val_batch[data_loader.dataset.key].cuda( 52 | non_blocking=True 53 | ) 54 | batch_labels = val_batch["label"].cuda(non_blocking=True) 55 | out = model(batch_images) 56 | if make_binary_labels: 57 | batch_labels[batch_labels > 0] = 1 58 | val_loss = loss_fun(torch.squeeze(out), batch_labels) 59 | ok_mask = torch.eq(torch.max(out, dim=-1)[1], batch_labels) 60 | val_ok += torch.sum(ok_mask).item() 61 | val_total += batch_labels.shape[0] 62 | val_acc = val_ok / val_total 63 | print("acc", val_acc, "ok", val_ok, "total", val_total) 64 | return val_acc, val_loss 65 | 66 | 67 | def _parse_args(): 68 | """Parse cmd line args for training an image classifier.""" 69 | parser = argparse.ArgumentParser(description="Train an image classifier") 70 | parser.add_argument( 71 | "--features", 72 | choices=["raw", "packets", "all-packets", "fourier", "all-packets-fourier"], 73 | default="packets", 74 | help="the representation type", 75 | ) 76 | parser.add_argument( 77 | "--batch-size", 78 | type=int, 79 | default=512, 80 | help="input batch size for testing (default: 512)", 81 | ) 82 | parser.add_argument( 83 | "--learning-rate", 84 | type=float, 85 | default=1e-3, 86 | help="learning rate for optimizer (default: 1e-3)", 87 | ) 88 | parser.add_argument( 89 | "--weight-decay", 90 | type=float, 91 | default=0, 92 | help="weight decay for optimizer (default: 0)", 93 | ) 94 | parser.add_argument( 95 | "--epochs", type=int, default=20, help="number of epochs (default: 10)" 96 | ) 97 | parser.add_argument( 98 | "--validation-interval", 99 | type=int, 100 | default=200, 101 | help="number of training steps after which the model is tested on the validation data set (default: 200)", 102 | ) 103 | parser.add_argument( 104 | "--data-prefix", 105 | type=str, 106 | nargs="+", 107 | default=["./data/source_data_packets"], 108 | help="shared prefix of the data paths (default: ./data/source_data_packets)", 109 | ) 110 | parser.add_argument( 111 | "--nclasses", type=int, default=2, help="number of classes (default: 2)" 112 | ) 113 | parser.add_argument( 114 | "--seed", type=int, default=42, help="the random seed pytorch works with." 115 | ) 116 | 117 | parser.add_argument( 118 | "--model", 119 | choices=["regression", "cnn"], 120 | default="regression", 121 | help="The model type chosse regression or CNN. Default: Regression.", 122 | ) 123 | 124 | parser.add_argument( 125 | "--tensorboard", 126 | action="store_true", 127 | help="enables a tensorboard visualization.", 128 | ) 129 | 130 | parser.add_argument( 131 | "--pbar", 132 | action="store_true", 133 | help="enables progress bars", 134 | ) 135 | 136 | parser.add_argument( 137 | "--num-workers", 138 | type=int, 139 | default=2, 140 | help="Number of worker processes started by the test and validation data loaders. The training data_loader " 141 | "uses three times this argument many workers. Hence, this argument should probably be chosen below 10. " 142 | "Defaults to 2.", 143 | ) 144 | 145 | parser.add_argument( 146 | "--class-weights", 147 | type=float, 148 | metavar="CLASS_WEIGHT", 149 | nargs="+", 150 | default=None, 151 | help="If specified, training samples are weighted based on their class " 152 | "in the loss calculation. Expects one weight per class.", 153 | ) 154 | 155 | # one should not specify normalization parameters and request their calculation at the same time 156 | group = parser.add_mutually_exclusive_group() 157 | group.add_argument( 158 | "--calc-normalization", 159 | action="store_true", 160 | help="calculates mean and standard deviation used in normalization" 161 | "from the training data", 162 | ) 163 | return parser.parse_args() 164 | 165 | 166 | def create_data_loaders(data_prefix: str, batch_size: int) -> tuple: 167 | """Create the data loaders needed for training. 168 | 169 | The test set is created outside a loader. 170 | 171 | Args: 172 | data_prefix (str): Where to look for the data. 173 | 174 | Raises: 175 | RuntimeError: Raised if the prefix is incorrect. 176 | 177 | Returns: 178 | list: train_data_loader, val_data_loader, test_data_set 179 | 180 | # noqa: DAR401 181 | """ 182 | data_set_list = [] 183 | for data_prefix_el in data_prefix: 184 | with open(f"{data_prefix_el}_train/mean_std.pkl", "rb") as file: 185 | mean, std = pickle.load(file) 186 | mean = torch.from_numpy(mean.astype(np.float32)) 187 | std = torch.from_numpy(std.astype(np.float32)) 188 | 189 | print("mean", mean, "std", std) 190 | key = "image" 191 | if "raw" in data_prefix_el.split("_"): 192 | key = "raw" 193 | elif "packets" in data_prefix_el.split("_"): 194 | key = "packets" + data_prefix_el.split("_")[-1] 195 | elif "fourier" in data_prefix_el.split("_"): 196 | key = "fourier" 197 | 198 | train_data_set = NumpyDataset( 199 | data_prefix_el + "_train", mean=mean, std=std, key=key 200 | ) 201 | val_data_set = NumpyDataset( 202 | data_prefix_el + "_val", mean=mean, std=std, key=key 203 | ) 204 | test_data_set = NumpyDataset( 205 | data_prefix_el + "_test", mean=mean, std=std, key=key 206 | ) 207 | data_set_list.append((train_data_set, val_data_set, test_data_set)) 208 | 209 | if len(data_set_list) == 1: 210 | train_data_loader = DataLoader( 211 | train_data_set, batch_size=batch_size, shuffle=True, num_workers=3 212 | ) 213 | val_data_loader = DataLoader( 214 | val_data_set, batch_size=batch_size, shuffle=False, num_workers=3 215 | ) 216 | test_data_sets: Any = test_data_set 217 | elif len(data_set_list) > 1: 218 | train_data_sets = [el[0] for el in data_set_list] 219 | val_data_sets = [el[1] for el in data_set_list] 220 | test_data_sets = [el[2] for el in data_set_list] 221 | train_data_loader = DataLoader( 222 | CombinedDataset(train_data_sets), 223 | batch_size=batch_size, 224 | shuffle=True, 225 | num_workers=3, 226 | ) 227 | val_data_loader = DataLoader( 228 | CombinedDataset(val_data_sets), 229 | batch_size=batch_size, 230 | shuffle=False, 231 | num_workers=3, 232 | ) 233 | else: 234 | raise RuntimeError("Failed to load data from the specified prefixes.") 235 | 236 | return train_data_loader, val_data_loader, test_data_sets 237 | 238 | 239 | def main(): 240 | """Trains a model to classify images. 241 | 242 | All settings such as which model to use, parameters, normalization, data set path, 243 | seed etc. are specified via cmd line args. 244 | All training, validation and testing results are printed to stdout. 245 | After the training is done, the results are stored in a pickle dump in the 'log' folder. 246 | The state_dict of the trained model is stored there as well. 247 | 248 | Raises: 249 | ValueError: Raised if mean and std values are incomplete or if the number of 250 | specified class weights does not match the number of classes. 251 | 252 | # noqa: DAR401 253 | """ 254 | args = _parse_args() 255 | print(args) 256 | 257 | if args.class_weights and len(args.class_weights) != args.nclasses: 258 | raise ValueError( 259 | f"The number of class_weights ({len(args.class_weights)}) must equal " 260 | f"the number of classes ({args.nclasses})" 261 | ) 262 | 263 | # fix the seed in the interest of reproducible results. 264 | torch.manual_seed(args.seed) 265 | 266 | make_binary_labels = args.nclasses == 2 267 | train_data_loader, val_data_loader, test_data_set = create_data_loaders( 268 | args.data_prefix, args.batch_size 269 | ) 270 | 271 | validation_list = [] 272 | loss_list = [] 273 | accuracy_list = [] 274 | step_total = 0 275 | 276 | if args.model == "cnn": 277 | model = CNN(args.nclasses, args.features).cuda() 278 | else: 279 | model = Regression(args.nclasses).cuda() 280 | 281 | print("model parameter count:", compute_parameter_total(model)) 282 | 283 | if args.tensorboard: 284 | writer_str = "runs/" 285 | writer_str += "params_test2/" 286 | writer_str += f"{args.model}/" 287 | writer_str += f"{args.batch_size}/" 288 | writer_str += str(args.data_prefix.split("/")[-1]) + "/" 289 | writer_str += f"{args.learning_rate}_" 290 | writer_str += f"{args.seed}" 291 | writer = SummaryWriter(writer_str, max_queue=100) 292 | 293 | if args.class_weights: 294 | loss_fun = torch.nn.NLLLoss(weight=torch.tensor(args.class_weights).cuda()) 295 | else: 296 | loss_fun = torch.nn.NLLLoss() 297 | optimizer = torch.optim.Adam( 298 | model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay 299 | ) 300 | 301 | for e in tqdm( 302 | range(args.epochs), desc="Epochs", unit="epochs", disable=not args.pbar 303 | ): 304 | # iterate over training data. 305 | for it, batch in enumerate( 306 | tqdm( 307 | iter(train_data_loader), 308 | desc="Training", 309 | unit="batches", 310 | disable=not args.pbar, 311 | ) 312 | ): 313 | model.train() 314 | optimizer.zero_grad() 315 | # find the bug. 316 | if type(train_data_loader.dataset) is CombinedDataset: 317 | batch_images = { 318 | key: batch[key].cuda(non_blocking=True) 319 | for key in train_data_loader.dataset.key 320 | } 321 | else: 322 | batch_images = batch[train_data_loader.dataset.key].cuda( 323 | non_blocking=True 324 | ) 325 | 326 | batch_labels = batch["label"].cuda(non_blocking=True) 327 | if make_binary_labels: 328 | batch_labels[batch_labels > 0] = 1 329 | 330 | out = model(batch_images) 331 | loss = loss_fun(torch.squeeze(out), batch_labels) 332 | ok_mask = torch.eq(torch.max(out, dim=-1)[1], batch_labels) 333 | acc = torch.sum(ok_mask.type(torch.float32)) / len(batch_labels) 334 | 335 | if it % 10 == 0: 336 | print( 337 | "e", 338 | e, 339 | "it", 340 | it, 341 | "total", 342 | step_total, 343 | "loss", 344 | loss.item(), 345 | "acc", 346 | acc.item(), 347 | ) 348 | loss.backward() 349 | optimizer.step() 350 | step_total += 1 351 | loss_list.append([step_total, e, loss.item()]) 352 | accuracy_list.append([step_total, e, acc.item()]) 353 | 354 | if args.tensorboard: 355 | writer.add_scalar("loss/train", loss.item(), step_total) 356 | writer.add_scalar("accuracy/train", acc.item(), step_total) 357 | if step_total == 0: 358 | writer.add_graph(model, batch_images) 359 | 360 | # iterate over val batches. 361 | if step_total % args.validation_interval == 0: 362 | val_acc, val_loss = val_test_loop( 363 | val_data_loader, 364 | model, 365 | loss_fun, 366 | make_binary_labels=make_binary_labels, 367 | pbar=args.pbar, 368 | ) 369 | validation_list.append([step_total, e, val_acc]) 370 | if validation_list[-1] == 1.0: 371 | print("val acc ideal stopping training.") 372 | break 373 | 374 | if args.tensorboard: 375 | writer.add_scalar("loss/validation", val_loss, step_total) 376 | writer.add_scalar("accuracy/validation", val_acc, step_total) 377 | 378 | if args.tensorboard: 379 | writer.add_scalar("epochs", e, step_total) 380 | 381 | print(validation_list) 382 | 383 | if not os.path.exists("./log/"): 384 | os.makedirs("./log/") 385 | model_file = ( 386 | "./log/" 387 | + args.data_prefix[0].split("/")[-1] 388 | + "_" 389 | + str(args.learning_rate) 390 | + "_" 391 | + f"{args.epochs}e" 392 | + "_" 393 | + str(args.model) 394 | ) 395 | save_model(model, model_file + "_" + str(args.seed) + ".pt") 396 | print(model_file, " saved.") 397 | 398 | # Run over the test set. 399 | print("Training done testing....") 400 | if type(test_data_set) is list: 401 | test_data_set = CombinedDataset(test_data_set) 402 | 403 | test_data_loader = DataLoader( 404 | test_data_set, 405 | args.batch_size, 406 | shuffle=False, 407 | num_workers=args.num_workers, 408 | ) 409 | with torch.no_grad(): 410 | test_acc, test_loss = val_test_loop( 411 | test_data_loader, 412 | model, 413 | loss_fun, 414 | make_binary_labels=make_binary_labels, 415 | pbar=not args.pbar, 416 | _description="Testing", 417 | ) 418 | print("test acc", test_acc) 419 | 420 | if args.tensorboard: 421 | writer.add_scalar("accuracy/test", test_acc, step_total) 422 | writer.add_scalar("loss/test", test_loss, step_total) 423 | 424 | _save_stats( 425 | model_file, 426 | loss_list, 427 | accuracy_list, 428 | validation_list, 429 | test_acc, 430 | args, 431 | len(iter(train_data_loader)), 432 | ) 433 | 434 | if args.tensorboard: 435 | writer.close() 436 | 437 | 438 | def _save_stats( 439 | model_file: str, 440 | loss_list: list, 441 | accuracy_list: list, 442 | validation_list: list, 443 | test_acc: float, 444 | args, 445 | iterations_per_epoch: int, 446 | ): 447 | stats_file = model_file + "_" + str(args.seed) + ".pkl" 448 | try: 449 | res = pickle.load(open(stats_file, "rb")) 450 | except OSError as e: 451 | res = [] 452 | print( 453 | e, 454 | "stats.pickle does not exist, \ 455 | creating a new file.", 456 | ) 457 | res.append( 458 | { 459 | "train_loss": loss_list, 460 | "train_acc": accuracy_list, 461 | "val_acc": validation_list, 462 | "test_acc": test_acc, 463 | "args": args, 464 | "iterations_per_epoch": iterations_per_epoch, 465 | } 466 | ) 467 | pickle.dump(res, open(stats_file, "wb")) 468 | print(stats_file, " saved.") 469 | 470 | 471 | if __name__ == "__main__": 472 | main() 473 | -------------------------------------------------------------------------------- /src/freqdect/plot_accuracy_results.py: -------------------------------------------------------------------------------- 1 | """Code to plot training mean accuracy as well as the standard deviation.""" 2 | import argparse 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from matplotlib.axes import Axes 8 | 9 | 10 | def stack_list(dict_list, key: str): 11 | """Extract time series data from a logfile-list. 12 | 13 | Args: 14 | dict_list (list): A list as stored by train_classifier.py 15 | key (str): The key for a logfile entry. 16 | 17 | Returns: 18 | tuple: A tuple of a step and accuracy numpy array. 19 | """ 20 | step_lst = [] 21 | acc_lst = [] 22 | for current_dictionary in dict_list: 23 | if len(current_dictionary[key][0]) == 2: 24 | steps, acc = zip(*current_dictionary[key]) 25 | elif len(current_dictionary[key][0]) == 3: 26 | steps, epochs, acc = zip(*current_dictionary[key]) 27 | step_lst.append(steps) 28 | acc_lst.append(acc) 29 | return np.stack(step_lst), np.stack(acc_lst) 30 | 31 | 32 | def _get_steps_mean_std(step_lst, cost_lst): 33 | mean = np.mean(cost_lst, axis=0) 34 | std = np.std(cost_lst, axis=0) 35 | return step_lst[0, :], mean, std 36 | 37 | 38 | def get_plot_tuple(dict_list, key: str): 39 | """Extract time series data from a logfile-list. 40 | 41 | Args: 42 | dict_list (list): A list as stored by train_classifier.py 43 | key (str): The key for a logfile entry. 44 | 45 | Returns: 46 | tuple: A tuple of a step and mean accuracy and standard deviation. 47 | """ 48 | steps, loss = stack_list(dict_list, key) 49 | steps, mean, std = _get_steps_mean_std(steps, loss) 50 | return steps, mean, std 51 | 52 | 53 | def _plot_mean_std(axs, steps, mean, std, color, label="", marker="."): 54 | axs.plot(steps, mean, label=label, color=color, marker=marker) 55 | axs.fill_between(steps, mean - std, mean + std, color=color, alpha=0.2) 56 | 57 | 58 | def get_test_acc_mean_std_max(dict_list: list, key: str): 59 | """Compute the mean test accuracy and standard deviation over multiple runs. 60 | 61 | Args: 62 | dict_list (list): A list of dicts as stored by train_classifier.py 63 | key (str): The dictionary key we are interested in. 64 | 65 | Returns: 66 | tuple: The mean, standard deviation and max in that order. 67 | """ 68 | test_accs = [] 69 | for experiment_dict in dict_list: 70 | test_accs.append(experiment_dict[key]) 71 | return np.mean(test_accs), np.std(test_accs), np.max(test_accs) 72 | 73 | 74 | def _plot_on_ax( # noqa: C901 75 | dataset: str, 76 | param_str: str, 77 | axs: Axes, 78 | logpacket_logs, 79 | packet_logs, 80 | raw_logs, 81 | epochs: int = None, 82 | batch_size: int = None, 83 | ylabel: str = None, 84 | ylim: float = None, 85 | title: str = None, 86 | place_legend: bool = False, 87 | ): 88 | # convert logs to ndarrays to allow better indexing 89 | if raw_logs: 90 | raw_logs = np.array(raw_logs) 91 | if packet_logs: 92 | packet_logs = np.array(packet_logs) 93 | if logpacket_logs: 94 | logpacket_logs = np.array(logpacket_logs) 95 | 96 | log_names = ["raw", "packets", "log_packets"] 97 | 98 | # filter out all log entries that do not match the specified epoch number 99 | if epochs is not None: 100 | if raw_logs is not None: 101 | indices_raw = [vars(run["args"])["epochs"] == epochs for run in raw_logs] 102 | raw_logs = raw_logs[indices_raw] 103 | if packet_logs is not None: 104 | indices_packets = [ 105 | vars(run["args"])["epochs"] == epochs for run in packet_logs 106 | ] 107 | packet_logs = packet_logs[indices_packets] 108 | if logpacket_logs is not None: 109 | indices_logpackets = [ 110 | vars(run["args"])["epochs"] == epochs for run in logpacket_logs 111 | ] 112 | logpacket_logs = logpacket_logs[indices_logpackets] 113 | 114 | for logs, logs_name in zip([raw_logs, packet_logs, logpacket_logs], log_names): 115 | if logs is not None and logs.size == 0: 116 | print(f"No runs found for {epochs} epochs for {logs_name}") 117 | 118 | # filter out all log entries that do not match the specified batch_size number 119 | if batch_size is not None: 120 | if raw_logs is not None: 121 | indices_raw = [ 122 | vars(run["args"])["batch_size"] == epochs for run in raw_logs 123 | ] 124 | raw_logs = raw_logs[indices_raw] 125 | if packet_logs is not None: 126 | indices_packets = [ 127 | vars(run["args"])["batch_size"] == epochs for run in packet_logs 128 | ] 129 | packet_logs = packet_logs[indices_packets] 130 | if logpacket_logs is not None: 131 | indices_logpackets = [ 132 | vars(run["args"])["batch_size"] == epochs for run in logpacket_logs 133 | ] 134 | logpacket_logs = logpacket_logs[indices_logpackets] 135 | 136 | for logs, logs_name in zip([raw_logs, packet_logs, logpacket_logs], log_names): 137 | if logs is not None and logs.size == 0: 138 | print(f"No runs found for {batch_size} epochs for {logs_name}") 139 | 140 | print(f"{dataset} {param_str}:") 141 | 142 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] 143 | 144 | def print_results(name, logs, logs_mean, logs_std, logs_max): 145 | """Print the max, mean and std of the accuracy of the runs on one feature.""" 146 | print(f"{name} ({len(logs)} runs):") 147 | print( 148 | f"\t{name} seeds:", 149 | ", ".join([str(vars(run["args"])["seed"]) for run in logs]), 150 | ) 151 | print( 152 | f"\t\tmax: {logs_max * 100:.2f}%\n\t\tmean: {logs_mean * 100:.2f}%\n\t\tstd: {logs_std * 100:.2f}" 153 | ) 154 | 155 | def _process_logs(name, logs, idx): 156 | steps, mean, std = get_plot_tuple(logs, "val_acc") 157 | _plot_mean_std( 158 | axs, steps, mean, std, color=colors[idx], label=f"{name} validation acc" 159 | ) 160 | t_mean, t_std, t_max = get_test_acc_mean_std_max(logs, "test_acc") 161 | print_results(name, logs, t_mean, t_std, t_max) 162 | 163 | axs.errorbar( 164 | logs[0]["train_acc"][-1][0], 165 | t_mean, 166 | t_std, 167 | color=colors[3 + idx], 168 | label=f"{name} test acc", 169 | marker="_", 170 | ) 171 | 172 | for idx, (logs, logs_name) in enumerate( 173 | zip([raw_logs, packet_logs, logpacket_logs], log_names) 174 | ): 175 | if logs is not None: 176 | _process_logs(logs_name, logs, idx) 177 | 178 | axs.set_xlabel("training steps") 179 | if ylabel is not None: 180 | axs.set_ylabel(ylabel) 181 | 182 | if ylim is not None: 183 | axs.set_ylim(top=ylim) 184 | 185 | if title is not None: 186 | axs.set_title(title) 187 | else: 188 | axs.set_title(f"{dataset}-GAN") 189 | 190 | if place_legend: 191 | axs.legend() 192 | 193 | 194 | def export_plots(args, output_prefix: str): 195 | """Export the plot as png or tikz plot. 196 | 197 | Shows the plot, if not specified otherwise in the cmd line args. 198 | 199 | Args: 200 | args: The cmd line args settings. 201 | output_prefix (str): A prefix, with which the file names of the exported plots start. 202 | """ 203 | output_str = "" 204 | if args.wavelet: 205 | output_str += f"_{args.wavelet}" 206 | if args.mode: 207 | output_str += f"_{args.mode}" 208 | if args.learning_rate: 209 | output_str += f"_{args.learning_rate}" 210 | if args.epochs: 211 | output_str += f"_{args.epochs}e" 212 | output_str += f"_{args.model}" 213 | 214 | if args.png: 215 | print(f"saving {output_prefix}{output_str}_accuracy.png") 216 | plt.savefig(f"{output_prefix}{output_str}_accuracy.png") 217 | if args.tikz: 218 | import tikzplotlib 219 | 220 | print(f"saving {output_prefix}{output_str}_accuracy.tex") 221 | tikzplotlib.save(f"{output_prefix}{output_str}_accuracy.tex", standalone=True) 222 | if not args.hide: 223 | plt.show() 224 | 225 | 226 | def skip_every_second_val_acc(logs): 227 | """Half the validation accuracy resolution by skipping every second validation accuracy entry. 228 | 229 | If the interval between the validation accuracy measurements is too small, the resulting plot is too loaded. 230 | In this case, this function is useful. 231 | 232 | Args: 233 | logs: The log of the runs, from which every second validation accuracy measurement is skipped. 234 | """ 235 | for run in logs: 236 | run["val_acc"] = run["val_acc"][1::2] 237 | 238 | 239 | def plot_shared(args): 240 | """Plot the validation and test accuracy. 241 | 242 | Both LSUN and CelebA are shown side by side for better comparision. 243 | """ 244 | logpacket_logs_lsun = pickle.load( 245 | open(f"{args.prefix_lsun}_logpackets_{args.model}.pkl", "rb") 246 | ) 247 | packet_logs_lsun = pickle.load( 248 | open(f"{args.prefix_lsun}_packets_{args.model}.pkl", "rb") 249 | ) 250 | raw_logs_lsun = pickle.load(open(f"{args.prefix_lsun}_raw_{args.model}.pkl", "rb")) 251 | logpacket_logs_celeba = pickle.load( 252 | open(f"{args.prefix_celeba}_logpackets_{args.model}.pkl", "rb") 253 | ) 254 | packet_logs_celeba = pickle.load( 255 | open(f"{args.prefix_celeba}_packets_{args.model}.pkl", "rb") 256 | ) 257 | raw_logs_celeba = pickle.load( 258 | open(f"{args.prefix_celeba}_raw_{args.model}.pkl", "rb") 259 | ) 260 | 261 | if args.skip_val_acc_indices is not None: 262 | log_list = [ 263 | raw_logs_celeba, 264 | packet_logs_celeba, 265 | logpacket_logs_celeba, 266 | raw_logs_lsun, 267 | packet_logs_lsun, 268 | logpacket_logs_lsun, 269 | ] 270 | for idx in args.skip_val_acc_indices: 271 | skip_every_second_val_acc(log_list[idx]) 272 | 273 | fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True, figsize=(10, 5)) 274 | 275 | _plot_on_ax( 276 | dataset="CelebA", 277 | model=args.model, 278 | axs=ax2, 279 | logpacket_logs=logpacket_logs_celeba, 280 | packet_logs=packet_logs_celeba, 281 | raw_logs=raw_logs_celeba, 282 | epochs=args.epochs[1], 283 | batch_size=args.batch_size[1], 284 | ylim=args.ylim, 285 | ) 286 | 287 | _plot_on_ax( 288 | dataset="LSUN", 289 | model=args.model, 290 | axs=ax1, 291 | logpacket_logs=logpacket_logs_lsun, 292 | packet_logs=packet_logs_lsun, 293 | raw_logs=raw_logs_lsun, 294 | epochs=args.epochs[0], 295 | batch_size=args.batch_size[0], 296 | ylim=args.ylim, 297 | ylabel="accuracy", 298 | ) 299 | 300 | plt.suptitle("source identification") 301 | handles, labels = ax2.get_legend_handles_labels() 302 | fig.legend(handles, labels, loc="center right", bbox_to_anchor=(1.0, 0.30)) 303 | plt.tight_layout() 304 | 305 | export_plots(args, output_prefix="lsun_celeba") 306 | 307 | 308 | def plot_single(args): 309 | """Plot the validation and test accuracy for one data set.""" 310 | plt.figure(figsize=(8, 5)) 311 | 312 | suffix_str_packets = "" 313 | suffix_str_raw = "" 314 | params_str = f"{args.model}" 315 | if args.wavelet: 316 | suffix_str_packets += f"_{args.wavelet}" 317 | params_str += f" {args.wavelet}" 318 | if args.mode: 319 | suffix_str_packets += f"_{args.mode}" 320 | params_str += f" {args.mode}" 321 | if args.learning_rate: 322 | suffix_str_packets += f"_{args.learning_rate}" 323 | suffix_str_raw += f"_{args.learning_rate}" 324 | params_str += f" {args.learning_rate}" 325 | if args.epochs: 326 | suffix_str_packets += f"_{args.epochs}e" 327 | suffix_str_raw += f"_{args.epochs}e" 328 | params_str += f" {args.epochs}e" 329 | suffix_str_packets += f"_{args.model}" 330 | suffix_str_raw += f"_{args.model}" 331 | 332 | try: 333 | logpacket_logs = pickle.load( 334 | open(f"{args.prefix}_log_packets{suffix_str_packets}.pkl", "rb") 335 | ) 336 | except FileNotFoundError: 337 | print(f"{args.prefix}_log_packets{suffix_str_packets}.pkl not found!") 338 | logpacket_logs = None 339 | try: 340 | packet_logs = pickle.load( 341 | open(f"{args.prefix}_packets{suffix_str_packets}.pkl", "rb") 342 | ) 343 | except FileNotFoundError: 344 | print(f"{args.prefix}_packets{suffix_str_packets}.pkl not found!") 345 | packet_logs = None 346 | try: 347 | raw_logs = pickle.load(open(f"{args.prefix}_raw{suffix_str_raw}.pkl", "rb")) 348 | except FileNotFoundError: 349 | raw_logs = None 350 | print(f"{args.prefix}_raw{suffix_str_raw}.pkl not found!") 351 | 352 | if not any([logpacket_logs, packet_logs, raw_logs]): 353 | raise ValueError("Not log files found!") 354 | 355 | if args.skip_val_acc_indices is not None: 356 | log_list = [raw_logs, packet_logs, logpacket_logs] 357 | for idx in args.skip_val_acc_indices: 358 | if log_list[idx]: 359 | skip_every_second_val_acc(log_list[idx]) 360 | 361 | _plot_on_ax( 362 | dataset=args.dataset, 363 | param_str=params_str, 364 | axs=plt.gca(), 365 | logpacket_logs=logpacket_logs, 366 | packet_logs=packet_logs, 367 | raw_logs=raw_logs, 368 | epochs=args.epochs, 369 | batch_size=args.batch_size, 370 | ylabel="accuracy", 371 | ylim=args.ylim, 372 | place_legend=True, 373 | title=f"{args.dataset} {params_str} binary classification", 374 | ) 375 | 376 | export_plots(args, output_prefix=args.dataset.lower()) 377 | 378 | 379 | def _parse_args(): 380 | parser = argparse.ArgumentParser(description="Plot validation accuracy") 381 | 382 | parent_parser = argparse.ArgumentParser(add_help=False) 383 | 384 | parent_parser.add_argument("model", choices=["regression", "cnn"]) 385 | parent_parser.add_argument( 386 | "-p", "--png", action="store_true", help="save the plot as a png" 387 | ) 388 | parent_parser.add_argument( 389 | "-t", "--tikz", action="store_true", help="export a tikz version of the plot" 390 | ) 391 | parent_parser.add_argument( 392 | "--hide", action="store_true", help="do not show the plot" 393 | ) 394 | parent_parser.add_argument( 395 | "--skip-val-acc-indices", 396 | nargs="*", 397 | type=int, 398 | default=None, 399 | help="indices of the logs, for which every second validation accuracy value should be " 400 | "skipped (starting at 0). The order of lists is [raw, packets, logpackets] (and " 401 | "[celeba, lsun] in the shared case), e.g. for lsun packets the index would be 1 " 402 | "(or 4 in the shared case).", 403 | ) 404 | parent_parser.add_argument( 405 | "--ylim", type=float, default=None, help="Maximal value of the y axis" 406 | ) 407 | parent_parser.add_argument("--wavelet", type=str, default=None, help="Wavelet used") 408 | parent_parser.add_argument( 409 | "--mode", type=str, default=None, help="Boundary mode used" 410 | ) 411 | parent_parser.add_argument( 412 | "--learning-rate", type=float, default=None, help="Learning rate used" 413 | ) 414 | 415 | subparsers = parser.add_subparsers(required=True) 416 | 417 | # create subparser for plotting a shared plot for LSUN/CelebA 418 | parser_shared = subparsers.add_parser("shared", parents=[parent_parser]) 419 | parser_shared.add_argument( 420 | "--epochs", 421 | nargs=2, 422 | metavar=("LSUN_EPOCHS", "CELEBA_EPOCHS"), 423 | type=int, 424 | default=[None, None], 425 | help="Filter the logs for only these numbers of epochs", 426 | ) 427 | parser_shared.add_argument( 428 | "--batch-size", 429 | nargs=2, 430 | metavar=("LSUN_BATCH_SIZE", "CELEBA_BATCH_SIZE"), 431 | type=int, 432 | default=[None, None], 433 | help="Filter the logs for only these batch sizes", 434 | ) 435 | parser_shared.add_argument( 436 | "--prefix-lsun", 437 | type=str, 438 | default="./log/lsun_bedroom_200k_png", 439 | help="shared file path prefix of the log files (default: ./log/lsun_bedroom_200k_png)", 440 | ) 441 | parser_shared.add_argument( 442 | "--prefix-celeba", 443 | default="./log/celeba_align_png_cropped", 444 | help="shared file path prefix of the log files (default: ./log/celeba_align_png_cropped)", 445 | ) 446 | parser_shared.set_defaults(func=plot_shared) 447 | 448 | # create subparser for plotting either LSUN or CelebA 449 | parser_lsun = subparsers.add_parser("lsun", parents=[parent_parser]) 450 | parser_lsun.add_argument( 451 | "--prefix", 452 | default="./log/lsun_bedroom_200k_png", 453 | help="shared file path prefix of the log files (default: ./log/lsun_bedroom_200k_png)", 454 | ) 455 | parser_lsun.add_argument( 456 | "--epochs", 457 | type=int, 458 | default=None, 459 | help="Filter the logs for only this number of epochs", 460 | ) 461 | parser_lsun.add_argument( 462 | "--batch-size", 463 | type=int, 464 | default=None, 465 | help="Filter the logs for only this batch size", 466 | ) 467 | parser_lsun.set_defaults(func=plot_single) 468 | parser_lsun.set_defaults(dataset="LSUN") 469 | 470 | parser_celeba = subparsers.add_parser("celeba", parents=[parent_parser]) 471 | parser_celeba.add_argument( 472 | "--prefix", 473 | default="./log/celeba_align_png_cropped", 474 | help="shared file path prefix of the log files (default: ./log/celeba_align_png_cropped)", 475 | ) 476 | parser_celeba.add_argument( 477 | "--epochs", 478 | type=int, 479 | default=None, 480 | help="Filter the logs for only this number of epochs", 481 | ) 482 | parser_celeba.add_argument( 483 | "--batch-size", 484 | type=int, 485 | default=None, 486 | help="Filter the logs for only this batch size", 487 | ) 488 | parser_celeba.set_defaults(func=plot_single) 489 | parser_celeba.set_defaults(dataset="CelebA") 490 | 491 | parser_other = subparsers.add_parser("other", parents=[parent_parser]) 492 | parser_other.add_argument( 493 | "--dataset", type=str, default="other_dataset", help="Name of the dataset" 494 | ) 495 | parser_other.add_argument( 496 | "--prefix", 497 | type=str, 498 | default="./log/data", 499 | help="shared file path prefix of the log files", 500 | ) 501 | parser_other.add_argument( 502 | "--epochs", 503 | type=int, 504 | default=None, 505 | help="Filter the logs for only this number of epochs", 506 | ) 507 | parser_other.add_argument( 508 | "--batch-size", 509 | type=int, 510 | default=None, 511 | help="Filter the logs for only this batch size", 512 | ) 513 | parser_other.set_defaults(func=plot_single) 514 | 515 | return parser.parse_args() 516 | 517 | 518 | def main(args): 519 | """Plot the accuracy results, as specified in the cmd line args.""" 520 | args.func(args) 521 | 522 | 523 | if __name__ == "__main__": 524 | main(_parse_args()) 525 | --------------------------------------------------------------------------------