├── .gitattributes
├── .github
└── workflows
│ └── CI.yml
├── .gitignore
├── CITATION.cff
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── eula.pdf
├── examples
├── batfd
│ ├── README.md
│ ├── batfd.toml
│ ├── batfd
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── audio_encoder.py
│ │ │ ├── batfd.py
│ │ │ ├── batfd_plus.py
│ │ │ ├── boundary_module.py
│ │ │ ├── boundary_module_plus.py
│ │ │ ├── frame_classifier.py
│ │ │ ├── fusion_module.py
│ │ │ ├── loss.py
│ │ │ └── video_encoder.py
│ │ ├── post_process.py
│ │ └── utils.py
│ ├── batfd_plus.toml
│ ├── evaluate.py
│ ├── infer.py
│ └── train.py
└── xception
│ ├── README.md
│ ├── evaluate.py
│ ├── infer.py
│ ├── train.py
│ ├── utils.py
│ └── xception.py
├── pixi.lock
├── pyproject.toml
├── python
└── avdeepfake1m
│ ├── __init__.py
│ ├── evaluation
│ ├── __init__.py
│ ├── __init__.pyi
│ └── auc.py
│ ├── loader.py
│ └── utils.py
└── src
├── lib.rs
└── loc_1d.rs
/.gitattributes:
--------------------------------------------------------------------------------
1 | # GitHub syntax highlighting
2 | pixi.lock linguist-language=YAML linguist-generated=true
3 |
--------------------------------------------------------------------------------
/.github/workflows/CI.yml:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by maturin v1.7.4
2 | # To update, run
3 | #
4 | # maturin generate-ci github
5 | #
6 | name: CI
7 |
8 | on:
9 | push:
10 | branches:
11 | - main
12 | - master
13 | tags:
14 | - '*'
15 | pull_request:
16 | workflow_dispatch:
17 |
18 | permissions:
19 | contents: read
20 |
21 | jobs:
22 | linux:
23 | runs-on: ${{ matrix.platform.runner }}
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | platform:
28 | - runner: ubuntu-latest
29 | target: x86_64
30 | - runner: ubuntu-latest
31 | target: aarch64
32 | - runner: ubuntu-latest
33 | target: armv7
34 | - runner: ubuntu-latest
35 | target: s390x
36 | - runner: ubuntu-latest
37 | target: ppc64le
38 | steps:
39 | - uses: actions/checkout@v4
40 | - uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.11"
43 | - name: Update version if tag
44 | if: startsWith(github.ref, 'refs/tags/')
45 | run: |
46 | VERSION=${GITHUB_REF#refs/tags/}
47 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml
48 | - name: Build wheels
49 | uses: PyO3/maturin-action@v1
50 | with:
51 | rust-toolchain: nightly
52 | target: ${{ matrix.platform.target }}
53 | args: --release --out dist --find-interpreter
54 | sccache: 'true'
55 | manylinux: auto
56 | - name: Upload wheels
57 | uses: actions/upload-artifact@v4
58 | with:
59 | name: wheels-linux-${{ matrix.platform.target }}
60 | path: dist
61 |
62 | musllinux:
63 | runs-on: ${{ matrix.platform.runner }}
64 | strategy:
65 | fail-fast: false
66 | matrix:
67 | platform:
68 | - runner: ubuntu-latest
69 | target: x86_64
70 | - runner: ubuntu-latest
71 | target: aarch64
72 | steps:
73 | - uses: actions/checkout@v4
74 | - uses: actions/setup-python@v5
75 | with:
76 | python-version: "3.11"
77 | - uses: dtolnay/rust-toolchain@nightly
78 | - name: Update version if tag
79 | if: startsWith(github.ref, 'refs/tags/')
80 | run: |
81 | VERSION=${GITHUB_REF#refs/tags/}
82 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml
83 | - name: Build wheels
84 | uses: PyO3/maturin-action@v1
85 | with:
86 | rust-toolchain: nightly
87 | target: ${{ matrix.platform.target }}
88 | args: --release --out dist --find-interpreter
89 | sccache: 'true'
90 | manylinux: musllinux_1_2
91 | - name: Upload wheels
92 | uses: actions/upload-artifact@v4
93 | with:
94 | name: wheels-musllinux-${{ matrix.platform.target }}
95 | path: dist
96 |
97 | windows:
98 | runs-on: ${{ matrix.platform.runner }}
99 | strategy:
100 | fail-fast: false
101 | matrix:
102 | platform:
103 | - runner: windows-latest
104 | target: x64
105 | steps:
106 | - uses: actions/checkout@v4
107 | - uses: actions/setup-python@v5
108 | with:
109 | python-version: "3.11"
110 | architecture: ${{ matrix.platform.target }}
111 | - uses: dtolnay/rust-toolchain@nightly
112 | - name: Update version if tag
113 | if: startsWith(github.ref, 'refs/tags/')
114 | shell: bash
115 | run: |
116 | VERSION=${GITHUB_REF#refs/tags/}
117 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml
118 | - name: Build wheels
119 | uses: PyO3/maturin-action@v1
120 | with:
121 | rust-toolchain: nightly
122 | target: ${{ matrix.platform.target }}
123 | args: --release --out dist --find-interpreter
124 | sccache: 'true'
125 | - name: Upload wheels
126 | uses: actions/upload-artifact@v4
127 | with:
128 | name: wheels-windows-${{ matrix.platform.target }}
129 | path: dist
130 |
131 | macos:
132 | runs-on: ${{ matrix.platform.runner }}
133 | strategy:
134 | fail-fast: false
135 | matrix:
136 | platform:
137 | - runner: macos-13
138 | target: x86_64
139 | - runner: macos-14
140 | target: aarch64
141 | steps:
142 | - uses: actions/checkout@v4
143 | - uses: actions/setup-python@v5
144 | with:
145 | python-version: "3.11"
146 | - uses: dtolnay/rust-toolchain@nightly
147 | - name: Update version if tag
148 | if: startsWith(github.ref, 'refs/tags/')
149 | run: |
150 | VERSION=${GITHUB_REF#refs/tags/}
151 | sed -i "" "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml
152 | - name: Build wheels
153 | uses: PyO3/maturin-action@v1
154 | with:
155 | rust-toolchain: nightly
156 | target: ${{ matrix.platform.target }}
157 | args: --release --out dist --find-interpreter
158 | sccache: 'true'
159 | - name: Upload wheels
160 | uses: actions/upload-artifact@v4
161 | with:
162 | name: wheels-macos-${{ matrix.platform.target }}
163 | path: dist
164 |
165 | sdist:
166 | runs-on: ubuntu-latest
167 | steps:
168 | - uses: actions/checkout@v4
169 | - uses: dtolnay/rust-toolchain@nightly
170 | - name: Update version if tag
171 | if: startsWith(github.ref, 'refs/tags/')
172 | run: |
173 | VERSION=${GITHUB_REF#refs/tags/}
174 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml
175 | - name: Build sdist
176 | uses: PyO3/maturin-action@v1
177 | with:
178 | rust-toolchain: nightly
179 | command: sdist
180 | args: --out dist
181 | - name: Upload sdist
182 | uses: actions/upload-artifact@v4
183 | with:
184 | name: wheels-sdist
185 | path: dist
186 |
187 | release:
188 | name: Release
189 | runs-on: ubuntu-latest
190 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
191 | needs: [linux, musllinux, windows, macos, sdist]
192 | permissions:
193 | # Use to sign the release artifacts
194 | id-token: write
195 | # Used to upload release artifacts
196 | contents: write
197 | # Used to generate artifact attestation
198 | attestations: write
199 | steps:
200 | - uses: actions/download-artifact@v4
201 | - name: Generate artifact attestation
202 | uses: actions/attest-build-provenance@v1
203 | with:
204 | subject-path: 'wheels-*/*'
205 | - name: Publish to PyPI
206 | if: startsWith(github.ref, 'refs/tags/')
207 | uses: PyO3/maturin-action@v1
208 | env:
209 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
210 | with:
211 | command: upload
212 | args: --non-interactive --skip-existing wheels-*/*
213 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=intellij,python,windows,macos,linux,jupyternotebooks
3 |
4 | ### Intellij ###
5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
7 |
8 | # User-specific stuff
9 | .idea/**/workspace.xml
10 | .idea/**/tasks.xml
11 | .idea/**/usage.statistics.xml
12 | .idea/**/dictionaries
13 | .idea/**/shelf
14 |
15 | # AWS User-specific
16 | .idea/**/aws.xml
17 |
18 | # Generated files
19 | .idea/**/contentModel.xml
20 |
21 | # Sensitive or high-churn files
22 | .idea/**/dataSources/
23 | .idea/**/dataSources.ids
24 | .idea/**/dataSources.local.xml
25 | .idea/**/sqlDataSources.xml
26 | .idea/**/dynamic.xml
27 | .idea/**/uiDesigner.xml
28 | .idea/**/dbnavigator.xml
29 |
30 | # Gradle
31 | .idea/**/gradle.xml
32 | .idea/**/libraries
33 |
34 | # Gradle and Maven with auto-import
35 | # When using Gradle or Maven with auto-import, you should exclude module files,
36 | # since they will be recreated, and may cause churn. Uncomment if using
37 | # auto-import.
38 | # .idea/artifacts
39 | # .idea/compiler.xml
40 | # .idea/jarRepositories.xml
41 | # .idea/modules.xml
42 | # .idea/*.iml
43 | # .idea/modules
44 | # *.iml
45 | # *.ipr
46 |
47 | # CMake
48 | cmake-build-*/
49 |
50 | # Mongo Explorer plugin
51 | .idea/**/mongoSettings.xml
52 |
53 | # File-based project format
54 | *.iws
55 |
56 | # IntelliJ
57 | out/
58 |
59 | # mpeltonen/sbt-idea plugin
60 | .idea_modules/
61 |
62 | # JIRA plugin
63 | atlassian-ide-plugin.xml
64 |
65 | # Cursive Clojure plugin
66 | .idea/replstate.xml
67 |
68 | # SonarLint plugin
69 | .idea/sonarlint/
70 |
71 | # Crashlytics plugin (for Android Studio and IntelliJ)
72 | com_crashlytics_export_strings.xml
73 | crashlytics.properties
74 | crashlytics-build.properties
75 | fabric.properties
76 |
77 | # Editor-based Rest Client
78 | .idea/httpRequests
79 |
80 | # Android studio 3.1+ serialized cache file
81 | .idea/caches/build_file_checksums.ser
82 |
83 | ### Intellij Patch ###
84 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
85 |
86 | # *.iml
87 | # modules.xml
88 | # .idea/misc.xml
89 | # *.ipr
90 | .idea
91 | # Sonarlint plugin
92 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
93 | .idea/**/sonarlint/
94 |
95 | # SonarQube Plugin
96 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
97 | .idea/**/sonarIssues.xml
98 |
99 | # Markdown Navigator plugin
100 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
101 | .idea/**/markdown-navigator.xml
102 | .idea/**/markdown-navigator-enh.xml
103 | .idea/**/markdown-navigator/
104 |
105 | # Cache file creation bug
106 | # See https://youtrack.jetbrains.com/issue/JBR-2257
107 | .idea/$CACHE_FILE$
108 |
109 | # CodeStream plugin
110 | # https://plugins.jetbrains.com/plugin/12206-codestream
111 | .idea/codestream.xml
112 |
113 | # Azure Toolkit for IntelliJ plugin
114 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
115 | .idea/**/azureSettings.xml
116 |
117 | ### JupyterNotebooks ###
118 | # gitignore template for Jupyter Notebooks
119 | # website: http://jupyter.org/
120 |
121 | .ipynb_checkpoints
122 | */.ipynb_checkpoints/*
123 |
124 | # IPython
125 | profile_default/
126 | ipython_config.py
127 |
128 | # Remove previous ipynb_checkpoints
129 | # git rm -r .ipynb_checkpoints/
130 |
131 | ### Linux ###
132 | *~
133 |
134 | # temporary files which can be created if a process still has a handle open of a deleted file
135 | .fuse_hidden*
136 |
137 | # KDE directory preferences
138 | .directory
139 |
140 | # Linux trash folder which might appear on any partition or disk
141 | .Trash-*
142 |
143 | # .nfs files are created when an open file is removed but is still being accessed
144 | .nfs*
145 |
146 | ### macOS ###
147 | # General
148 | .DS_Store
149 | .AppleDouble
150 | .LSOverride
151 |
152 | # Icon must end with two \r
153 | Icon
154 |
155 |
156 | # Thumbnails
157 | ._*
158 |
159 | # Files that might appear in the root of a volume
160 | .DocumentRevisions-V100
161 | .fseventsd
162 | .Spotlight-V100
163 | .TemporaryItems
164 | .Trashes
165 | .VolumeIcon.icns
166 | .com.apple.timemachine.donotpresent
167 |
168 | # Directories potentially created on remote AFP share
169 | .AppleDB
170 | .AppleDesktop
171 | Network Trash Folder
172 | Temporary Items
173 | .apdisk
174 |
175 | ### macOS Patch ###
176 | # iCloud generated files
177 | *.icloud
178 |
179 | ### Python ###
180 | # Byte-compiled / optimized / DLL files
181 | __pycache__/
182 | *.py[cod]
183 | *$py.class
184 |
185 | # C extensions
186 | *.so
187 |
188 | # Distribution / packaging
189 | .Python
190 | build/
191 | develop-eggs/
192 | dist/
193 | downloads/
194 | eggs/
195 | .eggs/
196 | lib/
197 | lib64/
198 | parts/
199 | sdist/
200 | var/
201 | wheels/
202 | share/python-wheels/
203 | *.egg-info/
204 | .installed.cfg
205 | *.egg
206 | MANIFEST
207 |
208 | # PyInstaller
209 | # Usually these files are written by a python script from a template
210 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
211 | *.manifest
212 | *.spec
213 |
214 | # Installer logs
215 | pip-log.txt
216 | pip-delete-this-directory.txt
217 |
218 | # Unit test / coverage reports
219 | htmlcov/
220 | .tox/
221 | .nox/
222 | .coverage
223 | .coverage.*
224 | .cache
225 | nosetests.xml
226 | coverage.xml
227 | *.cover
228 | *.py,cover
229 | .hypothesis/
230 | .pytest_cache/
231 | cover/
232 |
233 | # Translations
234 | *.mo
235 | *.pot
236 |
237 | # Django stuff:
238 | *.log
239 | local_settings.py
240 | db.sqlite3
241 | db.sqlite3-journal
242 |
243 | # Flask stuff:
244 | instance/
245 | .webassets-cache
246 |
247 | # Scrapy stuff:
248 | .scrapy
249 |
250 | # Sphinx documentation
251 | docs/_build/
252 |
253 | # PyBuilder
254 | .pybuilder/
255 | target/
256 |
257 | # Jupyter Notebook
258 |
259 | # IPython
260 |
261 | # pyenv
262 | # For a library or package, you might want to ignore these files since the code is
263 | # intended to run in multiple environments; otherwise, check them in:
264 | # .python-version
265 |
266 | # pipenv
267 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
268 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
269 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
270 | # install all needed dependencies.
271 | #Pipfile.lock
272 |
273 | # poetry
274 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
275 | # This is especially recommended for binary packages to ensure reproducibility, and is more
276 | # commonly ignored for libraries.
277 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
278 | #poetry.lock
279 |
280 | # pdm
281 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
282 | #pdm.lock
283 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
284 | # in version control.
285 | # https://pdm.fming.dev/#use-with-ide
286 | .pdm.toml
287 |
288 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
289 | __pypackages__/
290 |
291 | # Celery stuff
292 | celerybeat-schedule
293 | celerybeat.pid
294 |
295 | # SageMath parsed files
296 | *.sage.py
297 |
298 | # Environments
299 | .env
300 | .venv
301 | env/
302 | venv/
303 | ENV/
304 | env.bak/
305 | venv.bak/
306 |
307 | # Spyder project settings
308 | .spyderproject
309 | .spyproject
310 |
311 | # Rope project settings
312 | .ropeproject
313 |
314 | # mkdocs documentation
315 | /site
316 |
317 | # mypy
318 | .mypy_cache/
319 | .dmypy.json
320 | dmypy.json
321 |
322 | # Pyre type checker
323 | .pyre/
324 |
325 | # pytype static type analyzer
326 | .pytype/
327 |
328 | # Cython debug symbols
329 | cython_debug/
330 |
331 | # PyCharm
332 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
333 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
334 | # and can be added to the global gitignore or merged into this file. For a more nuclear
335 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
336 | #.idea/
337 |
338 | ### Windows ###
339 | # Windows thumbnail cache files
340 | Thumbs.db
341 | Thumbs.db:encryptable
342 | ehthumbs.db
343 | ehthumbs_vista.db
344 |
345 | # Dump file
346 | *.stackdump
347 |
348 | # Folder config file
349 | [Dd]esktop.ini
350 |
351 | # Recycle Bin used on file shares
352 | $RECYCLE.BIN/
353 |
354 | # Windows Installer files
355 | *.cab
356 | *.msi
357 | *.msix
358 | *.msm
359 | *.msp
360 |
361 | # Windows shortcuts
362 | *.lnk
363 |
364 | # End of https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks
365 |
366 | .vscode
367 | # pixi environments
368 | .pixi
369 | *.egg-info
370 |
371 | /target
372 |
373 | # Byte-compiled / optimized / DLL files
374 | __pycache__/
375 | .pytest_cache/
376 | *.py[cod]
377 |
378 | # C extensions
379 | *.so
380 |
381 | # Distribution / packaging
382 | .Python
383 | .venv/
384 | env/
385 | bin/
386 | build/
387 | develop-eggs/
388 | dist/
389 | eggs/
390 | lib/
391 | lib64/
392 | parts/
393 | sdist/
394 | var/
395 | include/
396 | man/
397 | venv/
398 | *.egg-info/
399 | .installed.cfg
400 | *.egg
401 |
402 | # Installer logs
403 | pip-log.txt
404 | pip-delete-this-directory.txt
405 | pip-selfcheck.json
406 |
407 | # Unit test / coverage reports
408 | htmlcov/
409 | .tox/
410 | .coverage
411 | .cache
412 | nosetests.xml
413 | coverage.xml
414 |
415 | # Translations
416 | *.mo
417 |
418 | # Mr Developer
419 | .mr.developer.cfg
420 | .project
421 | .pydevproject
422 |
423 | # Rope
424 | .ropeproject
425 |
426 | # Django stuff:
427 | *.log
428 | *.pot
429 |
430 | .DS_Store
431 |
432 | # Sphinx documentation
433 | docs/_build/
434 |
435 | # PyCharm
436 | .idea/
437 |
438 | # VSCode
439 | .vscode/
440 |
441 | # Pyenv
442 | .python-version
443 |
444 | /examples/batfd/lightning_logs
445 | /examples/batfd/ckpt
446 | /examples/batfd/output
447 | /examples/xception/lightning_logs
448 | /examples/xception/ckpt
449 | /examples/xception/output
450 | wandb
451 | *.ckpt
452 | *.pth
453 | *.pt
454 | *.jit
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you find this work useful in your research, please cite it."
3 | preferred-citation:
4 | type: conference-paper
5 | title: "AV-Deepfake1M: A Large-Scale LLM-Driven Audio-Visual Deepfake Dataset"
6 | authors:
7 | - family-names: "Cai"
8 | given-names: "Zhixi"
9 | - family-names: "Ghosh"
10 | given-names: "Shreya"
11 | - family-names: "Adatia"
12 | given-names: "Aman Pankaj"
13 | - family-names: "Hayat"
14 | given-names: "Munawar"
15 | - family-names: "Dhall"
16 | given-names: "Abhinav"
17 | - family-names: "Stefanov"
18 | given-names: "Kalin"
19 | collection-title: "Proceedings of the 32nd ACM International Conference on Multimedia"
20 | year: 2023
21 | location:
22 | name: "Melbourne, Australia"
23 | start: 7414
24 | end: 7423
25 | doi: "10.1145/3664647.3680795"
26 |
--------------------------------------------------------------------------------
/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 4
4 |
5 | [[package]]
6 | name = "ahash"
7 | version = "0.8.11"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
10 | dependencies = [
11 | "cfg-if",
12 | "once_cell",
13 | "version_check",
14 | "zerocopy",
15 | ]
16 |
17 | [[package]]
18 | name = "allocator-api2"
19 | version = "0.2.18"
20 | source = "registry+https://github.com/rust-lang/crates.io-index"
21 | checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
22 |
23 | [[package]]
24 | name = "autocfg"
25 | version = "1.4.0"
26 | source = "registry+https://github.com/rust-lang/crates.io-index"
27 | checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
28 |
29 | [[package]]
30 | name = "avdeepfake1m"
31 | version = "0.0.0"
32 | dependencies = [
33 | "ndarray",
34 | "pyo3",
35 | "rayon",
36 | "serde",
37 | "serde-ndim",
38 | "serde_json",
39 | "simd-json",
40 | ]
41 |
42 | [[package]]
43 | name = "bumpalo"
44 | version = "3.16.0"
45 | source = "registry+https://github.com/rust-lang/crates.io-index"
46 | checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
47 |
48 | [[package]]
49 | name = "cblas-sys"
50 | version = "0.1.4"
51 | source = "registry+https://github.com/rust-lang/crates.io-index"
52 | checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
53 | dependencies = [
54 | "libc",
55 | ]
56 |
57 | [[package]]
58 | name = "cfg-if"
59 | version = "1.0.0"
60 | source = "registry+https://github.com/rust-lang/crates.io-index"
61 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
62 |
63 | [[package]]
64 | name = "crossbeam-deque"
65 | version = "0.8.5"
66 | source = "registry+https://github.com/rust-lang/crates.io-index"
67 | checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
68 | dependencies = [
69 | "crossbeam-epoch",
70 | "crossbeam-utils",
71 | ]
72 |
73 | [[package]]
74 | name = "crossbeam-epoch"
75 | version = "0.9.18"
76 | source = "registry+https://github.com/rust-lang/crates.io-index"
77 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
78 | dependencies = [
79 | "crossbeam-utils",
80 | ]
81 |
82 | [[package]]
83 | name = "crossbeam-utils"
84 | version = "0.8.20"
85 | source = "registry+https://github.com/rust-lang/crates.io-index"
86 | checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
87 |
88 | [[package]]
89 | name = "either"
90 | version = "1.13.0"
91 | source = "registry+https://github.com/rust-lang/crates.io-index"
92 | checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
93 |
94 | [[package]]
95 | name = "float-cmp"
96 | version = "0.9.0"
97 | source = "registry+https://github.com/rust-lang/crates.io-index"
98 | checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4"
99 | dependencies = [
100 | "num-traits",
101 | ]
102 |
103 | [[package]]
104 | name = "getrandom"
105 | version = "0.2.15"
106 | source = "registry+https://github.com/rust-lang/crates.io-index"
107 | checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
108 | dependencies = [
109 | "cfg-if",
110 | "js-sys",
111 | "libc",
112 | "wasi",
113 | "wasm-bindgen",
114 | ]
115 |
116 | [[package]]
117 | name = "halfbrown"
118 | version = "0.2.5"
119 | source = "registry+https://github.com/rust-lang/crates.io-index"
120 | checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f"
121 | dependencies = [
122 | "hashbrown",
123 | "serde",
124 | ]
125 |
126 | [[package]]
127 | name = "hashbrown"
128 | version = "0.14.5"
129 | source = "registry+https://github.com/rust-lang/crates.io-index"
130 | checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
131 | dependencies = [
132 | "ahash",
133 | "allocator-api2",
134 | ]
135 |
136 | [[package]]
137 | name = "heck"
138 | version = "0.5.0"
139 | source = "registry+https://github.com/rust-lang/crates.io-index"
140 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
141 |
142 | [[package]]
143 | name = "indoc"
144 | version = "2.0.5"
145 | source = "registry+https://github.com/rust-lang/crates.io-index"
146 | checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
147 |
148 | [[package]]
149 | name = "itoa"
150 | version = "1.0.11"
151 | source = "registry+https://github.com/rust-lang/crates.io-index"
152 | checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
153 |
154 | [[package]]
155 | name = "js-sys"
156 | version = "0.3.72"
157 | source = "registry+https://github.com/rust-lang/crates.io-index"
158 | checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9"
159 | dependencies = [
160 | "wasm-bindgen",
161 | ]
162 |
163 | [[package]]
164 | name = "lexical-core"
165 | version = "1.0.2"
166 | source = "registry+https://github.com/rust-lang/crates.io-index"
167 | checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458"
168 | dependencies = [
169 | "lexical-parse-float",
170 | "lexical-parse-integer",
171 | "lexical-util",
172 | "lexical-write-float",
173 | "lexical-write-integer",
174 | ]
175 |
176 | [[package]]
177 | name = "lexical-parse-float"
178 | version = "1.0.2"
179 | source = "registry+https://github.com/rust-lang/crates.io-index"
180 | checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0"
181 | dependencies = [
182 | "lexical-parse-integer",
183 | "lexical-util",
184 | "static_assertions",
185 | ]
186 |
187 | [[package]]
188 | name = "lexical-parse-integer"
189 | version = "1.0.2"
190 | source = "registry+https://github.com/rust-lang/crates.io-index"
191 | checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61"
192 | dependencies = [
193 | "lexical-util",
194 | "static_assertions",
195 | ]
196 |
197 | [[package]]
198 | name = "lexical-util"
199 | version = "1.0.3"
200 | source = "registry+https://github.com/rust-lang/crates.io-index"
201 | checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0"
202 | dependencies = [
203 | "static_assertions",
204 | ]
205 |
206 | [[package]]
207 | name = "lexical-write-float"
208 | version = "1.0.2"
209 | source = "registry+https://github.com/rust-lang/crates.io-index"
210 | checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809"
211 | dependencies = [
212 | "lexical-util",
213 | "lexical-write-integer",
214 | "static_assertions",
215 | ]
216 |
217 | [[package]]
218 | name = "lexical-write-integer"
219 | version = "1.0.2"
220 | source = "registry+https://github.com/rust-lang/crates.io-index"
221 | checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162"
222 | dependencies = [
223 | "lexical-util",
224 | "static_assertions",
225 | ]
226 |
227 | [[package]]
228 | name = "libc"
229 | version = "0.2.159"
230 | source = "registry+https://github.com/rust-lang/crates.io-index"
231 | checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
232 |
233 | [[package]]
234 | name = "log"
235 | version = "0.4.22"
236 | source = "registry+https://github.com/rust-lang/crates.io-index"
237 | checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
238 |
239 | [[package]]
240 | name = "matrixmultiply"
241 | version = "0.3.9"
242 | source = "registry+https://github.com/rust-lang/crates.io-index"
243 | checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
244 | dependencies = [
245 | "autocfg",
246 | "rawpointer",
247 | ]
248 |
249 | [[package]]
250 | name = "memchr"
251 | version = "2.7.4"
252 | source = "registry+https://github.com/rust-lang/crates.io-index"
253 | checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
254 |
255 | [[package]]
256 | name = "memoffset"
257 | version = "0.9.1"
258 | source = "registry+https://github.com/rust-lang/crates.io-index"
259 | checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
260 | dependencies = [
261 | "autocfg",
262 | ]
263 |
264 | [[package]]
265 | name = "ndarray"
266 | version = "0.15.6"
267 | source = "registry+https://github.com/rust-lang/crates.io-index"
268 | checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
269 | dependencies = [
270 | "cblas-sys",
271 | "libc",
272 | "matrixmultiply",
273 | "num-complex",
274 | "num-integer",
275 | "num-traits",
276 | "rawpointer",
277 | "rayon",
278 | "serde",
279 | ]
280 |
281 | [[package]]
282 | name = "num-complex"
283 | version = "0.4.6"
284 | source = "registry+https://github.com/rust-lang/crates.io-index"
285 | checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
286 | dependencies = [
287 | "num-traits",
288 | ]
289 |
290 | [[package]]
291 | name = "num-integer"
292 | version = "0.1.46"
293 | source = "registry+https://github.com/rust-lang/crates.io-index"
294 | checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
295 | dependencies = [
296 | "num-traits",
297 | ]
298 |
299 | [[package]]
300 | name = "num-traits"
301 | version = "0.2.19"
302 | source = "registry+https://github.com/rust-lang/crates.io-index"
303 | checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
304 | dependencies = [
305 | "autocfg",
306 | ]
307 |
308 | [[package]]
309 | name = "once_cell"
310 | version = "1.20.2"
311 | source = "registry+https://github.com/rust-lang/crates.io-index"
312 | checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
313 |
314 | [[package]]
315 | name = "portable-atomic"
316 | version = "1.9.0"
317 | source = "registry+https://github.com/rust-lang/crates.io-index"
318 | checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
319 |
320 | [[package]]
321 | name = "proc-macro2"
322 | version = "1.0.87"
323 | source = "registry+https://github.com/rust-lang/crates.io-index"
324 | checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a"
325 | dependencies = [
326 | "unicode-ident",
327 | ]
328 |
329 | [[package]]
330 | name = "pyo3"
331 | version = "0.24.2"
332 | source = "registry+https://github.com/rust-lang/crates.io-index"
333 | checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219"
334 | dependencies = [
335 | "cfg-if",
336 | "indoc",
337 | "libc",
338 | "memoffset",
339 | "once_cell",
340 | "portable-atomic",
341 | "pyo3-build-config",
342 | "pyo3-ffi",
343 | "pyo3-macros",
344 | "unindent",
345 | ]
346 |
347 | [[package]]
348 | name = "pyo3-build-config"
349 | version = "0.24.2"
350 | source = "registry+https://github.com/rust-lang/crates.io-index"
351 | checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999"
352 | dependencies = [
353 | "once_cell",
354 | "target-lexicon",
355 | ]
356 |
357 | [[package]]
358 | name = "pyo3-ffi"
359 | version = "0.24.2"
360 | source = "registry+https://github.com/rust-lang/crates.io-index"
361 | checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33"
362 | dependencies = [
363 | "libc",
364 | "pyo3-build-config",
365 | ]
366 |
367 | [[package]]
368 | name = "pyo3-macros"
369 | version = "0.24.2"
370 | source = "registry+https://github.com/rust-lang/crates.io-index"
371 | checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9"
372 | dependencies = [
373 | "proc-macro2",
374 | "pyo3-macros-backend",
375 | "quote",
376 | "syn",
377 | ]
378 |
379 | [[package]]
380 | name = "pyo3-macros-backend"
381 | version = "0.24.2"
382 | source = "registry+https://github.com/rust-lang/crates.io-index"
383 | checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a"
384 | dependencies = [
385 | "heck",
386 | "proc-macro2",
387 | "pyo3-build-config",
388 | "quote",
389 | "syn",
390 | ]
391 |
392 | [[package]]
393 | name = "quote"
394 | version = "1.0.37"
395 | source = "registry+https://github.com/rust-lang/crates.io-index"
396 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
397 | dependencies = [
398 | "proc-macro2",
399 | ]
400 |
401 | [[package]]
402 | name = "rawpointer"
403 | version = "0.2.1"
404 | source = "registry+https://github.com/rust-lang/crates.io-index"
405 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
406 |
407 | [[package]]
408 | name = "rayon"
409 | version = "1.10.0"
410 | source = "registry+https://github.com/rust-lang/crates.io-index"
411 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
412 | dependencies = [
413 | "either",
414 | "rayon-core",
415 | ]
416 |
417 | [[package]]
418 | name = "rayon-core"
419 | version = "1.12.1"
420 | source = "registry+https://github.com/rust-lang/crates.io-index"
421 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
422 | dependencies = [
423 | "crossbeam-deque",
424 | "crossbeam-utils",
425 | ]
426 |
427 | [[package]]
428 | name = "ref-cast"
429 | version = "1.0.23"
430 | source = "registry+https://github.com/rust-lang/crates.io-index"
431 | checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931"
432 | dependencies = [
433 | "ref-cast-impl",
434 | ]
435 |
436 | [[package]]
437 | name = "ref-cast-impl"
438 | version = "1.0.23"
439 | source = "registry+https://github.com/rust-lang/crates.io-index"
440 | checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6"
441 | dependencies = [
442 | "proc-macro2",
443 | "quote",
444 | "syn",
445 | ]
446 |
447 | [[package]]
448 | name = "ryu"
449 | version = "1.0.18"
450 | source = "registry+https://github.com/rust-lang/crates.io-index"
451 | checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
452 |
453 | [[package]]
454 | name = "serde"
455 | version = "1.0.210"
456 | source = "registry+https://github.com/rust-lang/crates.io-index"
457 | checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
458 | dependencies = [
459 | "serde_derive",
460 | ]
461 |
462 | [[package]]
463 | name = "serde-ndim"
464 | version = "1.1.0"
465 | source = "registry+https://github.com/rust-lang/crates.io-index"
466 | checksum = "5883c695f4433e428c19938eddf3a8e9d4e040f0cd0c5c11f6ded90f190774c0"
467 | dependencies = [
468 | "ndarray",
469 | "serde",
470 | ]
471 |
472 | [[package]]
473 | name = "serde_derive"
474 | version = "1.0.210"
475 | source = "registry+https://github.com/rust-lang/crates.io-index"
476 | checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
477 | dependencies = [
478 | "proc-macro2",
479 | "quote",
480 | "syn",
481 | ]
482 |
483 | [[package]]
484 | name = "serde_json"
485 | version = "1.0.128"
486 | source = "registry+https://github.com/rust-lang/crates.io-index"
487 | checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
488 | dependencies = [
489 | "itoa",
490 | "memchr",
491 | "ryu",
492 | "serde",
493 | ]
494 |
495 | [[package]]
496 | name = "simd-json"
497 | version = "0.13.11"
498 | source = "registry+https://github.com/rust-lang/crates.io-index"
499 | checksum = "a0228a564470f81724e30996bbc2b171713b37b15254a6440c7e2d5449b95691"
500 | dependencies = [
501 | "getrandom",
502 | "halfbrown",
503 | "lexical-core",
504 | "ref-cast",
505 | "serde",
506 | "serde_json",
507 | "simdutf8",
508 | "value-trait",
509 | ]
510 |
511 | [[package]]
512 | name = "simdutf8"
513 | version = "0.1.5"
514 | source = "registry+https://github.com/rust-lang/crates.io-index"
515 | checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
516 |
517 | [[package]]
518 | name = "static_assertions"
519 | version = "1.1.0"
520 | source = "registry+https://github.com/rust-lang/crates.io-index"
521 | checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
522 |
523 | [[package]]
524 | name = "syn"
525 | version = "2.0.79"
526 | source = "registry+https://github.com/rust-lang/crates.io-index"
527 | checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
528 | dependencies = [
529 | "proc-macro2",
530 | "quote",
531 | "unicode-ident",
532 | ]
533 |
534 | [[package]]
535 | name = "target-lexicon"
536 | version = "0.13.2"
537 | source = "registry+https://github.com/rust-lang/crates.io-index"
538 | checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a"
539 |
540 | [[package]]
541 | name = "unicode-ident"
542 | version = "1.0.13"
543 | source = "registry+https://github.com/rust-lang/crates.io-index"
544 | checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
545 |
546 | [[package]]
547 | name = "unindent"
548 | version = "0.2.3"
549 | source = "registry+https://github.com/rust-lang/crates.io-index"
550 | checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
551 |
552 | [[package]]
553 | name = "value-trait"
554 | version = "0.8.1"
555 | source = "registry+https://github.com/rust-lang/crates.io-index"
556 | checksum = "dad8db98c1e677797df21ba03fca7d3bf9bec3ca38db930954e4fe6e1ea27eb4"
557 | dependencies = [
558 | "float-cmp",
559 | "halfbrown",
560 | "itoa",
561 | "ryu",
562 | ]
563 |
564 | [[package]]
565 | name = "version_check"
566 | version = "0.9.5"
567 | source = "registry+https://github.com/rust-lang/crates.io-index"
568 | checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
569 |
570 | [[package]]
571 | name = "wasi"
572 | version = "0.11.0+wasi-snapshot-preview1"
573 | source = "registry+https://github.com/rust-lang/crates.io-index"
574 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
575 |
576 | [[package]]
577 | name = "wasm-bindgen"
578 | version = "0.2.95"
579 | source = "registry+https://github.com/rust-lang/crates.io-index"
580 | checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e"
581 | dependencies = [
582 | "cfg-if",
583 | "once_cell",
584 | "wasm-bindgen-macro",
585 | ]
586 |
587 | [[package]]
588 | name = "wasm-bindgen-backend"
589 | version = "0.2.95"
590 | source = "registry+https://github.com/rust-lang/crates.io-index"
591 | checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358"
592 | dependencies = [
593 | "bumpalo",
594 | "log",
595 | "once_cell",
596 | "proc-macro2",
597 | "quote",
598 | "syn",
599 | "wasm-bindgen-shared",
600 | ]
601 |
602 | [[package]]
603 | name = "wasm-bindgen-macro"
604 | version = "0.2.95"
605 | source = "registry+https://github.com/rust-lang/crates.io-index"
606 | checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56"
607 | dependencies = [
608 | "quote",
609 | "wasm-bindgen-macro-support",
610 | ]
611 |
612 | [[package]]
613 | name = "wasm-bindgen-macro-support"
614 | version = "0.2.95"
615 | source = "registry+https://github.com/rust-lang/crates.io-index"
616 | checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
617 | dependencies = [
618 | "proc-macro2",
619 | "quote",
620 | "syn",
621 | "wasm-bindgen-backend",
622 | "wasm-bindgen-shared",
623 | ]
624 |
625 | [[package]]
626 | name = "wasm-bindgen-shared"
627 | version = "0.2.95"
628 | source = "registry+https://github.com/rust-lang/crates.io-index"
629 | checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d"
630 |
631 | [[package]]
632 | name = "zerocopy"
633 | version = "0.7.35"
634 | source = "registry+https://github.com/rust-lang/crates.io-index"
635 | checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
636 | dependencies = [
637 | "zerocopy-derive",
638 | ]
639 |
640 | [[package]]
641 | name = "zerocopy-derive"
642 | version = "0.7.35"
643 | source = "registry+https://github.com/rust-lang/crates.io-index"
644 | checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
645 | dependencies = [
646 | "proc-macro2",
647 | "quote",
648 | "syn",
649 | ]
650 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "avdeepfake1m"
3 | version = "0.0.0"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 | [lib]
8 | name = "avdeepfake1m"
9 | crate-type = ["cdylib"]
10 | path = "src/lib.rs"
11 |
12 | [dependencies]
13 | pyo3 = { version = "0.24.2", features = ["extension-module"] }
14 | ndarray = { version = "0.15.0", features = ["blas", "rayon", "serde"] }
15 | serde = { version = "1.0.197", features = ["derive"] }
16 | serde_json = "1.0.115"
17 | rayon = "1.10.0"
18 | simd-json = "0.13.9"
19 | serde-ndim = { version = "1.1.0", features = ["ndarray"] }
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Attribution-NonCommercial 4.0 International
3 |
4 | =======================================================================
5 |
6 | Creative Commons Corporation ("Creative Commons") is not a law firm and
7 | does not provide legal services or legal advice. Distribution of
8 | Creative Commons public licenses does not create a lawyer-client or
9 | other relationship. Creative Commons makes its licenses and related
10 | information available on an "as-is" basis. Creative Commons gives no
11 | warranties regarding its licenses, any material licensed under their
12 | terms and conditions, or any related information. Creative Commons
13 | disclaims all liability for damages resulting from their use to the
14 | fullest extent possible.
15 |
16 | Using Creative Commons Public Licenses
17 |
18 | Creative Commons public licenses provide a standard set of terms and
19 | conditions that creators and other rights holders may use to share
20 | original works of authorship and other material subject to copyright
21 | and certain other rights specified in the public license below. The
22 | following considerations are for informational purposes only, are not
23 | exhaustive, and do not form part of our licenses.
24 |
25 | Considerations for licensors: Our public licenses are
26 | intended for use by those authorized to give the public
27 | permission to use material in ways otherwise restricted by
28 | copyright and certain other rights. Our licenses are
29 | irrevocable. Licensors should read and understand the terms
30 | and conditions of the license they choose before applying it.
31 | Licensors should also secure all rights necessary before
32 | applying our licenses so that the public can reuse the
33 | material as expected. Licensors should clearly mark any
34 | material not subject to the license. This includes other CC-
35 | licensed material, or material used under an exception or
36 | limitation to copyright. More considerations for licensors:
37 | wiki.creativecommons.org/Considerations_for_licensors
38 |
39 | Considerations for the public: By using one of our public
40 | licenses, a licensor grants the public permission to use the
41 | licensed material under specified terms and conditions. If
42 | the licensor's permission is not necessary for any reason--for
43 | example, because of any applicable exception or limitation to
44 | copyright--then that use is not regulated by the license. Our
45 | licenses grant only permissions under copyright and certain
46 | other rights that a licensor has authority to grant. Use of
47 | the licensed material may still be restricted for other
48 | reasons, including because others have copyright or other
49 | rights in the material. A licensor may make special requests,
50 | such as asking that all changes be marked or described.
51 | Although not required by our licenses, you are encouraged to
52 | respect those requests where reasonable. More_considerations
53 | for the public:
54 | wiki.creativecommons.org/Considerations_for_licensees
55 |
56 | =======================================================================
57 |
58 | Creative Commons Attribution-NonCommercial 4.0 International Public
59 | License
60 |
61 | By exercising the Licensed Rights (defined below), You accept and agree
62 | to be bound by the terms and conditions of this Creative Commons
63 | Attribution-NonCommercial 4.0 International Public License ("Public
64 | License"). To the extent this Public License may be interpreted as a
65 | contract, You are granted the Licensed Rights in consideration of Your
66 | acceptance of these terms and conditions, and the Licensor grants You
67 | such rights in consideration of benefits the Licensor receives from
68 | making the Licensed Material available under these terms and
69 | conditions.
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 | Section 2 -- Scope.
142 |
143 | a. License grant.
144 |
145 | 1. Subject to the terms and conditions of this Public License,
146 | the Licensor hereby grants You a worldwide, royalty-free,
147 | non-sublicensable, non-exclusive, irrevocable license to
148 | exercise the Licensed Rights in the Licensed Material to:
149 |
150 | a. reproduce and Share the Licensed Material, in whole or
151 | in part, for NonCommercial purposes only; and
152 |
153 | b. produce, reproduce, and Share Adapted Material for
154 | NonCommercial purposes only.
155 |
156 | 2. Exceptions and Limitations. For the avoidance of doubt, where
157 | Exceptions and Limitations apply to Your use, this Public
158 | License does not apply, and You do not need to comply with
159 | its terms and conditions.
160 |
161 | 3. Term. The term of this Public License is specified in Section
162 | 6(a).
163 |
164 | 4. Media and formats; technical modifications allowed. The
165 | Licensor authorizes You to exercise the Licensed Rights in
166 | all media and formats whether now known or hereafter created,
167 | and to make technical modifications necessary to do so. The
168 | Licensor waives and/or agrees not to assert any right or
169 | authority to forbid You from making technical modifications
170 | necessary to exercise the Licensed Rights, including
171 | technical modifications necessary to circumvent Effective
172 | Technological Measures. For purposes of this Public License,
173 | simply making modifications authorized by this Section 2(a)
174 | (4) never produces Adapted Material.
175 |
176 | 5. Downstream recipients.
177 |
178 | a. Offer from the Licensor -- Licensed Material. Every
179 | recipient of the Licensed Material automatically
180 | receives an offer from the Licensor to exercise the
181 | Licensed Rights under the terms and conditions of this
182 | Public License.
183 |
184 | b. No downstream restrictions. You may not offer or impose
185 | any additional or different terms or conditions on, or
186 | apply any Effective Technological Measures to, the
187 | Licensed Material if doing so restricts exercise of the
188 | Licensed Rights by any recipient of the Licensed
189 | Material.
190 |
191 | 6. No endorsement. Nothing in this Public License constitutes or
192 | may be construed as permission to assert or imply that You
193 | are, or that Your use of the Licensed Material is, connected
194 | with, or sponsored, endorsed, or granted official status by,
195 | the Licensor or others designated to receive attribution as
196 | provided in Section 3(a)(1)(A)(i).
197 |
198 | b. Other rights.
199 |
200 | 1. Moral rights, such as the right of integrity, are not
201 | licensed under this Public License, nor are publicity,
202 | privacy, and/or other similar personality rights; however, to
203 | the extent possible, the Licensor waives and/or agrees not to
204 | assert any such rights held by the Licensor to the limited
205 | extent necessary to allow You to exercise the Licensed
206 | Rights, but not otherwise.
207 |
208 | 2. Patent and trademark rights are not licensed under this
209 | Public License.
210 |
211 | 3. To the extent possible, the Licensor waives any right to
212 | collect royalties from You for the exercise of the Licensed
213 | Rights, whether directly or through a collecting society
214 | under any voluntary or waivable statutory or compulsory
215 | licensing scheme. In all other cases the Licensor expressly
216 | reserves any right to collect such royalties, including when
217 | the Licensed Material is used other than for NonCommercial
218 | purposes.
219 |
220 | Section 3 -- License Conditions.
221 |
222 | Your exercise of the Licensed Rights is expressly made subject to the
223 | following conditions.
224 |
225 | a. Attribution.
226 |
227 | 1. If You Share the Licensed Material (including in modified
228 | form), You must:
229 |
230 | a. retain the following if it is supplied by the Licensor
231 | with the Licensed Material:
232 |
233 | i. identification of the creator(s) of the Licensed
234 | Material and any others designated to receive
235 | attribution, in any reasonable manner requested by
236 | the Licensor (including by pseudonym if
237 | designated);
238 |
239 | ii. a copyright notice;
240 |
241 | iii. a notice that refers to this Public License;
242 |
243 | iv. a notice that refers to the disclaimer of
244 | warranties;
245 |
246 | v. a URI or hyperlink to the Licensed Material to the
247 | extent reasonably practicable;
248 |
249 | b. indicate if You modified the Licensed Material and
250 | retain an indication of any previous modifications; and
251 |
252 | c. indicate the Licensed Material is licensed under this
253 | Public License, and include the text of, or the URI or
254 | hyperlink to, this Public License.
255 |
256 | 2. You may satisfy the conditions in Section 3(a)(1) in any
257 | reasonable manner based on the medium, means, and context in
258 | which You Share the Licensed Material. For example, it may be
259 | reasonable to satisfy the conditions by providing a URI or
260 | hyperlink to a resource that includes the required
261 | information.
262 |
263 | 3. If requested by the Licensor, You must remove any of the
264 | information required by Section 3(a)(1)(A) to the extent
265 | reasonably practicable.
266 |
267 | 4. If You Share Adapted Material You produce, the Adapter's
268 | License You apply must not prevent recipients of the Adapted
269 | Material from complying with this Public License.
270 |
271 | Section 4 -- Sui Generis Database Rights.
272 |
273 | Where the Licensed Rights include Sui Generis Database Rights that
274 | apply to Your use of the Licensed Material:
275 |
276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277 | to extract, reuse, reproduce, and Share all or a substantial
278 | portion of the contents of the database for NonCommercial purposes
279 | only;
280 |
281 | b. if You include all or a substantial portion of the database
282 | contents in a database in which You have Sui Generis Database
283 | Rights, then the database in which You have Sui Generis Database
284 | Rights (but not its individual contents) is Adapted Material; and
285 |
286 | c. You must comply with the conditions in Section 3(a) if You Share
287 | all or a substantial portion of the contents of the database.
288 |
289 | For the avoidance of doubt, this Section 4 supplements and does not
290 | replace Your obligations under this Public License where the Licensed
291 | Rights include other Copyright and Similar Rights.
292 |
293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294 |
295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305 |
306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315 |
316 | c. The disclaimer of warranties and limitation of liability provided
317 | above shall be interpreted in a manner that, to the extent
318 | possible, most closely approximates an absolute disclaimer and
319 | waiver of all liability.
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 | Section 7 -- Other Terms and Conditions.
350 |
351 | a. The Licensor shall not be bound by any additional or different
352 | terms or conditions communicated by You unless expressly agreed.
353 |
354 | b. Any arrangements, understandings, or agreements regarding the
355 | Licensed Material not stated herein are separate from and
356 | independent of the terms and conditions of this Public License.
357 |
358 | Section 8 -- Interpretation.
359 |
360 | a. For the avoidance of doubt, this Public License does not, and
361 | shall not be interpreted to, reduce, limit, restrict, or impose
362 | conditions on any use of the Licensed Material that could lawfully
363 | be made without permission under this Public License.
364 |
365 | b. To the extent possible, if any provision of this Public License is
366 | deemed unenforceable, it shall be automatically reformed to the
367 | minimum extent necessary to make it enforceable. If the provision
368 | cannot be reformed, it shall be severed from this Public License
369 | without affecting the enforceability of the remaining terms and
370 | conditions.
371 |
372 | c. No term or condition of this Public License will be waived and no
373 | failure to comply consented to unless expressly agreed to by the
374 | Licensor.
375 |
376 | d. Nothing in this Public License constitutes or may be interpreted
377 | as a limitation upon, or waiver of, any privileges and immunities
378 | that apply to the Licensor or You, including from the legal
379 | processes of any jurisdiction or authority.
380 |
381 | =======================================================================
382 |
383 | Creative Commons is not a party to its public
384 | licenses. Notwithstanding, Creative Commons may elect to apply one of
385 | its public licenses to material it publishes and in those instances
386 | will be considered the “Licensor.” The text of the Creative Commons
387 | public licenses is dedicated to the public domain under the CC0 Public
388 | Domain Dedication. Except for the limited purpose of indicating that
389 | material is shared under a Creative Commons public license or as
390 | otherwise permitted by the Creative Commons policies published at
391 | creativecommons.org/policies, Creative Commons does not authorize the
392 | use of the trademark "Creative Commons" or any other trademark or logo
393 | of Creative Commons without its prior written consent including,
394 | without limitation, in connection with any unauthorized modifications
395 | to any of its public licenses or any other arrangements,
396 | understandings, or agreements concerning use of licensed material. For
397 | the avoidance of doubt, this paragraph does not form part of the
398 | public licenses.
399 |
400 | Creative Commons may be contacted at creativecommons.org.
401 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AV-Deepfake1M
2 |
3 |
4 |

5 |
6 |
7 |
8 |
27 |
28 | This is the official repository for the paper
29 | [AV-Deepfake1M: A Large-Scale LLM-Driven Audio-Visual Deepfake Dataset](https://dl.acm.org/doi/abs/10.1145/3664647.3680795) (Best Award).
30 |
31 | ## News
32 |
33 | - [2025/04/08] 🏆 [**2025 1M-Deepfakes Challenge**](https://deepfakes1m.github.io/2025) starts. New database (2M videos!) - **AV-Deepfake1M++** is released
34 | - [2024/10/13] 🚀 [**PYPI package**](https://pypi.org/project/avdeepfake1m/) is released.
35 | - [2024/07/15] 🔥 [**AV-Deepfake1M** paper](https://dl.acm.org/doi/abs/10.1145/3664647.3680795) is accepted in MM 2024.
36 | - [2024/03/09] 🏆 [**2024 1M-Deepfakes Challenge**](https://deepfakes1m.github.io/2024) starts.
37 |
38 | ## Abstract
39 | The detection and localization of highly realistic deepfake audio-visual content are challenging even for the most
40 | advanced state-of-the-art methods. While most of the research efforts in this domain are focused on detecting
41 | high-quality deepfake images and videos, only a few works address the problem of the localization of small segments of
42 | audio-visual manipulations embedded in real videos. In this research, we emulate the process of such content generation
43 | and propose the AV-Deepfake1M dataset. The dataset contains content-driven (i) video manipulations,
44 | (ii) audio manipulations, and (iii) audio-visual manipulations for more than 2K subjects resulting in a total of more
45 | than 1M videos. The paper provides a thorough description of the proposed data generation pipeline accompanied by a
46 | rigorous analysis of the quality of the generated data. The comprehensive benchmark of the proposed dataset utilizing
47 | state-of-the-art deepfake detection and localization methods indicates a significant drop in performance compared to
48 | previous datasets. The proposed dataset will play a vital role in building the next-generation deepfake localization
49 | methods.
50 |
51 | https://github.com/user-attachments/assets/d91aee8a-0fb5-4dff-ba20-86420332fed5
52 |
53 |
54 | ## Dataset
55 |
56 | ### Download
57 |
58 | We're hosting [1M-Deepfakes Detection Challenge](https://deepfakes1m.github.io/2024) at ACM MM 2024.
59 |
60 | ### Baseline Benchmark
61 |
62 | | Method | AP@0.5 | AP@0.75 | AP@0.9 | AP@0.95 | AR@50 | AR@20 | AR@10 | AR@5 |
63 | |----------------------------|--------|---------|--------|---------|-------|-------|-------|-------|
64 | | PyAnnote | 00.03 | 00.00 | 00.00 | 00.00 | 00.67 | 00.67 | 00.67 | 00.67 |
65 | | Meso4 | 09.86 | 06.05 | 02.22 | 00.59 | 38.92 | 38.81 | 36.47 | 26.91 |
66 | | MesoInception4 | 08.50 | 05.16 | 01.89 | 00.50 | 39.27 | 39.00 | 35.78 | 24.59 |
67 | | EfficientViT | 14.71 | 02.42 | 00.13 | 00.01 | 27.04 | 26.43 | 23.90 | 20.31 |
68 | | TriDet + VideoMAEv2 | 21.67 | 05.83 | 00.54 | 00.06 | 20.27 | 20.12 | 19.50 | 18.18 |
69 | | TriDet + InternVideo | 29.66 | 09.02 | 00.79 | 00.09 | 24.08 | 23.96 | 23.50 | 22.55 |
70 | | ActionFormer + VideoMAEv2 | 20.24 | 05.73 | 00.57 | 00.07 | 19.97 | 19.81 | 19.11 | 17.80 |
71 | | ActionFormer + InternVideo | 36.08 | 12.01 | 01.23 | 00.16 | 27.11 | 27.00 | 26.60 | 25.80 |
72 | | BA-TFD | 37.37 | 06.34 | 00.19 | 00.02 | 45.55 | 35.95 | 30.66 | 26.82 |
73 | | BA-TFD+ | 44.42 | 13.64 | 00.48 | 00.03 | 48.86 | 40.37 | 34.67 | 29.88 |
74 | | UMMAFormer | 51.64 | 28.07 | 07.65 | 01.58 | 44.07 | 43.45 | 42.09 | 40.27 |
75 |
76 |
77 | ### Metadata Structure
78 |
79 | The metadata is a json file for each subset (train, val), which is a list of dictionaries. The fields in the dictionary are as follows.
80 | - file: the path to the video file.
81 | - original: if the current video is fake, the path to the original video; otherwise, the original path in VoxCeleb2.
82 | - split: the name of the current subset.
83 | - modify_type: the type of modifications in different modalities, which can be ["real", "visual_modified", "audio_modified", "both_modified"]. We evaluate the deepfake detection performance based on this field.
84 | - audio_model: the audio generation model used for generating this video.
85 | - fake_segments: the timestamps of the fake segments. We evaluate the temporal localization performance based on this field.
86 | - audio_fake_segments: the timestamps of the fake segments in audio modality.
87 | - visual_fake_segments: the timestamps of the fake segments in visual modality.
88 | - video_frames: the number of frames in the video.
89 | - audio_frames: the number of frames in the audio.
90 |
91 | ## SDK
92 |
93 | We provide a Python library `avdeepfake1m` to load the dataset and evaluation.
94 |
95 | ### Installation
96 |
97 | ```bash
98 | pip install avdeepfake1m
99 | ```
100 |
101 | ### Usage
102 |
103 | Prepare the dataset as follows.
104 |
105 | ```
106 | |- train_metadata.json
107 | |- train_metadata
108 | | |- ...
109 | |- train
110 | | |- ...
111 | |- val_metadata.json
112 | |- val_metadata
113 | | |- ...
114 | |- val
115 | | |- ...
116 | |- test_files.txt
117 | |- test
118 | ```
119 |
120 | Load the dataset.
121 |
122 | ```python
123 | from avdeepfake1m.loader import AVDeepfake1mDataModule
124 |
125 | # access to Lightning DataModule
126 | dm = AVDeepfake1mDataModule("/path/to/dataset")
127 | ```
128 |
129 | Evaluate the predictions. Firstly prepare the predictions as described in the [details](https://deepfakes1m.github.io/2024/details). Then run the following code.
130 |
131 | ```python
132 | from avdeepfake1m.evaluation import ap_ar_1d, auc
133 | print(ap_ar_1d("", "", "file", "fake_segments", 1, [0.5, 0.75, 0.9, 0.95], [50, 30, 20, 10, 5], [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]))
134 | print(auc("", "", "file", "fake_segments"))
135 | ```
136 |
137 | ## License
138 |
139 | The dataset is under the [EULA](eula.pdf). You need to agree and sign the EULA to access the dataset.
140 |
141 | The baseline Xception code [/examples/xception](/examples/xception) is under MIT Licence. The BA-TFD/BA-TFD+ code [/examples/batfd](/examples/batfd) is from [ControlNet/LAV-DF](https://github.com/ControlNet/LAV-DF) under CC BY-NC 4.0 Licence.
142 |
143 | The other parts of this project is under the CC BY-NC 4.0 license. See [LICENSE](LICENSE) for details.
144 |
145 | ## References
146 |
147 | If you find this work useful in your research, please cite it.
148 |
149 | ```bibtex
150 | @inproceedings{cai2024av,
151 | title={AV-Deepfake1M: A large-scale LLM-driven audio-visual deepfake dataset},
152 | author={Cai, Zhixi and Ghosh, Shreya and Adatia, Aman Pankaj and Hayat, Munawar and Dhall, Abhinav and Gedeon, Tom and Stefanov, Kalin},
153 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia},
154 | pages={7414--7423},
155 | year={2024},
156 | doi={10.1145/3664647.3680795}
157 | }
158 | ```
159 |
160 | The challenge summary paper:
161 | ```bibtex
162 | @inproceedings{cai20241m,
163 | title={1M-Deepfakes Detection Challenge},
164 | author={Cai, Zhixi and Dhall, Abhinav and Ghosh, Shreya and Hayat, Munawar and Kollias, Dimitrios and Stefanov, Kalin and Tariq, Usman},
165 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia},
166 | pages={11355--11359},
167 | year={2024},
168 | doi={10.1145/3664647.3689145}
169 | }
170 | ```
171 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/assets/teaser.png
--------------------------------------------------------------------------------
/eula.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/eula.pdf
--------------------------------------------------------------------------------
/examples/batfd/README.md:
--------------------------------------------------------------------------------
1 | # BA-TFD
2 |
3 | This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.
4 | ## Requirements
5 |
6 | Ensure you have the necessary environment setup. You can create a Conda environment using the following commands:
7 |
8 | ```bash
9 | # prepare the environment
10 | conda create -n batfd python=3.10 -y
11 | conda activate batfd
12 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia -y
13 | pip install avdeepfake1m toml tensorboard pytorch-lightning pandas "av<14"
14 | ```
15 |
16 | ## Training
17 |
18 | Train the BATFD or BATFD+ model using a TOML configuration file (e.g., `batfd.toml` or `batfd_plus.toml`).
19 |
20 | ```bash
21 | python train.py --config ./batfd.toml --data_root /path/to/AV-Deepfake1M-PlusPlus
22 | ```
23 |
24 | ### Output
25 |
26 | * **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`.
27 | * **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`).
28 |
29 | ## Inference
30 |
31 | After training, generate predictions on a dataset subset (e.g., `val`, `test`) using `infer.py`. This script saves the predictions to a JSON file, which is required for evaluation.
32 |
33 | ```bash
34 | python infer.py --config ./batfd.toml --checkpoint /path/to/checkpoint --data_root /path/to/AV-Deepfake1M-PlusPlus --subset val
35 | ```
36 |
37 | ## Evaluation
38 |
39 | ```bash
40 | python evaluate.py /path/to/prediction_file /path/to/metadata_file
41 | ```
42 |
43 |
--------------------------------------------------------------------------------
/examples/batfd/batfd.toml:
--------------------------------------------------------------------------------
1 | name = "batfd"
2 | num_frames = 100 # T
3 | max_duration = 30 # D
4 | model_type = "batfd"
5 | dataset = "avdeepfake1m++"
6 |
7 | [model.video_encoder]
8 | type = "c3d"
9 | hidden_dims = [64, 96, 128, 128]
10 | cla_feature_in = 256 # C_f
11 |
12 | [model.audio_encoder]
13 | type = "cnn"
14 | hidden_dims = [32, 64, 64]
15 | cla_feature_in = 256 # C_f
16 |
17 | [model.frame_classifier]
18 | type = "lr"
19 |
20 | [model.boundary_module]
21 | hidden_dims = [512, 128]
22 | samples = 10 # N
23 |
24 | [optimizer]
25 | learning_rate = 0.00001
26 | frame_loss_weight = 2.0
27 | modal_bm_loss_weight = 1.0
28 | contrastive_loss_weight = 0.1
29 | contrastive_loss_margin = 0.99
30 | weight_decay = 0.0001
31 |
32 | [soft_nms]
33 | alpha = 0.7234
34 | t1 = 0.1968
35 | t2 = 0.4123
--------------------------------------------------------------------------------
/examples/batfd/batfd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/examples/batfd/batfd/__init__.py
--------------------------------------------------------------------------------
/examples/batfd/batfd/inference.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from typing import Any, List, Optional
3 | import torch
4 | from torch import Tensor
5 | import pandas as pd
6 | from pathlib import Path
7 | from lightning.pytorch import LightningModule, Trainer, Callback
8 | from torch.utils.data import DataLoader
9 |
10 | from avdeepfake1m.loader import Metadata
11 |
12 |
13 | def nullable_index(obj, index):
14 | if obj is None:
15 | return None
16 | return obj[index]
17 |
18 |
19 | class SaveToCsvCallback(Callback):
20 |
21 | def __init__(self, max_duration: int, metadata: List[Metadata], model_name: str, model_type: str, temp_dir: str):
22 | super().__init__()
23 | self.max_duration = max_duration
24 | self.metadata = metadata
25 | self.model_name = model_name
26 | self.model_type = model_type
27 | self.temp_dir = temp_dir
28 |
29 | def on_predict_batch_end(
30 | self,
31 | trainer: Trainer,
32 | pl_module: LightningModule,
33 | outputs: Any,
34 | batch: Any,
35 | batch_idx: int,
36 | ) -> None:
37 | if self.model_type == "batfd":
38 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla = outputs
39 | batch_size = fusion_bm_map.shape[0]
40 |
41 | for i in range(batch_size):
42 | temporal_size = torch.tensor(100) # the first value of `Batfd.get_meta_attr`
43 | video_name = self.metadata[batch_idx * batch_size + i].file
44 | n_frames = self.metadata[batch_idx * batch_size + i].video_frames
45 | # if n_frames is not available, it should be in test set, and we can get it from the batch
46 | if n_frames == -1:
47 | n_frames = batch[-1][i].cpu().numpy().item()
48 |
49 | assert isinstance(video_name, str)
50 | self.gen_df_for_batfd(fusion_bm_map[i], temporal_size, n_frames, os.path.join(
51 | self.temp_dir, self.model_name, video_name.replace("/", "_").replace(".mp4", ".csv")
52 | ))
53 |
54 | elif self.model_type == "batfd_plus":
55 | fusion_bm_map, fusion_start, fusion_end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end, v_frame_cla, a_frame_cla = outputs
56 | batch_size = fusion_bm_map.shape[0]
57 |
58 | for i in range(batch_size):
59 | temporal_size = torch.tensor(100) # the first value of `BatfdPlus.get_meta_attr`
60 | video_name = self.metadata[batch_idx * batch_size + i].file
61 | n_frames = self.metadata[batch_idx * batch_size + i].video_frames
62 | # if n_frames is not available, it should be in test set, and we can get it from the batch
63 | if n_frames == -1:
64 | n_frames = batch[-1][i].cpu().numpy().item()
65 | assert isinstance(video_name, str)
66 |
67 | self.gen_df_for_batfd_plus(fusion_bm_map[i], nullable_index(fusion_start, i),
68 | nullable_index(fusion_end, i), temporal_size,
69 | n_frames, os.path.join(self.temp_dir, self.model_name,
70 | video_name.replace("/", "_").replace(".mp4", ".csv")
71 | ))
72 |
73 | else:
74 | raise ValueError("Invalid model type")
75 |
76 | def gen_df_for_batfd(self, bm_map: Tensor, temporal_size: Tensor, n_frames: int, output_file: str):
77 | bm_map = bm_map.cpu().numpy()
78 | temporal_size = temporal_size.cpu().numpy().item()
79 | # for each boundary proposal in boundary map
80 | df = pd.DataFrame(bm_map)
81 | df = df.stack().reset_index()
82 | df.columns = ["duration", "begin", "score"]
83 | df["end"] = df.duration + df.begin
84 | df = df[(df.duration > 0) & (df.end <= temporal_size)]
85 | df = df.sort_values(["begin", "end"])
86 | df = df.reset_index()[["begin", "end", "score"]]
87 | df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int)
88 | df["end"] = (df["end"] / temporal_size * n_frames).astype(int)
89 | df = df.sort_values(["score"], ascending=False).iloc[:100]
90 | df.to_csv(output_file, index=False)
91 |
92 | def gen_df_for_batfd_plus(self, bm_map: Tensor, start: Optional[Tensor], end: Optional[Tensor],
93 | temporal_size: Tensor, n_frames: int, output_file: str
94 | ):
95 | bm_map = bm_map.cpu().numpy()
96 | temporal_size = temporal_size.cpu().numpy().item()
97 | if start is not None and end is not None:
98 | start = start.cpu().numpy()
99 | end = end.cpu().numpy()
100 |
101 | # for each boundary proposal in boundary map
102 | df = pd.DataFrame(bm_map)
103 | df = df.stack().reset_index()
104 | df.columns = ["duration", "begin", "score"]
105 | df["end"] = df.duration + df.begin
106 | df = df[(df.duration > 0) & (df.end <= temporal_size)]
107 | df = df.sort_values(["begin", "end"])
108 | df = df.reset_index()[["begin", "end", "score"]]
109 | if start is not None and end is not None:
110 | df["score"] = df["score"] * start[df.begin] * end[df.end]
111 |
112 | df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int)
113 | df["end"] = (df["end"] / temporal_size * n_frames).astype(int)
114 | df = df.sort_values(["score"], ascending=False).iloc[:100]
115 | df.to_csv(output_file, index=False)
116 |
117 |
118 | def inference_model(model_name: str, model: LightningModule, dataloader: DataLoader,
119 | metadata: List[Metadata],
120 | max_duration: int, model_type: str,
121 | gpus: int = 1,
122 | temp_dir: str = "output/"
123 | ) -> List[Metadata]:
124 | Path(os.path.join(temp_dir, model_name)).mkdir(parents=True, exist_ok=True)
125 |
126 | model.eval()
127 |
128 | trainer = Trainer(logger=False,
129 | enable_checkpointing=False, devices=1 if gpus > 1 else "auto",
130 | accelerator="auto" if gpus > 0 else "cpu",
131 | callbacks=[SaveToCsvCallback(max_duration, metadata, model_name, model_type, temp_dir)]
132 | )
133 |
134 | trainer.predict(model, dataloader)
135 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .batfd import Batfd
2 | from .batfd_plus import BatfdPlus
3 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/audio_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from einops import rearrange
4 | from einops.layers.torch import Rearrange
5 | from torch import Tensor
6 | from torch.nn import Module, Sequential, LeakyReLU, MaxPool2d, Linear
7 | from torchvision.models.vision_transformer import Encoder as ViTEncoder
8 |
9 | from ..utils import Conv2d
10 |
11 |
12 | class CNNAudioEncoder(Module):
13 | """
14 | Audio encoder (E_a): Process log mel spectrogram to extract features.
15 | Input:
16 | A': (B, F_m, T_a)
17 | Output:
18 | E_a: (B, C_f, T)
19 | """
20 |
21 | def __init__(self, n_features=(32, 64, 64)):
22 | super().__init__()
23 |
24 | n_dim0, n_dim1, n_dim2 = n_features
25 |
26 | # (B, 64, 2048) -> (B, 1, 64, 2048) -> (B, 32, 32, 1024)
27 | self.block0 = Sequential(
28 | Rearrange("b c t -> b 1 c t"),
29 | Conv2d(1, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
30 | MaxPool2d(2)
31 | )
32 |
33 | # (B, 32, 32, 1024) -> (B, 64, 16, 512)
34 | self.block1 = Sequential(
35 | Conv2d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
36 | Conv2d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
37 | MaxPool2d(2)
38 | )
39 |
40 | # (B, 64, 16, 512) -> (B, 64, 4, 512) -> (B, 256, 512)
41 | self.block2 = Sequential(
42 | Conv2d(n_dim1, n_dim2, kernel_size=(2, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU),
43 | MaxPool2d((2, 1)),
44 | Conv2d(n_dim2, n_dim2, kernel_size=(3, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU),
45 | MaxPool2d((2, 1)),
46 | Rearrange("b f c t -> b (f c) t")
47 | )
48 |
49 | def forward(self, audio: Tensor) -> Tensor:
50 | x = self.block0(audio)
51 | x = self.block1(x)
52 | x = self.block2(x)
53 | return x
54 |
55 |
56 | class SelfAttentionAudioEncoder(Module):
57 |
58 | def __init__(self, block_type: Literal["vit_t", "vit_s", "vit_b"], a_cla_feature_in: int = 256, temporal_size: int = 512):
59 | super().__init__()
60 | # The ViT configurations are from:
61 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
62 | if block_type == "vit_t":
63 | self.n_features = 192
64 | self.block = ViTEncoder(
65 | seq_length=temporal_size,
66 | num_layers=12,
67 | num_heads=3,
68 | hidden_dim=self.n_features,
69 | mlp_dim=self.n_features * 4,
70 | dropout=0.,
71 | attention_dropout=0.
72 | )
73 | elif block_type == "vit_s":
74 | self.n_features = 384
75 | self.block = ViTEncoder(
76 | seq_length=temporal_size,
77 | num_layers=12,
78 | num_heads=6,
79 | hidden_dim=self.n_features,
80 | mlp_dim=self.n_features * 4,
81 | dropout=0.,
82 | attention_dropout=0.
83 | )
84 | elif block_type == "vit_b":
85 | self.n_features = 768
86 | self.block = ViTEncoder(
87 | seq_length=temporal_size,
88 | num_layers=12,
89 | num_heads=12,
90 | hidden_dim=self.n_features,
91 | mlp_dim=self.n_features * 4,
92 | dropout=0.,
93 | attention_dropout=0.
94 | )
95 | else:
96 | raise ValueError(f"Unknown block type: {block_type}")
97 |
98 | self.input_proj = Conv2d(1, self.n_features, kernel_size=(64, 4), stride=(64, 4))
99 | self.output_proj = Linear(self.n_features, a_cla_feature_in)
100 |
101 | def forward(self, audio: Tensor) -> Tensor:
102 | x = audio.unsqueeze(1) # (B, 64, 2048) -> (B, 1, 64, 2048)
103 | x = self.input_proj(x) # (B, 1, 64, 2048) -> (B, feat, 1, 512)
104 | x = rearrange(x, "b f 1 t -> b t f") # (B, feat, 1, 512) -> (B, 512, feat)
105 | x = self.block(x)
106 | x = self.output_proj(x) # (B, 512, feat) -> (B, 512, 256)
107 | x = x.permute(0, 2, 1) # (B, 512, 256) -> (B, 256, 512)
108 | return x
109 |
110 |
111 | class AudioFeatureProjection(Module):
112 |
113 | def __init__(self, input_feature_dim: int, a_cla_feature_in: int = 256):
114 | super().__init__()
115 | self.proj = Linear(input_feature_dim, a_cla_feature_in)
116 |
117 | def forward(self, x: Tensor) -> Tensor:
118 | x = self.proj(x)
119 | return x.permute(0, 2, 1)
120 |
121 |
122 | def get_audio_encoder(a_cla_feature_in, temporal_size, a_encoder, ae_features):
123 | if a_encoder == "cnn":
124 | audio_encoder = CNNAudioEncoder(n_features=ae_features)
125 | elif a_encoder == "vit_t":
126 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_t", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
127 | elif a_encoder == "vit_s":
128 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_s", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
129 | elif a_encoder == "vit_b":
130 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_b", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
131 | elif a_encoder == "wav2vec2":
132 | audio_encoder = AudioFeatureProjection(input_feature_dim=1536, a_cla_feature_in=a_cla_feature_in)
133 | elif a_encoder == "trillsson3":
134 | audio_encoder = AudioFeatureProjection(input_feature_dim=1280, a_cla_feature_in=a_cla_feature_in)
135 | else:
136 | raise ValueError(f"Invalid audio encoder: {a_encoder}")
137 | return audio_encoder
138 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/batfd.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union, Sequence, Tuple
2 |
3 | import torch
4 | from lightning.pytorch import LightningModule
5 | from torch import Tensor
6 | from torch.nn import BCEWithLogitsLoss, MSELoss, functional as F
7 | from torch.optim import Adam
8 | from torch.optim.lr_scheduler import ReduceLROnPlateau
9 | from avdeepfake1m.loader import Metadata
10 |
11 | from .loss import ContrastLoss
12 | from .audio_encoder import get_audio_encoder
13 | from .boundary_module import BoundaryModule
14 | from .frame_classifier import FrameLogisticRegression
15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion
16 | from .video_encoder import get_video_encoder
17 |
18 |
19 | class Batfd(LightningModule):
20 |
21 | def __init__(self,
22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr",
23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256,
24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40,
25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99,
26 | weight_decay=0.0001, learning_rate=0.0002, distributed=False
27 | ):
28 | super().__init__()
29 | self.save_hyperparameters()
30 | self.cla_feature_in = v_cla_feature_in
31 | self.temporal_dim = temporal_dim
32 |
33 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features)
34 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features)
35 |
36 | if frame_classifier == "lr":
37 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in)
38 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in)
39 |
40 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier
41 |
42 | assert v_cla_feature_in == a_cla_feature_in
43 |
44 | v_bm_in = v_cla_feature_in + 1
45 | a_bm_in = a_cla_feature_in + 1
46 |
47 | self.video_boundary_module = BoundaryModule(v_bm_in, boundary_features, boundary_samples, temporal_dim,
48 | max_duration
49 | )
50 | self.audio_boundary_module = BoundaryModule(a_bm_in, boundary_features, boundary_samples, temporal_dim,
51 | max_duration
52 | )
53 |
54 | self.fusion = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration)
55 |
56 | self.frame_loss = BCEWithLogitsLoss()
57 | self.contrast_loss = ContrastLoss(margin=contrast_loss_margin)
58 | self.bm_loss = MSELoss()
59 | self.weight_frame_loss = weight_frame_loss
60 | self.weight_modal_bm_loss = weight_modal_bm_loss
61 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim)
62 | self.weight_decay = weight_decay
63 | self.learning_rate = learning_rate
64 | self.distributed = distributed
65 |
66 | def forward(self, video: Tensor, audio: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
67 | # encoders
68 | v_features = self.video_encoder(video)
69 | a_features = self.audio_encoder(audio)
70 |
71 | # frame classifiers
72 | v_frame_cla = self.video_frame_classifier(v_features)
73 | a_frame_cla = self.audio_frame_classifier(a_features)
74 |
75 | # concat classification result to features
76 | v_bm_in = torch.column_stack([v_features, v_frame_cla])
77 | a_bm_in = torch.column_stack([a_features, a_frame_cla])
78 |
79 | # modal boundary module
80 | v_bm_map = self.video_boundary_module(v_bm_in)
81 | a_bm_map = self.audio_boundary_module(a_bm_in)
82 |
83 | # boundary map modal attention fusion
84 | fusion_bm_map = self.fusion(v_bm_in, a_bm_in, v_bm_map, a_bm_map)
85 |
86 | return fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features
87 |
88 | def loss_fn(self, fusion_bm_map: Tensor, v_bm_map: Tensor, a_bm_map: Tensor,
89 | v_frame_cla: Tensor, a_frame_cla: Tensor, label: Tensor, n_frames: Tensor,
90 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features
91 | ) -> Dict[str, Tensor]:
92 | fusion_bm_loss = self.bm_loss(fusion_bm_map, label)
93 |
94 | v_bm_loss = self.bm_loss(v_bm_map, v_bm_label)
95 | a_bm_loss = self.bm_loss(a_bm_map, a_bm_label)
96 |
97 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label)
98 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label)
99 |
100 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label)
101 | / (self.cla_feature_in * self.temporal_dim), max=1.)
102 |
103 | loss = fusion_bm_loss + \
104 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \
105 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \
106 | self.weight_contrastive_loss * contrast_loss
107 |
108 | return {
109 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss,
110 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss
111 | }
112 |
113 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
114 | ) -> Tensor:
115 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch
116 |
117 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio)
118 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames,
119 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features
120 | )
121 |
122 | self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True,
123 | prog_bar=False, sync_dist=self.distributed)
124 | return loss_dict["loss"]
125 |
126 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
127 | ) -> Tensor:
128 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch
129 |
130 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio)
131 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames,
132 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features
133 | )
134 |
135 | self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True,
136 | prog_bar=False, sync_dist=self.distributed)
137 | return loss_dict["loss"]
138 |
139 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None
140 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
141 | video, audio, *_ = batch
142 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio)
143 | return fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla
144 |
145 | def configure_optimizers(self):
146 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay)
147 | return {
148 | "optimizer": optimizer,
149 | "lr_scheduler": {
150 | "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=3, verbose=True, min_lr=1e-8),
151 | "monitor": "val_loss"
152 | }
153 | }
154 |
155 |
156 | @staticmethod
157 | def get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tuple[Tensor, Optional[Tensor], Optional[Tensor]]):
158 | label, visual_label, audio_label = label
159 | label_real = torch.zeros(label.size(), dtype=label.dtype, device=label.device)
160 |
161 | if visual_label is not None:
162 | v_bm_label = visual_label
163 | elif meta.modify_video:
164 | v_bm_label = label
165 | else:
166 | v_bm_label = label_real
167 |
168 | if audio_label is not None:
169 | a_bm_label = audio_label
170 | elif meta.modify_audio:
171 | a_bm_label = label
172 | else:
173 | a_bm_label = label_real
174 |
175 | frame_label_real = torch.zeros(meta.video_frames)
176 | frame_label_fake = torch.zeros(meta.video_frames)
177 | for begin, end in meta.fake_periods:
178 | begin = int(begin * 25)
179 | end = int(end * 25)
180 | frame_label_fake[begin: end] = 1
181 |
182 | if visual_label is not None:
183 | v_frame_label = torch.zeros(meta.video_frames)
184 | for begin, end in meta.visual_fake_periods:
185 | begin = int(begin * 25)
186 | end = int(end * 25)
187 | v_frame_label[begin: end] = 1
188 | elif meta.modify_video:
189 | v_frame_label = frame_label_fake
190 | else:
191 | v_frame_label = frame_label_real
192 |
193 | v_frame_label = F.interpolate(v_frame_label[None, None], (100,), mode="linear")[0, 0]
194 |
195 | if audio_label is not None:
196 | a_frame_label = torch.zeros(meta.video_frames)
197 | for begin, end in meta.audio_fake_periods:
198 | begin = int(begin * 25)
199 | end = int(end * 25)
200 | a_frame_label[begin: end] = 1
201 | elif meta.modify_audio:
202 | a_frame_label = frame_label_fake
203 | else:
204 | a_frame_label = frame_label_real
205 |
206 | a_frame_label = F.interpolate(a_frame_label[None, None], (100,), mode="linear")[0, 0]
207 |
208 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1
209 |
210 | return [100, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label]
211 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/batfd_plus.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union, Sequence, Tuple
2 |
3 | import torch
4 | from lightning.pytorch import LightningModule
5 | from torch import Tensor
6 | from torch.nn import BCEWithLogitsLoss, functional as F
7 | from torch.optim import Adam
8 | from torch.optim.lr_scheduler import ExponentialLR
9 | from avdeepfake1m.loader import Metadata
10 |
11 | from .loss import ContrastLoss, BsnppLoss
12 | from .audio_encoder import get_audio_encoder
13 | from .boundary_module_plus import BoundaryModulePlus, NestedUNet
14 | from .frame_classifier import FrameLogisticRegression
15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion, ModalFeatureAttnCfgFusion
16 | from .video_encoder import get_video_encoder
17 |
18 |
19 | class BatfdPlus(LightningModule):
20 |
21 | def __init__(self,
22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr",
23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256,
24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40,
25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99,
26 | cbg_feature_weight=0.01, prb_weight_forward=1.,
27 | weight_decay=0.0001, learning_rate=0.0002, distributed=False
28 | ):
29 | super().__init__()
30 | self.save_hyperparameters()
31 |
32 | self.cla_feature_in = v_cla_feature_in
33 | self.temporal_dim = temporal_dim
34 |
35 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features)
36 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features)
37 |
38 | if frame_classifier == "lr":
39 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in)
40 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in)
41 |
42 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier
43 |
44 | assert v_cla_feature_in == a_cla_feature_in
45 |
46 | v_bm_in = v_cla_feature_in + 1
47 | a_bm_in = a_cla_feature_in + 1
48 |
49 | # Proposal Relation Block in BSN++ mechanism
50 | self.video_boundary_module = BoundaryModulePlus(v_bm_in, boundary_features, boundary_samples, temporal_dim,
51 | max_duration
52 | )
53 | self.audio_boundary_module = BoundaryModulePlus(a_bm_in, boundary_features, boundary_samples, temporal_dim,
54 | max_duration
55 | )
56 |
57 | if cbg_feature_weight > 0:
58 | # Complementary Boundary Generator in BSN++ mechanism
59 | self.video_comp_boundary_generator = NestedUNet(in_ch=v_bm_in, out_ch=2)
60 | self.audio_comp_boundary_generator = NestedUNet(in_ch=a_bm_in, out_ch=2)
61 | self.cbg_fusion_start = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in)
62 | self.cbg_fusion_end = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in)
63 | else:
64 | self.video_comp_boundary_generator = None
65 | self.audio_comp_boundary_generator = None
66 | self.cbg_fusion_start = None
67 | self.cbg_fusion_end = None
68 |
69 | self.prb_fusion_p = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration)
70 | self.prb_fusion_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration)
71 | self.prb_fusion_p_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration)
72 |
73 | self.frame_loss = BCEWithLogitsLoss()
74 | self.contrast_loss = ContrastLoss(margin=contrast_loss_margin)
75 | self.bm_loss = BsnppLoss(cbg_feature_weight, prb_weight_forward)
76 | self.weight_frame_loss = weight_frame_loss
77 | self.weight_modal_bm_loss = weight_modal_bm_loss
78 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim)
79 | self.weight_decay = weight_decay
80 | self.learning_rate = learning_rate
81 | self.distributed = distributed
82 |
83 | def forward(self, video: Tensor, audio: Tensor) -> Sequence[Tensor]:
84 | a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla = self.forward_features(audio, video)
85 |
86 | # modal boundary module
87 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c = self.video_boundary_module(v_bm_in)
88 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c = self.audio_boundary_module(a_bm_in)
89 |
90 | # complementary boundary generator
91 | if self.cbg_fusion_start is not None:
92 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in)
93 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in)
94 | else:
95 | v_cbg_feature, v_cbg_start, v_cbg_end = None, None, None
96 | a_cbg_feature, a_cbg_start, a_cbg_end = None, None, None
97 |
98 | # boundary map modal attention fusion
99 | fusion_bm_map_p = self.prb_fusion_p(v_bm_in, a_bm_in, v_bm_map_p, a_bm_map_p)
100 | fusion_bm_map_c = self.prb_fusion_c(v_bm_in, a_bm_in, v_bm_map_c, a_bm_map_c)
101 | fusion_bm_map_p_c = self.prb_fusion_p_c(v_bm_in, a_bm_in, v_bm_map_p_c, a_bm_map_p_c)
102 |
103 | # complementary boundary generator modal attention fusion
104 | if self.cbg_fusion_start is not None:
105 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start)
106 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end)
107 | else:
108 | fusion_cbg_start = None
109 | fusion_cbg_end = None
110 |
111 | return (
112 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end,
113 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end,
114 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end,
115 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature
116 | )
117 |
118 | def forward_back(self, video: Tensor, audio: Tensor) -> Sequence[Optional[Tensor]]:
119 | if self.cbg_fusion_start is not None:
120 | a_bm_in, _, _, v_bm_in, _, _ = self.forward_features(audio, video)
121 |
122 | # complementary boundary generator
123 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in)
124 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in)
125 |
126 | # complementary boundary generator modal attention fusion
127 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start)
128 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end)
129 |
130 | return (
131 | fusion_cbg_start, fusion_cbg_end, v_cbg_start, v_cbg_end, a_cbg_start, a_cbg_end,
132 | v_cbg_feature, a_cbg_feature
133 | )
134 | else:
135 | return None, None, None, None, None, None, None, None
136 |
137 | def forward_features(self, audio, video):
138 | # encoders
139 | v_features = self.video_encoder(video)
140 | a_features = self.audio_encoder(audio)
141 | # frame classifiers
142 | v_frame_cla = self.video_frame_classifier(v_features)
143 | a_frame_cla = self.audio_frame_classifier(a_features)
144 | # concat classification result to features
145 | v_bm_in = torch.column_stack([v_features, v_frame_cla])
146 | a_bm_in = torch.column_stack([a_features, a_frame_cla])
147 | return a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla
148 |
149 | def forward_video_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
150 | cbg_prob, cbg_feature = self.video_comp_boundary_generator(feature)
151 | start = cbg_prob[:, 0, :].squeeze(1)
152 | end = cbg_prob[:, 1, :].squeeze(1)
153 | return cbg_feature, end, start
154 |
155 | def forward_audio_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
156 | cbg_prob, cbg_feature = self.audio_comp_boundary_generator(feature)
157 | start = cbg_prob[:, 0, :].squeeze(1)
158 | end = cbg_prob[:, 1, :].squeeze(1)
159 | return cbg_feature, end, start
160 |
161 | def loss_fn(self,
162 | fusion_bm_map_p: Tensor, fusion_bm_map_c: Tensor, fusion_bm_map_p_c: Tensor,
163 | fusion_cbg_start: Tensor, fusion_cbg_end: Tensor,
164 | fusion_cbg_start_back: Tensor, fusion_cbg_end_back: Tensor,
165 | v_bm_map_p: Tensor, v_bm_map_c: Tensor, v_bm_map_p_c: Tensor,
166 | v_cbg_start: Tensor, v_cbg_end: Tensor, v_cbg_feature: Tensor,
167 | v_cbg_start_back: Tensor, v_cbg_end_back: Tensor, v_cbg_feature_back: Tensor,
168 | a_bm_map_p: Tensor, a_bm_map_c: Tensor, a_bm_map_p_c: Tensor,
169 | a_cbg_start: Tensor, a_cbg_end: Tensor, a_cbg_feature: Tensor,
170 | a_cbg_start_back: Tensor, a_cbg_end_back: Tensor, a_cbg_feature_back: Tensor,
171 | v_frame_cla: Tensor, a_frame_cla: Tensor, n_frames: Tensor,
172 | fusion_bm_label: Tensor, fusion_start_label: Tensor, fusion_end_label: Tensor,
173 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label,
174 | v_frame_label, a_frame_label, contrast_label, v_features, a_features
175 | ) -> Dict[str, Tensor]:
176 | (
177 | fusion_bm_loss, fusion_cbg_loss, fusion_prb_loss, fusion_cbg_loss_forward, fusion_cbg_loss_backward, _
178 | ) = self.bm_loss(
179 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c,
180 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back,
181 | fusion_bm_label, fusion_start_label, fusion_end_label
182 | )
183 |
184 | (
185 | v_bm_loss, v_cbg_loss, v_prb_loss, v_cbg_loss_forward, v_cbg_loss_backward, v_cbg_feature_loss
186 | ) = self.bm_loss(
187 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c,
188 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back,
189 | v_bm_label, v_start_label, v_end_label,
190 | v_cbg_feature, v_cbg_feature_back
191 | )
192 |
193 | (
194 | a_bm_loss, a_cbg_loss, a_prb_loss, a_cbg_loss_forward, a_cbg_loss_backward, a_cbg_feature_loss
195 | ) = self.bm_loss(
196 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c,
197 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back,
198 | a_bm_label, a_start_label, a_end_label,
199 | a_cbg_feature, a_cbg_feature_back
200 | )
201 |
202 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label)
203 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label)
204 |
205 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label)
206 | / (self.cla_feature_in * self.temporal_dim), max=1.)
207 |
208 | loss = fusion_bm_loss + \
209 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \
210 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \
211 | self.weight_contrastive_loss * contrast_loss
212 |
213 | loss_dict = {
214 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss,
215 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss,
216 | "fusion_cbg_loss": fusion_cbg_loss, "v_cbg_loss": v_cbg_loss, "a_cbg_loss": a_cbg_loss,
217 | "fusion_prb_loss": fusion_prb_loss, "v_prb_loss": v_prb_loss, "a_prb_loss": a_prb_loss,
218 | "fusion_cbg_loss_forward": fusion_cbg_loss_forward, "v_cbg_loss_forward": v_cbg_loss_forward,
219 | "a_cbg_loss_forward": a_cbg_loss_forward, "fusion_cbg_loss_backward": fusion_cbg_loss_backward,
220 | "v_cbg_loss_backward": v_cbg_loss_backward, "a_cbg_loss_backward": a_cbg_loss_backward,
221 | "v_cbg_feature_loss": v_cbg_feature_loss, "a_cbg_feature_loss": a_cbg_feature_loss
222 | }
223 | return {k: v for k, v in loss_dict.items() if v is not None}
224 |
225 | def step(self, batch: Sequence[Tensor]) -> Dict[str, Tensor]:
226 | (
227 | video, audio, fusion_bm_label, fusion_start_label, fusion_end_label, n_frames,
228 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label,
229 | a_start_label, v_start_label, a_end_label, v_end_label
230 | ) = batch
231 | # forward
232 | (
233 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end,
234 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end,
235 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end,
236 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature
237 | ) = self(video, audio)
238 | # BSN++ back
239 | video_back = torch.flip(video, dims=(2,))
240 | audio_back = torch.flip(audio, dims=(2,))
241 | (
242 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back,
243 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back
244 | ) = self.forward_back(video_back, audio_back)
245 |
246 | # loss
247 | loss_dict = self.loss_fn(
248 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c,
249 | fusion_cbg_start, fusion_cbg_end,
250 | fusion_cbg_start_back, fusion_cbg_end_back,
251 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c,
252 | v_cbg_start, v_cbg_end, v_cbg_feature,
253 | v_cbg_start_back, v_cbg_end_back, v_cbg_feature_back,
254 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c,
255 | a_cbg_start, a_cbg_end, a_cbg_feature,
256 | a_cbg_start_back, a_cbg_end_back, a_cbg_feature_back,
257 | v_frame_cla, a_frame_cla, n_frames,
258 | fusion_bm_label, fusion_start_label, fusion_end_label,
259 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label,
260 | v_frame_label, a_frame_label, contrast_label, v_features, a_features
261 | )
262 | return loss_dict
263 |
264 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
265 | ) -> Tensor:
266 | loss_dict = self.step(batch)
267 |
268 | if torch.isnan(loss_dict["loss"]):
269 | print(f"NaN in loss in {self.global_step}")
270 | exit(1)
271 | return None
272 |
273 | self.log_dict({f"metrics/train_{k}": v for k, v in loss_dict.items() if k != "loss"}, on_step=True, on_epoch=True,
274 | prog_bar=False, sync_dist=self.distributed)
275 | # only log the loss to progress bar
276 | self.log("metrics/train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True,
277 | sync_dist=self.distributed, rank_zero_only=True)
278 | return loss_dict["loss"]
279 |
280 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None,
281 | ) -> Tensor:
282 | loss_dict = self.step(batch)
283 |
284 | self.log_dict({f"metrics/val_{k}": v for k, v in loss_dict.items() if k != "loss"}, on_step=True, on_epoch=True,
285 | prog_bar=False, sync_dist=self.distributed)
286 | self.log("metrics/val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True,
287 | sync_dist=self.distributed, rank_zero_only=True)
288 | return loss_dict["loss"]
289 |
290 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None
291 | ) -> Tuple[
292 | Tensor, Optional[Tensor], Optional[Tensor],
293 | Tensor, Optional[Tensor], Optional[Tensor],
294 | Tensor, Optional[Tensor], Optional[Tensor]
295 | ]:
296 | video, audio, *_ = batch
297 | # forward
298 | (
299 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end,
300 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end,
301 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end,
302 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature
303 | ) = self(video, audio)
304 | # BSN++ back
305 | video_back = torch.flip(video, dims=(2,))
306 | audio_back = torch.flip(audio, dims=(2,))
307 | (
308 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back,
309 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back
310 | ) = self.forward_back(video_back, audio_back)
311 |
312 | fusion_bm_map, start, end = self.post_process_predict(fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c,
313 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back
314 | )
315 |
316 | v_bm_map, v_start, v_end = self.post_process_predict(v_bm_map_p, v_bm_map_c, v_bm_map_p_c,
317 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back
318 | )
319 |
320 | a_bm_map, a_start, a_end = self.post_process_predict(a_bm_map_p, a_bm_map_c, a_bm_map_p_c,
321 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back
322 | )
323 |
324 | return fusion_bm_map, start, end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end, v_frame_cla, a_frame_cla
325 |
326 | def post_process_predict(self,
327 | bm_map_p: Tensor, bm_map_c: Tensor, bm_map_p_c: Tensor,
328 | cbg_start: Tensor, cbg_end: Tensor,
329 | cbg_start_back: Tensor, cbg_end_back: Tensor
330 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
331 |
332 | bm_map = (bm_map_p + bm_map_c + bm_map_p_c) / 3
333 | if self.cbg_fusion_start is not None:
334 | start = torch.sqrt(cbg_start * torch.flip(cbg_end_back, dims=(1,)))
335 | end = torch.sqrt(cbg_end * torch.flip(cbg_start_back, dims=(1,)))
336 | else:
337 | start = None
338 | end = None
339 |
340 | return bm_map, start, end
341 |
342 | def configure_optimizers(self):
343 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay)
344 | return {
345 | "optimizer": optimizer,
346 | "lr_scheduler": {
347 | # "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True, min_lr=1e-8),
348 | "scheduler": ExponentialLR(optimizer, gamma=0.992),
349 | "monitor": "val_loss"
350 | }
351 | }
352 |
353 | @classmethod
354 | def gen_audio_video_labels(cls, label_fake: Tensor, meta: Metadata):
355 | label_real = torch.zeros(label_fake.size(), dtype=label_fake.dtype, device=label_fake.device)
356 | v_label = label_fake if meta.modify_video else label_real
357 | a_label = label_fake if meta.modify_audio else label_real
358 | return a_label, v_label
359 |
360 | @staticmethod
361 | def get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tuple[Tensor, Optional[Tensor], Optional[Tensor]]):
362 | label, visual_label, audio_label = label
363 | label_real = torch.zeros(label.size(), dtype=label.dtype, device=label.device)
364 | if visual_label is not None:
365 | v_bm_label = visual_label
366 | elif meta.modify_video:
367 | v_bm_label = label
368 | else:
369 | v_bm_label = label_real
370 |
371 | if audio_label is not None:
372 | a_bm_label = audio_label
373 | elif meta.modify_audio:
374 | a_bm_label = label
375 | else:
376 | a_bm_label = label_real
377 |
378 | frame_label_real = torch.zeros(meta.video_frames)
379 | frame_label_fake = torch.zeros(meta.video_frames)
380 | for begin, end in meta.fake_periods:
381 | begin = int(begin * 25)
382 | end = int(end * 25)
383 | frame_label_fake[begin: end] = 1
384 |
385 | if visual_label is not None:
386 | v_frame_label = torch.zeros(meta.video_frames)
387 | for begin, end in meta.visual_fake_periods:
388 | begin = int(begin * 25)
389 | end = int(end * 25)
390 | v_frame_label[begin: end] = 1
391 | elif meta.modify_video:
392 | v_frame_label = frame_label_fake
393 | else:
394 | v_frame_label = frame_label_real
395 |
396 | v_frame_label = F.interpolate(v_frame_label[None, None], (100,), mode="linear")[0, 0]
397 |
398 | if audio_label is not None:
399 | a_frame_label = torch.zeros(meta.video_frames)
400 | for begin, end in meta.audio_fake_periods:
401 | begin = int(begin * 25)
402 | end = int(end * 25)
403 | a_frame_label[begin: end] = 1
404 | elif meta.modify_audio:
405 | a_frame_label = frame_label_fake
406 | else:
407 | a_frame_label = frame_label_real
408 |
409 | a_frame_label = F.interpolate(a_frame_label[None, None], (100,), mode="linear")[0, 0]
410 |
411 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1
412 | return [100, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, 0, 0, 0, 0]
413 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/boundary_module.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from einops.layers.torch import Rearrange
5 | from torch import Tensor
6 | from torch.nn import Sequential, LeakyReLU, Sigmoid, Module
7 |
8 | from ..utils import Conv3d, Conv2d
9 |
10 |
11 | class BoundaryModule(Module):
12 | """
13 | Boundary matching module for video or audio features.
14 | Input:
15 | F_v or F_a: (B, C_f, T)
16 | Output:
17 | M_v^ or M_a^: (B, D, T)
18 |
19 | """
20 |
21 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512,
22 | max_duration: int = 40
23 | ):
24 | super().__init__()
25 |
26 | dim0, dim1 = n_features
27 |
28 | # (B, n_feature_in, temporal_dim) -> (B, n_feature_in, sample, max_duration, temporal_dim)
29 | self.bm_layer = BMLayer(temporal_dim, num_samples, max_duration)
30 |
31 | # (B, n_feature_in, sample, max_duration, temporal_dim) -> (B, dim0, max_duration, temporal_dim)
32 | self.block0 = Sequential(
33 | Conv3d(n_feature_in, dim0, kernel_size=(num_samples, 1, 1), stride=(num_samples, 1, 1),
34 | build_activation=LeakyReLU
35 | ),
36 | Rearrange("b c n d t -> b c (n d) t")
37 | )
38 |
39 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim)
40 | self.block1 = Sequential(
41 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU),
42 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU),
43 | Conv2d(dim1, 1, kernel_size=1, build_activation=Sigmoid),
44 | Rearrange("b c d t -> b (c d) t")
45 | )
46 |
47 | def forward(self, feature: Tensor) -> Tensor:
48 | feature = self.bm_layer(feature)
49 | feature = self.block0(feature)
50 | feature = self.block1(feature)
51 | return feature
52 |
53 |
54 | class BMLayer(Module):
55 | """BM Layer"""
56 |
57 | def __init__(self, temporal_dim: int, num_sample: int, max_duration: int, roi_expand_ratio: float = 0.5):
58 | super().__init__()
59 | self.temporal_dim = temporal_dim
60 | # self.feat_dim = opt['bmn_feat_dim']
61 | self.num_sample = num_sample
62 | self.duration = max_duration
63 | self.roi_expand_ratio = roi_expand_ratio
64 | self.smp_weight = self.get_pem_smp_weight()
65 |
66 | def get_pem_smp_weight(self):
67 | T = self.temporal_dim
68 | N = self.num_sample
69 | D = self.duration
70 | w = torch.zeros([T, N, D, T]) # T * N * D * T
71 | # In each temporal location i, there are D predefined proposals,
72 | # with length ranging between 1 and D
73 | # the j-th proposal is [i, i+j+1], 0<=j T - 1:
88 | continue
89 | left, right = int(np.floor(xp)), int(np.ceil(xp))
90 | left_weight = 1 - (xp - left)
91 | right_weight = 1 - (right - xp)
92 | w[left, k, j, i] += left_weight
93 | w[right, k, j, i] += right_weight
94 | return w.view(T, -1).float()
95 |
96 | def _apply(self, fn):
97 | self.smp_weight = fn(self.smp_weight)
98 |
99 | def forward(self, X):
100 | input_size = X.size()
101 | assert (input_size[-1] == self.temporal_dim)
102 | # assert(len(input_size) == 3 and
103 | X_view = X.view(-1, input_size[-1])
104 | # feature [bs*C, T]
105 | # smp_w [T, N*D*T]
106 | # out [bs*C, N*D*T] --> [bs, C, N, D, T]
107 | result = torch.matmul(X_view, self.smp_weight)
108 | return result.view(-1, input_size[1], self.num_sample, self.duration, self.temporal_dim)
109 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/boundary_module_plus.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops.layers.torch import Rearrange
7 | from torch import Tensor
8 | from torch.nn import Sequential, LeakyReLU
9 |
10 | from .boundary_module import BoundaryModule
11 | from ..utils import Conv2d
12 |
13 |
14 | class ConvUnit(nn.Module):
15 | """
16 | Unit in NestedUNet
17 | """
18 |
19 | def __init__(self, in_ch, out_ch, is_output=False):
20 | super(ConvUnit, self).__init__()
21 | module_list: list[nn.Module] = [nn.Conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True)]
22 | if is_output is False:
23 | module_list.append(nn.BatchNorm1d(out_ch))
24 | module_list.append(nn.ReLU(inplace=True))
25 | self.conv = nn.Sequential(*module_list)
26 |
27 | def forward(self, x):
28 | x = self.conv(x)
29 | return x
30 |
31 |
32 | class NestedUNet(nn.Module):
33 | """
34 | UNet - Basic Implementation
35 | Paper : https://arxiv.org/abs/1505.04597
36 | """
37 | def __init__(self, in_ch=400, out_ch=2):
38 | super(NestedUNet, self).__init__()
39 |
40 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
41 | self.up = nn.Upsample(scale_factor=2)
42 |
43 | n1 = 512
44 | filters = [n1, n1 * 2, n1 * 3]
45 | self.conv0_0 = ConvUnit(in_ch, filters[0], is_output=False)
46 | self.conv1_0 = ConvUnit(filters[0], filters[0], is_output=False)
47 | self.conv2_0 = ConvUnit(filters[0], filters[0], is_output=False)
48 |
49 | self.conv0_1 = ConvUnit(filters[1], filters[0], is_output=False)
50 | self.conv1_1 = ConvUnit(filters[1], filters[0], is_output=False)
51 |
52 | self.conv0_2 = ConvUnit(filters[2], filters[0], is_output=False)
53 |
54 | self.final = nn.Conv1d(filters[0] * 3, out_ch, kernel_size=1)
55 | # self.final = ConvUnit(filters[0] * 3, out_ch, is_output=True)
56 | self.out = nn.Sigmoid()
57 |
58 | def forward(self, x):
59 | x0_0 = self.conv0_0(x)
60 | x1_0 = self.conv1_0(self.pool(x0_0))
61 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
62 | x2_0 = self.conv2_0(self.pool(x1_0))
63 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
64 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
65 | out_feature = torch.cat([x0_0, x0_1, x0_2], 1) # for calculating loss
66 | final_feature = self.final(out_feature)
67 | out = self.out(final_feature)
68 |
69 | return out, out_feature
70 |
71 |
72 | class PositionAwareAttentionModule(nn.Module):
73 | def __init__(self, in_channels, inter_channels=None, sub_sample=None, dim=2):
74 | super(PositionAwareAttentionModule, self).__init__()
75 |
76 | self.sub_sample = sub_sample
77 | self.in_channels = in_channels
78 | self.inter_channels = inter_channels
79 | self.dim = dim
80 |
81 | if self.inter_channels is None:
82 | self.inter_channels = in_channels // 2
83 | if self.inter_channels == 0:
84 | self.inter_channels = 1
85 |
86 | if self.dim == 2:
87 | conv_nd = nn.Conv2d
88 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
89 | bn = nn.BatchNorm2d
90 | else:
91 | conv_nd = nn.Conv1d
92 | max_pool_layer = nn.MaxPool1d(kernel_size=(2,))
93 | bn = nn.BatchNorm1d
94 |
95 | self.g = nn.Sequential(
96 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
97 | bn(self.inter_channels),
98 | nn.ReLU(inplace=True)
99 | )
100 | self.theta = nn.Sequential(
101 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
102 | bn(self.inter_channels),
103 | nn.ReLU(inplace=True)
104 | )
105 | self.phi = nn.Sequential(
106 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
107 | bn(self.inter_channels),
108 | nn.ReLU(inplace=True)
109 | )
110 | self.W = nn.Sequential(
111 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
112 | kernel_size=1, stride=1, padding=0),
113 | bn(self.in_channels)
114 | )
115 | if self.sub_sample:
116 | self.g = nn.Sequential(self.g, max_pool_layer)
117 | self.phi = nn.Sequential(self.phi, max_pool_layer)
118 |
119 | def forward(self, x):
120 | batch_size = x.size(0)
121 | # value
122 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
123 | g_x = g_x.permute(0, 2, 1)
124 |
125 | # query
126 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
127 | theta_x = theta_x.permute(0, 2, 1)
128 |
129 | # key
130 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
131 |
132 | f = torch.matmul(theta_x, phi_x)
133 | f = F.softmax(f, dim=2)
134 |
135 | y = torch.matmul(f, g_x)
136 | y = y.permute(0, 2, 1).contiguous()
137 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
138 | y = self.W(y)
139 |
140 | z = y + x
141 | return z
142 |
143 |
144 | class ChannelAwareAttentionModule(nn.Module):
145 | def __init__(self, in_channels, inter_channels=None, dim=2):
146 | super(ChannelAwareAttentionModule, self).__init__()
147 |
148 | self.in_channels = in_channels
149 | self.inter_channels = inter_channels
150 | self.dim = dim
151 |
152 | if self.inter_channels is None:
153 | self.inter_channels = in_channels // 2
154 | if self.inter_channels == 0:
155 | self.inter_channels = 1
156 |
157 | if self.dim == 2:
158 | conv_nd = nn.Conv2d
159 | bn = nn.BatchNorm2d
160 | else:
161 | conv_nd = nn.Conv1d
162 | bn = nn.BatchNorm1d
163 |
164 | self.g = nn.Sequential(
165 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
166 | bn(self.inter_channels),
167 | nn.ReLU(inplace=True)
168 | )
169 | self.theta = nn.Sequential(
170 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
171 | bn(self.inter_channels),
172 | nn.ReLU(inplace=True)
173 | )
174 | self.phi = nn.Sequential(
175 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0),
176 | bn(self.inter_channels),
177 | nn.ReLU(inplace=True)
178 | )
179 | self.W = nn.Sequential(
180 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
181 | kernel_size=1, stride=1, padding=0),
182 | bn(self.in_channels)
183 | )
184 |
185 | def forward(self, x):
186 | batch_size = x.size(0)
187 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
188 |
189 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
190 |
191 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
192 | phi_x = phi_x.permute(0, 2, 1)
193 |
194 | f = torch.matmul(theta_x, phi_x)
195 | f = F.softmax(f, dim=2)
196 |
197 | y = torch.matmul(f, g_x)
198 | y = y.permute(0, 2, 1).contiguous()
199 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
200 | y = self.W(y)
201 |
202 | z = y + x
203 | return z
204 |
205 |
206 | def conv_block(in_ch, out_ch, kernel_size=3, stride=1, bn_layer=False, activate=False):
207 | module_list = [nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=1)]
208 | if bn_layer:
209 | module_list.append(nn.BatchNorm2d(out_ch))
210 | module_list.append(nn.ReLU(inplace=True))
211 | if activate:
212 | module_list.append(nn.Sigmoid())
213 | conv = nn.Sequential(*module_list)
214 | return conv
215 |
216 |
217 | class ProposalRelationBlock(nn.Module):
218 | def __init__(self, in_channels, inter_channels=128, out_channels=2, sub_sample=False):
219 | super(ProposalRelationBlock, self).__init__()
220 | self.p_net = PositionAwareAttentionModule(in_channels, inter_channels=inter_channels, sub_sample=sub_sample, dim=2)
221 | self.c_net = ChannelAwareAttentionModule(in_channels, inter_channels=inter_channels, dim=2)
222 | self.conv0_0 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False)
223 | self.conv0_1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False)
224 |
225 | self.conv1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False)
226 | self.conv2 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True)
227 | self.conv3 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True)
228 | self.conv4 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False)
229 | self.conv5 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True)
230 |
231 | def forward(self, x):
232 | x_p = self.conv0_0(x)
233 | x_c = self.conv0_1(x)
234 |
235 | x_p = self.p_net(x_p)
236 | x_c = self.c_net(x_c)
237 |
238 | x_p_0 = self.conv1(x_p)
239 | x_p_1 = self.conv2(x_p_0)
240 |
241 | x_c_0 = self.conv4(x_c)
242 | x_c_1 = self.conv5(x_c_0)
243 |
244 | x_p_c = self.conv3(x_p_0 + x_c_0)
245 | return x_p_1, x_c_1, x_p_c
246 |
247 |
248 | class BoundaryModulePlus(BoundaryModule):
249 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512,
250 | max_duration: int = 40
251 | ):
252 | super().__init__(n_feature_in, n_features, num_samples, temporal_dim, max_duration)
253 | del self.block1
254 | dim0, dim1 = n_features
255 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim)
256 | self.block1 = Sequential(
257 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU),
258 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU)
259 | )
260 | # Proposal Relation Block in BSN++ mechanism
261 | self.proposal_block = ProposalRelationBlock(dim1, dim1, 1, sub_sample=True)
262 | self.out = Rearrange("b c d t -> b (c d) t")
263 |
264 | def forward(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
265 | confidence_map = self.bm_layer(feature)
266 | confidence_map = self.block0(confidence_map)
267 | confidence_map = self.block1(confidence_map)
268 | confidence_map_p, confidence_map_c, confidence_map_p_c = self.proposal_block(confidence_map)
269 |
270 | confidence_map_p = self.out(confidence_map_p)
271 | confidence_map_c = self.out(confidence_map_c)
272 | confidence_map_p_c = self.out(confidence_map_p_c)
273 | return confidence_map_p, confidence_map_c, confidence_map_p_c
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/frame_classifier.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from torch.nn import Module
3 |
4 | from ..utils import Conv1d
5 |
6 |
7 | class FrameLogisticRegression(Module):
8 | """
9 | Frame classifier (FC_v and FC_a) for video feature (F_v) and audio feature (F_a).
10 | Input:
11 | F_v or F_a: (B, C_f, T)
12 | Output:
13 | Y^: (B, 1, T)
14 | """
15 |
16 | def __init__(self, n_features: int):
17 | super().__init__()
18 | self.lr_layer = Conv1d(n_features, 1, kernel_size=1)
19 |
20 | def forward(self, features: Tensor) -> Tensor:
21 | return self.lr_layer(features)
22 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/fusion_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import Sigmoid, Module
4 |
5 | from ..utils import Conv1d
6 |
7 |
8 | class ModalFeatureAttnBoundaryMapFusion(Module):
9 | """
10 | Fusion module for video and audio boundary maps.
11 |
12 | Input:
13 | F_v: (B, C_f, T)
14 | F_a: (B, C_f, T)
15 | M_v^: (B, D, T)
16 | M_a^: (B, D, T)
17 |
18 | Output:
19 | M^: (B, D, T)
20 | """
21 |
22 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257, max_duration: int = 40):
23 | super().__init__()
24 |
25 | self.a_attn_block = ModalMapAttnBlock(n_audio_features, n_video_features, max_duration)
26 | self.v_attn_block = ModalMapAttnBlock(n_video_features, n_audio_features, max_duration)
27 |
28 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_bm: Tensor, audio_bm: Tensor) -> Tensor:
29 | a_attn = self.a_attn_block(audio_bm, audio_feature, video_feature)
30 | v_attn = self.v_attn_block(video_bm, video_feature, audio_feature)
31 |
32 | sum_attn = a_attn + v_attn
33 |
34 | a_w = a_attn / sum_attn
35 | v_w = v_attn / sum_attn
36 |
37 | fusion_bm = video_bm * v_w + audio_bm * a_w
38 | return fusion_bm
39 |
40 |
41 | class ModalMapAttnBlock(Module):
42 |
43 | def __init__(self, n_self_features: int, n_another_features: int, max_duration: int = 40):
44 | super().__init__()
45 | self.attn_from_self_features = Conv1d(n_self_features, max_duration, kernel_size=1)
46 | self.attn_from_another_features = Conv1d(n_another_features, max_duration, kernel_size=1)
47 | self.attn_from_bm = Conv1d(max_duration, max_duration, kernel_size=1)
48 | self.sigmoid = Sigmoid()
49 |
50 | def forward(self, self_bm: Tensor, self_features: Tensor, another_features: Tensor) -> Tensor:
51 | w_bm = self.attn_from_bm(self_bm)
52 | w_self_feat = self.attn_from_self_features(self_features)
53 | w_another_feat = self.attn_from_another_features(another_features)
54 | w_stack = torch.stack((w_bm, w_self_feat, w_another_feat), dim=3)
55 | w = w_stack.mean(dim=3)
56 | return self.sigmoid(w)
57 |
58 |
59 | class ModalFeatureAttnCfgFusion(ModalFeatureAttnBoundaryMapFusion):
60 |
61 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257):
62 | super().__init__()
63 | self.a_attn_block = ModalCbgAttnBlock(n_audio_features, n_video_features)
64 | self.v_attn_block = ModalCbgAttnBlock(n_video_features, n_audio_features)
65 |
66 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_cfg: Tensor, audio_cfg: Tensor) -> Tensor:
67 | video_cfg = video_cfg.unsqueeze(1)
68 | audio_cfg = audio_cfg.unsqueeze(1)
69 | fusion_cfg = super().forward(video_feature, audio_feature, video_cfg, audio_cfg)
70 | return fusion_cfg.squeeze(1)
71 |
72 |
73 | class ModalCbgAttnBlock(ModalMapAttnBlock):
74 |
75 | def __init__(self, n_self_features: int, n_another_features: int):
76 | super().__init__(n_self_features, n_another_features, 1)
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import Module, MSELoss
4 |
5 |
6 | class MaskedBMLoss(Module):
7 |
8 | def __init__(self, loss_fn: Module):
9 | super().__init__()
10 | self.loss_fn = loss_fn
11 |
12 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor):
13 | loss = []
14 | for i, frame in enumerate(n_frames):
15 | loss.append(self.loss_fn(pred[i, :, :frame], true[i, :, :frame]))
16 | return torch.mean(torch.stack(loss))
17 |
18 |
19 | class MaskedFrameLoss(Module):
20 |
21 | def __init__(self, loss_fn: Module):
22 | super().__init__()
23 | self.loss_fn = loss_fn
24 |
25 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor):
26 | # input: (B, T)
27 | loss = []
28 | for i, frame in enumerate(n_frames):
29 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame]))
30 | return torch.mean(torch.stack(loss))
31 |
32 |
33 | class MaskedContrastLoss(Module):
34 |
35 | def __init__(self, margin: float = 0.99):
36 | super().__init__()
37 | self.margin = margin
38 |
39 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor, n_frames: Tensor):
40 | # input: (B, C, T)
41 | loss = []
42 | for i, frame in enumerate(n_frames):
43 | # mean L2 distance squared
44 | d = torch.dist(pred1[i, :, :frame], pred2[i, :, :frame], 2)
45 | if labels[i]:
46 | # if is positive pair, minimize distance
47 | loss.append(d ** 2)
48 | else:
49 | # if is negative pair, minimize (margin - distance) if distance < margin
50 | loss.append(torch.clip(self.margin - d, min=0.) ** 2)
51 | return torch.mean(torch.stack(loss))
52 |
53 |
54 | class MaskedMSE(Module):
55 |
56 | def __init__(self):
57 | super().__init__()
58 | self.loss_fn = MSELoss()
59 |
60 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor):
61 | loss = []
62 | for i, frame in enumerate(n_frames):
63 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame]))
64 | return torch.mean(torch.stack(loss))
65 |
66 |
67 | class MaskedBsnppLoss(Module):
68 | """Simplified version of BSN++ loss function."""
69 |
70 | def __init__(self, cbg_feature_weight=0.01, prb_weight_forward=1):
71 | super().__init__()
72 | self.cbg_feature_weight = cbg_feature_weight
73 | self.prb_weight_forward = prb_weight_forward
74 |
75 | self.cbg_loss_func = MaskedMSE()
76 | self.cbg_feature_loss = MaskedBMLoss(MSELoss())
77 | self.bsnpp_pem_reg_loss_func = self.cbg_feature_loss
78 |
79 | def forward(self, pred_bm_p, pred_bm_c, pred_bm_p_c, pred_start, pred_end,
80 | pred_start_backward, pred_end_backward, gt_iou_map, gt_start, gt_end, n_frames,
81 | feature_forward=None, feature_backward=None
82 | ):
83 | if self.cbg_feature_weight > 0:
84 | cbg_loss_forward = self.cbg_loss_func(pred_start, gt_start, n_frames) + \
85 | self.cbg_loss_func(pred_end, gt_end, n_frames)
86 | cbg_loss_backward = self.cbg_loss_func(torch.flip(pred_end_backward, dims=(1,)), gt_start, n_frames) + \
87 | self.cbg_loss_func(torch.flip(pred_start_backward, dims=(1,)), gt_end, n_frames)
88 |
89 | cbg_loss = cbg_loss_forward + cbg_loss_backward
90 | if feature_forward is not None and feature_backward is not None:
91 | inter_feature_loss = self.cbg_feature_weight * self.cbg_feature_loss(feature_forward,
92 | torch.flip(feature_backward, dims=(2,)), n_frames)
93 | cbg_loss += inter_feature_loss
94 | else:
95 | inter_feature_loss = None
96 | else:
97 | cbg_loss = None
98 | cbg_loss_forward = None
99 | cbg_loss_backward = None
100 | inter_feature_loss = None
101 |
102 | prb_reg_loss_p = self.bsnpp_pem_reg_loss_func(pred_bm_p, gt_iou_map, n_frames)
103 | prb_reg_loss_c = self.bsnpp_pem_reg_loss_func(pred_bm_c, gt_iou_map, n_frames)
104 | prb_reg_loss_p_c = self.bsnpp_pem_reg_loss_func(pred_bm_p_c, gt_iou_map, n_frames)
105 | prb_loss = prb_reg_loss_p + prb_reg_loss_c + prb_reg_loss_p_c
106 |
107 | loss = cbg_loss + prb_loss if cbg_loss is not None else prb_loss
108 | return loss, cbg_loss, prb_loss, cbg_loss_forward, cbg_loss_backward, inter_feature_loss
109 |
110 |
111 | # Non-masked versions of the loss functions
112 | class BMLoss(Module):
113 | """Non-masked version of MaskedBMLoss."""
114 |
115 | def __init__(self, loss_fn: Module):
116 | super().__init__()
117 | self.loss_fn = loss_fn
118 |
119 | def forward(self, pred: Tensor, true: Tensor):
120 | return self.loss_fn(pred, true)
121 |
122 |
123 | class FrameLoss(Module):
124 | def __init__(self, loss_fn: Module):
125 | super().__init__()
126 | self.loss_fn = loss_fn
127 |
128 | def forward(self, pred: Tensor, true: Tensor):
129 | # input: (B, T)
130 | return self.loss_fn(pred, true)
131 |
132 |
133 | class ContrastLoss(Module):
134 | def __init__(self, margin: float = 0.99):
135 | super().__init__()
136 | self.margin = margin
137 |
138 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor):
139 | # input: (B, C, T)
140 | batch_size = pred1.size(0)
141 | loss = []
142 | for i in range(batch_size):
143 | # mean L2 distance squared
144 | d = torch.dist(pred1[i], pred2[i], 2)
145 | if labels[i]:
146 | # if is positive pair, minimize distance
147 | loss.append(d ** 2)
148 | else:
149 | # if is negative pair, minimize (margin - distance) if distance < margin
150 | loss.append(torch.clip(self.margin - d, min=0.) ** 2)
151 | return torch.mean(torch.stack(loss))
152 |
153 |
154 | class BsnppLoss(Module):
155 | def __init__(self, cbg_feature_weight=0.01, prb_weight_forward=1):
156 | super().__init__()
157 | self.cbg_feature_weight = cbg_feature_weight
158 | self.prb_weight_forward = prb_weight_forward
159 |
160 | self.cbg_loss_func = MSELoss()
161 | self.cbg_feature_loss = BMLoss(MSELoss())
162 | self.bsnpp_pem_reg_loss_func = self.cbg_feature_loss
163 |
164 | def forward(self, pred_bm_p, pred_bm_c, pred_bm_p_c, pred_start, pred_end,
165 | pred_start_backward, pred_end_backward, gt_iou_map, gt_start, gt_end,
166 | feature_forward=None, feature_backward=None
167 | ):
168 | if self.cbg_feature_weight > 0:
169 | cbg_loss_forward = self.cbg_loss_func(pred_start, gt_start) + \
170 | self.cbg_loss_func(pred_end, gt_end)
171 | cbg_loss_backward = self.cbg_loss_func(torch.flip(pred_end_backward, dims=(1,)), gt_start) + \
172 | self.cbg_loss_func(torch.flip(pred_start_backward, dims=(1,)), gt_end)
173 |
174 | cbg_loss = cbg_loss_forward + cbg_loss_backward
175 | if feature_forward is not None and feature_backward is not None:
176 | inter_feature_loss = self.cbg_feature_weight * self.cbg_feature_loss(feature_forward,
177 | torch.flip(feature_backward, dims=(2,)))
178 | cbg_loss += inter_feature_loss
179 | else:
180 | inter_feature_loss = None
181 | else:
182 | cbg_loss = None
183 | cbg_loss_forward = None
184 | cbg_loss_backward = None
185 | inter_feature_loss = None
186 |
187 | prb_reg_loss_p = self.bsnpp_pem_reg_loss_func(pred_bm_p, gt_iou_map)
188 | prb_reg_loss_c = self.bsnpp_pem_reg_loss_func(pred_bm_c, gt_iou_map)
189 | prb_reg_loss_p_c = self.bsnpp_pem_reg_loss_func(pred_bm_p_c, gt_iou_map)
190 | prb_loss = prb_reg_loss_p + prb_reg_loss_c + prb_reg_loss_p_c
191 |
192 | loss = cbg_loss + prb_loss if cbg_loss is not None else prb_loss
193 | return loss, cbg_loss, prb_loss, cbg_loss_forward, cbg_loss_backward, inter_feature_loss
194 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/model/video_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import numpy as np
4 | from einops.layers.torch import Rearrange
5 | from torch import Tensor
6 | from torch.nn import Sequential, LeakyReLU, MaxPool3d, Module, Linear
7 | from torchvision.models.video.mvit import MSBlockConfig, _mvit
8 |
9 | from ..utils import Conv3d, Conv1d
10 |
11 |
12 | class C3DVideoEncoder(Module):
13 | """
14 | Video encoder (E_v): Process video frames to extract features.
15 | Input:
16 | V: (B, C, T, H, W)
17 | Output:
18 | F_v: (B, C_f, T)
19 | """
20 |
21 | def __init__(self, n_features=(64, 96, 128, 128), v_cla_feature_in: int = 256):
22 | super().__init__()
23 |
24 | n_dim0, n_dim1, n_dim2, n_dim3 = n_features
25 |
26 | # (B, 3, 512, 96, 96) -> (B, 64, 512, 32, 32)
27 | self.block0 = Sequential(
28 | Conv3d(3, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
29 | Conv3d(n_dim0, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
30 | MaxPool3d((1, 3, 3))
31 | )
32 |
33 | # (B, 64, 512, 32, 32) -> (B, 96, 512, 16, 16)
34 | self.block1 = Sequential(
35 | Conv3d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
36 | Conv3d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
37 | MaxPool3d((1, 2, 2))
38 | )
39 |
40 | # (B, 96, 512, 16, 16) -> (B, 128, 512, 8, 8)
41 | self.block2 = Sequential(
42 | Conv3d(n_dim1, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
43 | Conv3d(n_dim2, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
44 | MaxPool3d((1, 2, 2))
45 | )
46 |
47 | # (B, 128, 512, 8, 8) -> (B, 128, 512, 2, 2) -> (B, 512, 512) -> (B, 256, 512)
48 | self.block3 = Sequential(
49 | Conv3d(n_dim2, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
50 | MaxPool3d((1, 2, 2)),
51 | Conv3d(n_dim3, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
52 | MaxPool3d((1, 2, 2)),
53 | Rearrange("b c t h w -> b (c h w) t"),
54 | Conv1d(n_dim3 * 4, v_cla_feature_in, kernel_size=1, stride=1, build_activation=LeakyReLU)
55 | )
56 |
57 | def forward(self, video: Tensor) -> Tensor:
58 | x = self.block0(video)
59 | x = self.block1(x)
60 | x = self.block2(x)
61 | x = self.block3(x)
62 | return x
63 |
64 |
65 | class MvitVideoEncoder(Module):
66 |
67 | def __init__(self, v_cla_feature_in: int = 256,
68 | temporal_size: int = 512,
69 | mvit_type: Literal["mvit_v2_t", "mvit_v2_s", "mvit_v2_b"] = "mvit_v2_t"
70 | ):
71 | super().__init__()
72 | if mvit_type == "mvit_v2_t":
73 | self.mvit = mvit_v2_t(v_cla_feature_in, temporal_size)
74 | elif mvit_type == "mvit_v2_s":
75 | self.mvit = mvit_v2_s(v_cla_feature_in, temporal_size)
76 | elif mvit_type == "mvit_v2_b":
77 | self.mvit = mvit_v2_b(v_cla_feature_in, temporal_size)
78 | else:
79 | raise ValueError(f"Invalid mvit_type: {mvit_type}")
80 | del self.mvit.head
81 |
82 | def forward(self, video: Tensor) -> Tensor:
83 | feat = self.mvit.conv_proj(video)
84 | feat = feat.flatten(2).transpose(1, 2)
85 | feat = self.mvit.pos_encoding(feat)
86 | thw = (self.mvit.pos_encoding.temporal_size,) + self.mvit.pos_encoding.spatial_size
87 | for block in self.mvit.blocks:
88 | feat, thw = block(feat, thw)
89 |
90 | feat = self.mvit.norm(feat)
91 | feat = feat[:, 1:]
92 | feat = feat.permute(0, 2, 1)
93 | return feat
94 |
95 |
96 | def generate_config(blocks, heads, channels, out_dim):
97 | num_heads = []
98 | input_channels = []
99 | kernel_qkv = []
100 | stride_q = [[1, 1, 1]] * sum(blocks)
101 | blocks_cum = np.cumsum(blocks)
102 | stride_kv = []
103 |
104 | for i in range(len(blocks)):
105 | num_heads.extend([heads[i]] * blocks[i])
106 | input_channels.extend([channels[i]] * blocks[i])
107 | kernel_qkv.extend([[3, 3, 3]] * blocks[i])
108 |
109 | if i != len(blocks) - 1:
110 | stride_q[blocks_cum[i]] = [1, 2, 2]
111 |
112 | stride_kv_value = 2 ** (len(blocks) - 1 - i)
113 | stride_kv.extend([[1, stride_kv_value, stride_kv_value]] * blocks[i])
114 |
115 | return {
116 | "num_heads": num_heads,
117 | "input_channels": [input_channels[0]] + input_channels[:-1],
118 | "output_channels": input_channels[:-1] + [out_dim],
119 | "kernel_q": kernel_qkv,
120 | "kernel_kv": kernel_qkv,
121 | "stride_q": stride_q,
122 | "stride_kv": stride_kv
123 | }
124 |
125 |
126 | def build_mvit(config, kwargs, temporal_size=512):
127 | block_setting = []
128 | for i in range(len(config["num_heads"])):
129 | block_setting.append(
130 | MSBlockConfig(
131 | num_heads=config["num_heads"][i],
132 | input_channels=config["input_channels"][i],
133 | output_channels=config["output_channels"][i],
134 | kernel_q=config["kernel_q"][i],
135 | kernel_kv=config["kernel_kv"][i],
136 | stride_q=config["stride_q"][i],
137 | stride_kv=config["stride_kv"][i],
138 | )
139 | )
140 | return _mvit(
141 | spatial_size=(96, 96),
142 | temporal_size=temporal_size,
143 | block_setting=block_setting,
144 | residual_pool=True,
145 | residual_with_cls_embed=False,
146 | rel_pos_embed=True,
147 | proj_after_attn=True,
148 | stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
149 | weights=None,
150 | progress=False,
151 | patch_embed_kernel=(3, 15, 15),
152 | patch_embed_stride=(1, 12, 12),
153 | patch_embed_padding=(1, 3, 3),
154 | **kwargs,
155 | )
156 |
157 |
158 | def mvit_v2_b(out_dim: int, temporal_size: int, **kwargs):
159 | config = generate_config([2, 3, 16, 3], [1, 2, 4, 8], [96, 192, 384, 768], out_dim)
160 | return build_mvit(config, kwargs, temporal_size=temporal_size)
161 |
162 |
163 | def mvit_v2_s(out_dim: int, temporal_size: int, **kwargs):
164 | config = generate_config([1, 2, 11, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim)
165 | return build_mvit(config, kwargs, temporal_size=temporal_size)
166 |
167 |
168 | def mvit_v2_t(out_dim: int, temporal_size: int, **kwargs):
169 | config = generate_config([1, 2, 5, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim)
170 | return build_mvit(config, kwargs, temporal_size=temporal_size)
171 |
172 |
173 | class VideoFeatureProjection(Module):
174 |
175 | def __init__(self, input_feature_dim: int, v_cla_feature_in: int = 256):
176 | super().__init__()
177 | self.proj = Linear(input_feature_dim, v_cla_feature_in)
178 |
179 | def forward(self, x: Tensor) -> Tensor:
180 | x = self.proj(x)
181 | return x.permute(0, 2, 1)
182 |
183 |
184 | def get_video_encoder(v_cla_feature_in, temporal_size, v_encoder, ve_features):
185 | if v_encoder == "c3d":
186 | video_encoder = C3DVideoEncoder(n_features=ve_features, v_cla_feature_in=v_cla_feature_in)
187 | elif v_encoder == "mvit_t":
188 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_t")
189 | elif v_encoder == "mvit_s":
190 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_s")
191 | elif v_encoder == "mvit_b":
192 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_b")
193 | elif v_encoder == "marlin_vit_small":
194 | video_encoder = VideoFeatureProjection(input_feature_dim=13824, v_cla_feature_in=v_cla_feature_in)
195 | elif v_encoder == "i3d":
196 | video_encoder = VideoFeatureProjection(input_feature_dim=2048, v_cla_feature_in=v_cla_feature_in)
197 | elif v_encoder == "3dmm":
198 | video_encoder = VideoFeatureProjection(input_feature_dim=393, v_cla_feature_in=v_cla_feature_in)
199 | else:
200 | raise ValueError(f"Invalid video encoder: {v_encoder}")
201 | return video_encoder
202 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/post_process.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os.path
3 | from concurrent.futures import ProcessPoolExecutor
4 | from os import cpu_count
5 | from typing import List
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from tqdm.auto import tqdm
10 |
11 | from avdeepfake1m.loader import Metadata
12 | from avdeepfake1m.utils import iou_with_anchors
13 |
14 |
15 | def soft_nms(df, alpha, t1, t2, fps):
16 | df = df.sort_values(by="score", ascending=False)
17 | t_start = list(df.begin.values[:] / fps)
18 | t_end = list(df.end.values[:] / fps)
19 | t_score = list(df.score.values[:])
20 |
21 | r_start = []
22 | r_end = []
23 | r_score = []
24 |
25 | while len(t_score) > 1 and len(r_score) < 101:
26 | max_index = t_score.index(max(t_score))
27 | tmp_iou_list = iou_with_anchors(
28 | np.array(t_start),
29 | np.array(t_end), t_start[max_index], t_end[max_index])
30 | for idx in range(0, len(t_score)):
31 | if idx != max_index:
32 | tmp_iou = tmp_iou_list[idx]
33 | tmp_width = t_end[max_index] - t_start[max_index]
34 | if tmp_iou > t1 + (t2 - t1) * tmp_width:
35 | t_score[idx] *= np.exp(-np.square(tmp_iou) / alpha)
36 |
37 | r_start.append(t_start[max_index])
38 | r_end.append(t_end[max_index])
39 | r_score.append(t_score[max_index])
40 | t_start.pop(max_index)
41 | t_end.pop(max_index)
42 | t_score.pop(max_index)
43 |
44 | new_df = pd.DataFrame()
45 | new_df['score'] = r_score
46 | new_df['begin'] = r_start
47 | new_df['end'] = r_end
48 | return new_df
49 |
50 |
51 | def video_post_process(meta, model_name, fps=25, alpha=0.4, t1=0.2, t2=0.9, dataset_name="avdeepfake1m", output_dir="output"):
52 | file = resolve_csv_file_name(meta, dataset_name)
53 | df = pd.read_csv(os.path.join(output_dir, model_name, file))
54 |
55 | if len(df) > 1:
56 | df = soft_nms(df, alpha, t1, t2, fps)
57 |
58 | df = df.sort_values(by="score", ascending=False)
59 |
60 | proposal_list = []
61 |
62 | for j in range(len(df)):
63 | # round the score for saving json size
64 | score = round(df.score.values[j], 4)
65 |
66 | if score > 0:
67 | proposal_list.append([
68 | score,
69 | round(df.begin.values[j].item(), 2),
70 | round(df.end.values[j].item(), 2)
71 | ])
72 |
73 | return [meta.file, proposal_list]
74 |
75 |
76 | def resolve_csv_file_name(meta: Metadata, dataset_name: str = "avdeepfake1m") -> str:
77 | if dataset_name in ("avdeepfake1m", "avdeepfake1m++"):
78 | return meta.file.replace("/", "_").replace(".mp4", ".csv")
79 | else:
80 | raise NotImplementedError
81 |
82 |
83 | def post_process(model_name: str, save_path: str, metadata: List[Metadata], fps=25,
84 | alpha=0.4, t1=0.2, t2=0.9, dataset_name="avdeepfake1m", output_dir="output"
85 | ):
86 | with ProcessPoolExecutor(cpu_count() // 2 - 1) as executor:
87 | futures = []
88 | for meta in metadata:
89 | futures.append(executor.submit(video_post_process, meta, model_name, fps,
90 | alpha, t1, t2, dataset_name, output_dir
91 | ))
92 |
93 | results = dict(map(lambda x: x.result(), tqdm(futures)))
94 |
95 | with open(save_path, "w") as f:
96 | json.dump(results, f, indent=4)
97 |
--------------------------------------------------------------------------------
/examples/batfd/batfd/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from lightning.pytorch import LightningModule, Trainer, Callback
3 | import torch
4 | from typing import Callable, Optional
5 | from abc import ABC
6 | from torch import Tensor
7 | from torch.nn import Module
8 |
9 |
10 | class _ConvNd(Module, ABC):
11 |
12 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0,
13 | build_activation: Optional[Callable] = None
14 | ):
15 | super().__init__()
16 | self.conv = self.PtConv(
17 | in_channels, out_channels, kernel_size, stride=stride, padding=padding
18 | )
19 | if build_activation is not None:
20 | self.activation = build_activation()
21 | else:
22 | self.activation = None
23 |
24 | def forward(self, x: Tensor) -> Tensor:
25 | x = self.conv(x)
26 | if self.activation is not None:
27 | x = self.activation(x)
28 | return x
29 |
30 |
31 | class Conv1d(_ConvNd):
32 | PtConv = torch.nn.Conv1d
33 |
34 |
35 | class Conv2d(_ConvNd):
36 | PtConv = torch.nn.Conv2d
37 |
38 |
39 | class Conv3d(_ConvNd):
40 | PtConv = torch.nn.Conv3d
41 |
42 |
43 | class LrLogger(Callback):
44 | """Log learning rate in each epoch start."""
45 |
46 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
47 | for i, optimizer in enumerate(trainer.optimizers):
48 | for j, params in enumerate(optimizer.param_groups):
49 | key = f"opt{i}_lr{j}"
50 | value = params["lr"]
51 | pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
52 | pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
53 |
54 |
55 | class EarlyStoppingLR(Callback):
56 | """Early stop model training when the LR is lower than threshold."""
57 |
58 | def __init__(self, lr_threshold: float, mode="all"):
59 | self.lr_threshold = lr_threshold
60 |
61 | if mode in ("any", "all"):
62 | self.mode = mode
63 | else:
64 | raise ValueError(f"mode must be one of ('any', 'all')")
65 |
66 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
67 | self._run_early_stop_checking(trainer)
68 |
69 | def _run_early_stop_checking(self, trainer: Trainer) -> None:
70 | metrics = trainer._logger_connector.callback_metrics
71 | if len(metrics) == 0:
72 | return
73 | all_lr = []
74 | for key, value in metrics.items():
75 | if re.match(r"opt\d+_lr\d+", key):
76 | all_lr.append(value)
77 |
78 | if len(all_lr) == 0:
79 | return
80 |
81 | if self.mode == "all":
82 | if all(lr <= self.lr_threshold for lr in all_lr):
83 | trainer.should_stop = True
84 | elif self.mode == "any":
85 | if any(lr <= self.lr_threshold for lr in all_lr):
86 | trainer.should_stop = True
--------------------------------------------------------------------------------
/examples/batfd/batfd_plus.toml:
--------------------------------------------------------------------------------
1 | name = "batfd_plus"
2 | num_frames = 100 # T
3 | max_duration = 30 # D
4 | model_type = "batfd_plus"
5 | dataset = "avdeepfake1m++"
6 |
7 | [model.video_encoder]
8 | type = "mvit_b"
9 | hidden_dims = [] # handled by model type
10 | cla_feature_in = 256 # C_f
11 |
12 | [model.audio_encoder]
13 | type = "vit_b"
14 | hidden_dims = [] # handled by model type
15 | cla_feature_in = 256 # C_f
16 |
17 | [model.frame_classifier]
18 | type = "lr"
19 |
20 | [model.boundary_module]
21 | hidden_dims = [512, 128]
22 | samples = 10 # N
23 |
24 | [optimizer]
25 | learning_rate = 0.00001
26 | frame_loss_weight = 2.0
27 | modal_bm_loss_weight = 1.0
28 | cbg_feature_weight = 0.0
29 | prb_weight_forward = 1.0
30 | contrastive_loss_weight = 0.1
31 | contrastive_loss_margin = 0.99
32 | weight_decay = 0.0001
33 |
34 | [soft_nms]
35 | alpha = 0.7234
36 | t1 = 0.1968
37 | t2 = 0.4123
--------------------------------------------------------------------------------
/examples/batfd/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 |
5 | from avdeepfake1m.evaluation import ap_ar_1d
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser(description="Evaluation script for BATFD/BATFD+ models on AV-Deepfake1M")
9 | parser.add_argument("prediction_file_path", type=str, help="Path to the prediction JSON file (e.g., output/batfd_test.json)")
10 | parser.add_argument("metadata_file_path", type=str, help="Path to the metadata JSON file (e.g., /path/to/dataset/test_metadata.json or /path/to/dataset/val_metadata.json)")
11 | args = parser.parse_args()
12 |
13 | print(f"Calculating AP/AR for prediction file: {args.prediction_file_path}")
14 | print(f"Using metadata file: {args.metadata_file_path}")
15 |
16 | # Parameters for ap_ar_1d based on README.md
17 | file_key = "file"
18 | value_key = "fake_segments" # For ground truth in metadata
19 | fps = 1.0
20 | ap_iou_thresholds = [0.5, 0.75, 0.9, 0.95]
21 | ar_n_proposals = [50, 30, 20, 10, 5]
22 | ar_iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
23 |
24 | ap_ar_results = ap_ar_1d(
25 | proposals_path=args.prediction_file_path,
26 | labels_path=args.metadata_file_path,
27 | file_key=file_key,
28 | value_key=value_key,
29 | fps=fps,
30 | ap_iou_thresholds=ap_iou_thresholds,
31 | ar_n_proposals=ar_n_proposals,
32 | ar_iou_thresholds=ar_iou_thresholds
33 | )
34 |
35 | print(ap_ar_results)
36 |
37 | score = 0.5 * sum(ap_ar_results["ap"].values()) / len(ap_ar_results["ap"]) \
38 | + 0.5 * sum(ap_ar_results["ar"].values()) / len(ap_ar_results["ar"])
39 |
40 | print(f"Score: {score}")
41 |
--------------------------------------------------------------------------------
/examples/batfd/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import toml
3 | import torch
4 | import os
5 | from pathlib import Path
6 |
7 | from avdeepfake1m.loader import AVDeepfake1mDataModule, Metadata
8 | from batfd.model import Batfd, BatfdPlus
9 | from batfd.inference import inference_model
10 | from batfd.post_process import post_process
11 | from avdeepfake1m.utils import read_json
12 |
13 | def main():
14 | parser = argparse.ArgumentParser(description="BATFD/BATFD+ Inference")
15 | parser.add_argument("--config", type=str, required=True,
16 | help="Path to the TOML configuration file.")
17 | parser.add_argument("--checkpoint", type=str, required=True,
18 | help="Path to the model checkpoint.")
19 | parser.add_argument("--data_root", type=str, required=True,
20 | help="Root directory of the dataset.")
21 | parser.add_argument("--num_workers", type=int, default=8,
22 | help="Number of workers for data loading.")
23 | parser.add_argument("--subset", type=str, choices=["val", "test", "testA", "testB"],
24 | default="test", help="Dataset subset.")
25 | parser.add_argument("--gpus", type=int, default=1,
26 | help="Number of GPUs. Set to 0 for CPU.")
27 |
28 | args = parser.parse_args()
29 |
30 | # Determine device
31 | if args.gpus > 0 and torch.cuda.is_available():
32 | device = f"cuda:{torch.cuda.current_device()}"
33 | else:
34 | device = "cpu"
35 |
36 | print(f"Using device: {device}")
37 |
38 | config = toml.load(args.config)
39 | temp_dir = "output"
40 | output_file = os.path.join(temp_dir, f"{config['name']}_{args.subset}.json")
41 | model_type = config["model_type"]
42 |
43 | if model_type == "batfd_plus":
44 | model = BatfdPlus.load_from_checkpoint(args.checkpoint)
45 | elif model_type == "batfd":
46 | model = Batfd.load_from_checkpoint(args.checkpoint)
47 | else:
48 | raise ValueError(f"Unknown model type: {model_type}")
49 |
50 | model.eval()
51 |
52 | # Setup DataModule
53 | dm_dataset_name = config["dataset"]
54 | is_plusplus = dm_dataset_name == "avdeepfake1m++"
55 |
56 | dm = AVDeepfake1mDataModule(
57 | root=args.data_root,
58 | temporal_size=config["num_frames"],
59 | max_duration=config["max_duration"],
60 | require_match_scores=False,
61 | batch_size=1, # due to the problem from lightning, only 1 is supported
62 | num_workers=args.num_workers,
63 | get_meta_attr=model.get_meta_attr,
64 | return_file_name=True,
65 | is_plusplus=is_plusplus,
66 | test_subset=args.subset if args.subset in ("test", "testA", "testB") else None
67 | )
68 | dm.setup()
69 |
70 | Path(output_file).parent.mkdir(parents=True, exist_ok=True)
71 | Path(temp_dir).mkdir(parents=True, exist_ok=True)
72 |
73 | if args.subset in ("test", "testA", "testB"):
74 | dataloader = dm.test_dataloader()
75 | metadata_path = os.path.join(dm.root, f"{args.subset}_metadata.json")
76 | elif args.subset == "val":
77 | dataloader = dm.val_dataloader()
78 | metadata_path = os.path.join(dm.root, "val_metadata.json")
79 | else:
80 | raise ValueError("Invalid subset")
81 |
82 | if os.path.exists(metadata_path):
83 | metadata = [Metadata(**each, fps=25) for each in read_json(metadata_path)]
84 | else:
85 | metadata = [
86 | Metadata(file=file_name,
87 | original=None,
88 | split=args.subset,
89 | fake_segments=[],
90 | fps=25,
91 | visual_fake_segments=[],
92 | audio_fake_segments=[],
93 | audio_model="",
94 | modify_type="",
95 | # handle by the predictor in `inference_model`
96 | video_frames=-1,
97 | audio_frames=-1)
98 | for file_name in dataloader.dataset.file_list
99 | ]
100 |
101 | inference_model(
102 | model_name=config["name"],
103 | model=model,
104 | dataloader=dataloader,
105 | metadata=metadata,
106 | max_duration=config["max_duration"],
107 | model_type=config["model_type"],
108 | gpus=args.gpus,
109 | temp_dir=temp_dir
110 | )
111 |
112 | post_process(
113 | model_name=config["name"],
114 | save_path=output_file,
115 | metadata=metadata,
116 | fps=25,
117 | alpha=config["soft_nms"]["alpha"],
118 | t1=config["soft_nms"]["t1"],
119 | t2=config["soft_nms"]["t2"],
120 | dataset_name=dm_dataset_name,
121 | output_dir=temp_dir
122 | )
123 |
124 | print(f"Inference complete. Results saved to {output_file}")
125 |
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
--------------------------------------------------------------------------------
/examples/batfd/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import toml
4 | from lightning.pytorch import Trainer
5 | from lightning.pytorch.callbacks import ModelCheckpoint
6 | import torch
7 | torch.set_float32_matmul_precision('high')
8 |
9 | from avdeepfake1m.loader import AVDeepfake1mDataModule
10 | from batfd.model import Batfd, BatfdPlus
11 | from batfd.utils import LrLogger, EarlyStoppingLR
12 |
13 | parser = argparse.ArgumentParser(description="BATFD training")
14 | parser.add_argument("--config", type=str)
15 | parser.add_argument("--data_root", type=str)
16 | parser.add_argument("--batch_size", type=int, default=4)
17 | parser.add_argument("--num_workers", type=int, default=8)
18 | parser.add_argument("--gpus", type=int, default=1)
19 | parser.add_argument("--precision", default=32)
20 | parser.add_argument("--num_train", type=int, default=None)
21 | parser.add_argument("--num_val", type=int, default=1000)
22 | parser.add_argument("--max_epochs", type=int, default=500)
23 | parser.add_argument("--logger", type=str, choices=["wandb", "tensorboard"], default="tensorboard")
24 | parser.add_argument("--resume", type=str, default=None)
25 |
26 | if __name__ == '__main__':
27 | args = parser.parse_args()
28 | config = toml.load(args.config)
29 |
30 | learning_rate = config["optimizer"]["learning_rate"]
31 | gpus = args.gpus
32 | total_batch_size = args.batch_size * gpus
33 | learning_rate = learning_rate * total_batch_size / 4
34 | dataset = config["dataset"]
35 |
36 | v_encoder_type = config["model"]["video_encoder"]["type"]
37 | a_encoder_type = config["model"]["audio_encoder"]["type"]
38 |
39 | if v_encoder_type in ("marlin_vit_small", "3dmm", "i3d"):
40 | v_feature = v_encoder_type
41 | else:
42 | v_feature = None
43 |
44 | if a_encoder_type in ("deep_speech", "wav2vec2", "trill"):
45 | a_feature = a_encoder_type
46 | else:
47 | a_feature = None
48 |
49 | if config["model_type"] == "batfd_plus":
50 | model = BatfdPlus(
51 | v_encoder=v_encoder_type,
52 | a_encoder=config["model"]["audio_encoder"]["type"],
53 | frame_classifier=config["model"]["frame_classifier"]["type"],
54 | ve_features=config["model"]["video_encoder"]["hidden_dims"],
55 | ae_features=config["model"]["audio_encoder"]["hidden_dims"],
56 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"],
57 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"],
58 | boundary_features=config["model"]["boundary_module"]["hidden_dims"],
59 | boundary_samples=config["model"]["boundary_module"]["samples"],
60 | temporal_dim=config["num_frames"],
61 | max_duration=config["max_duration"],
62 | weight_frame_loss=config["optimizer"]["frame_loss_weight"],
63 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"],
64 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"],
65 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"],
66 | cbg_feature_weight=config["optimizer"]["cbg_feature_weight"],
67 | prb_weight_forward=config["optimizer"]["prb_weight_forward"],
68 | weight_decay=config["optimizer"]["weight_decay"],
69 | learning_rate=learning_rate,
70 | distributed=args.gpus > 1
71 | )
72 | require_match_scores = True
73 | get_meta_attr = BatfdPlus.get_meta_attr
74 | elif config["model_type"] == "batfd":
75 | model = Batfd(
76 | v_encoder=config["model"]["video_encoder"]["type"],
77 | a_encoder=config["model"]["audio_encoder"]["type"],
78 | frame_classifier=config["model"]["frame_classifier"]["type"],
79 | ve_features=config["model"]["video_encoder"]["hidden_dims"],
80 | ae_features=config["model"]["audio_encoder"]["hidden_dims"],
81 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"],
82 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"],
83 | boundary_features=config["model"]["boundary_module"]["hidden_dims"],
84 | boundary_samples=config["model"]["boundary_module"]["samples"],
85 | temporal_dim=config["num_frames"],
86 | max_duration=config["max_duration"],
87 | weight_frame_loss=config["optimizer"]["frame_loss_weight"],
88 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"],
89 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"],
90 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"],
91 | weight_decay=config["optimizer"]["weight_decay"],
92 | learning_rate=learning_rate,
93 | distributed=args.gpus > 1
94 | )
95 | require_match_scores = False
96 | get_meta_attr = Batfd.get_meta_attr
97 | else:
98 | raise ValueError("Invalid model type")
99 |
100 | if dataset == "avdeepfake1m":
101 | dm = AVDeepfake1mDataModule(
102 | root=args.data_root,
103 | temporal_size=config["num_frames"],
104 | max_duration=config["max_duration"],
105 | require_match_scores=require_match_scores,
106 | batch_size=args.batch_size, num_workers=args.num_workers,
107 | take_train=args.num_train, take_val=args.num_val,
108 | get_meta_attr=get_meta_attr,
109 | is_plusplus=False
110 | )
111 | elif dataset == "avdeepfake1m++":
112 | dm = AVDeepfake1mDataModule(
113 | root=args.data_root,
114 | temporal_size=config["num_frames"],
115 | max_duration=config["max_duration"],
116 | require_match_scores=require_match_scores,
117 | batch_size=args.batch_size, num_workers=args.num_workers,
118 | take_train=args.num_train, take_val=args.num_val,
119 | get_meta_attr=get_meta_attr,
120 | is_plusplus=True
121 | )
122 | else:
123 | raise ValueError("Invalid dataset type")
124 |
125 | try:
126 | precision = int(args.precision)
127 | except ValueError:
128 | precision: int | str = args.precision
129 |
130 | monitor = "metrics/val_loss"
131 |
132 | if args.logger == "wandb":
133 | from lightning.pytorch.loggers import WandbLogger
134 | logger = WandbLogger(name=config["name"], project=dataset)
135 | else:
136 | logger = True
137 |
138 | trainer = Trainer(log_every_n_steps=20, precision=precision, max_epochs=args.max_epochs,
139 | callbacks=[
140 | ModelCheckpoint(
141 | dirpath=f"./ckpt/{config['name']}", save_last=True, filename=config["name"] + "-{epoch}-{val_loss:.3f}",
142 | monitor=monitor, mode="min"
143 | ),
144 | LrLogger(),
145 | EarlyStoppingLR(lr_threshold=1e-7)
146 | ], enable_checkpointing=True,
147 | benchmark=True,
148 | accelerator="auto",
149 | devices=args.gpus,
150 | strategy="auto" if args.gpus < 2 else "ddp",
151 | logger=logger
152 | )
153 |
154 | trainer.fit(model, dm, ckpt_path=args.resume)
155 |
--------------------------------------------------------------------------------
/examples/xception/README.md:
--------------------------------------------------------------------------------
1 | # Xception
2 |
3 | This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.
4 |
5 | ## Requirements
6 |
7 | - Python
8 | - PyTorch
9 | - PyTorch Lightning
10 | - TIMM
11 | - AVDeepfake1M SDK
12 |
13 |
14 | ## Training
15 |
16 | ```bash
17 | python train.py --data_root /path/to/avdeepfake1m --model xception
18 | ```
19 | ### Output
20 |
21 | * **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`.
22 | * **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`).
23 |
24 |
25 | ## Inference
26 |
27 | After training, you can generate predictions on a dataset subset (train, val, or test) using `infer.py`. This script will save the predictions to a text file, following the format from the [challenge](https://deepfakes1m.github.io/2025/details).
28 |
29 | ```bash
30 | python infer.py --data_root /path/to/avdeepfake1m --checkpoint /path/to/your/checkpoint.ckpt --model xception --subset val
31 | ```
32 |
33 | The output prediction file will be saved to `output/_.txt` (e.g., `output/xception_val.txt`).
34 |
35 | ## Evaluation
36 |
37 | ```bash
38 | python evaluate.py
39 | ```
40 |
41 | For example:
42 |
43 | ```bash
44 | python evaluate.py ./output/xception_val.txt /path/to/avdeepfake1m/val_metadata.json
45 | ```
46 |
47 | This will print the AUC score based on your model's predictions.
48 |
--------------------------------------------------------------------------------
/examples/xception/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from avdeepfake1m.evaluation import auc
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser(description="Evaluation script for AV-Deepfake1M")
7 | parser.add_argument("prediction_file_path", type=str, help="Path to the prediction file (e.g., output/results/xception_val.txt)")
8 | parser.add_argument("metadata_file_path", type=str, help="Path to the metadata JSON file (e.g., /path/to/val_metadata.json)")
9 | args = parser.parse_args()
10 |
11 | print(auc(
12 | args.prediction_file_path,
13 | args.metadata_file_path,
14 | "file", # As per README, this is usually "file"
15 | "fake_segments" # As per README, this is usually "fake_segments" for AUC
16 | ))
--------------------------------------------------------------------------------
/examples/xception/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from tqdm.auto import tqdm
5 | from pathlib import Path
6 |
7 | from avdeepfake1m.loader import AVDeepfake1mPlusPlusVideo
8 | from xception import Xception
9 |
10 | parser = argparse.ArgumentParser(description="Xception inference")
11 | parser.add_argument("--data_root", type=str)
12 | parser.add_argument("--checkpoint", type=str)
13 | parser.add_argument("--model", type=str)
14 | parser.add_argument("--batch_size", type=int, default=128)
15 | parser.add_argument("--subset", type=str, choices=["train", "val", "test", "testA", "testB"])
16 | parser.add_argument("--gpus", type=int, default=1)
17 | parser.add_argument("--resume", type=str, default=None)
18 | parser.add_argument("--take_num", type=int, default=None)
19 |
20 | if __name__ == '__main__':
21 | args = parser.parse_args()
22 | use_gpu = args.gpus > 0
23 | device = "cuda" if use_gpu else "cpu"
24 |
25 | if args.model == "xception":
26 | model = Xception.load_from_checkpoint(args.checkpoint, lr=None, distributed=False).eval()
27 | else:
28 | raise ValueError(f"Unknown model: {args.model}")
29 |
30 | model.to(device)
31 | model.train() # not sure why but eval mode will generate nonsense output
32 | test_dataset = AVDeepfake1mPlusPlusVideo(args.subset, args.data_root, take_num=args.take_num, pred_mode=True)
33 |
34 | save_path = f"output/{args.model}_{args.subset}.txt"
35 | Path(save_path).parent.mkdir(parents=True, exist_ok=True)
36 |
37 | processed_files = set()
38 | if args.resume is not None:
39 | with open(args.resume, "r") as f:
40 | for line in f:
41 | processed_files.add(line.split(";")[0])
42 |
43 | with open(save_path, "a") as f:
44 | with torch.inference_mode():
45 | for i in tqdm(range(len(test_dataset))):
46 | file_name = test_dataset.metadata[i].file
47 | if file_name in processed_files:
48 | continue
49 |
50 | video, _, _ = test_dataset[i]
51 | # batch video as frames use batch_size
52 | preds_video = []
53 | for j in range(0, len(video), args.batch_size):
54 | batch = video[j:j + args.batch_size].to(device)
55 | preds_video.append(model(batch))
56 |
57 | preds_video = torch.cat(preds_video, dim=0).flatten()
58 | # choose the max prediction
59 | pred = preds_video.max().item()
60 |
61 | f.write(f"{file_name};{pred}\n")
62 | f.flush()
63 |
--------------------------------------------------------------------------------
/examples/xception/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from torch.utils.data import DataLoader
4 | from lightning.pytorch import Trainer
5 | from lightning.pytorch.callbacks import ModelCheckpoint
6 | from avdeepfake1m.loader import AVDeepfake1mPlusPlusImages
7 |
8 | from xception import Xception
9 | from utils import LrLogger, EarlyStoppingLR
10 |
11 |
12 | parser = argparse.ArgumentParser(description="Classification model training")
13 | parser.add_argument("--data_root", type=str)
14 | parser.add_argument("--batch_size", type=int, default=128)
15 | parser.add_argument("--model", type=str, choices=["xception", "meso4", "meso_inception4"])
16 | parser.add_argument("--gpus", type=int, default=1)
17 | parser.add_argument("--precision", default=32)
18 | parser.add_argument("--num_train", type=int, default=None)
19 | parser.add_argument("--num_val", type=int, default=2000)
20 | parser.add_argument("--max_epochs", type=int, default=500)
21 | parser.add_argument("--resume", type=str, default=None)
22 | args = parser.parse_args()
23 |
24 |
25 | if __name__ == "__main__":
26 |
27 | # You can fix the random seed if you want reproducible subsets each epoch:
28 | # torch.manual_seed(42)
29 | # random.seed(42)
30 |
31 | learning_rate = 1e-4
32 | gpus = args.gpus
33 | total_batch_size = args.batch_size * gpus
34 | learning_rate = learning_rate * total_batch_size / 4
35 |
36 | # Setup model
37 | if args.model == "xception":
38 | model = Xception(learning_rate, distributed=gpus > 1)
39 | else:
40 | raise ValueError(f"Unknown model: {args.model}")
41 |
42 | train_dataset = AVDeepfake1mPlusPlusImages(
43 | subset="train",
44 | data_root=args.data_root,
45 | take_num=args.num_train,
46 | use_video_label=True # For video-level label access, set True
47 | )
48 |
49 | # For validation, you can still do the normal dataset
50 | val_dataset = AVDeepfake1mPlusPlusImages(
51 | subset="val",
52 | data_root=args.data_root,
53 | take_num=args.num_val,
54 | use_video_label=True
55 | )
56 |
57 | # Parse precision properly
58 | try:
59 | precision = int(args.precision)
60 | except ValueError:
61 | precision = args.precision
62 |
63 | monitor = "val_loss"
64 |
65 | trainer = Trainer(
66 | log_every_n_steps=50,
67 | precision=precision,
68 | max_epochs=args.max_epochs,
69 | callbacks=[
70 | ModelCheckpoint(
71 | dirpath=f"./ckpt/{args.model}",
72 | save_last=True,
73 | filename=args.model + "-{epoch}-{val_loss:.3f}",
74 | monitor=monitor,
75 | mode="min"
76 | ),
77 | LrLogger(),
78 | EarlyStoppingLR(lr_threshold=1e-7)
79 | ],
80 | enable_checkpointing=True,
81 | benchmark=True,
82 | accelerator="gpu",
83 | devices=args.gpus,
84 | strategy="ddp" if args.gpus > 1 else "auto",
85 | # ckpt_path=args.resume,
86 | # If you're on an older version of Lightning, you may need `strategy='ddp'` just the same, but this is typical.
87 | )
88 |
89 | trainer.fit(
90 | model,
91 | train_dataloaders=DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0),
92 | val_dataloaders=DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0),
93 | ckpt_path=args.resume,
94 | )
95 |
--------------------------------------------------------------------------------
/examples/xception/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from lightning.pytorch import Callback, Trainer, LightningModule
4 |
5 |
6 | class LrLogger(Callback):
7 | """Log learning rate in each epoch start."""
8 |
9 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
10 | for i, optimizer in enumerate(trainer.optimizers):
11 | for j, params in enumerate(optimizer.param_groups):
12 | key = f"opt{i}_lr{j}"
13 | value = params["lr"]
14 | pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
15 | pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
16 |
17 |
18 | class EarlyStoppingLR(Callback):
19 | """Early stop model training when the LR is lower than threshold."""
20 |
21 | def __init__(self, lr_threshold: float, mode="all"):
22 | self.lr_threshold = lr_threshold
23 |
24 | if mode in ("any", "all"):
25 | self.mode = mode
26 | else:
27 | raise ValueError(f"mode must be one of ('any', 'all')")
28 |
29 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
30 | self._run_early_stop_checking(trainer)
31 |
32 | def _run_early_stop_checking(self, trainer: Trainer) -> None:
33 | metrics = trainer._logger_connector.callback_metrics
34 | if len(metrics) == 0:
35 | return
36 | all_lr = []
37 | for key, value in metrics.items():
38 | if re.match(r"opt\d+_lr\d+", key):
39 | all_lr.append(value)
40 |
41 | if len(all_lr) == 0:
42 | return
43 |
44 | if self.mode == "all":
45 | if all(lr <= self.lr_threshold for lr in all_lr):
46 | trainer.should_stop = True
47 | elif self.mode == "any":
48 | if any(lr <= self.lr_threshold for lr in all_lr):
49 | trainer.should_stop = True
50 |
--------------------------------------------------------------------------------
/examples/xception/xception.py:
--------------------------------------------------------------------------------
1 | import timm
2 |
3 | from lightning.pytorch import LightningModule
4 | from torch.nn import BCEWithLogitsLoss
5 | from torch.optim import Adam
6 |
7 |
8 | class Xception(LightningModule):
9 | def __init__(self, lr, distributed=False):
10 | super(Xception, self).__init__()
11 | self.lr = lr
12 | self.model = timm.create_model('xception', pretrained=True, num_classes=1)
13 | self.loss_fn = BCEWithLogitsLoss()
14 | self.distributed = distributed
15 |
16 | def forward(self, x):
17 | x = self.model(x)
18 | return x
19 |
20 | def training_step(self, batch, batch_idx):
21 | x, y = batch
22 | y_hat = self(x)
23 | loss = self.loss_fn(y_hat, y.unsqueeze(1))
24 | self.log('train_loss', loss)
25 | return loss
26 |
27 | def validation_step(self, batch, batch_idx):
28 | x, y = batch
29 | y_hat = self(x)
30 | loss = self.loss_fn(y_hat, y.unsqueeze(1))
31 | self.log('val_loss', loss)
32 | return loss
33 |
34 | def configure_optimizers(self):
35 | optimizer = Adam(self.parameters(), lr=self.lr)
36 | return [optimizer]
37 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.7,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [project]
6 | name = "avdeepfake1m"
7 | requires-python = ">=3.7"
8 | classifiers = [
9 | "Programming Language :: Rust",
10 | "Programming Language :: Python :: Implementation :: CPython",
11 | "Programming Language :: Python :: 3",
12 | "Programming Language :: Python :: 3.7",
13 | "Programming Language :: Python :: 3.8",
14 | "Programming Language :: Python :: 3.9",
15 | "Programming Language :: Python :: 3.10",
16 | "Programming Language :: Python :: 3.11",
17 | "Programming Language :: Python :: 3.12",
18 | "License :: Other/Proprietary License",
19 | "Operating System :: OS Independent",
20 | "Intended Audience :: Developers",
21 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
22 | "Topic :: Multimedia :: Video",
23 | "Topic :: Utilities"
24 | ]
25 | authors = [
26 | {name = "ControlNet", email = "smczx@hotmail.com"}
27 | ]
28 | dynamic = ["version"]
29 | readme = "README.md"
30 | keywords = ["pytorch", "AI"]
31 | dependencies = [
32 | "torch",
33 | "lightning>=2",
34 | "tqdm",
35 | "einops",
36 | "opencv-python",
37 | "numpy>=1,<2",
38 | "torchvision",
39 | "torchaudio",
40 | "av",
41 | "pandas~=2.0",
42 | "torchmetrics~=1.0",
43 | ]
44 |
45 | [project.urls]
46 | "Homepage" = "https://github.com/ControlNet/AV-Deepfake1M"
47 | "Source Code" = "https://github.com/ControlNet/AV-Deepfake1M"
48 | "Bug Tracker" = "https://github.com/ControlNet/AV-Deepfake1M/issues"
49 |
50 | [tool.maturin]
51 | features = ["pyo3/extension-module"]
52 | module-name = "avdeepfake1m._evaluation"
53 | python-source = "python"
54 |
55 | [tool.pixi.project]
56 | channels = ["pytorch", "conda-forge/label/rust_dev", "conda-forge"]
57 | platforms = ["linux-64", "osx-arm64", "osx-64", "win-64"]
58 |
59 | [tool.pixi.dependencies]
60 | pytorch = { version = "~=2.2.2", channel = "pytorch" }
61 | torchvision = { version = "~=0.17.2", channel = "pytorch" }
62 | torchaudio = { version = "~=2.2.2", channel = "pytorch" }
63 | cpuonly = { version = "~=2.0", channel = "pytorch" } # for developing the dataloader and evaluator only
64 | python = "~=3.11.0"
65 | numpy = { version = "~=1.0", channel = "conda-forge" }
66 | ffmpeg = { version = "*", channel = "conda-forge" }
67 | av = { version = "*", channel = "conda-forge" }
68 | pandas = ">=2.2.3,<3"
69 | torchmetrics = ">=1.4.2,<2"
70 |
71 | [tool.pixi.target.linux-64.dependencies]
72 | rust = { version = "*", channel = "conda-forge/label/rust_dev" }
73 |
74 | [tool.pixi.target.osx-64.dependencies]
75 | rust_osx-64 = { version = "*", channel = "conda-forge/label/rust_dev" }
76 |
77 | [tool.pixi.target.win-64.dependencies]
78 | rust_win-64 = { version = "*", channel = "conda-forge/label/rust_dev" }
79 |
80 | [tool.pixi.target.osx-arm64.dependencies]
81 | rust_osx-arm64 = { version = "*", channel = "conda-forge/label/rust_dev" }
82 |
83 | [tool.pixi.pypi-dependencies]
84 | avdeepfake1m = { path = ".", editable = true }
85 |
86 | [tool.pixi.build-dependencies]
87 | maturin = ">=1.7,<2.0"
88 | setuptools = "*"
89 | pip = "*"
90 |
91 | [tool.pixi.tasks]
92 | develop = "maturin develop"
93 | build = "maturin build --release --find-interpreter"
94 |
--------------------------------------------------------------------------------
/python/avdeepfake1m/__init__.py:
--------------------------------------------------------------------------------
1 | from . import evaluation
2 | from . import utils
3 | from . import loader
4 |
5 | __all__ = ["evaluation", "utils", "loader"]
6 |
--------------------------------------------------------------------------------
/python/avdeepfake1m/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .auc import auc
2 | from .._evaluation import ap_1d, ar_1d, ap_ar_1d
3 |
4 | __all__ = ["auc", "ap_1d", "ar_1d", "ap_ar_1d"]
5 |
--------------------------------------------------------------------------------
/python/avdeepfake1m/evaluation/__init__.pyi:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Union
2 |
3 |
4 | def ap_1d(proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float,
5 | iou_thresholds: List[float]) -> Dict[float, float]:
6 | pass
7 |
8 |
9 | def ar_1d(proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float, n_proposals: List[int],
10 | iou_thresholds: List[float]) -> Dict[int, float]:
11 | pass
12 |
13 |
14 | def ap_ar_1d(
15 | proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float,
16 | ap_iou_thresholds: List[float], ar_n_proposals: List[int], ar_iou_thresholds: List[float]
17 | ) -> Dict[str, Dict[Union[float, int], float]]:
18 | pass
19 |
20 |
21 | def auc(prediction_file: str, reference_path: str, file_key: str, value_key: str) -> float:
22 | pass
23 |
--------------------------------------------------------------------------------
/python/avdeepfake1m/evaluation/auc.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from torchmetrics import AUROC
4 |
5 | from ..utils import read_json
6 |
7 |
8 | def auc(prediction_file: str, reference_path: str, file_key: str, value_key: str) -> float:
9 |
10 | prediction = pd.read_csv(prediction_file, header=None, sep=";")
11 | # convert to dict
12 | prediction_dict = {}
13 | for i in range(len(prediction)):
14 | prediction_dict[prediction.iloc[i, 0]] = prediction.iloc[i, 1]
15 |
16 | gt = read_json(reference_path)
17 |
18 | # make it as list
19 | truth = []
20 | prediction = []
21 | for gt_item in gt:
22 | key = gt_item[file_key]
23 | truth.append(int(len(gt_item[value_key]) > 0))
24 | prediction.append(prediction_dict[key])
25 |
26 | # to tensor for torchmetrics
27 | truth = torch.tensor(truth)
28 | prediction = torch.tensor(prediction)
29 |
30 | # compute auc
31 | auroc = AUROC(task="binary")
32 | return auroc(prediction, truth).item()
33 |
--------------------------------------------------------------------------------
/python/avdeepfake1m/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from einops import rearrange
9 | from torch import Tensor
10 | from torch.nn import functional as F
11 |
12 |
13 | def read_json(path: str, object_hook=None):
14 | with open(path, 'r') as f:
15 | return json.load(f, object_hook=object_hook)
16 |
17 |
18 | def read_video(path: str):
19 | video, audio, info = torchvision.io.read_video(path, pts_unit="sec")
20 | video = video.permute(0, 3, 1, 2) / 255
21 | audio = audio.permute(1, 0)
22 | if audio.shape[0] == 0:
23 | audio = torch.zeros(1, 1)
24 | return video, audio, info
25 |
26 |
27 | def read_video_fast(path: str):
28 | cap = cv2.VideoCapture(path)
29 | frames = []
30 | while True:
31 | ret, frame = cap.read()
32 | if not ret:
33 | break
34 | else:
35 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
36 | cap.release()
37 | video = np.stack(frames, axis=0)
38 | video = rearrange(video, 'T H W C -> T C H W')
39 | return torch.from_numpy(video) / 255
40 |
41 |
42 | def resize_video(tensor: Tensor, size: Tuple[int, int], resize_method: str = "bicubic") -> Tensor:
43 | return F.interpolate(tensor, size=size, mode=resize_method)
44 |
45 |
46 | def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
47 | """Compute jaccard score between a box and the anchors."""
48 |
49 | len_anchors = anchors_max - anchors_min
50 | int_xmin = np.maximum(anchors_min, box_min)
51 | int_xmax = np.minimum(anchors_max, box_max)
52 | inter_len = np.maximum(int_xmax - int_xmin, 0.)
53 | union_len = len_anchors - inter_len + box_max - box_min
54 | iou = inter_len / (union_len + 1e-8)
55 | return iou
56 |
57 |
58 | def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
59 | # calculate the overlap proportion between the anchor and all bbox for supervise signal,
60 | # the length of the anchor is 0.01
61 | len_anchors = anchors_max - anchors_min
62 | int_xmin = np.maximum(anchors_min, box_min)
63 | int_xmax = np.minimum(anchors_max, box_max)
64 | inter_len = np.maximum(int_xmax - int_xmin, 0.)
65 | scores = np.divide(inter_len, len_anchors + 1e-8)
66 | return scores
67 |
68 |
69 | def iou_1d(proposal, target) -> Tensor:
70 | """
71 | Calculate 1D IOU for N proposals with L labels.
72 |
73 | Args:
74 | proposal (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The predicted array with [M, 2]. First column is
75 | beginning, second column is end.
76 | target (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The label array with [N, 2]. First column is
77 | beginning, second column is end.
78 |
79 | Returns:
80 | :class:`~torch.Tensor`: The iou result with [M, N].
81 | """
82 | if type(proposal) is np.ndarray:
83 | proposal = torch.from_numpy(proposal)
84 |
85 | if type(target) is np.ndarray:
86 | target = torch.from_numpy(target)
87 |
88 | proposal_begin = proposal[:, 0].unsqueeze(0).T
89 | proposal_end = proposal[:, 1].unsqueeze(0).T
90 | target_begin = target[:, 0]
91 | target_end = target[:, 1]
92 |
93 | inner_begin = torch.maximum(proposal_begin, target_begin)
94 | inner_end = torch.minimum(proposal_end, target_end)
95 | outer_begin = torch.minimum(proposal_begin, target_begin)
96 | outer_end = torch.maximum(proposal_end, target_end)
97 |
98 | inter = torch.clamp(inner_end - inner_begin, min=0.)
99 | union = outer_end - outer_begin
100 | return inter / union
101 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | #![feature(iter_map_windows)]
2 |
3 | use pyo3::prelude::*;
4 |
5 | pub mod loc_1d;
6 |
7 | #[pymodule]
8 | fn _evaluation(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
9 | m.add("__version__", env!("CARGO_PKG_VERSION"))?;
10 | m.add_function(wrap_pyfunction!(loc_1d::ap_1d, m)?)?;
11 | m.add_function(wrap_pyfunction!(loc_1d::ar_1d, m)?)?;
12 | m.add_function(wrap_pyfunction!(loc_1d::ap_ar_1d, m)?)?;
13 |
14 | Ok(())
15 | }
16 |
--------------------------------------------------------------------------------
/src/loc_1d.rs:
--------------------------------------------------------------------------------
1 | extern crate serde_json;
2 | extern crate simd_json;
3 |
4 | use std::collections::HashMap;
5 | use std::fs;
6 |
7 | use ndarray::{concatenate, OwnedRepr, s, stack, Zip};
8 | use ndarray::prelude::*;
9 | use pyo3::prelude::*;
10 | use pyo3::types::{IntoPyDict, PyDict};
11 | use rayon::prelude::*;
12 | use serde::{Deserialize, Serialize};
13 | use serde_json::{Map, Value};
14 |
15 | #[derive(Serialize, Deserialize, Debug)]
16 | struct Metadata {
17 | file: String,
18 | segments: Vec>,
19 | }
20 |
21 | fn convert_metadata_info_to_metadata(
22 | metadata_info: Map,
23 | file_key: &str,
24 | value_key: &str,
25 | ) -> Metadata {
26 | Metadata {
27 | file: metadata_info.get(file_key).unwrap().as_str().unwrap().to_string(),
28 | segments: metadata_info.get(value_key).unwrap().as_array().unwrap().iter().map(|x| {
29 | x.as_array().unwrap().iter().map(|x| x.as_f64().unwrap() as f32).collect()
30 | }).collect(),
31 | }
32 | }
33 |
34 | fn iou_1d(proposal: Array2, target: &Array2) -> Array2 {
35 | let m = proposal.nrows();
36 | let n = target.nrows();
37 |
38 | let mut ious = Array2::::zeros((m, n));
39 |
40 | for i in 0..m {
41 | for j in 0..n {
42 | let proposal_start = proposal[[i, 0]];
43 | let proposal_end = proposal[[i, 1]];
44 | let target_start = target[[j, 0]];
45 | let target_end = target[[j, 1]];
46 |
47 | let inner_begin = proposal_start.max(target_start);
48 | let inner_end = proposal_end.min(target_end);
49 | let outer_begin = proposal_start.min(target_start);
50 | let outer_end = proposal_end.max(target_end);
51 |
52 | let intersection = (inner_end - inner_begin).max(0.0);
53 | let union = outer_end - outer_begin;
54 | ious[[i, j]] = intersection / union;
55 | }
56 | }
57 |
58 | ious
59 | }
60 |
61 | fn calc_ap_curve(is_tp: Array1, n_labels: f32) -> Array2 {
62 | let acc_tp = Array1::from_vec(
63 | is_tp.iter().scan(0.0, |state, &x| {
64 | if x { *state += 1.0 }
65 | Some(*state)
66 | }).collect()
67 | );
68 |
69 | let precision: Array1 = acc_tp.iter().enumerate().map(|(i, &x)| x / (i as f32 + 1.0)).collect();
70 | let recall: Array1 = acc_tp / n_labels;
71 | let binding = stack!(Axis(0), recall.view(), precision.view());
72 | let binding = binding.t();
73 |
74 | concatenate![
75 | Axis(0),
76 | arr2(&[[1., 0.]]).view(),
77 | binding.slice(s![..;-1, ..])
78 | ]
79 | }
80 |
81 | fn calculate_ap(curve: &Array2) -> f32 {
82 | let x = curve.column(0).to_owned();
83 | let y = curve.column(1).to_owned();
84 |
85 | let y_max = Array1::from(y.iter().scan(None, |state, &x| {
86 | if state.is_none() || x > state.unwrap() {
87 | *state = Some(x);
88 | }
89 |
90 | *state
91 | }).collect::>());
92 |
93 | let x_diff: Array1 = x
94 | .into_iter()
95 | .map_windows(|[x, y]| (y - x).abs())
96 | .collect();
97 |
98 | (x_diff * y_max.slice(s![..-1])).sum()
99 | }
100 |
101 | fn get_ap_values(
102 | iou_threshold: f32,
103 | proposals: &Array2,
104 | labels: &Array2,
105 | fps: f32,
106 | ) -> (Array1, Array1) {
107 | let n_labels = labels.len_of(Axis(0));
108 | let n_proposals = proposals.len_of(Axis(0));
109 | let local_proposals = if proposals.shape() != [0] {
110 | proposals.clone()
111 | } else {
112 | proposals.clone()
113 | .into_shape((0, 3))
114 | .unwrap()
115 | };
116 |
117 | let ious = if n_labels > 0 {
118 | iou_1d(local_proposals.slice(s![.., 1..]).mapv(|x| x / fps), labels)
119 | } else {
120 | Array::zeros((n_proposals, 0))
121 | };
122 |
123 | let confidence = local_proposals.column(0).to_owned();
124 | let potential_tp = ious.mapv(|x| x > iou_threshold);
125 |
126 | let mut is_tp = Array1::from_elem((n_proposals, ), false);
127 |
128 | let mut tp_indexes = Vec::new();
129 | for i in 0..n_labels {
130 | let potential_tp_col = potential_tp.column(i);
131 | let potential_tp_index = potential_tp_col.iter().enumerate().filter(|(_, &x)| x).map(|(j, _)| j);
132 | for j in potential_tp_index {
133 | if !tp_indexes.contains(&j) {
134 | tp_indexes.push(j);
135 | break;
136 | }
137 | }
138 | }
139 |
140 | if !tp_indexes.is_empty() {
141 | for &j in &tp_indexes {
142 | is_tp[j] = true;
143 | }
144 | }
145 |
146 | (confidence, is_tp)
147 | }
148 |
149 | fn calc_ap_scores(
150 | iou_thresholds: &Vec,
151 | metadata: &Vec,
152 | proposals_map: &Proposals,
153 | fps: f32,
154 | ) -> Vec<(f32, f32)> {
155 | iou_thresholds.par_iter().map(|iou| {
156 | let (values, labels): (Vec<_>, Vec) = metadata
157 | .par_iter()
158 | .map(|meta| {
159 | let proposals = &proposals_map.content[&meta.file];
160 | let rows = meta.segments.len();
161 | let x: Vec = meta.segments.iter().flatten().copied().collect();
162 | let labels = Array2::from_shape_vec((rows, 2), x).unwrap().to_owned();
163 | let meta_value = get_ap_values(*iou, &proposals.row, &labels, fps);
164 |
165 | (meta_value, labels.len_of(Axis(0)) as isize)
166 | })
167 | .unzip();
168 |
169 | let n_labels = labels.iter().sum::() as f32;
170 |
171 | let (r, n): (Vec<_>, Vec<_>) = values.into_iter().unzip();
172 | let confidence = concatenate(
173 | Axis(0),
174 | &r.iter()
175 | .map(|x| x.view())
176 | .collect::>(),
177 | ).unwrap();
178 | let is_tp = concatenate(
179 | Axis(0),
180 | &n.iter()
181 | .map(|x| x.view())
182 | .collect::>(),
183 | ).unwrap();
184 |
185 | let mut indices: Vec = (0..confidence.len()).collect();
186 | indices.sort_by(|&a, &b| confidence[b].partial_cmp(&confidence[a]).unwrap());
187 | let is_tp = is_tp.select(Axis(0), &indices);
188 | let curve = calc_ap_curve(is_tp, n_labels);
189 | let ap = calculate_ap(&curve);
190 |
191 | (*iou, ap)
192 | }).collect::>()
193 | }
194 |
195 |
196 | fn cummax_2d(array: &Array2) -> Array2 {
197 | let mut result = array.clone();
198 |
199 | for mut column in result.axis_iter_mut(Axis(1)) {
200 | let mut cummax = column[0];
201 |
202 | for row in column.iter_mut().skip(1) {
203 | cummax = cummax.max(*row);
204 | *row = cummax;
205 | }
206 | }
207 |
208 | result
209 | }
210 |
211 | fn calc_ar_values(
212 | n_proposals: &Vec,
213 | iou_thresholds: &Vec,
214 | proposals: &Array2,
215 | labels: &Array2,
216 | fps: f32,
217 | ) -> ArrayBase, Ix3> {
218 | let max_proposals = *n_proposals.iter().max().unwrap();
219 | let max_proposals = max_proposals.min(proposals.nrows());
220 |
221 | let mut proposals = proposals.slice(s![..max_proposals, ..]).to_owned();
222 | if proposals.is_empty() {
223 | proposals = Array2::zeros((0, 3)).into();
224 | }
225 |
226 | let n_proposals_clamped = n_proposals.iter().map(|&n| n.min(proposals.nrows())).collect::>();
227 | let n_labels = labels.nrows();
228 |
229 | let ious = if n_labels > 0 {
230 | iou_1d(proposals.slice(s![.., 1..]).mapv(|x| x / fps), labels)
231 | } else {
232 | Array::zeros((max_proposals, 0)) // TODO: maybe short-circuit
233 | };
234 |
235 | let mut values = Array3::zeros((iou_thresholds.len(), n_proposals_clamped.len(), 2));
236 | if !proposals.is_empty() {
237 | let iou_max = cummax_2d(&ious); // (n_iou, n_labels)
238 | for (threshold_idx, &threshold) in iou_thresholds.iter().enumerate() {
239 | for (n_proposals_idx, &n_proposal) in n_proposals_clamped.iter().enumerate() {
240 | let tp = iou_max.row(n_proposal - 1).iter().filter(|&&iou| iou > threshold).count();
241 | values[[threshold_idx, n_proposals_idx, 0]] = tp;
242 | values[[threshold_idx, n_proposals_idx, 1]] = n_labels - tp;
243 | }
244 | }
245 | }
246 | values
247 | }
248 |
249 | fn calc_ar_scores(
250 | n_proposals: &Vec,
251 | iou_thresholds: &Vec,
252 | metadata: &Vec,
253 | proposals_map: &Proposals,
254 | fps: f32,
255 | ) -> Vec<(usize, f32)> {
256 | let values = metadata.par_iter().map(|meta| {
257 | let proposals = &proposals_map.content[&meta.file];
258 |
259 | let rows = meta.segments.len();
260 | let x: Vec = meta.segments.iter().flatten().copied().collect();
261 | let labels = Array2::from_shape_vec((rows, 2), x).unwrap().to_owned();
262 |
263 | calc_ar_values(&n_proposals, iou_thresholds, &proposals.row, &labels, fps)
264 | }).collect::>();
265 |
266 | let values = stack(
267 | Axis(0),
268 | &values
269 | .iter()
270 | .map(|x| x.view())
271 | .collect::>(),
272 | ).unwrap();
273 |
274 | let values_sum = values.sum_axis(Axis(0));
275 | let tp = values_sum.slice(s![.., .., 0]);
276 | let f_n = values_sum.slice(s![.., .., 1]);
277 |
278 | let recall = Zip::from(&tp).and(&f_n).map_collect(|&x, &y| {
279 | let div = x as f32 + y as f32;
280 | if div == 0. {
281 | 0.
282 | } else {
283 | x as f32 / div
284 | }
285 | });
286 |
287 | n_proposals.iter().enumerate().map(|(ix, &prop)| {
288 | (prop, recall.column(ix).mean().unwrap())
289 | }).collect::>()
290 | }
291 |
292 |
293 | #[derive(Deserialize, Debug)]
294 | #[serde(transparent)]
295 | struct ProposalRow {
296 | #[serde(with = "serde_ndim")]
297 | pub row: Array2,
298 | }
299 |
300 | #[derive(Deserialize, Debug)]
301 | #[serde(transparent)]
302 | struct Proposals {
303 | pub content: HashMap,
304 | }
305 |
306 | fn load_json(proposals_path: &str, labels_path: &str, file_key: &str, value_key: &str) -> (Vec, Proposals) {
307 | let mut proposals_raw = fs::read_to_string(proposals_path).expect("Unable to read proposal file");
308 | let mut labels_raw = fs::read_to_string(labels_path).expect("Unable to read labels file");
309 | let labels_infos: Vec