├── lada
├── cli
│ └── __init__.py
├── models
│ ├── __init__.py
│ ├── yolo
│ │ ├── __init__.py
│ │ └── yolo.py
│ ├── bpjdet
│ │ ├── __init__.py
│ │ ├── models
│ │ │ └── __init__.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── autoanchor.py
│ │ │ ├── datasets.py
│ │ │ ├── metrics.py
│ │ │ └── augmentations.py
│ │ └── data
│ │ │ └── JointBP_CrowdHuman_head.py
│ ├── centerface
│ │ └── __init__.py
│ ├── deepmosaics
│ │ ├── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── loadmodel.py
│ │ ├── util
│ │ │ ├── __init__.py
│ │ │ ├── image_processing.py
│ │ │ └── data.py
│ │ └── inference.py
│ ├── dover
│ │ ├── __init__.py
│ │ ├── requirements.txt
│ │ ├── datasets
│ │ │ └── __init__.py
│ │ ├── models
│ │ │ └── __init__.py
│ │ └── dover.yml
│ └── basicvsrpp
│ │ ├── __init__.py
│ │ ├── mmagic
│ │ ├── typing.py
│ │ ├── logger.py
│ │ ├── log_processor.py
│ │ ├── __init__.py
│ │ ├── flow_warp.py
│ │ ├── setup_env.py
│ │ ├── loop_utils.py
│ │ ├── unet_disc.py
│ │ ├── iter_time_hook.py
│ │ └── model_utils.py
│ │ ├── deformconv.py
│ │ └── inference.py
├── gui
│ ├── config
│ │ ├── __init__.py
│ │ ├── no_gpu_banner.ui
│ │ └── no_gpu_banner.py
│ ├── export
│ │ ├── __init__.py
│ │ ├── spinner_button.ui
│ │ ├── spinner_button.py
│ │ ├── export_multiple_files_page.ui
│ │ └── shutdown_manager.py
│ ├── preview
│ │ ├── __init__.py
│ │ ├── timeline.ui
│ │ ├── headerbar_files_drop_down.ui
│ │ ├── seek_preview_popover.py
│ │ ├── seek_preview_popover.ui
│ │ └── headerbar_files_drop_down.py
│ ├── fileselection
│ │ ├── __init__.py
│ │ ├── file_selection_view.py
│ │ └── file_selection_view.ui
│ ├── icons
│ │ ├── cafe-symbolic.svg.license
│ │ ├── edit-symbolic.svg.license
│ │ ├── color-scheme-dark.svg.license
│ │ ├── color-scheme-light.svg.license
│ │ ├── color-scheme-system.svg.license
│ │ ├── playback-symbolic.svg.license
│ │ ├── plus-large-symbolic.svg.license
│ │ ├── speaker-0-symbolic.svg.license
│ │ ├── speaker-4-symbolic.svg.license
│ │ ├── cross-large-symbolic.svg.license
│ │ ├── external-link-symbolic.svg.license
│ │ ├── folder-open-symbolic.svg.license
│ │ ├── pause-large-symbolic.svg.license
│ │ ├── check-round-outline2-symbolic.svg.license
│ │ ├── exclamation-mark-symbolic.svg.license
│ │ ├── media-playback-pause-symbolic.svg.license
│ │ ├── media-playback-start-symbolic.svg.license
│ │ ├── sliders-horizontal-symbolic.svg.license
│ │ ├── arrow-pointing-away-from-line-right-symbolic.svg.license
│ │ ├── lada-logo-gray.png
│ │ ├── lada-logo-gray.png.license
│ │ ├── color-scheme-dark.svg
│ │ ├── color-scheme-light.svg
│ │ ├── plus-large-symbolic.svg
│ │ ├── README.md
│ │ ├── color-scheme-system.svg
│ │ ├── pause-large-symbolic.svg
│ │ ├── media-playback-pause-symbolic.svg
│ │ ├── exclamation-mark-symbolic.svg
│ │ ├── media-playback-start-symbolic.svg
│ │ ├── edit-symbolic.svg
│ │ ├── speaker-4-symbolic.svg
│ │ ├── arrow-pointing-away-from-line-right-symbolic.svg
│ │ ├── playback-symbolic.svg
│ │ ├── cross-large-symbolic.svg
│ │ ├── sliders-horizontal-symbolic.svg
│ │ ├── speaker-0-symbolic.svg
│ │ ├── cafe-symbolic.svg
│ │ ├── external-link-symbolic.svg
│ │ ├── check-round-outline2-symbolic.svg
│ │ └── folder-open-symbolic.svg
│ ├── resources.gresource
│ ├── style.css
│ ├── __init__.py
│ ├── resources.gresource.xml
│ ├── shortcuts.py
│ └── window.ui
├── datasetcreation
│ ├── __init__.py
│ └── detectors
│ │ ├── __init__.py
│ │ ├── nsfw_frame_detector.py
│ │ ├── head_detector.py
│ │ ├── face_detector.py
│ │ ├── mosaic_detector.py
│ │ └── watermark_detector.py
├── locale
│ └── README.md
├── utils
│ ├── random_utils.py
│ ├── os_utils.py
│ ├── encoding_presets.csv
│ ├── box_utils.py
│ ├── torch_letterbox.py
│ ├── audio_utils.py
│ ├── __init__.py
│ └── scene_utils.py
└── restorationpipeline
│ ├── deepmosaics_mosaic_restorer.py
│ ├── basicvsrpp_mosaic_restorer.py
│ └── __init__.py
├── .python-version
├── datasets
└── .gitignore
├── model_weights
├── 3rd_party
│ ├── .gitignore
│ ├── 640m.pt.license
│ ├── DOVER.pth.license
│ ├── spynet_20210409-c6c1bd09.pth.license
│ ├── ch_head_s_1536_e150_best_mMR.pt.license
│ ├── centerface.onnx.license
│ └── clean_youknow_video.pth.license
├── lada_nsfw_detection_model.pt.license
├── lada_mosaic_detection_model_v2.pt.license
├── lada_mosaic_detection_model_v3.pt.license
├── lada_mosaic_edge_detection_model.pth.license
├── lada_nsfw_detection_model_v1.1.pt.license
├── lada_nsfw_detection_model_v1.2.pt.license
├── lada_nsfw_detection_model_v1.3.pt.license
├── lada_watermark_detection_model.pt.license
├── lada_mosaic_detection_model_v3.1_fast.pt.license
├── lada_mosaic_restoration_model_bj_pov.pth.license
├── lada_watermark_detection_model_v1.1.pt.license
├── lada_watermark_detection_model_v1.2.pt.license
├── lada_watermark_detection_model_v1.3.pt.license
├── lada_mosaic_detection_model_v3.1_accurate.pt.license
├── lada_mosaic_restoration_model_generic_v1.1.pth.license
├── lada_mosaic_restoration_model_generic_v1.2.pth.license
├── checksums_md5.txt
└── checksums_sha256.txt
├── translations
├── release_ready_translations.txt
├── update_po.sh
├── update_pot.sh
├── extract_csv_strings.py
├── compile_po.sh
├── compile_po.ps1
└── README.md
├── assets
├── screenshot_cli_1.png
├── screenshot_gui_1_dark.png
├── screenshot_gui_2_dark.png
├── screenshot_view_yolo.png
├── screenshot_gui_1_light.png
├── screenshot_gui_2_light.png
├── screenshot_labelme_nsfw.png
└── screenshot_labelme_sfw.png
├── packaging
├── flatpak
│ ├── share
│ │ ├── io.github.ladaapp.lada.png
│ │ └── io.github.ladaapp.lada.desktop
│ ├── README.md
│ └── io.github.ladaapp.lada.yaml
├── windows
│ ├── pyinstaller_runtime_hook_lada.py
│ └── README.md
├── docker
│ ├── README.md
│ └── Dockerfile
└── README.md
├── configs
├── yolo
│ ├── nsfw_detection_dataset_config.yaml
│ ├── watermark_detection_dataset_config.yaml
│ └── mosaic_detection_dataset_config.yaml
└── basicvsrpp
│ ├── _base_
│ └── default_runtime.py
│ └── mosaic_restoration_generic_stage1.py
├── MANIFEST.in
├── .mailmap
├── .dockerignore
├── .gitignore
├── scripts
├── training
│ ├── train-nsfw-detection-yolo.py
│ ├── train-watermark-detection-yolo.py
│ ├── train-mosaic-detection-yolo.py
│ ├── export-weights-basicvsrpp-stage2-for-inference.py
│ ├── convert-weights-basicvsrpp-stage1-to-stage2.py
│ ├── train-bj-classifier.py
│ └── train-mosaic-restoration-basicvsrpp.py
├── evaluation
│ ├── validate-yolo.py
│ ├── validate-deepmosaics.py
│ └── validate-basicvsrpp.py
└── dataset_creation
│ └── convert-dataset-labelme-to-yolo.py
├── patches
├── increase_mms_time_limit.patch
├── gvsbuild_ffmpeg.patch
├── fix_loading_mmengine_weights_on_torch26_and_higher.diff
├── adjust_mmengine_resume_dataloader.patch
├── remove_use_of_torch_dist_in_mmengine.patch
└── adw_spinner_to_gtk_spinner.patch
├── LICENSES
└── MIT.txt
└── pyproject.toml
/lada/cli/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.13
2 |
--------------------------------------------------------------------------------
/lada/gui/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/gui/export/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/gui/preview/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/yolo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/datasetcreation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/lada/gui/fileselection/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/centerface/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/deepmosaics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/deepmosaics/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/models/deepmosaics/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model_weights/3rd_party/.gitignore:
--------------------------------------------------------------------------------
1 | !.gitignore
--------------------------------------------------------------------------------
/translations/release_ready_translations.txt:
--------------------------------------------------------------------------------
1 | zh_CN zh_TW nl ja es
--------------------------------------------------------------------------------
/lada/gui/icons/cafe-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/edit-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-dark.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-light.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-system.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/playback-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/plus-large-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/speaker-0-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/speaker-4-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/models/dover/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .models import *
3 |
--------------------------------------------------------------------------------
/lada/gui/icons/cross-large-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/external-link-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/folder-open-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/pause-large-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/check-round-outline2-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/exclamation-mark-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/media-playback-pause-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/media-playback-start-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/icons/sliders-horizontal-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/assets/screenshot_cli_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_cli_1.png
--------------------------------------------------------------------------------
/lada/gui/icons/arrow-pointing-away-from-line-right-symbolic.svg.license:
--------------------------------------------------------------------------------
1 | SPDX-License-Identifier: CC0-1.0
2 |
--------------------------------------------------------------------------------
/lada/gui/resources.gresource:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/lada/gui/resources.gresource
--------------------------------------------------------------------------------
/lada/models/dover/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | opencv-python
4 | numpy
5 | timm
6 | einops
7 |
--------------------------------------------------------------------------------
/assets/screenshot_gui_1_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_gui_1_dark.png
--------------------------------------------------------------------------------
/assets/screenshot_gui_2_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_gui_2_dark.png
--------------------------------------------------------------------------------
/assets/screenshot_view_yolo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_view_yolo.png
--------------------------------------------------------------------------------
/assets/screenshot_gui_1_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_gui_1_light.png
--------------------------------------------------------------------------------
/assets/screenshot_gui_2_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_gui_2_light.png
--------------------------------------------------------------------------------
/assets/screenshot_labelme_nsfw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_labelme_nsfw.png
--------------------------------------------------------------------------------
/assets/screenshot_labelme_sfw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/assets/screenshot_labelme_sfw.png
--------------------------------------------------------------------------------
/lada/gui/icons/lada-logo-gray.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/lada/gui/icons/lada-logo-gray.png
--------------------------------------------------------------------------------
/model_weights/lada_nsfw_detection_model.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_detection_model_v2.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_detection_model_v3.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_edge_detection_model.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_nsfw_detection_model_v1.1.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_nsfw_detection_model_v1.2.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_nsfw_detection_model_v1.3.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_watermark_detection_model.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_detection_model_v3.1_fast.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_restoration_model_bj_pov.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_watermark_detection_model_v1.1.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_watermark_detection_model_v1.2.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_watermark_detection_model_v1.3.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_detection_model_v3.1_accurate.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_restoration_model_generic_v1.1.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/model_weights/lada_mosaic_restoration_model_generic_v1.2.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-License-Identifier: AGPL-3.0
3 |
--------------------------------------------------------------------------------
/packaging/flatpak/share/io.github.ladaapp.lada.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ladaapp/lada/HEAD/packaging/flatpak/share/io.github.ladaapp.lada.png
--------------------------------------------------------------------------------
/configs/yolo/nsfw_detection_dataset_config.yaml:
--------------------------------------------------------------------------------
1 | path: nsfw_detection
2 | train: train/images
3 | val: val/images
4 |
5 | names:
6 | 0: 'nsfw'
7 |
--------------------------------------------------------------------------------
/configs/yolo/watermark_detection_dataset_config.yaml:
--------------------------------------------------------------------------------
1 | path: watermark_detection
2 | train: train/images
3 | val: val/images
4 |
5 | names:
6 | 0: logo
7 | 1: text
8 | nc: 2
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # As locale is not a package (no __init__.py) package_data in setup.py will not pick it up so we have to add these files like so:
2 | include lada/locale/*/LC_MESSAGES/*.mo
--------------------------------------------------------------------------------
/lada/models/yolo/yolo.py:
--------------------------------------------------------------------------------
1 | from ultralytics.models import YOLO
2 |
3 | class Yolo(YOLO):
4 |
5 | def __init__(self, *args, **kwargs):
6 | super().__init__(*args, **kwargs)
--------------------------------------------------------------------------------
/configs/yolo/mosaic_detection_dataset_config.yaml:
--------------------------------------------------------------------------------
1 | path: mosaic_detection
2 | train: train/images
3 | val: val/images
4 |
5 | names:
6 | 0: 'mosaic_nsfw'
7 | 1: 'mosaic_sfw_head'
8 |
--------------------------------------------------------------------------------
/lada/locale/README.md:
--------------------------------------------------------------------------------
1 | This is the locale directory used by gettext to pick up compiled translations (.mo files).
2 |
3 | Use the scripts in `translations` directory to compile the files.
--------------------------------------------------------------------------------
/lada/gui/icons/lada-logo-gray.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Lada Authors
2 | SPDX-FileCopyrightText: Twitter Emoji (Twemoji) Authors
3 | SPDX-License-Identifier: CC-BY-4.0 AND AGPL-3.0
4 |
--------------------------------------------------------------------------------
/model_weights/3rd_party/640m.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: NudeNet Authors
2 | SPDX-License-Identifier: MIT AND AGPL-3.0
3 | Source: https://github.com/notAI-tech/NudeNet/releases/download/v3.4-weights/640m.pt
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-dark.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-light.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.mailmap:
--------------------------------------------------------------------------------
1 | ladaapp lada
2 | ladaapp
3 | Codeberg Translate
--------------------------------------------------------------------------------
/model_weights/3rd_party/DOVER.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2022 NTU Visual Quality Assessment Group
2 | SPDX-License-Identifier: MIT
3 | Source: https://github.com/QualityAssessment/DOVER/releases/download/v0.1.0/DOVER.pth
--------------------------------------------------------------------------------
/model_weights/3rd_party/spynet_20210409-c6c1bd09.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: OpenMMLab
2 | SPDX-License-Identifier: Apache-2.0
3 | Source: https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth
--------------------------------------------------------------------------------
/model_weights/3rd_party/ch_head_s_1536_e150_best_mMR.pt.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: BPJDet Authors
2 | SPDX-License-Identifier: GPL-3.0
3 | Source: https://huggingface.co/HoyerChou/BPJDet/resolve/main/ch_head_s_1536_e150_best_mMR.pt?download=true
--------------------------------------------------------------------------------
/model_weights/3rd_party/centerface.onnx.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: deface Authors
2 | SPDX-FileCopyrightText: CenterFace Authors
3 | SPDX-License-Identifier: MIT
4 | Source: https://github.com/ORB-HD/deface/raw/refs/tags/v1.5.0/deface/centerface.onnx
--------------------------------------------------------------------------------
/model_weights/3rd_party/clean_youknow_video.pth.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: DeepMosaics Authors
2 | SPDX-License-Identifier: GPL-3.0
3 | Source: https://drive.usercontent.google.com/download?id=1ulct4RhRxQp1v5xwEmUH7xz7AK42Oqlw&export=download&confirm=t
--------------------------------------------------------------------------------
/lada/gui/icons/plus-large-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | venv_gtk/
3 | venv_gtk_release/
4 | venv_release_win/
5 | experiments/
6 | build/
7 | dist/
8 | .idea/
9 | .flatpak-builder/
10 | build_flatpak/
11 | build_gtk/
12 | build_gtk_release/
13 | model_weights/
14 | datasets/
15 | lada/locale/*/
16 | *.pt
17 | *.pth
--------------------------------------------------------------------------------
/packaging/flatpak/share/io.github.ladaapp.lada.desktop:
--------------------------------------------------------------------------------
1 | [Desktop Entry]
2 | Version=1.0
3 | Type=Application
4 |
5 | Name=Lada
6 | Comment=Remove and recover pixelated areas in adult videos
7 | Categories=AudioVideo;Adult;
8 |
9 | Icon=io.github.ladaapp.lada
10 | Exec=lada
11 | Terminal=false
--------------------------------------------------------------------------------
/packaging/windows/pyinstaller_runtime_hook_lada.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | base_path = sys._MEIPASS
5 |
6 | os.environ["LADA_MODEL_WEIGHTS_DIR"] = os.path.join(base_path, "model_weights")
7 | os.environ["PATH"] = os.path.join(base_path, "bin")
8 | os.environ["LOCALE_DIR"] = os.path.join(base_path, "lada", "locale")
--------------------------------------------------------------------------------
/lada/models/dover/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022 NTU Visual Quality Assessment Group
2 | # SPDX-License-Identifier: MIT AND AGPL-3.0
3 | # Code vendored from: https://github.com/VQAssessment/DOVER
4 |
5 | ## API for DOVER and its variants
6 | #from .basic_datasets import *
7 | from .dover_datasets import *
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | venv_gtk/
3 | venv_gtk_release/
4 | venv_release_win/
5 | experiments/
6 | build/
7 | dist/
8 | .idea/
9 | .flatpak-builder/
10 | build_flatpak/
11 | build_gtk/
12 | build_gtk_release/
13 | torch_compile_debug/
14 | __pycache__/
15 | *.egg-info/
16 | *.pth
17 | *.pt
18 | *.onnx
19 | *.pyc
20 | *.mo
21 | *.ep
--------------------------------------------------------------------------------
/lada/gui/icons/README.md:
--------------------------------------------------------------------------------
1 | After adding a new icon you have to include it in `resources.gresource.xml` and compile it to create a new version
2 | of `resources.gresource` with the command `glib-compile-resources resources.gresource.xml`
3 |
4 | If possible use GNOME icons. Easiest way to do that is using https://flathub.org/en/apps/org.gnome.design.IconLibrary
--------------------------------------------------------------------------------
/lada/gui/preview/timeline.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | def register_all_modules():
5 | from lada.models.basicvsrpp.mmagic import register_all_modules
6 | register_all_modules()
7 | from lada.models.basicvsrpp.basicvsrpp_gan import BasicVSRPlusPlusGanNet, BasicVSRPlusPlusGan
8 | from lada.models.basicvsrpp.mosaic_video_dataset import MosaicVideoDataset
--------------------------------------------------------------------------------
/scripts/training/train-nsfw-detection-yolo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from lada.utils.ultralytics_utils import set_default_settings
5 | from lada.models.yolo.yolo import Yolo
6 |
7 | set_default_settings()
8 |
9 | model = Yolo('yolo11m-seg.pt')
10 | model.train(data='configs/yolo/nsfw_detection_dataset_config.yaml', epochs=200, imgsz=640, name="train_nsfw_detection_yolo11m")
11 |
--------------------------------------------------------------------------------
/patches/increase_mms_time_limit.patch:
--------------------------------------------------------------------------------
1 | --- old/ultralytics/utils/nms.py 2025-06-19 00:00:00.0.0
2 | +++ new/ultralytics/utils/nms.py 2025-06-19 00:00:00.0.0
3 | @@ -20,7 +20,7 @@
4 | labels=(),
5 | max_det: int = 300,
6 | nc: int = 0, # number of classes (optional)
7 | - max_time_img: float = 0.05,
8 | + max_time_img: float = 0.3,
9 | max_nms: int = 30000,
10 | max_wh: int = 7680,
11 | rotated: bool = False,
12 |
--------------------------------------------------------------------------------
/scripts/training/train-watermark-detection-yolo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from lada.utils.ultralytics_utils import set_default_settings
5 | from lada.models.yolo.yolo import Yolo
6 |
7 | set_default_settings()
8 |
9 | model = Yolo('yolo11s.pt')
10 | model.train(data='configs/yolo/watermark_detection_dataset_config.yaml', epochs=100, imgsz=512, single_cls=True, name="train_watermark_detection_yolo11s")
11 |
--------------------------------------------------------------------------------
/lada/gui/icons/color-scheme-system.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/patches/gvsbuild_ffmpeg.patch:
--------------------------------------------------------------------------------
1 | --- old/gvsbuild/patches/ffmpeg/build/build.sh 2025-11-12 13:00:0.0 +0000
2 | +++ new/gvsbuild/patches/ffmpeg/build/build.sh 2025-11-12 13:00:0.0 +0000
3 | @@ -41,6 +41,7 @@
4 | configure_cmd[idx++]="--disable-programs"
5 | configure_cmd[idx++]="--disable-avdevice"
6 | configure_cmd[idx++]="--disable-swresample"
7 | +configure_cmd[idx++]="--enable-decoder=mp2float"
8 |
9 | if [ "$build_type" = "debug" ]; then
10 | configure_cmd[idx++]="--enable-debug"
11 |
--------------------------------------------------------------------------------
/lada/gui/icons/pause-large-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/translations/update_po.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | # SPDX-FileCopyrightText: Lada Authors
4 | # SPDX-License-Identifier: AGPL-3.0
5 |
6 | translations_dir=$(dirname -- "$0")
7 | if [ "$(pwd)" != "$translations_dir" ] ; then
8 | cd "$translations_dir"
9 | fi
10 |
11 | find . -mindepth 1 -maxdepth 1 -type f -name "*.po" -printf '%f\n' | while read po_file ; do
12 | lang="${po_file%.po}"
13 | echo "Updating language $lang .po file"
14 | msgmerge --no-fuzzy-matching --no-wrap --previous --update "$po_file" lada.pot
15 | done
--------------------------------------------------------------------------------
/lada/gui/preview/headerbar_files_drop_down.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
15 |
--------------------------------------------------------------------------------
/lada/utils/random_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import random
5 |
6 | import numpy as np
7 |
8 | repeatable_rng_random = random.Random(42)
9 | repeatable_rng_numpy = np.random.RandomState(42)
10 |
11 | def get_rngs(repeatable) -> tuple[random, np.random]:
12 | if repeatable:
13 | rng_random = repeatable_rng_random
14 | rng_numpy = repeatable_rng_numpy
15 | else:
16 | rng_random = random
17 | rng_numpy = np.random
18 | return rng_random, rng_numpy
--------------------------------------------------------------------------------
/patches/fix_loading_mmengine_weights_on_torch26_and_higher.diff:
--------------------------------------------------------------------------------
1 | --- old/mmengine/runner/checkpoint.py 2025-06-19 00:00:00.0.0
2 | +++ new/mmengine/runner/checkpoint.py 2025-06-19 00:00:00.0.0
3 | @@ -344,7 +344,7 @@
4 | filename = osp.expanduser(filename)
5 | if not osp.isfile(filename):
6 | raise FileNotFoundError(f'{filename} can not be found.')
7 | - checkpoint = torch.load(filename, map_location=map_location)
8 | + checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
9 | return checkpoint
10 |
11 |
12 |
--------------------------------------------------------------------------------
/lada/gui/icons/media-playback-pause-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
8 |
--------------------------------------------------------------------------------
/lada/gui/icons/exclamation-mark-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/gui/icons/media-playback-start-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
5 |
--------------------------------------------------------------------------------
/lada/gui/icons/edit-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/packaging/docker/README.md:
--------------------------------------------------------------------------------
1 | ## Building the image
2 | ```shell
3 | docker build . -f packaging/docker/Dockerfile -t ladaapp/lada: --build-arg SOURCE_DATE_EPOCH=0
4 | ```
5 |
6 | ## Publish image on Dockerhub
7 | ```shell
8 | docker login -u ladaapp
9 | docker push ladaapp/lada:
10 | docker tag ladaapp/lada: ladaapp/lada:latest
11 | docker push ladaapp/lada:latest
12 | ```
13 |
14 | ## Update python dependencies
15 |
16 | The `requirements.txt` file is generated by `uv export` based on `uv.lock` located in the root of the project.
17 |
18 | See packaging [README.md](../README.md) for context.
--------------------------------------------------------------------------------
/lada/gui/style.css:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Lada Authors
3 | * SPDX-License-Identifier: AGPL-3.0
4 | */
5 |
6 | #label_cursor_time {
7 | background-color: color-mix(in srgb, var(--dialog-bg-color) 70%, transparent);
8 | }
9 |
10 | #drop_down_files > button:not(:checked):not(:hover) {
11 | background: transparent;
12 | box-shadow: none;
13 | }
14 |
15 | progressbar.finished > trough > progress {
16 | background-color: var(--success-color);
17 | }
18 |
19 | progressbar.failed > trough > progress {
20 | background-color: var(--error-color);
21 | }
22 |
23 | .fullscreen-preview {
24 | background-color: black;
25 | }
--------------------------------------------------------------------------------
/lada/restorationpipeline/deepmosaics_mosaic_restorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lada.models.deepmosaics.inference import restore_video_frames
4 | from lada.utils import ImageTensor
5 |
6 | class DeepmosaicsMosaicRestorer:
7 | def __init__(self, model, device):
8 | self.model = model
9 | self.device = device
10 | self.dtype = model.dtype
11 |
12 | def restore(self, video: list[ImageTensor]) -> list[ImageTensor]:
13 | frames = [x.contiguous().numpy() for x in video]
14 | restored_frames = restore_video_frames(self.device.index, self.model, frames)
15 | return [torch.from_numpy(x) for x in restored_frames]
16 |
17 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/typing.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | from typing import Callable, Dict, List, Sequence, Tuple, Union
6 |
7 | from mmengine.config import ConfigDict
8 | from mmengine.structures import BaseDataElement
9 | from torch import Tensor
10 |
11 | ForwardInputs = Tuple[Dict[str, Union[Tensor, str, int]], Tensor]
12 | SampleList = Sequence[BaseDataElement]
13 |
14 | NoiseVar = Union[Tensor, Callable, None]
15 | LabelVar = Union[Tensor, Callable, List[int], None]
16 |
17 | ConfigType = Union[ConfigDict, Dict]
18 |
--------------------------------------------------------------------------------
/lada/gui/icons/speaker-4-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/utils/autoanchor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: BPJDet Authors
2 | # SPDX-FileCopyrightText: YOLOv5 🚀 by Ultralytics
3 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
4 | # Code vendored from: https://github.com/hnuzhy/BPJDet
5 |
6 | """
7 | Auto-anchor utils
8 | """
9 |
10 | def check_anchor_order(m):
11 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
12 | a = m.anchor_grid.prod(-1).view(-1) # anchor area
13 | da = a[-1] - a[0] # delta a
14 | ds = m.stride[-1] - m.stride[0] # delta s
15 | if da.sign() != ds.sign(): # same order
16 | print('Reversing anchor order')
17 | m.anchors[:] = m.anchors.flip(0)
18 | m.anchor_grid[:] = m.anchor_grid.flip(0)
--------------------------------------------------------------------------------
/lada/gui/icons/arrow-pointing-away-from-line-right-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/data/JointBP_CrowdHuman_head.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: BPJDet Authors
2 | # SPDX-FileCopyrightText: YOLOv5 🚀 by Ultralytics
3 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
4 | # Code vendored from: https://github.com/hnuzhy/BPJDet
5 |
6 | DATA = dict(
7 | nc=2, # number of classes (two class: human body, human head)
8 | num_offsets=2, # number of coordinates introduced by the body part, e.g., (head_x, head_y)
9 | names=[ 'person', 'head' ], # class names.
10 | conf_thres_part=0.45, # the larger conf threshold for filtering body-part detection proposals
11 | iou_thres_part=0.75, # the smaller iou threshold for filtering body-part detection proposals
12 | match_iou_thres=0.6, # whether a body-part in matched with one body bbox
13 | )
--------------------------------------------------------------------------------
/lada/models/dover/models/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022 NTU Visual Quality Assessment Group
2 | # SPDX-License-Identifier: MIT AND AGPL-3.0
3 | # Code vendored from: https://github.com/VQAssessment/DOVER
4 |
5 | from .conv_backbone import convnext_3d_small, convnext_3d_tiny
6 | from .evaluator import DOVER, BaseEvaluator, BaseImageEvaluator
7 | from .head import IQAHead, VARHead, VQAHead
8 | from .swin_backbone import SwinTransformer2D as IQABackbone
9 | from .swin_backbone import SwinTransformer3D as VQABackbone
10 | from .swin_backbone import swin_3d_small, swin_3d_tiny
11 |
12 | __all__ = [
13 | "VQABackbone",
14 | "IQABackbone",
15 | "VQAHead",
16 | "IQAHead",
17 | "VARHead",
18 | "BaseEvaluator",
19 | "BaseImageEvaluator",
20 | "DOVER",
21 | ]
22 |
--------------------------------------------------------------------------------
/lada/gui/icons/playback-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/models/deepmosaics/models/loadmodel.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: DeepMosaics Authors
2 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/HypoX64/DeepMosaics/
4 |
5 | import torch
6 | from lada.models.deepmosaics.models import model_util
7 | from lada.models.deepmosaics.models.BVDNet import define_G as video_G
8 |
9 | def video(device: torch.device, model_path: str, fp16: bool):
10 | dtype = torch.float16 if fp16 else torch.float32
11 | gpu_id = str(device.index) if device.type == 'cuda' else '-1'
12 | netG = video_G(N=2,n_blocks=4,gpu_id=gpu_id)
13 | netG.load_state_dict(torch.load(model_path))
14 | netG = model_util.todevice(netG,gpu_id)
15 | netG.eval()
16 | netG.to(dtype)
17 | netG.dtype = dtype
18 | return netG
19 |
20 |
--------------------------------------------------------------------------------
/scripts/training/train-mosaic-detection-yolo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from lada.utils.ultralytics_utils import set_default_settings
5 | from lada.models.yolo.yolo import Yolo
6 |
7 | set_default_settings()
8 |
9 | # !! uninstall albumentations, as it will blur and jpeg compress if installed and found by ultralytics. There seems to be no way to disable this in ultralytics
10 |
11 | # "accurate" model
12 | model = Yolo('yolo11s-seg.pt')
13 | model.train(data='configs/yolo/mosaic_detection_dataset_config.yaml', epochs=200, imgsz=640, name="train_mosaic_detection_yolo11s")
14 |
15 | # "fast" model
16 | # model = Yolo('yolo11n-seg.pt')
17 | # model.train(data='configs/yolo/mosaic_detection_dataset_config.yaml', epochs=200, imgsz=640, name="train_mosaic_detection_yolo11n")
--------------------------------------------------------------------------------
/lada/utils/os_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import subprocess
5 | import sys
6 |
7 | import torch
8 |
9 | def get_subprocess_startup_info():
10 | if sys.platform != "win32":
11 | return None
12 | startup_info = subprocess.STARTUPINFO()
13 | startup_info.dwFlags |= subprocess.STARTF_USESHOWWINDOW
14 | return startup_info
15 |
16 | def gpu_has_tensor_cores(device_index: int = 0) -> bool:
17 | if not torch.cuda.is_available():
18 | return False
19 | major, minor = torch.cuda.get_device_capability(device_index)
20 | if major < 7:
21 | return False
22 | if major > 7:
23 | return True
24 | name = torch.cuda.get_device_name(device_index).lower()
25 | if "gtx 16" in name:
26 | return False
27 | return True
--------------------------------------------------------------------------------
/lada/gui/icons/cross-large-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/logger.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import logging
6 |
7 | from mmengine.logging import print_log
8 | from termcolor import colored
9 |
10 |
11 | def print_colored_log(msg, level=logging.INFO, color='magenta'):
12 | """Print colored log with default logger.
13 |
14 | Args:
15 | msg (str): Message to log.
16 | level (int): The root logger level. Note that only the process of
17 | rank 0 is affected, while other processes will set the level to
18 | "Error" and be silent most of the time.Log level,
19 | default to 'info'.
20 | color (str, optional): Color 'magenta'.
21 | """
22 | print_log(colored(msg, color), 'current', level)
23 |
--------------------------------------------------------------------------------
/patches/adjust_mmengine_resume_dataloader.patch:
--------------------------------------------------------------------------------
1 | --- old/mmengine/runner/loops.py 2025-06-19 00:00:00.0.0
2 | +++ new/mmengine/runner/loops.py 2025-06-19 00:00:00.0.0
3 | @@ -274,14 +274,6 @@
4 | # In iteration-based training loop, we treat the whole training process
5 | # as a big epoch and execute the corresponding hook.
6 | self.runner.call_hook('before_train_epoch')
7 | - if self._iter > 0:
8 | - print_log(
9 | - f'Advance dataloader {self._iter} steps to skip data '
10 | - 'that has already been trained',
11 | - logger='current',
12 | - level=logging.WARNING)
13 | - for _ in range(self._iter):
14 | - next(self.dataloader_iterator)
15 | while self._iter < self._max_iters and not self.stop_training:
16 | self.runner.model.train()
17 |
18 |
--------------------------------------------------------------------------------
/lada/utils/encoding_presets.csv:
--------------------------------------------------------------------------------
1 | preset_name|preset_description(translatable)|encoder_name|encoder_options
2 | h264-cpu-hq|H.264, libx264 software encoder, High Quality, Medium File Size|libx264|-crf 20 -preset medium
3 | h264-cpu-fast|H.264, libx264 software encoder, Fast, Large File Size|libx264|-crf 23 -preset veryfast
4 | hevc-cpu-hq|HEVC, libx265 software encoder, High Quality, Medium File Size|libx265|-crf 23 -preset medium -x265-params log_level=error
5 | hevc-nvidia-gpu-hq|HEVC, Nvidia hardware encoder, High Quality, Medium File Size|hevc_nvenc|-preset p7 -tune hq -rc constqp -qp 20 -spatial_aq 1 -aq-strength 6 -bf 4 -b_ref_mode middle -rc-lookahead 32
6 | hevc-nvidia-gpu-uhq|HEVC, Nvidia hardware encoder, Indistinguishable Quality, Large File Size|hevc_nvenc|-preset p7 -tune hq -rc constqp -qp 18 -spatial_aq 1 -aq-strength 6 -bf 4 -b_ref_mode middle -rc-lookahead 32
7 | h264-nvidia-gpu-fast|H.264, Nvidia hardware encoder, Fast, Large File Size|h264_nvenc|-preset p4 -rc constqp -qp 23
--------------------------------------------------------------------------------
/lada/models/deepmosaics/util/image_processing.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: DeepMosaics Authors
2 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/HypoX64/DeepMosaics/
4 |
5 | import numpy as np
6 |
7 | def psnr(img1,img2):
8 | mse = np.mean((img1/255.0-img2/255.0)**2)
9 | if mse < 1e-10:
10 | return 100
11 | psnr_v = 20*np.log10(1/np.sqrt(mse))
12 | return psnr_v
13 |
14 | def splice(imgs,splice_shape):
15 | '''Stitching multiple images, all imgs must have the same size
16 | imgs : [img1,img2,img3,img4]
17 | splice_shape: (2,2)
18 | '''
19 | h,w,ch = imgs[0].shape
20 | output = np.zeros((h*splice_shape[0],w*splice_shape[1],ch),np.uint8)
21 | cnt = 0
22 | for i in range(splice_shape[0]):
23 | for j in range(splice_shape[1]):
24 | if cnt < len(imgs):
25 | output[h*i:h*(i+1),w*j:w*(j+1)] = imgs[cnt]
26 | cnt += 1
27 | return output
28 |
29 |
--------------------------------------------------------------------------------
/lada/gui/icons/sliders-horizontal-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/gui/config/no_gpu_banner.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | vertical
15 |
16 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/lada/gui/export/spinner_button.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/lada/gui/icons/speaker-0-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/translations/update_pot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | # SPDX-FileCopyrightText: Lada Authors
4 | # SPDX-License-Identifier: AGPL-3.0
5 |
6 | translations_dir=$(dirname -- "$0")
7 | if [ "$(pwd)" != "$translations_dir" ] ; then
8 | cd "$translations_dir"
9 | # go to project root so the file paths in the .pot file will show up nicely
10 | cd ..
11 | fi
12 | export TZ=UTC
13 |
14 | echo "Updating template .pot file"
15 | xgettext \
16 | --package-name=lada \
17 | --msgid-bugs-address=https://codeberg.org/ladaapp/lada/issues \
18 | --from-code=UTF-8 \
19 | --no-wrap \
20 | -f <( find lada/gui lada/cli -name "*.ui" -or -name "*.py" ) \
21 | -o $translations_dir/lada.pot
22 |
23 | python3 $translations_dir/extract_csv_strings.py lada/utils/encoding_presets.csv $translations_dir/csv_strings.pot 'preset_description(translatable)'
24 | echo "Merging extracted strings from python files with strings from encoding_presets.csv"
25 | msgcat --no-wrap $translations_dir/lada.pot $translations_dir/csv_strings.pot -o $translations_dir/lada.pot
26 | rm $translations_dir/csv_strings.pot
27 |
--------------------------------------------------------------------------------
/LICENSES/MIT.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
6 | associated documentation files (the "Software"), to deal in the Software without restriction, including
7 | without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
9 | following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in all copies or substantial
12 | portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
16 | EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
18 | USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
--------------------------------------------------------------------------------
/lada/utils/box_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import random
5 |
6 | from lada.utils import Box, Image
7 |
8 | def box_overlap(box1: Box, box2: Box):
9 | y1min, x1min, y1max, x1max = box1
10 | y2min, x2min, y2max, x2max = box2
11 | return x1min < x2max and x2min < x1max and y1min < y2max and y2min < y1max
12 |
13 | def scale_box(img: Image, box: Box, mask_scale=1.0) -> Box:
14 | img_h, img_w = img.shape[:2]
15 | s = mask_scale - 1.0
16 | t, l, b, r = box
17 | w, h = r - l + 1, b - t + 1
18 | t -= h * s
19 | b += h * s
20 | l -= w * s
21 | r += w * s
22 | t = max(0, t)
23 | b = min(img_h - 1, b)
24 | l = max(0, l)
25 | r = min(img_w - 1, r)
26 | return int(t), int(l), int(b), int(r)
27 |
28 | def random_scale_box(img: Image, box: Box, scale_range=(1.0, 1.5)) -> Box:
29 | scale = random.uniform(scale_range[0], scale_range[1])
30 | return scale_box(img, box, scale)
31 |
32 | def convert_from_opencv(opencv_box: tuple[int, int, int, int]) -> Box:
33 | x, y, w, h = opencv_box
34 | t, l, b, r = y, x, y + h, x + w
35 | return t, l, b, r
--------------------------------------------------------------------------------
/lada/gui/icons/cafe-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/translations/extract_csv_strings.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import csv
5 | import sys
6 |
7 | def extract_csv_strings(csv_file, pot_file, column_header):
8 | with open(csv_file, mode='r', newline='', encoding='utf-8') as csvfile, \
9 | open(pot_file, mode='w', encoding='utf-8') as potfile:
10 | reader = csv.DictReader(csvfile, delimiter='|')
11 | for row in reader:
12 | description = row[column_header]
13 | description = description.replace('"', '\\"') # Escape quotes
14 |
15 | potfile.write(f'#: {csv_file}:{reader.line_num}\n')
16 | potfile.write(f'msgid "{description}"\n')
17 | potfile.write('msgstr ""\n\n')
18 |
19 | print(f"{reader.line_num - 1} Strings extracted from {csv_file} and written to {pot_file}")
20 |
21 | if __name__ == "__main__":
22 | if len(sys.argv) != 4:
23 | print("Usage: python extract_csv_strings.py ")
24 | sys.exit(1)
25 |
26 | input_csv = sys.argv[1]
27 | output_pot = sys.argv[2]
28 | column_header = sys.argv[3]
29 | extract_csv_strings(input_csv, output_pot, column_header)
30 |
--------------------------------------------------------------------------------
/configs/basicvsrpp/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | default_scope = 'lada.basicvsrpp.mmagic'
2 | save_dir = './work_dirs'
3 |
4 | default_hooks = dict(
5 | timer=dict(type='IterTimerHook'),
6 | logger=dict(type='LoggerHook', interval=100),
7 | param_scheduler=dict(type='ParamSchedulerHook'),
8 | checkpoint=dict(
9 | type='CheckpointHook',
10 | interval=5000,
11 | out_dir=save_dir,
12 | by_epoch=False,
13 | max_keep_ckpts=10,
14 | save_best='PSNR',
15 | rule='greater',
16 | ),
17 | sampler_seed=dict(type='DistSamplerSeedHook'),
18 | )
19 |
20 | env_cfg = dict(
21 | cudnn_benchmark=False,
22 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),
23 | dist_cfg=dict(backend='nccl'),
24 | )
25 |
26 | log_level = 'INFO'
27 | log_processor = dict(type='LogProcessor', window_size=100, by_epoch=False)
28 |
29 | load_from = None
30 | resume = False
31 |
32 | vis_backends = [dict(type='LocalVisBackend')]
33 | visualizer = dict(
34 | type='ConcatImageVisualizer',
35 | vis_backends=vis_backends,
36 | fn_key='gt_path',
37 | img_keys=['gt_img', 'input', 'pred_img'],
38 | bgr2rgb=True)
39 | custom_hooks = [dict(type='BasicVisualizationHook', interval=1)]
40 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/utils/datasets.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: BPJDet Authors
2 | # SPDX-FileCopyrightText: YOLOv5 🚀 by Ultralytics
3 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
4 | # Code vendored from: https://github.com/hnuzhy/BPJDet
5 |
6 | """
7 | Dataloaders and dataset utils
8 | """
9 | from PIL import Image
10 |
11 | def exif_transpose(image):
12 | """
13 | Transpose a PIL image accordingly if it has an EXIF Orientation tag.
14 | From https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
15 |
16 | :param image: The image to transpose.
17 | :return: An image.
18 | """
19 | exif = image.getexif()
20 | orientation = exif.get(0x0112, 1) # default 1
21 | if orientation > 1:
22 | method = {2: Image.FLIP_LEFT_RIGHT,
23 | 3: Image.ROTATE_180,
24 | 4: Image.FLIP_TOP_BOTTOM,
25 | 5: Image.TRANSPOSE,
26 | 6: Image.ROTATE_270,
27 | 7: Image.TRANSVERSE,
28 | 8: Image.ROTATE_90,
29 | }.get(orientation)
30 | if method is not None:
31 | image = image.transpose(method)
32 | del exif[0x0112]
33 | image.info["exif"] = exif.tobytes()
34 | return image
35 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: BPJDet Authors
2 | # SPDX-FileCopyrightText: YOLOv5 🚀 by Ultralytics
3 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
4 | # Code vendored from: https://github.com/hnuzhy/BPJDet
5 |
6 | """
7 | Model validation metrics
8 | """
9 |
10 | import torch
11 |
12 | def box_iou(box1, box2):
13 | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
14 | """
15 | Return intersection-over-union (Jaccard index) of boxes.
16 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
17 | Arguments:
18 | box1 (Tensor[N, 4])
19 | box2 (Tensor[M, 4])
20 | Returns:
21 | iou (Tensor[N, M]): the NxM matrix containing the pairwise
22 | IoU values for every element in boxes1 and boxes2
23 | """
24 |
25 | def box_area(box):
26 | # box = 4xn
27 | return (box[2] - box[0]) * (box[3] - box[1])
28 |
29 | area1 = box_area(box1.T)
30 | area2 = box_area(box2.T)
31 |
32 | # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
33 | inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
34 | return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
35 |
--------------------------------------------------------------------------------
/lada/gui/icons/external-link-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/gui/config/no_gpu_banner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 | import pathlib
6 |
7 | from gi.repository import Adw, Gtk, GObject
8 |
9 | from lada import LOG_LEVEL
10 | from lada.gui import utils
11 | from lada.gui.config.config import Config
12 |
13 | here = pathlib.Path(__file__).parent.resolve()
14 |
15 | logger = logging.getLogger(__name__)
16 | logging.basicConfig(level=LOG_LEVEL)
17 |
18 | @Gtk.Template(string=utils.translate_ui_xml(here / 'no_gpu_banner.ui'))
19 | class NoGpuBanner(Gtk.Box):
20 | __gtype_name__ = "NoGpuBanner"
21 |
22 | banner: Adw.Banner = Gtk.Template.Child()
23 |
24 | def __init__(self, **kwargs) -> None:
25 | super().__init__(**kwargs)
26 | self._config: Config | None = None
27 |
28 | @GObject.Property(type=Config)
29 | def config(self):
30 | return self._config
31 |
32 | @config.setter
33 | def config(self, value):
34 | self._config = value
35 | if self._config.get_property('device') == 'cpu':
36 | self.banner.set_revealed(True)
37 |
38 | @Gtk.Template.Callback()
39 | def banner_no_gpu_button_clicked(self, button_clicked):
40 | self.banner.set_revealed(False)
41 |
42 | def set_revealed(self, value: bool):
43 | self.banner.set_revealed(value)
--------------------------------------------------------------------------------
/lada/models/dover/dover.yml:
--------------------------------------------------------------------------------
1 | data:
2 | val-l1080p:
3 | type: ViewDecompositionDataset
4 | args:
5 | weight: 0.620
6 | phase: test
7 | anno_file: ./examplar_data_labels/LSVQ/labels_1080p.txt
8 | data_prefix: ../datasets/LSVQ/
9 | sample_types:
10 | technical:
11 | fragments_h: 7
12 | fragments_w: 7
13 | fsize_h: 32
14 | fsize_w: 32
15 | aligned: 32
16 | clip_len: 32
17 | frame_interval: 2
18 | num_clips: 3
19 | aesthetic:
20 | size_h: 224
21 | size_w: 224
22 | clip_len: 32
23 | frame_interval: 2
24 | t_frag: 32
25 | num_clips: 1
26 |
27 | model:
28 | type: DOVER
29 | args:
30 | backbone:
31 | technical:
32 | type: swin_tiny_grpb
33 | checkpoint: true
34 | pretrained:
35 | aesthetic:
36 | type: conv_tiny
37 | backbone_preserve_keys: technical,aesthetic
38 | divide_head: true
39 | vqa_head:
40 | in_channels: 768
41 | hidden_channels: 64
42 |
43 | test_load_path: ./../../model_weights/3rd_party/DOVER.pth
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/lada/utils/torch_letterbox.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import torch
5 | from torchvision.transforms.v2 import Resize, Pad
6 | from torchvision.transforms.v2.functional import InterpolationMode
7 |
8 | class PyTorchLetterBox:
9 | def __init__(self, imgsz: int | tuple[int, int], original_shape: tuple[int, int], stride: int = 32) -> None:
10 | if isinstance(imgsz, int):
11 | new_shape: tuple[int, int] = (imgsz, imgsz)
12 | else:
13 | new_shape = imgsz
14 |
15 | self.original_shape = original_shape
16 | pad_value: float = 114.0/255.0
17 | h, w = original_shape
18 | new_h, new_w = new_shape
19 |
20 | r = min(new_h / h, new_w / w)
21 | new_unpad_w = int(round(w * r))
22 | new_unpad_h = int(round(h * r))
23 |
24 | dw = new_w - new_unpad_w
25 | dh = new_h - new_unpad_h
26 | dw = int(dw % stride)
27 | dh = int(dh % stride)
28 |
29 | resize = None if (h, w) == (new_unpad_h, new_unpad_w) else Resize(size=(new_unpad_h, new_unpad_w), interpolation=InterpolationMode.BILINEAR, antialias=False)
30 | pad = Pad(padding=(dw // 2, dh // 2, dw - (dw // 2), dh - (dh // 2)), fill=pad_value)
31 | self.transform = torch.nn.Sequential(resize, pad) if resize is not None else pad
32 |
33 | def __call__(self, image: torch.Tensor) -> torch.Tensor: # (B,C,H,W)
34 | return self.transform(image)
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/log_processor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | from mmengine.runner import LogProcessor as BaseLogProcessor
6 |
7 | from .registry import LOG_PROCESSORS
8 |
9 |
10 | @LOG_PROCESSORS.register_module() # type: ignore
11 | class LogProcessor(BaseLogProcessor):
12 | """LogProcessor inherits from :class:`mmengine.runner.LogProcessor` and
13 | overwrites :meth:`self.get_log_after_iter`.
14 |
15 | This log processor should be used along with
16 | :class:`mmagic.engine.runner.MultiValLoop` and
17 | :class:`mmagic.engine.runner.MultiTestLoop`.
18 | """
19 |
20 | def _get_dataloader_size(self, runner, mode) -> int:
21 | """Get dataloader size of current loop. In `MultiValLoop` and
22 | `MultiTestLoop`, we use `total_length` instead of `len(dataloader)` to
23 | denote the total number of iterations.
24 |
25 | Args:
26 | runner (Runner): The runner of the training/validation/testing
27 | mode (str): Current mode of runner.
28 |
29 | Returns:
30 | int: The dataloader size of current loop.
31 | """
32 | if hasattr(self._get_cur_loop(runner, mode), 'total_length'):
33 | return self._get_cur_loop(runner, mode).total_length
34 | else:
35 | return super()._get_dataloader_size(runner, mode)
36 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | from mmengine import DefaultScope
6 |
7 | SCOPE = 'lada.models.basicvsrpp.mmagic'
8 |
9 | def register_all_modules():
10 | from .base_edit_model import BaseEditModel
11 | from .basicvsr_plusplus_net import BasicVSRPlusPlusNet
12 | from .basicvsr import BasicVSR
13 | from .concat_visualizer import ConcatImageVisualizer
14 | from .data_preprocessor import DataPreprocessor
15 | from .ema import ExponentialMovingAverageHook
16 | from .gan_loss import GANLoss
17 | from .iter_time_hook import IterTimerHook
18 | from .log_processor import LogProcessor
19 | from .multi_optimizer_constructor import MultiOptimWrapperConstructor
20 | from .perceptual_loss import PerceptualLoss
21 | from .pixelwise_loss import CharbonnierLoss
22 | from .real_basicvsr import RealBasicVSR
23 | from .unet_disc import UNetDiscriminatorWithSpectralNorm
24 | from .vis_backend import TensorboardVisBackend
25 | from .visualization_hook import VisualizationHook
26 | from .evaluator import Evaluator
27 | from .psnr import PSNR
28 | from .ssim import SSIM
29 | from .multi_loops import MultiValLoop
30 |
31 | never_created = DefaultScope.get_current_instance() is None or not DefaultScope.check_instance_created(SCOPE)
32 | if never_created:
33 | DefaultScope.get_instance(SCOPE, scope_name=SCOPE)
34 | return
35 |
--------------------------------------------------------------------------------
/lada/gui/icons/check-round-outline2-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/lada/gui/export/spinner_button.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 | import pathlib
6 |
7 | from gi.repository import Gtk, Adw, GObject
8 |
9 | from lada import LOG_LEVEL
10 | from lada.gui import utils
11 |
12 | here = pathlib.Path(__file__).parent.resolve()
13 | logger = logging.getLogger(__name__)
14 | logging.basicConfig(level=LOG_LEVEL)
15 |
16 | @Gtk.Template(string=utils.translate_ui_xml(here / 'spinner_button.ui'))
17 | class SpinnerButton(Gtk.Button):
18 | __gtype_name__ = 'SpinnerButton'
19 |
20 | spinner: Adw.Spinner = Gtk.Template.Child()
21 | _label: Gtk.Label = Gtk.Template.Child()
22 |
23 | def __init__(self, **kwargs):
24 | super().__init__(**kwargs)
25 |
26 | @GObject.Property(type=str)
27 | def label(self):
28 | return self._label.get_label()
29 |
30 | @label.setter
31 | def label(self, value):
32 | self._label.set_label(value)
33 |
34 | @GObject.Property(type=bool, default=True)
35 | def spinner_visible(self):
36 | return self.spinner.get_visible()
37 |
38 | @spinner_visible.setter
39 | def spinner_visible(self, value):
40 | self.spinner.set_visible(value)
41 |
42 | @label.setter
43 | def label(self, value):
44 | self._label.set_label(value)
45 |
46 | def set_label(self, label: str):
47 | self._label.set_label(label)
48 |
49 | def get_label(self) -> str:
50 | return self._label.get_label()
51 |
52 | def set_spinner_visible(self, value: bool):
53 | self.spinner.set_visible(value)
--------------------------------------------------------------------------------
/lada/gui/icons/folder-open-symbolic.svg:
--------------------------------------------------------------------------------
1 |
2 |
5 |
--------------------------------------------------------------------------------
/lada/gui/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import sys
4 |
5 | is_running_pyinstaller_bundle = getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
6 |
7 | def _update_env_var(env_var, paths, separator=";"):
8 | assert sys.platform == "win32", "_update_env_var() only works on Windows with case-insensitive paths"
9 | paths_to_add = [str(path).lower() for path in paths]
10 | if env_var in os.environ:
11 | existing_paths = os.environ[env_var].lower().split(separator)
12 | paths_to_add = [path for path in paths_to_add if path not in existing_paths]
13 | os.environ[env_var] = separator.join(paths_to_add + existing_paths)
14 | else:
15 | os.environ[env_var] = separator.join(paths_to_add)
16 |
17 | def prepare_windows_gui_environment():
18 | gvsbuild_build_dir = pathlib.Path(__file__).parent.parent.parent.joinpath("build_gtk").absolute()
19 | if not gvsbuild_build_dir.is_dir():
20 | return
21 |
22 | release_dir = gvsbuild_build_dir / "gtk" / "x64" / "release"
23 | if not release_dir.exists():
24 | return
25 |
26 | bin_dir = release_dir / "bin"
27 | lib_dir = release_dir / "lib"
28 | includes = [
29 | release_dir / "include",
30 | release_dir / "include" / "cairo",
31 | release_dir / "include" / "glib-2.0",
32 | release_dir / "include" / "gobject-introspection-1.0",
33 | release_dir / "lib" / "glib-2.0" / "include",
34 | ]
35 |
36 | _update_env_var("PATH", [bin_dir])
37 | _update_env_var("LIB", [lib_dir])
38 | _update_env_var("INCLUDE", includes)
39 |
40 | if sys.platform == "win32" and not is_running_pyinstaller_bundle:
41 | prepare_windows_gui_environment()
--------------------------------------------------------------------------------
/packaging/windows/README.md:
--------------------------------------------------------------------------------
1 | ## Build exe
2 |
3 | ```powershell
4 | powershell -ExecutionPolicy Bypass ./packaging/windows/package_executable.ps1
5 | ```
6 |
7 | Script options:
8 | * `--skip-winget`: Skip installing/upgrading system dependencies via winget
9 | * `--skip-gvsbuild`: Skip installing/upgrading system dependencies via gvsbuild
10 | * `--clean-gvsbuild`: Does a clean build of gvsbuild
11 | * `--cli-only`: Skips gvsbuild and builds only `lada-cli.exe`
12 |
13 | > [!TIP]
14 | > If you updated `gvsbuild`, `uv` or `python` do a clean build (`--clean-gvsbuild`)
15 |
16 | > [!TIP]
17 | > If you get a build error about *Clock skew* check you Date & Time settings. If it doesn't work rewrite timestamps:
18 | > ```Powershell
19 | > $now = (Get-Date)
20 | > Get-ChildItem -Path ./build_gtk_release/build -Recurse | ForEach-Object { $_.LastWriteTime = $now }
21 | > ```
22 |
23 | > [!TIP]
24 | > After doing major packaging changes that involve system dependencies test the exe on another pristine Windows VM.
25 | >
26 | > This makes sure that PyInstaller picked up all required dependencies which are available on the build machine but maybe aren't on the users Windows installation.
27 |
28 | ## Publish exe
29 |
30 | * Attach both `lada-.7z.001` and `lada-.7z.002` to GitHub Release (rag-and-drop to the Draft Release)
31 | * Upload `lada-.7z` to https://pixeldrain.com
32 | * Add Pixeldrain download link and link to GitHub Release to both Draft Releases on GitHub and Codeberg
33 |
34 | ## Update python dependencies
35 |
36 | The `requirements.txt` file is generated by `uv export` based on `uv.lock` located in the root of the project.
37 |
38 | See packaging [README.md](../README.md) for context.
--------------------------------------------------------------------------------
/translations/compile_po.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | # SPDX-FileCopyrightText: Lada Authors
4 | # SPDX-License-Identifier: AGPL-3.0
5 |
6 | translations_dir=$(dirname -- "$0")
7 | cwd_back="$(pwd)"
8 | if [ "$(pwd)" != "$translations_dir" ] ; then
9 | cd "$translations_dir"
10 | fi
11 |
12 | lang_filter=""
13 | for arg in "$@"; do
14 | case "$arg" in
15 | --release)
16 | lang_filter=$(cat release_ready_translations.txt | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
17 | if [ -z "$lang filter" ]; then
18 | echo "No translations in release_ready_translations.txt"
19 | exit 1
20 | fi
21 | ;;
22 | esac
23 | done
24 |
25 | should_compile_po() {
26 | lang="$1"
27 | if [ -z "$lang_filter" ]; then
28 | return 0
29 | fi
30 | echo "$lang_filter" | tr " " "\n" | grep -q "$lang"
31 | return $?
32 | }
33 |
34 | # Clean up compiled translations if there is no corresponding .po file anymore (deleted translations)
35 | find ../lada/locale/ -mindepth 2 -maxdepth 2 -type d -name LC_MESSAGES | while read lang_dir; do
36 | lang=$(basename "$(dirname "$lang_dir")")
37 | po_file="$lang.po"
38 | if [ ! -f "$po_file" ]; then
39 | rm -rf "$(dirname $lang_dir)"
40 | fi
41 | done
42 |
43 | # Compile .po files
44 | find . -mindepth 1 -maxdepth 1 -type f -name "*.po" -printf '%f\n' | while read po_file ; do
45 | lang="${po_file%.po}"
46 | if ! should_compile_po $lang ; then
47 | _lang_dir="../lada/locale/$lang"
48 | [ -d "$_lang_dir" ] && rm -r "$_lang_dir"
49 | continue
50 | fi
51 | lang_dir="../lada/locale/$lang/LC_MESSAGES"
52 | if [ ! -d "$lang_dir" ] ; then
53 | mkdir -p "$lang_dir"
54 | fi
55 | echo "Compiling language $lang .po file into .mo file"
56 | msgfmt "$po_file" -o "$lang_dir/lada.mo"
57 | done
58 |
59 | cd "$cwd_back"
--------------------------------------------------------------------------------
/lada/gui/export/export_multiple_files_page.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | True
19 |
20 |
21 | 1
22 |
23 |
24 | true
25 | 1
26 | 18
27 | 18
28 | 12
29 | 12
30 | 0
31 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/model_weights/checksums_md5.txt:
--------------------------------------------------------------------------------
1 | d5ae43e5e5b1671fc966ecc908ac371a lada_mosaic_detection_model_v2.pt
2 | 5a5d8892cdacc5a6acff09acf9d0b2cc lada_mosaic_detection_model_v3.1_accurate.pt
3 | b98975fcbb8d7dfe9e84837af7e37da5 lada_mosaic_detection_model_v3.1_fast.pt
4 | 5df084b7be279181bbe04799146e659b lada_mosaic_detection_model_v3.pt
5 | 541c4b94fbbfa3016180855ef1577a2f lada_mosaic_detection_model_v4_accurate.pt
6 | bbbf5450bd34e6f9defa0886ee61363a lada_mosaic_detection_model_v4_fast.pt
7 | b7f41174c8d4285cfe03407e9b47951b lada_mosaic_edge_detection_model.pth
8 | eb7f73fb07584b889accf3f7dd865ba8 lada_mosaic_restoration_model_bj_pov.pth
9 | 29b248a032c1ecefc1130345785ff4d9 lada_mosaic_restoration_model_generic_v1.1.pth
10 | d701d761cfe7bae8ebb8ec9fd417439d lada_mosaic_restoration_model_generic_v1.2.pth
11 | 4912f79ad687757c3eb21bf15b80ce9b lada_nsfw_detection_model.pt
12 | db114bfd9b1dcb34df17cfc3e22cf759 lada_nsfw_detection_model_v1.1.pt
13 | 513d443241abfe5d91ea800ec4af11cb lada_nsfw_detection_model_v1.2.pt
14 | 09b22e20575d89d975f51daf314a5c7e lada_nsfw_detection_model_v1.3.pt
15 | 17d3430a66e3cba946a5bdeb7fbc04dc lada_watermark_detection_model.pt
16 | 311a1844e137d5cc17176bfb11865925 lada_watermark_detection_model_v1.1.pt
17 | 1a2e043ffc53d26f8a03fffdc6f582b2 lada_watermark_detection_model_v1.2.pt
18 | 1dc0ab0f60c5bf573281d0f1f66d117e lada_watermark_detection_model_v1.3.pt
19 | fe7f40d374184270dffd1259955e3023 3rd_party/640m.pt
20 | 986074d38e050e75139193058b60ff51 3rd_party/clean_youknow_video.pth
21 | 5595c0a1f378117fe4acb83a417b6377 3rd_party/DOVER.pth
22 | 5185d0677645416acd18928b1db9c1e0 3rd_party/spynet_20210409-c6c1bd09.pth
23 | 92881fe292bd7d2408ecff58a101fd03 3rd_party/vgg19-dcbb9e9d.pth
24 | 57f3d9de607bfb185fbf07d8b42fc125 3rd_party/centerface.onnx
25 | 6a6bbeb282ee76604d1214529bd1f42d 3rd_party/ch_head_s_1536_e150_best_mMR.pt
26 |
--------------------------------------------------------------------------------
/scripts/evaluation/validate-yolo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 |
5 | from lada.models.yolo.yolo import Yolo
6 | from lada.utils import ultralytics_utils
7 |
8 | ultralytics_utils.set_default_settings()
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description='Validate a model on a validation dataset')
12 | parser.add_argument('--model-path')
13 | parser.add_argument('--dataset-config-path', help="Path to YOLO dataset config file")
14 | parser.add_argument('--conf', type=float, default=0.4, help="Detection confidence threshold")
15 | parser.add_argument('--imgsz', type=int, default=640, help="Target Image/Frame resolution. Image/Frame will be scaled up/down accordingly. Needs to be a multiple of 32 to match YOLO stride size")
16 | parser.add_argument('--iou', type=float, default=0.7, help="IoU (Intersection over union) used for NMS (Non-Maximum-Suppression) and to calculate recall metric")
17 | parser.add_argument('--plot', default=True, action=argparse.BooleanOptionalAction, help="Plot results (precision, recall curves, confusion matrix etc.)")
18 | return parser.parse_args()
19 |
20 | if __name__ == '__main__':
21 | args = parse_args()
22 | model = Yolo(args.model_path)
23 | run_name = f"run_{datetime.now().strftime("%Y%m%d_%H%M%S")}"
24 |
25 | results = model.val(plots=True, data=args.dataset_config_path, conf=args.conf, iou=args.iou, imgsz=args.imgsz, name=run_name)
26 |
27 | print(f"Results for {args.model_path}:")
28 | print(results.to_df())
29 | print("Confusion Matrix:")
30 | print(results.confusion_matrix.to_df())
31 | if args.plot:
32 | runs_dir = ultralytics_utils.get_settings()["runs_dir"]
33 | plot_dir = os.path.join(runs_dir, model.task, run_name)
34 | print("Plotted validation results to", plot_dir)
--------------------------------------------------------------------------------
/lada/restorationpipeline/basicvsrpp_mosaic_restorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lada.models.basicvsrpp.basicvsrpp_gan import BasicVSRPlusPlusGan
4 | from lada.utils import ImageTensor
5 |
6 | class BasicvsrppMosaicRestorer:
7 | def __init__(self, model: BasicVSRPlusPlusGan, device: torch.device, fp16: bool):
8 | self.model = model
9 | self.device: torch.device = torch.device(device)
10 | self.dtype = torch.float16 if fp16 else torch.float32
11 |
12 | def restore(self, video: list[ImageTensor], max_frames=-1) -> list[ImageTensor]:
13 | input_frame_count = len(video)
14 | input_frame_shape = video[0].shape
15 | with torch.inference_mode():
16 | result = []
17 | inference_view = torch.stack([x.permute(2, 0, 1) for x in video], dim=0).to(device=self.device).to(dtype=self.dtype).div_(255.0).unsqueeze(0)
18 |
19 | if max_frames > 0:
20 | for i in range(0, inference_view.shape[1], max_frames):
21 | output = self.model(inputs=inference_view[:, i:i + max_frames])
22 | result.append(output)
23 | result = torch.cat(result, dim=1)
24 | else:
25 | result = self.model(inputs=inference_view)
26 |
27 | # (H, W, C[BGR]) uint8 images to (B, T, C, H, W) float in [0,1]
28 | result = result.squeeze(0)[:input_frame_count] # -> (T, C, H, W)
29 | result = result.mul_(255.0).round_().clamp_(0, 255).to(dtype=torch.uint8).permute(0, 2, 3, 1) # (T, H, W, C)
30 | result = list(torch.unbind(result, 0)) # (T, H, W, C) to list of (H, W, C)
31 | output_frame_count = len(result)
32 | output_frame_shape = result[0].shape
33 | assert input_frame_count == output_frame_count and input_frame_shape == output_frame_shape
34 |
35 | return result
36 |
--------------------------------------------------------------------------------
/lada/models/bpjdet/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: BPJDet Authors
2 | # SPDX-FileCopyrightText: YOLOv5 🚀 by Ultralytics
3 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
4 | # Code vendored from: https://github.com/hnuzhy/BPJDet
5 |
6 | """
7 | Image augmentation functions
8 | """
9 |
10 | import cv2
11 | import numpy as np
12 |
13 | def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
14 | # Resize and pad image while meeting stride-multiple constraints
15 | shape = im.shape[:2] # current shape [height, width]
16 | if isinstance(new_shape, int):
17 | new_shape = (new_shape, new_shape)
18 |
19 | # Scale ratio (new / old)
20 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
21 | if not scaleup: # only scale down, do not scale up (for better val mAP)
22 | r = min(r, 1.0)
23 |
24 | # Compute padding
25 | ratio = r, r # width, height ratios
26 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
27 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
28 | if auto: # minimum rectangle
29 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
30 | elif scaleFill: # stretch
31 | dw, dh = 0.0, 0.0
32 | new_unpad = (new_shape[1], new_shape[0])
33 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
34 |
35 | dw /= 2 # divide padding into 2 sides
36 | dh /= 2
37 |
38 | if shape[::-1] != new_unpad: # resize
39 | im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
40 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
41 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
42 | im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
43 | return im, ratio, (dw, dh)
44 |
--------------------------------------------------------------------------------
/patches/remove_use_of_torch_dist_in_mmengine.patch:
--------------------------------------------------------------------------------
1 | --- old/mmengine/dist/dist.py 2025-11-18 17:52:31.052601500 +0800
2 | +++ new/mmengine/dist/dist.py 2025-11-19 14:04:52.456769400 +0800
3 | @@ -23,6 +23,12 @@
4 | from mmengine.device import is_npu_available
5 |
6 |
7 | +if not hasattr(torch.distributed, "ReduceOp"):
8 | + class DummyReduceOp:
9 | + SUM = None
10 | + MEAN = None
11 | + torch.distributed.ReduceOp = DummyReduceOp
12 | +
13 | def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
14 | op_mappings = {
15 | 'sum': torch_dist.ReduceOp.SUM,
16 |
17 | --- old/mmengine/model/wrappers/__init__.py 2025-11-18 17:52:31.114603600 +0800
18 | +++ new/mmengine/model/wrappers/__init__.py 2025-11-22 02:19:08.124080097 +0800
19 | @@ -1,4 +1,11 @@
20 | # Copyright (c) OpenMMLab. All rights reserved.
21 | +import torch, types
22 | +if not hasattr(torch, "distributed") or not hasattr(torch.distributed, "fsdp"):
23 | + torch.distributed = types.SimpleNamespace()
24 | + torch.distributed.fsdp = types.SimpleNamespace()
25 | + torch.distributed.fsdp.fully_sharded_data_parallel = types.SimpleNamespace()
26 | +
27 | +
28 | from mmengine.utils.dl_utils import TORCH_VERSION
29 | from mmengine.utils.version_utils import digit_version
30 | from .distributed import MMDistributedDataParallel
31 | @@ -11,6 +18,12 @@
32 | ]
33 |
34 | if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
35 | - from .fully_sharded_distributed import \
36 | - MMFullyShardedDataParallel # noqa:F401
37 | - __all__.append('MMFullyShardedDataParallel')
38 | + try:
39 | + from .fully_sharded_distributed import MMFullyShardedDataParallel # noqa:F401
40 | + except Exception as e:
41 | + import warnings
42 | + warnings.warn(f"FSDP disabled: {e}")
43 | + MMFullyShardedDataParallel = None
44 | +
45 | +
46 | + __all__.append('MMFullyShardedDataParallel')
47 |
--------------------------------------------------------------------------------
/scripts/evaluation/validate-deepmosaics.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import argparse
5 | import glob
6 | import os.path
7 |
8 | import cv2
9 | import torch
10 |
11 | from lada.utils.image_utils import pad_image, resize
12 | from lada.utils.video_utils import read_video_frames, get_video_meta_data, write_frames_to_video_file
13 | from lada.models.deepmosaics.inference import restore_video_frames
14 | from lada.models.deepmosaics.models import loadmodel
15 |
16 | def validate(in_dir, out_dir, device, model_path):
17 | model = loadmodel.video(device, model_path, False)
18 | for video_path in glob.glob(os.path.join(in_dir, '*')):
19 | video_metadata = get_video_meta_data(video_path)
20 | images = read_video_frames(video_path, float32=False)
21 |
22 | if images[0].shape[:2] != (256, 256):
23 | size = 256
24 | for i, _ in enumerate(images):
25 | images[i] = resize(images[i], size, interpolation=cv2.INTER_LINEAR)
26 | images[i], _ = pad_image(images[i], size, size, mode='reflect')
27 |
28 | restored_images = restore_video_frames(device.index, model, images)
29 | filename = os.path.basename(video_path)
30 | out_path = os.path.join(out_dir, filename)
31 | fps = video_metadata.video_fps
32 | write_frames_to_video_file(restored_images, out_path, fps)
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser(description='Validate a model on a validation dataset')
36 | parser.add_argument('--out-dir')
37 | parser.add_argument('--in-dir')
38 | parser.add_argument('--model-path')
39 | return parser.parse_args()
40 |
41 | if __name__ == '__main__':
42 | args = parse_args()
43 | if not os.path.exists(args.out_dir):
44 | os.mkdir(args.out_dir)
45 | validate(args.in_dir, args.out_dir, torch.device("cuda:0"), args.model_path)
--------------------------------------------------------------------------------
/patches/adw_spinner_to_gtk_spinner.patch:
--------------------------------------------------------------------------------
1 | --- old/lada/gui/preview/preview_view.ui 2025-10-02 00:00:00.0.0
2 | +++ new/lada/gui/preview/preview_view.ui 2025-10-02 00:00:00.0.0
3 | @@ -91,7 +91,8 @@
4 |
5 |
6 |
7 | -
8 | +
9 | + True
10 | 64
11 | 64
12 | center
13 | @@ -180,7 +181,8 @@
14 |
15 | spinner
16 |
17 | -
18 | +
19 | + True
20 | 64
21 | 64
22 | center
23 |
--------------------------------------------------------------------------------
/lada/gui/preview/seek_preview_popover.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import pathlib
5 |
6 | import cv2
7 | from gi.repository import Gtk, GObject, Gdk, Graphene, Gsk, Adw, GLib, GdkPixbuf
8 |
9 | from lada.gui import utils
10 | from lada.utils import Image
11 |
12 | here = pathlib.Path(__file__).parent.resolve()
13 |
14 | @Gtk.Template(string=utils.translate_ui_xml(here / 'seek_preview_popover.ui'))
15 | class SeekPreviewPopover(Gtk.Popover):
16 | __gtype_name__ = 'SeekPreviewPopover'
17 |
18 | label: Gtk.Label = Gtk.Template.Child()
19 | spinner: Gtk.Spinner = Gtk.Template.Child()
20 | picture: Gtk.Picture = Gtk.Template.Child()
21 |
22 | def __init__(self, **kwargs):
23 | super().__init__(**kwargs)
24 |
25 | def set_text(self, text: str):
26 | self.label.set_text(text)
27 |
28 | def show_spinner(self):
29 | self.spinner.set_visible(True)
30 | self.spinner.start()
31 |
32 | def hide_spinner(self):
33 | self.spinner.stop()
34 | self.spinner.set_visible(False)
35 |
36 | def set_thumbnail(self, thumbnail: Image):
37 | # Convert BGR to RGB for GdkPixbuf
38 | rgb_thumbnail = cv2.cvtColor(thumbnail, cv2.COLOR_BGR2RGB)
39 |
40 | # Create pixbuf from bytes in memory
41 | height, width, channels = rgb_thumbnail.shape
42 | pixbuf = GdkPixbuf.Pixbuf.new_from_bytes(
43 | GLib.Bytes.new(rgb_thumbnail.tobytes()),
44 | GdkPixbuf.Colorspace.RGB,
45 | False, # has_alpha
46 | 8, # bits_per_sample
47 | width,
48 | height,
49 | width * channels
50 | )
51 |
52 | def update_ui():
53 | self.picture.set_pixbuf(pixbuf)
54 | self.hide_spinner()
55 | return False
56 |
57 | GLib.idle_add(update_ui)
58 |
59 | def clear_thumbnail(self):
60 | GLib.idle_add(lambda: self.picture.set_pixbuf(None))
61 |
62 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/deformconv.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import init as init
7 | from torch.nn.modules.utils import _pair, _single
8 | import math
9 |
10 | class ModulatedDeformConv2d(nn.Module):
11 | def __init__(self,
12 | in_channels,
13 | out_channels,
14 | kernel_size,
15 | stride=1,
16 | padding=0,
17 | dilation=1,
18 | groups=1,
19 | deform_groups=1,
20 | bias=True):
21 | super(ModulatedDeformConv2d, self).__init__()
22 |
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.kernel_size = _pair(kernel_size)
26 | self.stride = stride
27 | self.padding = padding
28 | self.dilation = dilation
29 | self.groups = groups
30 | self.deform_groups = deform_groups
31 | self.with_bias = bias
32 | # enable compatibility with nn.Conv2d
33 | self.transposed = False
34 | self.output_padding = _single(0)
35 |
36 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
37 | if bias:
38 | self.bias = nn.Parameter(torch.Tensor(out_channels))
39 | else:
40 | self.register_parameter('bias', None)
41 | self.init_weights()
42 |
43 | def init_weights(self):
44 | n = self.in_channels
45 | for k in self.kernel_size:
46 | n *= k
47 | stdv = 1. / math.sqrt(n)
48 | self.weight.data.uniform_(-stdv, stdv)
49 | if self.bias is not None:
50 | self.bias.data.zero_()
51 |
52 | if hasattr(self, 'conv_offset'):
53 | self.conv_offset.weight.data.zero_()
54 | self.conv_offset.bias.data.zero_()
55 |
56 | def forward(self, x, offset, mask):
57 | pass
--------------------------------------------------------------------------------
/scripts/evaluation/validate-basicvsrpp.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import argparse
5 | import glob
6 | import os.path
7 |
8 | import torch
9 | import cv2
10 |
11 | from lada.models.basicvsrpp.inference import load_model, inference
12 | from lada.utils.image_utils import pad_image, resize
13 | from lada.utils.video_utils import read_video_frames, get_video_meta_data, write_frames_to_video_file
14 |
15 | def validate(in_dir, out_dir, config_path, model_path, device):
16 | model = load_model(config_path, model_path, device)
17 | with torch.no_grad():
18 | for video_path in glob.glob(os.path.join(in_dir, '*')):
19 | video_metadata = get_video_meta_data(video_path)
20 | orig_images = read_video_frames(video_path, float32=False)
21 |
22 | if orig_images[0].shape[:2] != (256, 256):
23 | size = 256
24 | for i, _ in enumerate(orig_images):
25 | orig_images[i] = resize(orig_images[i], size, interpolation=cv2.INTER_LINEAR)
26 | orig_images[i], _ = pad_image(orig_images[i], size, size, mode='zero')
27 |
28 | restored_images = inference(model, orig_images, device)
29 | filename = os.path.basename(video_path)
30 | out_path = os.path.join(out_dir, filename)
31 | fps = video_metadata.video_fps
32 | write_frames_to_video_file(restored_images, out_path, fps)
33 |
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Validate a model on a validation dataset')
37 | parser.add_argument('--out-dir', help='the dir to save logs and models')
38 | parser.add_argument('--in-dir')
39 | parser.add_argument('--model-path')
40 | parser.add_argument('--config-path')
41 | return parser.parse_args()
42 |
43 | if __name__ == '__main__':
44 | args = parse_args()
45 | if not os.path.exists(args.out_dir):
46 | os.mkdir(args.out_dir)
47 | validate(args.in_dir, args.out_dir, args.config_path, args.model_path, "cuda")
--------------------------------------------------------------------------------
/lada/gui/fileselection/file_selection_view.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 | import pathlib
6 | from gi.repository import Adw, Gtk, Gio, GObject
7 | from lada import LOG_LEVEL
8 | from lada.gui import utils
9 | from lada.gui.shortcuts import ShortcutsManager
10 |
11 | here = pathlib.Path(__file__).parent.resolve()
12 | logger = logging.getLogger(__name__)
13 | logging.basicConfig(level=LOG_LEVEL)
14 |
15 | @Gtk.Template(string=utils.translate_ui_xml(here / 'file_selection_view.ui'))
16 | class FileSelectionView(Gtk.Widget):
17 | __gtype_name__ = 'FileSelectionView'
18 |
19 | button_open_file: Gtk.Button = Gtk.Template.Child()
20 | status_page: Adw.StatusPage = Gtk.Template.Child()
21 |
22 | def __init__(self, **kwargs):
23 | super().__init__(**kwargs)
24 |
25 | self._shortcuts_manager: ShortcutsManager | None = None
26 | self._window_title: str | None = None
27 |
28 | drop_target = utils.create_video_files_drop_target(lambda files: self.emit("files-selected", files))
29 | self.add_controller(drop_target)
30 |
31 | logo_image = Gtk.Image.new_from_resource("/io/github/ladaapp/lada/icons/128x128/lada-logo-gray.png")
32 | self.status_page.set_paintable(logo_image.get_paintable())
33 |
34 | @GObject.Property(type=str)
35 | def window_title(self):
36 | return self._window_title
37 |
38 | @window_title.setter
39 | def window_title(self, value):
40 | self._window_title = value
41 |
42 | @Gtk.Template.Callback()
43 | def button_open_file_callback(self, button_clicked):
44 | self.button_open_file.set_sensitive(False)
45 | callback = lambda files: self.emit("files-selected", files)
46 | dismissed_callback = lambda *args: self.button_open_file.set_sensitive(True)
47 | utils.show_open_files_dialog(callback, dismissed_callback)
48 |
49 | @GObject.Signal(name="files-selected", arg_types=(GObject.TYPE_PYOBJECT,))
50 | def files_opened_signal(self, files: list[Gio.File]):
51 | pass
52 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/nsfw_frame_detector.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from ultralytics.engine.results import Results as UltralyticsResults
5 | from lada.utils import Detections, Detection, Image, DETECTION_CLASSES
6 | from lada.utils import mask_utils, ultralytics_utils
7 | from lada.models.yolo.yolo import Yolo
8 |
9 |
10 | def get_nsfw_frames(yolo_results: UltralyticsResults, random_extend_masks: bool) -> Detections | None:
11 | detections = []
12 | if not yolo_results.boxes:
13 | return None
14 | for yolo_box, yolo_mask in zip(yolo_results.boxes, yolo_results.masks):
15 | mask = ultralytics_utils.convert_yolo_mask(yolo_mask, yolo_results.orig_img.shape)
16 | box = ultralytics_utils.convert_yolo_box(yolo_box, yolo_results.orig_img.shape)
17 | conf = ultralytics_utils.convert_yolo_conf(yolo_box)
18 | mask, box = mask_utils.clean_mask(mask, box)
19 | mask = mask_utils.smooth_mask(mask, kernel_size=11)
20 |
21 | if random_extend_masks:
22 | mask = mask_utils.apply_random_mask_extensions(mask)
23 | mask = mask_utils.smooth_mask(mask, kernel_size=15)
24 | box = mask_utils.get_box(mask)
25 |
26 | t, l, b, r = box
27 | width, height = r - l + 1, b - t + 1
28 | if min(width, height) < 10:
29 | # skip tiny detections
30 | continue
31 |
32 | detections.append(Detection(DETECTION_CLASSES["nsfw"]["cls"], box, mask, conf))
33 | return Detections(yolo_results.orig_img, detections)
34 |
35 | class NsfwImageDetector:
36 | def __init__(self, model: Yolo, device=None, random_extend_masks=False, conf=0.25):
37 | self.model = model
38 | self.device = device
39 | self.random_extend_masks = random_extend_masks
40 | self.conf = conf
41 |
42 | def detect(self, source: str | Image) -> Detections | None:
43 | for results in self.model.predict(source=source, stream=False, verbose=False, device=self.device, conf=self.conf, iou=0.):
44 | return get_nsfw_frames(results, self.random_extend_masks)
--------------------------------------------------------------------------------
/packaging/flatpak/README.md:
--------------------------------------------------------------------------------
1 | > [!NOTE]
2 | > The manifest `io.github.ladaapp.lada.yaml` in this directory is used only for building the flatpak locally.
3 | >
4 | > The manifest used for building the flatpak available on Flathub is maintained in the [flathub repo on GitHub](https://github.com/flathub/io.github.ladaapp.lada).
5 | >
6 | > Check out the sections below for further details.
7 |
8 | ## Build and publish to Flathub
9 |
10 | The Flatpak on Flathub is build via CI. You just have to open (and merge) a pull request on this repository:
11 |
12 | https://github.com/flathub/io.github.ladaapp.lada
13 |
14 | It contains a very similar manifest file. Adjust it (in most cases just update the git tag) and create a pull request with these changes.
15 |
16 | If the pipeline succeeds it should post a comment to the PR with a link to install the built flatpak.
17 |
18 | Only if the PR gets merged the production pipeline runs which will push the new flatpak to Flathub.
19 |
20 | Note, that it will take a few hours before the new Flatpak is available on Flathub. It takes even longer before the Flathub website gets updated.
21 |
22 | ```shell
23 | uv run packaging/flatpak/convert-pylock-to-flatpak.py
24 | ```
25 |
26 | ## Build and install locally
27 |
28 | Setup dependencies
29 | ```shell
30 | flatpak remote-add --if-not-exists --user flathub https://dl.flathub.org/repo/flathub.flatpakrepo
31 | flatpak install --user -y flathub org.flatpak.Builder
32 | ```
33 |
34 | Build and install:
35 | ```shell
36 | flatpak run org.flatpak.Builder --force-clean --user --install --install-deps-from=flathub build_flatpak packaging/flatpak/io.github.ladaapp.lada.yaml
37 | ```
38 |
39 | Lada is now installed. You should be able to find it in your application launcher as `lada (dev)`.
40 |
41 | Or you run it via `flatpak run io.github.ladaapp.lada//main`.
42 |
43 |
44 | ## Update python dependencies
45 |
46 | All python dependencies are specified in and installed via `lada-python-dependencies.yaml` flatpak module.
47 |
48 | This file is generated py the script `convert-pylock-to-flatpak.py` based on `uv.lock` located in the root of the project.
49 |
50 | See packaging [README.md](../README.md) for context.
--------------------------------------------------------------------------------
/lada/gui/resources.gresource.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 | icons/arrow-pointing-away-from-line-right-symbolic.svg
11 | icons/cafe-symbolic.svg
12 | icons/check-round-outline2-symbolic.svg
13 | icons/color-scheme-dark.svg
14 | icons/color-scheme-light.svg
15 | icons/color-scheme-system.svg
16 | icons/external-link-symbolic.svg
17 | icons/media-playback-pause-symbolic.svg
18 | icons/media-playback-start-symbolic.svg
19 | icons/playback-symbolic.svg
20 | icons/sliders-horizontal-symbolic.svg
21 | icons/speaker-0-symbolic.svg
22 | icons/speaker-4-symbolic.svg
23 | icons/plus-large-symbolic.svg
24 | icons/folder-open-symbolic.svg
25 | icons/exclamation-mark-symbolic.svg
26 | icons/cross-large-symbolic.svg
27 | icons/pause-large-symbolic.svg
28 | icons/edit-symbolic.svg
29 |
30 |
31 | icons/lada-logo-gray.png
32 |
33 |
--------------------------------------------------------------------------------
/lada/restorationpipeline/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 | from lada import LOG_LEVEL, ModelFiles
6 | from lada.models.yolo.yolo11_segmentation_model import Yolo11SegmentationModel
7 |
8 | logger = logging.getLogger(__name__)
9 | logging.basicConfig(level=LOG_LEVEL)
10 |
11 | def load_models(
12 | device: torch.device,
13 | mosaic_restoration_model_name: str,
14 | mosaic_restoration_model_path: str,
15 | mosaic_restoration_config_path: str | None,
16 | mosaic_detection_model_path: str,
17 | fp16: bool,
18 | detect_face_mosaics: bool):
19 | if mosaic_restoration_model_name.startswith("deepmosaics"):
20 | from lada.models.deepmosaics.models import loadmodel
21 | from lada.restorationpipeline.deepmosaics_mosaic_restorer import DeepmosaicsMosaicRestorer
22 | _model = loadmodel.video(device, mosaic_restoration_model_path, fp16)
23 | mosaic_restoration_model = DeepmosaicsMosaicRestorer(_model, device)
24 | pad_mode = 'reflect'
25 | elif mosaic_restoration_model_name.startswith("basicvsrpp"):
26 | from lada.models.basicvsrpp.inference import load_model
27 | from lada.restorationpipeline.basicvsrpp_mosaic_restorer import BasicvsrppMosaicRestorer
28 | _model = load_model(mosaic_restoration_config_path, mosaic_restoration_model_path, device, fp16)
29 | mosaic_restoration_model = BasicvsrppMosaicRestorer(_model, device, fp16)
30 | pad_mode = 'zero'
31 | else:
32 | raise NotImplementedError()
33 | # setting classes=[0] will consider only detections of class id = 0 (nsfw mosaics) therefore filtering out sfw mosaics (heads, faces)
34 | if detect_face_mosaics:
35 | classes = [0]
36 | detection_model_name = ModelFiles.get_detection_model_by_path(mosaic_detection_model_path)
37 | if detection_model_name and detection_model_name == "v2":
38 | logger.info("Mosaic detection model v2 does not support detecting face mosaics. Use detection models v3 or newer. Ignoring...")
39 | else:
40 | classes = None
41 | mosaic_detection_model = Yolo11SegmentationModel(mosaic_detection_model_path, device, classes=classes, conf=0.15, fp16=fp16)
42 | return mosaic_detection_model, mosaic_restoration_model, pad_mode
43 |
--------------------------------------------------------------------------------
/scripts/training/export-weights-basicvsrpp-stage2-for-inference.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import os
5 |
6 | from mmengine.runner import load_checkpoint
7 | import torch
8 |
9 | from lada.models.basicvsrpp.basicvsrpp_gan import BasicVSRPlusPlusGan
10 | from lada.models.basicvsrpp import register_all_modules
11 |
12 | register_all_modules()
13 |
14 | MODEL_WEIGHTS_IN_PATH = 'experiments/basicvsrpp/mosaic_restoration_generic_stage2.6/iter_100000.pth'
15 | MODEL_WEIGHTS_OUT_PATH = 'experiments/basicvsrpp/mosaic_restoration_generic_stage2.6/lada_mosaic_restoration_model_generic_v1.2.pth'
16 | pretrained_models_dir = 'model_weights'
17 |
18 | model = BasicVSRPlusPlusGan(
19 | generator=dict(
20 | type='BasicVSRPlusPlusGanNet',
21 | mid_channels=64,
22 | num_blocks=15,
23 | spynet_pretrained=os.path.join(pretrained_models_dir, "3rd_party", "spynet_20210409-c6c1bd09.pth")),
24 | discriminator=dict(
25 | type='UNetDiscriminatorWithSpectralNorm',
26 | in_channels=3,
27 | mid_channels=64,
28 | skip_connection=True),
29 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
30 | perceptual_loss=dict(
31 | type='PerceptualLoss',
32 | layer_weights={
33 | '2': 0.1,
34 | '7': 0.1,
35 | '16': 1.0,
36 | '25': 1.0,
37 | '34': 1.0,
38 | },
39 | vgg_type='vgg19',
40 | perceptual_weight=0.2, # was 1.0
41 | pretrained=os.path.join(pretrained_models_dir, "3rd_party", "vgg19-dcbb9e9d.pth"),
42 | style_weight=0,
43 | norm_img=False),
44 | gan_loss=dict(
45 | type='GANLoss',
46 | gan_type='vanilla',
47 | loss_weight=5e-2,
48 | real_label_val=1.0,
49 | fake_label_val=0),
50 | is_use_ema=True,
51 | data_preprocessor=dict(
52 | type='DataPreprocessor',
53 | mean=[0., 0., 0.],
54 | std=[255., 255., 255.],
55 | ))
56 |
57 | load_checkpoint(model, MODEL_WEIGHTS_IN_PATH, strict=True)
58 |
59 | model.discriminator = None
60 | model.perceptual_loss = None
61 | model.gan_loss = None
62 |
63 |
64 | torch.save(model.state_dict(), MODEL_WEIGHTS_OUT_PATH)
65 |
--------------------------------------------------------------------------------
/translations/compile_po.ps1:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | $cwd_backup = (Get-Location).Path
5 | $translationsDir = Split-Path -Parent $MyInvocation.MyCommand.Path
6 | if ((Get-Location).Path -ne $translationsDir) {
7 | Set-Location $translationsDir
8 | }
9 |
10 | $global:lang_filter = ""
11 | if ($args -contains "--release") {
12 | $global:lang_filter = (Get-Content .\release_ready_translations.txt -Raw).Trim() -replace '\s+', ' '
13 | if ($global:lang_filter -eq "") {
14 | Write-Host "No translations in .\release_ready_translations.txt"
15 | return
16 | }
17 | }
18 |
19 | function should_compile_po {
20 | param (
21 | [string]$lang
22 | )
23 |
24 | if (-not $global:lang_filter) {
25 | return $true
26 | }
27 |
28 | return ($global:lang_filter.split(" ") -contains $lang)
29 | }
30 |
31 | # Clean up compiled translations if there is no corresponding .po file anymore (deleted translations)
32 | Get-ChildItem -Directory -Path ..\lada\locale | ForEach-Object {
33 | $langDir = $_.FullName
34 | $lang = $_.Name
35 |
36 | $poFile = "$lang.po"
37 | if (-not (Test-Path $poFile)) {
38 | $lcMessagesDir = Join-Path $langDir "LC_MESSAGES"
39 | if (Test-Path $lcMessagesDir) {
40 | Write-Host "Removing outdated compiled translations for language '$lang' at '$langDir'"
41 | Remove-Item -Path $langDir -Recurse -Force
42 | }
43 | }
44 | }
45 |
46 | # Compile .po files
47 | Get-ChildItem -File -Filter "*.po" | ForEach-Object {
48 | $poFile = $_.Name
49 | $lang = $poFile -replace "\.po$"
50 |
51 | if (-not (should_compile_po $lang)) {
52 | $_langDir = "..\lada\locale\$lang"
53 | if (Test-Path $_langDir) {
54 | Remove-Item -Path $_langDir -Recurse -Force
55 | }
56 | return # actually a continue in a ForEach-Object loop
57 | }
58 |
59 | $langDir = "..\lada\locale\$lang\LC_MESSAGES"
60 | if (-not (Test-Path -Path $langDir)) {
61 | New-Item -ItemType Directory -Path $langDir -Force | Out-Null
62 | }
63 |
64 | Write-Host "Compiling language '$lang' .po file into .mo file"
65 | & msgfmt $poFile -o "$langDir\lada.mo"
66 | }
67 |
68 | Set-Location $cwd_backup
--------------------------------------------------------------------------------
/lada/gui/preview/seek_preview_popover.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | top
14 | false
15 |
16 |
17 | vertical
18 |
19 |
20 |
21 | center
22 | center
23 |
24 |
25 | 220
26 | 124
27 |
28 |
29 |
30 |
31 | 48
32 | 48
33 | center
34 | center
35 | true
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | Time code
44 | center
45 | center
46 | 2
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/lada/gui/shortcuts.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from gi.repository import GObject, Gtk
5 |
6 | class ShortcutsManager(GObject.Object):
7 | def __init__(self):
8 | GObject.Object.__init__(self)
9 | self.groups = {}
10 | self.group_titles = {}
11 |
12 | def add(self, group_key, action_key, keyboard_trigger, action, action_title):
13 | self.groups[group_key][action_key] = (keyboard_trigger, action, action_title)
14 |
15 | def register_group(self, group_key, group_title):
16 | self.group_titles[group_key] = group_title
17 | if group_key not in self.groups:
18 | self.groups[group_key] = {}
19 |
20 | def init(self, shortcut_controller: Gtk.ShortcutController):
21 | for group_key in self.groups:
22 | shortcuts = self.groups.get(group_key)
23 | if not shortcuts:
24 | continue
25 |
26 | for action_key in shortcuts:
27 | keyboard_trigger, action, _ = shortcuts[action_key]
28 | gtk_trigger = Gtk.ShortcutTrigger.parse_string(keyboard_trigger)
29 | gtk_action = Gtk.CallbackAction.new(action)
30 | shortcut = Gtk.Shortcut.new(gtk_trigger, gtk_action)
31 | shortcut_controller.add_shortcut(shortcut)
32 |
33 | class ShortcutsWindow(Gtk.ShortcutsWindow):
34 | def __init__(self, shortcuts_manager: ShortcutsManager):
35 | Gtk.ShortcutsWindow.__init__(self)
36 | self.shortcuts_manager = shortcuts_manager
37 | self.set_modal(True)
38 | self.populate()
39 |
40 | def populate(self):
41 | section = Gtk.ShortcutsSection()
42 | section.show()
43 | for group_key in self.shortcuts_manager.groups:
44 | shortcuts = self.shortcuts_manager.groups.get(group_key)
45 | if not shortcuts:
46 | continue
47 |
48 | group = Gtk.ShortcutsGroup(title=self.shortcuts_manager.group_titles[group_key])
49 | group.show()
50 | for action_key in shortcuts:
51 | keyboard_trigger, _, action_title = shortcuts[action_key]
52 | short = Gtk.ShortcutsShortcut(title=action_title, accelerator=keyboard_trigger)
53 | short.show()
54 | group.add_shortcut(short)
55 | section.add_group(group)
56 |
57 | self.add_section(section)
--------------------------------------------------------------------------------
/lada/models/deepmosaics/inference.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: DeepMosaics Authors
2 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/HypoX64/DeepMosaics/
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from lada.models.deepmosaics.util import data
9 |
10 | def restore_video_frames(gpu_id,netG, frames: list[np.ndarray[np.uint8]]) -> list[np.ndarray[np.uint8]]:
11 | """
12 | T is numer of frames processed in a single step (center frame + N previous/next frames that come before/after it):
13 | T = 2N + 1. The paper authors use N = 2 in their network (T = 5).
14 | S is the stride that determines which neighboring frames (N) we chose. With 1 we would take the immediate neighboring frames.
15 | The bigger S the more changes we're expected to see as each frame is further apart.
16 | The paper authors use 3 in their network.
17 | """
18 | N,T,S = 2,5,3
19 | LEFT_FRAME = (N*S)
20 | POOL_NUM = LEFT_FRAME*2+1
21 | INPUT_SIZE = 256
22 | FRAME_POS = np.linspace(0, (T-1)*S,T,dtype=np.int64)
23 | img_pool = []
24 | previous_frame = None
25 | init_flag = True
26 |
27 | restored_clip_frames = []
28 |
29 | for i in range(len(frames)):
30 | input_stream = []
31 | # image read stream
32 | if i==0 :# init
33 | for j in range(POOL_NUM):
34 | img_pool.append(frames[np.clip(i+j-LEFT_FRAME,0,len(frames)-1)])
35 | else: # load next frame
36 | img_pool.pop(0)
37 | img_pool.append(frames[np.clip(i+LEFT_FRAME,0,len(frames)-1)])
38 |
39 | for pos in FRAME_POS:
40 | input_stream.append(img_pool[pos][:,:,::-1])
41 | if init_flag:
42 | init_flag = False
43 | previous_frame = input_stream[N]
44 | previous_frame = data.im2tensor(previous_frame,bgr2rgb=True,gpu_id=gpu_id,dtype=netG.dtype)
45 |
46 | input_stream = np.array(input_stream).reshape(1,T,INPUT_SIZE,INPUT_SIZE,3).transpose((0,4,1,2,3))
47 | input_stream = data.to_tensor(data.normalize(input_stream),gpu_id=gpu_id,dtype=netG.dtype)
48 |
49 | with torch.inference_mode():
50 | unmosaic_pred = netG(input_stream,previous_frame)
51 | img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = True)
52 | previous_frame = unmosaic_pred
53 | restored_clip_frames.append(img_fake.copy())
54 |
55 | return restored_clip_frames
--------------------------------------------------------------------------------
/translations/README.md:
--------------------------------------------------------------------------------
1 | ## Compile translations
2 |
3 | Use either `compile_po.ps1` or `compile_po.sh` to compile translations.
4 |
5 | If run without additional arguments all .po files will be compiled.
6 |
7 | To only compile translations that are expected to be shipped in the next release add the `--release` argument.
8 |
9 |
10 | ## Include / Exclude translations in a release
11 |
12 | Packaging scripts will use the `--release` argument. So only translations listed in the file `release_ready_translations.txt` will be included.
13 |
14 | Before doing a release check translation completeness on weblate and / or ping recent translators if they deem the quality of the translation good enough to be shipped.
15 |
16 | The file format is: single line, lang codes separated by spaces.
17 |
18 | ## Update translations
19 |
20 | Updating `lada.pot` will re-sync translation strings found in the codebase and make them available to translators on [Codeberg's Weblate instance](https://translate.codeberg.org/projects/lada/lada/).
21 |
22 | Use this script for updating the .pot file:
23 | ```bash
24 | bash translations/update_pot.sh
25 | ```
26 |
27 | On Weblate the addon [Update PO files to match POT (msgmerge)](https://docs.weblate.org/en/weblate-5.14/admin/addons.html#update-po-files-to-match-pot-msgmerge) is active.
28 |
29 | This means that if you commit and push an updated .pot file Weblate will automatically update all translation files (.po).
30 |
31 | So you should be thoughtful when to update the .pot file. If you think some strings are likely to be changed soonish before the next release you should probably wait and
32 | update the .pot once the strings are more or less final to avoid causing unnecessary retranslation work.
33 |
34 | > [!NOTE]
35 | > The script `translations/update_po.sh` is currently not used as this is done by the Weblate addon as mentioned above.
36 |
37 | If a translator does changes on Weblate it will create a Pull Request push those changes. We are also using the squash commit of Weblate, meaning
38 | that all commits are squashed by author and language. The PR is force-pushed / updated continuously.
39 |
40 | > [!WARNING]
41 | > Make sure to merge this PR before the next release.
42 | >
43 | > But you should also merge this PR before you are doing updated to the .pot file to avoid merge conflicts!
44 |
45 | On the other hand, you shouldn't merge it for every change to avoid filling the git history with to many translation commits.
46 |
--------------------------------------------------------------------------------
/scripts/training/convert-weights-basicvsrpp-stage1-to-stage2.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from mmengine.runner import load_checkpoint
5 | import torch
6 |
7 | from lada.models.basicvsrpp.basicvsrpp_gan import BasicVSRPlusPlusGan
8 | from lada.models.basicvsrpp.mmagic.basicvsr import BasicVSR
9 | from lada.models.basicvsrpp import register_all_modules
10 |
11 | register_all_modules()
12 |
13 | BASICVSRPP_WEIGHTS_PATH = 'experiments/basicvsrpp/mosaic_restoration_generic_stage1/iter_10000.pth'
14 | BASICVSRPP_GAN_WEIGHTS_PATH = 'experiments/basicvsrpp/mosaic_restoration_generic_stage1/iter_10000_converted.pth'
15 |
16 | gan_model = BasicVSRPlusPlusGan(
17 | generator=dict(
18 | type='BasicVSRPlusPlusGanNet',
19 | mid_channels=64,
20 | num_blocks=15,
21 | spynet_pretrained='model_weights/3rd_party/spynet_20210409-c6c1bd09.pth'),
22 | discriminator=dict(
23 | type='UNetDiscriminatorWithSpectralNorm',
24 | in_channels=3,
25 | mid_channels=64,
26 | skip_connection=True),
27 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
28 |
29 | perceptual_loss=dict(
30 | type='PerceptualLoss',
31 | layer_weights={
32 | '2': 0.1,
33 | '7': 0.1,
34 | '16': 1.0,
35 | '25': 1.0,
36 | '34': 1.0,
37 | },
38 | vgg_type='vgg19',
39 | pretrained='model_weights/3rd_party/vgg19-dcbb9e9d.pth',
40 | perceptual_weight=0.2, # was 1.0
41 | style_weight=0,
42 | norm_img=False),
43 | gan_loss=dict(
44 | type='GANLoss',
45 | gan_type='vanilla',
46 | loss_weight=5e-2,
47 | real_label_val=1.0,
48 | fake_label_val=0),
49 | is_use_ema=True,
50 | data_preprocessor=dict(
51 | type='DataPreprocessor',
52 | mean=[0., 0., 0.],
53 | std=[255., 255., 255.],
54 | )
55 | )
56 |
57 | basicvsr = BasicVSR(dict(
58 | type='BasicVSRPlusPlusGanNet',
59 | mid_channels=64,
60 | num_blocks=15,
61 | spynet_pretrained='model_weights/3rd_party/spynet_20210409-c6c1bd09.pth'),
62 | dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'))
63 |
64 | load_checkpoint(basicvsr, BASICVSRPP_WEIGHTS_PATH, strict=True)
65 |
66 | gan_model.generator = basicvsr.generator
67 | gan_model.generator_ema = basicvsr.generator
68 |
69 | torch.save(gan_model.state_dict(), BASICVSRPP_GAN_WEIGHTS_PATH)
70 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/head_detector.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import cv2
5 | import numpy as np
6 | from lada.utils import Mask, Box, Image, Detections, Detection, DETECTION_CLASSES
7 | from lada.utils import box_utils
8 | from lada.models.bpjdet.inference import inference
9 |
10 | def _create_mask(frame: Image, box: Box) -> Mask:
11 | t, l, b, r = box
12 | box_width, box_height = r - l + 1, b - t + 1
13 |
14 | mask = np.zeros(frame.shape[:2], dtype=np.uint8)
15 |
16 | # Set the center of the ellipse at the center of the box
17 | center = (l + (box_width // 2), t + (box_height // 2))
18 |
19 | # Set the axes of the ellipse to half the width and half the height of the box
20 | axes = (box_width // 2, box_height // 2)
21 |
22 | angle = 0
23 | start_angle = 0
24 | end_angle = 360
25 |
26 | color = DETECTION_CLASSES["sfw_head"]["mask_value"]
27 | thickness = -1
28 |
29 | cv2.ellipse(mask, center, axes, angle, start_angle, end_angle, color, thickness)
30 |
31 | mask = np.expand_dims(mask, axis=-1)
32 |
33 | return mask
34 |
35 | def _get_detection(dets: list[Box], frame, random_extend_masks: bool) -> Detections | None:
36 | if len(dets) == 0:
37 | return None
38 | detections = []
39 | for box in dets:
40 | if random_extend_masks:
41 | box = box_utils.random_scale_box(frame, box, scale_range=(0.9, 1.2))
42 | mask = _create_mask(frame, box)
43 |
44 | t, l, b, r = box
45 | width, height = r - l + 1, b - t + 1
46 | if min(width, height) < 40:
47 | # skip tiny detections
48 | continue
49 | detections.append(Detection(DETECTION_CLASSES["sfw_head"]["cls"], box, mask))
50 | return Detections(frame, detections)
51 |
52 | class HeadDetector:
53 | def __init__(self, model, data, conf_thres, iou_thres, imgz=1536, random_extend_masks=False):
54 | self.model = model
55 | self.data = data
56 | self.random_extend_masks = random_extend_masks
57 | self.conf_thres = conf_thres
58 | self.iou_thres = iou_thres
59 | self.imgz = imgz
60 |
61 | def detect(self, source: str | Image) -> Detections | None:
62 | if isinstance(source, str):
63 | image = cv2.imread(source)
64 | else:
65 | image = source
66 | dets = inference(self.model, source, data=self.data, conf_thres=self.conf_thres, iou_thres=self.iou_thres, imgz=self.imgz)
67 | return _get_detection(dets, image, random_extend_masks=self.random_extend_masks)
68 |
--------------------------------------------------------------------------------
/model_weights/checksums_sha256.txt:
--------------------------------------------------------------------------------
1 | 056756fcab250bcdf0833e75aac33e2197b8809b0ab8c16e14722dcec94269b5 lada_mosaic_detection_model_v2.pt
2 | 2b6e5d6cd5a795a4dcc1205b817a7323a4bd3725cef1a7de3a172cb5689f0368 lada_mosaic_detection_model_v3.1_accurate.pt
3 | 25d62894c16bba00468f3bcc160360bb84726b2f92751b5e235578bf2f9b0820 lada_mosaic_detection_model_v3.1_fast.pt
4 | 5707c3af78a13ef8d60d0c3a3ea420e79745ac5c9b7d792e3a13598fcdedfc8f lada_mosaic_detection_model_v3.pt
5 | c244d7e49d8f88e264b8dc15f91fb21f5908ad8fb6f300b7bc88462d0801bc1f lada_mosaic_detection_model_v4_accurate.pt
6 | 9a6b660d1d3e3797d39515e08b0e72fcc59815f38279faa7a4ab374ab2c1e3b4 lada_mosaic_detection_model_v4_fast.pt
7 | b9d75d1c574f93287f6597da3f07be79e67d8ac93f3f3ee52caaffa41d7556ab lada_mosaic_edge_detection_model.pth
8 | b79de6fcb1fcafd3ce7c954f4ac788be448ec2d82c6e688aaf18b8ba48fb5b47 lada_mosaic_restoration_model_bj_pov.pth
9 | 6ec6542dde73fbc2086d252a041b41881e3194eaa0bac964348e6f7e8aad007c lada_mosaic_restoration_model_generic_v1.1.pth
10 | d404152576ce64fb5b2f315c03062709dac4f5f8548934866cd01c823c8104ee lada_mosaic_restoration_model_generic_v1.2.pth
11 | 70836ca7397c7844dfc299c7d512c27361758ff7428ffa2a51a7cb2ae8ee5c93 lada_nsfw_detection_model.pt
12 | 4083f08224d4b24193ddf4ba2405e33ff5429008dd340cd8c7ba787f097a4b8b lada_nsfw_detection_model_v1.1.pt
13 | 40f0ae51bfb744dea2d973e06128323bc279f75c810485d7f0986992b4280a70 lada_nsfw_detection_model_v1.2.pt
14 | 51962d8f1758c383a5ae6e998b4ffc4a309611a40008824fc3919428675d7e7a lada_nsfw_detection_model_v1.3.pt
15 | bf78aa7c32006c8516ea5f771315f8abdd48ebc294951f59a4bc85ee7bcfb073 lada_watermark_detection_model.pt
16 | 8b7db6659c02be18fe4c5b9459e44491a7d6ef03b3e18e947a896b5eed2c421d lada_watermark_detection_model_v1.1.pt
17 | d38da4f838fd5225403b75f5587a122fb91c4908f5c4f1fee0c14022a62c48f9 lada_watermark_detection_model_v1.2.pt
18 | c0bd2bb62820b3ee332463238e6a99a8e4d27742a8331a62b5d4c7ce7a78f599 lada_watermark_detection_model_v1.3.pt
19 | e6d7cddecc4417ff62db5b92c1c9f5d0d7b0f92e6cfd562120aea59a6ec3af3f 3rd_party/640m.pt
20 | 5643ca297c13920b8ffd39a0d85296e494683a69e5e8204d662653d24c582766 3rd_party/clean_youknow_video.pth
21 | f4a42c0bbc94c94dd7409e7f40887d44c5c30314d1d09e7edf03cc35813b4838 3rd_party/DOVER.pth
22 | c6c1bd09b52d05ba17f3e701f549d6faf5e314aabce8ae462c1c171a8d6c4914 3rd_party/spynet_20210409-c6c1bd09.pth
23 | dcbb9e9dad569fff7a846263a77324fc34978fea2bfb039c012d710e1776ae44 3rd_party/vgg19-dcbb9e9d.pth
24 | 09189deaaf8646c5c51a68447e3c744ea1e211798155d4728c20507b9f5aefbc 3rd_party/centerface.onnx
25 | 07fc13937ae8c530ebe45ec78e1646f7edd298796c7b84f33a87f8c3c83affa0 3rd_party/ch_head_s_1536_e150_best_mMR.pt
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | [project]
5 | name = "lada"
6 | description = "Remove and recover pixelated areas in adult videos"
7 | dynamic = ["version"]
8 | requires-python = ">=3.12.0"
9 | license = "AGPL-3.0"
10 | license-files = ["LICENSE.md", "LICENSES/*.txt"]
11 | readme = "README.md"
12 | dependencies = [
13 | "ultralytics==8.3.203", # Pinned as we apply a patch
14 | "numpy",
15 | "opencv-python",
16 | "av>=15.0.0", # Binary wheels before 15.0.0 had either no nvidia encoders or libx265 was broken/missing.
17 | "mmengine==0.10.7", # Pinned as we apply a patch
18 | "tqdm",
19 | "torch",
20 | "torchvision",
21 | ]
22 |
23 | [project.urls]
24 | homepage = "https://codeberg.org/ladaapp/lada"
25 | source = "https://codeberg.org/ladaapp/lada"
26 | documentation = "https://codeberg.org/ladaapp/lada/blob/main/README.md"
27 | issues = "https://codeberg.org/ladaapp/lada/issues"
28 | changelog = "https://codeberg.org/ladaapp/lada/releases"
29 |
30 | [build-system]
31 | requires = ["setuptools>=77.0.3"]
32 | build-backend = "setuptools.build_meta"
33 |
34 | [project.scripts]
35 | lada = "lada.gui.main:main"
36 | lada-cli = "lada.cli.main:main"
37 |
38 | [tool.setuptools]
39 | packages = { find = { where = ["."], include = ["lada", "lada.*"] } }
40 | [tool.setuptools.package-data]
41 | "*" = ['*.ui']
42 | "lada.gui" = ['style.css', 'resources.gresource']
43 | "lada.locale" = ['*.mo']
44 | "lada.utils" = ['encoding_presets.csv']
45 |
46 | [tool.setuptools.dynamic]
47 | version = { attr = "lada.VERSION" }
48 |
49 | [project.optional-dependencies]
50 | cu128 = ["torch==2.8.0", "torchvision==0.23.0"]
51 | gui = ["pycairo", "PyGObject"]
52 |
53 | [dependency-groups]
54 | dev = [
55 | { include-group = "training" },
56 | { include-group = "dataset-creation" }
57 | ]
58 | gui-dev = [
59 | { include-group = "dev" },
60 | "pygobject-stubs",
61 | ]
62 | training = ["albumentations", "tensorboard", "standard-imghdr"]
63 | dataset-creation = ["lap>=0.5.12", "timm", "einops", "pillow", "scipy", "onnx", "onnxruntime-gpu"]
64 | docker = ["opencv-python-headless"]
65 |
66 | [[tool.uv.index]]
67 | name = "pytorch-cu128"
68 | url = "https://download.pytorch.org/whl/cu128"
69 | explicit = true
70 |
71 | [tool.uv]
72 | conflicts = [
73 | [
74 | { package = "torch", extra = "cu128" },
75 | { package = "torchvision", extra = "cu128" },
76 | { package = "torch" },
77 | { package = "torchvision" },
78 | ],
79 | ]
80 |
81 | [tool.uv.sources]
82 | torch = [
83 | { index = "pytorch-cu128", extra = "cu128" },
84 | ]
85 | torchvision = [
86 | { index = "pytorch-cu128", extra = "cu128" },
87 | ]
88 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/flow_warp.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def flow_warp(x,
10 | flow,
11 | interpolation='bilinear',
12 | padding_mode='zeros',
13 | align_corners=True):
14 | """Warp an image or a feature map with optical flow.
15 |
16 | Args:
17 | x (Tensor): Tensor with size (n, c, h, w).
18 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
19 | a two-channel, denoting the width and height relative offsets.
20 | Note that the values are not normalized to [-1, 1].
21 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
22 | Default: 'bilinear'.
23 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
24 | Default: 'zeros'.
25 | align_corners (bool): Whether align corners. Default: True.
26 |
27 | Returns:
28 | Tensor: Warped image or feature map.
29 | """
30 | if x.size()[-2:] != flow.size()[1:3]:
31 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
32 | f'flow ({flow.size()[1:3]}) are not the same.')
33 | _, _, h, w = x.size()
34 | # create mesh grid
35 | device = flow.device
36 | # torch.meshgrid has been modified in 1.10.0 (compatibility with previous
37 | # versions), and will be further modified in 1.12 (Breaking Change)
38 | if 'indexing' in torch.meshgrid.__code__.co_varnames:
39 | grid_y, grid_x = torch.meshgrid(
40 | torch.arange(0, h, device=device, dtype=x.dtype),
41 | torch.arange(0, w, device=device, dtype=x.dtype),
42 | indexing='ij')
43 | else:
44 | grid_y, grid_x = torch.meshgrid(
45 | torch.arange(0, h, device=device, dtype=x.dtype),
46 | torch.arange(0, w, device=device, dtype=x.dtype))
47 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2
48 | grid.requires_grad_(False)
49 |
50 | grid_flow = grid + flow
51 | # scale grid_flow to [-1,1]
52 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
53 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
54 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
55 | grid_flow = grid_flow.type(x.type())
56 | output = F.grid_sample(
57 | x,
58 | grid_flow,
59 | mode=interpolation,
60 | padding_mode=padding_mode,
61 | align_corners=align_corners)
62 | return output
63 |
--------------------------------------------------------------------------------
/lada/utils/audio_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 |
6 | import av
7 | import io
8 | import os
9 | import subprocess
10 | import shutil
11 | from typing import Optional
12 | from lada.utils import video_utils, os_utils
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | def combine_audio_video_files(av_video_metadata: video_utils.VideoMetadata, tmp_v_video_input_path, av_video_output_path):
17 | audio_codec = get_audio_codec(av_video_metadata.video_file)
18 | if audio_codec:
19 | needs_audio_reencoding = not is_output_container_compatible_with_input_audio_codec(audio_codec, av_video_output_path)
20 | needs_video_delay = av_video_metadata.start_pts > 0
21 |
22 | cmd = ["ffmpeg", "-y", "-loglevel", "quiet"]
23 | cmd += ["-i", av_video_metadata.video_file]
24 | if needs_video_delay > 0:
25 | delay_in_seconds = float(av_video_metadata.start_pts * av_video_metadata.time_base)
26 | cmd += ["-itsoffset", str(delay_in_seconds)]
27 | cmd += ["-i", tmp_v_video_input_path]
28 | if needs_audio_reencoding:
29 | cmd += ["-c:v", "copy"]
30 | else:
31 | cmd += ["-c", "copy"]
32 | cmd += ["-map", "1:v:0"]
33 | cmd += ["-map", "0:a:0"]
34 | cmd += [av_video_output_path]
35 | subprocess.run(cmd, stdout=subprocess.PIPE, startupinfo=os_utils.get_subprocess_startup_info())
36 | else:
37 | shutil.copy(tmp_v_video_input_path, av_video_output_path)
38 | os.remove(tmp_v_video_input_path)
39 |
40 | def get_audio_codec(file_path: str) -> Optional[str]:
41 | cmd = f"ffprobe -loglevel error -select_streams a:0 -show_entries stream=codec_name -of default=nw=1:nk=1"
42 | cmd = cmd.split() + [file_path]
43 | cmd_result = subprocess.run(cmd, stdout=subprocess.PIPE, startupinfo=os_utils.get_subprocess_startup_info())
44 | audio_codec = cmd_result.stdout.decode('utf-8').strip().lower()
45 | return audio_codec if len(audio_codec) > 0 else None
46 |
47 | def is_output_container_compatible_with_input_audio_codec(audio_codec: str, output_path: str) -> bool:
48 | file_extension = os.path.splitext(output_path)[1]
49 | file_extension = file_extension.lower()
50 | if file_extension in ('.mp4', '.m4v'):
51 | output_container_format = "mp4"
52 | elif file_extension == '.mkv':
53 | output_container_format = "matroska"
54 | else:
55 | logger.info(f"Couldn't determine video container format based on file extension: {file_extension}")
56 | return False
57 |
58 | buf = io.BytesIO()
59 | with av.open(buf, 'w', output_container_format) as container:
60 | return audio_codec in container.supported_codecs
61 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/setup_env.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import datetime
6 | import importlib
7 | import warnings
8 | from types import ModuleType
9 | from typing import Optional
10 |
11 | from mmengine import DefaultScope
12 |
13 |
14 | def register_all_modules(init_default_scope: bool = True) -> None:
15 | """Register all modules in mmagic into the registries.
16 |
17 | Args:
18 | init_default_scope (bool): Whether initialize the mmagic default scope.
19 | When `init_default_scope=True`, the global default scope will be
20 | set to `mmagic`, and all registries will build modules from
21 | mmagic's registry node.
22 | To understand more about the registry, please refer
23 | to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html
24 | Defaults to True.
25 | """ # noqa
26 | # import mmagic.datasets # noqa: F401,F403
27 | # import mmagic.engine # noqa: F401,F403
28 | # import mmagic.evaluation # noqa: F401,F403
29 | # import mmagic.models # noqa: F401,F403
30 | # import mmagic.visualization # noqa: F401,F403
31 |
32 | if init_default_scope:
33 | never_created = DefaultScope.get_current_instance() is None \
34 | or not DefaultScope.check_instance_created('mmagic')
35 | if never_created:
36 | DefaultScope.get_instance('mmagic', scope_name='mmagic')
37 | return
38 | current_scope = DefaultScope.get_current_instance()
39 | if current_scope.scope_name != 'mmagic':
40 | warnings.warn('The current default scope '
41 | f'"{current_scope.scope_name}" is not "mmagic", '
42 | '`register_all_modules` will force the current'
43 | 'default scope to be "mmagic". If this is not '
44 | 'expected, please set `init_default_scope=False`.')
45 | # avoid name conflict
46 | new_instance_name = f'mmagic-{datetime.datetime.now()}'
47 | DefaultScope.get_instance(new_instance_name, scope_name='mmagic')
48 |
49 |
50 | def try_import(name: str) -> Optional[ModuleType]:
51 | """Try to import a module.
52 |
53 | Args:
54 | name (str): Specifies what module to import in absolute or relative
55 | terms (e.g. either pkg.mod or ..mod).
56 | Returns:
57 | ModuleType or None: If importing successfully, returns the imported
58 | module, otherwise returns None.
59 | """
60 | try:
61 | return importlib.import_module(name)
62 | except ImportError:
63 | return None
64 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/face_detector.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import cv2
5 | import numpy as np
6 | from lada.utils import Mask, Box, Image, Detection, Detections, DETECTION_CLASSES
7 | from lada.utils import box_utils
8 | from lada.models.centerface.centerface import CenterFace
9 |
10 | def convert_to_boxes(dets) -> list[Box]:
11 | boxes = []
12 | for i, det in enumerate(dets):
13 | box, score = det[:4], det[4]
14 | x1, y1, x2, y2 = box.astype(int)
15 | boxes.append((int(y1), int(x1), int(y2), int(x2)))
16 | return boxes
17 |
18 | def create_mask(frame: Image, box: Box) -> Mask:
19 | t, l, b, r = box
20 | box_width, box_height = r - l + 1, b - t + 1
21 |
22 | mask = np.zeros(frame.shape[:2], dtype=np.uint8)
23 |
24 | # Set the center of the ellipse at the center of the box
25 | center = (l + (box_width // 2), t + (box_height // 2))
26 |
27 | # Set the axes of the ellipse to half the width and half the height of the box
28 | axes = (box_width // 2, box_height // 2)
29 |
30 | angle = 0
31 | start_angle = 0
32 | end_angle = 360
33 |
34 | color = DETECTION_CLASSES["sfw_face"]["mask_value"]
35 | thickness = -1
36 |
37 | cv2.ellipse(mask, center, axes, angle, start_angle, end_angle, color, thickness)
38 |
39 | mask = np.expand_dims(mask, axis=-1)
40 |
41 | return mask
42 |
43 | def get_nsfw_frame(dets: list[Box], frame: Image, random_extend_masks: bool) -> Detections | None:
44 | if len(dets) == 0:
45 | return None
46 | detections = []
47 | for box in dets:
48 |
49 | if random_extend_masks:
50 | box = box_utils.random_scale_box(frame, box, scale_range=(1.2, 1.5))
51 | mask = create_mask(frame, box)
52 |
53 | t, l, b, r = box
54 | width, height = r - l + 1, b - t + 1
55 | if min(width, height) < 40:
56 | # skip tiny detections
57 | continue
58 |
59 | detections.append(Detection(DETECTION_CLASSES["sfw_face"]["cls"], box, mask))
60 | return Detections(frame, detections)
61 |
62 | class FaceDetector:
63 | def __init__(self, model: CenterFace, random_extend_masks=False, conf=0.2):
64 | self.model = model
65 | self.random_extend_masks = random_extend_masks
66 | self.conf = conf
67 |
68 | def detect(self, source: str | Image) -> Detections | None:
69 | if isinstance(source, str):
70 | bgr_image = cv2.imread(source)
71 | rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
72 | else:
73 | bgr_image = source
74 | rgb_image = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
75 | dets, _ = self.model(rgb_image, threshold=self.conf)
76 | dets_boxes = convert_to_boxes(dets)
77 | return get_nsfw_frame(dets_boxes, bgr_image, random_extend_masks=self.random_extend_masks)
78 |
--------------------------------------------------------------------------------
/lada/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from dataclasses import dataclass
5 | from fractions import Fraction
6 |
7 | import numpy as np
8 | import torch
9 |
10 | """
11 | A bounding box of a detected object defined by two points, the top/left and bottom/right pixel.
12 | Represented as X/Y coordinate tuple: top-left (Y), top-left (X), bottom-right (Y), bottom-right (X)
13 | """
14 | type Box = tuple[int, int, int, int]
15 |
16 | """
17 | A segmentation mask of a detected object. Pixel values of 0 indicate that the pixel is not part of the object.
18 | Shape: (H, W, 1), dtype: np.uint8, range: 0-255
19 | """
20 | type Mask = np.ndarray[np.uint8]
21 |
22 | """
23 | A segmentation mask of a detected object. Pixel values of 0 indicate that the pixel is not part of the object.
24 | Shape: (H, W, 1), dtype: torch.uint8, range: 0-255
25 | """
26 | type MaskTensor = torch.Tensor
27 |
28 | """
29 | Color Image
30 | Shape: (H, W, C=3), dtype: np.uint8, range: 0-255
31 | H, W, C stand for image height, width and color channels respectively. C is in BGR instead of RGB order
32 | """
33 | type Image = np.ndarray[np.uint8]
34 |
35 | """
36 | Color Image
37 | Shape: (H, W, C=3), dtype: torch.uint8, range: 0-255
38 | H, W, C stand for image height, width and color channels respectively. C is in BGR instead of RGB order
39 | """
40 | type ImageTensor = torch.Tensor
41 |
42 | """
43 | Padding of an Image or Mask represented as tuple padding values (number of black pixels) added to each image edge:
44 | (padding-top, padding-bottom, padding-left, padding-right)
45 | """
46 | type Pad = tuple[int, int, int, int]
47 |
48 | """
49 | Metadata about a video file
50 | """
51 | @dataclass
52 | class VideoMetadata:
53 | video_file: str
54 | video_height: int
55 | video_width: int
56 | video_fps: float
57 | average_fps: float
58 | video_fps_exact: Fraction
59 | codec_name: str
60 | frames_count: int
61 | duration: float
62 | time_base: Fraction
63 | start_pts: int
64 |
65 | @dataclass
66 | class Detection:
67 | cls: int
68 | box: Box
69 | mask: Mask # Binary segmentation mask. Values can be either 0 (background) or mask_val
70 | confidence: float | None = None # value between 0 and 1 where 1 is completely certain
71 |
72 | """
73 | Detection result containing bounding box and segmentation mask of the detected object within the frame
74 | """
75 | @dataclass
76 | class Detections:
77 | frame: Image
78 | detections: list[Detection]
79 |
80 | """
81 | Mapping for class ids and mask values.
82 | Mask value is anon-zero value used in binary mask (Mask) to indicate if pixel belongs to the class
83 | """
84 | DETECTION_CLASSES = {
85 | "nsfw": dict(cls=0, mask_value=255),
86 | "sfw_head": dict(cls=1, mask_value=127),
87 | "sfw_face": dict(cls=2, mask_value=192),
88 | "watermark": dict(cls=3, mask_value=60),
89 | "mosaic": dict(cls=4, mask_value=90),
90 | }
--------------------------------------------------------------------------------
/lada/gui/preview/headerbar_files_drop_down.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 | import pathlib
6 |
7 | from gi.repository import Gtk, Gio, GObject, Pango
8 |
9 | from lada import LOG_LEVEL
10 | from lada.gui import utils
11 |
12 | here = pathlib.Path(__file__).parent.resolve()
13 |
14 | logger = logging.getLogger(__name__)
15 | logging.basicConfig(level=LOG_LEVEL)
16 |
17 |
18 | @Gtk.Template(string=utils.translate_ui_xml(here / 'headerbar_files_drop_down.ui'))
19 | class HeaderbarFilesDropDown(Gtk.DropDown):
20 | __gtype_name__ = "HeaderbarFilesDropDown"
21 |
22 | def __init__(self, **kwargs) -> None:
23 | super().__init__(**kwargs)
24 |
25 | # Custom drop down factory so we can limit the dropdown button width and ellipse the selected item in the header bar if they are too long
26 | button_factory = Gtk.SignalListItemFactory()
27 | button_factory.connect("setup", self.on_button_item_setup)
28 | button_factory.connect("bind", self.on_item_bind)
29 | button_factory.connect("unbind", self.on_item_unbind)
30 |
31 | # Custom drop down factory for popup items shown when dropdown is clicked without width restrictions
32 | popup_factory = Gtk.SignalListItemFactory()
33 | popup_factory.connect("setup", self.on_popup_item_setup)
34 | popup_factory.connect("bind", self.on_item_bind)
35 | popup_factory.connect("unbind", self.on_item_unbind)
36 |
37 | expression = Gtk.ClosureExpression.new(
38 | GObject.TYPE_STRING,
39 | lambda obj: obj.get_string(),
40 | None,
41 | )
42 |
43 | model = Gtk.StringList()
44 | self.props.model = model
45 | self.props.expression = expression
46 | self.props.factory = button_factory
47 | self.props.list_factory = popup_factory
48 |
49 | def on_popup_item_setup(self, factory, list_item):
50 | label = Gtk.Label()
51 | list_item.set_child(label)
52 |
53 | def on_button_item_setup(self, factory, list_item):
54 | label = Gtk.Label()
55 | label.set_ellipsize(Pango.EllipsizeMode.END)
56 | label.set_max_width_chars(20)
57 | list_item.set_child(label)
58 |
59 | def on_item_bind(self, factory, list_item):
60 | label = list_item.get_child()
61 | idx = list_item.get_position()
62 |
63 | label.set_text(self.props.model[idx].get_string())
64 |
65 | def on_item_unbind(self, factory, list_item):
66 | label = list_item.get_child()
67 | if label:
68 | label.set_text("")
69 |
70 | def add_files(self, files: list[Gio.File]):
71 | for file in files:
72 | self.props.model.append(file.get_basename())
73 |
74 | is_multiple_files = len(self.props.model) > 1
75 | self.set_enable_search(is_multiple_files)
76 | self.set_sensitive(is_multiple_files)
77 | self.set_show_arrow(is_multiple_files)
--------------------------------------------------------------------------------
/scripts/training/train-bj-classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import torch
5 | import torchvision
6 | from torchvision import transforms
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import os
10 |
11 | transform = transforms.Compose([
12 | transforms.Resize(224),
13 | transforms.ToTensor(),
14 | transforms.Normalize(
15 | mean=[0.485, 0.456, 0.406],
16 | std=[0.229, 0.224, 0.225]
17 | )
18 | ])
19 |
20 | train_data = torchvision.datasets.ImageFolder(root="datasets/pov_bj_scene_detection/train", transform=transform)
21 | test_data = torchvision.datasets.ImageFolder(root="datasets/pov_bj_scene_detection/val", transform=transform)
22 |
23 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
24 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)
25 |
26 | model = torchvision.models.resnet50(pretrained=True)
27 |
28 | # Replace the last layer to match our own classes
29 | num_features = model.fc.in_features
30 | model.fc = nn.Linear(num_features, len(train_data.classes))
31 |
32 | criterion = nn.CrossEntropyLoss()
33 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
34 |
35 | device = torch.device("cuda:0")
36 | model = model.to(device)
37 |
38 | num_epochs = 15
39 |
40 | experiment_name = "run1"
41 | experiment_root_dir = "../../experiments/bj_classifier"
42 | experiment_dir = f"{experiment_root_dir}/{experiment_name}"
43 | os.makedirs(experiment_dir)
44 |
45 | for epoch in range(num_epochs):
46 | # Train
47 | model.train()
48 | train_loss = 0.0
49 | for i, (inputs, labels) in enumerate(train_loader):
50 | inputs = inputs.to(device)
51 | labels = labels.to(device)
52 |
53 | optimizer.zero_grad()
54 |
55 | outputs = model(inputs)
56 | loss = criterion(outputs, labels)
57 | loss.backward()
58 | optimizer.step()
59 |
60 | train_loss += loss.item() * inputs.size(0)
61 |
62 | # Evaluate
63 | model.eval()
64 | test_loss = 0.0
65 | test_acc = 0.0
66 | with torch.no_grad():
67 | for i, (inputs, labels) in enumerate(test_loader):
68 | inputs = inputs.to(device)
69 | labels = labels.to(device)
70 |
71 | outputs = model(inputs)
72 | loss = criterion(outputs, labels)
73 |
74 | # Update the test loss and accuracy
75 | test_loss += loss.item() * inputs.size(0)
76 | _, preds = torch.max(outputs, 1)
77 | test_acc += torch.sum(preds == labels.data)
78 |
79 | train_loss /= len(train_data)
80 | test_loss /= len(test_data)
81 | test_acc = test_acc.double() / len(test_data)
82 | print(f"Epoch [{epoch + 1}/{num_epochs}] Train Loss: {train_loss:.4f} Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}")
83 | best_test_acc = test_acc
84 | save_path = f"{experiment_dir}/checkpoint_{epoch}.pt"
85 | torch.save(model.state_dict(), save_path)
86 |
87 |
--------------------------------------------------------------------------------
/scripts/dataset_creation/convert-dataset-labelme-to-yolo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import glob
5 | import os
6 | import json
7 | import argparse
8 | import shutil
9 |
10 | def convert_to_yolo_txt_lines(labelme_json, labelme_label_to_yolo_class_mapping={"nsfw": 0}):
11 | image_height = labelme_json["imageHeight"]
12 | image_width = labelme_json["imageWidth"]
13 |
14 | point_txt = []
15 |
16 | for shape in labelme_json["shapes"]:
17 | if shape["label"] not in labelme_label_to_yolo_class_mapping:
18 | continue
19 | if shape["shape_type"] != "polygon":
20 | continue
21 |
22 | yolo_class = labelme_label_to_yolo_class_mapping[shape["label"]]
23 | txt = f"{yolo_class}"
24 |
25 | for w, h in shape["points"]:
26 | txt = f'{txt} {float(w)/image_width} {float(h)/image_height}'
27 |
28 | point_txt.append(txt)
29 | return point_txt
30 |
31 | def main(input_json_dir, output_text_dir, output_images_dir):
32 | labelme_json_file_paths = glob.glob(os.path.join(input_json_dir, "*.json"))
33 | for labelme_json_file_path in labelme_json_file_paths:
34 | with open(labelme_json_file_path) as labelme_json_file:
35 | labelme_json = json.load(labelme_json_file)
36 | image_path = labelme_json['imagePath']
37 | image_filename = os.path.basename(image_path)
38 |
39 | yolo_img_file_path = os.path.join(output_images_dir, image_filename)
40 | shutil.copyfile(os.path.join(input_json_dir, image_path), yolo_img_file_path)
41 |
42 | yolo_txt_lines = convert_to_yolo_txt_lines(labelme_json)
43 | if len(yolo_txt_lines) == 0:
44 | continue
45 |
46 | image_filename_without_ext = os.path.splitext(image_filename)[0]
47 | yolo_txt_file_path = os.path.join(output_text_dir, image_filename_without_ext + '.txt')
48 | with open(yolo_txt_file_path, 'w') as yolo_txt_file:
49 | for line in yolo_txt_lines:
50 | yolo_txt_file.write(line)
51 | yolo_txt_file.write('\n')
52 |
53 | def parse_args():
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument('--dir-in', type=str)
56 | parser.add_argument('--dir-out-images', type=str)
57 | parser.add_argument('--dir-out-labels', type=str)
58 | args = parser.parse_args()
59 | return args
60 |
61 |
62 | if __name__ == '__main__':
63 | args = parse_args()
64 | main(args.dir_in, args.dir_out_labels, args.dir_out_images)
65 |
66 | # mkdir -p yolo/datasets/nsfw_detection/{train,val}/{images,labels}
67 | # python yolo/convert-dataset-labelme-to-yolo.py --dir-in datasets/nsfw_detection/val --dir-out-images yolo/datasets/nsfw_detection/val/images --dir-out-labels yolo/datasets/nsfw_detection/val/labels
68 | # python yolo/convert-dataset-labelme-to-yolo.py --dir-in datasets/nsfw_detection/train --dir-out-images yolo/datasets/nsfw_detection/train/images --dir-out-labels yolo/datasets/nsfw_detection/train/labels
69 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/loop_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | from logging import WARNING
6 | from typing import Any, Dict, List, Union
7 |
8 | from mmengine import is_list_of, print_log
9 | from mmengine.evaluator import Evaluator
10 |
11 | EVALUATOR_TYPE = Union[Evaluator, Dict, List]
12 |
13 |
14 | def update_and_check_evaluator(evaluator: EVALUATOR_TYPE
15 | ) -> Union[Evaluator, dict]:
16 | """Check the whether the evaluator instance or dict config is Evaluator. If
17 | input is a dict config, attempt to set evaluator type as Evaluator and
18 | raised warning if it is not allowed. If input is a Evaluator instance,
19 | check whether it is a Evaluator class, otherwise,
20 |
21 | Args:
22 | evaluator (Union[Evaluator, dict, list]): The evaluator instance or
23 | config dict.
24 | """
25 | # check Evaluator instance
26 | warning_template = ('Evaluator type for current config is \'{}\'. '
27 | 'If you want to use MultiValLoop, we strongly '
28 | 'recommend you to use \'Evaluator\' provided by '
29 | '\'MMagic\'. Otherwise, there maybe some potential '
30 | 'bugs.')
31 | if isinstance(evaluator, Evaluator):
32 | cls_name = evaluator.__class__.__name__
33 | if cls_name != 'Evaluator':
34 | print_log(warning_template.format(cls_name), 'current', WARNING)
35 | return evaluator
36 |
37 | # add type for **single evaluator with list of metrics**
38 | if isinstance(evaluator, list):
39 | evaluator = dict(type='Evaluator', metrics=evaluator)
40 | return evaluator
41 |
42 | # check and update dict config
43 | assert isinstance(evaluator, dict), (
44 | 'Can only conduct check and update for list of metrics, a config dict '
45 | f'or a Evaluator object. But receives {type(evaluator)}.')
46 | evaluator.setdefault('type', 'Evaluator')
47 | evaluator.setdefault('metrics', None) # default as 'dummy evaluator'
48 | _type = evaluator['type']
49 | if _type != 'Evaluator':
50 | print_log(warning_template.format(_type), 'current', WARNING)
51 | return evaluator
52 |
53 |
54 | def is_evaluator(evaluator: Any) -> bool:
55 | """Check whether the input is a valid evaluator config or Evaluator object.
56 |
57 | Args:
58 | evaluator (Any): The input to check.
59 |
60 | Returns:
61 | bool: Whether the input is a valid evaluator config or Evaluator
62 | object.
63 | """
64 | # Single evaluator with type
65 | if isinstance(evaluator, dict) and 'metrics' in evaluator:
66 | return True
67 | # Single evaluator without type
68 | elif (is_list_of(evaluator, dict)
69 | and all(['metrics' not in cfg_ for cfg_ in evaluator])):
70 | return True
71 | elif isinstance(evaluator, Evaluator):
72 | return True
73 | else:
74 | return False
75 |
--------------------------------------------------------------------------------
/lada/models/deepmosaics/util/data.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: DeepMosaics Authors
2 | # SPDX-License-Identifier: GPL-3.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/HypoX64/DeepMosaics/
4 |
5 | import numpy as np
6 | import torch
7 | import cv2
8 |
9 | def to_tensor(data,gpu_id,dtype):
10 | data = torch.from_numpy(data)
11 | if gpu_id != '-1':
12 | data = data.to(device=f'cuda:{gpu_id}',dtype=dtype)
13 | return data
14 |
15 | def normalize(data):
16 | '''
17 | normalize to -1 ~ 1
18 | '''
19 | return (data.astype(np.float32)/255.0-0.5)/0.5
20 |
21 | def anti_normalize(data):
22 | return np.clip((data*0.5+0.5)*255,0,255).astype(np.uint8)
23 |
24 | def tensor2im(image_tensor, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
25 | image_tensor =image_tensor.data
26 | image_numpy = image_tensor[batch_index].cpu().float().numpy()
27 |
28 | if not is0_1:
29 | image_numpy = (image_numpy + 1)/2.0
30 | image_numpy = np.clip(image_numpy * 255.0,0,255)
31 |
32 | # gray -> output 1ch
33 | if gray:
34 | h, w = image_numpy.shape[1:]
35 | image_numpy = image_numpy.reshape(h,w)
36 | return image_numpy.astype(np.uint8)
37 |
38 | # output 3ch
39 | if image_numpy.shape[0] == 1:
40 | image_numpy = np.tile(image_numpy, (3, 1, 1))
41 | image_numpy = image_numpy.transpose((1, 2, 0))
42 | if rgb2bgr and not gray:
43 | image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
44 | return image_numpy.astype(np.uint8)
45 |
46 |
47 | def im2tensor(image_numpy, gray=False,bgr2rgb = True, reshape = True, gpu_id = '-1',is0_1 = False, dtype=torch.float32):
48 | if gray:
49 | h, w = image_numpy.shape
50 | image_numpy = (image_numpy/255.0-0.5)/0.5
51 | image_tensor = torch.from_numpy(image_numpy).float()
52 | if reshape:
53 | image_tensor = image_tensor.reshape(1,1,h,w)
54 | else:
55 | h, w ,ch = image_numpy.shape
56 | if bgr2rgb:
57 | image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
58 | if is0_1:
59 | image_numpy = image_numpy/255.0
60 | else:
61 | image_numpy = (image_numpy/255.0-0.5)/0.5
62 | image_numpy = image_numpy.transpose((2, 0, 1))
63 | image_tensor = torch.from_numpy(image_numpy).float()
64 | if reshape:
65 | image_tensor = image_tensor.reshape(1,ch,h,w)
66 | if gpu_id != '-1':
67 | image_tensor = image_tensor.to(device=f'cuda:{gpu_id}',dtype=dtype)
68 | return image_tensor
69 |
70 | def shuffledata(data,target):
71 | state = np.random.get_state()
72 | np.random.shuffle(data)
73 | np.random.set_state(state)
74 | np.random.shuffle(target)
75 |
76 | def showresult(img1,img2,img3,name,is0_1 = False):
77 | size = img1.shape[3]
78 | showimg=np.zeros((size,size*3,3))
79 | showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = is0_1)
80 | showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = is0_1)
81 | showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = is0_1)
82 | cv2.imwrite(name, showimg)
83 |
--------------------------------------------------------------------------------
/packaging/flatpak/io.github.ladaapp.lada.yaml:
--------------------------------------------------------------------------------
1 | id: io.github.ladaapp.lada
2 | runtime: org.gnome.Platform
3 | runtime-version: '49'
4 | desktop-file-name-suffix: " (dev)"
5 | default-branch: "main"
6 | sdk: org.gnome.Sdk
7 | command: lada
8 | finish-args:
9 | - "--socket=wayland"
10 | - "--socket=fallback-x11"
11 | - "--device=dri"
12 | - "--share=ipc"
13 | - "--socket=pulseaudio"
14 | - "--talk-name=org.gnome.SessionManager" # Post-export automatic system shutdown on GNOME desktops
15 | - "--talk-name=org.kde.Shutdown" # Post-export automatic system shutdown on KDE desktops
16 | - "--env=LADA_MODEL_WEIGHTS_DIR=/app/model_weights"
17 | - "--env=LOCALE_DIR=/app/lib/python3.13/site-packages/lada/locale"
18 | - "--env=YOLO_CONFIG_DIR=/var/config/yolo"
19 | modules:
20 | - lada-python-dependencies.yaml
21 | - name: lada-model-weights
22 | buildsystem: simple
23 | build-commands:
24 | - |
25 | mkdir -p /app/model_weights/3rd_party
26 | mv lada_*.pt{,h} /app/model_weights/
27 | mv clean_youknow_video.pth /app/model_weights/3rd_party/
28 | sources:
29 | - type: file
30 | url: "https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_detection_model_v4_accurate.pt?download=true"
31 | sha256: c244d7e49d8f88e264b8dc15f91fb21f5908ad8fb6f300b7bc88462d0801bc1f
32 | - type: file
33 | url: "https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_detection_model_v4_fast.pt?download=true"
34 | sha256: 9a6b660d1d3e3797d39515e08b0e72fcc59815f38279faa7a4ab374ab2c1e3b4
35 | - type: file
36 | url: "https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_restoration_model_generic_v1.2.pth?download=true"
37 | sha256: d404152576ce64fb5b2f315c03062709dac4f5f8548934866cd01c823c8104ee
38 | - type: file
39 | url: "https://drive.usercontent.google.com/download?id=1ulct4RhRxQp1v5xwEmUH7xz7AK42Oqlw&export=download&confirm=t"
40 | sha256: 5643ca297c13920b8ffd39a0d85296e494683a69e5e8204d662653d24c582766
41 | dest-filename: clean_youknow_video.pth
42 | - name: lada
43 | buildsystem: simple
44 | build-commands:
45 | - sh translations/compile_po.sh --release
46 | - python3 -m pip install --prefix=/app --no-deps --no-build-isolation '.[gui]'
47 | - |
48 | mkdir -p /app/share/applications /app/share/metainfo /app/share/icons/hicolor/128x128/apps
49 | install -m 644 packaging/flatpak/share/io.github.ladaapp.lada.desktop /app/share/applications
50 | install -m 644 packaging/flatpak/share/io.github.ladaapp.lada.metainfo.xml /app/share/metainfo
51 | install -m 644 packaging/flatpak/share/io.github.ladaapp.lada.png /app/share/icons/hicolor/128x128/apps
52 | - |
53 | patch -u -p1 -d /app/lib/python3.13/site-packages < $FLATPAK_BUILDER_BUILDDIR/patches/increase_mms_time_limit.patch
54 | patch -u -p1 -d /app/lib/python3.13/site-packages < $FLATPAK_BUILDER_BUILDDIR/patches/remove_ultralytics_telemetry.patch
55 | patch -u -p1 -d /app/lib/python3.13/site-packages < $FLATPAK_BUILDER_BUILDDIR/patches/fix_loading_mmengine_weights_on_torch26_and_higher.diff
56 | sources:
57 | - type: dir
58 | path: "../.."
59 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/inference.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import logging
5 |
6 | import numpy as np
7 | import torch
8 | from mmengine.config import Config
9 | from mmengine.runner import load_checkpoint
10 |
11 | from lada.utils import Image
12 | from lada.utils import image_utils
13 | from lada.models.basicvsrpp import register_all_modules
14 | from lada.models.basicvsrpp.basicvsrpp_gan import BasicVSRPlusPlusGan
15 | from lada.models.basicvsrpp.mmagic.basicvsr import BasicVSR
16 | from lada.models.basicvsrpp.mmagic.registry import MODELS
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | def get_default_gan_inference_config() -> dict:
21 | return dict(
22 | type='BasicVSRPlusPlusGan',
23 | generator=dict(
24 | type='BasicVSRPlusPlusGanNet',
25 | mid_channels=64,
26 | num_blocks=15,
27 | spynet_pretrained=None),
28 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
29 | is_use_ema=True,
30 | data_preprocessor=dict(
31 | type='DataPreprocessor',
32 | mean=[0., 0., 0.],
33 | std=[255., 255., 255.],
34 | ))
35 |
36 |
37 | def load_model(config: str | dict | None, checkpoint_path, device, fp16=False) -> BasicVSRPlusPlusGan | BasicVSR:
38 | register_all_modules()
39 | if device and type(device) == str:
40 | device = torch.device(device)
41 | if config is None:
42 | config = get_default_gan_inference_config()
43 | elif type(config) == str:
44 | config = Config.fromfile(config).model
45 | elif type(config) == dict:
46 | pass
47 | else:
48 | raise Exception("unsupported value for 'config', Must be either a file path to a config file or a dict definition of the model")
49 | model = MODELS.build(config)
50 | assert isinstance(model, BasicVSRPlusPlusGan) or isinstance(model, BasicVSR), "Unknown model config. Must be either stage1 (BasicVSR) or stage2 (BasicVSRPlusPlusGan)"
51 | load_checkpoint(model, checkpoint_path, map_location='cpu', logger=logger)
52 | model.cfg = config
53 | model = model.to(device).eval()
54 | if fp16:
55 | model = model.half()
56 | return model
57 |
58 | def inference(model: BasicVSRPlusPlusGan | BasicVSR, video: list[Image], device) -> list[Image]:
59 | input_frame_count = len(video)
60 | input_frame_shape = video[0].shape
61 | if device and type(device) == str:
62 | device = torch.device(device)
63 | with torch.no_grad():
64 | input = torch.stack(image_utils.img2tensor(video, bgr2rgb=False, float32=True), dim=0)
65 | input = torch.unsqueeze(input, dim=0) # TCHW -> BTCHW
66 | result = model(inputs=input.to(device))
67 | result = torch.squeeze(result, dim=0) # BTCHW -> TCHW
68 | result = list(torch.unbind(result, 0))
69 | output = image_utils.tensor2img(result, rgb2bgr=False, out_type=np.uint8, min_max=(0, 1))
70 | output_frame_count = len(output)
71 | output_frame_shape = output[0].shape
72 | assert input_frame_count == output_frame_count and input_frame_shape == output_frame_shape
73 | return output
74 |
--------------------------------------------------------------------------------
/lada/gui/fileselection/file_selection_view.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
34 |
35 |
36 |
37 | Restore pixelated adult videos
38 | Drop one or more videos here
39 |
40 |
41 | Open Files…
42 |
43 | center
44 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
69 |
70 |
--------------------------------------------------------------------------------
/configs/basicvsrpp/mosaic_restoration_generic_stage1.py:
--------------------------------------------------------------------------------
1 | from mmengine.config import read_base
2 |
3 | with read_base():
4 | from ._base_.default_runtime import *
5 |
6 | experiment_name = 'mosaic_restoration_generic_stage1'
7 | work_dir = f'./experiments/basicvsrpp/{experiment_name}'
8 | save_dir = './experiments/basicvsrpp'
9 |
10 | model = dict(
11 | type='BasicVSR',
12 | generator=dict(
13 | type='BasicVSRPlusPlusNet',
14 | mid_channels=64,
15 | num_blocks=15,
16 | spynet_pretrained='model_weights/3rd_party/spynet_20210409-c6c1bd09.pth'),
17 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
18 | train_cfg=dict(fix_iter=5000),
19 | data_preprocessor=dict(
20 | type='DataPreprocessor',
21 | mean=[0., 0., 0.],
22 | std=[255., 255., 255.],
23 | ))
24 |
25 | data_root = 'datasets/mosaic_removal_vid'
26 |
27 | train_dataloader = dict(
28 | num_workers=4,
29 | batch_size=2,
30 | persistent_workers=False,
31 | sampler=dict(type='InfiniteSampler', shuffle=True),
32 | dataset=dict(
33 | type='MosaicVideoDataset',
34 | metadata_root_dir=data_root + "/train/crop_unscaled_meta",
35 | num_frame=30,
36 | degrade=True,
37 | use_hflip=True,
38 | repeatable_random=False,
39 | random_mosaic_params=True,
40 | filter_watermark=False,
41 | filter_nudenet_nsfw=False,
42 | filter_video_quality=False,
43 | lq_size=256),
44 | collate_fn=dict(type='default_collate'))
45 |
46 | val_dataloader = dict(
47 | num_workers=1,
48 | batch_size=1,
49 | persistent_workers=False,
50 | sampler=dict(type='DefaultSampler', shuffle=False),
51 | dataset=dict(
52 | type='MosaicVideoDataset',
53 | metadata_root_dir=data_root + "/val/crop_unscaled_meta",
54 | num_frame=30,
55 | degrade=True,
56 | use_hflip=False,
57 | repeatable_random=True,
58 | random_mosaic_params=True,
59 | filter_watermark=False,
60 | filter_nudenet_nsfw=False,
61 | filter_video_quality=False,
62 | lq_size=256),
63 | collate_fn=dict(type='default_collate'))
64 |
65 | val_evaluator = dict(
66 | type='Evaluator', metrics=[
67 | dict(type='PSNR'),
68 | dict(type='SSIM'),
69 | ])
70 |
71 | train_cfg = dict(
72 | type='IterBasedTrainLoop', max_iters=100_000, val_interval=5000)
73 | val_cfg = dict(type='MultiValLoop')
74 |
75 | optim_wrapper = dict(
76 | constructor='DefaultOptimWrapperConstructor',
77 | type='OptimWrapper',
78 | optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)),
79 | paramwise_cfg=dict(custom_keys={'spynet': dict(lr_mult=0.25)}))
80 |
81 |
82 | vis_backends = [dict(type='TensorboardVisBackend')]
83 | visualizer = dict(
84 | name='visualizer',
85 | type='ConcatImageVisualizer',
86 | vis_backends=vis_backends,
87 | fn_key='gt_path',
88 | img_keys=['gt_img', 'input', 'pred_img'],
89 | bgr2rgb=True)
90 | custom_hooks = [dict(type='BasicVisualizationHook', interval=5)]
91 |
92 | default_hooks = dict(
93 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000, out_dir=save_dir),
94 | logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False))
95 |
--------------------------------------------------------------------------------
/packaging/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.13.9-slim-trixie as build
2 | RUN useradd --create-home lada
3 | RUN apt-get update \
4 | && apt-get install -y --no-install-recommends gettext patch curl xz-utils \
5 | && apt-get clean \
6 | && rm -rf /var/lib/apt/lists/*
7 | USER lada
8 | WORKDIR /home/lada
9 | RUN mkdir -p .local/bin \
10 | && curl -L 'https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-n8.0-latest-linux64-gpl-8.0.tar.xz' \
11 | | tar xJf - -C .local/bin --strip-components 2 ffmpeg-n8.0-latest-linux64-gpl-8.0/bin/ffmpeg ffmpeg-n8.0-latest-linux64-gpl-8.0/bin/ffprobe \
12 | && chmod 555 .local/bin/*
13 | COPY packaging/docker/requirements.txt packaging/docker/requirements.txt
14 | RUN pip install --user --no-deps --no-cache-dir --requirement packaging/docker/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu128
15 | COPY patches patches
16 | RUN patch -u -p1 -d .local/lib/python3.13/site-packages < patches/increase_mms_time_limit.patch \
17 | && patch -u -p1 -d .local/lib/python3.13/site-packages < patches/remove_ultralytics_telemetry.patch \
18 | && patch -u -p1 -d .local/lib/python3.13/site-packages < patches/fix_loading_mmengine_weights_on_torch26_and_higher.diff
19 | COPY --chown=lada:lada . .
20 | RUN sh translations/compile_po.sh --release
21 | RUN pip install --user --no-deps --no-cache-dir '.'
22 |
23 |
24 | FROM python:3.13.9-slim-trixie
25 | RUN useradd --create-home lada
26 | RUN mkdir -p /home/lada/.config/Ultralytics \
27 | && echo '{"settings_version":"0.0.6","datasets_dir":"datasets","weights_dir":"weights","runs_dir":"experiments","uuid":"dummy","sync":false,"api_key":"","openai_api_key":"","clearml":false,"comet":false,"dvc":false,"hub":false,"mlflow":false,"neptune":false,"raytune":false,"tensorboard":false,"wandb":false,"vscode_msg":false,"openvino_msg":false}' > /home/lada/.config/Ultralytics/settings.json \
28 | && chown -R lada:lada /home/lada/.config
29 | RUN mkdir -p model_weights/3rd_party mnt && chmod -R 555 model_weights && chmod 777 mnt
30 | ADD --checksum=sha256:c244d7e49d8f88e264b8dc15f91fb21f5908ad8fb6f300b7bc88462d0801bc1f \
31 | https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_detection_model_v4_accurate.pt?download=true model_weights/lada_mosaic_detection_model_v4_accurate.pt
32 | ADD --checksum=sha256:9a6b660d1d3e3797d39515e08b0e72fcc59815f38279faa7a4ab374ab2c1e3b4 \
33 | https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_detection_model_v4_fast.pt?download=true model_weights/lada_mosaic_detection_model_v4_fast.pt
34 | ADD --checksum=sha256:d404152576ce64fb5b2f315c03062709dac4f5f8548934866cd01c823c8104ee \
35 | https://huggingface.co/ladaapp/lada/resolve/main/lada_mosaic_restoration_model_generic_v1.2.pth?download=true model_weights/lada_mosaic_restoration_model_generic_v1.2.pth
36 | ADD --checksum=sha256:5643ca297c13920b8ffd39a0d85296e494683a69e5e8204d662653d24c582766 \
37 | https://drive.usercontent.google.com/download?id=1ulct4RhRxQp1v5xwEmUH7xz7AK42Oqlw&export=download&confirm=t model_weights/3rd_party/clean_youknow_video.pth
38 | RUN chmod 444 model_weights/*\.pt* model_weights/3rd_party/*\.pt*
39 | COPY --from=build --chown=lada:lada /home/lada/.local /home/lada/.local
40 | USER lada
41 | ENV PATH=/home/lada/.local/bin:$PATH
42 | ENV LOCALE_DIR=/home/lada/.local/lib/python3.13/site-packages/lada/locale
43 |
44 | ENTRYPOINT ["lada-cli"]
45 | CMD ["--help"]
--------------------------------------------------------------------------------
/lada/gui/window.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 |
7 |
8 |
9 |
10 |
11 |
12 | Lada
13 | 900
14 | 550
15 | true
16 |
17 |
18 | managed
19 |
20 |
21 |
22 |
23 |
24 |
25 | crossfade
26 |
27 |
28 | file-selection
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | main
39 |
40 |
41 |
42 |
43 | preview
44 | Watch
45 | playback-symbolic
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | export
54 | Export
55 | arrow-pointing-away-from-line-right-symbolic
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/packaging/README.md:
--------------------------------------------------------------------------------
1 | ## Update dependencies
2 |
3 | After updating release dependencies by adjusting `uv.lock` we need to update dependencies for each release distribution as well.
4 |
5 | ```shell
6 | # No need for gui extra as pycairo and pygobject dependencies are available in flatpak gnome runtime
7 | uv export --no-default-groups --no-emit-local --format pylock.toml --extra cu128 --frozen | uv run packaging/flatpak/convert-pylock-to-flatpak.py
8 | # No need for gui extra as the docker image will only offer Lada CLI
9 | uv export --no-default-groups --no-emit-local --format requirements.txt --extra cu128 --group docker --no-emit-package opencv-python --frozen > packaging/docker/requirements.txt
10 | # No need for gui extra as pycairo and pygobject dependencies will be built locally via gvsbuild
11 | uv export --no-default-groups --no-emit-local --format requirements.txt --extra cu128 --frozen > packaging/windows/requirements.txt
12 | ```
13 |
14 | ## Release a new version
15 |
16 | #### GUI smoke tests
17 | * Open app
18 | * Drop a test file (no longer than a few seconds is fine)
19 | * Open sidebar and click *Reset to factory settings* button
20 | * Open Watch tab
21 | * If video and audio is playing continue
22 | * Open Export tab
23 | * Click *Restore* button
24 | * If restoration finishes and you can play the file by clicking the *Open in External Program" button we're good.
25 |
26 | #### CLI smoke tests
27 | * Run `lada-cli --input path/to/short/test/file.mp4`
28 | * If restoration finishes and you can play the restored file in some media player continue
29 | * Run `lada-cli --input path/to/short/test/file.mp4 --codec hevc_nvenc`
30 | * If restoration finishes and you can play the restored file in some media player we're good.
31 |
32 |
33 | ### Release Process
34 |
35 | > [!TIP]
36 | > Read README.md within the subfolder for each packaging method for specific steps how to build and package each variant
37 |
38 | - [ ] Make sure there is no pending translations PR and `release_ready_translations.txt` is up-to-date ([documentation](../translations/README.md)). Also check `Operations | Repository Maintenance` for pending changes.
39 | - [ ] Bump version in `lada/__init__.py` (no push to origin yet)
40 | - [ ] Write Flatpak release notes in `packaging/flatpak/share/io.github.ladaapp.lada.metainfo.xml` (no push to origin yet)
41 | - [ ] Create Draft Release on GitHub and write release notes
42 | - [ ] Create Draft Release on Codeberg and write release notes
43 | - [ ] Build Docker image on Linux build machine
44 | - [ ] Build Flatpak on Linux build machine
45 | - [ ] Build Windows .exe on Windows build machine
46 | - [ ] Do smoke tests. If something looks off stop and revert changes
47 | - [ ] `git tag v ; git push origin ; git push origin tag v`
48 | - [ ] Open Draft Release on GitHub and link it to git tag
49 | - [ ] Open Draft Release on Codeberg and link it to git tag
50 | - [ ] Create a Pull Request for flathub/io.github.ladaapp.lada and adjust *commit* and *tag* accordingly
51 | - [ ] Upload Windows .7z files to GitHub Draft Release and to pixeldrain.com
52 | - [ ] Add links/description for both Windows download options in Codeberg and GitHub Draft releases
53 | - [ ] Publish Codeberg and GitHub Releases (make them non-draft)
54 | - [ ] Merge Flathub Pull Request
55 | - [ ] Push Docker image to Dockerhub including (v and latest tags)
56 | - [ ] Bump version in `lada/__init__.py` by adding `-dev` suffix to new version (and push to origin)
57 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/mosaic_detector.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from typing import Optional
5 |
6 | from lada.utils.ultralytics_utils import convert_yolo_boxes
7 | from lada.utils.box_utils import box_overlap
8 | from lada.utils import Image, Box, Detections, ultralytics_utils, Detection, DETECTION_CLASSES
9 | from lada.models.yolo.yolo import Yolo
10 |
11 | class MosaicDetector:
12 | def __init__(self, model: Yolo, device, min_confidence=0.8):
13 | self.model = model
14 | self.device = device
15 | self.batch_size = 4
16 | self.min_confidence = min_confidence
17 | self.min_positive_detections = 4
18 | self.sampling_rate = 0.3
19 |
20 | def detect_batch(self, images:list[Image], boxes:Optional[list[Box]]=None) -> bool:
21 | num_samples = min(len(images), max(1, int(len(images) * self.sampling_rate)))
22 | indices_step_size = len(images) // num_samples
23 | indices = list(range(0, num_samples*indices_step_size, indices_step_size))
24 | samples = [images[i] for i in indices]
25 | samples_boxes = [boxes[i] for i in indices] if boxes else None
26 |
27 | batches = [samples[i:i + self.batch_size] for i in range(0, len(samples), self.batch_size)]
28 | positive_detections = 0
29 | for batch_idx, batch in enumerate(batches):
30 | batch_prediction_results = self.model.predict(source=batch, stream=False, verbose=False, device=self.device, conf=self.min_confidence, imgsz=640)
31 | for result_idx, results in enumerate(batch_prediction_results):
32 | sample_idx = batch_idx * self.batch_size + result_idx
33 | detections = results.boxes.conf.size(dim=0)
34 | if detections == 0:
35 | continue
36 | detection_confidences = results.boxes.conf.tolist()
37 | detection_boxes = convert_yolo_boxes(results.boxes, results.orig_shape)
38 | single_image_watermark_detected = any(conf > self.min_confidence and (not samples_boxes or box_overlap(detection_boxes[i], samples_boxes[sample_idx])) for i, conf in enumerate(detection_confidences))
39 | if single_image_watermark_detected:
40 | positive_detections += 1
41 | if positive_detections >= self.min_positive_detections:
42 | return True
43 | return False
44 |
45 | def detect(self, source: str | Image) -> Detections | None:
46 | for yolo_results in self.model.predict(source=source, stream=False, verbose=False, device=self.device, conf=self.min_confidence, imgsz=640):
47 | detections = []
48 | if not yolo_results.boxes:
49 | return None
50 | for yolo_box, yolo_mask in zip(yolo_results.boxes, yolo_results.masks):
51 | mask = ultralytics_utils.convert_yolo_mask(yolo_mask, yolo_results.orig_img.shape)
52 | box = ultralytics_utils.convert_yolo_box(yolo_box, yolo_results.orig_img.shape)
53 | t, l, b, r = box
54 | width, height = r - l + 1, b - t + 1
55 | if min(width, height) < 20:
56 | # skip tiny detections
57 | continue
58 |
59 | detections.append(Detection(DETECTION_CLASSES["mosaic"]["cls"], box, mask))
60 | return Detections(yolo_results.orig_img, detections)
--------------------------------------------------------------------------------
/lada/utils/scene_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | import math
5 |
6 | from lada.utils import Box, Mask, Image, ImageTensor, MaskTensor
7 |
8 | def crop_to_box_v3(box: Box, img: Image | ImageTensor, mask_img: Mask | MaskTensor, target_size: tuple[int, int], max_box_expansion_factor=1.0, border_size=0):
9 | """
10 | Crops Mask and Image by using Box. Will try to grow Box to better fit target size
11 | Parameters
12 | ----------
13 | box
14 | img
15 | mask_img
16 | target_size
17 | max_box_expansion_factor: Limits how much to grow the Box before cropping. Could be useful for tiny Boxes (compared to given target size)
18 | border_size: includes area outside of box. useful to additional context outside the box detection
19 |
20 | Returns
21 | -------
22 | img, mask_img, cropped_box, scale_factor
23 | """
24 | target_width, target_height = target_size
25 | t, l, b, r = box
26 | width, height = r - l + 1, b - t + 1
27 | border_size = max(20, int(max(width, height) * border_size)) if border_size > 0. else 0
28 | t, l, b, r = max(0, t-border_size), max(0, l-border_size), min(img.shape[0]-1, b+border_size), min(img.shape[1]-1, r+border_size)
29 | width, height = r - l + 1, b - t + 1
30 | down_scale_factor = min(target_width / width, target_height / height)
31 | if down_scale_factor > 1.0:
32 | # we ignore upscaling for now as we first want to try expanding the box.
33 | down_scale_factor = 1.0
34 | missing_width, missing_height = int((target_width - (width * down_scale_factor)) / down_scale_factor), int((target_height - (height * down_scale_factor)) / down_scale_factor)
35 |
36 | available_width_l = l
37 | available_width_r = (img.shape[1]-1) - r
38 | available_height_t = t
39 | available_height_b = (img.shape[0]-1) - b
40 |
41 | budget_width = int(max_box_expansion_factor * width)
42 | budget_height = int(max_box_expansion_factor * height)
43 |
44 | expand_width_lr = min(available_width_l, available_width_r, missing_width//2, budget_width)
45 | expand_width_l = min(available_width_l - expand_width_lr, missing_width - expand_width_lr * 2, budget_width - expand_width_lr)
46 | expand_width_r = min(available_width_r - expand_width_lr, missing_width - expand_width_lr * 2 - expand_width_l, budget_width - expand_width_lr - expand_width_l)
47 |
48 | expand_height_tb = min(available_height_t, available_height_b, missing_height//2, budget_height)
49 | expand_height_t = min(available_height_t - expand_height_tb, missing_height - expand_height_tb * 2, budget_height - expand_height_tb)
50 | expand_height_b = min(available_height_b - expand_height_tb, missing_height - expand_height_tb * 2 - expand_height_t, budget_height - expand_height_tb - expand_height_t)
51 |
52 | l, r = (l - math.floor(expand_width_lr/2) - expand_width_l,
53 | r + math.ceil(expand_width_lr/2) + expand_width_r)
54 | t, b = (t - math.floor(expand_height_tb/2) - expand_height_t,
55 | b + math.ceil(expand_height_tb/2) + expand_height_b)
56 | img = img[t:b + 1, l:r + 1]
57 | mask_img = mask_img[t:b + 1, l:r + 1]
58 |
59 | width, height = r - l + 1, b - t + 1
60 | if down_scale_factor <= 1.0:
61 | scale_factor = down_scale_factor
62 | else:
63 | scale_factor = min(target_width / width, target_height / height)
64 |
65 | cropped_box = t, l, b, r
66 | assert img.shape[:2] == mask_img.shape[:2] == (cropped_box[2]-cropped_box[0]+1, cropped_box[3]-cropped_box[1]+1)
67 | return img, mask_img, cropped_box, scale_factor
68 |
69 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/unet_disc.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import torch.nn as nn
6 | from mmengine.model import BaseModule
7 | from torch.nn.utils import spectral_norm
8 |
9 | from .registry import MODELS
10 |
11 |
12 | @MODELS.register_module()
13 | class UNetDiscriminatorWithSpectralNorm(BaseModule):
14 | """A U-Net discriminator with spectral normalization.
15 |
16 | Args:
17 | in_channels (int): Channel number of the input.
18 | mid_channels (int, optional): Channel number of the intermediate
19 | features. Default: 64.
20 | skip_connection (bool, optional): Whether to use skip connection.
21 | Default: True.
22 | """
23 |
24 | def __init__(self, in_channels, mid_channels=64, skip_connection=True):
25 |
26 | super().__init__()
27 |
28 | self.skip_connection = skip_connection
29 |
30 | self.conv_0 = nn.Conv2d(
31 | in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
32 |
33 | # downsample
34 | self.conv_1 = spectral_norm(
35 | nn.Conv2d(mid_channels, mid_channels * 2, 4, 2, 1, bias=False))
36 | self.conv_2 = spectral_norm(
37 | nn.Conv2d(mid_channels * 2, mid_channels * 4, 4, 2, 1, bias=False))
38 | self.conv_3 = spectral_norm(
39 | nn.Conv2d(mid_channels * 4, mid_channels * 8, 4, 2, 1, bias=False))
40 |
41 | # upsample
42 | self.conv_4 = spectral_norm(
43 | nn.Conv2d(mid_channels * 8, mid_channels * 4, 3, 1, 1, bias=False))
44 | self.conv_5 = spectral_norm(
45 | nn.Conv2d(mid_channels * 4, mid_channels * 2, 3, 1, 1, bias=False))
46 | self.conv_6 = spectral_norm(
47 | nn.Conv2d(mid_channels * 2, mid_channels, 3, 1, 1, bias=False))
48 |
49 | # final layers
50 | self.conv_7 = spectral_norm(
51 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False))
52 | self.conv_8 = spectral_norm(
53 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False))
54 | self.conv_9 = nn.Conv2d(mid_channels, 1, 3, 1, 1)
55 |
56 | self.upsample = nn.Upsample(
57 | scale_factor=2, mode='bilinear', align_corners=False)
58 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
59 |
60 | def forward(self, img):
61 | """Forward function.
62 |
63 | Args:
64 | img (Tensor): Input tensor with shape (n, c, h, w).
65 |
66 | Returns:
67 | Tensor: Forward results.
68 | """
69 |
70 | feat_0 = self.lrelu(self.conv_0(img))
71 |
72 | # downsample
73 | feat_1 = self.lrelu(self.conv_1(feat_0))
74 | feat_2 = self.lrelu(self.conv_2(feat_1))
75 | feat_3 = self.lrelu(self.conv_3(feat_2))
76 |
77 | # upsample
78 | feat_3 = self.upsample(feat_3)
79 | feat_4 = self.lrelu(self.conv_4(feat_3))
80 | if self.skip_connection:
81 | feat_4 = feat_4 + feat_2
82 |
83 | feat_4 = self.upsample(feat_4)
84 | feat_5 = self.lrelu(self.conv_5(feat_4))
85 | if self.skip_connection:
86 | feat_5 = feat_5 + feat_1
87 |
88 | feat_5 = self.upsample(feat_5)
89 | feat_6 = self.lrelu(self.conv_6(feat_5))
90 | if self.skip_connection:
91 | feat_6 = feat_6 + feat_0
92 |
93 | # final layers
94 | out = self.lrelu(self.conv_7(feat_6))
95 | out = self.lrelu(self.conv_8(out))
96 |
97 | return self.conv_9(out)
98 |
--------------------------------------------------------------------------------
/scripts/training/train-mosaic-restoration-basicvsrpp.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import argparse
6 | import os
7 | import os.path as osp
8 | from mmengine.config import Config, DictAction
9 | from mmengine.runner import Runner
10 | from lada.models.basicvsrpp import register_all_modules
11 |
12 | register_all_modules()
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description='Train a model')
16 | parser.add_argument('config', help='train config file path')
17 | parser.add_argument('--work-dir', help='the dir to save logs and models')
18 | parser.add_argument(
19 | '--resume', action='store_true', help='Whether to resume checkpoint.')
20 | parser.add_argument(
21 | '--amp',
22 | action='store_true',
23 | default=False,
24 | help='enable automatic-mixed-precision training')
25 | parser.add_argument(
26 | '--auto-scale-lr',
27 | action='store_true',
28 | help='enable automatically scaling LR.')
29 | parser.add_argument(
30 | '--cfg-options',
31 | nargs='+',
32 | action=DictAction,
33 | help='override some settings in the used config, the key-value pair '
34 | 'in xxx=yyy format will be merged into config file. If the value to '
35 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
36 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
37 | 'Note that the quotation marks are necessary and that no white space '
38 | 'is allowed.')
39 | parser.add_argument(
40 | '--launcher',
41 | choices=['none', 'pytorch', 'slurm', 'mpi'],
42 | default='none',
43 | help='job launcher')
44 | parser.add_argument('--load-from', help='specify checkpoint path to start from')
45 | # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
46 | # will pass the `--local-rank` parameter to `tools/train.py` instead
47 | # of `--local_rank`.
48 | parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
49 | args = parser.parse_args()
50 | if 'LOCAL_RANK' not in os.environ:
51 | os.environ['LOCAL_RANK'] = str(args.local_rank)
52 |
53 | return args
54 |
55 |
56 | def main():
57 | args = parse_args()
58 |
59 | # load config
60 | cfg = Config.fromfile(args.config)
61 | cfg.launcher = args.launcher
62 | if args.cfg_options is not None:
63 | cfg.merge_from_dict(args.cfg_options)
64 |
65 | # work_dir is determined in this priority: CLI > segment in file > filename
66 | if args.work_dir: # none or empty str
67 | # update configs according to CLI args if args.work_dir is not None
68 | cfg.work_dir = args.work_dir
69 | elif cfg.get('work_dir', None) is None:
70 | # use config filename as default work_dir if cfg.work_dir is None
71 | cfg.work_dir = osp.join('./work_dirs',
72 | osp.splitext(osp.basename(args.config))[0])
73 |
74 | if args.resume:
75 | cfg.resume = True
76 | if args.load_from:
77 | cfg.load_from = args.load_from
78 |
79 | # build the runner from config
80 | runner = Runner.from_cfg(cfg)
81 |
82 | print(f'Working directory: {cfg.work_dir}')
83 | print(f'Log directory: {runner._log_dir}')
84 |
85 | # start training
86 | runner.train()
87 |
88 | print(f'Log saved under {runner._log_dir}')
89 | print(f'Checkpoint saved under {cfg.work_dir}')
90 |
91 |
92 | if __name__ == '__main__':
93 | main()
94 |
--------------------------------------------------------------------------------
/lada/datasetcreation/detectors/watermark_detector.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Lada Authors
2 | # SPDX-License-Identifier: AGPL-3.0
3 |
4 | from typing import Optional
5 |
6 | from lada.utils.ultralytics_utils import convert_yolo_boxes
7 | from lada.utils.box_utils import box_overlap
8 | from lada.utils import Image, Box, Detections, ultralytics_utils, Detection, DETECTION_CLASSES, mask_utils
9 | from lada.models.yolo.yolo import Yolo
10 |
11 | class WatermarkDetector:
12 | def __init__(self, model: Yolo, device, min_confidence=0.4):
13 | self.model = model
14 | self.device = device
15 | self.batch_size = 4
16 | self.min_confidence = min_confidence
17 | self.min_positive_detections = 4
18 | self.sampling_rate = 0.3
19 |
20 | def detect_batch(self, images:list[Image], boxes:Optional[list[Box]]=None) -> bool:
21 | num_samples = min(len(images), max(1, int(len(images) * self.sampling_rate)))
22 | indices_step_size = len(images) // num_samples
23 | indices = list(range(0, num_samples*indices_step_size, indices_step_size))
24 | samples = [images[i] for i in indices]
25 | samples_boxes = [boxes[i] for i in indices] if boxes else None
26 |
27 | batches = [samples[i:i + self.batch_size] for i in range(0, len(samples), self.batch_size)]
28 | positive_detections = 0
29 | for batch_idx, batch in enumerate(batches):
30 | # not exactly sure why but prediction accuracy is horrible if not setting imgsz to 640 even though model was trained with 512 in train-yolo-watermark-detector.py.
31 | batch_prediction_results = self.model.predict(source=batch, stream=False, verbose=False, device=self.device, conf=self.min_confidence, imgsz=640)
32 | for result_idx, results in enumerate(batch_prediction_results):
33 | sample_idx = batch_idx * self.batch_size + result_idx
34 | detections = results.boxes.conf.size(dim=0)
35 | if detections == 0:
36 | continue
37 | detection_confidences = results.boxes.conf.tolist()
38 | detection_boxes = convert_yolo_boxes(results.boxes, results.orig_shape)
39 | single_image_watermark_detected = any(conf > self.min_confidence and (not samples_boxes or box_overlap(detection_boxes[i], samples_boxes[sample_idx])) for i, conf in enumerate(detection_confidences))
40 | if single_image_watermark_detected:
41 | positive_detections += 1
42 | watermark_detected = positive_detections >= self.min_positive_detections
43 | #print(f"watermark detector: watermark {watermark_detected}, detected {positive_detections}/{len(samples)}")
44 | return watermark_detected
45 |
46 | def detect(self, source: str | Image) -> Detections | None:
47 | for yolo_results in self.model.predict(source=source, stream=False, verbose=False, device=self.device, conf=self.min_confidence, imgsz=640):
48 | detections = []
49 | if not yolo_results.boxes:
50 | return None
51 | for yolo_box in yolo_results.boxes:
52 | box = ultralytics_utils.convert_yolo_box(yolo_box, yolo_results.orig_img.shape)
53 | mask = mask_utils.box_to_mask(box, yolo_results.orig_img.shape, DETECTION_CLASSES["watermark"]["mask_value"])
54 | t, l, b, r = box
55 | width, height = r - l + 1, b - t + 1
56 | if min(width, height) < 40:
57 | # skip tiny detections
58 | continue
59 |
60 | detections.append(Detection(DETECTION_CLASSES["watermark"]["cls"], box, mask))
61 | return Detections(yolo_results.orig_img, detections)
--------------------------------------------------------------------------------
/lada/gui/export/shutdown_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import sys
5 |
6 | from gi.repository import Gio, GLib
7 |
8 | class ShutdownError(Exception):
9 | pass
10 |
11 | class ShutdownManager:
12 | def __init__(self):
13 | self.bus = Gio.bus_get_sync(Gio.BusType.SESSION, None)
14 |
15 | def _call_dbus_method(self, service_name, object_path, interface_name, method_name, parameters=None):
16 | dbus_proxy = Gio.DBusProxy.new_sync(
17 | self.bus,
18 | Gio.DBusProxyFlags.NONE,
19 | None,
20 | service_name,
21 | object_path,
22 | interface_name,
23 | None
24 | )
25 |
26 | response = dbus_proxy.call_sync(
27 | method_name,
28 | parameters,
29 | Gio.DBusCallFlags.NONE,
30 | 1_000, # 1 sec timeout
31 | None # Not cancellable
32 | )
33 | return response
34 |
35 | def shutdown_windows(self):
36 | try:
37 | subprocess.run(["shutdown", "/s", "/t", "0"], check=True)
38 | except subprocess.CalledProcessError as e:
39 | raise ShutdownError(e)
40 |
41 | def shutdown_linux_generic(self):
42 | try:
43 | subprocess.run(["shutdown", "now"], check=True)
44 | except subprocess.CalledProcessError as e:
45 | raise ShutdownError(e)
46 |
47 | def shutdown_linux_kde(self):
48 | try:
49 | self._call_dbus_method(
50 | "org.kde.Shutdown",
51 | "/Shutdown",
52 | "org.kde.Shutdown",
53 | "logoutAndShutdown"
54 | )
55 | except GLib.GError as e:
56 | raise ShutdownError(e)
57 |
58 | def shutdown_linux_gnome(self):
59 | try:
60 | self._call_dbus_method(
61 | "org.gnome.SessionManager",
62 | "/org/gnome/SessionManager",
63 | "org.gnome.SessionManager",
64 | "Shutdown"
65 | )
66 | except GLib.GError as e:
67 | if e.code == 24: # user clicked Cancel
68 | return
69 | print("domain", e.domain, "code", e.code, "message", e.message)
70 | raise ShutdownError(e)
71 |
72 | def is_service_registered(self, service_name: str) -> bool:
73 | response = self._call_dbus_method(
74 | "org.freedesktop.DBus",
75 | "/org/freedesktop/DBus",
76 | "org.freedesktop.DBus",
77 | "ListNames"
78 | )
79 |
80 | if response:
81 | names = response.unpack()[0] # The result is a tuple, and the first item is the list of names
82 | return service_name in names
83 | return False
84 |
85 | def shutdown(self):
86 | linux_desktop_env = os.getenv("XDG_CURRENT_DESKTOP", "").upper()
87 | linux_session_desktop_env = os.getenv("XDG_SESSION_DESKTOP", "").upper()
88 | if sys.platform == "win32":
89 | self.shutdown_windows()
90 | elif "KDE" in linux_desktop_env or "KDE" in linux_session_desktop_env:
91 | self.shutdown_linux_kde()
92 | elif ("GNOME" in linux_desktop_env or "GNOME" in linux_session_desktop_env) and self.is_service_registered("org.gnome.SessionManager"):
93 | self.shutdown_linux_gnome()
94 | elif shutil.which("shutdown") is not None:
95 | self.shutdown_linux_generic()
96 | else:
97 | raise ShutdownError("Couldn't find any means to shutdown the system")
98 |
99 | if __name__ == "__main__":
100 | shutdown_manager = ShutdownManager()
101 | shutdown_manager.shutdown()
102 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/iter_time_hook.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import time
6 | from typing import Optional, Sequence, Union
7 |
8 | from mmengine.hooks import IterTimerHook as BaseIterTimerHook
9 | from mmengine.structures import BaseDataElement
10 |
11 | from .registry import HOOKS
12 |
13 | DATA_BATCH = Optional[Sequence[dict]]
14 |
15 |
16 | @HOOKS.register_module()
17 | class IterTimerHook(BaseIterTimerHook):
18 | """IterTimerHooks inherits from :class:`mmengine.hooks.IterTimerHook` and
19 | overwrites :meth:`self._after_iter`.
20 |
21 | This hooks should be used along with
22 | :class:`mmagic.engine.runner.MultiValLoop` and
23 | :class:`mmagic.engine.runner.MultiTestLoop`.
24 | """
25 |
26 | def _after_iter(self,
27 | runner,
28 | batch_idx: int,
29 | data_batch: DATA_BATCH = None,
30 | outputs: Optional[Union[dict,
31 | Sequence[BaseDataElement]]] = None,
32 | mode: str = 'train') -> None:
33 | """Calculating time for an iteration and updating "time"
34 | ``HistoryBuffer`` of ``runner.message_hub``. If `mode` is 'train', we
35 | take `runner.max_iters` as the total iterations and calculate the rest
36 | time. If `mode` in `val` or `test`, we use
37 | `runner.val_loop.total_length` or `runner.test_loop.total_length` as
38 | total number of iterations. If you want to know how `total_length` is
39 | calculated, please refers to
40 | :meth:`mmagic.engine.runner.MultiValLoop.run` and
41 | :meth:`mmagic.engine.runner.MultiTestLoop.run`.
42 |
43 | Args:
44 | runner (Runner): The runner of the training validation and
45 | testing process.
46 | batch_idx (int): The index of the current batch in the loop.
47 | data_batch (Sequence[dict], optional): Data from dataloader.
48 | Defaults to None.
49 | outputs (dict or sequence, optional): Outputs from model. Defaults
50 | to None.
51 | mode (str): Current mode of runner. Defaults to 'train'.
52 | """
53 | # Update iteration time in `runner.message_hub`.
54 | message_hub = runner.message_hub
55 | message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
56 | self.t = time.time()
57 | window_size = runner.log_processor.window_size
58 | # Calculate eta every `window_size` iterations. Since test and val
59 | # loop will not update runner.iter, use `every_n_inner_iters`to check
60 | # the interval.
61 | if self.every_n_inner_iters(batch_idx, window_size):
62 | iter_time = message_hub.get_scalar(f'{mode}/time').mean(
63 | window_size)
64 | if mode == 'train':
65 | self.time_sec_tot += iter_time * window_size
66 | # Calculate average iterative time.
67 | time_sec_avg = self.time_sec_tot / (
68 | runner.iter - self.start_iter + 1)
69 | # Calculate eta.
70 | eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
71 | runner.message_hub.update_info('eta', eta_sec)
72 | else:
73 | if mode == 'val':
74 | total_length = runner.val_loop.total_length
75 | else:
76 | total_length = runner.test_loop.total_length
77 |
78 | eta_sec = iter_time * (total_length - batch_idx - 1)
79 | runner.message_hub.update_info('eta', eta_sec)
80 |
--------------------------------------------------------------------------------
/lada/models/basicvsrpp/mmagic/model_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: OpenMMLab. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0 AND AGPL-3.0
3 | # Code vendored from: https://github.com/open-mmlab/mmagic
4 |
5 | import logging
6 | from typing import Any, Dict, List, Optional, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | from mmengine.model.weight_init import (constant_init, kaiming_init,
11 | normal_init, update_init_info,
12 | xavier_init)
13 | from mmengine.registry import Registry
14 | from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
15 |
16 |
17 | def default_init_weights(module, scale=1):
18 | """Initialize network weights.
19 |
20 | Args:
21 | modules (nn.Module): Modules to be initialized.
22 | scale (float): Scale initialized weights, especially for residual
23 | blocks. Default: 1.
24 | """
25 | for m in module.modules():
26 | if isinstance(m, nn.Conv2d):
27 | kaiming_init(m, a=0, mode='fan_in', bias=0)
28 | m.weight.data *= scale
29 | elif isinstance(m, nn.Linear):
30 | kaiming_init(m, a=0, mode='fan_in', bias=0)
31 | m.weight.data *= scale
32 | elif isinstance(m, _BatchNorm):
33 | constant_init(m.weight, val=1, bias=0)
34 |
35 |
36 | def make_layer(block, num_blocks, **kwarg):
37 | """Make layers by stacking the same blocks.
38 |
39 | Args:
40 | block (nn.module): nn.module class for basic block.
41 | num_blocks (int): number of blocks.
42 |
43 | Returns:
44 | nn.Sequential: Stacked blocks in nn.Sequential.
45 | """
46 | layers = []
47 | for _ in range(num_blocks):
48 | layers.append(block(**kwarg))
49 | return nn.Sequential(*layers)
50 |
51 |
52 | def get_module_device(module):
53 | """Get the device of a module.
54 |
55 | Args:
56 | module (nn.Module): A module contains the parameters.
57 |
58 | Returns:
59 | torch.device: The device of the module.
60 | """
61 | try:
62 | next(module.parameters())
63 | except StopIteration:
64 | raise ValueError('The input module should contain parameters.')
65 |
66 | if next(module.parameters()).is_cuda:
67 | return next(module.parameters()).get_device()
68 | else:
69 | return torch.device('cpu')
70 |
71 |
72 | def set_requires_grad(nets, requires_grad=False):
73 | """Set requires_grad for all the networks.
74 |
75 | Args:
76 | nets (nn.Module | list[nn.Module]): A list of networks or a single
77 | network.
78 | requires_grad (bool): Whether the networks require gradients or not
79 | """
80 | if not isinstance(nets, list):
81 | nets = [nets]
82 | for net in nets:
83 | if net is not None:
84 | for param in net.parameters():
85 | param.requires_grad = requires_grad
86 |
87 |
88 | def build_module(module: Union[dict, nn.Module], builder: Registry, *args,
89 | **kwargs) -> Any:
90 | """Build module from config or return the module itself.
91 |
92 | Args:
93 | module (Union[dict, nn.Module]): The module to build.
94 | builder (Registry): The registry to build module.
95 | *args, **kwargs: Arguments passed to build function.
96 |
97 | Returns:
98 | Any: The built module.
99 | """
100 | if isinstance(module, dict):
101 | return builder.build(module, *args, **kwargs)
102 | elif isinstance(module, nn.Module):
103 | return module
104 | else:
105 | raise TypeError(
106 | f'Only support dict and nn.Module, but got {type(module)}.')
107 |
--------------------------------------------------------------------------------