├── .gitattributes ├── .github └── workflows │ └── CI.yml ├── .gitignore ├── CITATION.cff ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── assets └── teaser.png ├── eula.pdf ├── examples ├── batfd │ ├── README.md │ ├── batfd.toml │ ├── batfd │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── audio_encoder.py │ │ │ ├── batfd.py │ │ │ ├── batfd_plus.py │ │ │ ├── boundary_module.py │ │ │ ├── boundary_module_plus.py │ │ │ ├── frame_classifier.py │ │ │ ├── fusion_module.py │ │ │ ├── loss.py │ │ │ └── video_encoder.py │ │ ├── post_process.py │ │ └── utils.py │ ├── batfd_plus.toml │ ├── evaluate.py │ ├── infer.py │ └── train.py └── xception │ ├── README.md │ ├── evaluate.py │ ├── infer.py │ ├── train.py │ ├── utils.py │ └── xception.py ├── pixi.lock ├── pyproject.toml ├── python └── avdeepfake1m │ ├── __init__.py │ ├── evaluation │ ├── __init__.py │ ├── __init__.pyi │ └── auc.py │ ├── loader.py │ └── utils.py └── src ├── lib.rs └── loc_1d.rs /.gitattributes: -------------------------------------------------------------------------------- 1 | # GitHub syntax highlighting 2 | pixi.lock linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by maturin v1.7.4 2 | # To update, run 3 | # 4 | # maturin generate-ci github 5 | # 6 | name: CI 7 | 8 | on: 9 | push: 10 | branches: 11 | - main 12 | - master 13 | tags: 14 | - '*' 15 | pull_request: 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | linux: 23 | runs-on: ${{ matrix.platform.runner }} 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | platform: 28 | - runner: ubuntu-latest 29 | target: x86_64 30 | - runner: ubuntu-latest 31 | target: aarch64 32 | - runner: ubuntu-latest 33 | target: armv7 34 | - runner: ubuntu-latest 35 | target: s390x 36 | - runner: ubuntu-latest 37 | target: ppc64le 38 | steps: 39 | - uses: actions/checkout@v4 40 | - uses: actions/setup-python@v5 41 | with: 42 | python-version: "3.11" 43 | - name: Update version if tag 44 | if: startsWith(github.ref, 'refs/tags/') 45 | run: | 46 | VERSION=${GITHUB_REF#refs/tags/} 47 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml 48 | - name: Build wheels 49 | uses: PyO3/maturin-action@v1 50 | with: 51 | rust-toolchain: nightly 52 | target: ${{ matrix.platform.target }} 53 | args: --release --out dist --find-interpreter 54 | sccache: 'true' 55 | manylinux: auto 56 | - name: Upload wheels 57 | uses: actions/upload-artifact@v4 58 | with: 59 | name: wheels-linux-${{ matrix.platform.target }} 60 | path: dist 61 | 62 | musllinux: 63 | runs-on: ${{ matrix.platform.runner }} 64 | strategy: 65 | fail-fast: false 66 | matrix: 67 | platform: 68 | - runner: ubuntu-latest 69 | target: x86_64 70 | - runner: ubuntu-latest 71 | target: aarch64 72 | steps: 73 | - uses: actions/checkout@v4 74 | - uses: actions/setup-python@v5 75 | with: 76 | python-version: "3.11" 77 | - uses: dtolnay/rust-toolchain@nightly 78 | - name: Update version if tag 79 | if: startsWith(github.ref, 'refs/tags/') 80 | run: | 81 | VERSION=${GITHUB_REF#refs/tags/} 82 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml 83 | - name: Build wheels 84 | uses: PyO3/maturin-action@v1 85 | with: 86 | rust-toolchain: nightly 87 | target: ${{ matrix.platform.target }} 88 | args: --release --out dist --find-interpreter 89 | sccache: 'true' 90 | manylinux: musllinux_1_2 91 | - name: Upload wheels 92 | uses: actions/upload-artifact@v4 93 | with: 94 | name: wheels-musllinux-${{ matrix.platform.target }} 95 | path: dist 96 | 97 | windows: 98 | runs-on: ${{ matrix.platform.runner }} 99 | strategy: 100 | fail-fast: false 101 | matrix: 102 | platform: 103 | - runner: windows-latest 104 | target: x64 105 | steps: 106 | - uses: actions/checkout@v4 107 | - uses: actions/setup-python@v5 108 | with: 109 | python-version: "3.11" 110 | architecture: ${{ matrix.platform.target }} 111 | - uses: dtolnay/rust-toolchain@nightly 112 | - name: Update version if tag 113 | if: startsWith(github.ref, 'refs/tags/') 114 | shell: bash 115 | run: | 116 | VERSION=${GITHUB_REF#refs/tags/} 117 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml 118 | - name: Build wheels 119 | uses: PyO3/maturin-action@v1 120 | with: 121 | rust-toolchain: nightly 122 | target: ${{ matrix.platform.target }} 123 | args: --release --out dist --find-interpreter 124 | sccache: 'true' 125 | - name: Upload wheels 126 | uses: actions/upload-artifact@v4 127 | with: 128 | name: wheels-windows-${{ matrix.platform.target }} 129 | path: dist 130 | 131 | macos: 132 | runs-on: ${{ matrix.platform.runner }} 133 | strategy: 134 | fail-fast: false 135 | matrix: 136 | platform: 137 | - runner: macos-13 138 | target: x86_64 139 | - runner: macos-14 140 | target: aarch64 141 | steps: 142 | - uses: actions/checkout@v4 143 | - uses: actions/setup-python@v5 144 | with: 145 | python-version: "3.11" 146 | - uses: dtolnay/rust-toolchain@nightly 147 | - name: Update version if tag 148 | if: startsWith(github.ref, 'refs/tags/') 149 | run: | 150 | VERSION=${GITHUB_REF#refs/tags/} 151 | sed -i "" "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml 152 | - name: Build wheels 153 | uses: PyO3/maturin-action@v1 154 | with: 155 | rust-toolchain: nightly 156 | target: ${{ matrix.platform.target }} 157 | args: --release --out dist --find-interpreter 158 | sccache: 'true' 159 | - name: Upload wheels 160 | uses: actions/upload-artifact@v4 161 | with: 162 | name: wheels-macos-${{ matrix.platform.target }} 163 | path: dist 164 | 165 | sdist: 166 | runs-on: ubuntu-latest 167 | steps: 168 | - uses: actions/checkout@v4 169 | - uses: dtolnay/rust-toolchain@nightly 170 | - name: Update version if tag 171 | if: startsWith(github.ref, 'refs/tags/') 172 | run: | 173 | VERSION=${GITHUB_REF#refs/tags/} 174 | sed -i "s/version = \"0.0.0\"/version = \"$VERSION\"/" Cargo.toml 175 | - name: Build sdist 176 | uses: PyO3/maturin-action@v1 177 | with: 178 | rust-toolchain: nightly 179 | command: sdist 180 | args: --out dist 181 | - name: Upload sdist 182 | uses: actions/upload-artifact@v4 183 | with: 184 | name: wheels-sdist 185 | path: dist 186 | 187 | release: 188 | name: Release 189 | runs-on: ubuntu-latest 190 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} 191 | needs: [linux, musllinux, windows, macos, sdist] 192 | permissions: 193 | # Use to sign the release artifacts 194 | id-token: write 195 | # Used to upload release artifacts 196 | contents: write 197 | # Used to generate artifact attestation 198 | attestations: write 199 | steps: 200 | - uses: actions/download-artifact@v4 201 | - name: Generate artifact attestation 202 | uses: actions/attest-build-provenance@v1 203 | with: 204 | subject-path: 'wheels-*/*' 205 | - name: Publish to PyPI 206 | if: startsWith(github.ref, 'refs/tags/') 207 | uses: PyO3/maturin-action@v1 208 | env: 209 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 210 | with: 211 | command: upload 212 | args: --non-interactive --skip-existing wheels-*/* 213 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=intellij,python,windows,macos,linux,jupyternotebooks 3 | 4 | ### Intellij ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # AWS User-specific 16 | .idea/**/aws.xml 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/artifacts 39 | # .idea/compiler.xml 40 | # .idea/jarRepositories.xml 41 | # .idea/modules.xml 42 | # .idea/*.iml 43 | # .idea/modules 44 | # *.iml 45 | # *.ipr 46 | 47 | # CMake 48 | cmake-build-*/ 49 | 50 | # Mongo Explorer plugin 51 | .idea/**/mongoSettings.xml 52 | 53 | # File-based project format 54 | *.iws 55 | 56 | # IntelliJ 57 | out/ 58 | 59 | # mpeltonen/sbt-idea plugin 60 | .idea_modules/ 61 | 62 | # JIRA plugin 63 | atlassian-ide-plugin.xml 64 | 65 | # Cursive Clojure plugin 66 | .idea/replstate.xml 67 | 68 | # SonarLint plugin 69 | .idea/sonarlint/ 70 | 71 | # Crashlytics plugin (for Android Studio and IntelliJ) 72 | com_crashlytics_export_strings.xml 73 | crashlytics.properties 74 | crashlytics-build.properties 75 | fabric.properties 76 | 77 | # Editor-based Rest Client 78 | .idea/httpRequests 79 | 80 | # Android studio 3.1+ serialized cache file 81 | .idea/caches/build_file_checksums.ser 82 | 83 | ### Intellij Patch ### 84 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 85 | 86 | # *.iml 87 | # modules.xml 88 | # .idea/misc.xml 89 | # *.ipr 90 | .idea 91 | # Sonarlint plugin 92 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 93 | .idea/**/sonarlint/ 94 | 95 | # SonarQube Plugin 96 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 97 | .idea/**/sonarIssues.xml 98 | 99 | # Markdown Navigator plugin 100 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 101 | .idea/**/markdown-navigator.xml 102 | .idea/**/markdown-navigator-enh.xml 103 | .idea/**/markdown-navigator/ 104 | 105 | # Cache file creation bug 106 | # See https://youtrack.jetbrains.com/issue/JBR-2257 107 | .idea/$CACHE_FILE$ 108 | 109 | # CodeStream plugin 110 | # https://plugins.jetbrains.com/plugin/12206-codestream 111 | .idea/codestream.xml 112 | 113 | # Azure Toolkit for IntelliJ plugin 114 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 115 | .idea/**/azureSettings.xml 116 | 117 | ### JupyterNotebooks ### 118 | # gitignore template for Jupyter Notebooks 119 | # website: http://jupyter.org/ 120 | 121 | .ipynb_checkpoints 122 | */.ipynb_checkpoints/* 123 | 124 | # IPython 125 | profile_default/ 126 | ipython_config.py 127 | 128 | # Remove previous ipynb_checkpoints 129 | # git rm -r .ipynb_checkpoints/ 130 | 131 | ### Linux ### 132 | *~ 133 | 134 | # temporary files which can be created if a process still has a handle open of a deleted file 135 | .fuse_hidden* 136 | 137 | # KDE directory preferences 138 | .directory 139 | 140 | # Linux trash folder which might appear on any partition or disk 141 | .Trash-* 142 | 143 | # .nfs files are created when an open file is removed but is still being accessed 144 | .nfs* 145 | 146 | ### macOS ### 147 | # General 148 | .DS_Store 149 | .AppleDouble 150 | .LSOverride 151 | 152 | # Icon must end with two \r 153 | Icon 154 | 155 | 156 | # Thumbnails 157 | ._* 158 | 159 | # Files that might appear in the root of a volume 160 | .DocumentRevisions-V100 161 | .fseventsd 162 | .Spotlight-V100 163 | .TemporaryItems 164 | .Trashes 165 | .VolumeIcon.icns 166 | .com.apple.timemachine.donotpresent 167 | 168 | # Directories potentially created on remote AFP share 169 | .AppleDB 170 | .AppleDesktop 171 | Network Trash Folder 172 | Temporary Items 173 | .apdisk 174 | 175 | ### macOS Patch ### 176 | # iCloud generated files 177 | *.icloud 178 | 179 | ### Python ### 180 | # Byte-compiled / optimized / DLL files 181 | __pycache__/ 182 | *.py[cod] 183 | *$py.class 184 | 185 | # C extensions 186 | *.so 187 | 188 | # Distribution / packaging 189 | .Python 190 | build/ 191 | develop-eggs/ 192 | dist/ 193 | downloads/ 194 | eggs/ 195 | .eggs/ 196 | lib/ 197 | lib64/ 198 | parts/ 199 | sdist/ 200 | var/ 201 | wheels/ 202 | share/python-wheels/ 203 | *.egg-info/ 204 | .installed.cfg 205 | *.egg 206 | MANIFEST 207 | 208 | # PyInstaller 209 | # Usually these files are written by a python script from a template 210 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 211 | *.manifest 212 | *.spec 213 | 214 | # Installer logs 215 | pip-log.txt 216 | pip-delete-this-directory.txt 217 | 218 | # Unit test / coverage reports 219 | htmlcov/ 220 | .tox/ 221 | .nox/ 222 | .coverage 223 | .coverage.* 224 | .cache 225 | nosetests.xml 226 | coverage.xml 227 | *.cover 228 | *.py,cover 229 | .hypothesis/ 230 | .pytest_cache/ 231 | cover/ 232 | 233 | # Translations 234 | *.mo 235 | *.pot 236 | 237 | # Django stuff: 238 | *.log 239 | local_settings.py 240 | db.sqlite3 241 | db.sqlite3-journal 242 | 243 | # Flask stuff: 244 | instance/ 245 | .webassets-cache 246 | 247 | # Scrapy stuff: 248 | .scrapy 249 | 250 | # Sphinx documentation 251 | docs/_build/ 252 | 253 | # PyBuilder 254 | .pybuilder/ 255 | target/ 256 | 257 | # Jupyter Notebook 258 | 259 | # IPython 260 | 261 | # pyenv 262 | # For a library or package, you might want to ignore these files since the code is 263 | # intended to run in multiple environments; otherwise, check them in: 264 | # .python-version 265 | 266 | # pipenv 267 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 268 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 269 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 270 | # install all needed dependencies. 271 | #Pipfile.lock 272 | 273 | # poetry 274 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 275 | # This is especially recommended for binary packages to ensure reproducibility, and is more 276 | # commonly ignored for libraries. 277 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 278 | #poetry.lock 279 | 280 | # pdm 281 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 282 | #pdm.lock 283 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 284 | # in version control. 285 | # https://pdm.fming.dev/#use-with-ide 286 | .pdm.toml 287 | 288 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 289 | __pypackages__/ 290 | 291 | # Celery stuff 292 | celerybeat-schedule 293 | celerybeat.pid 294 | 295 | # SageMath parsed files 296 | *.sage.py 297 | 298 | # Environments 299 | .env 300 | .venv 301 | env/ 302 | venv/ 303 | ENV/ 304 | env.bak/ 305 | venv.bak/ 306 | 307 | # Spyder project settings 308 | .spyderproject 309 | .spyproject 310 | 311 | # Rope project settings 312 | .ropeproject 313 | 314 | # mkdocs documentation 315 | /site 316 | 317 | # mypy 318 | .mypy_cache/ 319 | .dmypy.json 320 | dmypy.json 321 | 322 | # Pyre type checker 323 | .pyre/ 324 | 325 | # pytype static type analyzer 326 | .pytype/ 327 | 328 | # Cython debug symbols 329 | cython_debug/ 330 | 331 | # PyCharm 332 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 333 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 334 | # and can be added to the global gitignore or merged into this file. For a more nuclear 335 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 336 | #.idea/ 337 | 338 | ### Windows ### 339 | # Windows thumbnail cache files 340 | Thumbs.db 341 | Thumbs.db:encryptable 342 | ehthumbs.db 343 | ehthumbs_vista.db 344 | 345 | # Dump file 346 | *.stackdump 347 | 348 | # Folder config file 349 | [Dd]esktop.ini 350 | 351 | # Recycle Bin used on file shares 352 | $RECYCLE.BIN/ 353 | 354 | # Windows Installer files 355 | *.cab 356 | *.msi 357 | *.msix 358 | *.msm 359 | *.msp 360 | 361 | # Windows shortcuts 362 | *.lnk 363 | 364 | # End of https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks 365 | 366 | .vscode 367 | # pixi environments 368 | .pixi 369 | *.egg-info 370 | 371 | /target 372 | 373 | # Byte-compiled / optimized / DLL files 374 | __pycache__/ 375 | .pytest_cache/ 376 | *.py[cod] 377 | 378 | # C extensions 379 | *.so 380 | 381 | # Distribution / packaging 382 | .Python 383 | .venv/ 384 | env/ 385 | bin/ 386 | build/ 387 | develop-eggs/ 388 | dist/ 389 | eggs/ 390 | lib/ 391 | lib64/ 392 | parts/ 393 | sdist/ 394 | var/ 395 | include/ 396 | man/ 397 | venv/ 398 | *.egg-info/ 399 | .installed.cfg 400 | *.egg 401 | 402 | # Installer logs 403 | pip-log.txt 404 | pip-delete-this-directory.txt 405 | pip-selfcheck.json 406 | 407 | # Unit test / coverage reports 408 | htmlcov/ 409 | .tox/ 410 | .coverage 411 | .cache 412 | nosetests.xml 413 | coverage.xml 414 | 415 | # Translations 416 | *.mo 417 | 418 | # Mr Developer 419 | .mr.developer.cfg 420 | .project 421 | .pydevproject 422 | 423 | # Rope 424 | .ropeproject 425 | 426 | # Django stuff: 427 | *.log 428 | *.pot 429 | 430 | .DS_Store 431 | 432 | # Sphinx documentation 433 | docs/_build/ 434 | 435 | # PyCharm 436 | .idea/ 437 | 438 | # VSCode 439 | .vscode/ 440 | 441 | # Pyenv 442 | .python-version 443 | 444 | /examples/batfd/lightning_logs 445 | /examples/batfd/ckpt 446 | /examples/batfd/output 447 | /examples/xception/lightning_logs 448 | /examples/xception/ckpt 449 | /examples/xception/output 450 | wandb 451 | *.ckpt 452 | *.pth 453 | *.pt 454 | *.jit -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find this work useful in your research, please cite it." 3 | preferred-citation: 4 | type: conference-paper 5 | title: "AV-Deepfake1M: A Large-Scale LLM-Driven Audio-Visual Deepfake Dataset" 6 | authors: 7 | - family-names: "Cai" 8 | given-names: "Zhixi" 9 | - family-names: "Ghosh" 10 | given-names: "Shreya" 11 | - family-names: "Adatia" 12 | given-names: "Aman Pankaj" 13 | - family-names: "Hayat" 14 | given-names: "Munawar" 15 | - family-names: "Dhall" 16 | given-names: "Abhinav" 17 | - family-names: "Stefanov" 18 | given-names: "Kalin" 19 | collection-title: "Proceedings of the 32nd ACM International Conference on Multimedia" 20 | year: 2023 21 | location: 22 | name: "Melbourne, Australia" 23 | start: 7414 24 | end: 7423 25 | doi: "10.1145/3664647.3680795" 26 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "ahash" 7 | version = "0.8.11" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" 10 | dependencies = [ 11 | "cfg-if", 12 | "once_cell", 13 | "version_check", 14 | "zerocopy", 15 | ] 16 | 17 | [[package]] 18 | name = "allocator-api2" 19 | version = "0.2.18" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" 22 | 23 | [[package]] 24 | name = "autocfg" 25 | version = "1.4.0" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" 28 | 29 | [[package]] 30 | name = "avdeepfake1m" 31 | version = "0.0.0" 32 | dependencies = [ 33 | "ndarray", 34 | "pyo3", 35 | "rayon", 36 | "serde", 37 | "serde-ndim", 38 | "serde_json", 39 | "simd-json", 40 | ] 41 | 42 | [[package]] 43 | name = "bumpalo" 44 | version = "3.16.0" 45 | source = "registry+https://github.com/rust-lang/crates.io-index" 46 | checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" 47 | 48 | [[package]] 49 | name = "cblas-sys" 50 | version = "0.1.4" 51 | source = "registry+https://github.com/rust-lang/crates.io-index" 52 | checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" 53 | dependencies = [ 54 | "libc", 55 | ] 56 | 57 | [[package]] 58 | name = "cfg-if" 59 | version = "1.0.0" 60 | source = "registry+https://github.com/rust-lang/crates.io-index" 61 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 62 | 63 | [[package]] 64 | name = "crossbeam-deque" 65 | version = "0.8.5" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" 68 | dependencies = [ 69 | "crossbeam-epoch", 70 | "crossbeam-utils", 71 | ] 72 | 73 | [[package]] 74 | name = "crossbeam-epoch" 75 | version = "0.9.18" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 78 | dependencies = [ 79 | "crossbeam-utils", 80 | ] 81 | 82 | [[package]] 83 | name = "crossbeam-utils" 84 | version = "0.8.20" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" 87 | 88 | [[package]] 89 | name = "either" 90 | version = "1.13.0" 91 | source = "registry+https://github.com/rust-lang/crates.io-index" 92 | checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" 93 | 94 | [[package]] 95 | name = "float-cmp" 96 | version = "0.9.0" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" 99 | dependencies = [ 100 | "num-traits", 101 | ] 102 | 103 | [[package]] 104 | name = "getrandom" 105 | version = "0.2.15" 106 | source = "registry+https://github.com/rust-lang/crates.io-index" 107 | checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" 108 | dependencies = [ 109 | "cfg-if", 110 | "js-sys", 111 | "libc", 112 | "wasi", 113 | "wasm-bindgen", 114 | ] 115 | 116 | [[package]] 117 | name = "halfbrown" 118 | version = "0.2.5" 119 | source = "registry+https://github.com/rust-lang/crates.io-index" 120 | checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" 121 | dependencies = [ 122 | "hashbrown", 123 | "serde", 124 | ] 125 | 126 | [[package]] 127 | name = "hashbrown" 128 | version = "0.14.5" 129 | source = "registry+https://github.com/rust-lang/crates.io-index" 130 | checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" 131 | dependencies = [ 132 | "ahash", 133 | "allocator-api2", 134 | ] 135 | 136 | [[package]] 137 | name = "heck" 138 | version = "0.5.0" 139 | source = "registry+https://github.com/rust-lang/crates.io-index" 140 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" 141 | 142 | [[package]] 143 | name = "indoc" 144 | version = "2.0.5" 145 | source = "registry+https://github.com/rust-lang/crates.io-index" 146 | checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" 147 | 148 | [[package]] 149 | name = "itoa" 150 | version = "1.0.11" 151 | source = "registry+https://github.com/rust-lang/crates.io-index" 152 | checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" 153 | 154 | [[package]] 155 | name = "js-sys" 156 | version = "0.3.72" 157 | source = "registry+https://github.com/rust-lang/crates.io-index" 158 | checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" 159 | dependencies = [ 160 | "wasm-bindgen", 161 | ] 162 | 163 | [[package]] 164 | name = "lexical-core" 165 | version = "1.0.2" 166 | source = "registry+https://github.com/rust-lang/crates.io-index" 167 | checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" 168 | dependencies = [ 169 | "lexical-parse-float", 170 | "lexical-parse-integer", 171 | "lexical-util", 172 | "lexical-write-float", 173 | "lexical-write-integer", 174 | ] 175 | 176 | [[package]] 177 | name = "lexical-parse-float" 178 | version = "1.0.2" 179 | source = "registry+https://github.com/rust-lang/crates.io-index" 180 | checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" 181 | dependencies = [ 182 | "lexical-parse-integer", 183 | "lexical-util", 184 | "static_assertions", 185 | ] 186 | 187 | [[package]] 188 | name = "lexical-parse-integer" 189 | version = "1.0.2" 190 | source = "registry+https://github.com/rust-lang/crates.io-index" 191 | checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" 192 | dependencies = [ 193 | "lexical-util", 194 | "static_assertions", 195 | ] 196 | 197 | [[package]] 198 | name = "lexical-util" 199 | version = "1.0.3" 200 | source = "registry+https://github.com/rust-lang/crates.io-index" 201 | checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" 202 | dependencies = [ 203 | "static_assertions", 204 | ] 205 | 206 | [[package]] 207 | name = "lexical-write-float" 208 | version = "1.0.2" 209 | source = "registry+https://github.com/rust-lang/crates.io-index" 210 | checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" 211 | dependencies = [ 212 | "lexical-util", 213 | "lexical-write-integer", 214 | "static_assertions", 215 | ] 216 | 217 | [[package]] 218 | name = "lexical-write-integer" 219 | version = "1.0.2" 220 | source = "registry+https://github.com/rust-lang/crates.io-index" 221 | checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" 222 | dependencies = [ 223 | "lexical-util", 224 | "static_assertions", 225 | ] 226 | 227 | [[package]] 228 | name = "libc" 229 | version = "0.2.159" 230 | source = "registry+https://github.com/rust-lang/crates.io-index" 231 | checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" 232 | 233 | [[package]] 234 | name = "log" 235 | version = "0.4.22" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" 238 | 239 | [[package]] 240 | name = "matrixmultiply" 241 | version = "0.3.9" 242 | source = "registry+https://github.com/rust-lang/crates.io-index" 243 | checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" 244 | dependencies = [ 245 | "autocfg", 246 | "rawpointer", 247 | ] 248 | 249 | [[package]] 250 | name = "memchr" 251 | version = "2.7.4" 252 | source = "registry+https://github.com/rust-lang/crates.io-index" 253 | checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" 254 | 255 | [[package]] 256 | name = "memoffset" 257 | version = "0.9.1" 258 | source = "registry+https://github.com/rust-lang/crates.io-index" 259 | checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" 260 | dependencies = [ 261 | "autocfg", 262 | ] 263 | 264 | [[package]] 265 | name = "ndarray" 266 | version = "0.15.6" 267 | source = "registry+https://github.com/rust-lang/crates.io-index" 268 | checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" 269 | dependencies = [ 270 | "cblas-sys", 271 | "libc", 272 | "matrixmultiply", 273 | "num-complex", 274 | "num-integer", 275 | "num-traits", 276 | "rawpointer", 277 | "rayon", 278 | "serde", 279 | ] 280 | 281 | [[package]] 282 | name = "num-complex" 283 | version = "0.4.6" 284 | source = "registry+https://github.com/rust-lang/crates.io-index" 285 | checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" 286 | dependencies = [ 287 | "num-traits", 288 | ] 289 | 290 | [[package]] 291 | name = "num-integer" 292 | version = "0.1.46" 293 | source = "registry+https://github.com/rust-lang/crates.io-index" 294 | checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" 295 | dependencies = [ 296 | "num-traits", 297 | ] 298 | 299 | [[package]] 300 | name = "num-traits" 301 | version = "0.2.19" 302 | source = "registry+https://github.com/rust-lang/crates.io-index" 303 | checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" 304 | dependencies = [ 305 | "autocfg", 306 | ] 307 | 308 | [[package]] 309 | name = "once_cell" 310 | version = "1.20.2" 311 | source = "registry+https://github.com/rust-lang/crates.io-index" 312 | checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" 313 | 314 | [[package]] 315 | name = "portable-atomic" 316 | version = "1.9.0" 317 | source = "registry+https://github.com/rust-lang/crates.io-index" 318 | checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" 319 | 320 | [[package]] 321 | name = "proc-macro2" 322 | version = "1.0.87" 323 | source = "registry+https://github.com/rust-lang/crates.io-index" 324 | checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" 325 | dependencies = [ 326 | "unicode-ident", 327 | ] 328 | 329 | [[package]] 330 | name = "pyo3" 331 | version = "0.24.2" 332 | source = "registry+https://github.com/rust-lang/crates.io-index" 333 | checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" 334 | dependencies = [ 335 | "cfg-if", 336 | "indoc", 337 | "libc", 338 | "memoffset", 339 | "once_cell", 340 | "portable-atomic", 341 | "pyo3-build-config", 342 | "pyo3-ffi", 343 | "pyo3-macros", 344 | "unindent", 345 | ] 346 | 347 | [[package]] 348 | name = "pyo3-build-config" 349 | version = "0.24.2" 350 | source = "registry+https://github.com/rust-lang/crates.io-index" 351 | checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" 352 | dependencies = [ 353 | "once_cell", 354 | "target-lexicon", 355 | ] 356 | 357 | [[package]] 358 | name = "pyo3-ffi" 359 | version = "0.24.2" 360 | source = "registry+https://github.com/rust-lang/crates.io-index" 361 | checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" 362 | dependencies = [ 363 | "libc", 364 | "pyo3-build-config", 365 | ] 366 | 367 | [[package]] 368 | name = "pyo3-macros" 369 | version = "0.24.2" 370 | source = "registry+https://github.com/rust-lang/crates.io-index" 371 | checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" 372 | dependencies = [ 373 | "proc-macro2", 374 | "pyo3-macros-backend", 375 | "quote", 376 | "syn", 377 | ] 378 | 379 | [[package]] 380 | name = "pyo3-macros-backend" 381 | version = "0.24.2" 382 | source = "registry+https://github.com/rust-lang/crates.io-index" 383 | checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" 384 | dependencies = [ 385 | "heck", 386 | "proc-macro2", 387 | "pyo3-build-config", 388 | "quote", 389 | "syn", 390 | ] 391 | 392 | [[package]] 393 | name = "quote" 394 | version = "1.0.37" 395 | source = "registry+https://github.com/rust-lang/crates.io-index" 396 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" 397 | dependencies = [ 398 | "proc-macro2", 399 | ] 400 | 401 | [[package]] 402 | name = "rawpointer" 403 | version = "0.2.1" 404 | source = "registry+https://github.com/rust-lang/crates.io-index" 405 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 406 | 407 | [[package]] 408 | name = "rayon" 409 | version = "1.10.0" 410 | source = "registry+https://github.com/rust-lang/crates.io-index" 411 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" 412 | dependencies = [ 413 | "either", 414 | "rayon-core", 415 | ] 416 | 417 | [[package]] 418 | name = "rayon-core" 419 | version = "1.12.1" 420 | source = "registry+https://github.com/rust-lang/crates.io-index" 421 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" 422 | dependencies = [ 423 | "crossbeam-deque", 424 | "crossbeam-utils", 425 | ] 426 | 427 | [[package]] 428 | name = "ref-cast" 429 | version = "1.0.23" 430 | source = "registry+https://github.com/rust-lang/crates.io-index" 431 | checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" 432 | dependencies = [ 433 | "ref-cast-impl", 434 | ] 435 | 436 | [[package]] 437 | name = "ref-cast-impl" 438 | version = "1.0.23" 439 | source = "registry+https://github.com/rust-lang/crates.io-index" 440 | checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" 441 | dependencies = [ 442 | "proc-macro2", 443 | "quote", 444 | "syn", 445 | ] 446 | 447 | [[package]] 448 | name = "ryu" 449 | version = "1.0.18" 450 | source = "registry+https://github.com/rust-lang/crates.io-index" 451 | checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" 452 | 453 | [[package]] 454 | name = "serde" 455 | version = "1.0.210" 456 | source = "registry+https://github.com/rust-lang/crates.io-index" 457 | checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" 458 | dependencies = [ 459 | "serde_derive", 460 | ] 461 | 462 | [[package]] 463 | name = "serde-ndim" 464 | version = "1.1.0" 465 | source = "registry+https://github.com/rust-lang/crates.io-index" 466 | checksum = "5883c695f4433e428c19938eddf3a8e9d4e040f0cd0c5c11f6ded90f190774c0" 467 | dependencies = [ 468 | "ndarray", 469 | "serde", 470 | ] 471 | 472 | [[package]] 473 | name = "serde_derive" 474 | version = "1.0.210" 475 | source = "registry+https://github.com/rust-lang/crates.io-index" 476 | checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" 477 | dependencies = [ 478 | "proc-macro2", 479 | "quote", 480 | "syn", 481 | ] 482 | 483 | [[package]] 484 | name = "serde_json" 485 | version = "1.0.128" 486 | source = "registry+https://github.com/rust-lang/crates.io-index" 487 | checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" 488 | dependencies = [ 489 | "itoa", 490 | "memchr", 491 | "ryu", 492 | "serde", 493 | ] 494 | 495 | [[package]] 496 | name = "simd-json" 497 | version = "0.13.11" 498 | source = "registry+https://github.com/rust-lang/crates.io-index" 499 | checksum = "a0228a564470f81724e30996bbc2b171713b37b15254a6440c7e2d5449b95691" 500 | dependencies = [ 501 | "getrandom", 502 | "halfbrown", 503 | "lexical-core", 504 | "ref-cast", 505 | "serde", 506 | "serde_json", 507 | "simdutf8", 508 | "value-trait", 509 | ] 510 | 511 | [[package]] 512 | name = "simdutf8" 513 | version = "0.1.5" 514 | source = "registry+https://github.com/rust-lang/crates.io-index" 515 | checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" 516 | 517 | [[package]] 518 | name = "static_assertions" 519 | version = "1.1.0" 520 | source = "registry+https://github.com/rust-lang/crates.io-index" 521 | checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 522 | 523 | [[package]] 524 | name = "syn" 525 | version = "2.0.79" 526 | source = "registry+https://github.com/rust-lang/crates.io-index" 527 | checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" 528 | dependencies = [ 529 | "proc-macro2", 530 | "quote", 531 | "unicode-ident", 532 | ] 533 | 534 | [[package]] 535 | name = "target-lexicon" 536 | version = "0.13.2" 537 | source = "registry+https://github.com/rust-lang/crates.io-index" 538 | checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" 539 | 540 | [[package]] 541 | name = "unicode-ident" 542 | version = "1.0.13" 543 | source = "registry+https://github.com/rust-lang/crates.io-index" 544 | checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" 545 | 546 | [[package]] 547 | name = "unindent" 548 | version = "0.2.3" 549 | source = "registry+https://github.com/rust-lang/crates.io-index" 550 | checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" 551 | 552 | [[package]] 553 | name = "value-trait" 554 | version = "0.8.1" 555 | source = "registry+https://github.com/rust-lang/crates.io-index" 556 | checksum = "dad8db98c1e677797df21ba03fca7d3bf9bec3ca38db930954e4fe6e1ea27eb4" 557 | dependencies = [ 558 | "float-cmp", 559 | "halfbrown", 560 | "itoa", 561 | "ryu", 562 | ] 563 | 564 | [[package]] 565 | name = "version_check" 566 | version = "0.9.5" 567 | source = "registry+https://github.com/rust-lang/crates.io-index" 568 | checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" 569 | 570 | [[package]] 571 | name = "wasi" 572 | version = "0.11.0+wasi-snapshot-preview1" 573 | source = "registry+https://github.com/rust-lang/crates.io-index" 574 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 575 | 576 | [[package]] 577 | name = "wasm-bindgen" 578 | version = "0.2.95" 579 | source = "registry+https://github.com/rust-lang/crates.io-index" 580 | checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" 581 | dependencies = [ 582 | "cfg-if", 583 | "once_cell", 584 | "wasm-bindgen-macro", 585 | ] 586 | 587 | [[package]] 588 | name = "wasm-bindgen-backend" 589 | version = "0.2.95" 590 | source = "registry+https://github.com/rust-lang/crates.io-index" 591 | checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" 592 | dependencies = [ 593 | "bumpalo", 594 | "log", 595 | "once_cell", 596 | "proc-macro2", 597 | "quote", 598 | "syn", 599 | "wasm-bindgen-shared", 600 | ] 601 | 602 | [[package]] 603 | name = "wasm-bindgen-macro" 604 | version = "0.2.95" 605 | source = "registry+https://github.com/rust-lang/crates.io-index" 606 | checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" 607 | dependencies = [ 608 | "quote", 609 | "wasm-bindgen-macro-support", 610 | ] 611 | 612 | [[package]] 613 | name = "wasm-bindgen-macro-support" 614 | version = "0.2.95" 615 | source = "registry+https://github.com/rust-lang/crates.io-index" 616 | checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" 617 | dependencies = [ 618 | "proc-macro2", 619 | "quote", 620 | "syn", 621 | "wasm-bindgen-backend", 622 | "wasm-bindgen-shared", 623 | ] 624 | 625 | [[package]] 626 | name = "wasm-bindgen-shared" 627 | version = "0.2.95" 628 | source = "registry+https://github.com/rust-lang/crates.io-index" 629 | checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" 630 | 631 | [[package]] 632 | name = "zerocopy" 633 | version = "0.7.35" 634 | source = "registry+https://github.com/rust-lang/crates.io-index" 635 | checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" 636 | dependencies = [ 637 | "zerocopy-derive", 638 | ] 639 | 640 | [[package]] 641 | name = "zerocopy-derive" 642 | version = "0.7.35" 643 | source = "registry+https://github.com/rust-lang/crates.io-index" 644 | checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" 645 | dependencies = [ 646 | "proc-macro2", 647 | "quote", 648 | "syn", 649 | ] 650 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "avdeepfake1m" 3 | version = "0.0.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [lib] 8 | name = "avdeepfake1m" 9 | crate-type = ["cdylib"] 10 | path = "src/lib.rs" 11 | 12 | [dependencies] 13 | pyo3 = { version = "0.24.2", features = ["extension-module"] } 14 | ndarray = { version = "0.15.0", features = ["blas", "rayon", "serde"] } 15 | serde = { version = "1.0.197", features = ["derive"] } 16 | serde_json = "1.0.115" 17 | rayon = "1.10.0" 18 | simd-json = "0.13.9" 19 | serde-ndim = { version = "1.1.0", features = ["ndarray"] } 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AV-Deepfake1M 2 | 3 |
4 | 5 |

6 |
7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | This is the official repository for the paper 29 | [AV-Deepfake1M: A Large-Scale LLM-Driven Audio-Visual Deepfake Dataset](https://dl.acm.org/doi/abs/10.1145/3664647.3680795) (Best Award). 30 | 31 | ## News 32 | 33 | - [2025/04/08] 🏆 [**2025 1M-Deepfakes Challenge**](https://deepfakes1m.github.io/2025) starts. New database (2M videos!) - **AV-Deepfake1M++** is released 34 | - [2024/10/13] 🚀 [**PYPI package**](https://pypi.org/project/avdeepfake1m/) is released. 35 | - [2024/07/15] 🔥 [**AV-Deepfake1M** paper](https://dl.acm.org/doi/abs/10.1145/3664647.3680795) is accepted in MM 2024. 36 | - [2024/03/09] 🏆 [**2024 1M-Deepfakes Challenge**](https://deepfakes1m.github.io/2024) starts. 37 | 38 | ## Abstract 39 | The detection and localization of highly realistic deepfake audio-visual content are challenging even for the most 40 | advanced state-of-the-art methods. While most of the research efforts in this domain are focused on detecting 41 | high-quality deepfake images and videos, only a few works address the problem of the localization of small segments of 42 | audio-visual manipulations embedded in real videos. In this research, we emulate the process of such content generation 43 | and propose the AV-Deepfake1M dataset. The dataset contains content-driven (i) video manipulations, 44 | (ii) audio manipulations, and (iii) audio-visual manipulations for more than 2K subjects resulting in a total of more 45 | than 1M videos. The paper provides a thorough description of the proposed data generation pipeline accompanied by a 46 | rigorous analysis of the quality of the generated data. The comprehensive benchmark of the proposed dataset utilizing 47 | state-of-the-art deepfake detection and localization methods indicates a significant drop in performance compared to 48 | previous datasets. The proposed dataset will play a vital role in building the next-generation deepfake localization 49 | methods. 50 | 51 | https://github.com/user-attachments/assets/d91aee8a-0fb5-4dff-ba20-86420332fed5 52 | 53 | 54 | ## Dataset 55 | 56 | ### Download 57 | 58 | We're hosting [1M-Deepfakes Detection Challenge](https://deepfakes1m.github.io/2024) at ACM MM 2024. 59 | 60 | ### Baseline Benchmark 61 | 62 | | Method | AP@0.5 | AP@0.75 | AP@0.9 | AP@0.95 | AR@50 | AR@20 | AR@10 | AR@5 | 63 | |----------------------------|--------|---------|--------|---------|-------|-------|-------|-------| 64 | | PyAnnote | 00.03 | 00.00 | 00.00 | 00.00 | 00.67 | 00.67 | 00.67 | 00.67 | 65 | | Meso4 | 09.86 | 06.05 | 02.22 | 00.59 | 38.92 | 38.81 | 36.47 | 26.91 | 66 | | MesoInception4 | 08.50 | 05.16 | 01.89 | 00.50 | 39.27 | 39.00 | 35.78 | 24.59 | 67 | | EfficientViT | 14.71 | 02.42 | 00.13 | 00.01 | 27.04 | 26.43 | 23.90 | 20.31 | 68 | | TriDet + VideoMAEv2 | 21.67 | 05.83 | 00.54 | 00.06 | 20.27 | 20.12 | 19.50 | 18.18 | 69 | | TriDet + InternVideo | 29.66 | 09.02 | 00.79 | 00.09 | 24.08 | 23.96 | 23.50 | 22.55 | 70 | | ActionFormer + VideoMAEv2 | 20.24 | 05.73 | 00.57 | 00.07 | 19.97 | 19.81 | 19.11 | 17.80 | 71 | | ActionFormer + InternVideo | 36.08 | 12.01 | 01.23 | 00.16 | 27.11 | 27.00 | 26.60 | 25.80 | 72 | | BA-TFD | 37.37 | 06.34 | 00.19 | 00.02 | 45.55 | 35.95 | 30.66 | 26.82 | 73 | | BA-TFD+ | 44.42 | 13.64 | 00.48 | 00.03 | 48.86 | 40.37 | 34.67 | 29.88 | 74 | | UMMAFormer | 51.64 | 28.07 | 07.65 | 01.58 | 44.07 | 43.45 | 42.09 | 40.27 | 75 | 76 | 77 | ### Metadata Structure 78 | 79 | The metadata is a json file for each subset (train, val), which is a list of dictionaries. The fields in the dictionary are as follows. 80 | - file: the path to the video file. 81 | - original: if the current video is fake, the path to the original video; otherwise, the original path in VoxCeleb2. 82 | - split: the name of the current subset. 83 | - modify_type: the type of modifications in different modalities, which can be ["real", "visual_modified", "audio_modified", "both_modified"]. We evaluate the deepfake detection performance based on this field. 84 | - audio_model: the audio generation model used for generating this video. 85 | - fake_segments: the timestamps of the fake segments. We evaluate the temporal localization performance based on this field. 86 | - audio_fake_segments: the timestamps of the fake segments in audio modality. 87 | - visual_fake_segments: the timestamps of the fake segments in visual modality. 88 | - video_frames: the number of frames in the video. 89 | - audio_frames: the number of frames in the audio. 90 | 91 | ## SDK 92 | 93 | We provide a Python library `avdeepfake1m` to load the dataset and evaluation. 94 | 95 | ### Installation 96 | 97 | ```bash 98 | pip install avdeepfake1m 99 | ``` 100 | 101 | ### Usage 102 | 103 | Prepare the dataset as follows. 104 | 105 | ``` 106 | |- train_metadata.json 107 | |- train_metadata 108 | | |- ... 109 | |- train 110 | | |- ... 111 | |- val_metadata.json 112 | |- val_metadata 113 | | |- ... 114 | |- val 115 | | |- ... 116 | |- test_files.txt 117 | |- test 118 | ``` 119 | 120 | Load the dataset. 121 | 122 | ```python 123 | from avdeepfake1m.loader import AVDeepfake1mDataModule 124 | 125 | # access to Lightning DataModule 126 | dm = AVDeepfake1mDataModule("/path/to/dataset") 127 | ``` 128 | 129 | Evaluate the predictions. Firstly prepare the predictions as described in the [details](https://deepfakes1m.github.io/2024/details). Then run the following code. 130 | 131 | ```python 132 | from avdeepfake1m.evaluation import ap_ar_1d, auc 133 | print(ap_ar_1d("", "", "file", "fake_segments", 1, [0.5, 0.75, 0.9, 0.95], [50, 30, 20, 10, 5], [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])) 134 | print(auc("", "", "file", "fake_segments")) 135 | ``` 136 | 137 | ## License 138 | 139 | The dataset is under the [EULA](eula.pdf). You need to agree and sign the EULA to access the dataset. 140 | 141 | The baseline Xception code [/examples/xception](/examples/xception) is under MIT Licence. The BA-TFD/BA-TFD+ code [/examples/batfd](/examples/batfd) is from [ControlNet/LAV-DF](https://github.com/ControlNet/LAV-DF) under CC BY-NC 4.0 Licence. 142 | 143 | The other parts of this project is under the CC BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 144 | 145 | ## References 146 | 147 | If you find this work useful in your research, please cite it. 148 | 149 | ```bibtex 150 | @inproceedings{cai2024av, 151 | title={AV-Deepfake1M: A large-scale LLM-driven audio-visual deepfake dataset}, 152 | author={Cai, Zhixi and Ghosh, Shreya and Adatia, Aman Pankaj and Hayat, Munawar and Dhall, Abhinav and Gedeon, Tom and Stefanov, Kalin}, 153 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia}, 154 | pages={7414--7423}, 155 | year={2024}, 156 | doi={10.1145/3664647.3680795} 157 | } 158 | ``` 159 | 160 | The challenge summary paper: 161 | ```bibtex 162 | @inproceedings{cai20241m, 163 | title={1M-Deepfakes Detection Challenge}, 164 | author={Cai, Zhixi and Dhall, Abhinav and Ghosh, Shreya and Hayat, Munawar and Kollias, Dimitrios and Stefanov, Kalin and Tariq, Usman}, 165 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia}, 166 | pages={11355--11359}, 167 | year={2024}, 168 | doi={10.1145/3664647.3689145} 169 | } 170 | ``` 171 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/assets/teaser.png -------------------------------------------------------------------------------- /eula.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/eula.pdf -------------------------------------------------------------------------------- /examples/batfd/README.md: -------------------------------------------------------------------------------- 1 | # BA-TFD 2 | 3 | This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels. 4 | ## Requirements 5 | 6 | Ensure you have the necessary environment setup. You can create a Conda environment using the following commands: 7 | 8 | ```bash 9 | # prepare the environment 10 | conda create -n batfd python=3.10 -y 11 | conda activate batfd 12 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia -y 13 | pip install avdeepfake1m toml tensorboard pytorch-lightning pandas "av<14" 14 | ``` 15 | 16 | ## Training 17 | 18 | Train the BATFD or BATFD+ model using a TOML configuration file (e.g., `batfd.toml` or `batfd_plus.toml`). 19 | 20 | ```bash 21 | python train.py --config ./batfd.toml --data_root /path/to/AV-Deepfake1M-PlusPlus 22 | ``` 23 | 24 | ### Output 25 | 26 | * **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`. 27 | * **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`). 28 | 29 | ## Inference 30 | 31 | After training, generate predictions on a dataset subset (e.g., `val`, `test`) using `infer.py`. This script saves the predictions to a JSON file, which is required for evaluation. 32 | 33 | ```bash 34 | python infer.py --config ./batfd.toml --checkpoint /path/to/checkpoint --data_root /path/to/AV-Deepfake1M-PlusPlus --subset val 35 | ``` 36 | 37 | ## Evaluation 38 | 39 | ```bash 40 | python evaluate.py /path/to/prediction_file /path/to/metadata_file 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /examples/batfd/batfd.toml: -------------------------------------------------------------------------------- 1 | name = "batfd" 2 | num_frames = 100 # T 3 | max_duration = 30 # D 4 | model_type = "batfd" 5 | dataset = "avdeepfake1m++" 6 | 7 | [model.video_encoder] 8 | type = "c3d" 9 | hidden_dims = [64, 96, 128, 128] 10 | cla_feature_in = 256 # C_f 11 | 12 | [model.audio_encoder] 13 | type = "cnn" 14 | hidden_dims = [32, 64, 64] 15 | cla_feature_in = 256 # C_f 16 | 17 | [model.frame_classifier] 18 | type = "lr" 19 | 20 | [model.boundary_module] 21 | hidden_dims = [512, 128] 22 | samples = 10 # N 23 | 24 | [optimizer] 25 | learning_rate = 0.00001 26 | frame_loss_weight = 2.0 27 | modal_bm_loss_weight = 1.0 28 | contrastive_loss_weight = 0.1 29 | contrastive_loss_margin = 0.99 30 | weight_decay = 0.0001 31 | 32 | [soft_nms] 33 | alpha = 0.7234 34 | t1 = 0.1968 35 | t2 = 0.4123 -------------------------------------------------------------------------------- /examples/batfd/batfd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/AV-Deepfake1M/d6cbf3221134f14257d1a85936dfd805b6e00aa2/examples/batfd/batfd/__init__.py -------------------------------------------------------------------------------- /examples/batfd/batfd/inference.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Any, List, Optional 3 | import torch 4 | from torch import Tensor 5 | import pandas as pd 6 | from pathlib import Path 7 | from lightning.pytorch import LightningModule, Trainer, Callback 8 | from torch.utils.data import DataLoader 9 | 10 | from avdeepfake1m.loader import Metadata 11 | 12 | 13 | def nullable_index(obj, index): 14 | if obj is None: 15 | return None 16 | return obj[index] 17 | 18 | 19 | class SaveToCsvCallback(Callback): 20 | 21 | def __init__(self, max_duration: int, metadata: List[Metadata], model_name: str, model_type: str, temp_dir: str): 22 | super().__init__() 23 | self.max_duration = max_duration 24 | self.metadata = metadata 25 | self.model_name = model_name 26 | self.model_type = model_type 27 | self.temp_dir = temp_dir 28 | 29 | def on_predict_batch_end( 30 | self, 31 | trainer: Trainer, 32 | pl_module: LightningModule, 33 | outputs: Any, 34 | batch: Any, 35 | batch_idx: int, 36 | ) -> None: 37 | if self.model_type == "batfd": 38 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla = outputs 39 | batch_size = fusion_bm_map.shape[0] 40 | 41 | for i in range(batch_size): 42 | temporal_size = torch.tensor(100) # the first value of `Batfd.get_meta_attr` 43 | video_name = self.metadata[batch_idx * batch_size + i].file 44 | n_frames = self.metadata[batch_idx * batch_size + i].video_frames 45 | # if n_frames is not available, it should be in test set, and we can get it from the batch 46 | if n_frames == -1: 47 | n_frames = batch[-1][i].cpu().numpy().item() 48 | 49 | assert isinstance(video_name, str) 50 | self.gen_df_for_batfd(fusion_bm_map[i], temporal_size, n_frames, os.path.join( 51 | self.temp_dir, self.model_name, video_name.replace("/", "_").replace(".mp4", ".csv") 52 | )) 53 | 54 | elif self.model_type == "batfd_plus": 55 | fusion_bm_map, fusion_start, fusion_end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end, v_frame_cla, a_frame_cla = outputs 56 | batch_size = fusion_bm_map.shape[0] 57 | 58 | for i in range(batch_size): 59 | temporal_size = torch.tensor(100) # the first value of `BatfdPlus.get_meta_attr` 60 | video_name = self.metadata[batch_idx * batch_size + i].file 61 | n_frames = self.metadata[batch_idx * batch_size + i].video_frames 62 | # if n_frames is not available, it should be in test set, and we can get it from the batch 63 | if n_frames == -1: 64 | n_frames = batch[-1][i].cpu().numpy().item() 65 | assert isinstance(video_name, str) 66 | 67 | self.gen_df_for_batfd_plus(fusion_bm_map[i], nullable_index(fusion_start, i), 68 | nullable_index(fusion_end, i), temporal_size, 69 | n_frames, os.path.join(self.temp_dir, self.model_name, 70 | video_name.replace("/", "_").replace(".mp4", ".csv") 71 | )) 72 | 73 | else: 74 | raise ValueError("Invalid model type") 75 | 76 | def gen_df_for_batfd(self, bm_map: Tensor, temporal_size: Tensor, n_frames: int, output_file: str): 77 | bm_map = bm_map.cpu().numpy() 78 | temporal_size = temporal_size.cpu().numpy().item() 79 | # for each boundary proposal in boundary map 80 | df = pd.DataFrame(bm_map) 81 | df = df.stack().reset_index() 82 | df.columns = ["duration", "begin", "score"] 83 | df["end"] = df.duration + df.begin 84 | df = df[(df.duration > 0) & (df.end <= temporal_size)] 85 | df = df.sort_values(["begin", "end"]) 86 | df = df.reset_index()[["begin", "end", "score"]] 87 | df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int) 88 | df["end"] = (df["end"] / temporal_size * n_frames).astype(int) 89 | df = df.sort_values(["score"], ascending=False).iloc[:100] 90 | df.to_csv(output_file, index=False) 91 | 92 | def gen_df_for_batfd_plus(self, bm_map: Tensor, start: Optional[Tensor], end: Optional[Tensor], 93 | temporal_size: Tensor, n_frames: int, output_file: str 94 | ): 95 | bm_map = bm_map.cpu().numpy() 96 | temporal_size = temporal_size.cpu().numpy().item() 97 | if start is not None and end is not None: 98 | start = start.cpu().numpy() 99 | end = end.cpu().numpy() 100 | 101 | # for each boundary proposal in boundary map 102 | df = pd.DataFrame(bm_map) 103 | df = df.stack().reset_index() 104 | df.columns = ["duration", "begin", "score"] 105 | df["end"] = df.duration + df.begin 106 | df = df[(df.duration > 0) & (df.end <= temporal_size)] 107 | df = df.sort_values(["begin", "end"]) 108 | df = df.reset_index()[["begin", "end", "score"]] 109 | if start is not None and end is not None: 110 | df["score"] = df["score"] * start[df.begin] * end[df.end] 111 | 112 | df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int) 113 | df["end"] = (df["end"] / temporal_size * n_frames).astype(int) 114 | df = df.sort_values(["score"], ascending=False).iloc[:100] 115 | df.to_csv(output_file, index=False) 116 | 117 | 118 | def inference_model(model_name: str, model: LightningModule, dataloader: DataLoader, 119 | metadata: List[Metadata], 120 | max_duration: int, model_type: str, 121 | gpus: int = 1, 122 | temp_dir: str = "output/" 123 | ) -> List[Metadata]: 124 | Path(os.path.join(temp_dir, model_name)).mkdir(parents=True, exist_ok=True) 125 | 126 | model.eval() 127 | 128 | trainer = Trainer(logger=False, 129 | enable_checkpointing=False, devices=1 if gpus > 1 else "auto", 130 | accelerator="auto" if gpus > 0 else "cpu", 131 | callbacks=[SaveToCsvCallback(max_duration, metadata, model_name, model_type, temp_dir)] 132 | ) 133 | 134 | trainer.predict(model, dataloader) 135 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .batfd import Batfd 2 | from .batfd_plus import BatfdPlus 3 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/audio_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from einops import rearrange 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import Module, Sequential, LeakyReLU, MaxPool2d, Linear 7 | from torchvision.models.vision_transformer import Encoder as ViTEncoder 8 | 9 | from ..utils import Conv2d 10 | 11 | 12 | class CNNAudioEncoder(Module): 13 | """ 14 | Audio encoder (E_a): Process log mel spectrogram to extract features. 15 | Input: 16 | A': (B, F_m, T_a) 17 | Output: 18 | E_a: (B, C_f, T) 19 | """ 20 | 21 | def __init__(self, n_features=(32, 64, 64)): 22 | super().__init__() 23 | 24 | n_dim0, n_dim1, n_dim2 = n_features 25 | 26 | # (B, 64, 2048) -> (B, 1, 64, 2048) -> (B, 32, 32, 1024) 27 | self.block0 = Sequential( 28 | Rearrange("b c t -> b 1 c t"), 29 | Conv2d(1, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 30 | MaxPool2d(2) 31 | ) 32 | 33 | # (B, 32, 32, 1024) -> (B, 64, 16, 512) 34 | self.block1 = Sequential( 35 | Conv2d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 36 | Conv2d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 37 | MaxPool2d(2) 38 | ) 39 | 40 | # (B, 64, 16, 512) -> (B, 64, 4, 512) -> (B, 256, 512) 41 | self.block2 = Sequential( 42 | Conv2d(n_dim1, n_dim2, kernel_size=(2, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU), 43 | MaxPool2d((2, 1)), 44 | Conv2d(n_dim2, n_dim2, kernel_size=(3, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU), 45 | MaxPool2d((2, 1)), 46 | Rearrange("b f c t -> b (f c) t") 47 | ) 48 | 49 | def forward(self, audio: Tensor) -> Tensor: 50 | x = self.block0(audio) 51 | x = self.block1(x) 52 | x = self.block2(x) 53 | return x 54 | 55 | 56 | class SelfAttentionAudioEncoder(Module): 57 | 58 | def __init__(self, block_type: Literal["vit_t", "vit_s", "vit_b"], a_cla_feature_in: int = 256, temporal_size: int = 512): 59 | super().__init__() 60 | # The ViT configurations are from: 61 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 62 | if block_type == "vit_t": 63 | self.n_features = 192 64 | self.block = ViTEncoder( 65 | seq_length=temporal_size, 66 | num_layers=12, 67 | num_heads=3, 68 | hidden_dim=self.n_features, 69 | mlp_dim=self.n_features * 4, 70 | dropout=0., 71 | attention_dropout=0. 72 | ) 73 | elif block_type == "vit_s": 74 | self.n_features = 384 75 | self.block = ViTEncoder( 76 | seq_length=temporal_size, 77 | num_layers=12, 78 | num_heads=6, 79 | hidden_dim=self.n_features, 80 | mlp_dim=self.n_features * 4, 81 | dropout=0., 82 | attention_dropout=0. 83 | ) 84 | elif block_type == "vit_b": 85 | self.n_features = 768 86 | self.block = ViTEncoder( 87 | seq_length=temporal_size, 88 | num_layers=12, 89 | num_heads=12, 90 | hidden_dim=self.n_features, 91 | mlp_dim=self.n_features * 4, 92 | dropout=0., 93 | attention_dropout=0. 94 | ) 95 | else: 96 | raise ValueError(f"Unknown block type: {block_type}") 97 | 98 | self.input_proj = Conv2d(1, self.n_features, kernel_size=(64, 4), stride=(64, 4)) 99 | self.output_proj = Linear(self.n_features, a_cla_feature_in) 100 | 101 | def forward(self, audio: Tensor) -> Tensor: 102 | x = audio.unsqueeze(1) # (B, 64, 2048) -> (B, 1, 64, 2048) 103 | x = self.input_proj(x) # (B, 1, 64, 2048) -> (B, feat, 1, 512) 104 | x = rearrange(x, "b f 1 t -> b t f") # (B, feat, 1, 512) -> (B, 512, feat) 105 | x = self.block(x) 106 | x = self.output_proj(x) # (B, 512, feat) -> (B, 512, 256) 107 | x = x.permute(0, 2, 1) # (B, 512, 256) -> (B, 256, 512) 108 | return x 109 | 110 | 111 | class AudioFeatureProjection(Module): 112 | 113 | def __init__(self, input_feature_dim: int, a_cla_feature_in: int = 256): 114 | super().__init__() 115 | self.proj = Linear(input_feature_dim, a_cla_feature_in) 116 | 117 | def forward(self, x: Tensor) -> Tensor: 118 | x = self.proj(x) 119 | return x.permute(0, 2, 1) 120 | 121 | 122 | def get_audio_encoder(a_cla_feature_in, temporal_size, a_encoder, ae_features): 123 | if a_encoder == "cnn": 124 | audio_encoder = CNNAudioEncoder(n_features=ae_features) 125 | elif a_encoder == "vit_t": 126 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_t", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 127 | elif a_encoder == "vit_s": 128 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_s", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 129 | elif a_encoder == "vit_b": 130 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_b", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 131 | elif a_encoder == "wav2vec2": 132 | audio_encoder = AudioFeatureProjection(input_feature_dim=1536, a_cla_feature_in=a_cla_feature_in) 133 | elif a_encoder == "trillsson3": 134 | audio_encoder = AudioFeatureProjection(input_feature_dim=1280, a_cla_feature_in=a_cla_feature_in) 135 | else: 136 | raise ValueError(f"Invalid audio encoder: {a_encoder}") 137 | return audio_encoder 138 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/batfd.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union, Sequence, Tuple 2 | 3 | import torch 4 | from lightning.pytorch import LightningModule 5 | from torch import Tensor 6 | from torch.nn import BCEWithLogitsLoss, MSELoss, functional as F 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from avdeepfake1m.loader import Metadata 10 | 11 | from .loss import ContrastLoss 12 | from .audio_encoder import get_audio_encoder 13 | from .boundary_module import BoundaryModule 14 | from .frame_classifier import FrameLogisticRegression 15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion 16 | from .video_encoder import get_video_encoder 17 | 18 | 19 | class Batfd(LightningModule): 20 | 21 | def __init__(self, 22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr", 23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256, 24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40, 25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99, 26 | weight_decay=0.0001, learning_rate=0.0002, distributed=False 27 | ): 28 | super().__init__() 29 | self.save_hyperparameters() 30 | self.cla_feature_in = v_cla_feature_in 31 | self.temporal_dim = temporal_dim 32 | 33 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features) 34 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features) 35 | 36 | if frame_classifier == "lr": 37 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in) 38 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in) 39 | 40 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier 41 | 42 | assert v_cla_feature_in == a_cla_feature_in 43 | 44 | v_bm_in = v_cla_feature_in + 1 45 | a_bm_in = a_cla_feature_in + 1 46 | 47 | self.video_boundary_module = BoundaryModule(v_bm_in, boundary_features, boundary_samples, temporal_dim, 48 | max_duration 49 | ) 50 | self.audio_boundary_module = BoundaryModule(a_bm_in, boundary_features, boundary_samples, temporal_dim, 51 | max_duration 52 | ) 53 | 54 | self.fusion = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 55 | 56 | self.frame_loss = BCEWithLogitsLoss() 57 | self.contrast_loss = ContrastLoss(margin=contrast_loss_margin) 58 | self.bm_loss = MSELoss() 59 | self.weight_frame_loss = weight_frame_loss 60 | self.weight_modal_bm_loss = weight_modal_bm_loss 61 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim) 62 | self.weight_decay = weight_decay 63 | self.learning_rate = learning_rate 64 | self.distributed = distributed 65 | 66 | def forward(self, video: Tensor, audio: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 67 | # encoders 68 | v_features = self.video_encoder(video) 69 | a_features = self.audio_encoder(audio) 70 | 71 | # frame classifiers 72 | v_frame_cla = self.video_frame_classifier(v_features) 73 | a_frame_cla = self.audio_frame_classifier(a_features) 74 | 75 | # concat classification result to features 76 | v_bm_in = torch.column_stack([v_features, v_frame_cla]) 77 | a_bm_in = torch.column_stack([a_features, a_frame_cla]) 78 | 79 | # modal boundary module 80 | v_bm_map = self.video_boundary_module(v_bm_in) 81 | a_bm_map = self.audio_boundary_module(a_bm_in) 82 | 83 | # boundary map modal attention fusion 84 | fusion_bm_map = self.fusion(v_bm_in, a_bm_in, v_bm_map, a_bm_map) 85 | 86 | return fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features 87 | 88 | def loss_fn(self, fusion_bm_map: Tensor, v_bm_map: Tensor, a_bm_map: Tensor, 89 | v_frame_cla: Tensor, a_frame_cla: Tensor, label: Tensor, n_frames: Tensor, 90 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 91 | ) -> Dict[str, Tensor]: 92 | fusion_bm_loss = self.bm_loss(fusion_bm_map, label) 93 | 94 | v_bm_loss = self.bm_loss(v_bm_map, v_bm_label) 95 | a_bm_loss = self.bm_loss(a_bm_map, a_bm_label) 96 | 97 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label) 98 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label) 99 | 100 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label) 101 | / (self.cla_feature_in * self.temporal_dim), max=1.) 102 | 103 | loss = fusion_bm_loss + \ 104 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \ 105 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \ 106 | self.weight_contrastive_loss * contrast_loss 107 | 108 | return { 109 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss, 110 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss 111 | } 112 | 113 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 114 | ) -> Tensor: 115 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch 116 | 117 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio) 118 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames, 119 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 120 | ) 121 | 122 | self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 123 | prog_bar=False, sync_dist=self.distributed) 124 | return loss_dict["loss"] 125 | 126 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 127 | ) -> Tensor: 128 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch 129 | 130 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio) 131 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames, 132 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 133 | ) 134 | 135 | self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 136 | prog_bar=False, sync_dist=self.distributed) 137 | return loss_dict["loss"] 138 | 139 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None 140 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 141 | video, audio, *_ = batch 142 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio) 143 | return fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla 144 | 145 | def configure_optimizers(self): 146 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay) 147 | return { 148 | "optimizer": optimizer, 149 | "lr_scheduler": { 150 | "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=3, verbose=True, min_lr=1e-8), 151 | "monitor": "val_loss" 152 | } 153 | } 154 | 155 | 156 | @staticmethod 157 | def get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tuple[Tensor, Optional[Tensor], Optional[Tensor]]): 158 | label, visual_label, audio_label = label 159 | label_real = torch.zeros(label.size(), dtype=label.dtype, device=label.device) 160 | 161 | if visual_label is not None: 162 | v_bm_label = visual_label 163 | elif meta.modify_video: 164 | v_bm_label = label 165 | else: 166 | v_bm_label = label_real 167 | 168 | if audio_label is not None: 169 | a_bm_label = audio_label 170 | elif meta.modify_audio: 171 | a_bm_label = label 172 | else: 173 | a_bm_label = label_real 174 | 175 | frame_label_real = torch.zeros(meta.video_frames) 176 | frame_label_fake = torch.zeros(meta.video_frames) 177 | for begin, end in meta.fake_periods: 178 | begin = int(begin * 25) 179 | end = int(end * 25) 180 | frame_label_fake[begin: end] = 1 181 | 182 | if visual_label is not None: 183 | v_frame_label = torch.zeros(meta.video_frames) 184 | for begin, end in meta.visual_fake_periods: 185 | begin = int(begin * 25) 186 | end = int(end * 25) 187 | v_frame_label[begin: end] = 1 188 | elif meta.modify_video: 189 | v_frame_label = frame_label_fake 190 | else: 191 | v_frame_label = frame_label_real 192 | 193 | v_frame_label = F.interpolate(v_frame_label[None, None], (100,), mode="linear")[0, 0] 194 | 195 | if audio_label is not None: 196 | a_frame_label = torch.zeros(meta.video_frames) 197 | for begin, end in meta.audio_fake_periods: 198 | begin = int(begin * 25) 199 | end = int(end * 25) 200 | a_frame_label[begin: end] = 1 201 | elif meta.modify_audio: 202 | a_frame_label = frame_label_fake 203 | else: 204 | a_frame_label = frame_label_real 205 | 206 | a_frame_label = F.interpolate(a_frame_label[None, None], (100,), mode="linear")[0, 0] 207 | 208 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1 209 | 210 | return [100, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label] 211 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/batfd_plus.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union, Sequence, Tuple 2 | 3 | import torch 4 | from lightning.pytorch import LightningModule 5 | from torch import Tensor 6 | from torch.nn import BCEWithLogitsLoss, functional as F 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ExponentialLR 9 | from avdeepfake1m.loader import Metadata 10 | 11 | from .loss import ContrastLoss, BsnppLoss 12 | from .audio_encoder import get_audio_encoder 13 | from .boundary_module_plus import BoundaryModulePlus, NestedUNet 14 | from .frame_classifier import FrameLogisticRegression 15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion, ModalFeatureAttnCfgFusion 16 | from .video_encoder import get_video_encoder 17 | 18 | 19 | class BatfdPlus(LightningModule): 20 | 21 | def __init__(self, 22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr", 23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256, 24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40, 25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99, 26 | cbg_feature_weight=0.01, prb_weight_forward=1., 27 | weight_decay=0.0001, learning_rate=0.0002, distributed=False 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | 32 | self.cla_feature_in = v_cla_feature_in 33 | self.temporal_dim = temporal_dim 34 | 35 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features) 36 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features) 37 | 38 | if frame_classifier == "lr": 39 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in) 40 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in) 41 | 42 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier 43 | 44 | assert v_cla_feature_in == a_cla_feature_in 45 | 46 | v_bm_in = v_cla_feature_in + 1 47 | a_bm_in = a_cla_feature_in + 1 48 | 49 | # Proposal Relation Block in BSN++ mechanism 50 | self.video_boundary_module = BoundaryModulePlus(v_bm_in, boundary_features, boundary_samples, temporal_dim, 51 | max_duration 52 | ) 53 | self.audio_boundary_module = BoundaryModulePlus(a_bm_in, boundary_features, boundary_samples, temporal_dim, 54 | max_duration 55 | ) 56 | 57 | if cbg_feature_weight > 0: 58 | # Complementary Boundary Generator in BSN++ mechanism 59 | self.video_comp_boundary_generator = NestedUNet(in_ch=v_bm_in, out_ch=2) 60 | self.audio_comp_boundary_generator = NestedUNet(in_ch=a_bm_in, out_ch=2) 61 | self.cbg_fusion_start = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in) 62 | self.cbg_fusion_end = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in) 63 | else: 64 | self.video_comp_boundary_generator = None 65 | self.audio_comp_boundary_generator = None 66 | self.cbg_fusion_start = None 67 | self.cbg_fusion_end = None 68 | 69 | self.prb_fusion_p = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 70 | self.prb_fusion_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 71 | self.prb_fusion_p_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 72 | 73 | self.frame_loss = BCEWithLogitsLoss() 74 | self.contrast_loss = ContrastLoss(margin=contrast_loss_margin) 75 | self.bm_loss = BsnppLoss(cbg_feature_weight, prb_weight_forward) 76 | self.weight_frame_loss = weight_frame_loss 77 | self.weight_modal_bm_loss = weight_modal_bm_loss 78 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim) 79 | self.weight_decay = weight_decay 80 | self.learning_rate = learning_rate 81 | self.distributed = distributed 82 | 83 | def forward(self, video: Tensor, audio: Tensor) -> Sequence[Tensor]: 84 | a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla = self.forward_features(audio, video) 85 | 86 | # modal boundary module 87 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c = self.video_boundary_module(v_bm_in) 88 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c = self.audio_boundary_module(a_bm_in) 89 | 90 | # complementary boundary generator 91 | if self.cbg_fusion_start is not None: 92 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in) 93 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in) 94 | else: 95 | v_cbg_feature, v_cbg_start, v_cbg_end = None, None, None 96 | a_cbg_feature, a_cbg_start, a_cbg_end = None, None, None 97 | 98 | # boundary map modal attention fusion 99 | fusion_bm_map_p = self.prb_fusion_p(v_bm_in, a_bm_in, v_bm_map_p, a_bm_map_p) 100 | fusion_bm_map_c = self.prb_fusion_c(v_bm_in, a_bm_in, v_bm_map_c, a_bm_map_c) 101 | fusion_bm_map_p_c = self.prb_fusion_p_c(v_bm_in, a_bm_in, v_bm_map_p_c, a_bm_map_p_c) 102 | 103 | # complementary boundary generator modal attention fusion 104 | if self.cbg_fusion_start is not None: 105 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start) 106 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end) 107 | else: 108 | fusion_cbg_start = None 109 | fusion_cbg_end = None 110 | 111 | return ( 112 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 113 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 114 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 115 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 116 | ) 117 | 118 | def forward_back(self, video: Tensor, audio: Tensor) -> Sequence[Optional[Tensor]]: 119 | if self.cbg_fusion_start is not None: 120 | a_bm_in, _, _, v_bm_in, _, _ = self.forward_features(audio, video) 121 | 122 | # complementary boundary generator 123 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in) 124 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in) 125 | 126 | # complementary boundary generator modal attention fusion 127 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start) 128 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end) 129 | 130 | return ( 131 | fusion_cbg_start, fusion_cbg_end, v_cbg_start, v_cbg_end, a_cbg_start, a_cbg_end, 132 | v_cbg_feature, a_cbg_feature 133 | ) 134 | else: 135 | return None, None, None, None, None, None, None, None 136 | 137 | def forward_features(self, audio, video): 138 | # encoders 139 | v_features = self.video_encoder(video) 140 | a_features = self.audio_encoder(audio) 141 | # frame classifiers 142 | v_frame_cla = self.video_frame_classifier(v_features) 143 | a_frame_cla = self.audio_frame_classifier(a_features) 144 | # concat classification result to features 145 | v_bm_in = torch.column_stack([v_features, v_frame_cla]) 146 | a_bm_in = torch.column_stack([a_features, a_frame_cla]) 147 | return a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla 148 | 149 | def forward_video_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 150 | cbg_prob, cbg_feature = self.video_comp_boundary_generator(feature) 151 | start = cbg_prob[:, 0, :].squeeze(1) 152 | end = cbg_prob[:, 1, :].squeeze(1) 153 | return cbg_feature, end, start 154 | 155 | def forward_audio_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 156 | cbg_prob, cbg_feature = self.audio_comp_boundary_generator(feature) 157 | start = cbg_prob[:, 0, :].squeeze(1) 158 | end = cbg_prob[:, 1, :].squeeze(1) 159 | return cbg_feature, end, start 160 | 161 | def loss_fn(self, 162 | fusion_bm_map_p: Tensor, fusion_bm_map_c: Tensor, fusion_bm_map_p_c: Tensor, 163 | fusion_cbg_start: Tensor, fusion_cbg_end: Tensor, 164 | fusion_cbg_start_back: Tensor, fusion_cbg_end_back: Tensor, 165 | v_bm_map_p: Tensor, v_bm_map_c: Tensor, v_bm_map_p_c: Tensor, 166 | v_cbg_start: Tensor, v_cbg_end: Tensor, v_cbg_feature: Tensor, 167 | v_cbg_start_back: Tensor, v_cbg_end_back: Tensor, v_cbg_feature_back: Tensor, 168 | a_bm_map_p: Tensor, a_bm_map_c: Tensor, a_bm_map_p_c: Tensor, 169 | a_cbg_start: Tensor, a_cbg_end: Tensor, a_cbg_feature: Tensor, 170 | a_cbg_start_back: Tensor, a_cbg_end_back: Tensor, a_cbg_feature_back: Tensor, 171 | v_frame_cla: Tensor, a_frame_cla: Tensor, n_frames: Tensor, 172 | fusion_bm_label: Tensor, fusion_start_label: Tensor, fusion_end_label: Tensor, 173 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label, 174 | v_frame_label, a_frame_label, contrast_label, v_features, a_features 175 | ) -> Dict[str, Tensor]: 176 | ( 177 | fusion_bm_loss, fusion_cbg_loss, fusion_prb_loss, fusion_cbg_loss_forward, fusion_cbg_loss_backward, _ 178 | ) = self.bm_loss( 179 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 180 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back, 181 | fusion_bm_label, fusion_start_label, fusion_end_label 182 | ) 183 | 184 | ( 185 | v_bm_loss, v_cbg_loss, v_prb_loss, v_cbg_loss_forward, v_cbg_loss_backward, v_cbg_feature_loss 186 | ) = self.bm_loss( 187 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 188 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back, 189 | v_bm_label, v_start_label, v_end_label, 190 | v_cbg_feature, v_cbg_feature_back 191 | ) 192 | 193 | ( 194 | a_bm_loss, a_cbg_loss, a_prb_loss, a_cbg_loss_forward, a_cbg_loss_backward, a_cbg_feature_loss 195 | ) = self.bm_loss( 196 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 197 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back, 198 | a_bm_label, a_start_label, a_end_label, 199 | a_cbg_feature, a_cbg_feature_back 200 | ) 201 | 202 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label) 203 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label) 204 | 205 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label) 206 | / (self.cla_feature_in * self.temporal_dim), max=1.) 207 | 208 | loss = fusion_bm_loss + \ 209 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \ 210 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \ 211 | self.weight_contrastive_loss * contrast_loss 212 | 213 | loss_dict = { 214 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss, 215 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss, 216 | "fusion_cbg_loss": fusion_cbg_loss, "v_cbg_loss": v_cbg_loss, "a_cbg_loss": a_cbg_loss, 217 | "fusion_prb_loss": fusion_prb_loss, "v_prb_loss": v_prb_loss, "a_prb_loss": a_prb_loss, 218 | "fusion_cbg_loss_forward": fusion_cbg_loss_forward, "v_cbg_loss_forward": v_cbg_loss_forward, 219 | "a_cbg_loss_forward": a_cbg_loss_forward, "fusion_cbg_loss_backward": fusion_cbg_loss_backward, 220 | "v_cbg_loss_backward": v_cbg_loss_backward, "a_cbg_loss_backward": a_cbg_loss_backward, 221 | "v_cbg_feature_loss": v_cbg_feature_loss, "a_cbg_feature_loss": a_cbg_feature_loss 222 | } 223 | return {k: v for k, v in loss_dict.items() if v is not None} 224 | 225 | def step(self, batch: Sequence[Tensor]) -> Dict[str, Tensor]: 226 | ( 227 | video, audio, fusion_bm_label, fusion_start_label, fusion_end_label, n_frames, 228 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, 229 | a_start_label, v_start_label, a_end_label, v_end_label 230 | ) = batch 231 | # forward 232 | ( 233 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 234 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 235 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 236 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 237 | ) = self(video, audio) 238 | # BSN++ back 239 | video_back = torch.flip(video, dims=(2,)) 240 | audio_back = torch.flip(audio, dims=(2,)) 241 | ( 242 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back, 243 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back 244 | ) = self.forward_back(video_back, audio_back) 245 | 246 | # loss 247 | loss_dict = self.loss_fn( 248 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 249 | fusion_cbg_start, fusion_cbg_end, 250 | fusion_cbg_start_back, fusion_cbg_end_back, 251 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 252 | v_cbg_start, v_cbg_end, v_cbg_feature, 253 | v_cbg_start_back, v_cbg_end_back, v_cbg_feature_back, 254 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 255 | a_cbg_start, a_cbg_end, a_cbg_feature, 256 | a_cbg_start_back, a_cbg_end_back, a_cbg_feature_back, 257 | v_frame_cla, a_frame_cla, n_frames, 258 | fusion_bm_label, fusion_start_label, fusion_end_label, 259 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label, 260 | v_frame_label, a_frame_label, contrast_label, v_features, a_features 261 | ) 262 | return loss_dict 263 | 264 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 265 | ) -> Tensor: 266 | loss_dict = self.step(batch) 267 | 268 | if torch.isnan(loss_dict["loss"]): 269 | print(f"NaN in loss in {self.global_step}") 270 | exit(1) 271 | return None 272 | 273 | self.log_dict({f"metrics/train_{k}": v for k, v in loss_dict.items() if k != "loss"}, on_step=True, on_epoch=True, 274 | prog_bar=False, sync_dist=self.distributed) 275 | # only log the loss to progress bar 276 | self.log("metrics/train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True, 277 | sync_dist=self.distributed, rank_zero_only=True) 278 | return loss_dict["loss"] 279 | 280 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 281 | ) -> Tensor: 282 | loss_dict = self.step(batch) 283 | 284 | self.log_dict({f"metrics/val_{k}": v for k, v in loss_dict.items() if k != "loss"}, on_step=True, on_epoch=True, 285 | prog_bar=False, sync_dist=self.distributed) 286 | self.log("metrics/val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True, 287 | sync_dist=self.distributed, rank_zero_only=True) 288 | return loss_dict["loss"] 289 | 290 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None 291 | ) -> Tuple[ 292 | Tensor, Optional[Tensor], Optional[Tensor], 293 | Tensor, Optional[Tensor], Optional[Tensor], 294 | Tensor, Optional[Tensor], Optional[Tensor] 295 | ]: 296 | video, audio, *_ = batch 297 | # forward 298 | ( 299 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 300 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 301 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 302 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 303 | ) = self(video, audio) 304 | # BSN++ back 305 | video_back = torch.flip(video, dims=(2,)) 306 | audio_back = torch.flip(audio, dims=(2,)) 307 | ( 308 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back, 309 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back 310 | ) = self.forward_back(video_back, audio_back) 311 | 312 | fusion_bm_map, start, end = self.post_process_predict(fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 313 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back 314 | ) 315 | 316 | v_bm_map, v_start, v_end = self.post_process_predict(v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 317 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back 318 | ) 319 | 320 | a_bm_map, a_start, a_end = self.post_process_predict(a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 321 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back 322 | ) 323 | 324 | return fusion_bm_map, start, end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end, v_frame_cla, a_frame_cla 325 | 326 | def post_process_predict(self, 327 | bm_map_p: Tensor, bm_map_c: Tensor, bm_map_p_c: Tensor, 328 | cbg_start: Tensor, cbg_end: Tensor, 329 | cbg_start_back: Tensor, cbg_end_back: Tensor 330 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 331 | 332 | bm_map = (bm_map_p + bm_map_c + bm_map_p_c) / 3 333 | if self.cbg_fusion_start is not None: 334 | start = torch.sqrt(cbg_start * torch.flip(cbg_end_back, dims=(1,))) 335 | end = torch.sqrt(cbg_end * torch.flip(cbg_start_back, dims=(1,))) 336 | else: 337 | start = None 338 | end = None 339 | 340 | return bm_map, start, end 341 | 342 | def configure_optimizers(self): 343 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay) 344 | return { 345 | "optimizer": optimizer, 346 | "lr_scheduler": { 347 | # "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True, min_lr=1e-8), 348 | "scheduler": ExponentialLR(optimizer, gamma=0.992), 349 | "monitor": "val_loss" 350 | } 351 | } 352 | 353 | @classmethod 354 | def gen_audio_video_labels(cls, label_fake: Tensor, meta: Metadata): 355 | label_real = torch.zeros(label_fake.size(), dtype=label_fake.dtype, device=label_fake.device) 356 | v_label = label_fake if meta.modify_video else label_real 357 | a_label = label_fake if meta.modify_audio else label_real 358 | return a_label, v_label 359 | 360 | @staticmethod 361 | def get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tuple[Tensor, Optional[Tensor], Optional[Tensor]]): 362 | label, visual_label, audio_label = label 363 | label_real = torch.zeros(label.size(), dtype=label.dtype, device=label.device) 364 | if visual_label is not None: 365 | v_bm_label = visual_label 366 | elif meta.modify_video: 367 | v_bm_label = label 368 | else: 369 | v_bm_label = label_real 370 | 371 | if audio_label is not None: 372 | a_bm_label = audio_label 373 | elif meta.modify_audio: 374 | a_bm_label = label 375 | else: 376 | a_bm_label = label_real 377 | 378 | frame_label_real = torch.zeros(meta.video_frames) 379 | frame_label_fake = torch.zeros(meta.video_frames) 380 | for begin, end in meta.fake_periods: 381 | begin = int(begin * 25) 382 | end = int(end * 25) 383 | frame_label_fake[begin: end] = 1 384 | 385 | if visual_label is not None: 386 | v_frame_label = torch.zeros(meta.video_frames) 387 | for begin, end in meta.visual_fake_periods: 388 | begin = int(begin * 25) 389 | end = int(end * 25) 390 | v_frame_label[begin: end] = 1 391 | elif meta.modify_video: 392 | v_frame_label = frame_label_fake 393 | else: 394 | v_frame_label = frame_label_real 395 | 396 | v_frame_label = F.interpolate(v_frame_label[None, None], (100,), mode="linear")[0, 0] 397 | 398 | if audio_label is not None: 399 | a_frame_label = torch.zeros(meta.video_frames) 400 | for begin, end in meta.audio_fake_periods: 401 | begin = int(begin * 25) 402 | end = int(end * 25) 403 | a_frame_label[begin: end] = 1 404 | elif meta.modify_audio: 405 | a_frame_label = frame_label_fake 406 | else: 407 | a_frame_label = frame_label_real 408 | 409 | a_frame_label = F.interpolate(a_frame_label[None, None], (100,), mode="linear")[0, 0] 410 | 411 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1 412 | return [100, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, 0, 0, 0, 0] 413 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/boundary_module.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import Sequential, LeakyReLU, Sigmoid, Module 7 | 8 | from ..utils import Conv3d, Conv2d 9 | 10 | 11 | class BoundaryModule(Module): 12 | """ 13 | Boundary matching module for video or audio features. 14 | Input: 15 | F_v or F_a: (B, C_f, T) 16 | Output: 17 | M_v^ or M_a^: (B, D, T) 18 | 19 | """ 20 | 21 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512, 22 | max_duration: int = 40 23 | ): 24 | super().__init__() 25 | 26 | dim0, dim1 = n_features 27 | 28 | # (B, n_feature_in, temporal_dim) -> (B, n_feature_in, sample, max_duration, temporal_dim) 29 | self.bm_layer = BMLayer(temporal_dim, num_samples, max_duration) 30 | 31 | # (B, n_feature_in, sample, max_duration, temporal_dim) -> (B, dim0, max_duration, temporal_dim) 32 | self.block0 = Sequential( 33 | Conv3d(n_feature_in, dim0, kernel_size=(num_samples, 1, 1), stride=(num_samples, 1, 1), 34 | build_activation=LeakyReLU 35 | ), 36 | Rearrange("b c n d t -> b c (n d) t") 37 | ) 38 | 39 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim) 40 | self.block1 = Sequential( 41 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU), 42 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU), 43 | Conv2d(dim1, 1, kernel_size=1, build_activation=Sigmoid), 44 | Rearrange("b c d t -> b (c d) t") 45 | ) 46 | 47 | def forward(self, feature: Tensor) -> Tensor: 48 | feature = self.bm_layer(feature) 49 | feature = self.block0(feature) 50 | feature = self.block1(feature) 51 | return feature 52 | 53 | 54 | class BMLayer(Module): 55 | """BM Layer""" 56 | 57 | def __init__(self, temporal_dim: int, num_sample: int, max_duration: int, roi_expand_ratio: float = 0.5): 58 | super().__init__() 59 | self.temporal_dim = temporal_dim 60 | # self.feat_dim = opt['bmn_feat_dim'] 61 | self.num_sample = num_sample 62 | self.duration = max_duration 63 | self.roi_expand_ratio = roi_expand_ratio 64 | self.smp_weight = self.get_pem_smp_weight() 65 | 66 | def get_pem_smp_weight(self): 67 | T = self.temporal_dim 68 | N = self.num_sample 69 | D = self.duration 70 | w = torch.zeros([T, N, D, T]) # T * N * D * T 71 | # In each temporal location i, there are D predefined proposals, 72 | # with length ranging between 1 and D 73 | # the j-th proposal is [i, i+j+1], 0<=j T - 1: 88 | continue 89 | left, right = int(np.floor(xp)), int(np.ceil(xp)) 90 | left_weight = 1 - (xp - left) 91 | right_weight = 1 - (right - xp) 92 | w[left, k, j, i] += left_weight 93 | w[right, k, j, i] += right_weight 94 | return w.view(T, -1).float() 95 | 96 | def _apply(self, fn): 97 | self.smp_weight = fn(self.smp_weight) 98 | 99 | def forward(self, X): 100 | input_size = X.size() 101 | assert (input_size[-1] == self.temporal_dim) 102 | # assert(len(input_size) == 3 and 103 | X_view = X.view(-1, input_size[-1]) 104 | # feature [bs*C, T] 105 | # smp_w [T, N*D*T] 106 | # out [bs*C, N*D*T] --> [bs, C, N, D, T] 107 | result = torch.matmul(X_view, self.smp_weight) 108 | return result.view(-1, input_size[1], self.num_sample, self.duration, self.temporal_dim) 109 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/boundary_module_plus.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops.layers.torch import Rearrange 7 | from torch import Tensor 8 | from torch.nn import Sequential, LeakyReLU 9 | 10 | from .boundary_module import BoundaryModule 11 | from ..utils import Conv2d 12 | 13 | 14 | class ConvUnit(nn.Module): 15 | """ 16 | Unit in NestedUNet 17 | """ 18 | 19 | def __init__(self, in_ch, out_ch, is_output=False): 20 | super(ConvUnit, self).__init__() 21 | module_list: list[nn.Module] = [nn.Conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True)] 22 | if is_output is False: 23 | module_list.append(nn.BatchNorm1d(out_ch)) 24 | module_list.append(nn.ReLU(inplace=True)) 25 | self.conv = nn.Sequential(*module_list) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | return x 30 | 31 | 32 | class NestedUNet(nn.Module): 33 | """ 34 | UNet - Basic Implementation 35 | Paper : https://arxiv.org/abs/1505.04597 36 | """ 37 | def __init__(self, in_ch=400, out_ch=2): 38 | super(NestedUNet, self).__init__() 39 | 40 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2) 41 | self.up = nn.Upsample(scale_factor=2) 42 | 43 | n1 = 512 44 | filters = [n1, n1 * 2, n1 * 3] 45 | self.conv0_0 = ConvUnit(in_ch, filters[0], is_output=False) 46 | self.conv1_0 = ConvUnit(filters[0], filters[0], is_output=False) 47 | self.conv2_0 = ConvUnit(filters[0], filters[0], is_output=False) 48 | 49 | self.conv0_1 = ConvUnit(filters[1], filters[0], is_output=False) 50 | self.conv1_1 = ConvUnit(filters[1], filters[0], is_output=False) 51 | 52 | self.conv0_2 = ConvUnit(filters[2], filters[0], is_output=False) 53 | 54 | self.final = nn.Conv1d(filters[0] * 3, out_ch, kernel_size=1) 55 | # self.final = ConvUnit(filters[0] * 3, out_ch, is_output=True) 56 | self.out = nn.Sigmoid() 57 | 58 | def forward(self, x): 59 | x0_0 = self.conv0_0(x) 60 | x1_0 = self.conv1_0(self.pool(x0_0)) 61 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 62 | x2_0 = self.conv2_0(self.pool(x1_0)) 63 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 64 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 65 | out_feature = torch.cat([x0_0, x0_1, x0_2], 1) # for calculating loss 66 | final_feature = self.final(out_feature) 67 | out = self.out(final_feature) 68 | 69 | return out, out_feature 70 | 71 | 72 | class PositionAwareAttentionModule(nn.Module): 73 | def __init__(self, in_channels, inter_channels=None, sub_sample=None, dim=2): 74 | super(PositionAwareAttentionModule, self).__init__() 75 | 76 | self.sub_sample = sub_sample 77 | self.in_channels = in_channels 78 | self.inter_channels = inter_channels 79 | self.dim = dim 80 | 81 | if self.inter_channels is None: 82 | self.inter_channels = in_channels // 2 83 | if self.inter_channels == 0: 84 | self.inter_channels = 1 85 | 86 | if self.dim == 2: 87 | conv_nd = nn.Conv2d 88 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 89 | bn = nn.BatchNorm2d 90 | else: 91 | conv_nd = nn.Conv1d 92 | max_pool_layer = nn.MaxPool1d(kernel_size=(2,)) 93 | bn = nn.BatchNorm1d 94 | 95 | self.g = nn.Sequential( 96 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 97 | bn(self.inter_channels), 98 | nn.ReLU(inplace=True) 99 | ) 100 | self.theta = nn.Sequential( 101 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 102 | bn(self.inter_channels), 103 | nn.ReLU(inplace=True) 104 | ) 105 | self.phi = nn.Sequential( 106 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 107 | bn(self.inter_channels), 108 | nn.ReLU(inplace=True) 109 | ) 110 | self.W = nn.Sequential( 111 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 112 | kernel_size=1, stride=1, padding=0), 113 | bn(self.in_channels) 114 | ) 115 | if self.sub_sample: 116 | self.g = nn.Sequential(self.g, max_pool_layer) 117 | self.phi = nn.Sequential(self.phi, max_pool_layer) 118 | 119 | def forward(self, x): 120 | batch_size = x.size(0) 121 | # value 122 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 123 | g_x = g_x.permute(0, 2, 1) 124 | 125 | # query 126 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 127 | theta_x = theta_x.permute(0, 2, 1) 128 | 129 | # key 130 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 131 | 132 | f = torch.matmul(theta_x, phi_x) 133 | f = F.softmax(f, dim=2) 134 | 135 | y = torch.matmul(f, g_x) 136 | y = y.permute(0, 2, 1).contiguous() 137 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 138 | y = self.W(y) 139 | 140 | z = y + x 141 | return z 142 | 143 | 144 | class ChannelAwareAttentionModule(nn.Module): 145 | def __init__(self, in_channels, inter_channels=None, dim=2): 146 | super(ChannelAwareAttentionModule, self).__init__() 147 | 148 | self.in_channels = in_channels 149 | self.inter_channels = inter_channels 150 | self.dim = dim 151 | 152 | if self.inter_channels is None: 153 | self.inter_channels = in_channels // 2 154 | if self.inter_channels == 0: 155 | self.inter_channels = 1 156 | 157 | if self.dim == 2: 158 | conv_nd = nn.Conv2d 159 | bn = nn.BatchNorm2d 160 | else: 161 | conv_nd = nn.Conv1d 162 | bn = nn.BatchNorm1d 163 | 164 | self.g = nn.Sequential( 165 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 166 | bn(self.inter_channels), 167 | nn.ReLU(inplace=True) 168 | ) 169 | self.theta = nn.Sequential( 170 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 171 | bn(self.inter_channels), 172 | nn.ReLU(inplace=True) 173 | ) 174 | self.phi = nn.Sequential( 175 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 176 | bn(self.inter_channels), 177 | nn.ReLU(inplace=True) 178 | ) 179 | self.W = nn.Sequential( 180 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 181 | kernel_size=1, stride=1, padding=0), 182 | bn(self.in_channels) 183 | ) 184 | 185 | def forward(self, x): 186 | batch_size = x.size(0) 187 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 188 | 189 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 190 | 191 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 192 | phi_x = phi_x.permute(0, 2, 1) 193 | 194 | f = torch.matmul(theta_x, phi_x) 195 | f = F.softmax(f, dim=2) 196 | 197 | y = torch.matmul(f, g_x) 198 | y = y.permute(0, 2, 1).contiguous() 199 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 200 | y = self.W(y) 201 | 202 | z = y + x 203 | return z 204 | 205 | 206 | def conv_block(in_ch, out_ch, kernel_size=3, stride=1, bn_layer=False, activate=False): 207 | module_list = [nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=1)] 208 | if bn_layer: 209 | module_list.append(nn.BatchNorm2d(out_ch)) 210 | module_list.append(nn.ReLU(inplace=True)) 211 | if activate: 212 | module_list.append(nn.Sigmoid()) 213 | conv = nn.Sequential(*module_list) 214 | return conv 215 | 216 | 217 | class ProposalRelationBlock(nn.Module): 218 | def __init__(self, in_channels, inter_channels=128, out_channels=2, sub_sample=False): 219 | super(ProposalRelationBlock, self).__init__() 220 | self.p_net = PositionAwareAttentionModule(in_channels, inter_channels=inter_channels, sub_sample=sub_sample, dim=2) 221 | self.c_net = ChannelAwareAttentionModule(in_channels, inter_channels=inter_channels, dim=2) 222 | self.conv0_0 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 223 | self.conv0_1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 224 | 225 | self.conv1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 226 | self.conv2 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 227 | self.conv3 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 228 | self.conv4 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 229 | self.conv5 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 230 | 231 | def forward(self, x): 232 | x_p = self.conv0_0(x) 233 | x_c = self.conv0_1(x) 234 | 235 | x_p = self.p_net(x_p) 236 | x_c = self.c_net(x_c) 237 | 238 | x_p_0 = self.conv1(x_p) 239 | x_p_1 = self.conv2(x_p_0) 240 | 241 | x_c_0 = self.conv4(x_c) 242 | x_c_1 = self.conv5(x_c_0) 243 | 244 | x_p_c = self.conv3(x_p_0 + x_c_0) 245 | return x_p_1, x_c_1, x_p_c 246 | 247 | 248 | class BoundaryModulePlus(BoundaryModule): 249 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512, 250 | max_duration: int = 40 251 | ): 252 | super().__init__(n_feature_in, n_features, num_samples, temporal_dim, max_duration) 253 | del self.block1 254 | dim0, dim1 = n_features 255 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim) 256 | self.block1 = Sequential( 257 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU), 258 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU) 259 | ) 260 | # Proposal Relation Block in BSN++ mechanism 261 | self.proposal_block = ProposalRelationBlock(dim1, dim1, 1, sub_sample=True) 262 | self.out = Rearrange("b c d t -> b (c d) t") 263 | 264 | def forward(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 265 | confidence_map = self.bm_layer(feature) 266 | confidence_map = self.block0(confidence_map) 267 | confidence_map = self.block1(confidence_map) 268 | confidence_map_p, confidence_map_c, confidence_map_p_c = self.proposal_block(confidence_map) 269 | 270 | confidence_map_p = self.out(confidence_map_p) 271 | confidence_map_c = self.out(confidence_map_c) 272 | confidence_map_p_c = self.out(confidence_map_p_c) 273 | return confidence_map_p, confidence_map_c, confidence_map_p_c -------------------------------------------------------------------------------- /examples/batfd/batfd/model/frame_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import Module 3 | 4 | from ..utils import Conv1d 5 | 6 | 7 | class FrameLogisticRegression(Module): 8 | """ 9 | Frame classifier (FC_v and FC_a) for video feature (F_v) and audio feature (F_a). 10 | Input: 11 | F_v or F_a: (B, C_f, T) 12 | Output: 13 | Y^: (B, 1, T) 14 | """ 15 | 16 | def __init__(self, n_features: int): 17 | super().__init__() 18 | self.lr_layer = Conv1d(n_features, 1, kernel_size=1) 19 | 20 | def forward(self, features: Tensor) -> Tensor: 21 | return self.lr_layer(features) 22 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/fusion_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Sigmoid, Module 4 | 5 | from ..utils import Conv1d 6 | 7 | 8 | class ModalFeatureAttnBoundaryMapFusion(Module): 9 | """ 10 | Fusion module for video and audio boundary maps. 11 | 12 | Input: 13 | F_v: (B, C_f, T) 14 | F_a: (B, C_f, T) 15 | M_v^: (B, D, T) 16 | M_a^: (B, D, T) 17 | 18 | Output: 19 | M^: (B, D, T) 20 | """ 21 | 22 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257, max_duration: int = 40): 23 | super().__init__() 24 | 25 | self.a_attn_block = ModalMapAttnBlock(n_audio_features, n_video_features, max_duration) 26 | self.v_attn_block = ModalMapAttnBlock(n_video_features, n_audio_features, max_duration) 27 | 28 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_bm: Tensor, audio_bm: Tensor) -> Tensor: 29 | a_attn = self.a_attn_block(audio_bm, audio_feature, video_feature) 30 | v_attn = self.v_attn_block(video_bm, video_feature, audio_feature) 31 | 32 | sum_attn = a_attn + v_attn 33 | 34 | a_w = a_attn / sum_attn 35 | v_w = v_attn / sum_attn 36 | 37 | fusion_bm = video_bm * v_w + audio_bm * a_w 38 | return fusion_bm 39 | 40 | 41 | class ModalMapAttnBlock(Module): 42 | 43 | def __init__(self, n_self_features: int, n_another_features: int, max_duration: int = 40): 44 | super().__init__() 45 | self.attn_from_self_features = Conv1d(n_self_features, max_duration, kernel_size=1) 46 | self.attn_from_another_features = Conv1d(n_another_features, max_duration, kernel_size=1) 47 | self.attn_from_bm = Conv1d(max_duration, max_duration, kernel_size=1) 48 | self.sigmoid = Sigmoid() 49 | 50 | def forward(self, self_bm: Tensor, self_features: Tensor, another_features: Tensor) -> Tensor: 51 | w_bm = self.attn_from_bm(self_bm) 52 | w_self_feat = self.attn_from_self_features(self_features) 53 | w_another_feat = self.attn_from_another_features(another_features) 54 | w_stack = torch.stack((w_bm, w_self_feat, w_another_feat), dim=3) 55 | w = w_stack.mean(dim=3) 56 | return self.sigmoid(w) 57 | 58 | 59 | class ModalFeatureAttnCfgFusion(ModalFeatureAttnBoundaryMapFusion): 60 | 61 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257): 62 | super().__init__() 63 | self.a_attn_block = ModalCbgAttnBlock(n_audio_features, n_video_features) 64 | self.v_attn_block = ModalCbgAttnBlock(n_video_features, n_audio_features) 65 | 66 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_cfg: Tensor, audio_cfg: Tensor) -> Tensor: 67 | video_cfg = video_cfg.unsqueeze(1) 68 | audio_cfg = audio_cfg.unsqueeze(1) 69 | fusion_cfg = super().forward(video_feature, audio_feature, video_cfg, audio_cfg) 70 | return fusion_cfg.squeeze(1) 71 | 72 | 73 | class ModalCbgAttnBlock(ModalMapAttnBlock): 74 | 75 | def __init__(self, n_self_features: int, n_another_features: int): 76 | super().__init__(n_self_features, n_another_features, 1) -------------------------------------------------------------------------------- /examples/batfd/batfd/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, MSELoss 4 | 5 | 6 | class MaskedBMLoss(Module): 7 | 8 | def __init__(self, loss_fn: Module): 9 | super().__init__() 10 | self.loss_fn = loss_fn 11 | 12 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 13 | loss = [] 14 | for i, frame in enumerate(n_frames): 15 | loss.append(self.loss_fn(pred[i, :, :frame], true[i, :, :frame])) 16 | return torch.mean(torch.stack(loss)) 17 | 18 | 19 | class MaskedFrameLoss(Module): 20 | 21 | def __init__(self, loss_fn: Module): 22 | super().__init__() 23 | self.loss_fn = loss_fn 24 | 25 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 26 | # input: (B, T) 27 | loss = [] 28 | for i, frame in enumerate(n_frames): 29 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame])) 30 | return torch.mean(torch.stack(loss)) 31 | 32 | 33 | class MaskedContrastLoss(Module): 34 | 35 | def __init__(self, margin: float = 0.99): 36 | super().__init__() 37 | self.margin = margin 38 | 39 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor, n_frames: Tensor): 40 | # input: (B, C, T) 41 | loss = [] 42 | for i, frame in enumerate(n_frames): 43 | # mean L2 distance squared 44 | d = torch.dist(pred1[i, :, :frame], pred2[i, :, :frame], 2) 45 | if labels[i]: 46 | # if is positive pair, minimize distance 47 | loss.append(d ** 2) 48 | else: 49 | # if is negative pair, minimize (margin - distance) if distance < margin 50 | loss.append(torch.clip(self.margin - d, min=0.) ** 2) 51 | return torch.mean(torch.stack(loss)) 52 | 53 | 54 | class MaskedMSE(Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | self.loss_fn = MSELoss() 59 | 60 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 61 | loss = [] 62 | for i, frame in enumerate(n_frames): 63 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame])) 64 | return torch.mean(torch.stack(loss)) 65 | 66 | 67 | class MaskedBsnppLoss(Module): 68 | """Simplified version of BSN++ loss function.""" 69 | 70 | def __init__(self, cbg_feature_weight=0.01, prb_weight_forward=1): 71 | super().__init__() 72 | self.cbg_feature_weight = cbg_feature_weight 73 | self.prb_weight_forward = prb_weight_forward 74 | 75 | self.cbg_loss_func = MaskedMSE() 76 | self.cbg_feature_loss = MaskedBMLoss(MSELoss()) 77 | self.bsnpp_pem_reg_loss_func = self.cbg_feature_loss 78 | 79 | def forward(self, pred_bm_p, pred_bm_c, pred_bm_p_c, pred_start, pred_end, 80 | pred_start_backward, pred_end_backward, gt_iou_map, gt_start, gt_end, n_frames, 81 | feature_forward=None, feature_backward=None 82 | ): 83 | if self.cbg_feature_weight > 0: 84 | cbg_loss_forward = self.cbg_loss_func(pred_start, gt_start, n_frames) + \ 85 | self.cbg_loss_func(pred_end, gt_end, n_frames) 86 | cbg_loss_backward = self.cbg_loss_func(torch.flip(pred_end_backward, dims=(1,)), gt_start, n_frames) + \ 87 | self.cbg_loss_func(torch.flip(pred_start_backward, dims=(1,)), gt_end, n_frames) 88 | 89 | cbg_loss = cbg_loss_forward + cbg_loss_backward 90 | if feature_forward is not None and feature_backward is not None: 91 | inter_feature_loss = self.cbg_feature_weight * self.cbg_feature_loss(feature_forward, 92 | torch.flip(feature_backward, dims=(2,)), n_frames) 93 | cbg_loss += inter_feature_loss 94 | else: 95 | inter_feature_loss = None 96 | else: 97 | cbg_loss = None 98 | cbg_loss_forward = None 99 | cbg_loss_backward = None 100 | inter_feature_loss = None 101 | 102 | prb_reg_loss_p = self.bsnpp_pem_reg_loss_func(pred_bm_p, gt_iou_map, n_frames) 103 | prb_reg_loss_c = self.bsnpp_pem_reg_loss_func(pred_bm_c, gt_iou_map, n_frames) 104 | prb_reg_loss_p_c = self.bsnpp_pem_reg_loss_func(pred_bm_p_c, gt_iou_map, n_frames) 105 | prb_loss = prb_reg_loss_p + prb_reg_loss_c + prb_reg_loss_p_c 106 | 107 | loss = cbg_loss + prb_loss if cbg_loss is not None else prb_loss 108 | return loss, cbg_loss, prb_loss, cbg_loss_forward, cbg_loss_backward, inter_feature_loss 109 | 110 | 111 | # Non-masked versions of the loss functions 112 | class BMLoss(Module): 113 | """Non-masked version of MaskedBMLoss.""" 114 | 115 | def __init__(self, loss_fn: Module): 116 | super().__init__() 117 | self.loss_fn = loss_fn 118 | 119 | def forward(self, pred: Tensor, true: Tensor): 120 | return self.loss_fn(pred, true) 121 | 122 | 123 | class FrameLoss(Module): 124 | def __init__(self, loss_fn: Module): 125 | super().__init__() 126 | self.loss_fn = loss_fn 127 | 128 | def forward(self, pred: Tensor, true: Tensor): 129 | # input: (B, T) 130 | return self.loss_fn(pred, true) 131 | 132 | 133 | class ContrastLoss(Module): 134 | def __init__(self, margin: float = 0.99): 135 | super().__init__() 136 | self.margin = margin 137 | 138 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor): 139 | # input: (B, C, T) 140 | batch_size = pred1.size(0) 141 | loss = [] 142 | for i in range(batch_size): 143 | # mean L2 distance squared 144 | d = torch.dist(pred1[i], pred2[i], 2) 145 | if labels[i]: 146 | # if is positive pair, minimize distance 147 | loss.append(d ** 2) 148 | else: 149 | # if is negative pair, minimize (margin - distance) if distance < margin 150 | loss.append(torch.clip(self.margin - d, min=0.) ** 2) 151 | return torch.mean(torch.stack(loss)) 152 | 153 | 154 | class BsnppLoss(Module): 155 | def __init__(self, cbg_feature_weight=0.01, prb_weight_forward=1): 156 | super().__init__() 157 | self.cbg_feature_weight = cbg_feature_weight 158 | self.prb_weight_forward = prb_weight_forward 159 | 160 | self.cbg_loss_func = MSELoss() 161 | self.cbg_feature_loss = BMLoss(MSELoss()) 162 | self.bsnpp_pem_reg_loss_func = self.cbg_feature_loss 163 | 164 | def forward(self, pred_bm_p, pred_bm_c, pred_bm_p_c, pred_start, pred_end, 165 | pred_start_backward, pred_end_backward, gt_iou_map, gt_start, gt_end, 166 | feature_forward=None, feature_backward=None 167 | ): 168 | if self.cbg_feature_weight > 0: 169 | cbg_loss_forward = self.cbg_loss_func(pred_start, gt_start) + \ 170 | self.cbg_loss_func(pred_end, gt_end) 171 | cbg_loss_backward = self.cbg_loss_func(torch.flip(pred_end_backward, dims=(1,)), gt_start) + \ 172 | self.cbg_loss_func(torch.flip(pred_start_backward, dims=(1,)), gt_end) 173 | 174 | cbg_loss = cbg_loss_forward + cbg_loss_backward 175 | if feature_forward is not None and feature_backward is not None: 176 | inter_feature_loss = self.cbg_feature_weight * self.cbg_feature_loss(feature_forward, 177 | torch.flip(feature_backward, dims=(2,))) 178 | cbg_loss += inter_feature_loss 179 | else: 180 | inter_feature_loss = None 181 | else: 182 | cbg_loss = None 183 | cbg_loss_forward = None 184 | cbg_loss_backward = None 185 | inter_feature_loss = None 186 | 187 | prb_reg_loss_p = self.bsnpp_pem_reg_loss_func(pred_bm_p, gt_iou_map) 188 | prb_reg_loss_c = self.bsnpp_pem_reg_loss_func(pred_bm_c, gt_iou_map) 189 | prb_reg_loss_p_c = self.bsnpp_pem_reg_loss_func(pred_bm_p_c, gt_iou_map) 190 | prb_loss = prb_reg_loss_p + prb_reg_loss_c + prb_reg_loss_p_c 191 | 192 | loss = cbg_loss + prb_loss if cbg_loss is not None else prb_loss 193 | return loss, cbg_loss, prb_loss, cbg_loss_forward, cbg_loss_backward, inter_feature_loss 194 | -------------------------------------------------------------------------------- /examples/batfd/batfd/model/video_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import Sequential, LeakyReLU, MaxPool3d, Module, Linear 7 | from torchvision.models.video.mvit import MSBlockConfig, _mvit 8 | 9 | from ..utils import Conv3d, Conv1d 10 | 11 | 12 | class C3DVideoEncoder(Module): 13 | """ 14 | Video encoder (E_v): Process video frames to extract features. 15 | Input: 16 | V: (B, C, T, H, W) 17 | Output: 18 | F_v: (B, C_f, T) 19 | """ 20 | 21 | def __init__(self, n_features=(64, 96, 128, 128), v_cla_feature_in: int = 256): 22 | super().__init__() 23 | 24 | n_dim0, n_dim1, n_dim2, n_dim3 = n_features 25 | 26 | # (B, 3, 512, 96, 96) -> (B, 64, 512, 32, 32) 27 | self.block0 = Sequential( 28 | Conv3d(3, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 29 | Conv3d(n_dim0, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 30 | MaxPool3d((1, 3, 3)) 31 | ) 32 | 33 | # (B, 64, 512, 32, 32) -> (B, 96, 512, 16, 16) 34 | self.block1 = Sequential( 35 | Conv3d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 36 | Conv3d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 37 | MaxPool3d((1, 2, 2)) 38 | ) 39 | 40 | # (B, 96, 512, 16, 16) -> (B, 128, 512, 8, 8) 41 | self.block2 = Sequential( 42 | Conv3d(n_dim1, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 43 | Conv3d(n_dim2, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 44 | MaxPool3d((1, 2, 2)) 45 | ) 46 | 47 | # (B, 128, 512, 8, 8) -> (B, 128, 512, 2, 2) -> (B, 512, 512) -> (B, 256, 512) 48 | self.block3 = Sequential( 49 | Conv3d(n_dim2, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 50 | MaxPool3d((1, 2, 2)), 51 | Conv3d(n_dim3, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 52 | MaxPool3d((1, 2, 2)), 53 | Rearrange("b c t h w -> b (c h w) t"), 54 | Conv1d(n_dim3 * 4, v_cla_feature_in, kernel_size=1, stride=1, build_activation=LeakyReLU) 55 | ) 56 | 57 | def forward(self, video: Tensor) -> Tensor: 58 | x = self.block0(video) 59 | x = self.block1(x) 60 | x = self.block2(x) 61 | x = self.block3(x) 62 | return x 63 | 64 | 65 | class MvitVideoEncoder(Module): 66 | 67 | def __init__(self, v_cla_feature_in: int = 256, 68 | temporal_size: int = 512, 69 | mvit_type: Literal["mvit_v2_t", "mvit_v2_s", "mvit_v2_b"] = "mvit_v2_t" 70 | ): 71 | super().__init__() 72 | if mvit_type == "mvit_v2_t": 73 | self.mvit = mvit_v2_t(v_cla_feature_in, temporal_size) 74 | elif mvit_type == "mvit_v2_s": 75 | self.mvit = mvit_v2_s(v_cla_feature_in, temporal_size) 76 | elif mvit_type == "mvit_v2_b": 77 | self.mvit = mvit_v2_b(v_cla_feature_in, temporal_size) 78 | else: 79 | raise ValueError(f"Invalid mvit_type: {mvit_type}") 80 | del self.mvit.head 81 | 82 | def forward(self, video: Tensor) -> Tensor: 83 | feat = self.mvit.conv_proj(video) 84 | feat = feat.flatten(2).transpose(1, 2) 85 | feat = self.mvit.pos_encoding(feat) 86 | thw = (self.mvit.pos_encoding.temporal_size,) + self.mvit.pos_encoding.spatial_size 87 | for block in self.mvit.blocks: 88 | feat, thw = block(feat, thw) 89 | 90 | feat = self.mvit.norm(feat) 91 | feat = feat[:, 1:] 92 | feat = feat.permute(0, 2, 1) 93 | return feat 94 | 95 | 96 | def generate_config(blocks, heads, channels, out_dim): 97 | num_heads = [] 98 | input_channels = [] 99 | kernel_qkv = [] 100 | stride_q = [[1, 1, 1]] * sum(blocks) 101 | blocks_cum = np.cumsum(blocks) 102 | stride_kv = [] 103 | 104 | for i in range(len(blocks)): 105 | num_heads.extend([heads[i]] * blocks[i]) 106 | input_channels.extend([channels[i]] * blocks[i]) 107 | kernel_qkv.extend([[3, 3, 3]] * blocks[i]) 108 | 109 | if i != len(blocks) - 1: 110 | stride_q[blocks_cum[i]] = [1, 2, 2] 111 | 112 | stride_kv_value = 2 ** (len(blocks) - 1 - i) 113 | stride_kv.extend([[1, stride_kv_value, stride_kv_value]] * blocks[i]) 114 | 115 | return { 116 | "num_heads": num_heads, 117 | "input_channels": [input_channels[0]] + input_channels[:-1], 118 | "output_channels": input_channels[:-1] + [out_dim], 119 | "kernel_q": kernel_qkv, 120 | "kernel_kv": kernel_qkv, 121 | "stride_q": stride_q, 122 | "stride_kv": stride_kv 123 | } 124 | 125 | 126 | def build_mvit(config, kwargs, temporal_size=512): 127 | block_setting = [] 128 | for i in range(len(config["num_heads"])): 129 | block_setting.append( 130 | MSBlockConfig( 131 | num_heads=config["num_heads"][i], 132 | input_channels=config["input_channels"][i], 133 | output_channels=config["output_channels"][i], 134 | kernel_q=config["kernel_q"][i], 135 | kernel_kv=config["kernel_kv"][i], 136 | stride_q=config["stride_q"][i], 137 | stride_kv=config["stride_kv"][i], 138 | ) 139 | ) 140 | return _mvit( 141 | spatial_size=(96, 96), 142 | temporal_size=temporal_size, 143 | block_setting=block_setting, 144 | residual_pool=True, 145 | residual_with_cls_embed=False, 146 | rel_pos_embed=True, 147 | proj_after_attn=True, 148 | stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), 149 | weights=None, 150 | progress=False, 151 | patch_embed_kernel=(3, 15, 15), 152 | patch_embed_stride=(1, 12, 12), 153 | patch_embed_padding=(1, 3, 3), 154 | **kwargs, 155 | ) 156 | 157 | 158 | def mvit_v2_b(out_dim: int, temporal_size: int, **kwargs): 159 | config = generate_config([2, 3, 16, 3], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 160 | return build_mvit(config, kwargs, temporal_size=temporal_size) 161 | 162 | 163 | def mvit_v2_s(out_dim: int, temporal_size: int, **kwargs): 164 | config = generate_config([1, 2, 11, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 165 | return build_mvit(config, kwargs, temporal_size=temporal_size) 166 | 167 | 168 | def mvit_v2_t(out_dim: int, temporal_size: int, **kwargs): 169 | config = generate_config([1, 2, 5, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 170 | return build_mvit(config, kwargs, temporal_size=temporal_size) 171 | 172 | 173 | class VideoFeatureProjection(Module): 174 | 175 | def __init__(self, input_feature_dim: int, v_cla_feature_in: int = 256): 176 | super().__init__() 177 | self.proj = Linear(input_feature_dim, v_cla_feature_in) 178 | 179 | def forward(self, x: Tensor) -> Tensor: 180 | x = self.proj(x) 181 | return x.permute(0, 2, 1) 182 | 183 | 184 | def get_video_encoder(v_cla_feature_in, temporal_size, v_encoder, ve_features): 185 | if v_encoder == "c3d": 186 | video_encoder = C3DVideoEncoder(n_features=ve_features, v_cla_feature_in=v_cla_feature_in) 187 | elif v_encoder == "mvit_t": 188 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_t") 189 | elif v_encoder == "mvit_s": 190 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_s") 191 | elif v_encoder == "mvit_b": 192 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_b") 193 | elif v_encoder == "marlin_vit_small": 194 | video_encoder = VideoFeatureProjection(input_feature_dim=13824, v_cla_feature_in=v_cla_feature_in) 195 | elif v_encoder == "i3d": 196 | video_encoder = VideoFeatureProjection(input_feature_dim=2048, v_cla_feature_in=v_cla_feature_in) 197 | elif v_encoder == "3dmm": 198 | video_encoder = VideoFeatureProjection(input_feature_dim=393, v_cla_feature_in=v_cla_feature_in) 199 | else: 200 | raise ValueError(f"Invalid video encoder: {v_encoder}") 201 | return video_encoder 202 | -------------------------------------------------------------------------------- /examples/batfd/batfd/post_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | from concurrent.futures import ProcessPoolExecutor 4 | from os import cpu_count 5 | from typing import List 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm.auto import tqdm 10 | 11 | from avdeepfake1m.loader import Metadata 12 | from avdeepfake1m.utils import iou_with_anchors 13 | 14 | 15 | def soft_nms(df, alpha, t1, t2, fps): 16 | df = df.sort_values(by="score", ascending=False) 17 | t_start = list(df.begin.values[:] / fps) 18 | t_end = list(df.end.values[:] / fps) 19 | t_score = list(df.score.values[:]) 20 | 21 | r_start = [] 22 | r_end = [] 23 | r_score = [] 24 | 25 | while len(t_score) > 1 and len(r_score) < 101: 26 | max_index = t_score.index(max(t_score)) 27 | tmp_iou_list = iou_with_anchors( 28 | np.array(t_start), 29 | np.array(t_end), t_start[max_index], t_end[max_index]) 30 | for idx in range(0, len(t_score)): 31 | if idx != max_index: 32 | tmp_iou = tmp_iou_list[idx] 33 | tmp_width = t_end[max_index] - t_start[max_index] 34 | if tmp_iou > t1 + (t2 - t1) * tmp_width: 35 | t_score[idx] *= np.exp(-np.square(tmp_iou) / alpha) 36 | 37 | r_start.append(t_start[max_index]) 38 | r_end.append(t_end[max_index]) 39 | r_score.append(t_score[max_index]) 40 | t_start.pop(max_index) 41 | t_end.pop(max_index) 42 | t_score.pop(max_index) 43 | 44 | new_df = pd.DataFrame() 45 | new_df['score'] = r_score 46 | new_df['begin'] = r_start 47 | new_df['end'] = r_end 48 | return new_df 49 | 50 | 51 | def video_post_process(meta, model_name, fps=25, alpha=0.4, t1=0.2, t2=0.9, dataset_name="avdeepfake1m", output_dir="output"): 52 | file = resolve_csv_file_name(meta, dataset_name) 53 | df = pd.read_csv(os.path.join(output_dir, model_name, file)) 54 | 55 | if len(df) > 1: 56 | df = soft_nms(df, alpha, t1, t2, fps) 57 | 58 | df = df.sort_values(by="score", ascending=False) 59 | 60 | proposal_list = [] 61 | 62 | for j in range(len(df)): 63 | # round the score for saving json size 64 | score = round(df.score.values[j], 4) 65 | 66 | if score > 0: 67 | proposal_list.append([ 68 | score, 69 | round(df.begin.values[j].item(), 2), 70 | round(df.end.values[j].item(), 2) 71 | ]) 72 | 73 | return [meta.file, proposal_list] 74 | 75 | 76 | def resolve_csv_file_name(meta: Metadata, dataset_name: str = "avdeepfake1m") -> str: 77 | if dataset_name in ("avdeepfake1m", "avdeepfake1m++"): 78 | return meta.file.replace("/", "_").replace(".mp4", ".csv") 79 | else: 80 | raise NotImplementedError 81 | 82 | 83 | def post_process(model_name: str, save_path: str, metadata: List[Metadata], fps=25, 84 | alpha=0.4, t1=0.2, t2=0.9, dataset_name="avdeepfake1m", output_dir="output" 85 | ): 86 | with ProcessPoolExecutor(cpu_count() // 2 - 1) as executor: 87 | futures = [] 88 | for meta in metadata: 89 | futures.append(executor.submit(video_post_process, meta, model_name, fps, 90 | alpha, t1, t2, dataset_name, output_dir 91 | )) 92 | 93 | results = dict(map(lambda x: x.result(), tqdm(futures))) 94 | 95 | with open(save_path, "w") as f: 96 | json.dump(results, f, indent=4) 97 | -------------------------------------------------------------------------------- /examples/batfd/batfd/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from lightning.pytorch import LightningModule, Trainer, Callback 3 | import torch 4 | from typing import Callable, Optional 5 | from abc import ABC 6 | from torch import Tensor 7 | from torch.nn import Module 8 | 9 | 10 | class _ConvNd(Module, ABC): 11 | 12 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, 13 | build_activation: Optional[Callable] = None 14 | ): 15 | super().__init__() 16 | self.conv = self.PtConv( 17 | in_channels, out_channels, kernel_size, stride=stride, padding=padding 18 | ) 19 | if build_activation is not None: 20 | self.activation = build_activation() 21 | else: 22 | self.activation = None 23 | 24 | def forward(self, x: Tensor) -> Tensor: 25 | x = self.conv(x) 26 | if self.activation is not None: 27 | x = self.activation(x) 28 | return x 29 | 30 | 31 | class Conv1d(_ConvNd): 32 | PtConv = torch.nn.Conv1d 33 | 34 | 35 | class Conv2d(_ConvNd): 36 | PtConv = torch.nn.Conv2d 37 | 38 | 39 | class Conv3d(_ConvNd): 40 | PtConv = torch.nn.Conv3d 41 | 42 | 43 | class LrLogger(Callback): 44 | """Log learning rate in each epoch start.""" 45 | 46 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 47 | for i, optimizer in enumerate(trainer.optimizers): 48 | for j, params in enumerate(optimizer.param_groups): 49 | key = f"opt{i}_lr{j}" 50 | value = params["lr"] 51 | pl_module.logger.log_metrics({key: value}, step=trainer.global_step) 52 | pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed) 53 | 54 | 55 | class EarlyStoppingLR(Callback): 56 | """Early stop model training when the LR is lower than threshold.""" 57 | 58 | def __init__(self, lr_threshold: float, mode="all"): 59 | self.lr_threshold = lr_threshold 60 | 61 | if mode in ("any", "all"): 62 | self.mode = mode 63 | else: 64 | raise ValueError(f"mode must be one of ('any', 'all')") 65 | 66 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 67 | self._run_early_stop_checking(trainer) 68 | 69 | def _run_early_stop_checking(self, trainer: Trainer) -> None: 70 | metrics = trainer._logger_connector.callback_metrics 71 | if len(metrics) == 0: 72 | return 73 | all_lr = [] 74 | for key, value in metrics.items(): 75 | if re.match(r"opt\d+_lr\d+", key): 76 | all_lr.append(value) 77 | 78 | if len(all_lr) == 0: 79 | return 80 | 81 | if self.mode == "all": 82 | if all(lr <= self.lr_threshold for lr in all_lr): 83 | trainer.should_stop = True 84 | elif self.mode == "any": 85 | if any(lr <= self.lr_threshold for lr in all_lr): 86 | trainer.should_stop = True -------------------------------------------------------------------------------- /examples/batfd/batfd_plus.toml: -------------------------------------------------------------------------------- 1 | name = "batfd_plus" 2 | num_frames = 100 # T 3 | max_duration = 30 # D 4 | model_type = "batfd_plus" 5 | dataset = "avdeepfake1m++" 6 | 7 | [model.video_encoder] 8 | type = "mvit_b" 9 | hidden_dims = [] # handled by model type 10 | cla_feature_in = 256 # C_f 11 | 12 | [model.audio_encoder] 13 | type = "vit_b" 14 | hidden_dims = [] # handled by model type 15 | cla_feature_in = 256 # C_f 16 | 17 | [model.frame_classifier] 18 | type = "lr" 19 | 20 | [model.boundary_module] 21 | hidden_dims = [512, 128] 22 | samples = 10 # N 23 | 24 | [optimizer] 25 | learning_rate = 0.00001 26 | frame_loss_weight = 2.0 27 | modal_bm_loss_weight = 1.0 28 | cbg_feature_weight = 0.0 29 | prb_weight_forward = 1.0 30 | contrastive_loss_weight = 0.1 31 | contrastive_loss_margin = 0.99 32 | weight_decay = 0.0001 33 | 34 | [soft_nms] 35 | alpha = 0.7234 36 | t1 = 0.1968 37 | t2 = 0.4123 -------------------------------------------------------------------------------- /examples/batfd/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | 5 | from avdeepfake1m.evaluation import ap_ar_1d 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description="Evaluation script for BATFD/BATFD+ models on AV-Deepfake1M") 9 | parser.add_argument("prediction_file_path", type=str, help="Path to the prediction JSON file (e.g., output/batfd_test.json)") 10 | parser.add_argument("metadata_file_path", type=str, help="Path to the metadata JSON file (e.g., /path/to/dataset/test_metadata.json or /path/to/dataset/val_metadata.json)") 11 | args = parser.parse_args() 12 | 13 | print(f"Calculating AP/AR for prediction file: {args.prediction_file_path}") 14 | print(f"Using metadata file: {args.metadata_file_path}") 15 | 16 | # Parameters for ap_ar_1d based on README.md 17 | file_key = "file" 18 | value_key = "fake_segments" # For ground truth in metadata 19 | fps = 1.0 20 | ap_iou_thresholds = [0.5, 0.75, 0.9, 0.95] 21 | ar_n_proposals = [50, 30, 20, 10, 5] 22 | ar_iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 23 | 24 | ap_ar_results = ap_ar_1d( 25 | proposals_path=args.prediction_file_path, 26 | labels_path=args.metadata_file_path, 27 | file_key=file_key, 28 | value_key=value_key, 29 | fps=fps, 30 | ap_iou_thresholds=ap_iou_thresholds, 31 | ar_n_proposals=ar_n_proposals, 32 | ar_iou_thresholds=ar_iou_thresholds 33 | ) 34 | 35 | print(ap_ar_results) 36 | 37 | score = 0.5 * sum(ap_ar_results["ap"].values()) / len(ap_ar_results["ap"]) \ 38 | + 0.5 * sum(ap_ar_results["ar"].values()) / len(ap_ar_results["ar"]) 39 | 40 | print(f"Score: {score}") 41 | -------------------------------------------------------------------------------- /examples/batfd/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import toml 3 | import torch 4 | import os 5 | from pathlib import Path 6 | 7 | from avdeepfake1m.loader import AVDeepfake1mDataModule, Metadata 8 | from batfd.model import Batfd, BatfdPlus 9 | from batfd.inference import inference_model 10 | from batfd.post_process import post_process 11 | from avdeepfake1m.utils import read_json 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description="BATFD/BATFD+ Inference") 15 | parser.add_argument("--config", type=str, required=True, 16 | help="Path to the TOML configuration file.") 17 | parser.add_argument("--checkpoint", type=str, required=True, 18 | help="Path to the model checkpoint.") 19 | parser.add_argument("--data_root", type=str, required=True, 20 | help="Root directory of the dataset.") 21 | parser.add_argument("--num_workers", type=int, default=8, 22 | help="Number of workers for data loading.") 23 | parser.add_argument("--subset", type=str, choices=["val", "test", "testA", "testB"], 24 | default="test", help="Dataset subset.") 25 | parser.add_argument("--gpus", type=int, default=1, 26 | help="Number of GPUs. Set to 0 for CPU.") 27 | 28 | args = parser.parse_args() 29 | 30 | # Determine device 31 | if args.gpus > 0 and torch.cuda.is_available(): 32 | device = f"cuda:{torch.cuda.current_device()}" 33 | else: 34 | device = "cpu" 35 | 36 | print(f"Using device: {device}") 37 | 38 | config = toml.load(args.config) 39 | temp_dir = "output" 40 | output_file = os.path.join(temp_dir, f"{config['name']}_{args.subset}.json") 41 | model_type = config["model_type"] 42 | 43 | if model_type == "batfd_plus": 44 | model = BatfdPlus.load_from_checkpoint(args.checkpoint) 45 | elif model_type == "batfd": 46 | model = Batfd.load_from_checkpoint(args.checkpoint) 47 | else: 48 | raise ValueError(f"Unknown model type: {model_type}") 49 | 50 | model.eval() 51 | 52 | # Setup DataModule 53 | dm_dataset_name = config["dataset"] 54 | is_plusplus = dm_dataset_name == "avdeepfake1m++" 55 | 56 | dm = AVDeepfake1mDataModule( 57 | root=args.data_root, 58 | temporal_size=config["num_frames"], 59 | max_duration=config["max_duration"], 60 | require_match_scores=False, 61 | batch_size=1, # due to the problem from lightning, only 1 is supported 62 | num_workers=args.num_workers, 63 | get_meta_attr=model.get_meta_attr, 64 | return_file_name=True, 65 | is_plusplus=is_plusplus, 66 | test_subset=args.subset if args.subset in ("test", "testA", "testB") else None 67 | ) 68 | dm.setup() 69 | 70 | Path(output_file).parent.mkdir(parents=True, exist_ok=True) 71 | Path(temp_dir).mkdir(parents=True, exist_ok=True) 72 | 73 | if args.subset in ("test", "testA", "testB"): 74 | dataloader = dm.test_dataloader() 75 | metadata_path = os.path.join(dm.root, f"{args.subset}_metadata.json") 76 | elif args.subset == "val": 77 | dataloader = dm.val_dataloader() 78 | metadata_path = os.path.join(dm.root, "val_metadata.json") 79 | else: 80 | raise ValueError("Invalid subset") 81 | 82 | if os.path.exists(metadata_path): 83 | metadata = [Metadata(**each, fps=25) for each in read_json(metadata_path)] 84 | else: 85 | metadata = [ 86 | Metadata(file=file_name, 87 | original=None, 88 | split=args.subset, 89 | fake_segments=[], 90 | fps=25, 91 | visual_fake_segments=[], 92 | audio_fake_segments=[], 93 | audio_model="", 94 | modify_type="", 95 | # handle by the predictor in `inference_model` 96 | video_frames=-1, 97 | audio_frames=-1) 98 | for file_name in dataloader.dataset.file_list 99 | ] 100 | 101 | inference_model( 102 | model_name=config["name"], 103 | model=model, 104 | dataloader=dataloader, 105 | metadata=metadata, 106 | max_duration=config["max_duration"], 107 | model_type=config["model_type"], 108 | gpus=args.gpus, 109 | temp_dir=temp_dir 110 | ) 111 | 112 | post_process( 113 | model_name=config["name"], 114 | save_path=output_file, 115 | metadata=metadata, 116 | fps=25, 117 | alpha=config["soft_nms"]["alpha"], 118 | t1=config["soft_nms"]["t1"], 119 | t2=config["soft_nms"]["t2"], 120 | dataset_name=dm_dataset_name, 121 | output_dir=temp_dir 122 | ) 123 | 124 | print(f"Inference complete. Results saved to {output_file}") 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /examples/batfd/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import toml 4 | from lightning.pytorch import Trainer 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | import torch 7 | torch.set_float32_matmul_precision('high') 8 | 9 | from avdeepfake1m.loader import AVDeepfake1mDataModule 10 | from batfd.model import Batfd, BatfdPlus 11 | from batfd.utils import LrLogger, EarlyStoppingLR 12 | 13 | parser = argparse.ArgumentParser(description="BATFD training") 14 | parser.add_argument("--config", type=str) 15 | parser.add_argument("--data_root", type=str) 16 | parser.add_argument("--batch_size", type=int, default=4) 17 | parser.add_argument("--num_workers", type=int, default=8) 18 | parser.add_argument("--gpus", type=int, default=1) 19 | parser.add_argument("--precision", default=32) 20 | parser.add_argument("--num_train", type=int, default=None) 21 | parser.add_argument("--num_val", type=int, default=1000) 22 | parser.add_argument("--max_epochs", type=int, default=500) 23 | parser.add_argument("--logger", type=str, choices=["wandb", "tensorboard"], default="tensorboard") 24 | parser.add_argument("--resume", type=str, default=None) 25 | 26 | if __name__ == '__main__': 27 | args = parser.parse_args() 28 | config = toml.load(args.config) 29 | 30 | learning_rate = config["optimizer"]["learning_rate"] 31 | gpus = args.gpus 32 | total_batch_size = args.batch_size * gpus 33 | learning_rate = learning_rate * total_batch_size / 4 34 | dataset = config["dataset"] 35 | 36 | v_encoder_type = config["model"]["video_encoder"]["type"] 37 | a_encoder_type = config["model"]["audio_encoder"]["type"] 38 | 39 | if v_encoder_type in ("marlin_vit_small", "3dmm", "i3d"): 40 | v_feature = v_encoder_type 41 | else: 42 | v_feature = None 43 | 44 | if a_encoder_type in ("deep_speech", "wav2vec2", "trill"): 45 | a_feature = a_encoder_type 46 | else: 47 | a_feature = None 48 | 49 | if config["model_type"] == "batfd_plus": 50 | model = BatfdPlus( 51 | v_encoder=v_encoder_type, 52 | a_encoder=config["model"]["audio_encoder"]["type"], 53 | frame_classifier=config["model"]["frame_classifier"]["type"], 54 | ve_features=config["model"]["video_encoder"]["hidden_dims"], 55 | ae_features=config["model"]["audio_encoder"]["hidden_dims"], 56 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"], 57 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"], 58 | boundary_features=config["model"]["boundary_module"]["hidden_dims"], 59 | boundary_samples=config["model"]["boundary_module"]["samples"], 60 | temporal_dim=config["num_frames"], 61 | max_duration=config["max_duration"], 62 | weight_frame_loss=config["optimizer"]["frame_loss_weight"], 63 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"], 64 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"], 65 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"], 66 | cbg_feature_weight=config["optimizer"]["cbg_feature_weight"], 67 | prb_weight_forward=config["optimizer"]["prb_weight_forward"], 68 | weight_decay=config["optimizer"]["weight_decay"], 69 | learning_rate=learning_rate, 70 | distributed=args.gpus > 1 71 | ) 72 | require_match_scores = True 73 | get_meta_attr = BatfdPlus.get_meta_attr 74 | elif config["model_type"] == "batfd": 75 | model = Batfd( 76 | v_encoder=config["model"]["video_encoder"]["type"], 77 | a_encoder=config["model"]["audio_encoder"]["type"], 78 | frame_classifier=config["model"]["frame_classifier"]["type"], 79 | ve_features=config["model"]["video_encoder"]["hidden_dims"], 80 | ae_features=config["model"]["audio_encoder"]["hidden_dims"], 81 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"], 82 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"], 83 | boundary_features=config["model"]["boundary_module"]["hidden_dims"], 84 | boundary_samples=config["model"]["boundary_module"]["samples"], 85 | temporal_dim=config["num_frames"], 86 | max_duration=config["max_duration"], 87 | weight_frame_loss=config["optimizer"]["frame_loss_weight"], 88 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"], 89 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"], 90 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"], 91 | weight_decay=config["optimizer"]["weight_decay"], 92 | learning_rate=learning_rate, 93 | distributed=args.gpus > 1 94 | ) 95 | require_match_scores = False 96 | get_meta_attr = Batfd.get_meta_attr 97 | else: 98 | raise ValueError("Invalid model type") 99 | 100 | if dataset == "avdeepfake1m": 101 | dm = AVDeepfake1mDataModule( 102 | root=args.data_root, 103 | temporal_size=config["num_frames"], 104 | max_duration=config["max_duration"], 105 | require_match_scores=require_match_scores, 106 | batch_size=args.batch_size, num_workers=args.num_workers, 107 | take_train=args.num_train, take_val=args.num_val, 108 | get_meta_attr=get_meta_attr, 109 | is_plusplus=False 110 | ) 111 | elif dataset == "avdeepfake1m++": 112 | dm = AVDeepfake1mDataModule( 113 | root=args.data_root, 114 | temporal_size=config["num_frames"], 115 | max_duration=config["max_duration"], 116 | require_match_scores=require_match_scores, 117 | batch_size=args.batch_size, num_workers=args.num_workers, 118 | take_train=args.num_train, take_val=args.num_val, 119 | get_meta_attr=get_meta_attr, 120 | is_plusplus=True 121 | ) 122 | else: 123 | raise ValueError("Invalid dataset type") 124 | 125 | try: 126 | precision = int(args.precision) 127 | except ValueError: 128 | precision: int | str = args.precision 129 | 130 | monitor = "metrics/val_loss" 131 | 132 | if args.logger == "wandb": 133 | from lightning.pytorch.loggers import WandbLogger 134 | logger = WandbLogger(name=config["name"], project=dataset) 135 | else: 136 | logger = True 137 | 138 | trainer = Trainer(log_every_n_steps=20, precision=precision, max_epochs=args.max_epochs, 139 | callbacks=[ 140 | ModelCheckpoint( 141 | dirpath=f"./ckpt/{config['name']}", save_last=True, filename=config["name"] + "-{epoch}-{val_loss:.3f}", 142 | monitor=monitor, mode="min" 143 | ), 144 | LrLogger(), 145 | EarlyStoppingLR(lr_threshold=1e-7) 146 | ], enable_checkpointing=True, 147 | benchmark=True, 148 | accelerator="auto", 149 | devices=args.gpus, 150 | strategy="auto" if args.gpus < 2 else "ddp", 151 | logger=logger 152 | ) 153 | 154 | trainer.fit(model, dm, ckpt_path=args.resume) 155 | -------------------------------------------------------------------------------- /examples/xception/README.md: -------------------------------------------------------------------------------- 1 | # Xception 2 | 3 | This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels. 4 | 5 | ## Requirements 6 | 7 | - Python 8 | - PyTorch 9 | - PyTorch Lightning 10 | - TIMM 11 | - AVDeepfake1M SDK 12 | 13 | 14 | ## Training 15 | 16 | ```bash 17 | python train.py --data_root /path/to/avdeepfake1m --model xception 18 | ``` 19 | ### Output 20 | 21 | * **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`. 22 | * **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`). 23 | 24 | 25 | ## Inference 26 | 27 | After training, you can generate predictions on a dataset subset (train, val, or test) using `infer.py`. This script will save the predictions to a text file, following the format from the [challenge](https://deepfakes1m.github.io/2025/details). 28 | 29 | ```bash 30 | python infer.py --data_root /path/to/avdeepfake1m --checkpoint /path/to/your/checkpoint.ckpt --model xception --subset val 31 | ``` 32 | 33 | The output prediction file will be saved to `output/_.txt` (e.g., `output/xception_val.txt`). 34 | 35 | ## Evaluation 36 | 37 | ```bash 38 | python evaluate.py 39 | ``` 40 | 41 | For example: 42 | 43 | ```bash 44 | python evaluate.py ./output/xception_val.txt /path/to/avdeepfake1m/val_metadata.json 45 | ``` 46 | 47 | This will print the AUC score based on your model's predictions. 48 | -------------------------------------------------------------------------------- /examples/xception/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from avdeepfake1m.evaluation import auc 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser(description="Evaluation script for AV-Deepfake1M") 7 | parser.add_argument("prediction_file_path", type=str, help="Path to the prediction file (e.g., output/results/xception_val.txt)") 8 | parser.add_argument("metadata_file_path", type=str, help="Path to the metadata JSON file (e.g., /path/to/val_metadata.json)") 9 | args = parser.parse_args() 10 | 11 | print(auc( 12 | args.prediction_file_path, 13 | args.metadata_file_path, 14 | "file", # As per README, this is usually "file" 15 | "fake_segments" # As per README, this is usually "fake_segments" for AUC 16 | )) -------------------------------------------------------------------------------- /examples/xception/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from tqdm.auto import tqdm 5 | from pathlib import Path 6 | 7 | from avdeepfake1m.loader import AVDeepfake1mPlusPlusVideo 8 | from xception import Xception 9 | 10 | parser = argparse.ArgumentParser(description="Xception inference") 11 | parser.add_argument("--data_root", type=str) 12 | parser.add_argument("--checkpoint", type=str) 13 | parser.add_argument("--model", type=str) 14 | parser.add_argument("--batch_size", type=int, default=128) 15 | parser.add_argument("--subset", type=str, choices=["train", "val", "test", "testA", "testB"]) 16 | parser.add_argument("--gpus", type=int, default=1) 17 | parser.add_argument("--resume", type=str, default=None) 18 | parser.add_argument("--take_num", type=int, default=None) 19 | 20 | if __name__ == '__main__': 21 | args = parser.parse_args() 22 | use_gpu = args.gpus > 0 23 | device = "cuda" if use_gpu else "cpu" 24 | 25 | if args.model == "xception": 26 | model = Xception.load_from_checkpoint(args.checkpoint, lr=None, distributed=False).eval() 27 | else: 28 | raise ValueError(f"Unknown model: {args.model}") 29 | 30 | model.to(device) 31 | model.train() # not sure why but eval mode will generate nonsense output 32 | test_dataset = AVDeepfake1mPlusPlusVideo(args.subset, args.data_root, take_num=args.take_num, pred_mode=True) 33 | 34 | save_path = f"output/{args.model}_{args.subset}.txt" 35 | Path(save_path).parent.mkdir(parents=True, exist_ok=True) 36 | 37 | processed_files = set() 38 | if args.resume is not None: 39 | with open(args.resume, "r") as f: 40 | for line in f: 41 | processed_files.add(line.split(";")[0]) 42 | 43 | with open(save_path, "a") as f: 44 | with torch.inference_mode(): 45 | for i in tqdm(range(len(test_dataset))): 46 | file_name = test_dataset.metadata[i].file 47 | if file_name in processed_files: 48 | continue 49 | 50 | video, _, _ = test_dataset[i] 51 | # batch video as frames use batch_size 52 | preds_video = [] 53 | for j in range(0, len(video), args.batch_size): 54 | batch = video[j:j + args.batch_size].to(device) 55 | preds_video.append(model(batch)) 56 | 57 | preds_video = torch.cat(preds_video, dim=0).flatten() 58 | # choose the max prediction 59 | pred = preds_video.max().item() 60 | 61 | f.write(f"{file_name};{pred}\n") 62 | f.flush() 63 | -------------------------------------------------------------------------------- /examples/xception/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.utils.data import DataLoader 4 | from lightning.pytorch import Trainer 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | from avdeepfake1m.loader import AVDeepfake1mPlusPlusImages 7 | 8 | from xception import Xception 9 | from utils import LrLogger, EarlyStoppingLR 10 | 11 | 12 | parser = argparse.ArgumentParser(description="Classification model training") 13 | parser.add_argument("--data_root", type=str) 14 | parser.add_argument("--batch_size", type=int, default=128) 15 | parser.add_argument("--model", type=str, choices=["xception", "meso4", "meso_inception4"]) 16 | parser.add_argument("--gpus", type=int, default=1) 17 | parser.add_argument("--precision", default=32) 18 | parser.add_argument("--num_train", type=int, default=None) 19 | parser.add_argument("--num_val", type=int, default=2000) 20 | parser.add_argument("--max_epochs", type=int, default=500) 21 | parser.add_argument("--resume", type=str, default=None) 22 | args = parser.parse_args() 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | # You can fix the random seed if you want reproducible subsets each epoch: 28 | # torch.manual_seed(42) 29 | # random.seed(42) 30 | 31 | learning_rate = 1e-4 32 | gpus = args.gpus 33 | total_batch_size = args.batch_size * gpus 34 | learning_rate = learning_rate * total_batch_size / 4 35 | 36 | # Setup model 37 | if args.model == "xception": 38 | model = Xception(learning_rate, distributed=gpus > 1) 39 | else: 40 | raise ValueError(f"Unknown model: {args.model}") 41 | 42 | train_dataset = AVDeepfake1mPlusPlusImages( 43 | subset="train", 44 | data_root=args.data_root, 45 | take_num=args.num_train, 46 | use_video_label=True # For video-level label access, set True 47 | ) 48 | 49 | # For validation, you can still do the normal dataset 50 | val_dataset = AVDeepfake1mPlusPlusImages( 51 | subset="val", 52 | data_root=args.data_root, 53 | take_num=args.num_val, 54 | use_video_label=True 55 | ) 56 | 57 | # Parse precision properly 58 | try: 59 | precision = int(args.precision) 60 | except ValueError: 61 | precision = args.precision 62 | 63 | monitor = "val_loss" 64 | 65 | trainer = Trainer( 66 | log_every_n_steps=50, 67 | precision=precision, 68 | max_epochs=args.max_epochs, 69 | callbacks=[ 70 | ModelCheckpoint( 71 | dirpath=f"./ckpt/{args.model}", 72 | save_last=True, 73 | filename=args.model + "-{epoch}-{val_loss:.3f}", 74 | monitor=monitor, 75 | mode="min" 76 | ), 77 | LrLogger(), 78 | EarlyStoppingLR(lr_threshold=1e-7) 79 | ], 80 | enable_checkpointing=True, 81 | benchmark=True, 82 | accelerator="gpu", 83 | devices=args.gpus, 84 | strategy="ddp" if args.gpus > 1 else "auto", 85 | # ckpt_path=args.resume, 86 | # If you're on an older version of Lightning, you may need `strategy='ddp'` just the same, but this is typical. 87 | ) 88 | 89 | trainer.fit( 90 | model, 91 | train_dataloaders=DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0), 92 | val_dataloaders=DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0), 93 | ckpt_path=args.resume, 94 | ) 95 | -------------------------------------------------------------------------------- /examples/xception/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from lightning.pytorch import Callback, Trainer, LightningModule 4 | 5 | 6 | class LrLogger(Callback): 7 | """Log learning rate in each epoch start.""" 8 | 9 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 10 | for i, optimizer in enumerate(trainer.optimizers): 11 | for j, params in enumerate(optimizer.param_groups): 12 | key = f"opt{i}_lr{j}" 13 | value = params["lr"] 14 | pl_module.logger.log_metrics({key: value}, step=trainer.global_step) 15 | pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed) 16 | 17 | 18 | class EarlyStoppingLR(Callback): 19 | """Early stop model training when the LR is lower than threshold.""" 20 | 21 | def __init__(self, lr_threshold: float, mode="all"): 22 | self.lr_threshold = lr_threshold 23 | 24 | if mode in ("any", "all"): 25 | self.mode = mode 26 | else: 27 | raise ValueError(f"mode must be one of ('any', 'all')") 28 | 29 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 30 | self._run_early_stop_checking(trainer) 31 | 32 | def _run_early_stop_checking(self, trainer: Trainer) -> None: 33 | metrics = trainer._logger_connector.callback_metrics 34 | if len(metrics) == 0: 35 | return 36 | all_lr = [] 37 | for key, value in metrics.items(): 38 | if re.match(r"opt\d+_lr\d+", key): 39 | all_lr.append(value) 40 | 41 | if len(all_lr) == 0: 42 | return 43 | 44 | if self.mode == "all": 45 | if all(lr <= self.lr_threshold for lr in all_lr): 46 | trainer.should_stop = True 47 | elif self.mode == "any": 48 | if any(lr <= self.lr_threshold for lr in all_lr): 49 | trainer.should_stop = True 50 | -------------------------------------------------------------------------------- /examples/xception/xception.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from lightning.pytorch import LightningModule 4 | from torch.nn import BCEWithLogitsLoss 5 | from torch.optim import Adam 6 | 7 | 8 | class Xception(LightningModule): 9 | def __init__(self, lr, distributed=False): 10 | super(Xception, self).__init__() 11 | self.lr = lr 12 | self.model = timm.create_model('xception', pretrained=True, num_classes=1) 13 | self.loss_fn = BCEWithLogitsLoss() 14 | self.distributed = distributed 15 | 16 | def forward(self, x): 17 | x = self.model(x) 18 | return x 19 | 20 | def training_step(self, batch, batch_idx): 21 | x, y = batch 22 | y_hat = self(x) 23 | loss = self.loss_fn(y_hat, y.unsqueeze(1)) 24 | self.log('train_loss', loss) 25 | return loss 26 | 27 | def validation_step(self, batch, batch_idx): 28 | x, y = batch 29 | y_hat = self(x) 30 | loss = self.loss_fn(y_hat, y.unsqueeze(1)) 31 | self.log('val_loss', loss) 32 | return loss 33 | 34 | def configure_optimizers(self): 35 | optimizer = Adam(self.parameters(), lr=self.lr) 36 | return [optimizer] 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.7,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "avdeepfake1m" 7 | requires-python = ">=3.7" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: 3", 12 | "Programming Language :: Python :: 3.7", 13 | "Programming Language :: Python :: 3.8", 14 | "Programming Language :: Python :: 3.9", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "License :: Other/Proprietary License", 19 | "Operating System :: OS Independent", 20 | "Intended Audience :: Developers", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Multimedia :: Video", 23 | "Topic :: Utilities" 24 | ] 25 | authors = [ 26 | {name = "ControlNet", email = "smczx@hotmail.com"} 27 | ] 28 | dynamic = ["version"] 29 | readme = "README.md" 30 | keywords = ["pytorch", "AI"] 31 | dependencies = [ 32 | "torch", 33 | "lightning>=2", 34 | "tqdm", 35 | "einops", 36 | "opencv-python", 37 | "numpy>=1,<2", 38 | "torchvision", 39 | "torchaudio", 40 | "av", 41 | "pandas~=2.0", 42 | "torchmetrics~=1.0", 43 | ] 44 | 45 | [project.urls] 46 | "Homepage" = "https://github.com/ControlNet/AV-Deepfake1M" 47 | "Source Code" = "https://github.com/ControlNet/AV-Deepfake1M" 48 | "Bug Tracker" = "https://github.com/ControlNet/AV-Deepfake1M/issues" 49 | 50 | [tool.maturin] 51 | features = ["pyo3/extension-module"] 52 | module-name = "avdeepfake1m._evaluation" 53 | python-source = "python" 54 | 55 | [tool.pixi.project] 56 | channels = ["pytorch", "conda-forge/label/rust_dev", "conda-forge"] 57 | platforms = ["linux-64", "osx-arm64", "osx-64", "win-64"] 58 | 59 | [tool.pixi.dependencies] 60 | pytorch = { version = "~=2.2.2", channel = "pytorch" } 61 | torchvision = { version = "~=0.17.2", channel = "pytorch" } 62 | torchaudio = { version = "~=2.2.2", channel = "pytorch" } 63 | cpuonly = { version = "~=2.0", channel = "pytorch" } # for developing the dataloader and evaluator only 64 | python = "~=3.11.0" 65 | numpy = { version = "~=1.0", channel = "conda-forge" } 66 | ffmpeg = { version = "*", channel = "conda-forge" } 67 | av = { version = "*", channel = "conda-forge" } 68 | pandas = ">=2.2.3,<3" 69 | torchmetrics = ">=1.4.2,<2" 70 | 71 | [tool.pixi.target.linux-64.dependencies] 72 | rust = { version = "*", channel = "conda-forge/label/rust_dev" } 73 | 74 | [tool.pixi.target.osx-64.dependencies] 75 | rust_osx-64 = { version = "*", channel = "conda-forge/label/rust_dev" } 76 | 77 | [tool.pixi.target.win-64.dependencies] 78 | rust_win-64 = { version = "*", channel = "conda-forge/label/rust_dev" } 79 | 80 | [tool.pixi.target.osx-arm64.dependencies] 81 | rust_osx-arm64 = { version = "*", channel = "conda-forge/label/rust_dev" } 82 | 83 | [tool.pixi.pypi-dependencies] 84 | avdeepfake1m = { path = ".", editable = true } 85 | 86 | [tool.pixi.build-dependencies] 87 | maturin = ">=1.7,<2.0" 88 | setuptools = "*" 89 | pip = "*" 90 | 91 | [tool.pixi.tasks] 92 | develop = "maturin develop" 93 | build = "maturin build --release --find-interpreter" 94 | -------------------------------------------------------------------------------- /python/avdeepfake1m/__init__.py: -------------------------------------------------------------------------------- 1 | from . import evaluation 2 | from . import utils 3 | from . import loader 4 | 5 | __all__ = ["evaluation", "utils", "loader"] 6 | -------------------------------------------------------------------------------- /python/avdeepfake1m/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .auc import auc 2 | from .._evaluation import ap_1d, ar_1d, ap_ar_1d 3 | 4 | __all__ = ["auc", "ap_1d", "ar_1d", "ap_ar_1d"] 5 | -------------------------------------------------------------------------------- /python/avdeepfake1m/evaluation/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | 3 | 4 | def ap_1d(proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float, 5 | iou_thresholds: List[float]) -> Dict[float, float]: 6 | pass 7 | 8 | 9 | def ar_1d(proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float, n_proposals: List[int], 10 | iou_thresholds: List[float]) -> Dict[int, float]: 11 | pass 12 | 13 | 14 | def ap_ar_1d( 15 | proposals_path: str, labels_path: str, file_key: str, value_key: str, fps: float, 16 | ap_iou_thresholds: List[float], ar_n_proposals: List[int], ar_iou_thresholds: List[float] 17 | ) -> Dict[str, Dict[Union[float, int], float]]: 18 | pass 19 | 20 | 21 | def auc(prediction_file: str, reference_path: str, file_key: str, value_key: str) -> float: 22 | pass 23 | -------------------------------------------------------------------------------- /python/avdeepfake1m/evaluation/auc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torchmetrics import AUROC 4 | 5 | from ..utils import read_json 6 | 7 | 8 | def auc(prediction_file: str, reference_path: str, file_key: str, value_key: str) -> float: 9 | 10 | prediction = pd.read_csv(prediction_file, header=None, sep=";") 11 | # convert to dict 12 | prediction_dict = {} 13 | for i in range(len(prediction)): 14 | prediction_dict[prediction.iloc[i, 0]] = prediction.iloc[i, 1] 15 | 16 | gt = read_json(reference_path) 17 | 18 | # make it as list 19 | truth = [] 20 | prediction = [] 21 | for gt_item in gt: 22 | key = gt_item[file_key] 23 | truth.append(int(len(gt_item[value_key]) > 0)) 24 | prediction.append(prediction_dict[key]) 25 | 26 | # to tensor for torchmetrics 27 | truth = torch.tensor(truth) 28 | prediction = torch.tensor(prediction) 29 | 30 | # compute auc 31 | auroc = AUROC(task="binary") 32 | return auroc(prediction, truth).item() 33 | -------------------------------------------------------------------------------- /python/avdeepfake1m/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from einops import rearrange 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | 12 | 13 | def read_json(path: str, object_hook=None): 14 | with open(path, 'r') as f: 15 | return json.load(f, object_hook=object_hook) 16 | 17 | 18 | def read_video(path: str): 19 | video, audio, info = torchvision.io.read_video(path, pts_unit="sec") 20 | video = video.permute(0, 3, 1, 2) / 255 21 | audio = audio.permute(1, 0) 22 | if audio.shape[0] == 0: 23 | audio = torch.zeros(1, 1) 24 | return video, audio, info 25 | 26 | 27 | def read_video_fast(path: str): 28 | cap = cv2.VideoCapture(path) 29 | frames = [] 30 | while True: 31 | ret, frame = cap.read() 32 | if not ret: 33 | break 34 | else: 35 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 36 | cap.release() 37 | video = np.stack(frames, axis=0) 38 | video = rearrange(video, 'T H W C -> T C H W') 39 | return torch.from_numpy(video) / 255 40 | 41 | 42 | def resize_video(tensor: Tensor, size: Tuple[int, int], resize_method: str = "bicubic") -> Tensor: 43 | return F.interpolate(tensor, size=size, mode=resize_method) 44 | 45 | 46 | def iou_with_anchors(anchors_min, anchors_max, box_min, box_max): 47 | """Compute jaccard score between a box and the anchors.""" 48 | 49 | len_anchors = anchors_max - anchors_min 50 | int_xmin = np.maximum(anchors_min, box_min) 51 | int_xmax = np.minimum(anchors_max, box_max) 52 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 53 | union_len = len_anchors - inter_len + box_max - box_min 54 | iou = inter_len / (union_len + 1e-8) 55 | return iou 56 | 57 | 58 | def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max): 59 | # calculate the overlap proportion between the anchor and all bbox for supervise signal, 60 | # the length of the anchor is 0.01 61 | len_anchors = anchors_max - anchors_min 62 | int_xmin = np.maximum(anchors_min, box_min) 63 | int_xmax = np.minimum(anchors_max, box_max) 64 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 65 | scores = np.divide(inter_len, len_anchors + 1e-8) 66 | return scores 67 | 68 | 69 | def iou_1d(proposal, target) -> Tensor: 70 | """ 71 | Calculate 1D IOU for N proposals with L labels. 72 | 73 | Args: 74 | proposal (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The predicted array with [M, 2]. First column is 75 | beginning, second column is end. 76 | target (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The label array with [N, 2]. First column is 77 | beginning, second column is end. 78 | 79 | Returns: 80 | :class:`~torch.Tensor`: The iou result with [M, N]. 81 | """ 82 | if type(proposal) is np.ndarray: 83 | proposal = torch.from_numpy(proposal) 84 | 85 | if type(target) is np.ndarray: 86 | target = torch.from_numpy(target) 87 | 88 | proposal_begin = proposal[:, 0].unsqueeze(0).T 89 | proposal_end = proposal[:, 1].unsqueeze(0).T 90 | target_begin = target[:, 0] 91 | target_end = target[:, 1] 92 | 93 | inner_begin = torch.maximum(proposal_begin, target_begin) 94 | inner_end = torch.minimum(proposal_end, target_end) 95 | outer_begin = torch.minimum(proposal_begin, target_begin) 96 | outer_end = torch.maximum(proposal_end, target_end) 97 | 98 | inter = torch.clamp(inner_end - inner_begin, min=0.) 99 | union = outer_end - outer_begin 100 | return inter / union 101 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(iter_map_windows)] 2 | 3 | use pyo3::prelude::*; 4 | 5 | pub mod loc_1d; 6 | 7 | #[pymodule] 8 | fn _evaluation(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { 9 | m.add("__version__", env!("CARGO_PKG_VERSION"))?; 10 | m.add_function(wrap_pyfunction!(loc_1d::ap_1d, m)?)?; 11 | m.add_function(wrap_pyfunction!(loc_1d::ar_1d, m)?)?; 12 | m.add_function(wrap_pyfunction!(loc_1d::ap_ar_1d, m)?)?; 13 | 14 | Ok(()) 15 | } 16 | -------------------------------------------------------------------------------- /src/loc_1d.rs: -------------------------------------------------------------------------------- 1 | extern crate serde_json; 2 | extern crate simd_json; 3 | 4 | use std::collections::HashMap; 5 | use std::fs; 6 | 7 | use ndarray::{concatenate, OwnedRepr, s, stack, Zip}; 8 | use ndarray::prelude::*; 9 | use pyo3::prelude::*; 10 | use pyo3::types::{IntoPyDict, PyDict}; 11 | use rayon::prelude::*; 12 | use serde::{Deserialize, Serialize}; 13 | use serde_json::{Map, Value}; 14 | 15 | #[derive(Serialize, Deserialize, Debug)] 16 | struct Metadata { 17 | file: String, 18 | segments: Vec>, 19 | } 20 | 21 | fn convert_metadata_info_to_metadata( 22 | metadata_info: Map, 23 | file_key: &str, 24 | value_key: &str, 25 | ) -> Metadata { 26 | Metadata { 27 | file: metadata_info.get(file_key).unwrap().as_str().unwrap().to_string(), 28 | segments: metadata_info.get(value_key).unwrap().as_array().unwrap().iter().map(|x| { 29 | x.as_array().unwrap().iter().map(|x| x.as_f64().unwrap() as f32).collect() 30 | }).collect(), 31 | } 32 | } 33 | 34 | fn iou_1d(proposal: Array2, target: &Array2) -> Array2 { 35 | let m = proposal.nrows(); 36 | let n = target.nrows(); 37 | 38 | let mut ious = Array2::::zeros((m, n)); 39 | 40 | for i in 0..m { 41 | for j in 0..n { 42 | let proposal_start = proposal[[i, 0]]; 43 | let proposal_end = proposal[[i, 1]]; 44 | let target_start = target[[j, 0]]; 45 | let target_end = target[[j, 1]]; 46 | 47 | let inner_begin = proposal_start.max(target_start); 48 | let inner_end = proposal_end.min(target_end); 49 | let outer_begin = proposal_start.min(target_start); 50 | let outer_end = proposal_end.max(target_end); 51 | 52 | let intersection = (inner_end - inner_begin).max(0.0); 53 | let union = outer_end - outer_begin; 54 | ious[[i, j]] = intersection / union; 55 | } 56 | } 57 | 58 | ious 59 | } 60 | 61 | fn calc_ap_curve(is_tp: Array1, n_labels: f32) -> Array2 { 62 | let acc_tp = Array1::from_vec( 63 | is_tp.iter().scan(0.0, |state, &x| { 64 | if x { *state += 1.0 } 65 | Some(*state) 66 | }).collect() 67 | ); 68 | 69 | let precision: Array1 = acc_tp.iter().enumerate().map(|(i, &x)| x / (i as f32 + 1.0)).collect(); 70 | let recall: Array1 = acc_tp / n_labels; 71 | let binding = stack!(Axis(0), recall.view(), precision.view()); 72 | let binding = binding.t(); 73 | 74 | concatenate![ 75 | Axis(0), 76 | arr2(&[[1., 0.]]).view(), 77 | binding.slice(s![..;-1, ..]) 78 | ] 79 | } 80 | 81 | fn calculate_ap(curve: &Array2) -> f32 { 82 | let x = curve.column(0).to_owned(); 83 | let y = curve.column(1).to_owned(); 84 | 85 | let y_max = Array1::from(y.iter().scan(None, |state, &x| { 86 | if state.is_none() || x > state.unwrap() { 87 | *state = Some(x); 88 | } 89 | 90 | *state 91 | }).collect::>()); 92 | 93 | let x_diff: Array1 = x 94 | .into_iter() 95 | .map_windows(|[x, y]| (y - x).abs()) 96 | .collect(); 97 | 98 | (x_diff * y_max.slice(s![..-1])).sum() 99 | } 100 | 101 | fn get_ap_values( 102 | iou_threshold: f32, 103 | proposals: &Array2, 104 | labels: &Array2, 105 | fps: f32, 106 | ) -> (Array1, Array1) { 107 | let n_labels = labels.len_of(Axis(0)); 108 | let n_proposals = proposals.len_of(Axis(0)); 109 | let local_proposals = if proposals.shape() != [0] { 110 | proposals.clone() 111 | } else { 112 | proposals.clone() 113 | .into_shape((0, 3)) 114 | .unwrap() 115 | }; 116 | 117 | let ious = if n_labels > 0 { 118 | iou_1d(local_proposals.slice(s![.., 1..]).mapv(|x| x / fps), labels) 119 | } else { 120 | Array::zeros((n_proposals, 0)) 121 | }; 122 | 123 | let confidence = local_proposals.column(0).to_owned(); 124 | let potential_tp = ious.mapv(|x| x > iou_threshold); 125 | 126 | let mut is_tp = Array1::from_elem((n_proposals, ), false); 127 | 128 | let mut tp_indexes = Vec::new(); 129 | for i in 0..n_labels { 130 | let potential_tp_col = potential_tp.column(i); 131 | let potential_tp_index = potential_tp_col.iter().enumerate().filter(|(_, &x)| x).map(|(j, _)| j); 132 | for j in potential_tp_index { 133 | if !tp_indexes.contains(&j) { 134 | tp_indexes.push(j); 135 | break; 136 | } 137 | } 138 | } 139 | 140 | if !tp_indexes.is_empty() { 141 | for &j in &tp_indexes { 142 | is_tp[j] = true; 143 | } 144 | } 145 | 146 | (confidence, is_tp) 147 | } 148 | 149 | fn calc_ap_scores( 150 | iou_thresholds: &Vec, 151 | metadata: &Vec, 152 | proposals_map: &Proposals, 153 | fps: f32, 154 | ) -> Vec<(f32, f32)> { 155 | iou_thresholds.par_iter().map(|iou| { 156 | let (values, labels): (Vec<_>, Vec) = metadata 157 | .par_iter() 158 | .map(|meta| { 159 | let proposals = &proposals_map.content[&meta.file]; 160 | let rows = meta.segments.len(); 161 | let x: Vec = meta.segments.iter().flatten().copied().collect(); 162 | let labels = Array2::from_shape_vec((rows, 2), x).unwrap().to_owned(); 163 | let meta_value = get_ap_values(*iou, &proposals.row, &labels, fps); 164 | 165 | (meta_value, labels.len_of(Axis(0)) as isize) 166 | }) 167 | .unzip(); 168 | 169 | let n_labels = labels.iter().sum::() as f32; 170 | 171 | let (r, n): (Vec<_>, Vec<_>) = values.into_iter().unzip(); 172 | let confidence = concatenate( 173 | Axis(0), 174 | &r.iter() 175 | .map(|x| x.view()) 176 | .collect::>(), 177 | ).unwrap(); 178 | let is_tp = concatenate( 179 | Axis(0), 180 | &n.iter() 181 | .map(|x| x.view()) 182 | .collect::>(), 183 | ).unwrap(); 184 | 185 | let mut indices: Vec = (0..confidence.len()).collect(); 186 | indices.sort_by(|&a, &b| confidence[b].partial_cmp(&confidence[a]).unwrap()); 187 | let is_tp = is_tp.select(Axis(0), &indices); 188 | let curve = calc_ap_curve(is_tp, n_labels); 189 | let ap = calculate_ap(&curve); 190 | 191 | (*iou, ap) 192 | }).collect::>() 193 | } 194 | 195 | 196 | fn cummax_2d(array: &Array2) -> Array2 { 197 | let mut result = array.clone(); 198 | 199 | for mut column in result.axis_iter_mut(Axis(1)) { 200 | let mut cummax = column[0]; 201 | 202 | for row in column.iter_mut().skip(1) { 203 | cummax = cummax.max(*row); 204 | *row = cummax; 205 | } 206 | } 207 | 208 | result 209 | } 210 | 211 | fn calc_ar_values( 212 | n_proposals: &Vec, 213 | iou_thresholds: &Vec, 214 | proposals: &Array2, 215 | labels: &Array2, 216 | fps: f32, 217 | ) -> ArrayBase, Ix3> { 218 | let max_proposals = *n_proposals.iter().max().unwrap(); 219 | let max_proposals = max_proposals.min(proposals.nrows()); 220 | 221 | let mut proposals = proposals.slice(s![..max_proposals, ..]).to_owned(); 222 | if proposals.is_empty() { 223 | proposals = Array2::zeros((0, 3)).into(); 224 | } 225 | 226 | let n_proposals_clamped = n_proposals.iter().map(|&n| n.min(proposals.nrows())).collect::>(); 227 | let n_labels = labels.nrows(); 228 | 229 | let ious = if n_labels > 0 { 230 | iou_1d(proposals.slice(s![.., 1..]).mapv(|x| x / fps), labels) 231 | } else { 232 | Array::zeros((max_proposals, 0)) // TODO: maybe short-circuit 233 | }; 234 | 235 | let mut values = Array3::zeros((iou_thresholds.len(), n_proposals_clamped.len(), 2)); 236 | if !proposals.is_empty() { 237 | let iou_max = cummax_2d(&ious); // (n_iou, n_labels) 238 | for (threshold_idx, &threshold) in iou_thresholds.iter().enumerate() { 239 | for (n_proposals_idx, &n_proposal) in n_proposals_clamped.iter().enumerate() { 240 | let tp = iou_max.row(n_proposal - 1).iter().filter(|&&iou| iou > threshold).count(); 241 | values[[threshold_idx, n_proposals_idx, 0]] = tp; 242 | values[[threshold_idx, n_proposals_idx, 1]] = n_labels - tp; 243 | } 244 | } 245 | } 246 | values 247 | } 248 | 249 | fn calc_ar_scores( 250 | n_proposals: &Vec, 251 | iou_thresholds: &Vec, 252 | metadata: &Vec, 253 | proposals_map: &Proposals, 254 | fps: f32, 255 | ) -> Vec<(usize, f32)> { 256 | let values = metadata.par_iter().map(|meta| { 257 | let proposals = &proposals_map.content[&meta.file]; 258 | 259 | let rows = meta.segments.len(); 260 | let x: Vec = meta.segments.iter().flatten().copied().collect(); 261 | let labels = Array2::from_shape_vec((rows, 2), x).unwrap().to_owned(); 262 | 263 | calc_ar_values(&n_proposals, iou_thresholds, &proposals.row, &labels, fps) 264 | }).collect::>(); 265 | 266 | let values = stack( 267 | Axis(0), 268 | &values 269 | .iter() 270 | .map(|x| x.view()) 271 | .collect::>(), 272 | ).unwrap(); 273 | 274 | let values_sum = values.sum_axis(Axis(0)); 275 | let tp = values_sum.slice(s![.., .., 0]); 276 | let f_n = values_sum.slice(s![.., .., 1]); 277 | 278 | let recall = Zip::from(&tp).and(&f_n).map_collect(|&x, &y| { 279 | let div = x as f32 + y as f32; 280 | if div == 0. { 281 | 0. 282 | } else { 283 | x as f32 / div 284 | } 285 | }); 286 | 287 | n_proposals.iter().enumerate().map(|(ix, &prop)| { 288 | (prop, recall.column(ix).mean().unwrap()) 289 | }).collect::>() 290 | } 291 | 292 | 293 | #[derive(Deserialize, Debug)] 294 | #[serde(transparent)] 295 | struct ProposalRow { 296 | #[serde(with = "serde_ndim")] 297 | pub row: Array2, 298 | } 299 | 300 | #[derive(Deserialize, Debug)] 301 | #[serde(transparent)] 302 | struct Proposals { 303 | pub content: HashMap, 304 | } 305 | 306 | fn load_json(proposals_path: &str, labels_path: &str, file_key: &str, value_key: &str) -> (Vec, Proposals) { 307 | let mut proposals_raw = fs::read_to_string(proposals_path).expect("Unable to read proposal file"); 308 | let mut labels_raw = fs::read_to_string(labels_path).expect("Unable to read labels file"); 309 | let labels_infos: Vec> = unsafe { simd_json::serde::from_str(labels_raw.as_mut_str()) }.unwrap(); 310 | let labels_infos: Vec<_> = labels_infos.into_par_iter().map(|x| convert_metadata_info_to_metadata(x, file_key, value_key)).collect(); 311 | let proposals: Proposals = unsafe { simd_json::serde::from_str(proposals_raw.as_mut_str()) }.unwrap(); 312 | (labels_infos, proposals) 313 | } 314 | 315 | #[pyfunction] 316 | pub fn ap_1d<'py>( 317 | proposals_path: &str, 318 | labels_path: &str, 319 | file_key: &str, value_key: &str, 320 | fps: f32, 321 | iou_thresholds: Vec, 322 | py: Python<'py>, 323 | ) -> Bound<'py, PyDict> { 324 | let (labels_infos, proposals) = load_json(proposals_path, labels_path, file_key, value_key); 325 | 326 | let ap_score = calc_ap_scores( 327 | &iou_thresholds, 328 | &labels_infos, 329 | &proposals, 330 | fps, 331 | ); 332 | 333 | ap_score.into_py_dict_bound(py) 334 | } 335 | 336 | #[pyfunction] 337 | pub fn ar_1d<'py>( 338 | proposals_path: &str, 339 | labels_path: &str, 340 | file_key: &str, value_key: &str, 341 | fps: f32, 342 | n_proposals: Vec, 343 | iou_thresholds: Vec, 344 | py: Python<'py>, 345 | ) -> Bound<'py, PyDict> { 346 | let (labels_infos, proposals) = load_json(proposals_path, labels_path, file_key, value_key); 347 | 348 | let ar_score = calc_ar_scores( 349 | &n_proposals, 350 | &iou_thresholds, 351 | &labels_infos, 352 | &proposals, 353 | fps, 354 | ); 355 | 356 | ar_score.into_py_dict_bound(py) 357 | } 358 | 359 | #[pyfunction] 360 | pub fn ap_ar_1d<'py>( 361 | proposals_path: &str, 362 | labels_path: &str, 363 | file_key: &str, value_key: &str, 364 | fps: f32, 365 | ap_iou_thresholds: Vec, 366 | ar_n_proposals: Vec, 367 | ar_iou_thresholds: Vec, 368 | py: Python<'py>, 369 | ) -> Bound<'py, PyDict> { 370 | let (labels_infos, proposals) = load_json(proposals_path, labels_path, file_key, value_key); 371 | 372 | let ap_score = calc_ap_scores( 373 | &ap_iou_thresholds, 374 | &labels_infos, 375 | &proposals, 376 | fps, 377 | ); 378 | 379 | let ar_score = calc_ar_scores( 380 | &ar_n_proposals, 381 | &ar_iou_thresholds, 382 | &labels_infos, 383 | &proposals, 384 | fps, 385 | ); 386 | 387 | let ap_dict = ap_score.into_py_dict_bound(py); 388 | let ar_dict = ar_score.into_py_dict_bound(py); 389 | 390 | // {"ap": ap_dict, "ar": ar_dict} 391 | let dict = PyDict::new_bound(py); 392 | dict.set_item("ap", ap_dict).unwrap(); 393 | dict.set_item("ar", ar_dict).unwrap(); 394 | dict 395 | } 396 | --------------------------------------------------------------------------------