├── .github
├── FUNDING.yml
└── screenshot.png
├── .gitignore
├── LICENSE
├── README.md
├── Start.bat
├── Start_Portable.bat
├── Update_Portable.bat
├── app
├── helpers
│ ├── downloader.py
│ ├── integrity_checker.py
│ ├── miscellaneous.py
│ ├── recording.py
│ └── typing_helper.py
├── onnxmodels
│ ├── .gitkeep
│ ├── dfm_models
│ │ └── .keep
│ └── place_model_files_here
├── processors
│ ├── external
│ │ ├── cliplib
│ │ │ ├── __init__.py
│ │ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ │ ├── clip.py
│ │ │ ├── model.py
│ │ │ └── simple_tokenizer.py
│ │ ├── clipseg.py
│ │ └── resnet.py
│ ├── face_detectors.py
│ ├── face_editors.py
│ ├── face_landmark_detectors.py
│ ├── face_masks.py
│ ├── face_restorers.py
│ ├── face_swappers.py
│ ├── frame_enhancers.py
│ ├── models_data.py
│ ├── models_processor.py
│ ├── utils
│ │ ├── dfm_model.py
│ │ ├── engine_builder.py
│ │ ├── faceutil.py
│ │ └── tensorrt_predictor.py
│ ├── video_processor.py
│ └── workers
│ │ └── frame_worker.py
└── ui
│ ├── core
│ ├── MainWindow.ui
│ ├── convert_ui_to_py.bat
│ ├── main_window.py
│ ├── media.qrc
│ ├── media
│ │ ├── OffState.png
│ │ ├── OnState.png
│ │ ├── add_marker_hover.png
│ │ ├── add_marker_off.png
│ │ ├── audio_off.png
│ │ ├── audio_on.png
│ │ ├── fullscreen.png
│ │ ├── image.png
│ │ ├── marker.png
│ │ ├── marker_save.png
│ │ ├── next_marker_hover.png
│ │ ├── next_marker_off.png
│ │ ├── open_file.png
│ │ ├── play_hover.png
│ │ ├── play_off.png
│ │ ├── play_on.png
│ │ ├── previous_marker_hover.png
│ │ ├── previous_marker_off.png
│ │ ├── rec_hover.png
│ │ ├── rec_off.png
│ │ ├── rec_on.png
│ │ ├── remove_marker_hover.png
│ │ ├── remove_marker_off.png
│ │ ├── repeat.png
│ │ ├── reset_default.png
│ │ ├── save.png
│ │ ├── save_file.png
│ │ ├── save_file_as.png
│ │ ├── splash.png
│ │ ├── splash_next.png
│ │ ├── stop_hover.png
│ │ ├── stop_off.png
│ │ ├── stop_on.png
│ │ ├── tl_beg_hover.png
│ │ ├── tl_beg_off.png
│ │ ├── tl_beg_on.png
│ │ ├── tl_left_hover.png
│ │ ├── tl_left_off.png
│ │ ├── tl_left_on.png
│ │ ├── tl_right_hover.png
│ │ ├── tl_right_off.png
│ │ ├── tl_right_on.png
│ │ ├── video.png
│ │ ├── visomaster_full.png
│ │ ├── visomaster_small.png
│ │ └── webcam.png
│ ├── media_rc.py
│ └── proxy_style.py
│ ├── main_ui.py
│ ├── styles
│ ├── dark_styles.qss
│ └── light_styles.qss
│ └── widgets
│ ├── actions
│ ├── card_actions.py
│ ├── common_actions.py
│ ├── control_actions.py
│ ├── filter_actions.py
│ ├── graphics_view_actions.py
│ ├── layout_actions.py
│ ├── list_view_actions.py
│ ├── save_load_actions.py
│ └── video_control_actions.py
│ ├── common_layout_data.py
│ ├── event_filters.py
│ ├── face_editor_layout_data.py
│ ├── settings_layout_data.py
│ ├── swapper_layout_data.py
│ ├── ui_workers.py
│ └── widget_components.py
├── dependencies
└── .gitkeep
├── download_models.py
├── main.py
├── model_assets
├── dfm_models
│ └── .gitkeep
├── grid_sample_3d_plugin.dll
├── libgrid_sample_3d_plugin.so
├── liveportrait_onnx
│ └── lip_array.pkl
└── meanshape_68.pkl
├── requirements_cu118.txt
├── requirements_cu124.txt
├── scripts
├── setenv.bat
├── update_cu118.bat
└── update_cu124.bat
└── tools
└── convert_old_rope_embeddings.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | polar: # Replace with a single Polar username
13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
14 | thanks_dev: # Replace with a single thanks.dev username
15 | custom: ['https://github.com/visomaster/VisoMaster?tab=readme-ov-file#support-the-project']
16 |
--------------------------------------------------------------------------------
/.github/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/.github/screenshot.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/**
2 |
3 | *.ckpt
4 | *.pth
5 | *.onnx
6 | *.engine
7 | *.profile
8 | *.timing
9 | *.dfm
10 | *.trt
11 | models/liveportrait_onnx/*.onnx
12 | models/liveportrait_onnx/*.trt
13 | saved_parameters*.json
14 | startup_parameters*.json
15 | data.json
16 | merged_embeddings*.txt
17 | .vs
18 | *.sln
19 | *.pyproj
20 | *.json
21 | .vscode/
22 | tensorrt-engines/
23 | source_videos/
24 | source_images/
25 | output/
26 | test_frames*/
27 | test_videos*/
28 | install.dat
29 | visomaster.ico
30 | dependencies/CUDA/
31 | dependencies/Python/
32 | dependencies/git-portable/
33 | dependencies/TensorRT/
34 | *.mp4
35 | *.jpg
36 | *.exe
37 | .thumbnails
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # VisoMaster
3 | ### VisoMaster is a powerful yet easy-to-use tool for face swapping and editing in images and videos. It utilizes AI to produce natural-looking results with minimal effort, making it ideal for both casual users and professionals.
4 |
5 | ---
6 |
7 |
8 | ## Features
9 |
10 | ### 🔄 **Face Swap**
11 | - Supports multiple face swapper models
12 | - Compatible with DeepFaceLab trained models (DFM)
13 | - Advanced multi-face swapping with masking options for each facial part
14 | - Occlusion masking support (DFL XSeg Masking)
15 | - Works with all popular face detectors & landmark detectors
16 | - Expression Restorer: Transfers original expressions to the swapped face
17 | - Face Restoration: Supports all popular upscaling & enhancement models
18 |
19 | ### 🎭 **Face Editor (LivePortrait Models)**
20 | - Manually adjust expressions and poses for different face parts
21 | - Fine-tune colors for Face, Hair, Eyebrows, and Lips using RGB adjustments
22 |
23 | ### 🚀 **Other Powerful Features**
24 | - **Live Playback**: See processed video in real-time before saving
25 | - **Face Embeddings**: Use multiple source faces for better accuracy & similarity
26 | - **Live Swapping via Webcam**: Stream to virtual camera for Twitch, YouTube, Zoom, etc.
27 | - **User-Friendly Interface**: Intuitive and easy to use
28 | - **Video Markers**: Adjust settings per frame for precise results
29 | - **TensorRT Support**: Leverages supported GPUs for ultra-fast processing
30 | - **Many More Advanced Features** 🎉
31 |
32 | ## Automatic Installation (Windows)
33 | - For Windows users with an Nvidia GPU, we provide an automatic installer for easy set up.
34 | - You can get the installer from the [releases](https://github.com/visomaster/VisoMaster/releases/tag/v0.1.1) page or from this [link](https://github.com/visomaster/VisoMaster/releases/download/v0.1.1/VisoMaster_Setup.exe).
35 | - Choose the correct CUDA version inside the installer based on your GPU Compatibility.
36 | - After successful installation, go to your installed directory and run the **Start_Portable.bat** file to launch **VisoMaster**
37 |
38 | ## **Manual Installation Guide (Nvidia)**
39 |
40 | Follow the steps below to install and run **VisoMaster** on your system.
41 |
42 | ## **Prerequisites**
43 | Before proceeding, ensure you have the following installed on your system:
44 | - **Git** ([Download](https://git-scm.com/downloads))
45 | - **Miniconda** ([Download](https://www.anaconda.com/download))
46 |
47 | ---
48 |
49 | ## **Installation Steps**
50 |
51 | ### **1. Clone the Repository**
52 | Open a terminal or command prompt and run:
53 | ```sh
54 | git clone https://github.com/visomaster/VisoMaster.git
55 | ```
56 | ```sh
57 | cd VisoMaster
58 | ```
59 |
60 | ### **2. Create and Activate a Conda Environment**
61 | ```sh
62 | conda create -n visomaster python=3.10.13 -y
63 | ```
64 | ```sh
65 | conda activate visomaster
66 | ```
67 |
68 | ### **3. Install CUDA and cuDNN**
69 | ```sh
70 | conda install -c nvidia/label/cuda-12.4.1 cuda-runtime
71 | ```
72 | ```sh
73 | conda install -c conda-forge cudnn
74 | ```
75 |
76 | ### **4. Install Additional Dependencies**
77 | ```sh
78 | conda install scikit-image
79 | ```
80 | ```sh
81 | pip install -r requirements_cu124.txt
82 | ```
83 |
84 | ### **5. Download Models and Other Dependencies**
85 | 1. Download all the required models
86 | ```sh
87 | python download_models.py
88 | ```
89 | 2. Download all the files from this [page](https://github.com/visomaster/visomaster-assets/releases/tag/v0.1.0_dp) and copy it to the ***dependencies/*** folder.
90 |
91 | **Note**: You do not need to download the Source code (zip) and Source code (tar.gz) files
92 | ### **6. Run the Application**
93 | Once everything is set up, start the application by opening the **Start.bat** file.
94 | On Linux just run `python main.py`.
95 | ---
96 |
97 | ## **Troubleshooting**
98 | - If you face CUDA-related issues, ensure your GPU drivers are up to date.
99 | - For missing models, double-check that all models are placed in the correct directories.
100 |
101 | ## [Join Discord](https://discord.gg/5rx4SQuDbp)
102 |
103 | ## Support The Project ##
104 | This project was made possible by the combined efforts of **[@argenspin](https://github.com/argenspin)** and **[@Alucard24](https://github.com/alucard24)** with the support of countless other members in our Discord community. If you wish to support us for the continued development of **Visomaster**, you can donate to either of us (or Both if you're double Awesome :smiley: )
105 |
106 | ### **argenspin** ###
107 | - [BuyMeACoffee](https://buymeacoffee.com/argenspin)
108 | - BTC: bc1qe8y7z0lkjsw6ssnlyzsncw0f4swjgh58j9vrqm84gw2nscgvvs5s4fts8g
109 | - ETH: 0x967a442FBd13617DE8d5fDC75234b2052122156B
110 | ### **Alucard24** ###
111 | - [BuyMeACoffee](https://buymeacoffee.com/alucard_24)
112 | - [PayPal](https://www.paypal.com/donate/?business=XJX2E5ZTMZUSQ&no_recurring=0&item_name=Support+us+with+a+donation!+Your+contribution+helps+us+continue+improving+and+providing+quality+content.+Thank+you!¤cy_code=EUR)
113 | - BTC: 15ny8vV3ChYsEuDta6VG3aKdT6Ra7duRAc
114 |
115 |
116 | ## Disclaimer: ##
117 | **VisoMaster** is a hobby project that we are making available to the community as a thank you to all of the contributors ahead of us.
118 | We've copied the disclaimer from [Swap-Mukham](https://github.com/harisreedhar/Swap-Mukham) here since it is well-written and applies 100% to this repo.
119 |
120 | We would like to emphasize that our swapping software is intended for responsible and ethical use only. We must stress that users are solely responsible for their actions when using our software.
121 |
122 | Intended Usage: This software is designed to assist users in creating realistic and entertaining content, such as movies, visual effects, virtual reality experiences, and other creative applications. We encourage users to explore these possibilities within the boundaries of legality, ethical considerations, and respect for others' privacy.
123 |
124 | Ethical Guidelines: Users are expected to adhere to a set of ethical guidelines when using our software. These guidelines include, but are not limited to:
125 |
126 | Not creating or sharing content that could harm, defame, or harass individuals. Obtaining proper consent and permissions from individuals featured in the content before using their likeness. Avoiding the use of this technology for deceptive purposes, including misinformation or malicious intent. Respecting and abiding by applicable laws, regulations, and copyright restrictions.
127 |
128 | Privacy and Consent: Users are responsible for ensuring that they have the necessary permissions and consents from individuals whose likeness they intend to use in their creations. We strongly discourage the creation of content without explicit consent, particularly if it involves non-consensual or private content. It is essential to respect the privacy and dignity of all individuals involved.
129 |
130 | Legal Considerations: Users must understand and comply with all relevant local, regional, and international laws pertaining to this technology. This includes laws related to privacy, defamation, intellectual property rights, and other relevant legislation. Users should consult legal professionals if they have any doubts regarding the legal implications of their creations.
131 |
132 | Liability and Responsibility: We, as the creators and providers of the deep fake software, cannot be held responsible for the actions or consequences resulting from the usage of our software. Users assume full liability and responsibility for any misuse, unintended effects, or abusive behavior associated with the content they create.
133 |
134 | By using this software, users acknowledge that they have read, understood, and agreed to abide by the above guidelines and disclaimers. We strongly encourage users to approach this technology with caution, integrity, and respect for the well-being and rights of others.
135 |
136 | Remember, technology should be used to empower and inspire, not to harm or deceive. Let's strive for ethical and responsible use of deep fake technology for the betterment of society.
137 |
--------------------------------------------------------------------------------
/Start.bat:
--------------------------------------------------------------------------------
1 |
2 | call conda activate visomaster
3 | call app/ui/core/convert_ui_to_py.bat
4 | SET APP_ROOT=%~dp0
5 | SET APP_ROOT=%APP_ROOT:~0,-1%
6 | SET DEPENDENCIES=%APP_ROOT%\dependencies
7 | echo %DEPENDENCIES%
8 | SET PATH=%DEPENDENCIES%;%PATH%
9 | python main.py
10 | pause
--------------------------------------------------------------------------------
/Start_Portable.bat:
--------------------------------------------------------------------------------
1 | call scripts\setenv.bat
2 | "%PYTHON_EXECUTABLE%" main.py
3 | pause
--------------------------------------------------------------------------------
/Update_Portable.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | :: Check if install.dat exists
4 | if not exist install.dat (
5 | echo install.dat file not found!
6 | pause
7 | exit /b 1
8 | )
9 |
10 | :: Read the cuda_version from install.dat
11 | for /f "tokens=2 delims==" %%A in ('findstr "cuda_version" install.dat') do set CUDA_VERSION=%%A
12 |
13 | call scripts\update_%CUDA_VERSION%.bat
14 | pause
15 |
--------------------------------------------------------------------------------
/app/helpers/downloader.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from pathlib import Path
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 | from app.helpers.integrity_checker import check_file_integrity
8 |
9 | def download_file(model_name: str, file_path: str, correct_hash: str, url: str) -> bool:
10 | """
11 | Downloads a file and verifies its integrity.
12 |
13 | Parameters:
14 | - model_name (str): Name of the model being downloaded.
15 | - file_path (str): Path where the file will be saved.
16 | - correct_hash (str): Expected hash value of the file for integrity check.
17 | - url (str): URL to download the file from.
18 |
19 | Returns:
20 | - bool: True if the file is downloaded and verified successfully, False otherwise.
21 | """
22 | # Remove the file if it already exists and restart download
23 | if Path(file_path).is_file():
24 | if check_file_integrity(file_path, correct_hash):
25 | print(f"\nSkipping {model_name} as it is already downloaded!")
26 | return True
27 | else:
28 | print(f"\n{file_path} already exists, but its file integrity couldn't be verified. Re-downloading it!")
29 | os.remove(file_path)
30 |
31 | print(f"\nDownloading {model_name} from {url}")
32 |
33 | try:
34 | response = requests.get(url, stream=True, timeout=5)
35 | response.raise_for_status() # Raise an error for bad HTTP responses (e.g., 404, 500)
36 | except requests.exceptions.RequestException as e:
37 | print(f"Failed to download {model_name}: {e}")
38 | return False
39 |
40 | total_size = int(response.headers.get("content-length", 0)) # File size in bytes
41 | block_size = 1024 # Size of chunks to download
42 | max_attempts = 3
43 | attempt = 1
44 |
45 | def download_and_save():
46 | """Handles the file download and saves it to disk."""
47 | with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
48 | with open(file_path, "wb") as file:
49 | for data in response.iter_content(block_size):
50 | progress_bar.update(len(data))
51 | file.write(data)
52 |
53 | while attempt <= max_attempts:
54 | try:
55 | download_and_save()
56 |
57 | # Verify file integrity
58 | if check_file_integrity(file_path, correct_hash):
59 | print("File integrity verified successfully!")
60 | print(f"File saved at: {file_path}")
61 | return True
62 | else:
63 | print(f"Integrity check failed for {file_path}. Retrying download (Attempt {attempt}/{max_attempts})...")
64 | os.remove(file_path)
65 | attempt += 1
66 | except requests.exceptions.Timeout:
67 | print("Connection timed out! Retrying download...")
68 | attempt += 1
69 | except Exception as e:
70 | print(f"An error occurred during download: {e}")
71 | attempt += 1
72 |
73 | print(f"Failed to download {model_name} after {max_attempts} attempts.")
74 | return False
--------------------------------------------------------------------------------
/app/helpers/integrity_checker.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 |
3 | BUF_SIZE = 131072 # read in 128kb chunks!
4 |
5 | def get_file_hash(file_path: str) -> str:
6 | hash_sha256 = hashlib.sha256()
7 |
8 | with open(file_path, 'rb') as f:
9 | while True:
10 | data = f.read(BUF_SIZE)
11 | if not data:
12 | break
13 | hash_sha256.update(data)
14 |
15 | # print("SHA256: {0}".format(hash_sha256.hexdigest()))
16 | return hash_sha256.hexdigest()
17 |
18 | def write_hash_to_file(hash: str, hash_file_path: str):
19 | with open(hash_file_path, 'w') as hash_file:
20 | hash_file.write(hash)
21 |
22 | def get_hash_from_hash_file(hash_file_path: str) -> str:
23 | with open(hash_file_path, 'r') as hash_file:
24 | hash_sha256 = hash_file.read().strip()
25 | return hash_sha256
26 |
27 | def check_file_integrity(file_path, correct_hash) -> bool:
28 | actual_hash = get_file_hash(file_path)
29 | return actual_hash==correct_hash
--------------------------------------------------------------------------------
/app/helpers/miscellaneous.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import cv2
4 | import time
5 | from collections import UserDict
6 | import hashlib
7 | import numpy as np
8 | from functools import wraps
9 | from datetime import datetime
10 | from pathlib import Path
11 | from torchvision.transforms import v2
12 | import threading
13 | lock = threading.Lock()
14 |
15 | image_extensions = ('.jpg', '.jpeg', '.jpe', '.png', '.webp', '.tif', '.tiff', '.jp2', '.exr', '.hdr', '.ras', '.pnm', '.ppm', '.pgm', '.pbm', '.pfm')
16 | video_extensions = ('.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.m4v', '.3gp', '.gif')
17 |
18 | DFM_MODELS_PATH = './model_assets/dfm_models'
19 |
20 | DFM_MODELS_DATA = {}
21 |
22 | # Datatype used for storing parameter values
23 | # Major use case for subclassing this is to fallback to a default value, when trying to access value from a non-existing key
24 | # Helps when saving/importing workspace or parameters from external file after a future update including new Parameter widgets
25 | class ParametersDict(UserDict):
26 | def __init__(self, parameters, default_parameters: dict):
27 | super().__init__(parameters)
28 | self._default_parameters = default_parameters
29 |
30 | def __getitem__(self, key):
31 | try:
32 | return self.data[key]
33 | except KeyError:
34 | self.__setitem__(key, self._default_parameters[key])
35 | return self._default_parameters[key]
36 |
37 | def get_scaling_transforms():
38 | t512 = v2.Resize((512, 512), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
39 | t384 = v2.Resize((384, 384), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
40 | t256 = v2.Resize((256, 256), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
41 | t128 = v2.Resize((128, 128), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
42 | return t512, t384, t256, t128
43 |
44 | t512, t384, t256, t128 = get_scaling_transforms()
45 |
46 | def absoluteFilePaths(directory: str, include_subfolders=False):
47 | if include_subfolders:
48 | for dirpath,_,filenames in os.walk(directory):
49 | for f in filenames:
50 | yield os.path.abspath(os.path.join(dirpath, f))
51 | else:
52 | for filename in os.listdir(directory):
53 | file_path = os.path.join(directory, filename)
54 | if os.path.isfile(file_path):
55 | yield file_path
56 |
57 | def truncate_text(text):
58 | if len(text) >= 35:
59 | return f'{text[:32]}...'
60 | return text
61 |
62 | def get_video_files(folder_name, include_subfolders=False):
63 | return [f for f in absoluteFilePaths(folder_name, include_subfolders) if f.lower().endswith(video_extensions)]
64 |
65 | def get_image_files(folder_name, include_subfolders=False):
66 | return [f for f in absoluteFilePaths(folder_name, include_subfolders) if f.lower().endswith(image_extensions)]
67 |
68 | def is_image_file(file_name: str):
69 | return file_name.lower().endswith(image_extensions)
70 |
71 | def is_video_file(file_name: str):
72 | return file_name.lower().endswith(video_extensions)
73 |
74 | def is_file_exists(file_path: str) -> bool:
75 | if not file_path:
76 | return False
77 | return Path(file_path).is_file()
78 |
79 | def get_file_type(file_name):
80 | if is_image_file(file_name):
81 | return 'image'
82 | if is_video_file(file_name):
83 | return 'video'
84 | return None
85 |
86 | def get_hash_from_filename(filename):
87 | """Generate a hash from just the filename (not the full path)."""
88 | # Use just the filename without path
89 | name = os.path.basename(filename)
90 | # Create hash from filename and size for uniqueness
91 | file_size = os.path.getsize(filename)
92 | hash_input = f"{name}_{file_size}"
93 | return hashlib.md5(hash_input.encode('utf-8')).hexdigest()
94 |
95 | def get_thumbnail_path(file_hash):
96 | """Get the full path to a cached thumbnail."""
97 | thumbnail_dir = os.path.join(os.getcwd(), '.thumbnails')
98 | # Check if PNG version exists first
99 | png_path = os.path.join(thumbnail_dir, f"{file_hash}.png")
100 | if os.path.exists(png_path):
101 | return png_path
102 | # Otherwise use JPEG path
103 | return os.path.join(thumbnail_dir, f"{file_hash}.jpg")
104 |
105 | def ensure_thumbnail_dir():
106 | """Create the .thumbnails directory if it doesn't exist."""
107 | thumbnail_dir = os.path.join(os.getcwd(), '.thumbnails')
108 | os.makedirs(thumbnail_dir, exist_ok=True)
109 | return thumbnail_dir
110 |
111 | def save_thumbnail(frame, thumbnail_path):
112 | """Save a frame as an optimized thumbnail."""
113 | # Handle different color formats
114 | if len(frame.shape) == 2: # Grayscale
115 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
116 | elif frame.shape[2] == 4: # RGBA
117 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
118 |
119 | height,width,_ = frame.shape
120 | width, height = get_scaled_resolution(media_width=width, media_height=height, max_height=140, max_width=140)
121 | # Resize to exactly 70x70 pixels with high quality
122 | frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_LANCZOS4)
123 |
124 | # First try PNG for best quality
125 | try:
126 | cv2.imwrite(thumbnail_path[:-4] + '.png', frame)
127 | # If PNG file is too large (>30KB), fall back to high-quality JPEG
128 | if os.path.getsize(thumbnail_path[:-4] + '.png') > 30 * 1024:
129 | os.remove(thumbnail_path[:-4] + '.png')
130 | raise Exception("PNG too large")
131 | else:
132 | return
133 | except:
134 | # Define JPEG parameters for high quality
135 | params = [
136 | cv2.IMWRITE_JPEG_QUALITY, 98, # Maximum quality for JPEG
137 | cv2.IMWRITE_JPEG_OPTIMIZE, 1, # Enable optimization
138 | cv2.IMWRITE_JPEG_PROGRESSIVE, 1 # Enable progressive mode
139 | ]
140 | # Save as high quality JPEG
141 | cv2.imwrite(thumbnail_path, frame, params)
142 |
143 | def get_dfm_models_data():
144 | DFM_MODELS_DATA.clear()
145 | for dfm_file in os.listdir(DFM_MODELS_PATH):
146 | if dfm_file.endswith(('.dfm','.onnx')):
147 | DFM_MODELS_DATA[dfm_file] = f'{DFM_MODELS_PATH}/{dfm_file}'
148 | return DFM_MODELS_DATA
149 |
150 | def get_dfm_models_selection_values():
151 | return list(get_dfm_models_data().keys())
152 | def get_dfm_models_default_value():
153 | dfm_values = list(DFM_MODELS_DATA.keys())
154 | if dfm_values:
155 | return dfm_values[0]
156 | return ''
157 |
158 | def get_scaled_resolution(media_width=False, media_height=False, max_width=False, max_height=False, media_capture: cv2.VideoCapture = False,):
159 | if not max_width or not max_height:
160 | max_height = 1080
161 | max_width = 1920
162 |
163 | if (not media_width or not media_height) and media_capture:
164 | media_width = media_capture.get(cv2.CAP_PROP_FRAME_WIDTH)
165 | media_height = media_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)
166 |
167 | if media_width > max_width or media_height > max_height:
168 | width_scale = max_width/media_width
169 | height_scale = max_height/media_height
170 | scale = min(width_scale, height_scale)
171 | media_width,media_height = media_width* scale, media_height*scale
172 | return int(media_width), int(media_height)
173 |
174 | def benchmark(func):
175 | @wraps(func)
176 | def wrapper(*args, **kwargs):
177 | start_time = time.perf_counter() # Record the start time
178 | result = func(*args, **kwargs) # Call the original function
179 | end_time = time.perf_counter() # Record the end time
180 | elapsed_time = end_time - start_time # Calculate elapsed time
181 | print(f"Function '{func.__name__}' executed in {elapsed_time:.6f} seconds.")
182 | return result # Return the result of the original function
183 | return wrapper
184 |
185 | def read_frame(capture_obj: cv2.VideoCapture, preview_mode=False):
186 | with lock:
187 | ret, frame = capture_obj.read()
188 | if ret and preview_mode:
189 | pass
190 | # width, height = get_scaled_resolution(capture_obj)
191 | # frame = cv2.resize(fr2ame, dsize=(width, height), interpolation=cv2.INTER_LANCZOS4)
192 | return ret, frame
193 |
194 | def read_image_file(image_path):
195 | try:
196 | img_array = np.fromfile(image_path, np.uint8)
197 | img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # Always load as BGR
198 | except Exception as e:
199 | print(f"Failed to load {image_path}: {e}")
200 | return None
201 |
202 | if img is None:
203 | print("Failed to decode:", image_path)
204 | return None
205 |
206 | return img # Return BGR format
207 |
208 | def get_output_file_path(original_media_path, output_folder, media_type='video'):
209 | date_and_time = datetime.now().strftime(r'%Y_%m_%d_%H_%M_%S')
210 | input_filename = os.path.basename(original_media_path)
211 | # Create a temp Path object to split and merge the original filename to get the new output filename
212 | temp_path = Path(input_filename)
213 | # output_filename = "{0}_{2}{1}".format(temp_path.stem, temp_path.suffix, date_and_time)
214 | if media_type=='video':
215 | output_filename = f'{temp_path.stem}_{date_and_time}.mp4'
216 | elif media_type=='image':
217 | output_filename = f'{temp_path.stem}_{date_and_time}.png'
218 | output_file_path = os.path.join(output_folder, output_filename)
219 | return output_file_path
220 |
221 | def is_ffmpeg_in_path():
222 | if not cmd_exist('ffmpeg'):
223 | print("FFMPEG Not found in your system!")
224 | return False
225 | return True
226 |
227 | def cmd_exist(cmd):
228 | try:
229 | return shutil.which(cmd) is not None
230 | except ImportError:
231 | return any(
232 | os.access(os.path.join(path, cmd), os.X_OK)
233 | for path in os.environ["PATH"].split(os.pathsep)
234 | )
235 |
236 | def get_dir_of_file(file_path):
237 | if file_path:
238 | return os.path.dirname(file_path)
239 | return os.path.curdir
240 |
--------------------------------------------------------------------------------
/app/helpers/recording.py:
--------------------------------------------------------------------------------
1 |
2 | def write_frame_to_disk(frame):
3 | pass
--------------------------------------------------------------------------------
/app/helpers/typing_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Callable, NewType
2 | from app.helpers.miscellaneous import ParametersDict
3 |
4 | LayoutDictTypes = NewType('LayoutDictTypes', Dict[str, Dict[str, Dict[str, int|str|list|float|bool|Callable]]])
5 |
6 | ParametersTypes = NewType('ParametersTypes', ParametersDict)
7 | FacesParametersTypes = NewType('FacesParametersTypes', dict[int, ParametersTypes])
8 |
9 | ControlTypes = NewType('ControlTypes', Dict[str, bool|int|float|str])
10 |
11 | MarkerTypes = NewType('MarkerTypes', Dict[int, Dict[str, FacesParametersTypes|ControlTypes]])
--------------------------------------------------------------------------------
/app/onnxmodels/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/onnxmodels/.gitkeep
--------------------------------------------------------------------------------
/app/onnxmodels/dfm_models/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/onnxmodels/dfm_models/.keep
--------------------------------------------------------------------------------
/app/onnxmodels/place_model_files_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/processors/external/cliplib/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/app/processors/external/cliplib/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/processors/external/cliplib/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/app/processors/external/cliplib/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def _node_get(node: torch._C.Node, key: str):
149 | """Gets attributes of a node which is polymorphic over return type.
150 |
151 | From https://github.com/pytorch/pytorch/pull/82628
152 | """
153 | sel = node.kindOf(key)
154 | return getattr(node, sel)(key)
155 |
156 | def patch_device(module):
157 | try:
158 | graphs = [module.graph] if hasattr(module, "graph") else []
159 | except RuntimeError:
160 | graphs = []
161 |
162 | if hasattr(module, "forward1"):
163 | graphs.append(module.forward1.graph)
164 |
165 | for graph in graphs:
166 | for node in graph.findAllNodes("prim::Constant"):
167 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
168 | node.copyAttributes(device_node)
169 |
170 | model.apply(patch_device)
171 | patch_device(model.encode_image)
172 | patch_device(model.encode_text)
173 |
174 | # patch dtype to float32 on CPU
175 | if str(device) == "cpu":
176 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
177 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
178 | float_node = float_input.node()
179 |
180 | def patch_float(module):
181 | try:
182 | graphs = [module.graph] if hasattr(module, "graph") else []
183 | except RuntimeError:
184 | graphs = []
185 |
186 | if hasattr(module, "forward1"):
187 | graphs.append(module.forward1.graph)
188 |
189 | for graph in graphs:
190 | for node in graph.findAllNodes("aten::to"):
191 | inputs = list(node.inputs())
192 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
193 | if _node_get(inputs[i].node(), "value") == 5:
194 | inputs[i].node().copyAttributes(float_node)
195 |
196 | model.apply(patch_float)
197 | patch_float(model.encode_image)
198 | patch_float(model.encode_text)
199 |
200 | model.float()
201 |
202 | return model, _transform(model.input_resolution.item())
203 |
204 |
205 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
206 | """
207 | Returns the tokenized representation of given input string(s)
208 |
209 | Parameters
210 | ----------
211 | texts : Union[str, List[str]]
212 | An input string or a list of input strings to tokenize
213 |
214 | context_length : int
215 | The context length to use; all CLIP models use 77 as the context length
216 |
217 | truncate: bool
218 | Whether to truncate the text in case its encoding is longer than the context length
219 |
220 | Returns
221 | -------
222 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
223 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
224 | """
225 | if isinstance(texts, str):
226 | texts = [texts]
227 |
228 | sot_token = _tokenizer.encoder["<|startoftext|>"]
229 | eot_token = _tokenizer.encoder["<|endoftext|>"]
230 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
231 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
233 | else:
234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
235 |
236 | for i, tokens in enumerate(all_tokens):
237 | if len(tokens) > context_length:
238 | if truncate:
239 | tokens = tokens[:context_length]
240 | tokens[-1] = eot_token
241 | else:
242 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
243 | result[i, :len(tokens)] = torch.tensor(tokens)
244 |
245 | return result
246 |
--------------------------------------------------------------------------------
/app/processors/external/cliplib/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/app/processors/external/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/app/processors/face_restorers.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import TYPE_CHECKING
3 |
4 | import torch
5 | import numpy as np
6 | from torchvision.transforms import v2
7 | from skimage import transform as trans
8 |
9 | if TYPE_CHECKING:
10 | from app.processors.models_processor import ModelsProcessor
11 |
12 | class FaceRestorers:
13 | def __init__(self, models_processor: 'ModelsProcessor'):
14 | self.models_processor = models_processor
15 |
16 | def apply_facerestorer(self, swapped_face_upscaled, restorer_det_type, restorer_type, restorer_blend, fidelity_weight, detect_score):
17 | temp = swapped_face_upscaled
18 | t512 = v2.Resize((512, 512), antialias=False)
19 | t256 = v2.Resize((256, 256), antialias=False)
20 | t1024 = v2.Resize((1024, 1024), antialias=False)
21 | t2048 = v2.Resize((2048, 2048), antialias=False)
22 |
23 | # If using a separate detection mode
24 | if restorer_det_type == 'Blend' or restorer_det_type == 'Reference':
25 | if restorer_det_type == 'Blend':
26 | # Set up Transformation
27 | dst = self.models_processor.arcface_dst * 4.0
28 | dst[:,0] += 32.0
29 |
30 | elif restorer_det_type == 'Reference':
31 | try:
32 | dst, _, _ = self.models_processor.run_detect_landmark(swapped_face_upscaled, bbox=np.array([0, 0, 512, 512]), det_kpss=[], detect_mode='5', score=detect_score/100.0, from_points=False)
33 | except Exception as e: # pylint: disable=broad-except
34 | print(f"exception: {e}")
35 | return swapped_face_upscaled
36 |
37 | # Return non-enhanced face if keypoints are empty
38 | if not isinstance(dst, np.ndarray) or len(dst)==0:
39 | return swapped_face_upscaled
40 |
41 | tform = trans.SimilarityTransform()
42 | try:
43 | tform.estimate(dst, self.models_processor.FFHQ_kps)
44 | except:
45 | return swapped_face_upscaled
46 | # Transform, scale, and normalize
47 | temp = v2.functional.affine(swapped_face_upscaled, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0) )
48 | temp = v2.functional.crop(temp, 0,0, 512, 512)
49 |
50 | temp = torch.div(temp, 255)
51 | temp = v2.functional.normalize(temp, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=False)
52 |
53 | if restorer_type == 'GPEN-256':
54 | temp = t256(temp)
55 |
56 | temp = torch.unsqueeze(temp, 0).contiguous()
57 |
58 | # Bindings
59 | outpred = torch.empty((1,3,512,512), dtype=torch.float32, device=self.models_processor.device).contiguous()
60 |
61 | if restorer_type == 'GFPGAN-v1.4':
62 | self.run_GFPGAN(temp, outpred)
63 |
64 | elif restorer_type == 'CodeFormer':
65 | self.run_codeformer(temp, outpred, fidelity_weight)
66 |
67 | elif restorer_type == 'GPEN-256':
68 | outpred = torch.empty((1,3,256,256), dtype=torch.float32, device=self.models_processor.device).contiguous()
69 | self.run_GPEN_256(temp, outpred)
70 |
71 | elif restorer_type == 'GPEN-512':
72 | self.run_GPEN_512(temp, outpred)
73 |
74 | elif restorer_type == 'GPEN-1024':
75 | temp = t1024(temp)
76 | outpred = torch.empty((1, 3, 1024, 1024), dtype=torch.float32, device=self.models_processor.device).contiguous()
77 | self.run_GPEN_1024(temp, outpred)
78 |
79 | elif restorer_type == 'GPEN-2048':
80 | temp = t2048(temp)
81 | outpred = torch.empty((1, 3, 2048, 2048), dtype=torch.float32, device=self.models_processor.device).contiguous()
82 | self.run_GPEN_2048(temp, outpred)
83 |
84 | elif restorer_type == 'RestoreFormer++':
85 | self.run_RestoreFormerPlusPlus(temp, outpred)
86 |
87 | elif restorer_type == 'VQFR-v2':
88 | self.run_VQFR_v2(temp, outpred, fidelity_weight)
89 |
90 | # Format back to cxHxW @ 255
91 | outpred = torch.squeeze(outpred)
92 | outpred = torch.clamp(outpred, -1, 1)
93 | outpred = torch.add(outpred, 1)
94 | outpred = torch.div(outpred, 2)
95 | outpred = torch.mul(outpred, 255)
96 |
97 | if restorer_type == 'GPEN-256' or restorer_type == 'GPEN-1024' or restorer_type == 'GPEN-2048':
98 | outpred = t512(outpred)
99 |
100 | # Invert Transform
101 | if restorer_det_type == 'Blend' or restorer_det_type == 'Reference':
102 | outpred = v2.functional.affine(outpred, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
103 |
104 | # Blend
105 | alpha = float(restorer_blend)/100.0
106 | outpred = torch.add(torch.mul(outpred, alpha), torch.mul(swapped_face_upscaled, 1-alpha))
107 |
108 | return outpred
109 |
110 | def run_GFPGAN(self, image, output):
111 | if not self.models_processor.models['GFPGANv1.4']:
112 | self.models_processor.models['GFPGANv1.4'] = self.models_processor.load_model('GFPGANv1.4')
113 |
114 | io_binding = self.models_processor.models['GFPGANv1.4'].io_binding()
115 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr())
116 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr())
117 |
118 | if self.models_processor.device == "cuda":
119 | torch.cuda.synchronize()
120 | elif self.models_processor.device != "cpu":
121 | self.models_processor.syncvec.cpu()
122 | self.models_processor.models['GFPGANv1.4'].run_with_iobinding(io_binding)
123 |
124 | def run_GPEN_256(self, image, output):
125 | if not self.models_processor.models['GPENBFR256']:
126 | self.models_processor.models['GPENBFR256'] = self.models_processor.load_model('GPENBFR256')
127 |
128 | io_binding = self.models_processor.models['GPENBFR256'].io_binding()
129 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,256,256), buffer_ptr=image.data_ptr())
130 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,256,256), buffer_ptr=output.data_ptr())
131 |
132 | if self.models_processor.device == "cuda":
133 | torch.cuda.synchronize()
134 | elif self.models_processor.device != "cpu":
135 | self.models_processor.syncvec.cpu()
136 | self.models_processor.models['GPENBFR256'].run_with_iobinding(io_binding)
137 |
138 | def run_GPEN_512(self, image, output):
139 | if not self.models_processor.models['GPENBFR512']:
140 | self.models_processor.models['GPENBFR512'] = self.models_processor.load_model('GPENBFR512')
141 |
142 | io_binding = self.models_processor.models['GPENBFR512'].io_binding()
143 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr())
144 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr())
145 |
146 | if self.models_processor.device == "cuda":
147 | torch.cuda.synchronize()
148 | elif self.models_processor.device != "cpu":
149 | self.models_processor.syncvec.cpu()
150 | self.models_processor.models['GPENBFR512'].run_with_iobinding(io_binding)
151 |
152 | def run_GPEN_1024(self, image, output):
153 | if not self.models_processor.models['GPENBFR1024']:
154 | self.models_processor.models['GPENBFR1024'] = self.models_processor.load_model('GPENBFR1024')
155 |
156 | io_binding = self.models_processor.models['GPENBFR1024'].io_binding()
157 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,1024,1024), buffer_ptr=image.data_ptr())
158 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,1024,1024), buffer_ptr=output.data_ptr())
159 |
160 | if self.models_processor.device == "cuda":
161 | torch.cuda.synchronize()
162 | elif self.models_processor.device != "cpu":
163 | self.models_processor.syncvec.cpu()
164 | self.models_processor.models['GPENBFR1024'].run_with_iobinding(io_binding)
165 |
166 | def run_GPEN_2048(self, image, output):
167 | if not self.models_processor.models['GPENBFR2048']:
168 | self.models_processor.models['GPENBFR2048'] = self.models_processor.load_model('GPENBFR2048')
169 |
170 | io_binding = self.models_processor.models['GPENBFR2048'].io_binding()
171 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,2048,2048), buffer_ptr=image.data_ptr())
172 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,2048,2048), buffer_ptr=output.data_ptr())
173 |
174 | if self.models_processor.device == "cuda":
175 | torch.cuda.synchronize()
176 | elif self.models_processor.device != "cpu":
177 | self.models_processor.syncvec.cpu()
178 | self.models_processor.models['GPENBFR2048'].run_with_iobinding(io_binding)
179 |
180 | def run_codeformer(self, image, output, fidelity_weight_value=0.9):
181 | if not self.models_processor.models['CodeFormer']:
182 | self.models_processor.models['CodeFormer'] = self.models_processor.load_model('CodeFormer')
183 |
184 | io_binding = self.models_processor.models['CodeFormer'].io_binding()
185 | io_binding.bind_input(name='x', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr())
186 | w = np.array([fidelity_weight_value], dtype=np.double)
187 | io_binding.bind_cpu_input('w', w)
188 | io_binding.bind_output(name='y', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr())
189 |
190 | if self.models_processor.device == "cuda":
191 | torch.cuda.synchronize()
192 | elif self.models_processor.device != "cpu":
193 | self.models_processor.syncvec.cpu()
194 | self.models_processor.models['CodeFormer'].run_with_iobinding(io_binding)
195 |
196 | def run_VQFR_v2(self, image, output, fidelity_ratio_value):
197 | if not self.models_processor.models['VQFRv2']:
198 | self.models_processor.models['VQFRv2'] = self.models_processor.load_model('VQFRv2')
199 |
200 | assert fidelity_ratio_value >= 0.0 and fidelity_ratio_value <= 1.0, 'fidelity_ratio must in range[0,1]'
201 | fidelity_ratio = torch.tensor(fidelity_ratio_value).to(self.models_processor.device)
202 |
203 | io_binding = self.models_processor.models['VQFRv2'].io_binding()
204 | io_binding.bind_input(name='x_lq', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
205 | io_binding.bind_input(name='fidelity_ratio', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=fidelity_ratio.size(), buffer_ptr=fidelity_ratio.data_ptr())
206 | io_binding.bind_output('enc_feat', self.models_processor.device)
207 | io_binding.bind_output('quant_logit', self.models_processor.device)
208 | io_binding.bind_output('texture_dec', self.models_processor.device)
209 | io_binding.bind_output(name='main_dec', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr())
210 |
211 | if self.models_processor.device == "cuda":
212 | torch.cuda.synchronize()
213 | elif self.models_processor.device != "cpu":
214 | self.models_processor.syncvec.cpu()
215 | self.models_processor.models['VQFRv2'].run_with_iobinding(io_binding)
216 |
217 | def run_RestoreFormerPlusPlus(self, image, output):
218 | if not self.models_processor.models['RestoreFormerPlusPlus']:
219 | self.models_processor.models['RestoreFormerPlusPlus'] = self.models_processor.load_model('RestoreFormerPlusPlus')
220 |
221 | io_binding = self.models_processor.models['RestoreFormerPlusPlus'].io_binding()
222 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
223 | io_binding.bind_output(name='2359', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
224 | io_binding.bind_output('1228', self.models_processor.device)
225 | io_binding.bind_output('1238', self.models_processor.device)
226 | io_binding.bind_output('onnx::MatMul_1198', self.models_processor.device)
227 | io_binding.bind_output('onnx::Shape_1184', self.models_processor.device)
228 | io_binding.bind_output('onnx::ArgMin_1182', self.models_processor.device)
229 | io_binding.bind_output('input.1', self.models_processor.device)
230 | io_binding.bind_output('x', self.models_processor.device)
231 | io_binding.bind_output('x.3', self.models_processor.device)
232 | io_binding.bind_output('x.7', self.models_processor.device)
233 | io_binding.bind_output('x.11', self.models_processor.device)
234 | io_binding.bind_output('x.15', self.models_processor.device)
235 | io_binding.bind_output('input.252', self.models_processor.device)
236 | io_binding.bind_output('input.280', self.models_processor.device)
237 | io_binding.bind_output('input.288', self.models_processor.device)
238 |
239 | if self.models_processor.device == "cuda":
240 | torch.cuda.synchronize()
241 | elif self.models_processor.device != "cpu":
242 | self.models_processor.syncvec.cpu()
243 | self.models_processor.models['RestoreFormerPlusPlus'].run_with_iobinding(io_binding)
--------------------------------------------------------------------------------
/app/processors/frame_enhancers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import TYPE_CHECKING
3 |
4 | import torch
5 | import numpy as np
6 | from torchvision.transforms import v2
7 |
8 | if TYPE_CHECKING:
9 | from app.processors.models_processor import ModelsProcessor
10 |
11 | class FrameEnhancers:
12 | def __init__(self, models_processor: 'ModelsProcessor'):
13 | self.models_processor = models_processor
14 |
15 | def run_enhance_frame_tile_process(self, img, enhancer_type, tile_size=256, scale=1):
16 | _, _, height, width = img.shape
17 |
18 | # Calcolo del numero di tile necessari
19 | tiles_x = math.ceil(width / tile_size)
20 | tiles_y = math.ceil(height / tile_size)
21 |
22 | # Calcolo del padding necessario per adattare l'immagine alle dimensioni dei tile
23 | pad_right = (tile_size - (width % tile_size)) % tile_size
24 | pad_bottom = (tile_size - (height % tile_size)) % tile_size
25 |
26 | # Padding dell'immagine se necessario
27 | if pad_right != 0 or pad_bottom != 0:
28 | img = torch.nn.functional.pad(img, (0, pad_right, 0, pad_bottom), 'constant', 0)
29 |
30 | # Creazione di un output tensor vuoto
31 | b, c, h, w = img.shape
32 | output = torch.empty((b, c, h * scale, w * scale), dtype=torch.float32, device=self.models_processor.device).contiguous()
33 |
34 | # Selezione della funzione di upscaling in base al tipo
35 | upscaler_functions = {
36 | 'RealEsrgan-x2-Plus': self.run_realesrganx2,
37 | 'RealEsrgan-x4-Plus': self.run_realesrganx4,
38 | 'BSRGan-x2': self.run_bsrganx2,
39 | 'BSRGan-x4': self.run_bsrganx4,
40 | 'UltraSharp-x4': self.run_ultrasharpx4,
41 | 'UltraMix-x4': self.run_ultramixx4,
42 | 'RealEsr-General-x4v3': self.run_realesrx4v3
43 | }
44 |
45 | fn_upscaler = upscaler_functions.get(enhancer_type)
46 |
47 | if not fn_upscaler: # Se il tipo di enhancer non è valido
48 | if pad_right != 0 or pad_bottom != 0:
49 | img = v2.functional.crop(img, 0, 0, height, width)
50 | return img
51 |
52 | with torch.no_grad(): # Disabilita il calcolo del gradiente
53 | # Elaborazione dei tile
54 | for j in range(tiles_y):
55 | for i in range(tiles_x):
56 | x_start, y_start = i * tile_size, j * tile_size
57 | x_end, y_end = x_start + tile_size, y_start + tile_size
58 |
59 | # Estrazione del tile di input
60 | input_tile = img[:, :, y_start:y_end, x_start:x_end].contiguous()
61 | output_tile = torch.empty((input_tile.shape[0], input_tile.shape[1], input_tile.shape[2] * scale, input_tile.shape[3] * scale), dtype=torch.float32, device=self.models_processor.device).contiguous()
62 |
63 | # Upscaling del tile
64 | fn_upscaler(input_tile, output_tile)
65 |
66 | # Inserimento del tile upscalato nel tensor di output
67 | output_y_start, output_x_start = y_start * scale, x_start * scale
68 | output_y_end, output_x_end = output_y_start + output_tile.shape[2], output_x_start + output_tile.shape[3]
69 | output[:, :, output_y_start:output_y_end, output_x_start:output_x_end] = output_tile
70 |
71 | # Ritaglio dell'output per rimuovere il padding aggiunto
72 | if pad_right != 0 or pad_bottom != 0:
73 | output = v2.functional.crop(output, 0, 0, height * scale, width * scale)
74 |
75 | return output
76 |
77 | def run_realesrganx2(self, image, output):
78 | if not self.models_processor.models['RealEsrganx2Plus']:
79 | self.models_processor.models['RealEsrganx2Plus'] = self.models_processor.load_model('RealEsrganx2Plus')
80 |
81 | io_binding = self.models_processor.models['RealEsrganx2Plus'].io_binding()
82 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
83 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
84 |
85 | if self.models_processor.device == "cuda":
86 | torch.cuda.synchronize()
87 | elif self.models_processor.device != "cpu":
88 | self.models_processor.syncvec.cpu()
89 | self.models_processor.models['RealEsrganx2Plus'].run_with_iobinding(io_binding)
90 |
91 | def run_realesrganx4(self, image, output):
92 | if not self.models_processor.models['RealEsrganx4Plus']:
93 | self.models_processor.models['RealEsrganx4Plus'] = self.models_processor.load_model('RealEsrganx4Plus')
94 |
95 | io_binding = self.models_processor.models['RealEsrganx4Plus'].io_binding()
96 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
97 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
98 |
99 | if self.models_processor.device == "cuda":
100 | torch.cuda.synchronize()
101 | elif self.models_processor.device != "cpu":
102 | self.models_processor.syncvec.cpu()
103 | self.models_processor.models['RealEsrganx4Plus'].run_with_iobinding(io_binding)
104 |
105 | def run_realesrx4v3(self, image, output):
106 | if not self.models_processor.models['RealEsrx4v3']:
107 | self.models_processor.models['RealEsrx4v3'] = self.models_processor.load_model('RealEsrx4v3')
108 |
109 | io_binding = self.models_processor.models['RealEsrx4v3'].io_binding()
110 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
111 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
112 |
113 | if self.models_processor.device == "cuda":
114 | torch.cuda.synchronize()
115 | elif self.models_processor.device != "cpu":
116 | self.models_processor.syncvec.cpu()
117 | self.models_processor.models['RealEsrx4v3'].run_with_iobinding(io_binding)
118 |
119 | def run_bsrganx2(self, image, output):
120 | if not self.models_processor.models['BSRGANx2']:
121 | self.models_processor.models['BSRGANx2'] = self.models_processor.load_model('BSRGANx2')
122 |
123 | io_binding = self.models_processor.models['BSRGANx2'].io_binding()
124 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
125 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
126 |
127 | if self.models_processor.device == "cuda":
128 | torch.cuda.synchronize()
129 | elif self.models_processor.device != "cpu":
130 | self.models_processor.syncvec.cpu()
131 | self.models_processor.models['BSRGANx2'].run_with_iobinding(io_binding)
132 |
133 | def run_bsrganx4(self, image, output):
134 | if not self.models_processor.models['BSRGANx4']:
135 | self.models_processor.models['BSRGANx4'] = self.models_processor.load_model('BSRGANx4')
136 |
137 | io_binding = self.models_processor.models['BSRGANx4'].io_binding()
138 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
139 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
140 |
141 | if self.models_processor.device == "cuda":
142 | torch.cuda.synchronize()
143 | elif self.models_processor.device != "cpu":
144 | self.models_processor.syncvec.cpu()
145 | self.models_processor.models['BSRGANx4'].run_with_iobinding(io_binding)
146 |
147 | def run_ultrasharpx4(self, image, output):
148 | if not self.models_processor.models['UltraSharpx4']:
149 | self.models_processor.models['UltraSharpx4'] = self.models_processor.load_model('UltraSharpx4')
150 |
151 | io_binding = self.models_processor.models['UltraSharpx4'].io_binding()
152 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
153 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
154 |
155 | if self.models_processor.device == "cuda":
156 | torch.cuda.synchronize()
157 | elif self.models_processor.device != "cpu":
158 | self.models_processor.syncvec.cpu()
159 | self.models_processor.models['UltraSharpx4'].run_with_iobinding(io_binding)
160 |
161 | def run_ultramixx4(self, image, output):
162 | if not self.models_processor.models['UltraMixx4']:
163 | self.models_processor.models['UltraMixx4'] = self.models_processor.load_model('UltraMixx4')
164 |
165 | io_binding = self.models_processor.models['UltraMixx4'].io_binding()
166 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
167 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
168 |
169 | if self.models_processor.device == "cuda":
170 | torch.cuda.synchronize()
171 | elif self.models_processor.device != "cpu":
172 | self.models_processor.syncvec.cpu()
173 | self.models_processor.models['UltraMixx4'].run_with_iobinding(io_binding)
174 |
175 | def run_deoldify_artistic(self, image, output):
176 | if not self.models_processor.models['DeoldifyArt']:
177 | self.models_processor.models['DeoldifyArt'] = self.models_processor.load_model('DeoldifyArt')
178 |
179 | io_binding = self.models_processor.models['DeoldifyArt'].io_binding()
180 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
181 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
182 |
183 | if self.models_processor.device == "cuda":
184 | torch.cuda.synchronize()
185 | elif self.models_processor.device != "cpu":
186 | self.models_processor.syncvec.cpu()
187 | self.models_processor.models['DeoldifyArt'].run_with_iobinding(io_binding)
188 |
189 | def run_deoldify_stable(self, image, output):
190 | if not self.models_processor.models['DeoldifyStable']:
191 | self.models_processor.models['DeoldifyStable'] = self.models_processor.load_model('DeoldifyStable')
192 |
193 | io_binding = self.models_processor.models['DeoldifyStable'].io_binding()
194 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
195 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
196 |
197 | if self.models_processor.device == "cuda":
198 | torch.cuda.synchronize()
199 | elif self.models_processor.device != "cpu":
200 | self.models_processor.syncvec.cpu()
201 | self.models_processor.models['DeoldifyStable'].run_with_iobinding(io_binding)
202 |
203 | def run_deoldify_video(self, image, output):
204 | if not self.models_processor.models['DeoldifyVideo']:
205 | self.models_processor.models['DeoldifyVideo'] = self.models_processor.load_model('DeoldifyVideo')
206 |
207 | io_binding = self.models_processor.models['DeoldifyVideo'].io_binding()
208 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
209 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
210 |
211 | if self.models_processor.device == "cuda":
212 | torch.cuda.synchronize()
213 | elif self.models_processor.device != "cpu":
214 | self.models_processor.syncvec.cpu()
215 | self.models_processor.models['DeoldifyVideo'].run_with_iobinding(io_binding)
216 |
217 | def run_ddcolor_artistic(self, image, output):
218 | if not self.models_processor.models['DDColorArt']:
219 | self.models_processor.models['DDColorArt'] = self.models_processor.load_model('DDColorArt')
220 |
221 | io_binding = self.models_processor.models['DDColorArt'].io_binding()
222 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
223 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
224 |
225 | if self.models_processor.device == "cuda":
226 | torch.cuda.synchronize()
227 | elif self.models_processor.device != "cpu":
228 | self.models_processor.syncvec.cpu()
229 | self.models_processor.models['DDColorArt'].run_with_iobinding(io_binding)
230 |
231 | def run_ddcolor(self, image, output):
232 | if not self.models_processor.models['DDcolor']:
233 | self.models_processor.models['DDcolor'] = self.models_processor.load_model('DDcolor')
234 |
235 | io_binding = self.models_processor.models['DDcolor'].io_binding()
236 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr())
237 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr())
238 |
239 | if self.models_processor.device == "cuda":
240 | torch.cuda.synchronize()
241 | elif self.models_processor.device != "cpu":
242 | self.models_processor.syncvec.cpu()
243 | self.models_processor.models['DDcolor'].run_with_iobinding(io_binding)
--------------------------------------------------------------------------------
/app/processors/utils/engine_builder.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-member
2 |
3 | import os
4 | import sys
5 | import logging
6 | import platform
7 | import ctypes
8 | from pathlib import Path
9 |
10 | try:
11 | import tensorrt as trt
12 | except ModuleNotFoundError:
13 | pass
14 |
15 | logging.basicConfig(level=logging.INFO)
16 | logging.getLogger("EngineBuilder").setLevel(logging.INFO)
17 | log = logging.getLogger("EngineBuilder")
18 |
19 | if 'trt' in globals():
20 | # Creazione di un'istanza globale di logger di TensorRT
21 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) # pylint: disable=no-member
22 | else:
23 | TRT_LOGGER = {}
24 |
25 | # imported from https://github.com/warmshao/FasterLivePortrait/blob/master/scripts/onnx2trt.py
26 | # adjusted to work with TensorRT 10.3.0
27 | class EngineBuilder:
28 | """
29 | Parses an ONNX graph and builds a TensorRT engine from it.
30 | """
31 |
32 | def __init__(self, verbose=False, custom_plugin_path=None, builder_optimization_level=3):
33 | """
34 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
35 | :param custom_plugin_path: Path to the custom plugin library (DLL or SO).
36 | """
37 | if verbose:
38 | TRT_LOGGER.min_severity = trt.Logger.Severity.VERBOSE
39 |
40 | # Inizializza i plugin di TensorRT
41 | trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
42 |
43 | # Costruisce il builder di TensorRT e la configurazione usando lo stesso logger
44 | self.builder = trt.Builder(TRT_LOGGER)
45 | self.config = self.builder.create_builder_config()
46 | # Imposta il limite di memoria del pool di lavoro a 3 GB
47 | self.config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 3 * (2 ** 30)) # 3 GB
48 |
49 | # Imposta il livello di ottimizzazione del builder (se fornito)
50 | self.config.builder_optimization_level = builder_optimization_level
51 |
52 | # Crea un profilo di ottimizzazione, se necessario
53 | profile = self.builder.create_optimization_profile()
54 | self.config.add_optimization_profile(profile)
55 |
56 | self.batch_size = None
57 | self.network = None
58 | self.parser = None
59 |
60 | # Carica plugin personalizzati se specificato
61 | if custom_plugin_path is not None:
62 | if platform.system().lower() == 'linux':
63 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL)
64 | else:
65 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL, winmode=0)
66 |
67 | def create_network(self, onnx_path):
68 | """
69 | Parse the ONNX graph and create the corresponding TensorRT network definition.
70 | :param onnx_path: The path to the ONNX graph to load.
71 | """
72 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
73 |
74 | self.network = self.builder.create_network(network_flags)
75 | self.parser = trt.OnnxParser(self.network, TRT_LOGGER)
76 |
77 | onnx_path = os.path.realpath(onnx_path)
78 | with open(onnx_path, "rb") as f:
79 | if not self.parser.parse(f.read()):
80 | log.error("Failed to load ONNX file: %s", onnx_path)
81 | for error in range(self.parser.num_errors):
82 | log.error(self.parser.get_error(error))
83 | sys.exit(1)
84 |
85 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
86 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
87 |
88 | log.info("Network Description")
89 | for net_input in inputs:
90 | self.batch_size = net_input.shape[0]
91 | log.info("Input '%s' with shape %s and dtype %s", net_input.name, net_input.shape, net_input.dtype)
92 | for net_output in outputs:
93 | log.info("Output %s' with shape %s and dtype %s", net_output.name, net_output.shape, net_output.dtype)
94 |
95 | def create_engine(self, engine_path, precision):
96 | """
97 | Build the TensorRT engine and serialize it to disk.
98 | :param engine_path: The path where to serialize the engine to.
99 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
100 | """
101 | engine_path = os.path.realpath(engine_path)
102 | engine_dir = os.path.dirname(engine_path)
103 | os.makedirs(engine_dir, exist_ok=True)
104 | log.info("Building %s Engine in %s", precision, engine_path)
105 |
106 | # Forza TensorRT a rispettare i vincoli di precisione
107 | self.config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
108 |
109 | if precision == "fp16":
110 | if not self.builder.platform_has_fast_fp16:
111 | log.warning("FP16 is not supported natively on this platform/device")
112 | else:
113 | self.config.set_flag(trt.BuilderFlag.FP16)
114 |
115 | # Costruzione del motore serializzato
116 | serialized_engine = self.builder.build_serialized_network(self.network, self.config)
117 |
118 | # Verifica che il motore sia stato serializzato correttamente
119 | if serialized_engine is None:
120 | raise RuntimeError("Errore nella costruzione del motore TensorRT!")
121 |
122 | # Scrittura del motore serializzato su disco
123 | with open(engine_path, "wb") as f:
124 | log.info("Serializing engine to file: %s", engine_path)
125 | f.write(serialized_engine)
126 |
127 | def change_extension(file_path, new_extension, version=None):
128 | """
129 | Change the extension of the file path and optionally prepend a version.
130 | """
131 | # Remove leading '.' from the new extension if present
132 | new_extension = new_extension.lstrip('.')
133 |
134 | # Create the new file path with the version before the extension, if provided
135 | if version:
136 | new_file_path = Path(file_path).with_suffix(f'.{version}.{new_extension}')
137 | else:
138 | new_file_path = Path(file_path).with_suffix(f'.{new_extension}')
139 |
140 | return str(new_file_path)
141 |
142 | def onnx_to_trt(onnx_model_path, trt_model_path=None, precision="fp16", custom_plugin_path=None, verbose=False):
143 | # The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'"
144 |
145 | if trt_model_path is None:
146 | trt_version = trt.__version__
147 | trt_model_path = change_extension(onnx_model_path, "trt", version=trt_version)
148 | builder = EngineBuilder(verbose=verbose, custom_plugin_path=custom_plugin_path)
149 |
150 | builder.create_network(onnx_model_path)
151 | builder.create_engine(trt_model_path, precision)
152 |
--------------------------------------------------------------------------------
/app/processors/utils/tensorrt_predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from collections import OrderedDict
4 | import platform
5 | from queue import Queue
6 | from threading import Lock
7 | from typing import Dict, Any, OrderedDict as OrderedDictType
8 |
9 | try:
10 | from torch.cuda import nvtx
11 | import tensorrt as trt
12 | import ctypes
13 | except ModuleNotFoundError:
14 | pass
15 |
16 | # Dizionario per la conversione dei tipi di dati numpy a torch
17 | numpy_to_torch_dtype_dict = {
18 | np.uint8: torch.uint8,
19 | np.int8: torch.int8,
20 | np.int16: torch.int16,
21 | np.int32: torch.int32,
22 | np.int64: torch.int64,
23 | np.float16: torch.float16,
24 | np.float32: torch.float32,
25 | np.float64: torch.float64,
26 | np.complex64: torch.complex64,
27 | np.complex128: torch.complex128,
28 | }
29 | if np.version.full_version >= "1.24.0":
30 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool
31 | else:
32 | numpy_to_torch_dtype_dict[np.bool] = torch.bool
33 |
34 | if 'trt' in globals():
35 | # Creazione di un’istanza globale di logger di TensorRT
36 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
37 | else:
38 | TRT_LOGGER = None
39 |
40 |
41 | class TensorRTPredictor:
42 | """
43 | Implementa l'inferenza su un engine TensorRT, utilizzando un pool di execution context
44 | ognuno dei quali possiede i propri buffer per garantire la sicurezza in ambiente multithread.
45 | """
46 |
47 | def __init__(self, **kwargs) -> None:
48 | """
49 | :param model_path: Percorso al file dell'engine serializzato.
50 | :param pool_size: Numero di execution context da mantenere nel pool.
51 | :param custom_plugin_path: (Opzionale) percorso a eventuali plugin personalizzati.
52 | :param device: Device su cui allocare i tensori (default 'cuda').
53 | :param debug: Se True, stampa informazioni di debug.
54 | """
55 | self.device = kwargs.get("device", 'cuda')
56 | self.debug = kwargs.get("debug", False)
57 | self.pool_size = kwargs.get("pool_size", 10)
58 |
59 | # Caricamento del plugin personalizzato (se fornito)
60 | custom_plugin_path = kwargs.get("custom_plugin_path", None)
61 | if custom_plugin_path is not None:
62 | try:
63 | if platform.system().lower() == 'linux':
64 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL)
65 | else:
66 | # Su Windows eventualmente usare WinDLL o parametri specifici
67 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL, winmode=0)
68 | except Exception as e:
69 | raise RuntimeError(f"Errore nel caricamento del plugin personalizzato: {e}")
70 |
71 | # Verifica che il percorso del modello sia fornito
72 | engine_path = kwargs.get("model_path", None)
73 | if not engine_path:
74 | raise ValueError("Il parametro 'model_path' è obbligatorio.")
75 |
76 | # Caricamento dell'engine TensorRT
77 | try:
78 | with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
79 | engine_data = f.read()
80 | self.engine = runtime.deserialize_cuda_engine(engine_data)
81 | except Exception as e:
82 | raise RuntimeError(f"Errore nella deserializzazione dell'engine: {e}")
83 |
84 | if self.engine is None:
85 | raise RuntimeError("La deserializzazione dell'engine è fallita.")
86 |
87 | # Setup delle specifiche di I/O (input e output)
88 | self.inputs = []
89 | self.outputs = []
90 | for idx in range(self.engine.num_io_tensors):
91 | name = self.engine.get_tensor_name(idx)
92 | mode = self.engine.get_tensor_mode(name)
93 | shape = list(self.engine.get_tensor_shape(name))
94 | dtype = trt.nptype(self.engine.get_tensor_dtype(name))
95 | binding = {
96 | "index": idx,
97 | "name": name,
98 | "dtype": dtype,
99 | "shape": shape,
100 | }
101 | if mode == trt.TensorIOMode.INPUT:
102 | self.inputs.append(binding)
103 | else:
104 | self.outputs.append(binding)
105 |
106 | if len(self.inputs) == 0 or len(self.outputs) == 0:
107 | raise RuntimeError("L'engine deve avere almeno un input e un output.")
108 |
109 | # Creazione del pool di execution context
110 | self.context_pool = Queue(maxsize=self.pool_size)
111 | # (Opzionale) Lock per eventuali operazioni critiche
112 | self.lock = Lock()
113 | for _ in range(self.pool_size):
114 | context = self.engine.create_execution_context()
115 | buffers = self._allocate_buffers()
116 | self.context_pool.put({"context": context, "buffers": buffers})
117 |
118 | def _allocate_buffers(self) -> OrderedDictType[str, torch.Tensor]:
119 | """
120 | Alloca un dizionario di tensori per tutti gli I/O del modello, tenendo conto di eventuali
121 | dimensioni dinamiche. Viene restituito un OrderedDict in cui la chiave è il nome del tensore.
122 | """
123 | nvtx.range_push("allocate_max_buffers")
124 | buffers = OrderedDict()
125 | # Batch size predefinito
126 | batch_size = 1
127 | for idx in range(self.engine.num_io_tensors):
128 | name = self.engine.get_tensor_name(idx)
129 | shape = list(self.engine.get_tensor_shape(name)) # assicuriamoci di avere una lista
130 | is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
131 | if -1 in shape:
132 | if is_input:
133 | # Ottiene la shape massima per il profilo 0
134 | profile_shape = self.engine.get_tensor_profile_shape(name, 0)[-1]
135 | shape = list(profile_shape)
136 | batch_size = shape[0]
137 | else:
138 | shape[0] = batch_size
139 | dtype = trt.nptype(self.engine.get_tensor_dtype(name))
140 | if dtype not in numpy_to_torch_dtype_dict:
141 | raise TypeError(f"Tipo numpy non supportato: {dtype}")
142 | tensor = torch.empty(tuple(shape),
143 | dtype=numpy_to_torch_dtype_dict[dtype],
144 | device=self.device)
145 | buffers[name] = tensor
146 | nvtx.range_pop()
147 | return buffers
148 |
149 | def input_spec(self) -> list:
150 | """
151 | Restituisce le specifiche degli input (nome, shape, dtype) utili per preparare gli array.
152 | """
153 | specs = []
154 | for i, inp in enumerate(self.inputs):
155 | specs.append((inp["name"], inp["shape"], inp["dtype"]))
156 | if self.debug:
157 | print(f"trt input {i} -> {inp['name']} -> {inp['shape']} -> {inp['dtype']}")
158 | return specs
159 |
160 | def output_spec(self) -> list:
161 | """
162 | Restituisce le specifiche degli output (nome, shape, dtype) utili per preparare gli array.
163 | """
164 | specs = []
165 | for i, out in enumerate(self.outputs):
166 | specs.append((out["name"], out["shape"], out["dtype"]))
167 | if self.debug:
168 | print(f"trt output {i} -> {out['name']} -> {out['shape']} -> {out['dtype']}")
169 | return specs
170 |
171 | def adjust_buffer(self, feed_dict: Dict[str, Any], context: Any, buffers: OrderedDictType[str, torch.Tensor]) -> None:
172 | """
173 | Regola le dimensioni dei buffer di input e copia i dati dal feed_dict nei tensori allocati.
174 | Se l’input è un array NumPy, lo converte in tensore Torch (sul device corretto).
175 | Imposta inoltre la shape di input nel contesto di esecuzione.
176 | """
177 | nvtx.range_push("adjust_buffer")
178 | for name, buf in feed_dict.items():
179 | if name not in buffers:
180 | raise KeyError(f"Input '{name}' non trovato nei buffer allocati.")
181 | input_tensor = buffers[name]
182 | # Converte in tensore se necessario
183 | if isinstance(buf, np.ndarray):
184 | buf_tensor = torch.from_numpy(buf).to(input_tensor.device)
185 | elif isinstance(buf, torch.Tensor):
186 | buf_tensor = buf.to(input_tensor.device)
187 | else:
188 | raise TypeError(f"Tipo di dato per '{name}' non supportato: {type(buf)}")
189 | current_shape = list(buf_tensor.shape)
190 | # Copia solo la porzione effettivamente utilizzata nel buffer preallocato
191 | slices = tuple(slice(0, dim) for dim in current_shape)
192 | input_tensor[slices].copy_(buf_tensor)
193 | # Imposta la shape dell'input nel contesto
194 | context.set_input_shape(name, current_shape)
195 | nvtx.range_pop()
196 |
197 | def predict(self, feed_dict: Dict[str, Any]) -> OrderedDictType[str, torch.Tensor]:
198 | """
199 | Esegue l'inferenza in modalità sincrona usando execute_v2().
200 |
201 | :param feed_dict: Dizionario di input (array numpy o tensori Torch).
202 | :return: Dizionario dei tensori (input e output) aggiornati.
203 | """
204 | pool_entry = self.context_pool.get() # La Queue è thread-safe
205 | context = pool_entry["context"]
206 | buffers = pool_entry["buffers"]
207 |
208 | try:
209 | nvtx.range_push("set_tensors")
210 | self.adjust_buffer(feed_dict, context, buffers)
211 | # Imposta gli indirizzi dei buffer
212 | for name, tensor in buffers.items():
213 | # Se necessario, si può controllare che il tipo del tensore sia quello atteso
214 | context.set_tensor_address(name, tensor.data_ptr())
215 | nvtx.range_pop()
216 |
217 | # Prepara i binding (lista degli indirizzi dei buffer)
218 | bindings = [tensor.data_ptr() for tensor in buffers.values()]
219 |
220 | nvtx.range_push("execute")
221 | noerror = context.execute_v2(bindings)
222 | nvtx.range_pop()
223 | if not noerror:
224 | raise RuntimeError("ERROR: inference failed.")
225 |
226 | # (Opzionalmente, si potrebbero restituire solo gli output)
227 | return buffers
228 |
229 | finally:
230 | # Sincronizza il flusso CUDA prima di restituire il contesto
231 | torch.cuda.synchronize()
232 | self.context_pool.put(pool_entry)
233 |
234 | def predict_async(self, feed_dict: Dict[str, Any], stream: torch.cuda.Stream) -> OrderedDictType[str, torch.Tensor]:
235 | """
236 | Esegue l'inferenza in modalità asincrona usando execute_async_v3().
237 |
238 | :param feed_dict: Dizionario di input (array numpy o tensori Torch).
239 | :param stream: Un CUDA stream per l'esecuzione asincrona.
240 | :return: Dizionario dei tensori (input e output) aggiornati.
241 | """
242 | pool_entry = self.context_pool.get()
243 | context = pool_entry["context"]
244 | buffers = pool_entry["buffers"]
245 |
246 | try:
247 | nvtx.range_push("set_tensors")
248 | self.adjust_buffer(feed_dict, context, buffers)
249 | for name, tensor in buffers.items():
250 | context.set_tensor_address(name, tensor.data_ptr())
251 | nvtx.range_pop()
252 |
253 | # Creazione di un evento CUDA per monitorare il consumo dell'input
254 | input_consumed_event = torch.cuda.Event()
255 | context.set_input_consumed_event(input_consumed_event.cuda_event)
256 |
257 | nvtx.range_push("execute_async")
258 | noerror = context.execute_async_v3(stream.cuda_stream)
259 | nvtx.range_pop()
260 | if not noerror:
261 | raise RuntimeError("ERROR: inference failed.")
262 |
263 | input_consumed_event.synchronize()
264 |
265 | return buffers
266 |
267 | finally:
268 | # Sincronizza lo stream usato se diverso da quello corrente
269 | if stream != torch.cuda.current_stream():
270 | stream.synchronize()
271 | else:
272 | torch.cuda.synchronize()
273 | self.context_pool.put(pool_entry)
274 |
275 | def cleanup(self) -> None:
276 | """
277 | Libera tutte le risorse associate al TensorRTPredictor.
278 | Questo metodo deve essere chiamato esplicitamente prima di eliminare l'oggetto.
279 | """
280 | # Libera l'engine TensorRT
281 | if hasattr(self, 'engine') and self.engine is not None:
282 | del self.engine
283 | self.engine = None
284 |
285 | # Libera il pool di execution context e relativi buffer
286 | if hasattr(self, 'context_pool') and self.context_pool is not None:
287 | while not self.context_pool.empty():
288 | pool_entry = self.context_pool.get()
289 | context = pool_entry.get("context", None)
290 | buffers = pool_entry.get("buffers", None)
291 | if context is not None:
292 | del context
293 | if buffers is not None:
294 | for t in buffers.values():
295 | del t
296 | self.context_pool = None
297 |
298 | self.inputs = None
299 | self.outputs = None
300 | self.pool_size = None
301 |
302 | def __del__(self) -> None:
303 | # Per maggiore sicurezza, chiama cleanup nel distruttore
304 | self.cleanup()
305 |
--------------------------------------------------------------------------------
/app/ui/core/convert_ui_to_py.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal enabledelayedexpansion
3 |
4 |
5 | :: Define relative paths
6 | set "UI_FILE=app\ui\core\MainWindow.ui"
7 | set "PY_FILE=app\ui\core\main_window.py"
8 | set "QRC_FILE=app\ui\core\media.qrc"
9 | set "RCC_PY_FILE=app\ui\core\media_rc.py"
10 |
11 | :: Run PySide6 commands
12 | pyside6-uic "%UI_FILE%" -o "%PY_FILE%"
13 | pyside6-rcc "%QRC_FILE%" -o "%RCC_PY_FILE%"
14 |
15 | :: Define search and replace strings
16 | set "searchString=import media_rc"
17 | set "replaceString=from app.ui.core import media_rc"
18 |
19 | :: Create a temporary file
20 | set "tempFile=%PY_FILE%.tmp"
21 |
22 | :: Process the file
23 | (for /f "usebackq delims=" %%A in ("%PY_FILE%") do (
24 | set "line=%%A"
25 | if "!line!"=="%searchString%" (
26 | echo %replaceString%
27 | ) else (
28 | echo !line!
29 | )
30 | )) > "%tempFile%"
31 |
32 | :: Replace the original file with the temporary file
33 | move /y "%tempFile%" "%PY_FILE%"
34 |
35 | echo Replacement complete.
--------------------------------------------------------------------------------
/app/ui/core/media.qrc:
--------------------------------------------------------------------------------
1 |
Drop Files
" 217 | "or
" 218 | "Click here to Select a Folder
" 219 | "" 220 | ) 221 | # placeholder_label.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) 222 | placeholder_label.setStyleSheet("color: gray; font-size: 15px; font-weight: bold;") 223 | 224 | # Center the label inside the QListWidget 225 | # placeholder_label.setGeometry(list_widget.rect()) # Match QListWidget's size 226 | placeholder_label.setAttribute(QtCore.Qt.WidgetAttribute.WA_TransparentForMouseEvents) # Allow interactions to pass through 227 | placeholder_label.setVisible(not list_widget.count()) # Show if the list is empty 228 | 229 | # Use a QVBoxLayout to center the placeholder label 230 | layout = QtWidgets.QVBoxLayout(list_widget) 231 | layout.addWidget(placeholder_label) 232 | layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) # Center the label vertically and horizontally 233 | layout.setContentsMargins(0, 0, 0, 0) # Remove margins to ensure full coverage 234 | 235 | # Keep a reference for toggling visibility later 236 | list_widget.placeholder_label = placeholder_label 237 | # Set default cursor as PointingHand 238 | list_widget.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) 239 | 240 | def select_output_media_folder(main_window: 'MainWindow'): 241 | folder_name = QtWidgets.QFileDialog.getExistingDirectory() 242 | if folder_name: 243 | main_window.outputFolderLineEdit.setText(folder_name) 244 | common_widget_actions.create_control(main_window, 'OutputMediaFolder', folder_name) 245 | -------------------------------------------------------------------------------- /app/ui/widgets/common_layout_data.py: -------------------------------------------------------------------------------- 1 | from app.helpers.typing_helper import LayoutDictTypes 2 | import app.ui.widgets.actions.layout_actions as layout_actions 3 | 4 | COMMON_LAYOUT_DATA: LayoutDictTypes = { 5 | # 'Face Compare':{ 6 | # 'ViewFaceMaskEnableToggle':{ 7 | # 'level': 1, 8 | # 'label': 'View Face Mask', 9 | # 'default': False, 10 | # 'help': 'Show Face Mask', 11 | # 'exec_function': layout_actions.fit_image_to_view_onchange, 12 | # 'exec_function_args': [], 13 | # }, 14 | # 'ViewFaceCompareEnableToggle':{ 15 | # 'level': 1, 16 | # 'label': 'View Face Compare', 17 | # 'default': False, 18 | # 'help': 'Show Face Compare', 19 | # 'exec_function': layout_actions.fit_image_to_view_onchange, 20 | # 'exec_function_args': [], 21 | # }, 22 | # }, 23 | 'Face Restorer': { 24 | 'FaceRestorerEnableToggle': { 25 | 'level': 1, 26 | 'label': 'Enable Face Restorer', 27 | 'default': False, 28 | 'help': 'Enable the use of a face restoration model to improve the quality of the face after swapping.' 29 | }, 30 | 'FaceRestorerTypeSelection': { 31 | 'level': 2, 32 | 'label': 'Restorer Type', 33 | 'options': ['GFPGAN-v1.4', 'CodeFormer', 'GPEN-256', 'GPEN-512', 'GPEN-1024', 'GPEN-2048', 'RestoreFormer++', 'VQFR-v2'], 34 | 'default': 'GFPGAN-v1.4', 35 | 'parentToggle': 'FaceRestorerEnableToggle', 36 | 'requiredToggleValue': True, 37 | 'help': 'Select the model type for face restoration.' 38 | }, 39 | 'FaceRestorerDetTypeSelection': { 40 | 'level': 2, 41 | 'label': 'Alignment', 42 | 'options': ['Original', 'Blend', 'Reference'], 43 | 'default': 'Original', 44 | 'parentToggle': 'FaceRestorerEnableToggle', 45 | 'requiredToggleValue': True, 46 | 'help': 'Select the alignment method for restoring the face to its original or blended position.' 47 | }, 48 | 'FaceFidelityWeightDecimalSlider': { 49 | 'level': 2, 50 | 'label': 'Fidelity Weight', 51 | 'min_value': '0.0', 52 | 'max_value': '1.0', 53 | 'default': '0.9', 54 | 'decimals': 1, 55 | 'step': 0.1, 56 | 'parentToggle': 'FaceRestorerEnableToggle', 57 | 'requiredToggleValue': True, 58 | 'help': 'Adjust the fidelity weight to control how closely the restoration preserves the original face details.' 59 | }, 60 | 'FaceRestorerBlendSlider': { 61 | 'level': 2, 62 | 'label': 'Blend', 63 | 'min_value': '0', 64 | 'max_value': '100', 65 | 'default': '100', 66 | 'step': 1, 67 | 'parentToggle': 'FaceRestorerEnableToggle', 68 | 'requiredToggleValue': True, 69 | 'help': 'Control the blend ratio between the restored face and the swapped face.' 70 | }, 71 | 'FaceRestorerEnable2Toggle': { 72 | 'level': 1, 73 | 'label': 'Enable Face Restorer 2', 74 | 'default': False, 75 | 'help': 'Enable the use of a face restoration model to improve the quality of the face after swapping.' 76 | }, 77 | 'FaceRestorerType2Selection': { 78 | 'level': 2, 79 | 'label': 'Restorer Type', 80 | 'options': ['GFPGAN-v1.4', 'CodeFormer', 'GPEN-256', 'GPEN-512', 'GPEN-1024', 'GPEN-2048', 'RestoreFormer++', 'VQFR-v2'], 81 | 'default': 'GFPGAN-v1.4', 82 | 'parentToggle': 'FaceRestorerEnable2Toggle', 83 | 'requiredToggleValue': True, 84 | 'help': 'Select the model type for face restoration.' 85 | }, 86 | 'FaceRestorerDetType2Selection': { 87 | 'level': 2, 88 | 'label': 'Alignment', 89 | 'options': ['Original', 'Blend', 'Reference'], 90 | 'default': 'Original', 91 | 'parentToggle': 'FaceRestorerEnable2Toggle', 92 | 'requiredToggleValue': True, 93 | 'help': 'Select the alignment method for restoring the face to its original or blended position.' 94 | }, 95 | 'FaceFidelityWeight2DecimalSlider': { 96 | 'level': 2, 97 | 'label': 'Fidelity Weight', 98 | 'min_value': '0.0', 99 | 'max_value': '1.0', 100 | 'default': '0.9', 101 | 'decimals': 1, 102 | 'step': 0.1, 103 | 'parentToggle': 'FaceRestorerEnable2Toggle', 104 | 'requiredToggleValue': True, 105 | 'help': 'Adjust the fidelity weight to control how closely the restoration preserves the original face details.' 106 | }, 107 | 'FaceRestorerBlend2Slider': { 108 | 'level': 2, 109 | 'label': 'Blend', 110 | 'min_value': '0', 111 | 'max_value': '100', 112 | 'default': '100', 113 | 'step': 1, 114 | 'parentToggle': 'FaceRestorerEnable2Toggle', 115 | 'requiredToggleValue': True, 116 | 'help': 'Control the blend ratio between the restored face and the swapped face.' 117 | }, 118 | 'FaceExpressionEnableToggle': { 119 | 'level': 1, 120 | 'label': 'Enable Face Expression Restorer', 121 | 'default': False, 122 | 'help': 'Enabled the use of the LivePortrait face expression model to restore facial expressions after swapping.' 123 | }, 124 | 'FaceExpressionCropScaleDecimalSlider': { 125 | 'level': 2, 126 | 'label': 'Crop Scale', 127 | 'min_value': '1.80', 128 | 'max_value': '3.00', 129 | 'default': '2.30', 130 | 'step': 0.05, 131 | 'decimals': 2, 132 | 'parentToggle': 'FaceExpressionEnableToggle', 133 | 'requiredToggleValue': True, 134 | 'help': 'Changes swap crop scale. Increase the value to capture the face more distantly.' 135 | }, 136 | 'FaceExpressionVYRatioDecimalSlider': { 137 | 'level': 2, 138 | 'label': 'VY Ratio', 139 | 'min_value': '-0.125', 140 | 'max_value': '-0.100', 141 | 'default': '-0.125', 142 | 'step': 0.001, 143 | 'decimals': 3, 144 | 'parentToggle': 'FaceExpressionEnableToggle', 145 | 'requiredToggleValue': True, 146 | 'help': 'Changes the vy ratio for crop scale. Increase the value to capture the face more distantly.' 147 | }, 148 | 'FaceExpressionFriendlyFactorDecimalSlider': { 149 | 'level': 2, 150 | 'label': 'Expression Friendly Factor', 151 | 'min_value': '0.0', 152 | 'max_value': '1.0', 153 | 'default': '1.0', 154 | 'decimals': 1, 155 | 'step': 0.1, 156 | 'parentToggle': 'FaceExpressionEnableToggle', 157 | 'requiredToggleValue': True, 158 | 'help': 'Control the expression similarity between the driving face and the swapped face.' 159 | }, 160 | 'FaceExpressionAnimationRegionSelection': { 161 | 'level': 2, 162 | 'label': 'Animation Region', 163 | 'options': ['all', 'eyes', 'lips'], 164 | 'default': 'all', 165 | 'parentToggle': 'FaceExpressionEnableToggle', 166 | 'requiredToggleValue': True, 167 | 'help': 'The facial region involved in the restoration process.' 168 | }, 169 | 'FaceExpressionNormalizeLipsEnableToggle': { 170 | 'level': 2, 171 | 'label': 'Normalize Lips', 172 | 'default': True, 173 | 'parentToggle': 'FaceExpressionEnableToggle', 174 | 'requiredToggleValue': True, 175 | 'help': 'Normalize the lips during the facial restoration process.' 176 | }, 177 | 'FaceExpressionNormalizeLipsThresholdDecimalSlider': { 178 | 'level': 3, 179 | 'label': 'Normalize Lips Threshold', 180 | 'min_value': '0.00', 181 | 'max_value': '1.00', 182 | 'default': '0.03', 183 | 'decimals': 2, 184 | 'step': 0.01, 185 | 'parentToggle': 'FaceExpressionNormalizeLipsEnableToggle & FaceExpressionEnableToggle', 186 | 'requiredToggleValue': True, 187 | 'help': 'Threshold value for Normalize Lips.' 188 | }, 189 | 'FaceExpressionRetargetingEyesEnableToggle': { 190 | 'level': 2, 191 | 'label': 'Retargeting Eyes', 192 | 'default': False, 193 | 'parentToggle': 'FaceExpressionEnableToggle', 194 | 'requiredToggleValue': True, 195 | 'help': 'Adjusting or redirecting the gaze or movement of the eyes during the facial restoration process. It overrides the Animation Region settings, meaning that the Animation Region will be ignored.' 196 | }, 197 | 'FaceExpressionRetargetingEyesMultiplierDecimalSlider': { 198 | 'level': 3, 199 | 'label': 'Retargeting Eyes Multiplier', 200 | 'min_value': '0.00', 201 | 'max_value': '2.00', 202 | 'default': '1.00', 203 | 'decimals': 2, 204 | 'step': 0.01, 205 | 'parentToggle': 'FaceExpressionRetargetingEyesEnableToggle & FaceExpressionEnableToggle', 206 | 'requiredToggleValue': True, 207 | 'help': 'Multiplier value for Retargeting Eyes.' 208 | }, 209 | 'FaceExpressionRetargetingLipsEnableToggle': { 210 | 'level': 2, 211 | 'label': 'Retargeting Lips', 212 | 'default': False, 213 | 'parentToggle': 'FaceExpressionEnableToggle', 214 | 'requiredToggleValue': True, 215 | 'help': 'Adjusting or modifying the position, shape, or movement of the lips during the facial restoration process. It overrides the Animation Region settings, meaning that the Animation Region will be ignored.' 216 | }, 217 | 'FaceExpressionRetargetingLipsMultiplierDecimalSlider': { 218 | 'level': 3, 219 | 'label': 'Retargeting Lips Multiplier', 220 | 'min_value': '0.00', 221 | 'max_value': '2.00', 222 | 'default': '1.00', 223 | 'decimals': 2, 224 | 'step': 0.01, 225 | 'parentToggle': 'FaceExpressionRetargetingLipsEnableToggle & FaceExpressionEnableToggle', 226 | 'requiredToggleValue': True, 227 | 'help': 'Multiplier value for Retargeting Lips.' 228 | }, 229 | }, 230 | } -------------------------------------------------------------------------------- /app/ui/widgets/event_filters.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from functools import partial 3 | 4 | from PySide6 import QtWidgets, QtGui, QtCore 5 | from app.ui.widgets.actions import list_view_actions 6 | from app.ui.widgets import ui_workers 7 | import app.helpers.miscellaneous as misc_helpers 8 | 9 | if TYPE_CHECKING: 10 | from app.ui.main_ui import MainWindow 11 | 12 | class GraphicsViewEventFilter(QtCore.QObject): 13 | def __init__(self, main_window: 'MainWindow', parent=None): 14 | super().__init__(parent) 15 | self.main_window = main_window 16 | 17 | def eventFilter(self, graphics_object: QtWidgets.QGraphicsView, event): 18 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 19 | if event.button() == QtCore.Qt.MouseButton.LeftButton: 20 | self.main_window.buttonMediaPlay.click() 21 | # You can emit a signal or call another function here 22 | return True # Mark the event as handled 23 | return False # Pass the event to the original handler 24 | 25 | class videoSeekSliderLineEditEventFilter(QtCore.QObject): 26 | def __init__(self, main_window: 'MainWindow', parent=None): 27 | super().__init__(parent) 28 | self.main_window = main_window 29 | 30 | def eventFilter(self, line_edit: QtWidgets.QLineEdit, event): 31 | if event.type() == QtCore.QEvent.KeyPress: 32 | # Check if the pressed key is Enter/Return 33 | if event.key() in (QtCore.Qt.Key_Enter, QtCore.Qt.Key_Return): 34 | new_value = line_edit.text() 35 | # Reset the line edit value to the slider value if the user input an empty text 36 | if new_value=='': 37 | new_value = str(self.main_window.videoSeekSlider.value()) 38 | else: 39 | new_value = int(new_value) 40 | max_frame_number = self.main_window.video_processor.max_frame_number 41 | # If the value entered by user if greater than the max no of frames in the video, set the new value to the max_frame_number 42 | if new_value > max_frame_number: 43 | new_value = max_frame_number 44 | # Update values of line edit and slider 45 | line_edit.setText(str(new_value)) 46 | self.main_window.videoSeekSlider.setValue(new_value) 47 | self.main_window.video_processor.process_current_frame() # Process the current frame 48 | 49 | return True 50 | return False 51 | 52 | class VideoSeekSliderEventFilter(QtCore.QObject): 53 | def __init__(self, main_window: 'MainWindow', parent=None): 54 | super().__init__(parent) 55 | self.main_window = main_window 56 | 57 | def eventFilter(self, slider, event): 58 | if event.type() == QtCore.QEvent.Type.KeyPress: 59 | if event.key() in {QtCore.Qt.Key_Left, QtCore.Qt.Key_Right}: 60 | # Allow default slider movement 61 | result = super().eventFilter(slider, event) 62 | 63 | # After the slider moves, call the custom processing function 64 | QtCore.QTimer.singleShot(0, self.main_window.video_processor.process_current_frame) 65 | 66 | return result # Return the result of the default handling 67 | elif event.type() == QtCore.QEvent.Type.Wheel: 68 | # Allow default slider movement 69 | result = super().eventFilter(slider, event) 70 | 71 | # After the slider moves, call the custom processing function 72 | QtCore.QTimer.singleShot(0, self.main_window.video_processor.process_current_frame) 73 | return result 74 | 75 | # For other events, use the default behavior 76 | return super().eventFilter(slider, event) 77 | 78 | class ListWidgetEventFilter(QtCore.QObject): 79 | def __init__(self, main_window: 'MainWindow', parent=None): 80 | super().__init__(parent) 81 | self.main_window = main_window 82 | 83 | def eventFilter(self, list_widget: QtWidgets.QListWidget, event: QtCore.QEvent|QtGui.QDropEvent|QtGui.QMouseEvent): 84 | 85 | if list_widget == self.main_window.targetVideosList or list_widget == self.main_window.targetVideosList.viewport(): 86 | 87 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 88 | if event.button() == QtCore.Qt.MouseButton.LeftButton and not self.main_window.target_videos: 89 | list_view_actions.select_target_medias(self.main_window, 'folder') 90 | 91 | elif event.type() == QtCore.QEvent.Type.DragEnter: 92 | # Accept drag events with URLs 93 | if event.mimeData().hasUrls(): 94 | 95 | urls = event.mimeData().urls() 96 | print("Drag: URLS", [url.toLocalFile() for url in urls]) 97 | event.acceptProposedAction() 98 | return True 99 | # Handle the drop event 100 | elif event.type() == QtCore.QEvent.Type.Drop: 101 | 102 | if event.mimeData().hasUrls(): 103 | # Extract file paths 104 | file_paths = [] 105 | for url in event.mimeData().urls(): 106 | url = url.toLocalFile() 107 | if misc_helpers.is_image_file(url) or misc_helpers.is_video_file(url): 108 | file_paths.append(url) 109 | else: 110 | print(f'{url} is not an Video or Image file') 111 | # print("Drop: URLS", [url.toLocalFile() for url in urls]) 112 | if file_paths: 113 | self.main_window.video_loader_worker = ui_workers.TargetMediaLoaderWorker(main_window=self.main_window, folder_name=False, files_list=file_paths) 114 | self.main_window.video_loader_worker.thumbnail_ready.connect(partial(list_view_actions.add_media_thumbnail_to_target_videos_list, self.main_window)) 115 | self.main_window.video_loader_worker.start() 116 | event.acceptProposedAction() 117 | return True 118 | 119 | 120 | elif list_widget == self.main_window.inputFacesList or list_widget == self.main_window.inputFacesList.viewport(): 121 | 122 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 123 | if event.button() == QtCore.Qt.MouseButton.LeftButton and not self.main_window.input_faces: 124 | list_view_actions.select_input_face_images(self.main_window, 'folder') 125 | 126 | elif event.type() == QtCore.QEvent.Type.DragEnter: 127 | # Accept drag events with URLs 128 | if event.mimeData().hasUrls(): 129 | 130 | urls = event.mimeData().urls() 131 | print("Drag: URLS", [url.toLocalFile() for url in urls]) 132 | event.acceptProposedAction() 133 | return True 134 | # Handle the drop event 135 | elif event.type() == QtCore.QEvent.Type.Drop: 136 | 137 | if event.mimeData().hasUrls(): 138 | # Extract file paths 139 | file_paths = [] 140 | for url in event.mimeData().urls(): 141 | url = url.toLocalFile() 142 | if misc_helpers.is_image_file(url): 143 | file_paths.append(url) 144 | else: 145 | print(f'{url} is not an Image file') 146 | # print("Drop: URLS", [url.toLocalFile() for url in urls]) 147 | if file_paths: 148 | self.main_window.input_faces_loader_worker = ui_workers.InputFacesLoaderWorker(main_window=self.main_window, folder_name=False, files_list=file_paths) 149 | self.main_window.input_faces_loader_worker.thumbnail_ready.connect(partial(list_view_actions.add_media_thumbnail_to_source_faces_list, self.main_window)) 150 | self.main_window.input_faces_loader_worker.start() 151 | event.acceptProposedAction() 152 | return True 153 | return super().eventFilter(list_widget, event) -------------------------------------------------------------------------------- /app/ui/widgets/settings_layout_data.py: -------------------------------------------------------------------------------- 1 | from app.ui.widgets.actions import control_actions 2 | import cv2 3 | from app.helpers.typing_helper import LayoutDictTypes 4 | SETTINGS_LAYOUT_DATA: LayoutDictTypes = { 5 | 'Appearance': { 6 | 'ThemeSelection': { 7 | 'level': 1, 8 | 'label': 'Theme', 9 | 'options': ['Dark', 'Dark-Blue', 'Light'], 10 | 'default': 'Dark', 11 | 'help': 'Select the theme to be used', 12 | 'exec_function': control_actions.change_theme, 13 | 'exec_function_args': [], 14 | }, 15 | }, 16 | 'General': { 17 | 'ProvidersPrioritySelection': { 18 | 'level': 1, 19 | 'label': 'Providers Priority', 20 | 'options': ['CUDA', 'TensorRT', 'TensorRT-Engine', 'CPU'], 21 | 'default': 'CUDA', 22 | 'help': 'Select the providers priority to be used with the system.', 23 | 'exec_function': control_actions.change_execution_provider, 24 | 'exec_function_args': [], 25 | }, 26 | 'nThreadsSlider': { 27 | 'level': 1, 28 | 'label': 'Number of Threads', 29 | 'min_value': '1', 30 | 'max_value': '30', 31 | 'default': '2', 32 | 'step': 1, 33 | 'help': 'Set number of execution threads while playing and recording. Depends strongly on GPU VRAM.', 34 | 'exec_function': control_actions.change_threads_number, 35 | 'exec_function_args': [], 36 | }, 37 | }, 38 | 'Video Settings': { 39 | 'VideoPlaybackCustomFpsToggle': { 40 | 'level': 1, 41 | 'label': 'Set Custom Video Playback FPS', 42 | 'default': False, 43 | 'help': 'Manually set the FPS to be used when playing the video', 44 | 'exec_function': control_actions.set_video_playback_fps, 45 | 'exec_function_args': [], 46 | }, 47 | 'VideoPlaybackCustomFpsSlider': { 48 | 'level': 2, 49 | 'label': 'Video Playback FPS', 50 | 'min_value': '1', 51 | 'max_value': '120', 52 | 'default': '30', 53 | 'parentToggle': 'VideoPlaybackCustomFpsToggle', 54 | 'requiredToggleValue': True, 55 | 'step': 1, 56 | 'help': 'Set the maximum FPS of the video when playing' 57 | }, 58 | }, 59 | 'Auto Swap':{ 60 | 'AutoSwapToggle': { 61 | 'level': 1, 62 | 'label': 'Auto Swap', 63 | 'default': False, 64 | 'help': 'Automatically Swap all faces using selected Source Faces/Embeddings when loading an video/image file' 65 | }, 66 | }, 67 | 'Detectors': { 68 | 'DetectorModelSelection': { 69 | 'level': 1, 70 | 'label': 'Face Detect Model', 71 | 'options': ['RetinaFace', 'Yolov8', 'SCRFD', 'Yunet'], 72 | 'default': 'RetinaFace', 73 | 'help': 'Select the face detection model to use for detecting faces in the input image or video.' 74 | }, 75 | 'DetectorScoreSlider': { 76 | 'level': 1, 77 | 'label': 'Detect Score', 78 | 'min_value': '1', 79 | 'max_value': '100', 80 | 'default': '50', 81 | 'step': 1, 82 | 'help': 'Set the confidence score threshold for face detection. Higher values ensure more confident detections but may miss some faces.' 83 | }, 84 | 'MaxFacesToDetectSlider': { 85 | 'level': 1, 86 | 'label': 'Max No of Faces to Detect', 87 | 'min_value': '1', 88 | 'max_value': '50', 89 | 'default': '20', 90 | 'step': 1, 91 | 'help': 'Set the maximum number of faces to detect in a frame' 92 | 93 | }, 94 | 'AutoRotationToggle': { 95 | 'level': 1, 96 | 'label': 'Auto Rotation', 97 | 'default': False, 98 | 'help': 'Automatically rotate the input to detect faces in various orientations.' 99 | }, 100 | 'ManualRotationEnableToggle': { 101 | 'level': 1, 102 | 'label': 'Manual Rotation', 103 | 'default': False, 104 | 'help': 'Rotate the face detector to better detect faces at different angles.' 105 | }, 106 | 'ManualRotationAngleSlider': { 107 | 'level': 2, 108 | 'label': 'Rotation Angle', 109 | 'min_value': '0', 110 | 'max_value': '270', 111 | 'default': '0', 112 | 'step': 90, 113 | 'parentToggle': 'ManualRotationEnableToggle', 114 | 'requiredToggleValue': True, 115 | 'help': 'Set this to the angle of the input face angle to help with laying down/upside down/etc. Angles are read clockwise.' 116 | }, 117 | 'LandmarkDetectToggle': { 118 | 'level': 1, 119 | 'label': 'Enable Landmark Detection', 120 | 'default': False, 121 | 'help': 'Enable or disable facial landmark detection, which is used to refine face alignment.' 122 | }, 123 | 'LandmarkDetectModelSelection': { 124 | 'level': 2, 125 | 'label': 'Landmark Detect Model', 126 | 'options': ['5', '68', '3d68', '98', '106', '203', '478'], 127 | 'default': '203', 128 | 'parentToggle': 'LandmarkDetectToggle', 129 | 'requiredToggleValue': True, 130 | 'help': 'Select the landmark detection model, where different models detect varying numbers of facial landmarks.' 131 | }, 132 | 'LandmarkDetectScoreSlider': { 133 | 'level': 2, 134 | 'label': 'Landmark Detect Score', 135 | 'min_value': '1', 136 | 'max_value': '100', 137 | 'default': '50', 138 | 'step': 1, 139 | 'parentToggle': 'LandmarkDetectToggle', 140 | 'requiredToggleValue': True, 141 | 'help': 'Set the confidence score threshold for facial landmark detection.' 142 | }, 143 | 'DetectFromPointsToggle': { 144 | 'level': 2, 145 | 'label': 'Detect From Points', 146 | 'default': False, 147 | 'parentToggle': 'LandmarkDetectToggle', 148 | 'requiredToggleValue': True, 149 | 'help': 'Enable detection of faces from specified landmark points.' 150 | }, 151 | 'ShowLandmarksEnableToggle': { 152 | 'level': 1, 153 | 'label': 'Show Landmarks', 154 | 'default': False, 155 | 'help': 'Show Landmarks in realtime.' 156 | }, 157 | 'ShowAllDetectedFacesBBoxToggle': { 158 | 'level': 1, 159 | 'label': 'Show Bounding Boxes', 160 | 'default': False, 161 | 'help': 'Draw bounding boxes to all detected faces in the frame' 162 | } 163 | }, 164 | 'DFM Settings':{ 165 | 'MaxDFMModelsSlider':{ 166 | 'level': 1, 167 | 'label': 'Maximum DFM Models to use', 168 | 'min_value': '1', 169 | 'max_value': '5', 170 | 'default': '1', 171 | 'step': 1, 172 | 'help': "Set the maximum number of DFM Models to keep in memory at a time. Set this based on your GPU's VRAM", 173 | } 174 | }, 175 | 'Frame Enhancer':{ 176 | 'FrameEnhancerEnableToggle':{ 177 | 'level': 1, 178 | 'label': 'Enable Frame Enhancer', 179 | 'default': False, 180 | 'help': 'Enable frame enhancement for video inputs to improve visual quality.' 181 | }, 182 | 'FrameEnhancerTypeSelection':{ 183 | 'level': 2, 184 | 'label': 'Frame Enhancer Type', 185 | 'options': ['RealEsrgan-x2-Plus', 'RealEsrgan-x4-Plus', 'RealEsr-General-x4v3', 'BSRGan-x2', 'BSRGan-x4', 'UltraSharp-x4', 'UltraMix-x4', 'DDColor-Artistic', 'DDColor', 'DeOldify-Artistic', 'DeOldify-Stable', 'DeOldify-Video'], 186 | 'default': 'RealEsrgan-x2-Plus', 187 | 'parentToggle': 'FrameEnhancerEnableToggle', 188 | 'requiredToggleValue': True, 189 | 'help': 'Select the type of frame enhancement to apply, based on the content and resolution requirements.' 190 | }, 191 | 'FrameEnhancerBlendSlider': { 192 | 'level': 2, 193 | 'label': 'Blend', 194 | 'min_value': '0', 195 | 'max_value': '100', 196 | 'default': '100', 197 | 'step': 1, 198 | 'parentToggle': 'FrameEnhancerEnableToggle', 199 | 'requiredToggleValue': True, 200 | 'help': 'Blends the enhanced results back into the original frame.' 201 | }, 202 | }, 203 | 'Webcam Settings': { 204 | 'WebcamMaxNoSelection': { 205 | 'level': 2, 206 | 'label': 'Webcam Max No', 207 | 'options': ['1', '2', '3', '4', '5', '6'], 208 | 'default': '1', 209 | 'help': 'Select the maximum number of webcam streams to allow for face swapping.' 210 | }, 211 | 'WebcamBackendSelection': { 212 | 'level': 2, 213 | 'label': 'Webcam Backend', 214 | 'options': ['Default', 'DirectShow', 'MSMF', 'V4L', 'V4L2', 'GSTREAMER'], 215 | 'default': 'Default', 216 | 'help': 'Choose the backend for accessing webcam input.' 217 | }, 218 | 'WebcamMaxResSelection': { 219 | 'level': 2, 220 | 'label': 'Webcam Resolution', 221 | 'options': ['480x360', '640x480', '1280x720', '1920x1080', '2560x1440', '3840x2160'], 222 | 'default': '1280x720', 223 | 'help': 'Select the maximum resolution for webcam input.' 224 | }, 225 | 'WebCamMaxFPSSelection': { 226 | 'level': 2, 227 | 'label': 'Webcam FPS', 228 | 'options': ['23', '30', '60'], 229 | 'default': '30', 230 | 'help': 'Set the maximum frames per second (FPS) for webcam input.' 231 | }, 232 | }, 233 | 'Virtual Camera': { 234 | 'SendVirtCamFramesEnableToggle': { 235 | 'level': 1, 236 | 'label': 'Send Frames to Virtual Camera', 237 | 'default': False, 238 | 'help': 'Send the swapped video/webcam output to virtual camera for using in external applications', 239 | 'exec_function': control_actions.toggle_virtualcam, 240 | 'exec_function_args': [], 241 | }, 242 | 'VirtCamBackendSelection': { 243 | 'level': 1, 244 | 'label': 'Virtual Camera Backend', 245 | 'options': ['obs', 'unitycapture'], 246 | 'default': 'obs', 247 | 'help': 'Choose the backend based on the Virtual Camera you have set up', 248 | 'parentToggle': 'SendVirtCamFramesEnableToggle', 249 | 'requiredToggleValue': True, 250 | 'exec_function': control_actions.enable_virtualcam, 251 | 'exec_funtion_args': [], 252 | }, 253 | }, 254 | 'Face Recognition': { 255 | 'RecognitionModelSelection': { 256 | 'level': 1, 257 | 'label': 'Recognition Model', 258 | 'options': ['Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'CSCSArcFace'], 259 | 'default': 'Inswapper128ArcFace', 260 | 'help': 'Choose the ArcFace model to be used for comparing the similarity of faces.' 261 | }, 262 | 'SimilarityTypeSelection': { 263 | 'level': 1, 264 | 'label': 'Swapping Similarity Type', 265 | 'options': ['Opal', 'Pearl', 'Optimal'], 266 | 'default': 'Opal', 267 | 'help': 'Choose the type of similarity calculation for face detection and matching during the face swapping process.' 268 | }, 269 | }, 270 | 'Embedding Merge Method':{ 271 | 'EmbMergeMethodSelection':{ 272 | 'level': 1, 273 | 'label': 'Embedding Merge Method', 274 | 'options': ['Mean','Median'], 275 | 'default': 'Mean', 276 | 'help': 'Select the method to merge facial embeddings. "Mean" averages the embeddings, while "Median" selects the middle value, providing more robustness to outliers.' 277 | } 278 | }, 279 | 'Media Selection':{ 280 | 'TargetMediaFolderRecursiveToggle':{ 281 | 'level': 1, 282 | 'label': 'Target Media Include Subfolders', 283 | 'default': False, 284 | 'help': 'Include all files from Subfolders when choosing Target Media Folder' 285 | }, 286 | 'InputFacesFolderRecursiveToggle':{ 287 | 'level': 1, 288 | 'label': 'Input Faces Include Subfolders', 289 | 'default': False, 290 | 'help': 'Include all files from Subfolders when choosing Input Faces Folder' 291 | } 292 | } 293 | } 294 | 295 | CAMERA_BACKENDS = { 296 | 'Default': cv2.CAP_ANY, 297 | 'DirectShow': cv2.CAP_DSHOW, 298 | 'MSMF': cv2.CAP_MSMF, 299 | 'V4L': cv2.CAP_V4L, 300 | 'V4L2': cv2.CAP_V4L2, 301 | 'GSTREAMER': cv2.CAP_GSTREAMER, 302 | } -------------------------------------------------------------------------------- /app/ui/widgets/ui_workers.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from functools import partial 3 | from typing import TYPE_CHECKING, Dict 4 | import traceback 5 | import os 6 | 7 | import cv2 8 | import torch 9 | import numpy 10 | from PySide6 import QtCore as qtc 11 | from PySide6.QtGui import QPixmap 12 | 13 | from app.processors.models_data import detection_model_mapping, landmark_model_mapping 14 | from app.helpers import miscellaneous as misc_helpers 15 | from app.ui.widgets.actions import common_actions as common_widget_actions 16 | from app.ui.widgets.actions import filter_actions 17 | from app.ui.widgets.settings_layout_data import SETTINGS_LAYOUT_DATA, CAMERA_BACKENDS 18 | 19 | if TYPE_CHECKING: 20 | from app.ui.main_ui import MainWindow 21 | 22 | class TargetMediaLoaderWorker(qtc.QThread): 23 | # Define signals to emit when loading is done or if there are updates 24 | thumbnail_ready = qtc.Signal(str, QPixmap, str, str) # Signal with media path and QPixmap and file_type, media_id 25 | webcam_thumbnail_ready = qtc.Signal(str, QPixmap, str, str, int, int) 26 | finished = qtc.Signal() # Signal to indicate completion 27 | 28 | def __init__(self, main_window: 'MainWindow', folder_name=False, files_list=None, media_ids=None, webcam_mode=False, parent=None,): 29 | super().__init__(parent) 30 | self.main_window = main_window 31 | self.folder_name = folder_name 32 | self.files_list = files_list or [] 33 | self.media_ids = media_ids or [] 34 | self.webcam_mode = webcam_mode 35 | self._running = True # Flag to control the running state 36 | 37 | # Ensure thumbnail directory exists 38 | misc_helpers.ensure_thumbnail_dir() 39 | 40 | def run(self): 41 | if self.folder_name: 42 | self.load_videos_and_images_from_folder(self.folder_name) 43 | if self.files_list: 44 | self.load_videos_and_images_from_files_list(self.files_list) 45 | if self.webcam_mode: 46 | self.load_webcams() 47 | self.finished.emit() 48 | 49 | def load_videos_and_images_from_folder(self, folder_name): 50 | # Initially hide the placeholder text 51 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 52 | video_files = misc_helpers.get_video_files(folder_name, self.main_window.control['TargetMediaFolderRecursiveToggle']) 53 | image_files = misc_helpers.get_image_files(folder_name, self.main_window.control['TargetMediaFolderRecursiveToggle']) 54 | 55 | i=0 56 | media_files = video_files + image_files 57 | for media_file in media_files: 58 | if not self._running: # Check if the thread is still running 59 | break 60 | media_file_path = os.path.join(folder_name, media_file) 61 | file_type = misc_helpers.get_file_type(media_file_path) 62 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path, file_type) 63 | if self.media_ids: 64 | media_id = self.media_ids[i] 65 | else: 66 | media_id = str(uuid.uuid1().int) 67 | if pixmap: 68 | # Emit the signal to update GUI 69 | self.thumbnail_ready.emit(media_file_path, pixmap, file_type, media_id) 70 | i+=1 71 | # Show/Hide the placeholder text based on the number of items in ListWidget 72 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 73 | 74 | def load_videos_and_images_from_files_list(self, files_list): 75 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 76 | media_files = files_list 77 | i=0 78 | for media_file_path in media_files: 79 | if not self._running: # Check if the thread is still running 80 | break 81 | file_type = misc_helpers.get_file_type(media_file_path) 82 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path, file_type=file_type) 83 | if self.media_ids: 84 | media_id = self.media_ids[i] 85 | else: 86 | media_id = str(uuid.uuid1().int) 87 | if pixmap: 88 | # Emit the signal to update GUI 89 | self.thumbnail_ready.emit(media_file_path, pixmap, file_type,media_id) 90 | i+=1 91 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 92 | 93 | def load_webcams(self,): 94 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 95 | camera_backend = CAMERA_BACKENDS[self.main_window.control['WebcamBackendSelection']] 96 | for i in range(int(self.main_window.control['WebcamMaxNoSelection'])): 97 | try: 98 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path=f'Webcam {i}', file_type='webcam', webcam_index=i, webcam_backend=camera_backend) 99 | media_id = str(uuid.uuid1().int) 100 | 101 | if pixmap: 102 | # Emit the signal to update GUI 103 | self.webcam_thumbnail_ready.emit(f'Webcam {i}', pixmap, 'webcam',media_id, i, camera_backend) 104 | except Exception: # pylint: disable=broad-exception-caught 105 | traceback.print_exc() 106 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 107 | 108 | def stop(self): 109 | """Stop the thread by setting the running flag to False.""" 110 | self._running = False 111 | self.wait() 112 | 113 | class InputFacesLoaderWorker(qtc.QThread): 114 | # Define signals to emit when loading is done or if there are updates 115 | thumbnail_ready = qtc.Signal(str, numpy.ndarray, object, QPixmap, str) 116 | finished = qtc.Signal() # Signal to indicate completion 117 | def __init__(self, main_window: 'MainWindow', media_path=False, folder_name=False, files_list=None, face_ids=None, parent=None): 118 | super().__init__(parent) 119 | self.main_window = main_window 120 | self.folder_name = folder_name 121 | self.files_list = files_list or [] 122 | self.face_ids = face_ids or [] 123 | self._running = True # Flag to control the running state 124 | self.was_playing = True 125 | self.pre_load_detection_recognition_models() 126 | 127 | def pre_load_detection_recognition_models(self): 128 | control = self.main_window.control.copy() 129 | detect_model = detection_model_mapping[control['DetectorModelSelection']] 130 | landmark_detect_model = landmark_model_mapping[control['LandmarkDetectModelSelection']] 131 | models_processor = self.main_window.models_processor 132 | if self.main_window.video_processor.processing: 133 | was_playing = True 134 | self.main_window.buttonMediaPlay.click() 135 | else: 136 | was_playing = False 137 | if not models_processor.models[detect_model]: 138 | models_processor.models[detect_model] = models_processor.load_model(detect_model) 139 | if not models_processor.models[landmark_detect_model] and control['LandmarkDetectToggle']: 140 | models_processor.models[landmark_detect_model] = models_processor.load_model(landmark_detect_model) 141 | for recognition_model in ['Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'CSCSArcFace', 'CSCSIDArcFace']: 142 | if not models_processor.models[recognition_model]: 143 | models_processor.models[recognition_model] = models_processor.load_model(recognition_model) 144 | if was_playing: 145 | self.main_window.buttonMediaPlay.click() 146 | 147 | def run(self): 148 | if self.folder_name or self.files_list: 149 | self.main_window.placeholder_update_signal.emit(self.main_window.inputFacesList, True) 150 | self.load_faces(self.folder_name, self.files_list) 151 | self.main_window.placeholder_update_signal.emit(self.main_window.inputFacesList, False) 152 | 153 | def load_faces(self, folder_name=False, files_list=None): 154 | control = self.main_window.control.copy() 155 | files_list = files_list or [] 156 | image_files = [] 157 | if folder_name: 158 | image_files = misc_helpers.get_image_files(self.folder_name, self.main_window.control['InputFacesFolderRecursiveToggle']) 159 | elif files_list: 160 | image_files = files_list 161 | 162 | i=0 163 | image_files.sort() 164 | for image_file_path in image_files: 165 | if not self._running: # Check if the thread is still running 166 | break 167 | if not misc_helpers.is_image_file(image_file_path): 168 | return 169 | if folder_name: 170 | image_file_path = os.path.join(folder_name, image_file_path) 171 | frame = misc_helpers.read_image_file(image_file_path) 172 | if frame is None: 173 | continue 174 | # Frame must be in RGB format 175 | frame = frame[..., ::-1] # Swap the channels from BGR to RGB 176 | 177 | img = torch.from_numpy(frame.astype('uint8')).to(self.main_window.models_processor.device) 178 | img = img.permute(2,0,1) 179 | _, kpss_5, _ = self.main_window.models_processor.run_detect(img, control['DetectorModelSelection'], max_num=1, score=control['DetectorScoreSlider']/100.0, input_size=(512, 512), use_landmark_detection=control['LandmarkDetectToggle'], landmark_detect_mode=control['LandmarkDetectModelSelection'], landmark_score=control["LandmarkDetectScoreSlider"]/100.0, from_points=control["DetectFromPointsToggle"], rotation_angles=[0] if not control["AutoRotationToggle"] else [0, 90, 180, 270]) 180 | 181 | # If atleast one face is found 182 | # found_face = [] 183 | face_kps = False 184 | try: 185 | face_kps = kpss_5[0] 186 | except IndexError: 187 | continue 188 | if face_kps.any(): 189 | face_emb, cropped_img = self.main_window.models_processor.run_recognize_direct(img, face_kps, control['SimilarityTypeSelection'], control['RecognitionModelSelection']) 190 | cropped_img = cropped_img.cpu().numpy() 191 | cropped_img = cropped_img[..., ::-1] # Swap the channels from RGB to BGR 192 | face_img = numpy.ascontiguousarray(cropped_img) 193 | # crop = cv2.resize(face[2].cpu().numpy(), (82, 82)) 194 | pixmap = common_widget_actions.get_pixmap_from_frame(self.main_window, face_img) 195 | 196 | embedding_store: Dict[str, numpy.ndarray] = {} 197 | # Ottenere i valori di 'options' 198 | options = SETTINGS_LAYOUT_DATA['Face Recognition']['RecognitionModelSelection']['options'] 199 | for option in options: 200 | if option != control['RecognitionModelSelection']: 201 | target_emb, _ = self.main_window.models_processor.run_recognize_direct(img, face_kps, control['SimilarityTypeSelection'], option) 202 | embedding_store[option] = target_emb 203 | else: 204 | embedding_store[control['RecognitionModelSelection']] = face_emb 205 | if not self.face_ids: 206 | face_id = str(uuid.uuid1().int) 207 | else: 208 | face_id = self.face_ids[i] 209 | self.thumbnail_ready.emit(image_file_path, face_img, embedding_store, pixmap, face_id) 210 | i+=1 211 | torch.cuda.empty_cache() 212 | self.finished.emit() 213 | 214 | def stop(self): 215 | """Stop the thread by setting the running flag to False.""" 216 | self._running = False 217 | self.wait() 218 | 219 | class FilterWorker(qtc.QThread): 220 | filtered_results = qtc.Signal(list) 221 | 222 | def __init__(self, main_window: 'MainWindow', search_text='', filter_list='target_videos'): 223 | super().__init__() 224 | self.main_window = main_window 225 | self.search_text = search_text 226 | self.filter_list = filter_list 227 | self.filter_list_widget = self.get_list_widget() 228 | self.filtered_results.connect(partial(filter_actions.update_filtered_list, main_window, self.filter_list_widget)) 229 | 230 | def get_list_widget(self,): 231 | list_widget = False 232 | if self.filter_list == 'target_videos': 233 | list_widget = self.main_window.targetVideosList 234 | elif self.filter_list == 'input_faces': 235 | list_widget = self.main_window.inputFacesList 236 | elif self.filter_list == 'merged_embeddings': 237 | list_widget = self.main_window.inputEmbeddingsList 238 | return list_widget 239 | 240 | def run(self,): 241 | if self.filter_list == 'target_videos': 242 | self.filter_target_videos(self.main_window, self.search_text) 243 | elif self.filter_list == 'input_faces': 244 | self.filter_input_faces(self.main_window, self.search_text) 245 | elif self.filter_list == 'merged_embeddings': 246 | self.filter_merged_embeddings(self.main_window, self.search_text) 247 | 248 | 249 | def filter_target_videos(self, main_window: 'MainWindow', search_text: str = ''): 250 | search_text = main_window.targetVideosSearchBox.text().lower() 251 | include_file_types = [] 252 | if main_window.filterImagesCheckBox.isChecked(): 253 | include_file_types.append('image') 254 | if main_window.filterVideosCheckBox.isChecked(): 255 | include_file_types.append('video') 256 | if main_window.filterWebcamsCheckBox.isChecked(): 257 | include_file_types.append('webcam') 258 | 259 | visible_indices = [] 260 | for i in range(main_window.targetVideosList.count()): 261 | item = main_window.targetVideosList.item(i) 262 | item_widget = main_window.targetVideosList.itemWidget(item) 263 | if ((not search_text or search_text in item_widget.media_path.lower()) and 264 | (item_widget.file_type in include_file_types)): 265 | visible_indices.append(i) 266 | 267 | self.filtered_results.emit(visible_indices) 268 | 269 | def filter_input_faces(self, main_window: 'MainWindow', search_text: str): 270 | search_text = search_text.lower() 271 | visible_indices = [] 272 | 273 | for i in range(main_window.inputFacesList.count()): 274 | item = main_window.inputFacesList.item(i) 275 | item_widget = main_window.inputFacesList.itemWidget(item) 276 | if not search_text or search_text in item_widget.media_path.lower(): 277 | visible_indices.append(i) 278 | 279 | self.filtered_results.emit(visible_indices) 280 | 281 | def filter_merged_embeddings(self, main_window: 'MainWindow', search_text: str): 282 | search_text = search_text.lower() 283 | visible_indices = [] 284 | 285 | for i in range(main_window.inputEmbeddingsList.count()): 286 | item = main_window.inputEmbeddingsList.item(i) 287 | item_widget = main_window.inputEmbeddingsList.itemWidget(item) 288 | if not search_text or search_text in item_widget.embedding_name.lower(): 289 | visible_indices.append(i) 290 | 291 | self.filtered_results.emit(visible_indices) 292 | 293 | def stop_thread(self): 294 | self.quit() 295 | self.wait() 296 | -------------------------------------------------------------------------------- /dependencies/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/dependencies/.gitkeep -------------------------------------------------------------------------------- /download_models.py: -------------------------------------------------------------------------------- 1 | from app.helpers.downloader import download_file 2 | from app.processors.models_data import models_list 3 | 4 | for model_data in models_list: 5 | download_file(model_data['model_name'], model_data['local_path'], model_data['hash'], model_data['url']) 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from app.ui import main_ui 2 | from PySide6 import QtWidgets 3 | import sys 4 | 5 | import qdarktheme 6 | from app.ui.core.proxy_style import ProxyStyle 7 | 8 | if __name__=="__main__": 9 | 10 | app = QtWidgets.QApplication(sys.argv) 11 | app.setStyle(ProxyStyle()) 12 | with open("app/ui/styles/dark_styles.qss", "r") as f: 13 | _style = f.read() 14 | _style = qdarktheme.load_stylesheet(custom_colors={"primary": "#4facc9"})+'\n'+_style 15 | app.setStyleSheet(_style) 16 | window = main_ui.MainWindow() 17 | window.show() 18 | app.exec() -------------------------------------------------------------------------------- /model_assets/dfm_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/dfm_models/.gitkeep -------------------------------------------------------------------------------- /model_assets/grid_sample_3d_plugin.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/grid_sample_3d_plugin.dll -------------------------------------------------------------------------------- /model_assets/libgrid_sample_3d_plugin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/libgrid_sample_3d_plugin.so -------------------------------------------------------------------------------- /model_assets/liveportrait_onnx/lip_array.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/liveportrait_onnx/lip_array.pkl -------------------------------------------------------------------------------- /model_assets/meanshape_68.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/meanshape_68.pkl -------------------------------------------------------------------------------- /requirements_cu118.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | 3 | numpy==1.26.4 4 | opencv-python==4.10.0.84 5 | scikit-image==0.21.0 6 | pillow==9.5.0 7 | onnx==1.16.1 8 | protobuf==4.23.2 9 | psutil==6.0.0 10 | onnxruntime-gpu==1.18.0 11 | packaging==24.1 12 | PySide6==6.8.2.1 13 | kornia 14 | torch==2.4.1+cu118 15 | torchvision==0.19.1+cu118 16 | torchaudio==2.4.1+cu118 17 | tqdm 18 | ftfy 19 | regex 20 | pyvirtualcam==0.11.1 21 | tensorrt-cu11==10.4.0 22 | numexpr 23 | onnxsim 24 | requests 25 | pyqt-toast-notification==1.3.2 26 | qdarkstyle 27 | pyqtdarktheme -------------------------------------------------------------------------------- /requirements_cu124.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu124 2 | 3 | numpy==1.26.4 4 | opencv-python==4.10.0.84 5 | scikit-image==0.21.0 6 | pillow==9.5.0 7 | onnx==1.16.1 8 | protobuf==4.23.2 9 | psutil==6.0.0 10 | onnxruntime-gpu==1.20.0 11 | packaging==24.1 12 | PySide6==6.8.2.1 13 | kornia 14 | torch==2.4.1+cu124 15 | torchvision==0.19.1+cu124 16 | torchaudio==2.4.1+cu124 17 | tensorrt==10.6.0 --extra-index-url https://pypi.nvidia.com 18 | tensorrt-cu12_libs==10.6.0 19 | tensorrt-cu12_bindings==10.6.0 20 | tqdm 21 | ftfy 22 | regex 23 | pyvirtualcam==0.11.1 24 | numexpr 25 | onnxsim 26 | requests 27 | pyqt-toast-notification==1.3.2 28 | qdarkstyle 29 | pyqtdarktheme 30 | -------------------------------------------------------------------------------- /scripts/setenv.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | :: Get the parent directory of the script location 4 | SET "VISO_ROOT=%~dp0" 5 | SET "VISO_ROOT=%VISO_ROOT:~0,-1%" 6 | FOR %%A IN ("%VISO_ROOT%\..") DO SET "VISO_ROOT=%%~fA" 7 | 8 | :: Define dependencies directory 9 | SET "DEPENDENCIES=%VISO_ROOT%\dependencies" 10 | 11 | SET "GIT_EXECUTABLE=%DEPENDENCIES%\git-portable\bin\git.exe" 12 | 13 | :: Define Python paths 14 | SET "PYTHON_PATH=%DEPENDENCIES%\Python" 15 | SET "PYTHON_SCRIPTS=%PYTHON_PATH%\Scripts" 16 | SET "PYTHON_EXECUTABLE=%PYTHON_PATH%\python.exe" 17 | SET "PYTHONW_EXECUTABLE=%PYTHON_PATH%\pythonw.exe" 18 | 19 | :: Define CUDA and TensorRT paths 20 | SET "CUDA_PATH=%DEPENDENCIES%\CUDA" 21 | SET "CUDA_BIN_PATH=%CUDA_PATH%\bin" 22 | SET "TENSORRT_PATH=%DEPENDENCIES%\TensorRt\lib" 23 | 24 | :: Define FFMPEG path correctly 25 | SET "FFMPEG_PATH=%DEPENDENCIES%" 26 | 27 | :: Add all necessary paths to system PATH 28 | SET "PATH=%FFMPEG_PATH%;%PYTHON_PATH%;%PYTHON_SCRIPTS%;%CUDA_BIN_PATH%;%TENSORRT_PATH%;%PATH%" 29 | -------------------------------------------------------------------------------- /scripts/update_cu118.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call scripts\setenv.bat 3 | "%GIT_EXECUTABLE%" fetch origin main 4 | "%GIT_EXECUTABLE%" reset --hard origin/main 5 | "%PYTHON_EXECUTABLE%" -m pip install -r requirements_cu118.txt --default-timeout 100 6 | "%PYTHON_EXECUTABLE%" download_models.py -------------------------------------------------------------------------------- /scripts/update_cu124.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call scripts\setenv.bat 3 | "%GIT_EXECUTABLE%" fetch origin main 4 | "%GIT_EXECUTABLE%" reset --hard origin/main 5 | "%PYTHON_EXECUTABLE%" -m pip install -r requirements_cu124.txt --default-timeout 100 6 | "%PYTHON_EXECUTABLE%" download_models.py -------------------------------------------------------------------------------- /tools/convert_old_rope_embeddings.py: -------------------------------------------------------------------------------- 1 | # Script Usage Example 2 | # 'python3 convert_old_rope_embeddings.py old_merged_embeddings.txt new_merged_embeddings.json' 3 | 4 | import json 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser("Rope Embeddings Converter") 8 | parser.add_argument("old_embeddings_file", help="Old Embeddings File", type=str) 9 | parser.add_argument("--output_embeddings_file", help="New Embeddings File", type=str) 10 | parser.add_argument("--recognizer_model", help="Face Recognizer Model using which the embedding was created", default='Inswapper128ArcFace', choices=('Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'GhostArcFace', 'GhostArcFace'), type=str) 11 | args = parser.parse_args() 12 | input_filename = args.old_embeddings_file 13 | 14 | output_filename = args.output_embeddings_file or f'{input_filename.split(".")[0]}_converted.json' 15 | 16 | recognizer_model = args.recognizer_model 17 | temp0 = [] 18 | new_embed_list = [] 19 | with open(input_filename, "r") as embedfile: 20 | old_data = embedfile.read().splitlines() 21 | 22 | for i in range(0, len(old_data), 513): 23 | new_embed_data = {'name': old_data[i][6:], 'embedding_store': {recognizer_model: old_data[i+1:i+513]}} 24 | for i, val in enumerate(new_embed_data['embedding_store'][recognizer_model]): 25 | new_embed_data['embedding_store'][recognizer_model][i] = float(val) 26 | new_embed_list.append(new_embed_data) 27 | 28 | with open(output_filename, 'w') as embed_file: 29 | embeds_data = json.dumps(new_embed_list) 30 | embed_file.write(embeds_data) --------------------------------------------------------------------------------