├── .github ├── FUNDING.yml └── screenshot.png ├── .gitignore ├── LICENSE ├── README.md ├── Start.bat ├── Start_Portable.bat ├── Update_Portable.bat ├── app ├── helpers │ ├── downloader.py │ ├── integrity_checker.py │ ├── miscellaneous.py │ ├── recording.py │ └── typing_helper.py ├── onnxmodels │ ├── .gitkeep │ ├── dfm_models │ │ └── .keep │ └── place_model_files_here ├── processors │ ├── external │ │ ├── cliplib │ │ │ ├── __init__.py │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── clip.py │ │ │ ├── model.py │ │ │ └── simple_tokenizer.py │ │ ├── clipseg.py │ │ └── resnet.py │ ├── face_detectors.py │ ├── face_editors.py │ ├── face_landmark_detectors.py │ ├── face_masks.py │ ├── face_restorers.py │ ├── face_swappers.py │ ├── frame_enhancers.py │ ├── models_data.py │ ├── models_processor.py │ ├── utils │ │ ├── dfm_model.py │ │ ├── engine_builder.py │ │ ├── faceutil.py │ │ └── tensorrt_predictor.py │ ├── video_processor.py │ └── workers │ │ └── frame_worker.py └── ui │ ├── core │ ├── MainWindow.ui │ ├── convert_ui_to_py.bat │ ├── main_window.py │ ├── media.qrc │ ├── media │ │ ├── OffState.png │ │ ├── OnState.png │ │ ├── add_marker_hover.png │ │ ├── add_marker_off.png │ │ ├── audio_off.png │ │ ├── audio_on.png │ │ ├── fullscreen.png │ │ ├── image.png │ │ ├── marker.png │ │ ├── marker_save.png │ │ ├── next_marker_hover.png │ │ ├── next_marker_off.png │ │ ├── open_file.png │ │ ├── play_hover.png │ │ ├── play_off.png │ │ ├── play_on.png │ │ ├── previous_marker_hover.png │ │ ├── previous_marker_off.png │ │ ├── rec_hover.png │ │ ├── rec_off.png │ │ ├── rec_on.png │ │ ├── remove_marker_hover.png │ │ ├── remove_marker_off.png │ │ ├── repeat.png │ │ ├── reset_default.png │ │ ├── save.png │ │ ├── save_file.png │ │ ├── save_file_as.png │ │ ├── splash.png │ │ ├── splash_next.png │ │ ├── stop_hover.png │ │ ├── stop_off.png │ │ ├── stop_on.png │ │ ├── tl_beg_hover.png │ │ ├── tl_beg_off.png │ │ ├── tl_beg_on.png │ │ ├── tl_left_hover.png │ │ ├── tl_left_off.png │ │ ├── tl_left_on.png │ │ ├── tl_right_hover.png │ │ ├── tl_right_off.png │ │ ├── tl_right_on.png │ │ ├── video.png │ │ ├── visomaster_full.png │ │ ├── visomaster_small.png │ │ └── webcam.png │ ├── media_rc.py │ └── proxy_style.py │ ├── main_ui.py │ ├── styles │ ├── dark_styles.qss │ └── light_styles.qss │ └── widgets │ ├── actions │ ├── card_actions.py │ ├── common_actions.py │ ├── control_actions.py │ ├── filter_actions.py │ ├── graphics_view_actions.py │ ├── layout_actions.py │ ├── list_view_actions.py │ ├── save_load_actions.py │ └── video_control_actions.py │ ├── common_layout_data.py │ ├── event_filters.py │ ├── face_editor_layout_data.py │ ├── settings_layout_data.py │ ├── swapper_layout_data.py │ ├── ui_workers.py │ └── widget_components.py ├── dependencies └── .gitkeep ├── download_models.py ├── main.py ├── model_assets ├── dfm_models │ └── .gitkeep ├── grid_sample_3d_plugin.dll ├── libgrid_sample_3d_plugin.so ├── liveportrait_onnx │ └── lip_array.pkl └── meanshape_68.pkl ├── requirements_cu118.txt ├── requirements_cu124.txt ├── scripts ├── setenv.bat ├── update_cu118.bat └── update_cu124.bat └── tools └── convert_old_rope_embeddings.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: ['https://github.com/visomaster/VisoMaster?tab=readme-ov-file#support-the-project'] 16 | -------------------------------------------------------------------------------- /.github/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/.github/screenshot.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | 3 | *.ckpt 4 | *.pth 5 | *.onnx 6 | *.engine 7 | *.profile 8 | *.timing 9 | *.dfm 10 | *.trt 11 | models/liveportrait_onnx/*.onnx 12 | models/liveportrait_onnx/*.trt 13 | saved_parameters*.json 14 | startup_parameters*.json 15 | data.json 16 | merged_embeddings*.txt 17 | .vs 18 | *.sln 19 | *.pyproj 20 | *.json 21 | .vscode/ 22 | tensorrt-engines/ 23 | source_videos/ 24 | source_images/ 25 | output/ 26 | test_frames*/ 27 | test_videos*/ 28 | install.dat 29 | visomaster.ico 30 | dependencies/CUDA/ 31 | dependencies/Python/ 32 | dependencies/git-portable/ 33 | dependencies/TensorRT/ 34 | *.mp4 35 | *.jpg 36 | *.exe 37 | .thumbnails -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # VisoMaster 3 | ### VisoMaster is a powerful yet easy-to-use tool for face swapping and editing in images and videos. It utilizes AI to produce natural-looking results with minimal effort, making it ideal for both casual users and professionals. 4 | 5 | --- 6 | 7 | 8 | ## Features 9 | 10 | ### 🔄 **Face Swap** 11 | - Supports multiple face swapper models 12 | - Compatible with DeepFaceLab trained models (DFM) 13 | - Advanced multi-face swapping with masking options for each facial part 14 | - Occlusion masking support (DFL XSeg Masking) 15 | - Works with all popular face detectors & landmark detectors 16 | - Expression Restorer: Transfers original expressions to the swapped face 17 | - Face Restoration: Supports all popular upscaling & enhancement models 18 | 19 | ### 🎭 **Face Editor (LivePortrait Models)** 20 | - Manually adjust expressions and poses for different face parts 21 | - Fine-tune colors for Face, Hair, Eyebrows, and Lips using RGB adjustments 22 | 23 | ### 🚀 **Other Powerful Features** 24 | - **Live Playback**: See processed video in real-time before saving 25 | - **Face Embeddings**: Use multiple source faces for better accuracy & similarity 26 | - **Live Swapping via Webcam**: Stream to virtual camera for Twitch, YouTube, Zoom, etc. 27 | - **User-Friendly Interface**: Intuitive and easy to use 28 | - **Video Markers**: Adjust settings per frame for precise results 29 | - **TensorRT Support**: Leverages supported GPUs for ultra-fast processing 30 | - **Many More Advanced Features** 🎉 31 | 32 | ## Automatic Installation (Windows) 33 | - For Windows users with an Nvidia GPU, we provide an automatic installer for easy set up. 34 | - You can get the installer from the [releases](https://github.com/visomaster/VisoMaster/releases/tag/v0.1.1) page or from this [link](https://github.com/visomaster/VisoMaster/releases/download/v0.1.1/VisoMaster_Setup.exe). 35 | - Choose the correct CUDA version inside the installer based on your GPU Compatibility. 36 | - After successful installation, go to your installed directory and run the **Start_Portable.bat** file to launch **VisoMaster** 37 | 38 | ## **Manual Installation Guide (Nvidia)** 39 | 40 | Follow the steps below to install and run **VisoMaster** on your system. 41 | 42 | ## **Prerequisites** 43 | Before proceeding, ensure you have the following installed on your system: 44 | - **Git** ([Download](https://git-scm.com/downloads)) 45 | - **Miniconda** ([Download](https://www.anaconda.com/download)) 46 | 47 | --- 48 | 49 | ## **Installation Steps** 50 | 51 | ### **1. Clone the Repository** 52 | Open a terminal or command prompt and run: 53 | ```sh 54 | git clone https://github.com/visomaster/VisoMaster.git 55 | ``` 56 | ```sh 57 | cd VisoMaster 58 | ``` 59 | 60 | ### **2. Create and Activate a Conda Environment** 61 | ```sh 62 | conda create -n visomaster python=3.10.13 -y 63 | ``` 64 | ```sh 65 | conda activate visomaster 66 | ``` 67 | 68 | ### **3. Install CUDA and cuDNN** 69 | ```sh 70 | conda install -c nvidia/label/cuda-12.4.1 cuda-runtime 71 | ``` 72 | ```sh 73 | conda install -c conda-forge cudnn 74 | ``` 75 | 76 | ### **4. Install Additional Dependencies** 77 | ```sh 78 | conda install scikit-image 79 | ``` 80 | ```sh 81 | pip install -r requirements_cu124.txt 82 | ``` 83 | 84 | ### **5. Download Models and Other Dependencies** 85 | 1. Download all the required models 86 | ```sh 87 | python download_models.py 88 | ``` 89 | 2. Download all the files from this [page](https://github.com/visomaster/visomaster-assets/releases/tag/v0.1.0_dp) and copy it to the ***dependencies/*** folder. 90 | 91 | **Note**: You do not need to download the Source code (zip) and Source code (tar.gz) files 92 | ### **6. Run the Application** 93 | Once everything is set up, start the application by opening the **Start.bat** file. 94 | On Linux just run `python main.py`. 95 | --- 96 | 97 | ## **Troubleshooting** 98 | - If you face CUDA-related issues, ensure your GPU drivers are up to date. 99 | - For missing models, double-check that all models are placed in the correct directories. 100 | 101 | ## [Join Discord](https://discord.gg/5rx4SQuDbp) 102 | 103 | ## Support The Project ## 104 | This project was made possible by the combined efforts of **[@argenspin](https://github.com/argenspin)** and **[@Alucard24](https://github.com/alucard24)** with the support of countless other members in our Discord community. If you wish to support us for the continued development of **Visomaster**, you can donate to either of us (or Both if you're double Awesome :smiley: ) 105 | 106 | ### **argenspin** ### 107 | - [BuyMeACoffee](https://buymeacoffee.com/argenspin) 108 | - BTC: bc1qe8y7z0lkjsw6ssnlyzsncw0f4swjgh58j9vrqm84gw2nscgvvs5s4fts8g 109 | - ETH: 0x967a442FBd13617DE8d5fDC75234b2052122156B 110 | ### **Alucard24** ### 111 | - [BuyMeACoffee](https://buymeacoffee.com/alucard_24) 112 | - [PayPal](https://www.paypal.com/donate/?business=XJX2E5ZTMZUSQ&no_recurring=0&item_name=Support+us+with+a+donation!+Your+contribution+helps+us+continue+improving+and+providing+quality+content.+Thank+you!¤cy_code=EUR) 113 | - BTC: 15ny8vV3ChYsEuDta6VG3aKdT6Ra7duRAc 114 | 115 | 116 | ## Disclaimer: ## 117 | **VisoMaster** is a hobby project that we are making available to the community as a thank you to all of the contributors ahead of us. 118 | We've copied the disclaimer from [Swap-Mukham](https://github.com/harisreedhar/Swap-Mukham) here since it is well-written and applies 100% to this repo. 119 | 120 | We would like to emphasize that our swapping software is intended for responsible and ethical use only. We must stress that users are solely responsible for their actions when using our software. 121 | 122 | Intended Usage: This software is designed to assist users in creating realistic and entertaining content, such as movies, visual effects, virtual reality experiences, and other creative applications. We encourage users to explore these possibilities within the boundaries of legality, ethical considerations, and respect for others' privacy. 123 | 124 | Ethical Guidelines: Users are expected to adhere to a set of ethical guidelines when using our software. These guidelines include, but are not limited to: 125 | 126 | Not creating or sharing content that could harm, defame, or harass individuals. Obtaining proper consent and permissions from individuals featured in the content before using their likeness. Avoiding the use of this technology for deceptive purposes, including misinformation or malicious intent. Respecting and abiding by applicable laws, regulations, and copyright restrictions. 127 | 128 | Privacy and Consent: Users are responsible for ensuring that they have the necessary permissions and consents from individuals whose likeness they intend to use in their creations. We strongly discourage the creation of content without explicit consent, particularly if it involves non-consensual or private content. It is essential to respect the privacy and dignity of all individuals involved. 129 | 130 | Legal Considerations: Users must understand and comply with all relevant local, regional, and international laws pertaining to this technology. This includes laws related to privacy, defamation, intellectual property rights, and other relevant legislation. Users should consult legal professionals if they have any doubts regarding the legal implications of their creations. 131 | 132 | Liability and Responsibility: We, as the creators and providers of the deep fake software, cannot be held responsible for the actions or consequences resulting from the usage of our software. Users assume full liability and responsibility for any misuse, unintended effects, or abusive behavior associated with the content they create. 133 | 134 | By using this software, users acknowledge that they have read, understood, and agreed to abide by the above guidelines and disclaimers. We strongly encourage users to approach this technology with caution, integrity, and respect for the well-being and rights of others. 135 | 136 | Remember, technology should be used to empower and inspire, not to harm or deceive. Let's strive for ethical and responsible use of deep fake technology for the betterment of society. 137 | -------------------------------------------------------------------------------- /Start.bat: -------------------------------------------------------------------------------- 1 | 2 | call conda activate visomaster 3 | call app/ui/core/convert_ui_to_py.bat 4 | SET APP_ROOT=%~dp0 5 | SET APP_ROOT=%APP_ROOT:~0,-1% 6 | SET DEPENDENCIES=%APP_ROOT%\dependencies 7 | echo %DEPENDENCIES% 8 | SET PATH=%DEPENDENCIES%;%PATH% 9 | python main.py 10 | pause -------------------------------------------------------------------------------- /Start_Portable.bat: -------------------------------------------------------------------------------- 1 | call scripts\setenv.bat 2 | "%PYTHON_EXECUTABLE%" main.py 3 | pause -------------------------------------------------------------------------------- /Update_Portable.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | :: Check if install.dat exists 4 | if not exist install.dat ( 5 | echo install.dat file not found! 6 | pause 7 | exit /b 1 8 | ) 9 | 10 | :: Read the cuda_version from install.dat 11 | for /f "tokens=2 delims==" %%A in ('findstr "cuda_version" install.dat') do set CUDA_VERSION=%%A 12 | 13 | call scripts\update_%CUDA_VERSION%.bat 14 | pause 15 | -------------------------------------------------------------------------------- /app/helpers/downloader.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from pathlib import Path 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | from app.helpers.integrity_checker import check_file_integrity 8 | 9 | def download_file(model_name: str, file_path: str, correct_hash: str, url: str) -> bool: 10 | """ 11 | Downloads a file and verifies its integrity. 12 | 13 | Parameters: 14 | - model_name (str): Name of the model being downloaded. 15 | - file_path (str): Path where the file will be saved. 16 | - correct_hash (str): Expected hash value of the file for integrity check. 17 | - url (str): URL to download the file from. 18 | 19 | Returns: 20 | - bool: True if the file is downloaded and verified successfully, False otherwise. 21 | """ 22 | # Remove the file if it already exists and restart download 23 | if Path(file_path).is_file(): 24 | if check_file_integrity(file_path, correct_hash): 25 | print(f"\nSkipping {model_name} as it is already downloaded!") 26 | return True 27 | else: 28 | print(f"\n{file_path} already exists, but its file integrity couldn't be verified. Re-downloading it!") 29 | os.remove(file_path) 30 | 31 | print(f"\nDownloading {model_name} from {url}") 32 | 33 | try: 34 | response = requests.get(url, stream=True, timeout=5) 35 | response.raise_for_status() # Raise an error for bad HTTP responses (e.g., 404, 500) 36 | except requests.exceptions.RequestException as e: 37 | print(f"Failed to download {model_name}: {e}") 38 | return False 39 | 40 | total_size = int(response.headers.get("content-length", 0)) # File size in bytes 41 | block_size = 1024 # Size of chunks to download 42 | max_attempts = 3 43 | attempt = 1 44 | 45 | def download_and_save(): 46 | """Handles the file download and saves it to disk.""" 47 | with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: 48 | with open(file_path, "wb") as file: 49 | for data in response.iter_content(block_size): 50 | progress_bar.update(len(data)) 51 | file.write(data) 52 | 53 | while attempt <= max_attempts: 54 | try: 55 | download_and_save() 56 | 57 | # Verify file integrity 58 | if check_file_integrity(file_path, correct_hash): 59 | print("File integrity verified successfully!") 60 | print(f"File saved at: {file_path}") 61 | return True 62 | else: 63 | print(f"Integrity check failed for {file_path}. Retrying download (Attempt {attempt}/{max_attempts})...") 64 | os.remove(file_path) 65 | attempt += 1 66 | except requests.exceptions.Timeout: 67 | print("Connection timed out! Retrying download...") 68 | attempt += 1 69 | except Exception as e: 70 | print(f"An error occurred during download: {e}") 71 | attempt += 1 72 | 73 | print(f"Failed to download {model_name} after {max_attempts} attempts.") 74 | return False -------------------------------------------------------------------------------- /app/helpers/integrity_checker.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | BUF_SIZE = 131072 # read in 128kb chunks! 4 | 5 | def get_file_hash(file_path: str) -> str: 6 | hash_sha256 = hashlib.sha256() 7 | 8 | with open(file_path, 'rb') as f: 9 | while True: 10 | data = f.read(BUF_SIZE) 11 | if not data: 12 | break 13 | hash_sha256.update(data) 14 | 15 | # print("SHA256: {0}".format(hash_sha256.hexdigest())) 16 | return hash_sha256.hexdigest() 17 | 18 | def write_hash_to_file(hash: str, hash_file_path: str): 19 | with open(hash_file_path, 'w') as hash_file: 20 | hash_file.write(hash) 21 | 22 | def get_hash_from_hash_file(hash_file_path: str) -> str: 23 | with open(hash_file_path, 'r') as hash_file: 24 | hash_sha256 = hash_file.read().strip() 25 | return hash_sha256 26 | 27 | def check_file_integrity(file_path, correct_hash) -> bool: 28 | actual_hash = get_file_hash(file_path) 29 | return actual_hash==correct_hash -------------------------------------------------------------------------------- /app/helpers/miscellaneous.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import cv2 4 | import time 5 | from collections import UserDict 6 | import hashlib 7 | import numpy as np 8 | from functools import wraps 9 | from datetime import datetime 10 | from pathlib import Path 11 | from torchvision.transforms import v2 12 | import threading 13 | lock = threading.Lock() 14 | 15 | image_extensions = ('.jpg', '.jpeg', '.jpe', '.png', '.webp', '.tif', '.tiff', '.jp2', '.exr', '.hdr', '.ras', '.pnm', '.ppm', '.pgm', '.pbm', '.pfm') 16 | video_extensions = ('.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.m4v', '.3gp', '.gif') 17 | 18 | DFM_MODELS_PATH = './model_assets/dfm_models' 19 | 20 | DFM_MODELS_DATA = {} 21 | 22 | # Datatype used for storing parameter values 23 | # Major use case for subclassing this is to fallback to a default value, when trying to access value from a non-existing key 24 | # Helps when saving/importing workspace or parameters from external file after a future update including new Parameter widgets 25 | class ParametersDict(UserDict): 26 | def __init__(self, parameters, default_parameters: dict): 27 | super().__init__(parameters) 28 | self._default_parameters = default_parameters 29 | 30 | def __getitem__(self, key): 31 | try: 32 | return self.data[key] 33 | except KeyError: 34 | self.__setitem__(key, self._default_parameters[key]) 35 | return self._default_parameters[key] 36 | 37 | def get_scaling_transforms(): 38 | t512 = v2.Resize((512, 512), interpolation=v2.InterpolationMode.BILINEAR, antialias=False) 39 | t384 = v2.Resize((384, 384), interpolation=v2.InterpolationMode.BILINEAR, antialias=False) 40 | t256 = v2.Resize((256, 256), interpolation=v2.InterpolationMode.BILINEAR, antialias=False) 41 | t128 = v2.Resize((128, 128), interpolation=v2.InterpolationMode.BILINEAR, antialias=False) 42 | return t512, t384, t256, t128 43 | 44 | t512, t384, t256, t128 = get_scaling_transforms() 45 | 46 | def absoluteFilePaths(directory: str, include_subfolders=False): 47 | if include_subfolders: 48 | for dirpath,_,filenames in os.walk(directory): 49 | for f in filenames: 50 | yield os.path.abspath(os.path.join(dirpath, f)) 51 | else: 52 | for filename in os.listdir(directory): 53 | file_path = os.path.join(directory, filename) 54 | if os.path.isfile(file_path): 55 | yield file_path 56 | 57 | def truncate_text(text): 58 | if len(text) >= 35: 59 | return f'{text[:32]}...' 60 | return text 61 | 62 | def get_video_files(folder_name, include_subfolders=False): 63 | return [f for f in absoluteFilePaths(folder_name, include_subfolders) if f.lower().endswith(video_extensions)] 64 | 65 | def get_image_files(folder_name, include_subfolders=False): 66 | return [f for f in absoluteFilePaths(folder_name, include_subfolders) if f.lower().endswith(image_extensions)] 67 | 68 | def is_image_file(file_name: str): 69 | return file_name.lower().endswith(image_extensions) 70 | 71 | def is_video_file(file_name: str): 72 | return file_name.lower().endswith(video_extensions) 73 | 74 | def is_file_exists(file_path: str) -> bool: 75 | if not file_path: 76 | return False 77 | return Path(file_path).is_file() 78 | 79 | def get_file_type(file_name): 80 | if is_image_file(file_name): 81 | return 'image' 82 | if is_video_file(file_name): 83 | return 'video' 84 | return None 85 | 86 | def get_hash_from_filename(filename): 87 | """Generate a hash from just the filename (not the full path).""" 88 | # Use just the filename without path 89 | name = os.path.basename(filename) 90 | # Create hash from filename and size for uniqueness 91 | file_size = os.path.getsize(filename) 92 | hash_input = f"{name}_{file_size}" 93 | return hashlib.md5(hash_input.encode('utf-8')).hexdigest() 94 | 95 | def get_thumbnail_path(file_hash): 96 | """Get the full path to a cached thumbnail.""" 97 | thumbnail_dir = os.path.join(os.getcwd(), '.thumbnails') 98 | # Check if PNG version exists first 99 | png_path = os.path.join(thumbnail_dir, f"{file_hash}.png") 100 | if os.path.exists(png_path): 101 | return png_path 102 | # Otherwise use JPEG path 103 | return os.path.join(thumbnail_dir, f"{file_hash}.jpg") 104 | 105 | def ensure_thumbnail_dir(): 106 | """Create the .thumbnails directory if it doesn't exist.""" 107 | thumbnail_dir = os.path.join(os.getcwd(), '.thumbnails') 108 | os.makedirs(thumbnail_dir, exist_ok=True) 109 | return thumbnail_dir 110 | 111 | def save_thumbnail(frame, thumbnail_path): 112 | """Save a frame as an optimized thumbnail.""" 113 | # Handle different color formats 114 | if len(frame.shape) == 2: # Grayscale 115 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) 116 | elif frame.shape[2] == 4: # RGBA 117 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) 118 | 119 | height,width,_ = frame.shape 120 | width, height = get_scaled_resolution(media_width=width, media_height=height, max_height=140, max_width=140) 121 | # Resize to exactly 70x70 pixels with high quality 122 | frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_LANCZOS4) 123 | 124 | # First try PNG for best quality 125 | try: 126 | cv2.imwrite(thumbnail_path[:-4] + '.png', frame) 127 | # If PNG file is too large (>30KB), fall back to high-quality JPEG 128 | if os.path.getsize(thumbnail_path[:-4] + '.png') > 30 * 1024: 129 | os.remove(thumbnail_path[:-4] + '.png') 130 | raise Exception("PNG too large") 131 | else: 132 | return 133 | except: 134 | # Define JPEG parameters for high quality 135 | params = [ 136 | cv2.IMWRITE_JPEG_QUALITY, 98, # Maximum quality for JPEG 137 | cv2.IMWRITE_JPEG_OPTIMIZE, 1, # Enable optimization 138 | cv2.IMWRITE_JPEG_PROGRESSIVE, 1 # Enable progressive mode 139 | ] 140 | # Save as high quality JPEG 141 | cv2.imwrite(thumbnail_path, frame, params) 142 | 143 | def get_dfm_models_data(): 144 | DFM_MODELS_DATA.clear() 145 | for dfm_file in os.listdir(DFM_MODELS_PATH): 146 | if dfm_file.endswith(('.dfm','.onnx')): 147 | DFM_MODELS_DATA[dfm_file] = f'{DFM_MODELS_PATH}/{dfm_file}' 148 | return DFM_MODELS_DATA 149 | 150 | def get_dfm_models_selection_values(): 151 | return list(get_dfm_models_data().keys()) 152 | def get_dfm_models_default_value(): 153 | dfm_values = list(DFM_MODELS_DATA.keys()) 154 | if dfm_values: 155 | return dfm_values[0] 156 | return '' 157 | 158 | def get_scaled_resolution(media_width=False, media_height=False, max_width=False, max_height=False, media_capture: cv2.VideoCapture = False,): 159 | if not max_width or not max_height: 160 | max_height = 1080 161 | max_width = 1920 162 | 163 | if (not media_width or not media_height) and media_capture: 164 | media_width = media_capture.get(cv2.CAP_PROP_FRAME_WIDTH) 165 | media_height = media_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) 166 | 167 | if media_width > max_width or media_height > max_height: 168 | width_scale = max_width/media_width 169 | height_scale = max_height/media_height 170 | scale = min(width_scale, height_scale) 171 | media_width,media_height = media_width* scale, media_height*scale 172 | return int(media_width), int(media_height) 173 | 174 | def benchmark(func): 175 | @wraps(func) 176 | def wrapper(*args, **kwargs): 177 | start_time = time.perf_counter() # Record the start time 178 | result = func(*args, **kwargs) # Call the original function 179 | end_time = time.perf_counter() # Record the end time 180 | elapsed_time = end_time - start_time # Calculate elapsed time 181 | print(f"Function '{func.__name__}' executed in {elapsed_time:.6f} seconds.") 182 | return result # Return the result of the original function 183 | return wrapper 184 | 185 | def read_frame(capture_obj: cv2.VideoCapture, preview_mode=False): 186 | with lock: 187 | ret, frame = capture_obj.read() 188 | if ret and preview_mode: 189 | pass 190 | # width, height = get_scaled_resolution(capture_obj) 191 | # frame = cv2.resize(fr2ame, dsize=(width, height), interpolation=cv2.INTER_LANCZOS4) 192 | return ret, frame 193 | 194 | def read_image_file(image_path): 195 | try: 196 | img_array = np.fromfile(image_path, np.uint8) 197 | img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # Always load as BGR 198 | except Exception as e: 199 | print(f"Failed to load {image_path}: {e}") 200 | return None 201 | 202 | if img is None: 203 | print("Failed to decode:", image_path) 204 | return None 205 | 206 | return img # Return BGR format 207 | 208 | def get_output_file_path(original_media_path, output_folder, media_type='video'): 209 | date_and_time = datetime.now().strftime(r'%Y_%m_%d_%H_%M_%S') 210 | input_filename = os.path.basename(original_media_path) 211 | # Create a temp Path object to split and merge the original filename to get the new output filename 212 | temp_path = Path(input_filename) 213 | # output_filename = "{0}_{2}{1}".format(temp_path.stem, temp_path.suffix, date_and_time) 214 | if media_type=='video': 215 | output_filename = f'{temp_path.stem}_{date_and_time}.mp4' 216 | elif media_type=='image': 217 | output_filename = f'{temp_path.stem}_{date_and_time}.png' 218 | output_file_path = os.path.join(output_folder, output_filename) 219 | return output_file_path 220 | 221 | def is_ffmpeg_in_path(): 222 | if not cmd_exist('ffmpeg'): 223 | print("FFMPEG Not found in your system!") 224 | return False 225 | return True 226 | 227 | def cmd_exist(cmd): 228 | try: 229 | return shutil.which(cmd) is not None 230 | except ImportError: 231 | return any( 232 | os.access(os.path.join(path, cmd), os.X_OK) 233 | for path in os.environ["PATH"].split(os.pathsep) 234 | ) 235 | 236 | def get_dir_of_file(file_path): 237 | if file_path: 238 | return os.path.dirname(file_path) 239 | return os.path.curdir 240 | -------------------------------------------------------------------------------- /app/helpers/recording.py: -------------------------------------------------------------------------------- 1 | 2 | def write_frame_to_disk(frame): 3 | pass -------------------------------------------------------------------------------- /app/helpers/typing_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, NewType 2 | from app.helpers.miscellaneous import ParametersDict 3 | 4 | LayoutDictTypes = NewType('LayoutDictTypes', Dict[str, Dict[str, Dict[str, int|str|list|float|bool|Callable]]]) 5 | 6 | ParametersTypes = NewType('ParametersTypes', ParametersDict) 7 | FacesParametersTypes = NewType('FacesParametersTypes', dict[int, ParametersTypes]) 8 | 9 | ControlTypes = NewType('ControlTypes', Dict[str, bool|int|float|str]) 10 | 11 | MarkerTypes = NewType('MarkerTypes', Dict[int, Dict[str, FacesParametersTypes|ControlTypes]]) -------------------------------------------------------------------------------- /app/onnxmodels/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/onnxmodels/.gitkeep -------------------------------------------------------------------------------- /app/onnxmodels/dfm_models/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/onnxmodels/dfm_models/.keep -------------------------------------------------------------------------------- /app/onnxmodels/place_model_files_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/processors/external/cliplib/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /app/processors/external/cliplib/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/processors/external/cliplib/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /app/processors/external/cliplib/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def _node_get(node: torch._C.Node, key: str): 149 | """Gets attributes of a node which is polymorphic over return type. 150 | 151 | From https://github.com/pytorch/pytorch/pull/82628 152 | """ 153 | sel = node.kindOf(key) 154 | return getattr(node, sel)(key) 155 | 156 | def patch_device(module): 157 | try: 158 | graphs = [module.graph] if hasattr(module, "graph") else [] 159 | except RuntimeError: 160 | graphs = [] 161 | 162 | if hasattr(module, "forward1"): 163 | graphs.append(module.forward1.graph) 164 | 165 | for graph in graphs: 166 | for node in graph.findAllNodes("prim::Constant"): 167 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): 168 | node.copyAttributes(device_node) 169 | 170 | model.apply(patch_device) 171 | patch_device(model.encode_image) 172 | patch_device(model.encode_text) 173 | 174 | # patch dtype to float32 on CPU 175 | if str(device) == "cpu": 176 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 177 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 178 | float_node = float_input.node() 179 | 180 | def patch_float(module): 181 | try: 182 | graphs = [module.graph] if hasattr(module, "graph") else [] 183 | except RuntimeError: 184 | graphs = [] 185 | 186 | if hasattr(module, "forward1"): 187 | graphs.append(module.forward1.graph) 188 | 189 | for graph in graphs: 190 | for node in graph.findAllNodes("aten::to"): 191 | inputs = list(node.inputs()) 192 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 193 | if _node_get(inputs[i].node(), "value") == 5: 194 | inputs[i].node().copyAttributes(float_node) 195 | 196 | model.apply(patch_float) 197 | patch_float(model.encode_image) 198 | patch_float(model.encode_text) 199 | 200 | model.float() 201 | 202 | return model, _transform(model.input_resolution.item()) 203 | 204 | 205 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 206 | """ 207 | Returns the tokenized representation of given input string(s) 208 | 209 | Parameters 210 | ---------- 211 | texts : Union[str, List[str]] 212 | An input string or a list of input strings to tokenize 213 | 214 | context_length : int 215 | The context length to use; all CLIP models use 77 as the context length 216 | 217 | truncate: bool 218 | Whether to truncate the text in case its encoding is longer than the context length 219 | 220 | Returns 221 | ------- 222 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 223 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 224 | """ 225 | if isinstance(texts, str): 226 | texts = [texts] 227 | 228 | sot_token = _tokenizer.encoder["<|startoftext|>"] 229 | eot_token = _tokenizer.encoder["<|endoftext|>"] 230 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 231 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 233 | else: 234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 235 | 236 | for i, tokens in enumerate(all_tokens): 237 | if len(tokens) > context_length: 238 | if truncate: 239 | tokens = tokens[:context_length] 240 | tokens[-1] = eot_token 241 | else: 242 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 243 | result[i, :len(tokens)] = torch.tensor(tokens) 244 | 245 | return result 246 | -------------------------------------------------------------------------------- /app/processors/external/cliplib/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /app/processors/external/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /app/processors/face_restorers.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import TYPE_CHECKING 3 | 4 | import torch 5 | import numpy as np 6 | from torchvision.transforms import v2 7 | from skimage import transform as trans 8 | 9 | if TYPE_CHECKING: 10 | from app.processors.models_processor import ModelsProcessor 11 | 12 | class FaceRestorers: 13 | def __init__(self, models_processor: 'ModelsProcessor'): 14 | self.models_processor = models_processor 15 | 16 | def apply_facerestorer(self, swapped_face_upscaled, restorer_det_type, restorer_type, restorer_blend, fidelity_weight, detect_score): 17 | temp = swapped_face_upscaled 18 | t512 = v2.Resize((512, 512), antialias=False) 19 | t256 = v2.Resize((256, 256), antialias=False) 20 | t1024 = v2.Resize((1024, 1024), antialias=False) 21 | t2048 = v2.Resize((2048, 2048), antialias=False) 22 | 23 | # If using a separate detection mode 24 | if restorer_det_type == 'Blend' or restorer_det_type == 'Reference': 25 | if restorer_det_type == 'Blend': 26 | # Set up Transformation 27 | dst = self.models_processor.arcface_dst * 4.0 28 | dst[:,0] += 32.0 29 | 30 | elif restorer_det_type == 'Reference': 31 | try: 32 | dst, _, _ = self.models_processor.run_detect_landmark(swapped_face_upscaled, bbox=np.array([0, 0, 512, 512]), det_kpss=[], detect_mode='5', score=detect_score/100.0, from_points=False) 33 | except Exception as e: # pylint: disable=broad-except 34 | print(f"exception: {e}") 35 | return swapped_face_upscaled 36 | 37 | # Return non-enhanced face if keypoints are empty 38 | if not isinstance(dst, np.ndarray) or len(dst)==0: 39 | return swapped_face_upscaled 40 | 41 | tform = trans.SimilarityTransform() 42 | try: 43 | tform.estimate(dst, self.models_processor.FFHQ_kps) 44 | except: 45 | return swapped_face_upscaled 46 | # Transform, scale, and normalize 47 | temp = v2.functional.affine(swapped_face_upscaled, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0) ) 48 | temp = v2.functional.crop(temp, 0,0, 512, 512) 49 | 50 | temp = torch.div(temp, 255) 51 | temp = v2.functional.normalize(temp, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=False) 52 | 53 | if restorer_type == 'GPEN-256': 54 | temp = t256(temp) 55 | 56 | temp = torch.unsqueeze(temp, 0).contiguous() 57 | 58 | # Bindings 59 | outpred = torch.empty((1,3,512,512), dtype=torch.float32, device=self.models_processor.device).contiguous() 60 | 61 | if restorer_type == 'GFPGAN-v1.4': 62 | self.run_GFPGAN(temp, outpred) 63 | 64 | elif restorer_type == 'CodeFormer': 65 | self.run_codeformer(temp, outpred, fidelity_weight) 66 | 67 | elif restorer_type == 'GPEN-256': 68 | outpred = torch.empty((1,3,256,256), dtype=torch.float32, device=self.models_processor.device).contiguous() 69 | self.run_GPEN_256(temp, outpred) 70 | 71 | elif restorer_type == 'GPEN-512': 72 | self.run_GPEN_512(temp, outpred) 73 | 74 | elif restorer_type == 'GPEN-1024': 75 | temp = t1024(temp) 76 | outpred = torch.empty((1, 3, 1024, 1024), dtype=torch.float32, device=self.models_processor.device).contiguous() 77 | self.run_GPEN_1024(temp, outpred) 78 | 79 | elif restorer_type == 'GPEN-2048': 80 | temp = t2048(temp) 81 | outpred = torch.empty((1, 3, 2048, 2048), dtype=torch.float32, device=self.models_processor.device).contiguous() 82 | self.run_GPEN_2048(temp, outpred) 83 | 84 | elif restorer_type == 'RestoreFormer++': 85 | self.run_RestoreFormerPlusPlus(temp, outpred) 86 | 87 | elif restorer_type == 'VQFR-v2': 88 | self.run_VQFR_v2(temp, outpred, fidelity_weight) 89 | 90 | # Format back to cxHxW @ 255 91 | outpred = torch.squeeze(outpred) 92 | outpred = torch.clamp(outpred, -1, 1) 93 | outpred = torch.add(outpred, 1) 94 | outpred = torch.div(outpred, 2) 95 | outpred = torch.mul(outpred, 255) 96 | 97 | if restorer_type == 'GPEN-256' or restorer_type == 'GPEN-1024' or restorer_type == 'GPEN-2048': 98 | outpred = t512(outpred) 99 | 100 | # Invert Transform 101 | if restorer_det_type == 'Blend' or restorer_det_type == 'Reference': 102 | outpred = v2.functional.affine(outpred, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) ) 103 | 104 | # Blend 105 | alpha = float(restorer_blend)/100.0 106 | outpred = torch.add(torch.mul(outpred, alpha), torch.mul(swapped_face_upscaled, 1-alpha)) 107 | 108 | return outpred 109 | 110 | def run_GFPGAN(self, image, output): 111 | if not self.models_processor.models['GFPGANv1.4']: 112 | self.models_processor.models['GFPGANv1.4'] = self.models_processor.load_model('GFPGANv1.4') 113 | 114 | io_binding = self.models_processor.models['GFPGANv1.4'].io_binding() 115 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr()) 116 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr()) 117 | 118 | if self.models_processor.device == "cuda": 119 | torch.cuda.synchronize() 120 | elif self.models_processor.device != "cpu": 121 | self.models_processor.syncvec.cpu() 122 | self.models_processor.models['GFPGANv1.4'].run_with_iobinding(io_binding) 123 | 124 | def run_GPEN_256(self, image, output): 125 | if not self.models_processor.models['GPENBFR256']: 126 | self.models_processor.models['GPENBFR256'] = self.models_processor.load_model('GPENBFR256') 127 | 128 | io_binding = self.models_processor.models['GPENBFR256'].io_binding() 129 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,256,256), buffer_ptr=image.data_ptr()) 130 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,256,256), buffer_ptr=output.data_ptr()) 131 | 132 | if self.models_processor.device == "cuda": 133 | torch.cuda.synchronize() 134 | elif self.models_processor.device != "cpu": 135 | self.models_processor.syncvec.cpu() 136 | self.models_processor.models['GPENBFR256'].run_with_iobinding(io_binding) 137 | 138 | def run_GPEN_512(self, image, output): 139 | if not self.models_processor.models['GPENBFR512']: 140 | self.models_processor.models['GPENBFR512'] = self.models_processor.load_model('GPENBFR512') 141 | 142 | io_binding = self.models_processor.models['GPENBFR512'].io_binding() 143 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr()) 144 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr()) 145 | 146 | if self.models_processor.device == "cuda": 147 | torch.cuda.synchronize() 148 | elif self.models_processor.device != "cpu": 149 | self.models_processor.syncvec.cpu() 150 | self.models_processor.models['GPENBFR512'].run_with_iobinding(io_binding) 151 | 152 | def run_GPEN_1024(self, image, output): 153 | if not self.models_processor.models['GPENBFR1024']: 154 | self.models_processor.models['GPENBFR1024'] = self.models_processor.load_model('GPENBFR1024') 155 | 156 | io_binding = self.models_processor.models['GPENBFR1024'].io_binding() 157 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,1024,1024), buffer_ptr=image.data_ptr()) 158 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,1024,1024), buffer_ptr=output.data_ptr()) 159 | 160 | if self.models_processor.device == "cuda": 161 | torch.cuda.synchronize() 162 | elif self.models_processor.device != "cpu": 163 | self.models_processor.syncvec.cpu() 164 | self.models_processor.models['GPENBFR1024'].run_with_iobinding(io_binding) 165 | 166 | def run_GPEN_2048(self, image, output): 167 | if not self.models_processor.models['GPENBFR2048']: 168 | self.models_processor.models['GPENBFR2048'] = self.models_processor.load_model('GPENBFR2048') 169 | 170 | io_binding = self.models_processor.models['GPENBFR2048'].io_binding() 171 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,2048,2048), buffer_ptr=image.data_ptr()) 172 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,2048,2048), buffer_ptr=output.data_ptr()) 173 | 174 | if self.models_processor.device == "cuda": 175 | torch.cuda.synchronize() 176 | elif self.models_processor.device != "cpu": 177 | self.models_processor.syncvec.cpu() 178 | self.models_processor.models['GPENBFR2048'].run_with_iobinding(io_binding) 179 | 180 | def run_codeformer(self, image, output, fidelity_weight_value=0.9): 181 | if not self.models_processor.models['CodeFormer']: 182 | self.models_processor.models['CodeFormer'] = self.models_processor.load_model('CodeFormer') 183 | 184 | io_binding = self.models_processor.models['CodeFormer'].io_binding() 185 | io_binding.bind_input(name='x', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=image.data_ptr()) 186 | w = np.array([fidelity_weight_value], dtype=np.double) 187 | io_binding.bind_cpu_input('w', w) 188 | io_binding.bind_output(name='y', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr()) 189 | 190 | if self.models_processor.device == "cuda": 191 | torch.cuda.synchronize() 192 | elif self.models_processor.device != "cpu": 193 | self.models_processor.syncvec.cpu() 194 | self.models_processor.models['CodeFormer'].run_with_iobinding(io_binding) 195 | 196 | def run_VQFR_v2(self, image, output, fidelity_ratio_value): 197 | if not self.models_processor.models['VQFRv2']: 198 | self.models_processor.models['VQFRv2'] = self.models_processor.load_model('VQFRv2') 199 | 200 | assert fidelity_ratio_value >= 0.0 and fidelity_ratio_value <= 1.0, 'fidelity_ratio must in range[0,1]' 201 | fidelity_ratio = torch.tensor(fidelity_ratio_value).to(self.models_processor.device) 202 | 203 | io_binding = self.models_processor.models['VQFRv2'].io_binding() 204 | io_binding.bind_input(name='x_lq', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 205 | io_binding.bind_input(name='fidelity_ratio', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=fidelity_ratio.size(), buffer_ptr=fidelity_ratio.data_ptr()) 206 | io_binding.bind_output('enc_feat', self.models_processor.device) 207 | io_binding.bind_output('quant_logit', self.models_processor.device) 208 | io_binding.bind_output('texture_dec', self.models_processor.device) 209 | io_binding.bind_output(name='main_dec', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=(1,3,512,512), buffer_ptr=output.data_ptr()) 210 | 211 | if self.models_processor.device == "cuda": 212 | torch.cuda.synchronize() 213 | elif self.models_processor.device != "cpu": 214 | self.models_processor.syncvec.cpu() 215 | self.models_processor.models['VQFRv2'].run_with_iobinding(io_binding) 216 | 217 | def run_RestoreFormerPlusPlus(self, image, output): 218 | if not self.models_processor.models['RestoreFormerPlusPlus']: 219 | self.models_processor.models['RestoreFormerPlusPlus'] = self.models_processor.load_model('RestoreFormerPlusPlus') 220 | 221 | io_binding = self.models_processor.models['RestoreFormerPlusPlus'].io_binding() 222 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 223 | io_binding.bind_output(name='2359', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 224 | io_binding.bind_output('1228', self.models_processor.device) 225 | io_binding.bind_output('1238', self.models_processor.device) 226 | io_binding.bind_output('onnx::MatMul_1198', self.models_processor.device) 227 | io_binding.bind_output('onnx::Shape_1184', self.models_processor.device) 228 | io_binding.bind_output('onnx::ArgMin_1182', self.models_processor.device) 229 | io_binding.bind_output('input.1', self.models_processor.device) 230 | io_binding.bind_output('x', self.models_processor.device) 231 | io_binding.bind_output('x.3', self.models_processor.device) 232 | io_binding.bind_output('x.7', self.models_processor.device) 233 | io_binding.bind_output('x.11', self.models_processor.device) 234 | io_binding.bind_output('x.15', self.models_processor.device) 235 | io_binding.bind_output('input.252', self.models_processor.device) 236 | io_binding.bind_output('input.280', self.models_processor.device) 237 | io_binding.bind_output('input.288', self.models_processor.device) 238 | 239 | if self.models_processor.device == "cuda": 240 | torch.cuda.synchronize() 241 | elif self.models_processor.device != "cpu": 242 | self.models_processor.syncvec.cpu() 243 | self.models_processor.models['RestoreFormerPlusPlus'].run_with_iobinding(io_binding) -------------------------------------------------------------------------------- /app/processors/frame_enhancers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import TYPE_CHECKING 3 | 4 | import torch 5 | import numpy as np 6 | from torchvision.transforms import v2 7 | 8 | if TYPE_CHECKING: 9 | from app.processors.models_processor import ModelsProcessor 10 | 11 | class FrameEnhancers: 12 | def __init__(self, models_processor: 'ModelsProcessor'): 13 | self.models_processor = models_processor 14 | 15 | def run_enhance_frame_tile_process(self, img, enhancer_type, tile_size=256, scale=1): 16 | _, _, height, width = img.shape 17 | 18 | # Calcolo del numero di tile necessari 19 | tiles_x = math.ceil(width / tile_size) 20 | tiles_y = math.ceil(height / tile_size) 21 | 22 | # Calcolo del padding necessario per adattare l'immagine alle dimensioni dei tile 23 | pad_right = (tile_size - (width % tile_size)) % tile_size 24 | pad_bottom = (tile_size - (height % tile_size)) % tile_size 25 | 26 | # Padding dell'immagine se necessario 27 | if pad_right != 0 or pad_bottom != 0: 28 | img = torch.nn.functional.pad(img, (0, pad_right, 0, pad_bottom), 'constant', 0) 29 | 30 | # Creazione di un output tensor vuoto 31 | b, c, h, w = img.shape 32 | output = torch.empty((b, c, h * scale, w * scale), dtype=torch.float32, device=self.models_processor.device).contiguous() 33 | 34 | # Selezione della funzione di upscaling in base al tipo 35 | upscaler_functions = { 36 | 'RealEsrgan-x2-Plus': self.run_realesrganx2, 37 | 'RealEsrgan-x4-Plus': self.run_realesrganx4, 38 | 'BSRGan-x2': self.run_bsrganx2, 39 | 'BSRGan-x4': self.run_bsrganx4, 40 | 'UltraSharp-x4': self.run_ultrasharpx4, 41 | 'UltraMix-x4': self.run_ultramixx4, 42 | 'RealEsr-General-x4v3': self.run_realesrx4v3 43 | } 44 | 45 | fn_upscaler = upscaler_functions.get(enhancer_type) 46 | 47 | if not fn_upscaler: # Se il tipo di enhancer non è valido 48 | if pad_right != 0 or pad_bottom != 0: 49 | img = v2.functional.crop(img, 0, 0, height, width) 50 | return img 51 | 52 | with torch.no_grad(): # Disabilita il calcolo del gradiente 53 | # Elaborazione dei tile 54 | for j in range(tiles_y): 55 | for i in range(tiles_x): 56 | x_start, y_start = i * tile_size, j * tile_size 57 | x_end, y_end = x_start + tile_size, y_start + tile_size 58 | 59 | # Estrazione del tile di input 60 | input_tile = img[:, :, y_start:y_end, x_start:x_end].contiguous() 61 | output_tile = torch.empty((input_tile.shape[0], input_tile.shape[1], input_tile.shape[2] * scale, input_tile.shape[3] * scale), dtype=torch.float32, device=self.models_processor.device).contiguous() 62 | 63 | # Upscaling del tile 64 | fn_upscaler(input_tile, output_tile) 65 | 66 | # Inserimento del tile upscalato nel tensor di output 67 | output_y_start, output_x_start = y_start * scale, x_start * scale 68 | output_y_end, output_x_end = output_y_start + output_tile.shape[2], output_x_start + output_tile.shape[3] 69 | output[:, :, output_y_start:output_y_end, output_x_start:output_x_end] = output_tile 70 | 71 | # Ritaglio dell'output per rimuovere il padding aggiunto 72 | if pad_right != 0 or pad_bottom != 0: 73 | output = v2.functional.crop(output, 0, 0, height * scale, width * scale) 74 | 75 | return output 76 | 77 | def run_realesrganx2(self, image, output): 78 | if not self.models_processor.models['RealEsrganx2Plus']: 79 | self.models_processor.models['RealEsrganx2Plus'] = self.models_processor.load_model('RealEsrganx2Plus') 80 | 81 | io_binding = self.models_processor.models['RealEsrganx2Plus'].io_binding() 82 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 83 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 84 | 85 | if self.models_processor.device == "cuda": 86 | torch.cuda.synchronize() 87 | elif self.models_processor.device != "cpu": 88 | self.models_processor.syncvec.cpu() 89 | self.models_processor.models['RealEsrganx2Plus'].run_with_iobinding(io_binding) 90 | 91 | def run_realesrganx4(self, image, output): 92 | if not self.models_processor.models['RealEsrganx4Plus']: 93 | self.models_processor.models['RealEsrganx4Plus'] = self.models_processor.load_model('RealEsrganx4Plus') 94 | 95 | io_binding = self.models_processor.models['RealEsrganx4Plus'].io_binding() 96 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 97 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 98 | 99 | if self.models_processor.device == "cuda": 100 | torch.cuda.synchronize() 101 | elif self.models_processor.device != "cpu": 102 | self.models_processor.syncvec.cpu() 103 | self.models_processor.models['RealEsrganx4Plus'].run_with_iobinding(io_binding) 104 | 105 | def run_realesrx4v3(self, image, output): 106 | if not self.models_processor.models['RealEsrx4v3']: 107 | self.models_processor.models['RealEsrx4v3'] = self.models_processor.load_model('RealEsrx4v3') 108 | 109 | io_binding = self.models_processor.models['RealEsrx4v3'].io_binding() 110 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 111 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 112 | 113 | if self.models_processor.device == "cuda": 114 | torch.cuda.synchronize() 115 | elif self.models_processor.device != "cpu": 116 | self.models_processor.syncvec.cpu() 117 | self.models_processor.models['RealEsrx4v3'].run_with_iobinding(io_binding) 118 | 119 | def run_bsrganx2(self, image, output): 120 | if not self.models_processor.models['BSRGANx2']: 121 | self.models_processor.models['BSRGANx2'] = self.models_processor.load_model('BSRGANx2') 122 | 123 | io_binding = self.models_processor.models['BSRGANx2'].io_binding() 124 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 125 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 126 | 127 | if self.models_processor.device == "cuda": 128 | torch.cuda.synchronize() 129 | elif self.models_processor.device != "cpu": 130 | self.models_processor.syncvec.cpu() 131 | self.models_processor.models['BSRGANx2'].run_with_iobinding(io_binding) 132 | 133 | def run_bsrganx4(self, image, output): 134 | if not self.models_processor.models['BSRGANx4']: 135 | self.models_processor.models['BSRGANx4'] = self.models_processor.load_model('BSRGANx4') 136 | 137 | io_binding = self.models_processor.models['BSRGANx4'].io_binding() 138 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 139 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 140 | 141 | if self.models_processor.device == "cuda": 142 | torch.cuda.synchronize() 143 | elif self.models_processor.device != "cpu": 144 | self.models_processor.syncvec.cpu() 145 | self.models_processor.models['BSRGANx4'].run_with_iobinding(io_binding) 146 | 147 | def run_ultrasharpx4(self, image, output): 148 | if not self.models_processor.models['UltraSharpx4']: 149 | self.models_processor.models['UltraSharpx4'] = self.models_processor.load_model('UltraSharpx4') 150 | 151 | io_binding = self.models_processor.models['UltraSharpx4'].io_binding() 152 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 153 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 154 | 155 | if self.models_processor.device == "cuda": 156 | torch.cuda.synchronize() 157 | elif self.models_processor.device != "cpu": 158 | self.models_processor.syncvec.cpu() 159 | self.models_processor.models['UltraSharpx4'].run_with_iobinding(io_binding) 160 | 161 | def run_ultramixx4(self, image, output): 162 | if not self.models_processor.models['UltraMixx4']: 163 | self.models_processor.models['UltraMixx4'] = self.models_processor.load_model('UltraMixx4') 164 | 165 | io_binding = self.models_processor.models['UltraMixx4'].io_binding() 166 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 167 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 168 | 169 | if self.models_processor.device == "cuda": 170 | torch.cuda.synchronize() 171 | elif self.models_processor.device != "cpu": 172 | self.models_processor.syncvec.cpu() 173 | self.models_processor.models['UltraMixx4'].run_with_iobinding(io_binding) 174 | 175 | def run_deoldify_artistic(self, image, output): 176 | if not self.models_processor.models['DeoldifyArt']: 177 | self.models_processor.models['DeoldifyArt'] = self.models_processor.load_model('DeoldifyArt') 178 | 179 | io_binding = self.models_processor.models['DeoldifyArt'].io_binding() 180 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 181 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 182 | 183 | if self.models_processor.device == "cuda": 184 | torch.cuda.synchronize() 185 | elif self.models_processor.device != "cpu": 186 | self.models_processor.syncvec.cpu() 187 | self.models_processor.models['DeoldifyArt'].run_with_iobinding(io_binding) 188 | 189 | def run_deoldify_stable(self, image, output): 190 | if not self.models_processor.models['DeoldifyStable']: 191 | self.models_processor.models['DeoldifyStable'] = self.models_processor.load_model('DeoldifyStable') 192 | 193 | io_binding = self.models_processor.models['DeoldifyStable'].io_binding() 194 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 195 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 196 | 197 | if self.models_processor.device == "cuda": 198 | torch.cuda.synchronize() 199 | elif self.models_processor.device != "cpu": 200 | self.models_processor.syncvec.cpu() 201 | self.models_processor.models['DeoldifyStable'].run_with_iobinding(io_binding) 202 | 203 | def run_deoldify_video(self, image, output): 204 | if not self.models_processor.models['DeoldifyVideo']: 205 | self.models_processor.models['DeoldifyVideo'] = self.models_processor.load_model('DeoldifyVideo') 206 | 207 | io_binding = self.models_processor.models['DeoldifyVideo'].io_binding() 208 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 209 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 210 | 211 | if self.models_processor.device == "cuda": 212 | torch.cuda.synchronize() 213 | elif self.models_processor.device != "cpu": 214 | self.models_processor.syncvec.cpu() 215 | self.models_processor.models['DeoldifyVideo'].run_with_iobinding(io_binding) 216 | 217 | def run_ddcolor_artistic(self, image, output): 218 | if not self.models_processor.models['DDColorArt']: 219 | self.models_processor.models['DDColorArt'] = self.models_processor.load_model('DDColorArt') 220 | 221 | io_binding = self.models_processor.models['DDColorArt'].io_binding() 222 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 223 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 224 | 225 | if self.models_processor.device == "cuda": 226 | torch.cuda.synchronize() 227 | elif self.models_processor.device != "cpu": 228 | self.models_processor.syncvec.cpu() 229 | self.models_processor.models['DDColorArt'].run_with_iobinding(io_binding) 230 | 231 | def run_ddcolor(self, image, output): 232 | if not self.models_processor.models['DDcolor']: 233 | self.models_processor.models['DDcolor'] = self.models_processor.load_model('DDcolor') 234 | 235 | io_binding = self.models_processor.models['DDcolor'].io_binding() 236 | io_binding.bind_input(name='input', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=image.size(), buffer_ptr=image.data_ptr()) 237 | io_binding.bind_output(name='output', device_type=self.models_processor.device, device_id=0, element_type=np.float32, shape=output.size(), buffer_ptr=output.data_ptr()) 238 | 239 | if self.models_processor.device == "cuda": 240 | torch.cuda.synchronize() 241 | elif self.models_processor.device != "cpu": 242 | self.models_processor.syncvec.cpu() 243 | self.models_processor.models['DDcolor'].run_with_iobinding(io_binding) -------------------------------------------------------------------------------- /app/processors/utils/engine_builder.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-member 2 | 3 | import os 4 | import sys 5 | import logging 6 | import platform 7 | import ctypes 8 | from pathlib import Path 9 | 10 | try: 11 | import tensorrt as trt 12 | except ModuleNotFoundError: 13 | pass 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logging.getLogger("EngineBuilder").setLevel(logging.INFO) 17 | log = logging.getLogger("EngineBuilder") 18 | 19 | if 'trt' in globals(): 20 | # Creazione di un'istanza globale di logger di TensorRT 21 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) # pylint: disable=no-member 22 | else: 23 | TRT_LOGGER = {} 24 | 25 | # imported from https://github.com/warmshao/FasterLivePortrait/blob/master/scripts/onnx2trt.py 26 | # adjusted to work with TensorRT 10.3.0 27 | class EngineBuilder: 28 | """ 29 | Parses an ONNX graph and builds a TensorRT engine from it. 30 | """ 31 | 32 | def __init__(self, verbose=False, custom_plugin_path=None, builder_optimization_level=3): 33 | """ 34 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger. 35 | :param custom_plugin_path: Path to the custom plugin library (DLL or SO). 36 | """ 37 | if verbose: 38 | TRT_LOGGER.min_severity = trt.Logger.Severity.VERBOSE 39 | 40 | # Inizializza i plugin di TensorRT 41 | trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") 42 | 43 | # Costruisce il builder di TensorRT e la configurazione usando lo stesso logger 44 | self.builder = trt.Builder(TRT_LOGGER) 45 | self.config = self.builder.create_builder_config() 46 | # Imposta il limite di memoria del pool di lavoro a 3 GB 47 | self.config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 3 * (2 ** 30)) # 3 GB 48 | 49 | # Imposta il livello di ottimizzazione del builder (se fornito) 50 | self.config.builder_optimization_level = builder_optimization_level 51 | 52 | # Crea un profilo di ottimizzazione, se necessario 53 | profile = self.builder.create_optimization_profile() 54 | self.config.add_optimization_profile(profile) 55 | 56 | self.batch_size = None 57 | self.network = None 58 | self.parser = None 59 | 60 | # Carica plugin personalizzati se specificato 61 | if custom_plugin_path is not None: 62 | if platform.system().lower() == 'linux': 63 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL) 64 | else: 65 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL, winmode=0) 66 | 67 | def create_network(self, onnx_path): 68 | """ 69 | Parse the ONNX graph and create the corresponding TensorRT network definition. 70 | :param onnx_path: The path to the ONNX graph to load. 71 | """ 72 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 73 | 74 | self.network = self.builder.create_network(network_flags) 75 | self.parser = trt.OnnxParser(self.network, TRT_LOGGER) 76 | 77 | onnx_path = os.path.realpath(onnx_path) 78 | with open(onnx_path, "rb") as f: 79 | if not self.parser.parse(f.read()): 80 | log.error("Failed to load ONNX file: %s", onnx_path) 81 | for error in range(self.parser.num_errors): 82 | log.error(self.parser.get_error(error)) 83 | sys.exit(1) 84 | 85 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] 86 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] 87 | 88 | log.info("Network Description") 89 | for net_input in inputs: 90 | self.batch_size = net_input.shape[0] 91 | log.info("Input '%s' with shape %s and dtype %s", net_input.name, net_input.shape, net_input.dtype) 92 | for net_output in outputs: 93 | log.info("Output %s' with shape %s and dtype %s", net_output.name, net_output.shape, net_output.dtype) 94 | 95 | def create_engine(self, engine_path, precision): 96 | """ 97 | Build the TensorRT engine and serialize it to disk. 98 | :param engine_path: The path where to serialize the engine to. 99 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. 100 | """ 101 | engine_path = os.path.realpath(engine_path) 102 | engine_dir = os.path.dirname(engine_path) 103 | os.makedirs(engine_dir, exist_ok=True) 104 | log.info("Building %s Engine in %s", precision, engine_path) 105 | 106 | # Forza TensorRT a rispettare i vincoli di precisione 107 | self.config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS) 108 | 109 | if precision == "fp16": 110 | if not self.builder.platform_has_fast_fp16: 111 | log.warning("FP16 is not supported natively on this platform/device") 112 | else: 113 | self.config.set_flag(trt.BuilderFlag.FP16) 114 | 115 | # Costruzione del motore serializzato 116 | serialized_engine = self.builder.build_serialized_network(self.network, self.config) 117 | 118 | # Verifica che il motore sia stato serializzato correttamente 119 | if serialized_engine is None: 120 | raise RuntimeError("Errore nella costruzione del motore TensorRT!") 121 | 122 | # Scrittura del motore serializzato su disco 123 | with open(engine_path, "wb") as f: 124 | log.info("Serializing engine to file: %s", engine_path) 125 | f.write(serialized_engine) 126 | 127 | def change_extension(file_path, new_extension, version=None): 128 | """ 129 | Change the extension of the file path and optionally prepend a version. 130 | """ 131 | # Remove leading '.' from the new extension if present 132 | new_extension = new_extension.lstrip('.') 133 | 134 | # Create the new file path with the version before the extension, if provided 135 | if version: 136 | new_file_path = Path(file_path).with_suffix(f'.{version}.{new_extension}') 137 | else: 138 | new_file_path = Path(file_path).with_suffix(f'.{new_extension}') 139 | 140 | return str(new_file_path) 141 | 142 | def onnx_to_trt(onnx_model_path, trt_model_path=None, precision="fp16", custom_plugin_path=None, verbose=False): 143 | # The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'" 144 | 145 | if trt_model_path is None: 146 | trt_version = trt.__version__ 147 | trt_model_path = change_extension(onnx_model_path, "trt", version=trt_version) 148 | builder = EngineBuilder(verbose=verbose, custom_plugin_path=custom_plugin_path) 149 | 150 | builder.create_network(onnx_model_path) 151 | builder.create_engine(trt_model_path, precision) 152 | -------------------------------------------------------------------------------- /app/processors/utils/tensorrt_predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import OrderedDict 4 | import platform 5 | from queue import Queue 6 | from threading import Lock 7 | from typing import Dict, Any, OrderedDict as OrderedDictType 8 | 9 | try: 10 | from torch.cuda import nvtx 11 | import tensorrt as trt 12 | import ctypes 13 | except ModuleNotFoundError: 14 | pass 15 | 16 | # Dizionario per la conversione dei tipi di dati numpy a torch 17 | numpy_to_torch_dtype_dict = { 18 | np.uint8: torch.uint8, 19 | np.int8: torch.int8, 20 | np.int16: torch.int16, 21 | np.int32: torch.int32, 22 | np.int64: torch.int64, 23 | np.float16: torch.float16, 24 | np.float32: torch.float32, 25 | np.float64: torch.float64, 26 | np.complex64: torch.complex64, 27 | np.complex128: torch.complex128, 28 | } 29 | if np.version.full_version >= "1.24.0": 30 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 31 | else: 32 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 33 | 34 | if 'trt' in globals(): 35 | # Creazione di un’istanza globale di logger di TensorRT 36 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 37 | else: 38 | TRT_LOGGER = None 39 | 40 | 41 | class TensorRTPredictor: 42 | """ 43 | Implementa l'inferenza su un engine TensorRT, utilizzando un pool di execution context 44 | ognuno dei quali possiede i propri buffer per garantire la sicurezza in ambiente multithread. 45 | """ 46 | 47 | def __init__(self, **kwargs) -> None: 48 | """ 49 | :param model_path: Percorso al file dell'engine serializzato. 50 | :param pool_size: Numero di execution context da mantenere nel pool. 51 | :param custom_plugin_path: (Opzionale) percorso a eventuali plugin personalizzati. 52 | :param device: Device su cui allocare i tensori (default 'cuda'). 53 | :param debug: Se True, stampa informazioni di debug. 54 | """ 55 | self.device = kwargs.get("device", 'cuda') 56 | self.debug = kwargs.get("debug", False) 57 | self.pool_size = kwargs.get("pool_size", 10) 58 | 59 | # Caricamento del plugin personalizzato (se fornito) 60 | custom_plugin_path = kwargs.get("custom_plugin_path", None) 61 | if custom_plugin_path is not None: 62 | try: 63 | if platform.system().lower() == 'linux': 64 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL) 65 | else: 66 | # Su Windows eventualmente usare WinDLL o parametri specifici 67 | ctypes.CDLL(custom_plugin_path, mode=ctypes.RTLD_GLOBAL, winmode=0) 68 | except Exception as e: 69 | raise RuntimeError(f"Errore nel caricamento del plugin personalizzato: {e}") 70 | 71 | # Verifica che il percorso del modello sia fornito 72 | engine_path = kwargs.get("model_path", None) 73 | if not engine_path: 74 | raise ValueError("Il parametro 'model_path' è obbligatorio.") 75 | 76 | # Caricamento dell'engine TensorRT 77 | try: 78 | with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 79 | engine_data = f.read() 80 | self.engine = runtime.deserialize_cuda_engine(engine_data) 81 | except Exception as e: 82 | raise RuntimeError(f"Errore nella deserializzazione dell'engine: {e}") 83 | 84 | if self.engine is None: 85 | raise RuntimeError("La deserializzazione dell'engine è fallita.") 86 | 87 | # Setup delle specifiche di I/O (input e output) 88 | self.inputs = [] 89 | self.outputs = [] 90 | for idx in range(self.engine.num_io_tensors): 91 | name = self.engine.get_tensor_name(idx) 92 | mode = self.engine.get_tensor_mode(name) 93 | shape = list(self.engine.get_tensor_shape(name)) 94 | dtype = trt.nptype(self.engine.get_tensor_dtype(name)) 95 | binding = { 96 | "index": idx, 97 | "name": name, 98 | "dtype": dtype, 99 | "shape": shape, 100 | } 101 | if mode == trt.TensorIOMode.INPUT: 102 | self.inputs.append(binding) 103 | else: 104 | self.outputs.append(binding) 105 | 106 | if len(self.inputs) == 0 or len(self.outputs) == 0: 107 | raise RuntimeError("L'engine deve avere almeno un input e un output.") 108 | 109 | # Creazione del pool di execution context 110 | self.context_pool = Queue(maxsize=self.pool_size) 111 | # (Opzionale) Lock per eventuali operazioni critiche 112 | self.lock = Lock() 113 | for _ in range(self.pool_size): 114 | context = self.engine.create_execution_context() 115 | buffers = self._allocate_buffers() 116 | self.context_pool.put({"context": context, "buffers": buffers}) 117 | 118 | def _allocate_buffers(self) -> OrderedDictType[str, torch.Tensor]: 119 | """ 120 | Alloca un dizionario di tensori per tutti gli I/O del modello, tenendo conto di eventuali 121 | dimensioni dinamiche. Viene restituito un OrderedDict in cui la chiave è il nome del tensore. 122 | """ 123 | nvtx.range_push("allocate_max_buffers") 124 | buffers = OrderedDict() 125 | # Batch size predefinito 126 | batch_size = 1 127 | for idx in range(self.engine.num_io_tensors): 128 | name = self.engine.get_tensor_name(idx) 129 | shape = list(self.engine.get_tensor_shape(name)) # assicuriamoci di avere una lista 130 | is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT 131 | if -1 in shape: 132 | if is_input: 133 | # Ottiene la shape massima per il profilo 0 134 | profile_shape = self.engine.get_tensor_profile_shape(name, 0)[-1] 135 | shape = list(profile_shape) 136 | batch_size = shape[0] 137 | else: 138 | shape[0] = batch_size 139 | dtype = trt.nptype(self.engine.get_tensor_dtype(name)) 140 | if dtype not in numpy_to_torch_dtype_dict: 141 | raise TypeError(f"Tipo numpy non supportato: {dtype}") 142 | tensor = torch.empty(tuple(shape), 143 | dtype=numpy_to_torch_dtype_dict[dtype], 144 | device=self.device) 145 | buffers[name] = tensor 146 | nvtx.range_pop() 147 | return buffers 148 | 149 | def input_spec(self) -> list: 150 | """ 151 | Restituisce le specifiche degli input (nome, shape, dtype) utili per preparare gli array. 152 | """ 153 | specs = [] 154 | for i, inp in enumerate(self.inputs): 155 | specs.append((inp["name"], inp["shape"], inp["dtype"])) 156 | if self.debug: 157 | print(f"trt input {i} -> {inp['name']} -> {inp['shape']} -> {inp['dtype']}") 158 | return specs 159 | 160 | def output_spec(self) -> list: 161 | """ 162 | Restituisce le specifiche degli output (nome, shape, dtype) utili per preparare gli array. 163 | """ 164 | specs = [] 165 | for i, out in enumerate(self.outputs): 166 | specs.append((out["name"], out["shape"], out["dtype"])) 167 | if self.debug: 168 | print(f"trt output {i} -> {out['name']} -> {out['shape']} -> {out['dtype']}") 169 | return specs 170 | 171 | def adjust_buffer(self, feed_dict: Dict[str, Any], context: Any, buffers: OrderedDictType[str, torch.Tensor]) -> None: 172 | """ 173 | Regola le dimensioni dei buffer di input e copia i dati dal feed_dict nei tensori allocati. 174 | Se l’input è un array NumPy, lo converte in tensore Torch (sul device corretto). 175 | Imposta inoltre la shape di input nel contesto di esecuzione. 176 | """ 177 | nvtx.range_push("adjust_buffer") 178 | for name, buf in feed_dict.items(): 179 | if name not in buffers: 180 | raise KeyError(f"Input '{name}' non trovato nei buffer allocati.") 181 | input_tensor = buffers[name] 182 | # Converte in tensore se necessario 183 | if isinstance(buf, np.ndarray): 184 | buf_tensor = torch.from_numpy(buf).to(input_tensor.device) 185 | elif isinstance(buf, torch.Tensor): 186 | buf_tensor = buf.to(input_tensor.device) 187 | else: 188 | raise TypeError(f"Tipo di dato per '{name}' non supportato: {type(buf)}") 189 | current_shape = list(buf_tensor.shape) 190 | # Copia solo la porzione effettivamente utilizzata nel buffer preallocato 191 | slices = tuple(slice(0, dim) for dim in current_shape) 192 | input_tensor[slices].copy_(buf_tensor) 193 | # Imposta la shape dell'input nel contesto 194 | context.set_input_shape(name, current_shape) 195 | nvtx.range_pop() 196 | 197 | def predict(self, feed_dict: Dict[str, Any]) -> OrderedDictType[str, torch.Tensor]: 198 | """ 199 | Esegue l'inferenza in modalità sincrona usando execute_v2(). 200 | 201 | :param feed_dict: Dizionario di input (array numpy o tensori Torch). 202 | :return: Dizionario dei tensori (input e output) aggiornati. 203 | """ 204 | pool_entry = self.context_pool.get() # La Queue è thread-safe 205 | context = pool_entry["context"] 206 | buffers = pool_entry["buffers"] 207 | 208 | try: 209 | nvtx.range_push("set_tensors") 210 | self.adjust_buffer(feed_dict, context, buffers) 211 | # Imposta gli indirizzi dei buffer 212 | for name, tensor in buffers.items(): 213 | # Se necessario, si può controllare che il tipo del tensore sia quello atteso 214 | context.set_tensor_address(name, tensor.data_ptr()) 215 | nvtx.range_pop() 216 | 217 | # Prepara i binding (lista degli indirizzi dei buffer) 218 | bindings = [tensor.data_ptr() for tensor in buffers.values()] 219 | 220 | nvtx.range_push("execute") 221 | noerror = context.execute_v2(bindings) 222 | nvtx.range_pop() 223 | if not noerror: 224 | raise RuntimeError("ERROR: inference failed.") 225 | 226 | # (Opzionalmente, si potrebbero restituire solo gli output) 227 | return buffers 228 | 229 | finally: 230 | # Sincronizza il flusso CUDA prima di restituire il contesto 231 | torch.cuda.synchronize() 232 | self.context_pool.put(pool_entry) 233 | 234 | def predict_async(self, feed_dict: Dict[str, Any], stream: torch.cuda.Stream) -> OrderedDictType[str, torch.Tensor]: 235 | """ 236 | Esegue l'inferenza in modalità asincrona usando execute_async_v3(). 237 | 238 | :param feed_dict: Dizionario di input (array numpy o tensori Torch). 239 | :param stream: Un CUDA stream per l'esecuzione asincrona. 240 | :return: Dizionario dei tensori (input e output) aggiornati. 241 | """ 242 | pool_entry = self.context_pool.get() 243 | context = pool_entry["context"] 244 | buffers = pool_entry["buffers"] 245 | 246 | try: 247 | nvtx.range_push("set_tensors") 248 | self.adjust_buffer(feed_dict, context, buffers) 249 | for name, tensor in buffers.items(): 250 | context.set_tensor_address(name, tensor.data_ptr()) 251 | nvtx.range_pop() 252 | 253 | # Creazione di un evento CUDA per monitorare il consumo dell'input 254 | input_consumed_event = torch.cuda.Event() 255 | context.set_input_consumed_event(input_consumed_event.cuda_event) 256 | 257 | nvtx.range_push("execute_async") 258 | noerror = context.execute_async_v3(stream.cuda_stream) 259 | nvtx.range_pop() 260 | if not noerror: 261 | raise RuntimeError("ERROR: inference failed.") 262 | 263 | input_consumed_event.synchronize() 264 | 265 | return buffers 266 | 267 | finally: 268 | # Sincronizza lo stream usato se diverso da quello corrente 269 | if stream != torch.cuda.current_stream(): 270 | stream.synchronize() 271 | else: 272 | torch.cuda.synchronize() 273 | self.context_pool.put(pool_entry) 274 | 275 | def cleanup(self) -> None: 276 | """ 277 | Libera tutte le risorse associate al TensorRTPredictor. 278 | Questo metodo deve essere chiamato esplicitamente prima di eliminare l'oggetto. 279 | """ 280 | # Libera l'engine TensorRT 281 | if hasattr(self, 'engine') and self.engine is not None: 282 | del self.engine 283 | self.engine = None 284 | 285 | # Libera il pool di execution context e relativi buffer 286 | if hasattr(self, 'context_pool') and self.context_pool is not None: 287 | while not self.context_pool.empty(): 288 | pool_entry = self.context_pool.get() 289 | context = pool_entry.get("context", None) 290 | buffers = pool_entry.get("buffers", None) 291 | if context is not None: 292 | del context 293 | if buffers is not None: 294 | for t in buffers.values(): 295 | del t 296 | self.context_pool = None 297 | 298 | self.inputs = None 299 | self.outputs = None 300 | self.pool_size = None 301 | 302 | def __del__(self) -> None: 303 | # Per maggiore sicurezza, chiama cleanup nel distruttore 304 | self.cleanup() 305 | -------------------------------------------------------------------------------- /app/ui/core/convert_ui_to_py.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | 4 | 5 | :: Define relative paths 6 | set "UI_FILE=app\ui\core\MainWindow.ui" 7 | set "PY_FILE=app\ui\core\main_window.py" 8 | set "QRC_FILE=app\ui\core\media.qrc" 9 | set "RCC_PY_FILE=app\ui\core\media_rc.py" 10 | 11 | :: Run PySide6 commands 12 | pyside6-uic "%UI_FILE%" -o "%PY_FILE%" 13 | pyside6-rcc "%QRC_FILE%" -o "%RCC_PY_FILE%" 14 | 15 | :: Define search and replace strings 16 | set "searchString=import media_rc" 17 | set "replaceString=from app.ui.core import media_rc" 18 | 19 | :: Create a temporary file 20 | set "tempFile=%PY_FILE%.tmp" 21 | 22 | :: Process the file 23 | (for /f "usebackq delims=" %%A in ("%PY_FILE%") do ( 24 | set "line=%%A" 25 | if "!line!"=="%searchString%" ( 26 | echo %replaceString% 27 | ) else ( 28 | echo !line! 29 | ) 30 | )) > "%tempFile%" 31 | 32 | :: Replace the original file with the temporary file 33 | move /y "%tempFile%" "%PY_FILE%" 34 | 35 | echo Replacement complete. -------------------------------------------------------------------------------- /app/ui/core/media.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | media/image.png 4 | media/webcam.png 5 | media/video.png 6 | media/fullscreen.png 7 | media/open_file.png 8 | media/save_file_as.png 9 | media/save_file.png 10 | media/add_marker_hover.png 11 | media/add_marker_off.png 12 | media/audio_off.png 13 | media/audio_on.png 14 | media/marker.png 15 | media/marker_save.png 16 | media/next_marker_hover.png 17 | media/next_marker_off.png 18 | media/OffState.png 19 | media/OnState.png 20 | media/play_hover.png 21 | media/repeat.png 22 | media/play_off.png 23 | media/play_on.png 24 | media/previous_marker_hover.png 25 | media/previous_marker_off.png 26 | media/rec_hover.png 27 | media/rec_off.png 28 | media/rec_on.png 29 | media/remove_marker_hover.png 30 | media/remove_marker_off.png 31 | media/visomaster_small.png 32 | media/save.png 33 | media/splash.png 34 | media/splash_next.png 35 | media/stop_hover.png 36 | media/stop_off.png 37 | media/stop_on.png 38 | media/tl_beg_hover.png 39 | media/tl_beg_off.png 40 | media/tl_beg_on.png 41 | media/tl_left_hover.png 42 | media/tl_left_off.png 43 | media/tl_left_on.png 44 | media/tl_right_hover.png 45 | media/tl_right_off.png 46 | media/tl_right_on.png 47 | media/reset_default.png 48 | 49 | 50 | -------------------------------------------------------------------------------- /app/ui/core/media/OffState.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/OffState.png -------------------------------------------------------------------------------- /app/ui/core/media/OnState.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/OnState.png -------------------------------------------------------------------------------- /app/ui/core/media/add_marker_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/add_marker_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/add_marker_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/add_marker_off.png -------------------------------------------------------------------------------- /app/ui/core/media/audio_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/audio_off.png -------------------------------------------------------------------------------- /app/ui/core/media/audio_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/audio_on.png -------------------------------------------------------------------------------- /app/ui/core/media/fullscreen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/fullscreen.png -------------------------------------------------------------------------------- /app/ui/core/media/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/image.png -------------------------------------------------------------------------------- /app/ui/core/media/marker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/marker.png -------------------------------------------------------------------------------- /app/ui/core/media/marker_save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/marker_save.png -------------------------------------------------------------------------------- /app/ui/core/media/next_marker_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/next_marker_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/next_marker_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/next_marker_off.png -------------------------------------------------------------------------------- /app/ui/core/media/open_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/open_file.png -------------------------------------------------------------------------------- /app/ui/core/media/play_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/play_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/play_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/play_off.png -------------------------------------------------------------------------------- /app/ui/core/media/play_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/play_on.png -------------------------------------------------------------------------------- /app/ui/core/media/previous_marker_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/previous_marker_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/previous_marker_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/previous_marker_off.png -------------------------------------------------------------------------------- /app/ui/core/media/rec_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/rec_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/rec_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/rec_off.png -------------------------------------------------------------------------------- /app/ui/core/media/rec_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/rec_on.png -------------------------------------------------------------------------------- /app/ui/core/media/remove_marker_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/remove_marker_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/remove_marker_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/remove_marker_off.png -------------------------------------------------------------------------------- /app/ui/core/media/repeat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/repeat.png -------------------------------------------------------------------------------- /app/ui/core/media/reset_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/reset_default.png -------------------------------------------------------------------------------- /app/ui/core/media/save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/save.png -------------------------------------------------------------------------------- /app/ui/core/media/save_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/save_file.png -------------------------------------------------------------------------------- /app/ui/core/media/save_file_as.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/save_file_as.png -------------------------------------------------------------------------------- /app/ui/core/media/splash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/splash.png -------------------------------------------------------------------------------- /app/ui/core/media/splash_next.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/splash_next.png -------------------------------------------------------------------------------- /app/ui/core/media/stop_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/stop_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/stop_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/stop_off.png -------------------------------------------------------------------------------- /app/ui/core/media/stop_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/stop_on.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_beg_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_beg_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_beg_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_beg_off.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_beg_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_beg_on.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_left_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_left_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_left_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_left_off.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_left_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_left_on.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_right_hover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_right_hover.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_right_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_right_off.png -------------------------------------------------------------------------------- /app/ui/core/media/tl_right_on.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/tl_right_on.png -------------------------------------------------------------------------------- /app/ui/core/media/video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/video.png -------------------------------------------------------------------------------- /app/ui/core/media/visomaster_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/visomaster_full.png -------------------------------------------------------------------------------- /app/ui/core/media/visomaster_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/visomaster_small.png -------------------------------------------------------------------------------- /app/ui/core/media/webcam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/app/ui/core/media/webcam.png -------------------------------------------------------------------------------- /app/ui/core/proxy_style.py: -------------------------------------------------------------------------------- 1 | from PySide6 import QtWidgets 2 | from PySide6.QtCore import Qt 3 | 4 | 5 | class ProxyStyle(QtWidgets.QProxyStyle): 6 | def styleHint(self, hint, opt=None, widget=None, returnData=None) -> int: 7 | res = super().styleHint(hint, opt, widget, returnData) 8 | if hint == self.StyleHint.SH_Slider_AbsoluteSetButtons: 9 | res = Qt.LeftButton.value 10 | return res 11 | -------------------------------------------------------------------------------- /app/ui/styles/dark_styles.qss: -------------------------------------------------------------------------------- 1 | QToolTip { 2 | color: #ffffff; /* White text for clarity */ 3 | background-color: #2c3e50; /* Dark blue/gray background for contrast */ 4 | border: 1px solid #4facc9; /* A turquoise border to match your theme */ 5 | padding: 5px; /* Some padding for readability */ 6 | } 7 | 8 | /* QPushButton styling */ 9 | QPushButton { 10 | color: #9dcbd9; 11 | font-weight: bold; 12 | } 13 | 14 | QToolTip { 15 | color: #ffffff; 16 | background-color: #2f3d4b; 17 | padding: 2px; 18 | } 19 | 20 | /* Base button style */ 21 | TargetMediaCardButton, TargetFaceCardButton, InputFaceCardButton { 22 | background-color: #141516; /* Dark blue background */ 23 | border-radius: 5px; /* Rounded corners */ 24 | border: 1px solid #34495e; /* Slightly darker border */ 25 | color: white; /* White text color */ 26 | padding: 10px; /* Padding around the icon */ 27 | font-size: 14px; /* Default font size */ 28 | } 29 | 30 | /* Icon inside the button */ 31 | TargetMediaCardButton::icon, TargetFaceCardButton::icon, InputFaceCardButton:icon { 32 | margin: 5px; /* Spacing around the icon */ 33 | } 34 | 35 | /* Hover state */ 36 | TargetMediaCardButton:hover, TargetFaceCardButton:hover, InputFaceCardButton:hover { 37 | background-color: #34495e; /* Slightly lighter on hover */ 38 | border: 1px solid #4facc9; /* Turquoise border on hover */ 39 | } 40 | 41 | /* Pressed state */ 42 | TargetMediaCardButton:pressed, TargetFaceCardButton:pressed, InputFaceCardButton:pressed, EmbeddingCardButton:pressed { 43 | background-color: #4facc9; /* Turquoise background when pressed */ 44 | border: 1px solid #4facc9; /* Darker turquoise border */ 45 | } 46 | 47 | /* Checked state */ 48 | TargetMediaCardButton:checked, TargetFaceCardButton:checked, InputFaceCardButton:checked, EmbeddingCardButton:checked { 49 | border: 1px solid #4facc9; /* Darker turquoise border */ 50 | } 51 | 52 | /* Disabled state */ 53 | TargetMediaCardButton:disabled, TargetFaceCardButton:disabled, InputFaceCardButton:disabled, EmbeddingCardButton:disabled { 54 | background-color: #95a5a6; /* Gray background when disabled */ 55 | border: 1px solid #7f8c8d; /* Darker gray border */ 56 | color: #bdc3c7; /* Lighter gray text color */ 57 | } 58 | 59 | /* Focused state */ 60 | TargetMediaCardButton:focus { 61 | outline: none; /* Remove the default focus outline */ 62 | border: 2px solid #4facc9; /* Thicker turquoise border on focus */ 63 | } 64 | 65 | /* QSlider -------------------------------------- */ 66 | #videoSeekSlider::groove:horizontal { 67 | border-radius: 10px; 68 | height: 10px; 69 | margin: 0px; 70 | background-color: #34495e; 71 | } 72 | 73 | #videoSeekSlider::groove:horizontal:hover { 74 | background-color: #4facc9; 75 | } 76 | 77 | #videoSeekSlider::sub-page:horizontal { 78 | background-color: #197996; 79 | border-radius: 10px; 80 | height: 8px; 81 | } 82 | 83 | #videoSeekSlider::sub-page:horizontal:hover { 84 | background-color: #4facc9; 85 | } 86 | 87 | #videoSeekSlider::add-page:horizontal { 88 | background-color: #2c3e50; 89 | border-radius: 10px; 90 | height: 8px; 91 | } 92 | 93 | #videoSeekSlider::handle:horizontal { 94 | background-color: #bdc3c7; 95 | border: 1px solid #7f8c8d; 96 | height: 20px; 97 | width: 8px; 98 | margin: -6px 0; 99 | border-radius: 2px; 100 | } 101 | 102 | #videoSeekSlider::handle:horizontal:hover { 103 | background-color: #4facc9; 104 | } 105 | 106 | #videoSeekSlider::handle:horizontal:pressed { 107 | background-color: #4facc9; 108 | } 109 | 110 | #videoSeekSlider::handle:horizontal:disabled { 111 | background-color: #95a5a6; 112 | border: 1px solid #7f8c8d; 113 | } 114 | 115 | #videoSeekSlider::groove:horizontal:disabled { 116 | background-color: #7f8c8d; 117 | } 118 | 119 | #videoSeekSlider::sub-page:horizontal:disabled { 120 | background-color: #bdc3c7; 121 | } 122 | 123 | /* ToggleButton styling */ 124 | ToggleButton { 125 | border: 0px; 126 | border-radius: 6px; 127 | background-color: #dc2626; 128 | text-align: center; 129 | padding: 2px; 130 | } 131 | ToggleButton:hover { 132 | background-color: #232323; 133 | } 134 | ToggleButton:pressed { 135 | background-color: #38b845; 136 | } 137 | ToggleButton:checked { 138 | background-color: #38b845; 139 | } 140 | ToggleButton:default { 141 | border-color: none; 142 | } 143 | 144 | /* ParameterSlider - Thin Groove and Handle */ 145 | /* -------------------------------------- */ 146 | ParameterSlider::groove:horizontal, ParameterDecimalSlider::groove:horizontal { 147 | border-radius: 1px; 148 | height: 6px; 149 | margin: 0px; 150 | background-color: #34495e; 151 | } 152 | 153 | ParameterSlider::sub-page:horizontal, ParameterDecimalSlider::sub-page:horizontal { 154 | background-color: transparent; 155 | } 156 | 157 | ParameterSlider::add-page:horizontal, ParameterDecimalSlider::add-page:horizontal { 158 | background-color: #2c3e50; 159 | border-radius: 1px; 160 | height: 4px; 161 | } 162 | 163 | ParameterSlider::handle:horizontal, ParameterDecimalSlider::handle:horizontal { 164 | background-color: #ffffff; 165 | border: 1px solid #7f8c8d; 166 | height: 16px; 167 | width: 4px; 168 | margin: -4px 0; 169 | } 170 | 171 | ParameterSlider::handle:horizontal:disabled, ParameterDecimalSlider::handle:horizontal:disabled { 172 | background-color: #95a5a6; 173 | border: 1px solid #7f8c8d; 174 | } 175 | 176 | ParameterSlider::groove:horizontal:disabled, ParameterDecimalSlider::groove:horizontal:disabled { 177 | background-color: #7f8c8d; 178 | } 179 | 180 | ParameterSlider::sub-page:horizontal:disabled, ParameterDecimalSlider::sub-page:horizontal:disabled { 181 | background-color: transparent; 182 | } -------------------------------------------------------------------------------- /app/ui/styles/light_styles.qss: -------------------------------------------------------------------------------- 1 | QPushButton { 2 | color: #379bb9; 3 | font-weight: bold; 4 | } 5 | 6 | QToolTip { 7 | color: #242e38; 8 | background-color: #ccd8dd; 9 | padding: 2px; 10 | } 11 | 12 | /* Base button style */ 13 | TargetMediaCardButton, TargetFaceCardButton, InputFaceCardButton { 14 | background-color: #141516; /* Dark blue background */ 15 | border-radius: 5px; /* Rounded corners */ 16 | border: 1px solid #34495e; /* Slightly darker border */ 17 | color: white; /* White text color */ 18 | padding: 10px; /* Padding around the icon */ 19 | font-size: 14px; /* Default font size */ 20 | } 21 | 22 | /* Icon inside the button */ 23 | TargetMediaCardButton::icon, TargetFaceCardButton::icon, InputFaceCardButton:icon { 24 | margin: 5px; /* Spacing around the icon */ 25 | } 26 | 27 | /* Hover state */ 28 | TargetMediaCardButton:hover, TargetFaceCardButton:hover, InputFaceCardButton:hover { 29 | background-color: #34495e; /* Slightly lighter on hover */ 30 | border: 1px solid #4facc9; /* Turquoise border on hover */ 31 | } 32 | 33 | /* Pressed state */ 34 | TargetMediaCardButton:pressed { 35 | background-color: #4facc9; /* Turquoise background when pressed */ 36 | border: 1px solid #4facc9; /* Darker turquoise border */ 37 | } 38 | 39 | /* Pressed state */ 40 | TargetMediaCardButton:checked, TargetFaceCardButton:checked, InputFaceCardButton:checked, EmbeddingCardButton:checked { 41 | border: 1px solid #4facc9; /* Darker turquoise border */ 42 | } 43 | 44 | /* Disabled state */ 45 | TargetMediaCardButton:disabled, TargetFaceCardButton:disabled, InputFaceCardButton:disabled, EmbeddingCardButton:disabled { 46 | background-color: #95a5a6; /* Gray background when disabled */ 47 | border: 1px solid #7f8c8d; /* Darker gray border */ 48 | color: #bdc3c7; /* Lighter gray text color */ 49 | } 50 | 51 | /* Focused state */ 52 | TargetMediaCardButton:focus { 53 | outline: none; /* Remove the default focus outline */ 54 | border: 2px solid #4facc9; /* Thicker turquoise border on focus */ 55 | } 56 | 57 | 58 | /* QSlider -------------------------------------- */ 59 | #videoSeekSlider::groove:horizontal { 60 | border-radius: 10px; /* Rounded corners to match the button */ 61 | height: 10px; /* Increase groove thickness */ 62 | margin: 0px; 63 | background-color: #34495e; /* Background groove color (slightly lighter blue) */ 64 | } 65 | 66 | #videoSeekSlider::groove:horizontal:hover { 67 | background-color: #4facc9; /* Turquoise color on hover */ 68 | } 69 | 70 | #videoSeekSlider::sub-page:horizontal { 71 | background-color: #197996; /* Turquoise for completed part */ 72 | border-radius: 10px; /* Rounded corners */ 73 | height: 8px; 74 | } 75 | 76 | #videoSeekSlider::sub-page:horizontal:hover { 77 | background-color: #4facc9; /* Darker turquoise on hover */ 78 | } 79 | 80 | #videoSeekSlider::add-page:horizontal { 81 | background-color: #2c3e50; /* Dark blue for uncompleted part */ 82 | border-radius: 10px; 83 | height: 8px; 84 | } 85 | 86 | #videoSeekSlider::handle:horizontal { 87 | background-color: #bdc3c7; /* Lighter gray handle color */ 88 | border: 1px solid #7f8c8d; /* Darker gray border */ 89 | height: 20px; /* Handle height */ 90 | width: 8px; /* Handle width */ 91 | margin: -6px 0; /* Adjust handle position */ 92 | border-radius: 2px; /* Rounded handle */ 93 | } 94 | 95 | #videoSeekSlider::handle:horizontal:hover { 96 | background-color: #4facc9; /* Turquoise handle on hover */ 97 | } 98 | 99 | #videoSeekSlider::handle:horizontal:pressed { 100 | background-color: #4facc9; /* Darker turquoise when pressed */ 101 | } 102 | 103 | /* #videoSeekSlider state */ 104 | #videoSeekSlider::handle:horizontal:disabled { 105 | background-color: #95a5a6; /* Gray background when disabled */ 106 | border: 1px solid #7f8c8d; /* Darker gray border */ 107 | } 108 | 109 | #videoSeekSlider::groove:horizontal:disabled { 110 | background-color: #7f8c8d; /* Dark gray groove when disabled */ 111 | } 112 | 113 | #videoSeekSlider::sub-page:horizontal:disabled { 114 | background-color: #bdc3c7; /* Lighter gray completed part when disabled */ 115 | } 116 | 117 | 118 | ToggleButton { 119 | border: 0px; 120 | border-radius: 6px; 121 | background-color: #dc2626; 122 | text-align: center; 123 | padding: 2px; 124 | 125 | } 126 | ToggleButton:hover { 127 | background-color: #232323; 128 | } 129 | ToggleButton:pressed { 130 | background-color: #38b845; 131 | } 132 | ToggleButton:checked { 133 | background-color: #38b845; 134 | } 135 | ToggleButton:default { 136 | border-color: none; /* make the default button prominent */ 137 | } 138 | 139 | 140 | 141 | /* ParameterSlider - Thin Groove and Handle */ 142 | /* -------------------------------------- */ 143 | 144 | ParameterSlider::groove:horizontal,ParameterDecimalSlider::groove:horizontal { 145 | border-radius: 1px; /* Rounded corners */ 146 | height: 6px; /* Thinner groove */ 147 | margin: 0px; 148 | background-color: #34495e; /* Background groove color */ 149 | } 150 | 151 | /* Remove the completed part color by making it transparent */ 152 | ParameterSlider::sub-page:horizontal,ParameterDecimalSlider::sub-page:horizontal { 153 | background-color: transparent; /* No color for completed part */ 154 | } 155 | 156 | ParameterSlider::add-page:horizontal,ParameterDecimalSlider::add-page:horizontal { 157 | background-color: #2c3e50; /* Dark blue for uncompleted part */ 158 | border-radius: 1px; /* Thinner radius */ 159 | height: 4px; /* Thinner groove */ 160 | } 161 | 162 | ParameterSlider::handle:horizontal,ParameterDecimalSlider::handle:horizontal { 163 | background-color: #ffffff; /* Lighter gray handle color */ 164 | border: 1px solid #7f8c8d; /* Darker gray border */ 165 | height: 16px; /* Thinner handle */ 166 | width: 4px; /* Thinner handle */ 167 | margin: -4px 0; /* Adjust handle position for thinner groove */ 168 | } 169 | 170 | /* Disabled state */ 171 | ParameterSlider::handle:horizontal:disabled,ParameterDecimalSlider::handle:horizontal:disabled { 172 | background-color: #95a5a6; /* Gray background when disabled */ 173 | border: 1px solid #7f8c8d; /* Darker gray border */ 174 | } 175 | 176 | ParameterSlider::groove:horizontal:disabled,ParameterDecimalSlider::groove:horizontal:disabled { 177 | background-color: #7f8c8d; /* Dark gray groove when disabled */ 178 | } 179 | 180 | ParameterSlider::sub-page:horizontal:disabled,ParameterDecimalSlider::sub-page:horizontal:disabled { 181 | background-color: transparent; /* No color for completed part when disabled */ 182 | } 183 | -------------------------------------------------------------------------------- /app/ui/widgets/actions/card_actions.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import TYPE_CHECKING, Dict 3 | import uuid 4 | 5 | import numpy 6 | import cv2 7 | import torch 8 | from torchvision.transforms import v2 9 | 10 | import app.ui.widgets.actions.common_actions as common_widget_actions 11 | from app.ui.widgets.actions import list_view_actions 12 | import app.helpers.miscellaneous as misc_helpers 13 | from app.ui.widgets.settings_layout_data import SETTINGS_LAYOUT_DATA 14 | 15 | if TYPE_CHECKING: 16 | from app.ui.main_ui import MainWindow 17 | 18 | def clear_target_faces(main_window: 'MainWindow', refresh_frame=True): 19 | if main_window.video_processor.processing: 20 | main_window.video_processor.stop_processing() 21 | main_window.targetFacesList.clear() 22 | for _, target_face in main_window.target_faces.items(): 23 | target_face.deleteLater() 24 | main_window.target_faces = {} 25 | main_window.parameters = {} 26 | 27 | main_window.selected_target_face_id = False 28 | # Set Parameter widget values to default 29 | common_widget_actions.set_widgets_values_using_face_id_parameters(main_window=main_window, face_id=False) 30 | if refresh_frame: 31 | common_widget_actions.refresh_frame(main_window=main_window) 32 | 33 | 34 | def clear_input_faces(main_window: 'MainWindow'): 35 | main_window.inputFacesList.clear() 36 | for _, input_face in main_window.input_faces.items(): 37 | input_face.deleteLater() 38 | main_window.input_faces = {} 39 | 40 | for _, target_face in main_window.target_faces.items(): 41 | target_face.assigned_input_faces = {} 42 | target_face.calculate_assigned_input_embedding() 43 | common_widget_actions.refresh_frame(main_window=main_window) 44 | 45 | def clear_merged_embeddings(main_window: 'MainWindow'): 46 | main_window.inputEmbeddingsList.clear() 47 | for _, embed_button in main_window.merged_embeddings.items(): 48 | embed_button.deleteLater() 49 | main_window.merged_embeddings = {} 50 | 51 | for _, target_face in main_window.target_faces.items(): 52 | target_face.assigned_merged_embeddings = {} 53 | target_face.calculate_assigned_input_embedding() 54 | common_widget_actions.refresh_frame(main_window=main_window) 55 | 56 | def uncheck_all_input_faces(main_window: 'MainWindow'): 57 | # Uncheck All other input faces 58 | for _, input_face_button in main_window.input_faces.items(): 59 | input_face_button.setChecked(False) 60 | 61 | def uncheck_all_merged_embeddings(main_window: 'MainWindow'): 62 | for _, embed_button in main_window.merged_embeddings.items(): 63 | embed_button.setChecked(False) 64 | 65 | def find_target_faces(main_window: 'MainWindow'): 66 | control = main_window.control.copy() 67 | video_processor = main_window.video_processor 68 | if video_processor.media_path: 69 | frame = None 70 | media_capture = video_processor.media_capture 71 | 72 | if video_processor.file_type=='image': 73 | frame = misc_helpers.read_image_file(video_processor.media_path) 74 | elif video_processor.file_type=='video' and media_capture: 75 | ret,frame = misc_helpers.read_frame(media_capture) 76 | media_capture.set(cv2.CAP_PROP_POS_FRAMES, video_processor.current_frame_number) 77 | elif video_processor.file_type=='webcam' and media_capture: 78 | ret, frame = misc_helpers.read_frame(media_capture) 79 | media_capture.set(cv2.CAP_PROP_POS_FRAMES, video_processor.current_frame_number) 80 | 81 | if frame is not None: 82 | # Frame must be in RGB format 83 | frame = frame[..., ::-1] # Swap the channels from BGR to RGB 84 | 85 | # print(frame) 86 | img = torch.from_numpy(frame.astype('uint8')).to(main_window.models_processor.device) 87 | img = img.permute(2,0,1) 88 | if control['ManualRotationEnableToggle']: 89 | img = v2.functional.rotate(img, angle=control['ManualRotationAngleSlider'], interpolation=v2.InterpolationMode.BILINEAR, expand=True) 90 | 91 | _, kpss_5, _ = main_window.models_processor.run_detect(img, control['DetectorModelSelection'], max_num=control['MaxFacesToDetectSlider'], score=control['DetectorScoreSlider']/100.0, input_size=(512, 512), use_landmark_detection=control['LandmarkDetectToggle'], landmark_detect_mode=control['LandmarkDetectModelSelection'], landmark_score=control["LandmarkDetectScoreSlider"]/100.0, from_points=control["DetectFromPointsToggle"], rotation_angles=[0] if not control["AutoRotationToggle"] else [0, 90, 180, 270]) 92 | 93 | ret = [] 94 | for face_kps in kpss_5: 95 | face_emb, cropped_img = main_window.models_processor.run_recognize_direct(img, face_kps, control['SimilarityTypeSelection'], control['RecognitionModelSelection']) 96 | ret.append([face_kps, face_emb, cropped_img, img]) 97 | 98 | if ret: 99 | # Loop through all faces in video frame 100 | for face in ret: 101 | found = False 102 | # Check if this face has already been found 103 | for face_id, target_face in main_window.target_faces.items(): 104 | parameters = main_window.parameters[target_face.face_id] 105 | threshhold = parameters['SimilarityThresholdSlider'] 106 | if main_window.models_processor.findCosineDistance(target_face.get_embedding(control['RecognitionModelSelection']), face[1]) >= threshhold: 107 | found = True 108 | break 109 | if not found: 110 | face_img = face[2].cpu().numpy() 111 | face_img = face_img[..., ::-1] # Swap the channels from RGB to BGR 112 | face_img = numpy.ascontiguousarray(face_img) 113 | # crop = cv2.resize(face[2].cpu().numpy(), (82, 82)) 114 | pixmap = common_widget_actions.get_pixmap_from_frame(main_window, face_img) 115 | 116 | embedding_store: Dict[str, numpy.ndarray] = {} 117 | # Ottenere i valori di 'options' 118 | options = SETTINGS_LAYOUT_DATA['Face Recognition']['RecognitionModelSelection']['options'] 119 | for option in options: 120 | if option != control['RecognitionModelSelection']: 121 | target_emb, _ = main_window.models_processor.run_recognize_direct(face[3], face[0], control['SimilarityTypeSelection'], option) 122 | embedding_store[option] = target_emb 123 | else: 124 | embedding_store[control['RecognitionModelSelection']] = face[1] 125 | 126 | face_id = str(uuid.uuid1().int) 127 | 128 | list_view_actions.add_media_thumbnail_to_target_faces_list(main_window, face_img, embedding_store, pixmap, face_id) 129 | # Select the first target face if no target face is already selected 130 | if main_window.target_faces and not main_window.selected_target_face_id: 131 | list(main_window.target_faces.values())[0].click() 132 | 133 | if main_window.video_processor.processing: 134 | main_window.video_processor.stop_processing() 135 | common_widget_actions.refresh_frame(main_window) 136 | 137 | common_widget_actions.update_gpu_memory_progressbar(main_window) -------------------------------------------------------------------------------- /app/ui/widgets/actions/control_actions.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | import torch 3 | import qdarkstyle 4 | from PySide6 import QtWidgets 5 | import qdarktheme 6 | 7 | if TYPE_CHECKING: 8 | from app.ui.main_ui import MainWindow 9 | from app.ui.widgets.actions import common_actions as common_widget_actions 10 | 11 | #''' 12 | # Define functions here that has to be executed when value of a control widget (In the settings tab) is changed. 13 | # The first two parameters should be the MainWindow object and the new value of the control 14 | #''' 15 | 16 | def change_execution_provider(main_window: 'MainWindow', new_provider): 17 | main_window.video_processor.stop_processing() 18 | main_window.models_processor.switch_providers_priority(new_provider) 19 | main_window.models_processor.clear_gpu_memory() 20 | common_widget_actions.update_gpu_memory_progressbar(main_window) 21 | 22 | def change_threads_number(main_window: 'MainWindow', new_threads_number): 23 | main_window.video_processor.set_number_of_threads(new_threads_number) 24 | torch.cuda.empty_cache() 25 | common_widget_actions.update_gpu_memory_progressbar(main_window) 26 | 27 | 28 | def change_theme(main_window: 'MainWindow', new_theme): 29 | 30 | def get_style_data(filename, theme='dark', custom_colors=None): 31 | custom_colors = custom_colors or {"primary": "#4facc9"} 32 | with open(f"app/ui/styles/{filename}", "r") as f: # pylint: disable=unspecified-encoding 33 | _style = f.read() 34 | _style = qdarktheme.load_stylesheet(theme=theme, custom_colors=custom_colors)+'\n'+_style 35 | return _style 36 | app = QtWidgets.QApplication.instance() 37 | 38 | _style = '' 39 | if new_theme == "Dark": 40 | _style = get_style_data('dark_styles.qss', 'dark',) 41 | 42 | elif new_theme == "Light": 43 | _style = get_style_data('light_styles.qss', 'light',) 44 | 45 | elif new_theme == "Dark-Blue": 46 | _style = get_style_data('dark_styles.qss', 'dark',) + qdarkstyle.load_stylesheet() # Applica lo stile dark-blue 47 | 48 | app.setStyleSheet(_style) 49 | 50 | main_window.update() # Aggiorna la finestra principale 51 | 52 | def set_video_playback_fps(main_window: 'MainWindow', set_video_fps=False): 53 | # print("Called set_video_playback_fps()") 54 | if set_video_fps and main_window.video_processor.media_capture: 55 | main_window.parameter_widgets['VideoPlaybackCustomFpsSlider'].set_value(main_window.video_processor.fps) 56 | 57 | def toggle_virtualcam(main_window: 'MainWindow', toggle_value=False): 58 | video_processor = main_window.video_processor 59 | if toggle_value: 60 | video_processor.enable_virtualcam() 61 | else: 62 | video_processor.disable_virtualcam() 63 | 64 | def enable_virtualcam(main_window: 'MainWindow', backend): 65 | print('backend', backend) 66 | main_window.video_processor.enable_virtualcam(backend=backend) -------------------------------------------------------------------------------- /app/ui/widgets/actions/filter_actions.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from PySide6 import QtWidgets 4 | 5 | if TYPE_CHECKING: 6 | from app.ui.main_ui import MainWindow 7 | 8 | def filter_target_videos(main_window: 'MainWindow', search_text: str = ''): 9 | main_window.target_videos_filter_worker.stop_thread() 10 | main_window.target_videos_filter_worker.search_text = search_text 11 | main_window.target_videos_filter_worker.start() 12 | 13 | def filter_input_faces(main_window: 'MainWindow', search_text: str = ''): 14 | main_window.input_faces_filter_worker.stop_thread() 15 | main_window.input_faces_filter_worker.search_text = search_text 16 | main_window.input_faces_filter_worker.start() 17 | 18 | def filter_merged_embeddings(main_window: 'MainWindow', search_text: str = ''): 19 | main_window.merged_embeddings_filter_worker.stop_thread() 20 | main_window.merged_embeddings_filter_worker.search_text = search_text 21 | main_window.merged_embeddings_filter_worker.start() 22 | 23 | def update_filtered_list(main_window: 'MainWindow', filter_list_widget: QtWidgets.QListWidget, visible_indices: list): 24 | for i in range(filter_list_widget.count()): 25 | filter_list_widget.item(i).setHidden(True) 26 | 27 | # Show only the items in the visible_indices list 28 | for i in visible_indices: 29 | filter_list_widget.item(i).setHidden(False) -------------------------------------------------------------------------------- /app/ui/widgets/actions/graphics_view_actions.py: -------------------------------------------------------------------------------- 1 | from PySide6 import QtWidgets, QtGui, QtCore 2 | from typing import TYPE_CHECKING 3 | if TYPE_CHECKING: 4 | from app.ui.main_ui import MainWindow 5 | 6 | # @misc_helpers.benchmark (Keep this decorator if you have it) 7 | def update_graphics_view(main_window: 'MainWindow', pixmap: QtGui.QPixmap, current_frame_number, reset_fit=False): 8 | # print('(update_graphics_view) current_frame_number', current_frame_number) 9 | 10 | # Update the video seek slider and line edit 11 | if main_window.videoSeekSlider.value() != current_frame_number: 12 | main_window.videoSeekSlider.blockSignals(True) 13 | main_window.videoSeekSlider.setValue(current_frame_number) 14 | main_window.videoSeekSlider.blockSignals(False) 15 | 16 | current_text = main_window.videoSeekLineEdit.text() 17 | if current_text != str(current_frame_number): 18 | main_window.videoSeekLineEdit.setText(str(current_frame_number)) 19 | 20 | # Preserve the current transform (zoom and pan state) - No longer needed if we are not clearing scene every time 21 | # current_transform = main_window.graphicsViewFrame.transform() 22 | 23 | # Get the scene and existing pixmap item 24 | scene = main_window.graphicsViewFrame.scene() 25 | pixmap_item = None 26 | previous_items = scene.items() 27 | if previous_items: 28 | pixmap_item = previous_items[0] # Assume pixmap is the first item 29 | 30 | # Resize the pixmap if necessary (only if pixmap_item exists) 31 | if pixmap_item: 32 | bounding_rect = pixmap_item.boundingRect() 33 | # If the old pixmap is smaller than the new pixmap (ie, due to the face compare or mask compare), scale is to the size of the old one 34 | if bounding_rect.width() > pixmap.width() and bounding_rect.height() > pixmap.height(): 35 | pixmap = pixmap.scaled(bounding_rect.width(), bounding_rect.height(), QtCore.Qt.AspectRatioMode.KeepAspectRatio) 36 | 37 | # Update or create pixmap item 38 | if pixmap_item: 39 | pixmap_item.setPixmap(pixmap) # Update the pixmap of the existing item 40 | else: 41 | pixmap_item_new = QtWidgets.QGraphicsPixmapItem(pixmap) # Create a new pixmap item only if it doesn't exist 42 | scene.addItem(pixmap_item_new) 43 | pixmap_item = pixmap_item_new # Use the newly created item for fitting view 44 | 45 | # Set the scene rectangle to the bounding rectangle of the pixmap 46 | scene_rect = pixmap_item.boundingRect() 47 | main_window.graphicsViewFrame.setSceneRect(scene_rect) 48 | 49 | # Reset the view or restore the previous transform 50 | if reset_fit: 51 | fit_image_to_view(main_window, pixmap_item, scene_rect) # Pass pixmap_item here 52 | # else: # No longer need to restore transform if we are not clearing scene 53 | # zoom_andfit_image_to_view_onchange(main_window, current_transform) # No longer needed 54 | 55 | 56 | def zoom_andfit_image_to_view_onchange(main_window: 'MainWindow', new_transform): 57 | """Restore the previous transform (zoom and pan state) and update the view.""" 58 | # print("Called zoom_andfit_image_to_view_onchange()") 59 | main_window.graphicsViewFrame.setTransform(new_transform, combine=False) 60 | 61 | 62 | def fit_image_to_view(main_window: 'MainWindow', pixmap_item: QtWidgets.QGraphicsPixmapItem, scene_rect): 63 | """Reset the view and fit the image to the view, keeping the aspect ratio.""" 64 | # print("Called fit_image_to_view()") 65 | graphicsViewFrame = main_window.graphicsViewFrame 66 | # Reset the transform and set the scene rectangle 67 | graphicsViewFrame.resetTransform() 68 | graphicsViewFrame.setSceneRect(scene_rect) 69 | # Fit the image to the view, keeping the aspect ratio 70 | graphicsViewFrame.fitInView(pixmap_item, QtCore.Qt.AspectRatioMode.KeepAspectRatio) -------------------------------------------------------------------------------- /app/ui/widgets/actions/list_view_actions.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | from typing import TYPE_CHECKING 4 | 5 | from PySide6 import QtWidgets, QtGui, QtCore 6 | 7 | from app.ui.widgets.actions import common_actions as common_widget_actions 8 | from app.ui.widgets.actions import card_actions 9 | from app.ui.widgets import widget_components 10 | import app.helpers.miscellaneous as misc_helpers 11 | from app.ui.widgets import ui_workers 12 | if TYPE_CHECKING: 13 | from app.ui.main_ui import MainWindow 14 | 15 | # Functions to add Buttons with thumbnail for selecting videos/images and faces 16 | @QtCore.Slot(str, QtGui.QPixmap) 17 | def add_media_thumbnail_to_target_videos_list(main_window: 'MainWindow', media_path, pixmap, file_type, media_id): 18 | add_media_thumbnail_button(main_window, widget_components.TargetMediaCardButton, main_window.targetVideosList, main_window.target_videos, pixmap, media_path=media_path, file_type=file_type, media_id=media_id) 19 | 20 | # Functions to add Buttons with thumbnail for selecting videos/images and faces 21 | @QtCore.Slot(str, QtGui.QPixmap, str, int, int) 22 | def add_webcam_thumbnail_to_target_videos_list(main_window: 'MainWindow', media_path, pixmap, file_type, media_id, webcam_index, webcam_backend): 23 | add_media_thumbnail_button(main_window, widget_components.TargetMediaCardButton, main_window.targetVideosList, main_window.target_videos, pixmap, media_path=media_path, file_type=file_type, media_id=media_id, is_webcam=True, webcam_index=webcam_index, webcam_backend=webcam_backend) 24 | 25 | @QtCore.Slot() 26 | def add_media_thumbnail_to_target_faces_list(main_window: 'MainWindow', cropped_face, embedding_store, pixmap, face_id): 27 | add_media_thumbnail_button(main_window, widget_components.TargetFaceCardButton, main_window.targetFacesList, main_window.target_faces, pixmap, cropped_face=cropped_face, embedding_store=embedding_store, face_id=face_id ) 28 | 29 | @QtCore.Slot() 30 | def add_media_thumbnail_to_source_faces_list(main_window: 'MainWindow', media_path, cropped_face, embedding_store, pixmap, face_id): 31 | add_media_thumbnail_button(main_window, widget_components.InputFaceCardButton, main_window.inputFacesList, main_window.input_faces, pixmap, media_path=media_path, cropped_face=cropped_face, embedding_store=embedding_store, face_id=face_id ) 32 | 33 | 34 | def add_media_thumbnail_button(main_window: 'MainWindow', buttonClass: 'widget_components.CardButton', listWidget:QtWidgets.QListWidget, buttons_list:list, pixmap, **kwargs): 35 | if buttonClass==widget_components.TargetMediaCardButton: 36 | constructor_args = (kwargs.get('media_path'), kwargs.get('file_type'), kwargs.get('media_id')) 37 | if kwargs.get('is_webcam'): 38 | constructor_args+=(kwargs.get('is_webcam'), kwargs.get('webcam_index'), kwargs.get('webcam_backend')) 39 | elif buttonClass in (widget_components.TargetFaceCardButton, widget_components.InputFaceCardButton): 40 | constructor_args = (kwargs.get('media_path',''), kwargs.get('cropped_face'), kwargs.get('embedding_store'), kwargs.get('face_id')) 41 | if buttonClass==widget_components.TargetMediaCardButton: 42 | button_size = QtCore.QSize(90, 90) # Set a fixed size for the buttons 43 | else: 44 | button_size = QtCore.QSize(70, 70) # Set a fixed size for the buttons 45 | 46 | button: widget_components.CardButton = buttonClass(*constructor_args, main_window=main_window) 47 | button.setIcon(QtGui.QIcon(pixmap)) 48 | button.setIconSize(button_size - QtCore.QSize(8, 8)) # Slightly smaller than the button size to add some margin 49 | button.setFixedSize(button_size) 50 | button.setCheckable(True) 51 | if buttonClass in [widget_components.TargetFaceCardButton, widget_components.InputFaceCardButton]: 52 | buttons_list[button.face_id] = button 53 | elif buttonClass == widget_components.TargetMediaCardButton: 54 | buttons_list[button.media_id] = button 55 | elif buttonClass == widget_components.EmbeddingCardButton: 56 | buttons_list[button.embedding_id] = button 57 | # Create a QListWidgetItem and set the button as its widget 58 | list_item = QtWidgets.QListWidgetItem(listWidget) 59 | list_item.setSizeHint(button_size) 60 | button.list_item = list_item 61 | button.list_widget = listWidget 62 | # Align the item to center 63 | list_item.setTextAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) 64 | listWidget.setItemWidget(list_item, button) 65 | # Adjust the QListWidget properties to handle the grid layout 66 | grid_size_with_padding = button_size + QtCore.QSize(4, 4) # Add padding around the buttons 67 | listWidget.setGridSize(grid_size_with_padding) # Set grid size with padding 68 | listWidget.setWrapping(True) # Enable wrapping to have items in rows 69 | listWidget.setFlow(QtWidgets.QListView.LeftToRight) # Set flow direction 70 | listWidget.setResizeMode(QtWidgets.QListView.Adjust) # Adjust layout automatically 71 | 72 | 73 | def create_and_add_embed_button_to_list(main_window: 'MainWindow', embedding_name, embedding_store, embedding_id): 74 | inputEmbeddingsList = main_window.inputEmbeddingsList 75 | # Passa l'intero embedding_store 76 | embed_button = widget_components.EmbeddingCardButton(main_window=main_window, embedding_name=embedding_name, embedding_store=embedding_store, embedding_id=embedding_id) 77 | 78 | button_size = QtCore.QSize(105, 35) # Adjusted width to fit 3 per row with proper spacing 79 | embed_button.setFixedSize(button_size) 80 | 81 | list_item = QtWidgets.QListWidgetItem(inputEmbeddingsList) 82 | list_item.setSizeHint(button_size) 83 | embed_button.list_item = list_item 84 | list_item.setTextAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) 85 | 86 | inputEmbeddingsList.setItemWidget(list_item, embed_button) 87 | 88 | # Configure grid layout for 3x3 minimum grid 89 | grid_size_with_padding = button_size + QtCore.QSize(4, 4) # Add padding around buttons 90 | inputEmbeddingsList.setGridSize(grid_size_with_padding) 91 | inputEmbeddingsList.setWrapping(True) 92 | inputEmbeddingsList.setFlow(QtWidgets.QListView.TopToBottom) 93 | inputEmbeddingsList.setResizeMode(QtWidgets.QListView.Fixed) 94 | inputEmbeddingsList.setSpacing(2) 95 | inputEmbeddingsList.setUniformItemSizes(True) 96 | inputEmbeddingsList.setViewMode(QtWidgets.QListView.IconMode) 97 | inputEmbeddingsList.setMovement(QtWidgets.QListView.Static) 98 | 99 | # Set viewport mode and item size 100 | viewport_height = 180 # Fixed height for 3 rows (35px + padding per row) 101 | inputEmbeddingsList.setFixedHeight(viewport_height) 102 | 103 | # Calculate grid dimensions 104 | row_height = viewport_height // 3 105 | col_width = grid_size_with_padding.width() 106 | 107 | # Set minimum width for 3 columns and adjust spacing 108 | min_width = (3 * col_width) + 16 # Add extra padding for better spacing between columns 109 | inputEmbeddingsList.setMinimumWidth(min_width) 110 | 111 | # Configure scrolling behavior 112 | inputEmbeddingsList.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) 113 | inputEmbeddingsList.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) 114 | inputEmbeddingsList.setVerticalScrollMode(QtWidgets.QAbstractItemView.ScrollPerPixel) 115 | inputEmbeddingsList.setHorizontalScrollMode(QtWidgets.QAbstractItemView.ScrollPerPixel) 116 | 117 | # Set layout direction to ensure proper filling 118 | inputEmbeddingsList.setLayoutDirection(QtCore.Qt.LeftToRight) 119 | inputEmbeddingsList.setLayoutMode(QtWidgets.QListView.Batched) 120 | 121 | main_window.merged_embeddings[embed_button.embedding_id] = embed_button 122 | 123 | def clear_stop_loading_target_media(main_window: 'MainWindow'): 124 | if main_window.video_loader_worker: 125 | main_window.video_loader_worker.stop() 126 | main_window.video_loader_worker.terminate() 127 | main_window.video_loader_worker = False 128 | time.sleep(0.5) 129 | main_window.targetVideosList.clear() 130 | 131 | @QtCore.Slot() 132 | def select_target_medias(main_window: 'MainWindow', source_type='folder', folder_name=False, files_list=None): 133 | files_list = files_list or [] 134 | if source_type=='folder': 135 | folder_name = QtWidgets.QFileDialog.getExistingDirectory(dir=main_window.last_target_media_folder_path) 136 | if not folder_name: 137 | return 138 | main_window.labelTargetVideosPath.setText(misc_helpers.truncate_text(folder_name)) 139 | main_window.labelTargetVideosPath.setToolTip(folder_name) 140 | main_window.last_target_media_folder_path = folder_name 141 | 142 | elif source_type=='files': 143 | files_list = QtWidgets.QFileDialog.getOpenFileNames()[0] 144 | if not files_list: 145 | return 146 | # Get Folder name from the first file 147 | file_dir = misc_helpers.get_dir_of_file(files_list[0]) 148 | main_window.labelTargetVideosPath.setText(file_dir) #Just a temp text until i think of something better 149 | main_window.labelTargetVideosPath.setToolTip(file_dir) 150 | main_window.last_target_media_folder_path = file_dir 151 | 152 | clear_stop_loading_target_media(main_window) 153 | card_actions.clear_target_faces(main_window) 154 | 155 | main_window.selected_video_button = False 156 | main_window.target_videos = {} 157 | 158 | main_window.video_loader_worker = ui_workers.TargetMediaLoaderWorker(main_window=main_window, folder_name=folder_name, files_list=files_list) 159 | main_window.video_loader_worker.thumbnail_ready.connect(partial(add_media_thumbnail_to_target_videos_list, main_window)) 160 | main_window.video_loader_worker.start() 161 | 162 | @QtCore.Slot() 163 | def load_target_webcams(main_window: 'MainWindow',): 164 | if main_window.filterWebcamsCheckBox.isChecked(): 165 | main_window.video_loader_worker = ui_workers.TargetMediaLoaderWorker(main_window=main_window, webcam_mode=True) 166 | main_window.video_loader_worker.webcam_thumbnail_ready.connect(partial(add_webcam_thumbnail_to_target_videos_list, main_window)) 167 | main_window.video_loader_worker.start() 168 | else: 169 | main_window.placeholder_update_signal.emit(main_window.targetVideosList, True) 170 | for _, target_video in main_window.target_videos.copy().items(): #Use a copy of the dict to prevent Dictionary changed during iteration exceptions 171 | if target_video.file_type == 'webcam': 172 | target_video.remove_target_media_from_list() 173 | if target_video == main_window.selected_video_button: 174 | main_window.selected_video_button = False 175 | main_window.placeholder_update_signal.emit(main_window.targetVideosList, False) 176 | 177 | def clear_stop_loading_input_media(main_window: 'MainWindow'): 178 | if main_window.input_faces_loader_worker: 179 | main_window.input_faces_loader_worker.stop() 180 | main_window.input_faces_loader_worker.terminate() 181 | main_window.input_faces_loader_worker = False 182 | time.sleep(0.5) 183 | main_window.inputFacesList.clear() 184 | 185 | @QtCore.Slot() 186 | def select_input_face_images(main_window: 'MainWindow', source_type='folder', folder_name=False, files_list=None): 187 | files_list = files_list or [] 188 | if source_type=='folder': 189 | folder_name = QtWidgets.QFileDialog.getExistingDirectory(dir=main_window.last_input_media_folder_path) 190 | if not folder_name: 191 | return 192 | main_window.labelInputFacesPath.setText(misc_helpers.truncate_text(folder_name)) 193 | main_window.labelInputFacesPath.setToolTip(folder_name) 194 | main_window.last_input_media_folder_path = folder_name 195 | 196 | elif source_type=='files': 197 | files_list = QtWidgets.QFileDialog.getOpenFileNames()[0] 198 | if not files_list: 199 | return 200 | file_dir = misc_helpers.get_dir_of_file(files_list[0]) 201 | main_window.labelInputFacesPath.setText(file_dir) #Just a temp text until i think of something better 202 | main_window.labelInputFacesPath.setToolTip(file_dir) 203 | main_window.last_input_media_folder_path = file_dir 204 | 205 | clear_stop_loading_input_media(main_window) 206 | card_actions.clear_input_faces(main_window) 207 | main_window.input_faces_loader_worker = ui_workers.InputFacesLoaderWorker(main_window=main_window, folder_name=folder_name, files_list=files_list) 208 | main_window.input_faces_loader_worker.thumbnail_ready.connect(partial(add_media_thumbnail_to_source_faces_list, main_window)) 209 | main_window.input_faces_loader_worker.start() 210 | 211 | def set_up_list_widget_placeholder(main_window: 'MainWindow', list_widget: QtWidgets.QListWidget): 212 | # Placeholder label 213 | placeholder_label = QtWidgets.QLabel(list_widget) 214 | placeholder_label.setText( 215 | "" 216 | "

Drop Files

" 217 | "

or

" 218 | "

Click here to Select a Folder

" 219 | "" 220 | ) 221 | # placeholder_label.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) 222 | placeholder_label.setStyleSheet("color: gray; font-size: 15px; font-weight: bold;") 223 | 224 | # Center the label inside the QListWidget 225 | # placeholder_label.setGeometry(list_widget.rect()) # Match QListWidget's size 226 | placeholder_label.setAttribute(QtCore.Qt.WidgetAttribute.WA_TransparentForMouseEvents) # Allow interactions to pass through 227 | placeholder_label.setVisible(not list_widget.count()) # Show if the list is empty 228 | 229 | # Use a QVBoxLayout to center the placeholder label 230 | layout = QtWidgets.QVBoxLayout(list_widget) 231 | layout.addWidget(placeholder_label) 232 | layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) # Center the label vertically and horizontally 233 | layout.setContentsMargins(0, 0, 0, 0) # Remove margins to ensure full coverage 234 | 235 | # Keep a reference for toggling visibility later 236 | list_widget.placeholder_label = placeholder_label 237 | # Set default cursor as PointingHand 238 | list_widget.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) 239 | 240 | def select_output_media_folder(main_window: 'MainWindow'): 241 | folder_name = QtWidgets.QFileDialog.getExistingDirectory() 242 | if folder_name: 243 | main_window.outputFolderLineEdit.setText(folder_name) 244 | common_widget_actions.create_control(main_window, 'OutputMediaFolder', folder_name) 245 | -------------------------------------------------------------------------------- /app/ui/widgets/common_layout_data.py: -------------------------------------------------------------------------------- 1 | from app.helpers.typing_helper import LayoutDictTypes 2 | import app.ui.widgets.actions.layout_actions as layout_actions 3 | 4 | COMMON_LAYOUT_DATA: LayoutDictTypes = { 5 | # 'Face Compare':{ 6 | # 'ViewFaceMaskEnableToggle':{ 7 | # 'level': 1, 8 | # 'label': 'View Face Mask', 9 | # 'default': False, 10 | # 'help': 'Show Face Mask', 11 | # 'exec_function': layout_actions.fit_image_to_view_onchange, 12 | # 'exec_function_args': [], 13 | # }, 14 | # 'ViewFaceCompareEnableToggle':{ 15 | # 'level': 1, 16 | # 'label': 'View Face Compare', 17 | # 'default': False, 18 | # 'help': 'Show Face Compare', 19 | # 'exec_function': layout_actions.fit_image_to_view_onchange, 20 | # 'exec_function_args': [], 21 | # }, 22 | # }, 23 | 'Face Restorer': { 24 | 'FaceRestorerEnableToggle': { 25 | 'level': 1, 26 | 'label': 'Enable Face Restorer', 27 | 'default': False, 28 | 'help': 'Enable the use of a face restoration model to improve the quality of the face after swapping.' 29 | }, 30 | 'FaceRestorerTypeSelection': { 31 | 'level': 2, 32 | 'label': 'Restorer Type', 33 | 'options': ['GFPGAN-v1.4', 'CodeFormer', 'GPEN-256', 'GPEN-512', 'GPEN-1024', 'GPEN-2048', 'RestoreFormer++', 'VQFR-v2'], 34 | 'default': 'GFPGAN-v1.4', 35 | 'parentToggle': 'FaceRestorerEnableToggle', 36 | 'requiredToggleValue': True, 37 | 'help': 'Select the model type for face restoration.' 38 | }, 39 | 'FaceRestorerDetTypeSelection': { 40 | 'level': 2, 41 | 'label': 'Alignment', 42 | 'options': ['Original', 'Blend', 'Reference'], 43 | 'default': 'Original', 44 | 'parentToggle': 'FaceRestorerEnableToggle', 45 | 'requiredToggleValue': True, 46 | 'help': 'Select the alignment method for restoring the face to its original or blended position.' 47 | }, 48 | 'FaceFidelityWeightDecimalSlider': { 49 | 'level': 2, 50 | 'label': 'Fidelity Weight', 51 | 'min_value': '0.0', 52 | 'max_value': '1.0', 53 | 'default': '0.9', 54 | 'decimals': 1, 55 | 'step': 0.1, 56 | 'parentToggle': 'FaceRestorerEnableToggle', 57 | 'requiredToggleValue': True, 58 | 'help': 'Adjust the fidelity weight to control how closely the restoration preserves the original face details.' 59 | }, 60 | 'FaceRestorerBlendSlider': { 61 | 'level': 2, 62 | 'label': 'Blend', 63 | 'min_value': '0', 64 | 'max_value': '100', 65 | 'default': '100', 66 | 'step': 1, 67 | 'parentToggle': 'FaceRestorerEnableToggle', 68 | 'requiredToggleValue': True, 69 | 'help': 'Control the blend ratio between the restored face and the swapped face.' 70 | }, 71 | 'FaceRestorerEnable2Toggle': { 72 | 'level': 1, 73 | 'label': 'Enable Face Restorer 2', 74 | 'default': False, 75 | 'help': 'Enable the use of a face restoration model to improve the quality of the face after swapping.' 76 | }, 77 | 'FaceRestorerType2Selection': { 78 | 'level': 2, 79 | 'label': 'Restorer Type', 80 | 'options': ['GFPGAN-v1.4', 'CodeFormer', 'GPEN-256', 'GPEN-512', 'GPEN-1024', 'GPEN-2048', 'RestoreFormer++', 'VQFR-v2'], 81 | 'default': 'GFPGAN-v1.4', 82 | 'parentToggle': 'FaceRestorerEnable2Toggle', 83 | 'requiredToggleValue': True, 84 | 'help': 'Select the model type for face restoration.' 85 | }, 86 | 'FaceRestorerDetType2Selection': { 87 | 'level': 2, 88 | 'label': 'Alignment', 89 | 'options': ['Original', 'Blend', 'Reference'], 90 | 'default': 'Original', 91 | 'parentToggle': 'FaceRestorerEnable2Toggle', 92 | 'requiredToggleValue': True, 93 | 'help': 'Select the alignment method for restoring the face to its original or blended position.' 94 | }, 95 | 'FaceFidelityWeight2DecimalSlider': { 96 | 'level': 2, 97 | 'label': 'Fidelity Weight', 98 | 'min_value': '0.0', 99 | 'max_value': '1.0', 100 | 'default': '0.9', 101 | 'decimals': 1, 102 | 'step': 0.1, 103 | 'parentToggle': 'FaceRestorerEnable2Toggle', 104 | 'requiredToggleValue': True, 105 | 'help': 'Adjust the fidelity weight to control how closely the restoration preserves the original face details.' 106 | }, 107 | 'FaceRestorerBlend2Slider': { 108 | 'level': 2, 109 | 'label': 'Blend', 110 | 'min_value': '0', 111 | 'max_value': '100', 112 | 'default': '100', 113 | 'step': 1, 114 | 'parentToggle': 'FaceRestorerEnable2Toggle', 115 | 'requiredToggleValue': True, 116 | 'help': 'Control the blend ratio between the restored face and the swapped face.' 117 | }, 118 | 'FaceExpressionEnableToggle': { 119 | 'level': 1, 120 | 'label': 'Enable Face Expression Restorer', 121 | 'default': False, 122 | 'help': 'Enabled the use of the LivePortrait face expression model to restore facial expressions after swapping.' 123 | }, 124 | 'FaceExpressionCropScaleDecimalSlider': { 125 | 'level': 2, 126 | 'label': 'Crop Scale', 127 | 'min_value': '1.80', 128 | 'max_value': '3.00', 129 | 'default': '2.30', 130 | 'step': 0.05, 131 | 'decimals': 2, 132 | 'parentToggle': 'FaceExpressionEnableToggle', 133 | 'requiredToggleValue': True, 134 | 'help': 'Changes swap crop scale. Increase the value to capture the face more distantly.' 135 | }, 136 | 'FaceExpressionVYRatioDecimalSlider': { 137 | 'level': 2, 138 | 'label': 'VY Ratio', 139 | 'min_value': '-0.125', 140 | 'max_value': '-0.100', 141 | 'default': '-0.125', 142 | 'step': 0.001, 143 | 'decimals': 3, 144 | 'parentToggle': 'FaceExpressionEnableToggle', 145 | 'requiredToggleValue': True, 146 | 'help': 'Changes the vy ratio for crop scale. Increase the value to capture the face more distantly.' 147 | }, 148 | 'FaceExpressionFriendlyFactorDecimalSlider': { 149 | 'level': 2, 150 | 'label': 'Expression Friendly Factor', 151 | 'min_value': '0.0', 152 | 'max_value': '1.0', 153 | 'default': '1.0', 154 | 'decimals': 1, 155 | 'step': 0.1, 156 | 'parentToggle': 'FaceExpressionEnableToggle', 157 | 'requiredToggleValue': True, 158 | 'help': 'Control the expression similarity between the driving face and the swapped face.' 159 | }, 160 | 'FaceExpressionAnimationRegionSelection': { 161 | 'level': 2, 162 | 'label': 'Animation Region', 163 | 'options': ['all', 'eyes', 'lips'], 164 | 'default': 'all', 165 | 'parentToggle': 'FaceExpressionEnableToggle', 166 | 'requiredToggleValue': True, 167 | 'help': 'The facial region involved in the restoration process.' 168 | }, 169 | 'FaceExpressionNormalizeLipsEnableToggle': { 170 | 'level': 2, 171 | 'label': 'Normalize Lips', 172 | 'default': True, 173 | 'parentToggle': 'FaceExpressionEnableToggle', 174 | 'requiredToggleValue': True, 175 | 'help': 'Normalize the lips during the facial restoration process.' 176 | }, 177 | 'FaceExpressionNormalizeLipsThresholdDecimalSlider': { 178 | 'level': 3, 179 | 'label': 'Normalize Lips Threshold', 180 | 'min_value': '0.00', 181 | 'max_value': '1.00', 182 | 'default': '0.03', 183 | 'decimals': 2, 184 | 'step': 0.01, 185 | 'parentToggle': 'FaceExpressionNormalizeLipsEnableToggle & FaceExpressionEnableToggle', 186 | 'requiredToggleValue': True, 187 | 'help': 'Threshold value for Normalize Lips.' 188 | }, 189 | 'FaceExpressionRetargetingEyesEnableToggle': { 190 | 'level': 2, 191 | 'label': 'Retargeting Eyes', 192 | 'default': False, 193 | 'parentToggle': 'FaceExpressionEnableToggle', 194 | 'requiredToggleValue': True, 195 | 'help': 'Adjusting or redirecting the gaze or movement of the eyes during the facial restoration process. It overrides the Animation Region settings, meaning that the Animation Region will be ignored.' 196 | }, 197 | 'FaceExpressionRetargetingEyesMultiplierDecimalSlider': { 198 | 'level': 3, 199 | 'label': 'Retargeting Eyes Multiplier', 200 | 'min_value': '0.00', 201 | 'max_value': '2.00', 202 | 'default': '1.00', 203 | 'decimals': 2, 204 | 'step': 0.01, 205 | 'parentToggle': 'FaceExpressionRetargetingEyesEnableToggle & FaceExpressionEnableToggle', 206 | 'requiredToggleValue': True, 207 | 'help': 'Multiplier value for Retargeting Eyes.' 208 | }, 209 | 'FaceExpressionRetargetingLipsEnableToggle': { 210 | 'level': 2, 211 | 'label': 'Retargeting Lips', 212 | 'default': False, 213 | 'parentToggle': 'FaceExpressionEnableToggle', 214 | 'requiredToggleValue': True, 215 | 'help': 'Adjusting or modifying the position, shape, or movement of the lips during the facial restoration process. It overrides the Animation Region settings, meaning that the Animation Region will be ignored.' 216 | }, 217 | 'FaceExpressionRetargetingLipsMultiplierDecimalSlider': { 218 | 'level': 3, 219 | 'label': 'Retargeting Lips Multiplier', 220 | 'min_value': '0.00', 221 | 'max_value': '2.00', 222 | 'default': '1.00', 223 | 'decimals': 2, 224 | 'step': 0.01, 225 | 'parentToggle': 'FaceExpressionRetargetingLipsEnableToggle & FaceExpressionEnableToggle', 226 | 'requiredToggleValue': True, 227 | 'help': 'Multiplier value for Retargeting Lips.' 228 | }, 229 | }, 230 | } -------------------------------------------------------------------------------- /app/ui/widgets/event_filters.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from functools import partial 3 | 4 | from PySide6 import QtWidgets, QtGui, QtCore 5 | from app.ui.widgets.actions import list_view_actions 6 | from app.ui.widgets import ui_workers 7 | import app.helpers.miscellaneous as misc_helpers 8 | 9 | if TYPE_CHECKING: 10 | from app.ui.main_ui import MainWindow 11 | 12 | class GraphicsViewEventFilter(QtCore.QObject): 13 | def __init__(self, main_window: 'MainWindow', parent=None): 14 | super().__init__(parent) 15 | self.main_window = main_window 16 | 17 | def eventFilter(self, graphics_object: QtWidgets.QGraphicsView, event): 18 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 19 | if event.button() == QtCore.Qt.MouseButton.LeftButton: 20 | self.main_window.buttonMediaPlay.click() 21 | # You can emit a signal or call another function here 22 | return True # Mark the event as handled 23 | return False # Pass the event to the original handler 24 | 25 | class videoSeekSliderLineEditEventFilter(QtCore.QObject): 26 | def __init__(self, main_window: 'MainWindow', parent=None): 27 | super().__init__(parent) 28 | self.main_window = main_window 29 | 30 | def eventFilter(self, line_edit: QtWidgets.QLineEdit, event): 31 | if event.type() == QtCore.QEvent.KeyPress: 32 | # Check if the pressed key is Enter/Return 33 | if event.key() in (QtCore.Qt.Key_Enter, QtCore.Qt.Key_Return): 34 | new_value = line_edit.text() 35 | # Reset the line edit value to the slider value if the user input an empty text 36 | if new_value=='': 37 | new_value = str(self.main_window.videoSeekSlider.value()) 38 | else: 39 | new_value = int(new_value) 40 | max_frame_number = self.main_window.video_processor.max_frame_number 41 | # If the value entered by user if greater than the max no of frames in the video, set the new value to the max_frame_number 42 | if new_value > max_frame_number: 43 | new_value = max_frame_number 44 | # Update values of line edit and slider 45 | line_edit.setText(str(new_value)) 46 | self.main_window.videoSeekSlider.setValue(new_value) 47 | self.main_window.video_processor.process_current_frame() # Process the current frame 48 | 49 | return True 50 | return False 51 | 52 | class VideoSeekSliderEventFilter(QtCore.QObject): 53 | def __init__(self, main_window: 'MainWindow', parent=None): 54 | super().__init__(parent) 55 | self.main_window = main_window 56 | 57 | def eventFilter(self, slider, event): 58 | if event.type() == QtCore.QEvent.Type.KeyPress: 59 | if event.key() in {QtCore.Qt.Key_Left, QtCore.Qt.Key_Right}: 60 | # Allow default slider movement 61 | result = super().eventFilter(slider, event) 62 | 63 | # After the slider moves, call the custom processing function 64 | QtCore.QTimer.singleShot(0, self.main_window.video_processor.process_current_frame) 65 | 66 | return result # Return the result of the default handling 67 | elif event.type() == QtCore.QEvent.Type.Wheel: 68 | # Allow default slider movement 69 | result = super().eventFilter(slider, event) 70 | 71 | # After the slider moves, call the custom processing function 72 | QtCore.QTimer.singleShot(0, self.main_window.video_processor.process_current_frame) 73 | return result 74 | 75 | # For other events, use the default behavior 76 | return super().eventFilter(slider, event) 77 | 78 | class ListWidgetEventFilter(QtCore.QObject): 79 | def __init__(self, main_window: 'MainWindow', parent=None): 80 | super().__init__(parent) 81 | self.main_window = main_window 82 | 83 | def eventFilter(self, list_widget: QtWidgets.QListWidget, event: QtCore.QEvent|QtGui.QDropEvent|QtGui.QMouseEvent): 84 | 85 | if list_widget == self.main_window.targetVideosList or list_widget == self.main_window.targetVideosList.viewport(): 86 | 87 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 88 | if event.button() == QtCore.Qt.MouseButton.LeftButton and not self.main_window.target_videos: 89 | list_view_actions.select_target_medias(self.main_window, 'folder') 90 | 91 | elif event.type() == QtCore.QEvent.Type.DragEnter: 92 | # Accept drag events with URLs 93 | if event.mimeData().hasUrls(): 94 | 95 | urls = event.mimeData().urls() 96 | print("Drag: URLS", [url.toLocalFile() for url in urls]) 97 | event.acceptProposedAction() 98 | return True 99 | # Handle the drop event 100 | elif event.type() == QtCore.QEvent.Type.Drop: 101 | 102 | if event.mimeData().hasUrls(): 103 | # Extract file paths 104 | file_paths = [] 105 | for url in event.mimeData().urls(): 106 | url = url.toLocalFile() 107 | if misc_helpers.is_image_file(url) or misc_helpers.is_video_file(url): 108 | file_paths.append(url) 109 | else: 110 | print(f'{url} is not an Video or Image file') 111 | # print("Drop: URLS", [url.toLocalFile() for url in urls]) 112 | if file_paths: 113 | self.main_window.video_loader_worker = ui_workers.TargetMediaLoaderWorker(main_window=self.main_window, folder_name=False, files_list=file_paths) 114 | self.main_window.video_loader_worker.thumbnail_ready.connect(partial(list_view_actions.add_media_thumbnail_to_target_videos_list, self.main_window)) 115 | self.main_window.video_loader_worker.start() 116 | event.acceptProposedAction() 117 | return True 118 | 119 | 120 | elif list_widget == self.main_window.inputFacesList or list_widget == self.main_window.inputFacesList.viewport(): 121 | 122 | if event.type() == QtCore.QEvent.Type.MouseButtonPress: 123 | if event.button() == QtCore.Qt.MouseButton.LeftButton and not self.main_window.input_faces: 124 | list_view_actions.select_input_face_images(self.main_window, 'folder') 125 | 126 | elif event.type() == QtCore.QEvent.Type.DragEnter: 127 | # Accept drag events with URLs 128 | if event.mimeData().hasUrls(): 129 | 130 | urls = event.mimeData().urls() 131 | print("Drag: URLS", [url.toLocalFile() for url in urls]) 132 | event.acceptProposedAction() 133 | return True 134 | # Handle the drop event 135 | elif event.type() == QtCore.QEvent.Type.Drop: 136 | 137 | if event.mimeData().hasUrls(): 138 | # Extract file paths 139 | file_paths = [] 140 | for url in event.mimeData().urls(): 141 | url = url.toLocalFile() 142 | if misc_helpers.is_image_file(url): 143 | file_paths.append(url) 144 | else: 145 | print(f'{url} is not an Image file') 146 | # print("Drop: URLS", [url.toLocalFile() for url in urls]) 147 | if file_paths: 148 | self.main_window.input_faces_loader_worker = ui_workers.InputFacesLoaderWorker(main_window=self.main_window, folder_name=False, files_list=file_paths) 149 | self.main_window.input_faces_loader_worker.thumbnail_ready.connect(partial(list_view_actions.add_media_thumbnail_to_source_faces_list, self.main_window)) 150 | self.main_window.input_faces_loader_worker.start() 151 | event.acceptProposedAction() 152 | return True 153 | return super().eventFilter(list_widget, event) -------------------------------------------------------------------------------- /app/ui/widgets/settings_layout_data.py: -------------------------------------------------------------------------------- 1 | from app.ui.widgets.actions import control_actions 2 | import cv2 3 | from app.helpers.typing_helper import LayoutDictTypes 4 | SETTINGS_LAYOUT_DATA: LayoutDictTypes = { 5 | 'Appearance': { 6 | 'ThemeSelection': { 7 | 'level': 1, 8 | 'label': 'Theme', 9 | 'options': ['Dark', 'Dark-Blue', 'Light'], 10 | 'default': 'Dark', 11 | 'help': 'Select the theme to be used', 12 | 'exec_function': control_actions.change_theme, 13 | 'exec_function_args': [], 14 | }, 15 | }, 16 | 'General': { 17 | 'ProvidersPrioritySelection': { 18 | 'level': 1, 19 | 'label': 'Providers Priority', 20 | 'options': ['CUDA', 'TensorRT', 'TensorRT-Engine', 'CPU'], 21 | 'default': 'CUDA', 22 | 'help': 'Select the providers priority to be used with the system.', 23 | 'exec_function': control_actions.change_execution_provider, 24 | 'exec_function_args': [], 25 | }, 26 | 'nThreadsSlider': { 27 | 'level': 1, 28 | 'label': 'Number of Threads', 29 | 'min_value': '1', 30 | 'max_value': '30', 31 | 'default': '2', 32 | 'step': 1, 33 | 'help': 'Set number of execution threads while playing and recording. Depends strongly on GPU VRAM.', 34 | 'exec_function': control_actions.change_threads_number, 35 | 'exec_function_args': [], 36 | }, 37 | }, 38 | 'Video Settings': { 39 | 'VideoPlaybackCustomFpsToggle': { 40 | 'level': 1, 41 | 'label': 'Set Custom Video Playback FPS', 42 | 'default': False, 43 | 'help': 'Manually set the FPS to be used when playing the video', 44 | 'exec_function': control_actions.set_video_playback_fps, 45 | 'exec_function_args': [], 46 | }, 47 | 'VideoPlaybackCustomFpsSlider': { 48 | 'level': 2, 49 | 'label': 'Video Playback FPS', 50 | 'min_value': '1', 51 | 'max_value': '120', 52 | 'default': '30', 53 | 'parentToggle': 'VideoPlaybackCustomFpsToggle', 54 | 'requiredToggleValue': True, 55 | 'step': 1, 56 | 'help': 'Set the maximum FPS of the video when playing' 57 | }, 58 | }, 59 | 'Auto Swap':{ 60 | 'AutoSwapToggle': { 61 | 'level': 1, 62 | 'label': 'Auto Swap', 63 | 'default': False, 64 | 'help': 'Automatically Swap all faces using selected Source Faces/Embeddings when loading an video/image file' 65 | }, 66 | }, 67 | 'Detectors': { 68 | 'DetectorModelSelection': { 69 | 'level': 1, 70 | 'label': 'Face Detect Model', 71 | 'options': ['RetinaFace', 'Yolov8', 'SCRFD', 'Yunet'], 72 | 'default': 'RetinaFace', 73 | 'help': 'Select the face detection model to use for detecting faces in the input image or video.' 74 | }, 75 | 'DetectorScoreSlider': { 76 | 'level': 1, 77 | 'label': 'Detect Score', 78 | 'min_value': '1', 79 | 'max_value': '100', 80 | 'default': '50', 81 | 'step': 1, 82 | 'help': 'Set the confidence score threshold for face detection. Higher values ensure more confident detections but may miss some faces.' 83 | }, 84 | 'MaxFacesToDetectSlider': { 85 | 'level': 1, 86 | 'label': 'Max No of Faces to Detect', 87 | 'min_value': '1', 88 | 'max_value': '50', 89 | 'default': '20', 90 | 'step': 1, 91 | 'help': 'Set the maximum number of faces to detect in a frame' 92 | 93 | }, 94 | 'AutoRotationToggle': { 95 | 'level': 1, 96 | 'label': 'Auto Rotation', 97 | 'default': False, 98 | 'help': 'Automatically rotate the input to detect faces in various orientations.' 99 | }, 100 | 'ManualRotationEnableToggle': { 101 | 'level': 1, 102 | 'label': 'Manual Rotation', 103 | 'default': False, 104 | 'help': 'Rotate the face detector to better detect faces at different angles.' 105 | }, 106 | 'ManualRotationAngleSlider': { 107 | 'level': 2, 108 | 'label': 'Rotation Angle', 109 | 'min_value': '0', 110 | 'max_value': '270', 111 | 'default': '0', 112 | 'step': 90, 113 | 'parentToggle': 'ManualRotationEnableToggle', 114 | 'requiredToggleValue': True, 115 | 'help': 'Set this to the angle of the input face angle to help with laying down/upside down/etc. Angles are read clockwise.' 116 | }, 117 | 'LandmarkDetectToggle': { 118 | 'level': 1, 119 | 'label': 'Enable Landmark Detection', 120 | 'default': False, 121 | 'help': 'Enable or disable facial landmark detection, which is used to refine face alignment.' 122 | }, 123 | 'LandmarkDetectModelSelection': { 124 | 'level': 2, 125 | 'label': 'Landmark Detect Model', 126 | 'options': ['5', '68', '3d68', '98', '106', '203', '478'], 127 | 'default': '203', 128 | 'parentToggle': 'LandmarkDetectToggle', 129 | 'requiredToggleValue': True, 130 | 'help': 'Select the landmark detection model, where different models detect varying numbers of facial landmarks.' 131 | }, 132 | 'LandmarkDetectScoreSlider': { 133 | 'level': 2, 134 | 'label': 'Landmark Detect Score', 135 | 'min_value': '1', 136 | 'max_value': '100', 137 | 'default': '50', 138 | 'step': 1, 139 | 'parentToggle': 'LandmarkDetectToggle', 140 | 'requiredToggleValue': True, 141 | 'help': 'Set the confidence score threshold for facial landmark detection.' 142 | }, 143 | 'DetectFromPointsToggle': { 144 | 'level': 2, 145 | 'label': 'Detect From Points', 146 | 'default': False, 147 | 'parentToggle': 'LandmarkDetectToggle', 148 | 'requiredToggleValue': True, 149 | 'help': 'Enable detection of faces from specified landmark points.' 150 | }, 151 | 'ShowLandmarksEnableToggle': { 152 | 'level': 1, 153 | 'label': 'Show Landmarks', 154 | 'default': False, 155 | 'help': 'Show Landmarks in realtime.' 156 | }, 157 | 'ShowAllDetectedFacesBBoxToggle': { 158 | 'level': 1, 159 | 'label': 'Show Bounding Boxes', 160 | 'default': False, 161 | 'help': 'Draw bounding boxes to all detected faces in the frame' 162 | } 163 | }, 164 | 'DFM Settings':{ 165 | 'MaxDFMModelsSlider':{ 166 | 'level': 1, 167 | 'label': 'Maximum DFM Models to use', 168 | 'min_value': '1', 169 | 'max_value': '5', 170 | 'default': '1', 171 | 'step': 1, 172 | 'help': "Set the maximum number of DFM Models to keep in memory at a time. Set this based on your GPU's VRAM", 173 | } 174 | }, 175 | 'Frame Enhancer':{ 176 | 'FrameEnhancerEnableToggle':{ 177 | 'level': 1, 178 | 'label': 'Enable Frame Enhancer', 179 | 'default': False, 180 | 'help': 'Enable frame enhancement for video inputs to improve visual quality.' 181 | }, 182 | 'FrameEnhancerTypeSelection':{ 183 | 'level': 2, 184 | 'label': 'Frame Enhancer Type', 185 | 'options': ['RealEsrgan-x2-Plus', 'RealEsrgan-x4-Plus', 'RealEsr-General-x4v3', 'BSRGan-x2', 'BSRGan-x4', 'UltraSharp-x4', 'UltraMix-x4', 'DDColor-Artistic', 'DDColor', 'DeOldify-Artistic', 'DeOldify-Stable', 'DeOldify-Video'], 186 | 'default': 'RealEsrgan-x2-Plus', 187 | 'parentToggle': 'FrameEnhancerEnableToggle', 188 | 'requiredToggleValue': True, 189 | 'help': 'Select the type of frame enhancement to apply, based on the content and resolution requirements.' 190 | }, 191 | 'FrameEnhancerBlendSlider': { 192 | 'level': 2, 193 | 'label': 'Blend', 194 | 'min_value': '0', 195 | 'max_value': '100', 196 | 'default': '100', 197 | 'step': 1, 198 | 'parentToggle': 'FrameEnhancerEnableToggle', 199 | 'requiredToggleValue': True, 200 | 'help': 'Blends the enhanced results back into the original frame.' 201 | }, 202 | }, 203 | 'Webcam Settings': { 204 | 'WebcamMaxNoSelection': { 205 | 'level': 2, 206 | 'label': 'Webcam Max No', 207 | 'options': ['1', '2', '3', '4', '5', '6'], 208 | 'default': '1', 209 | 'help': 'Select the maximum number of webcam streams to allow for face swapping.' 210 | }, 211 | 'WebcamBackendSelection': { 212 | 'level': 2, 213 | 'label': 'Webcam Backend', 214 | 'options': ['Default', 'DirectShow', 'MSMF', 'V4L', 'V4L2', 'GSTREAMER'], 215 | 'default': 'Default', 216 | 'help': 'Choose the backend for accessing webcam input.' 217 | }, 218 | 'WebcamMaxResSelection': { 219 | 'level': 2, 220 | 'label': 'Webcam Resolution', 221 | 'options': ['480x360', '640x480', '1280x720', '1920x1080', '2560x1440', '3840x2160'], 222 | 'default': '1280x720', 223 | 'help': 'Select the maximum resolution for webcam input.' 224 | }, 225 | 'WebCamMaxFPSSelection': { 226 | 'level': 2, 227 | 'label': 'Webcam FPS', 228 | 'options': ['23', '30', '60'], 229 | 'default': '30', 230 | 'help': 'Set the maximum frames per second (FPS) for webcam input.' 231 | }, 232 | }, 233 | 'Virtual Camera': { 234 | 'SendVirtCamFramesEnableToggle': { 235 | 'level': 1, 236 | 'label': 'Send Frames to Virtual Camera', 237 | 'default': False, 238 | 'help': 'Send the swapped video/webcam output to virtual camera for using in external applications', 239 | 'exec_function': control_actions.toggle_virtualcam, 240 | 'exec_function_args': [], 241 | }, 242 | 'VirtCamBackendSelection': { 243 | 'level': 1, 244 | 'label': 'Virtual Camera Backend', 245 | 'options': ['obs', 'unitycapture'], 246 | 'default': 'obs', 247 | 'help': 'Choose the backend based on the Virtual Camera you have set up', 248 | 'parentToggle': 'SendVirtCamFramesEnableToggle', 249 | 'requiredToggleValue': True, 250 | 'exec_function': control_actions.enable_virtualcam, 251 | 'exec_funtion_args': [], 252 | }, 253 | }, 254 | 'Face Recognition': { 255 | 'RecognitionModelSelection': { 256 | 'level': 1, 257 | 'label': 'Recognition Model', 258 | 'options': ['Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'CSCSArcFace'], 259 | 'default': 'Inswapper128ArcFace', 260 | 'help': 'Choose the ArcFace model to be used for comparing the similarity of faces.' 261 | }, 262 | 'SimilarityTypeSelection': { 263 | 'level': 1, 264 | 'label': 'Swapping Similarity Type', 265 | 'options': ['Opal', 'Pearl', 'Optimal'], 266 | 'default': 'Opal', 267 | 'help': 'Choose the type of similarity calculation for face detection and matching during the face swapping process.' 268 | }, 269 | }, 270 | 'Embedding Merge Method':{ 271 | 'EmbMergeMethodSelection':{ 272 | 'level': 1, 273 | 'label': 'Embedding Merge Method', 274 | 'options': ['Mean','Median'], 275 | 'default': 'Mean', 276 | 'help': 'Select the method to merge facial embeddings. "Mean" averages the embeddings, while "Median" selects the middle value, providing more robustness to outliers.' 277 | } 278 | }, 279 | 'Media Selection':{ 280 | 'TargetMediaFolderRecursiveToggle':{ 281 | 'level': 1, 282 | 'label': 'Target Media Include Subfolders', 283 | 'default': False, 284 | 'help': 'Include all files from Subfolders when choosing Target Media Folder' 285 | }, 286 | 'InputFacesFolderRecursiveToggle':{ 287 | 'level': 1, 288 | 'label': 'Input Faces Include Subfolders', 289 | 'default': False, 290 | 'help': 'Include all files from Subfolders when choosing Input Faces Folder' 291 | } 292 | } 293 | } 294 | 295 | CAMERA_BACKENDS = { 296 | 'Default': cv2.CAP_ANY, 297 | 'DirectShow': cv2.CAP_DSHOW, 298 | 'MSMF': cv2.CAP_MSMF, 299 | 'V4L': cv2.CAP_V4L, 300 | 'V4L2': cv2.CAP_V4L2, 301 | 'GSTREAMER': cv2.CAP_GSTREAMER, 302 | } -------------------------------------------------------------------------------- /app/ui/widgets/ui_workers.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from functools import partial 3 | from typing import TYPE_CHECKING, Dict 4 | import traceback 5 | import os 6 | 7 | import cv2 8 | import torch 9 | import numpy 10 | from PySide6 import QtCore as qtc 11 | from PySide6.QtGui import QPixmap 12 | 13 | from app.processors.models_data import detection_model_mapping, landmark_model_mapping 14 | from app.helpers import miscellaneous as misc_helpers 15 | from app.ui.widgets.actions import common_actions as common_widget_actions 16 | from app.ui.widgets.actions import filter_actions 17 | from app.ui.widgets.settings_layout_data import SETTINGS_LAYOUT_DATA, CAMERA_BACKENDS 18 | 19 | if TYPE_CHECKING: 20 | from app.ui.main_ui import MainWindow 21 | 22 | class TargetMediaLoaderWorker(qtc.QThread): 23 | # Define signals to emit when loading is done or if there are updates 24 | thumbnail_ready = qtc.Signal(str, QPixmap, str, str) # Signal with media path and QPixmap and file_type, media_id 25 | webcam_thumbnail_ready = qtc.Signal(str, QPixmap, str, str, int, int) 26 | finished = qtc.Signal() # Signal to indicate completion 27 | 28 | def __init__(self, main_window: 'MainWindow', folder_name=False, files_list=None, media_ids=None, webcam_mode=False, parent=None,): 29 | super().__init__(parent) 30 | self.main_window = main_window 31 | self.folder_name = folder_name 32 | self.files_list = files_list or [] 33 | self.media_ids = media_ids or [] 34 | self.webcam_mode = webcam_mode 35 | self._running = True # Flag to control the running state 36 | 37 | # Ensure thumbnail directory exists 38 | misc_helpers.ensure_thumbnail_dir() 39 | 40 | def run(self): 41 | if self.folder_name: 42 | self.load_videos_and_images_from_folder(self.folder_name) 43 | if self.files_list: 44 | self.load_videos_and_images_from_files_list(self.files_list) 45 | if self.webcam_mode: 46 | self.load_webcams() 47 | self.finished.emit() 48 | 49 | def load_videos_and_images_from_folder(self, folder_name): 50 | # Initially hide the placeholder text 51 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 52 | video_files = misc_helpers.get_video_files(folder_name, self.main_window.control['TargetMediaFolderRecursiveToggle']) 53 | image_files = misc_helpers.get_image_files(folder_name, self.main_window.control['TargetMediaFolderRecursiveToggle']) 54 | 55 | i=0 56 | media_files = video_files + image_files 57 | for media_file in media_files: 58 | if not self._running: # Check if the thread is still running 59 | break 60 | media_file_path = os.path.join(folder_name, media_file) 61 | file_type = misc_helpers.get_file_type(media_file_path) 62 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path, file_type) 63 | if self.media_ids: 64 | media_id = self.media_ids[i] 65 | else: 66 | media_id = str(uuid.uuid1().int) 67 | if pixmap: 68 | # Emit the signal to update GUI 69 | self.thumbnail_ready.emit(media_file_path, pixmap, file_type, media_id) 70 | i+=1 71 | # Show/Hide the placeholder text based on the number of items in ListWidget 72 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 73 | 74 | def load_videos_and_images_from_files_list(self, files_list): 75 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 76 | media_files = files_list 77 | i=0 78 | for media_file_path in media_files: 79 | if not self._running: # Check if the thread is still running 80 | break 81 | file_type = misc_helpers.get_file_type(media_file_path) 82 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path, file_type=file_type) 83 | if self.media_ids: 84 | media_id = self.media_ids[i] 85 | else: 86 | media_id = str(uuid.uuid1().int) 87 | if pixmap: 88 | # Emit the signal to update GUI 89 | self.thumbnail_ready.emit(media_file_path, pixmap, file_type,media_id) 90 | i+=1 91 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 92 | 93 | def load_webcams(self,): 94 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, True) 95 | camera_backend = CAMERA_BACKENDS[self.main_window.control['WebcamBackendSelection']] 96 | for i in range(int(self.main_window.control['WebcamMaxNoSelection'])): 97 | try: 98 | pixmap = common_widget_actions.extract_frame_as_pixmap(media_file_path=f'Webcam {i}', file_type='webcam', webcam_index=i, webcam_backend=camera_backend) 99 | media_id = str(uuid.uuid1().int) 100 | 101 | if pixmap: 102 | # Emit the signal to update GUI 103 | self.webcam_thumbnail_ready.emit(f'Webcam {i}', pixmap, 'webcam',media_id, i, camera_backend) 104 | except Exception: # pylint: disable=broad-exception-caught 105 | traceback.print_exc() 106 | self.main_window.placeholder_update_signal.emit(self.main_window.targetVideosList, False) 107 | 108 | def stop(self): 109 | """Stop the thread by setting the running flag to False.""" 110 | self._running = False 111 | self.wait() 112 | 113 | class InputFacesLoaderWorker(qtc.QThread): 114 | # Define signals to emit when loading is done or if there are updates 115 | thumbnail_ready = qtc.Signal(str, numpy.ndarray, object, QPixmap, str) 116 | finished = qtc.Signal() # Signal to indicate completion 117 | def __init__(self, main_window: 'MainWindow', media_path=False, folder_name=False, files_list=None, face_ids=None, parent=None): 118 | super().__init__(parent) 119 | self.main_window = main_window 120 | self.folder_name = folder_name 121 | self.files_list = files_list or [] 122 | self.face_ids = face_ids or [] 123 | self._running = True # Flag to control the running state 124 | self.was_playing = True 125 | self.pre_load_detection_recognition_models() 126 | 127 | def pre_load_detection_recognition_models(self): 128 | control = self.main_window.control.copy() 129 | detect_model = detection_model_mapping[control['DetectorModelSelection']] 130 | landmark_detect_model = landmark_model_mapping[control['LandmarkDetectModelSelection']] 131 | models_processor = self.main_window.models_processor 132 | if self.main_window.video_processor.processing: 133 | was_playing = True 134 | self.main_window.buttonMediaPlay.click() 135 | else: 136 | was_playing = False 137 | if not models_processor.models[detect_model]: 138 | models_processor.models[detect_model] = models_processor.load_model(detect_model) 139 | if not models_processor.models[landmark_detect_model] and control['LandmarkDetectToggle']: 140 | models_processor.models[landmark_detect_model] = models_processor.load_model(landmark_detect_model) 141 | for recognition_model in ['Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'CSCSArcFace', 'CSCSIDArcFace']: 142 | if not models_processor.models[recognition_model]: 143 | models_processor.models[recognition_model] = models_processor.load_model(recognition_model) 144 | if was_playing: 145 | self.main_window.buttonMediaPlay.click() 146 | 147 | def run(self): 148 | if self.folder_name or self.files_list: 149 | self.main_window.placeholder_update_signal.emit(self.main_window.inputFacesList, True) 150 | self.load_faces(self.folder_name, self.files_list) 151 | self.main_window.placeholder_update_signal.emit(self.main_window.inputFacesList, False) 152 | 153 | def load_faces(self, folder_name=False, files_list=None): 154 | control = self.main_window.control.copy() 155 | files_list = files_list or [] 156 | image_files = [] 157 | if folder_name: 158 | image_files = misc_helpers.get_image_files(self.folder_name, self.main_window.control['InputFacesFolderRecursiveToggle']) 159 | elif files_list: 160 | image_files = files_list 161 | 162 | i=0 163 | image_files.sort() 164 | for image_file_path in image_files: 165 | if not self._running: # Check if the thread is still running 166 | break 167 | if not misc_helpers.is_image_file(image_file_path): 168 | return 169 | if folder_name: 170 | image_file_path = os.path.join(folder_name, image_file_path) 171 | frame = misc_helpers.read_image_file(image_file_path) 172 | if frame is None: 173 | continue 174 | # Frame must be in RGB format 175 | frame = frame[..., ::-1] # Swap the channels from BGR to RGB 176 | 177 | img = torch.from_numpy(frame.astype('uint8')).to(self.main_window.models_processor.device) 178 | img = img.permute(2,0,1) 179 | _, kpss_5, _ = self.main_window.models_processor.run_detect(img, control['DetectorModelSelection'], max_num=1, score=control['DetectorScoreSlider']/100.0, input_size=(512, 512), use_landmark_detection=control['LandmarkDetectToggle'], landmark_detect_mode=control['LandmarkDetectModelSelection'], landmark_score=control["LandmarkDetectScoreSlider"]/100.0, from_points=control["DetectFromPointsToggle"], rotation_angles=[0] if not control["AutoRotationToggle"] else [0, 90, 180, 270]) 180 | 181 | # If atleast one face is found 182 | # found_face = [] 183 | face_kps = False 184 | try: 185 | face_kps = kpss_5[0] 186 | except IndexError: 187 | continue 188 | if face_kps.any(): 189 | face_emb, cropped_img = self.main_window.models_processor.run_recognize_direct(img, face_kps, control['SimilarityTypeSelection'], control['RecognitionModelSelection']) 190 | cropped_img = cropped_img.cpu().numpy() 191 | cropped_img = cropped_img[..., ::-1] # Swap the channels from RGB to BGR 192 | face_img = numpy.ascontiguousarray(cropped_img) 193 | # crop = cv2.resize(face[2].cpu().numpy(), (82, 82)) 194 | pixmap = common_widget_actions.get_pixmap_from_frame(self.main_window, face_img) 195 | 196 | embedding_store: Dict[str, numpy.ndarray] = {} 197 | # Ottenere i valori di 'options' 198 | options = SETTINGS_LAYOUT_DATA['Face Recognition']['RecognitionModelSelection']['options'] 199 | for option in options: 200 | if option != control['RecognitionModelSelection']: 201 | target_emb, _ = self.main_window.models_processor.run_recognize_direct(img, face_kps, control['SimilarityTypeSelection'], option) 202 | embedding_store[option] = target_emb 203 | else: 204 | embedding_store[control['RecognitionModelSelection']] = face_emb 205 | if not self.face_ids: 206 | face_id = str(uuid.uuid1().int) 207 | else: 208 | face_id = self.face_ids[i] 209 | self.thumbnail_ready.emit(image_file_path, face_img, embedding_store, pixmap, face_id) 210 | i+=1 211 | torch.cuda.empty_cache() 212 | self.finished.emit() 213 | 214 | def stop(self): 215 | """Stop the thread by setting the running flag to False.""" 216 | self._running = False 217 | self.wait() 218 | 219 | class FilterWorker(qtc.QThread): 220 | filtered_results = qtc.Signal(list) 221 | 222 | def __init__(self, main_window: 'MainWindow', search_text='', filter_list='target_videos'): 223 | super().__init__() 224 | self.main_window = main_window 225 | self.search_text = search_text 226 | self.filter_list = filter_list 227 | self.filter_list_widget = self.get_list_widget() 228 | self.filtered_results.connect(partial(filter_actions.update_filtered_list, main_window, self.filter_list_widget)) 229 | 230 | def get_list_widget(self,): 231 | list_widget = False 232 | if self.filter_list == 'target_videos': 233 | list_widget = self.main_window.targetVideosList 234 | elif self.filter_list == 'input_faces': 235 | list_widget = self.main_window.inputFacesList 236 | elif self.filter_list == 'merged_embeddings': 237 | list_widget = self.main_window.inputEmbeddingsList 238 | return list_widget 239 | 240 | def run(self,): 241 | if self.filter_list == 'target_videos': 242 | self.filter_target_videos(self.main_window, self.search_text) 243 | elif self.filter_list == 'input_faces': 244 | self.filter_input_faces(self.main_window, self.search_text) 245 | elif self.filter_list == 'merged_embeddings': 246 | self.filter_merged_embeddings(self.main_window, self.search_text) 247 | 248 | 249 | def filter_target_videos(self, main_window: 'MainWindow', search_text: str = ''): 250 | search_text = main_window.targetVideosSearchBox.text().lower() 251 | include_file_types = [] 252 | if main_window.filterImagesCheckBox.isChecked(): 253 | include_file_types.append('image') 254 | if main_window.filterVideosCheckBox.isChecked(): 255 | include_file_types.append('video') 256 | if main_window.filterWebcamsCheckBox.isChecked(): 257 | include_file_types.append('webcam') 258 | 259 | visible_indices = [] 260 | for i in range(main_window.targetVideosList.count()): 261 | item = main_window.targetVideosList.item(i) 262 | item_widget = main_window.targetVideosList.itemWidget(item) 263 | if ((not search_text or search_text in item_widget.media_path.lower()) and 264 | (item_widget.file_type in include_file_types)): 265 | visible_indices.append(i) 266 | 267 | self.filtered_results.emit(visible_indices) 268 | 269 | def filter_input_faces(self, main_window: 'MainWindow', search_text: str): 270 | search_text = search_text.lower() 271 | visible_indices = [] 272 | 273 | for i in range(main_window.inputFacesList.count()): 274 | item = main_window.inputFacesList.item(i) 275 | item_widget = main_window.inputFacesList.itemWidget(item) 276 | if not search_text or search_text in item_widget.media_path.lower(): 277 | visible_indices.append(i) 278 | 279 | self.filtered_results.emit(visible_indices) 280 | 281 | def filter_merged_embeddings(self, main_window: 'MainWindow', search_text: str): 282 | search_text = search_text.lower() 283 | visible_indices = [] 284 | 285 | for i in range(main_window.inputEmbeddingsList.count()): 286 | item = main_window.inputEmbeddingsList.item(i) 287 | item_widget = main_window.inputEmbeddingsList.itemWidget(item) 288 | if not search_text or search_text in item_widget.embedding_name.lower(): 289 | visible_indices.append(i) 290 | 291 | self.filtered_results.emit(visible_indices) 292 | 293 | def stop_thread(self): 294 | self.quit() 295 | self.wait() 296 | -------------------------------------------------------------------------------- /dependencies/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/dependencies/.gitkeep -------------------------------------------------------------------------------- /download_models.py: -------------------------------------------------------------------------------- 1 | from app.helpers.downloader import download_file 2 | from app.processors.models_data import models_list 3 | 4 | for model_data in models_list: 5 | download_file(model_data['model_name'], model_data['local_path'], model_data['hash'], model_data['url']) 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from app.ui import main_ui 2 | from PySide6 import QtWidgets 3 | import sys 4 | 5 | import qdarktheme 6 | from app.ui.core.proxy_style import ProxyStyle 7 | 8 | if __name__=="__main__": 9 | 10 | app = QtWidgets.QApplication(sys.argv) 11 | app.setStyle(ProxyStyle()) 12 | with open("app/ui/styles/dark_styles.qss", "r") as f: 13 | _style = f.read() 14 | _style = qdarktheme.load_stylesheet(custom_colors={"primary": "#4facc9"})+'\n'+_style 15 | app.setStyleSheet(_style) 16 | window = main_ui.MainWindow() 17 | window.show() 18 | app.exec() -------------------------------------------------------------------------------- /model_assets/dfm_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/dfm_models/.gitkeep -------------------------------------------------------------------------------- /model_assets/grid_sample_3d_plugin.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/grid_sample_3d_plugin.dll -------------------------------------------------------------------------------- /model_assets/libgrid_sample_3d_plugin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/libgrid_sample_3d_plugin.so -------------------------------------------------------------------------------- /model_assets/liveportrait_onnx/lip_array.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/liveportrait_onnx/lip_array.pkl -------------------------------------------------------------------------------- /model_assets/meanshape_68.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visomaster/VisoMaster/18b5398b35cb7ebaa5f0f28d9025763da29b74c6/model_assets/meanshape_68.pkl -------------------------------------------------------------------------------- /requirements_cu118.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | 3 | numpy==1.26.4 4 | opencv-python==4.10.0.84 5 | scikit-image==0.21.0 6 | pillow==9.5.0 7 | onnx==1.16.1 8 | protobuf==4.23.2 9 | psutil==6.0.0 10 | onnxruntime-gpu==1.18.0 11 | packaging==24.1 12 | PySide6==6.8.2.1 13 | kornia 14 | torch==2.4.1+cu118 15 | torchvision==0.19.1+cu118 16 | torchaudio==2.4.1+cu118 17 | tqdm 18 | ftfy 19 | regex 20 | pyvirtualcam==0.11.1 21 | tensorrt-cu11==10.4.0 22 | numexpr 23 | onnxsim 24 | requests 25 | pyqt-toast-notification==1.3.2 26 | qdarkstyle 27 | pyqtdarktheme -------------------------------------------------------------------------------- /requirements_cu124.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu124 2 | 3 | numpy==1.26.4 4 | opencv-python==4.10.0.84 5 | scikit-image==0.21.0 6 | pillow==9.5.0 7 | onnx==1.16.1 8 | protobuf==4.23.2 9 | psutil==6.0.0 10 | onnxruntime-gpu==1.20.0 11 | packaging==24.1 12 | PySide6==6.8.2.1 13 | kornia 14 | torch==2.4.1+cu124 15 | torchvision==0.19.1+cu124 16 | torchaudio==2.4.1+cu124 17 | tensorrt==10.6.0 --extra-index-url https://pypi.nvidia.com 18 | tensorrt-cu12_libs==10.6.0 19 | tensorrt-cu12_bindings==10.6.0 20 | tqdm 21 | ftfy 22 | regex 23 | pyvirtualcam==0.11.1 24 | numexpr 25 | onnxsim 26 | requests 27 | pyqt-toast-notification==1.3.2 28 | qdarkstyle 29 | pyqtdarktheme 30 | -------------------------------------------------------------------------------- /scripts/setenv.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | :: Get the parent directory of the script location 4 | SET "VISO_ROOT=%~dp0" 5 | SET "VISO_ROOT=%VISO_ROOT:~0,-1%" 6 | FOR %%A IN ("%VISO_ROOT%\..") DO SET "VISO_ROOT=%%~fA" 7 | 8 | :: Define dependencies directory 9 | SET "DEPENDENCIES=%VISO_ROOT%\dependencies" 10 | 11 | SET "GIT_EXECUTABLE=%DEPENDENCIES%\git-portable\bin\git.exe" 12 | 13 | :: Define Python paths 14 | SET "PYTHON_PATH=%DEPENDENCIES%\Python" 15 | SET "PYTHON_SCRIPTS=%PYTHON_PATH%\Scripts" 16 | SET "PYTHON_EXECUTABLE=%PYTHON_PATH%\python.exe" 17 | SET "PYTHONW_EXECUTABLE=%PYTHON_PATH%\pythonw.exe" 18 | 19 | :: Define CUDA and TensorRT paths 20 | SET "CUDA_PATH=%DEPENDENCIES%\CUDA" 21 | SET "CUDA_BIN_PATH=%CUDA_PATH%\bin" 22 | SET "TENSORRT_PATH=%DEPENDENCIES%\TensorRt\lib" 23 | 24 | :: Define FFMPEG path correctly 25 | SET "FFMPEG_PATH=%DEPENDENCIES%" 26 | 27 | :: Add all necessary paths to system PATH 28 | SET "PATH=%FFMPEG_PATH%;%PYTHON_PATH%;%PYTHON_SCRIPTS%;%CUDA_BIN_PATH%;%TENSORRT_PATH%;%PATH%" 29 | -------------------------------------------------------------------------------- /scripts/update_cu118.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call scripts\setenv.bat 3 | "%GIT_EXECUTABLE%" fetch origin main 4 | "%GIT_EXECUTABLE%" reset --hard origin/main 5 | "%PYTHON_EXECUTABLE%" -m pip install -r requirements_cu118.txt --default-timeout 100 6 | "%PYTHON_EXECUTABLE%" download_models.py -------------------------------------------------------------------------------- /scripts/update_cu124.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call scripts\setenv.bat 3 | "%GIT_EXECUTABLE%" fetch origin main 4 | "%GIT_EXECUTABLE%" reset --hard origin/main 5 | "%PYTHON_EXECUTABLE%" -m pip install -r requirements_cu124.txt --default-timeout 100 6 | "%PYTHON_EXECUTABLE%" download_models.py -------------------------------------------------------------------------------- /tools/convert_old_rope_embeddings.py: -------------------------------------------------------------------------------- 1 | # Script Usage Example 2 | # 'python3 convert_old_rope_embeddings.py old_merged_embeddings.txt new_merged_embeddings.json' 3 | 4 | import json 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser("Rope Embeddings Converter") 8 | parser.add_argument("old_embeddings_file", help="Old Embeddings File", type=str) 9 | parser.add_argument("--output_embeddings_file", help="New Embeddings File", type=str) 10 | parser.add_argument("--recognizer_model", help="Face Recognizer Model using which the embedding was created", default='Inswapper128ArcFace', choices=('Inswapper128ArcFace', 'SimSwapArcFace', 'GhostArcFace', 'GhostArcFace', 'GhostArcFace'), type=str) 11 | args = parser.parse_args() 12 | input_filename = args.old_embeddings_file 13 | 14 | output_filename = args.output_embeddings_file or f'{input_filename.split(".")[0]}_converted.json' 15 | 16 | recognizer_model = args.recognizer_model 17 | temp0 = [] 18 | new_embed_list = [] 19 | with open(input_filename, "r") as embedfile: 20 | old_data = embedfile.read().splitlines() 21 | 22 | for i in range(0, len(old_data), 513): 23 | new_embed_data = {'name': old_data[i][6:], 'embedding_store': {recognizer_model: old_data[i+1:i+513]}} 24 | for i, val in enumerate(new_embed_data['embedding_store'][recognizer_model]): 25 | new_embed_data['embedding_store'][recognizer_model][i] = float(val) 26 | new_embed_list.append(new_embed_data) 27 | 28 | with open(output_filename, 'w') as embed_file: 29 | embeds_data = json.dumps(new_embed_list) 30 | embed_file.write(embeds_data) --------------------------------------------------------------------------------