├── !del_txt.bat ├── 2tag ├── may.jpg └── march.jpg ├── update_pip.cmd ├── convert_to_bnb_nf4.bat ├── version_checker.bat ├── fix_minicpmo.bat ├── reinstall.bat ├── requirements_ver.txt ├── .gitignore ├── update.cmd ├── install.bat ├── fix_minicpmo.py ├── utils.py ├── convert_to_bnb_nf4.py ├── version_checker.py ├── arg_parser.py ├── batch_processing.bat ├── requirements.txt ├── ide-cap-chan.py ├── README.md ├── image_processor.py ├── LICENSE └── model_handlers.py /!del_txt.bat: -------------------------------------------------------------------------------- 1 | del .\2tag\*.txt -------------------------------------------------------------------------------- /2tag/may.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2dameneko/ide-cap-chan/HEAD/2tag/may.jpg -------------------------------------------------------------------------------- /2tag/march.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2dameneko/ide-cap-chan/HEAD/2tag/march.jpg -------------------------------------------------------------------------------- /update_pip.cmd: -------------------------------------------------------------------------------- 1 | call .\venv\Scripts\activate.bat 2 | python.exe -m pip install --upgrade pip -------------------------------------------------------------------------------- /convert_to_bnb_nf4.bat: -------------------------------------------------------------------------------- 1 | call .\venv\Scripts\activate.bat 2 | rem pip install --upgrade bitsandbytes 3 | python "convert_to_bnb_nf4.py" "ToriiGate-v0.3" -------------------------------------------------------------------------------- /version_checker.bat: -------------------------------------------------------------------------------- 1 | python "version_checker.py" 2 | rem dual call for in/out venv dependencies check 3 | call .\venv\Scripts\activate.bat 4 | rem pip show transformers 5 | rem pip list 6 | python "version_checker.py" -------------------------------------------------------------------------------- /fix_minicpmo.bat: -------------------------------------------------------------------------------- 1 | REM Run if got an error: 2 | REM No such file or directory: 'C:\\Users\\username\\.cache\\huggingface\\modules\\transformers_modules\\MiniCPM-o-2_6\\image_processing_minicpmv.py' 3 | python "fix_minicpmo.py" -------------------------------------------------------------------------------- /reinstall.bat: -------------------------------------------------------------------------------- 1 | python -m venv --system-site-packages venv 2 | call venv\Scripts\activate 3 | rem pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 4 | pip install -r requirements.txt 5 | pause 6 | -------------------------------------------------------------------------------- /requirements_ver.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub==0.28.1 2 | einops==0.8.1 3 | accelerate==1.4.0 4 | bitsandbytes==0.45.2 5 | qwen-vl-utils==0.0.10 6 | vector-quantize-pytorch==1.21.8 7 | vocos==0.1.0 8 | soundfile==0.13.1 9 | librosa==0.10.2.post1 10 | exllamav2==0.2.7+cu121.torch2.5.0 11 | flash_attn==2.7.0.post2 12 | typing_extensions==4.12.2 13 | transformers==4.48.3 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all folders 2 | */ 3 | 4 | # Except these 5 | !2tag/march.jpt 6 | !2tag/may.jpg 7 | !2tag/march.txt 8 | !2tag/may.txt 9 | 10 | # Ignore Python bytecode files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # Ignore virtual environment directories 16 | venv/ 17 | .venv/ 18 | env/ 19 | .env/ 20 | ENV/ 21 | env.bak/ 22 | venv.bak/ 23 | 24 | # Ignore IDE and editor files 25 | .vscode/ 26 | .idea/ 27 | *.swp 28 | *.swo 29 | 30 | # Ignore system files 31 | .DS_Store 32 | Thumbs.db 33 | -------------------------------------------------------------------------------- /update.cmd: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Pull the latest changes from the repository 4 | git pull 5 | 6 | REM Set default values for environment variables if not already defined 7 | if not defined PYTHON (set PYTHON=python) 8 | if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%") 9 | if not defined VENV_DIR (set "VENV_DIR=%~dp0venv") 10 | 11 | REM Check if the virtual environment exists 12 | if exist "%VENV_DIR%\Scripts\Python.exe" ( 13 | REM Activate the virtual environment 14 | call "%VENV_DIR%\Scripts\activate.bat" 15 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 16 | echo Virtual environment activated: %PYTHON% 17 | ) else ( 18 | echo Virtual environment not found at %VENV_DIR%. Please create it first. 19 | exit /b 1 20 | ) 21 | 22 | REM Install project requirements 23 | pip install -r requirements.txt 24 | 25 | REM End of script 26 | echo Update complete. -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Set the directory for the virtual environment 4 | set "VENV_DIR=%~dp0venv" 5 | 6 | REM Check if the virtual environment already exists 7 | if exist "%VENV_DIR%\Scripts\Python.exe" ( 8 | echo Virtual environment already found at %VENV_DIR%. 9 | exit /b 1 10 | ) else ( 11 | REM Create the virtual environment 12 | python -m venv --system-site-packages venv 13 | goto :activate 14 | ) 15 | 16 | :activate 17 | REM Activate the virtual environment 18 | call "%VENV_DIR%\Scripts\activate.bat" 19 | 20 | REM Upgrade pip 21 | python -m pip install --upgrade pip 22 | 23 | REM Install PyTorch and related packages 24 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 25 | 26 | REM Install project requirements 27 | pip install -r requirements.txt 28 | 29 | REM End of script 30 | echo Virtual environment setup complete. -------------------------------------------------------------------------------- /fix_minicpmo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def copy_image_processing_file(): 5 | # Get current Windows username 6 | username = os.getenv('USERNAME') 7 | 8 | # Construct paths 9 | source_file = os.path.join('MiniCPM-o-2_6', 'image_processing_minicpmv.py') 10 | destination_dir = os.path.join( 11 | 'C:\\', 12 | 'Documents and Settings', 13 | username, 14 | '.cache', 15 | 'huggingface', 16 | 'modules', 17 | 'transformers_modules', 18 | 'MiniCPM-o-2_6' 19 | ) 20 | 21 | # Check if source file exists 22 | if not os.path.exists(source_file): 23 | raise FileNotFoundError(f"Source file '{source_file}' not found") 24 | 25 | # Create destination directory if it doesn't exist 26 | os.makedirs(destination_dir, exist_ok=True) 27 | 28 | # Perform file copy 29 | shutil.copy(source_file, destination_dir) 30 | print(f"File successfully copied to: {destination_dir}") 31 | 32 | if __name__ == "__main__": 33 | copy_image_processing_file() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from PIL import Image 4 | 5 | GPU_TEST_ITERATIONS = 4000 6 | GPU_TEST_SIZE = 1000 7 | 8 | def measure_gpu_speed(device): 9 | """ 10 | Measure the speed of a GPU by performing matrix operations. 11 | 12 | Args: 13 | device: The CUDA device to measure 14 | 15 | Returns: 16 | float: A score representing the relative speed of the GPU 17 | """ 18 | start_time = time.time() 19 | dummy_tensor = torch.randn(GPU_TEST_SIZE, GPU_TEST_SIZE).to(device) 20 | for _ in range(GPU_TEST_ITERATIONS): 21 | _ = dummy_tensor @ dummy_tensor 22 | end_time = time.time() 23 | return 1 / (end_time - start_time) 24 | 25 | def resize_image_proportionally(image, max_width=None, max_height=None): 26 | """ 27 | Resize an image proportionally to fit within the specified dimensions. 28 | 29 | Args: 30 | image: PIL Image to resize 31 | max_width: Maximum width 32 | max_height: Maximum height 33 | 34 | Returns: 35 | PIL Image: Resized image 36 | """ 37 | if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0): 38 | return image 39 | 40 | original_width, original_height = image.size 41 | 42 | if ((max_width is None or original_width <= max_width) and 43 | (max_height is None or original_height <= max_height)): 44 | return image 45 | 46 | if max_width and max_height: 47 | width_ratio = max_width / original_width 48 | height_ratio = max_height / original_height 49 | ratio = min(width_ratio, height_ratio) 50 | elif max_width: 51 | ratio = max_width / original_width 52 | else: 53 | ratio = max_height / original_height 54 | 55 | new_width = int(original_width * ratio) 56 | new_height = int(original_height * ratio) 57 | 58 | resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 59 | return resized_image 60 | -------------------------------------------------------------------------------- /convert_to_bnb_nf4.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | from transformers import AutoModel, AutoModelForVision2Seq, AutoTokenizer, BitsAndBytesConfig 5 | import torch 6 | import sys 7 | 8 | # Define the quantization configuration 9 | quantization_config = BitsAndBytesConfig( 10 | load_in_4bit=True, 11 | bnb_4bit_quant_type="nf4", 12 | bnb_4bit_use_double_quant=True, 13 | bnb_4bit_compute_dtype=torch.bfloat16, 14 | ) 15 | 16 | # Get model name from command line arguments 17 | model_name = sys.argv[1] 18 | 19 | # Load the model and tokenizer with the quantization configuration 20 | model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, trust_remote_code = True, low_cpu_mem_usage = True) 21 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 22 | 23 | # Print the original model size 24 | original_size = sum(p.numel() for p in model.parameters()) 25 | print(f"Original model size: {original_size / 1e6:.2f} million parameters") 26 | 27 | # Create new model name with suffix 28 | new_model_name = model_name + '-nf4' 29 | 30 | # Save the quantized model with new name 31 | model.save_pretrained(new_model_name) 32 | tokenizer.save_pretrained(new_model_name) 33 | 34 | # Modify config.json to remove the specified lines 35 | with open(os.path.join(new_model_name, 'config.json'), 'r') as file: 36 | config_dict = json.load(file) 37 | 38 | # Remove the specified fields 39 | config_dict["quantization_config"].pop("_load_in_4bit", None) 40 | config_dict["quantization_config"].pop("_load_in_8bit", None) 41 | config_dict["quantization_config"].pop("quant_method", None) 42 | 43 | # Write the modified config back to the file 44 | with open(os.path.join(new_model_name, 'config.json'), 'w') as file: 45 | json.dump(config_dict, file, indent=4) 46 | 47 | # Copy preprocessor_config.json and chat_template.json from input model to new model 48 | shutil.copyfile(os.path.join(model_name, 'preprocessor_config.json'), os.path.join(new_model_name, 'preprocessor_config.json')) 49 | shutil.copyfile(os.path.join(model_name, 'chat_template.json'), os.path.join(new_model_name, 'chat_template.json')) 50 | 51 | print(f"Model has been quantized and saved as {new_model_name}. Config.json has been modified.") -------------------------------------------------------------------------------- /version_checker.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pkg_resources 3 | from collections import OrderedDict 4 | 5 | def get_installed_versions(): 6 | # Regular expression to extract package names from requirements.txt lines 7 | package_pattern = re.compile(r'^([a-zA-Z0-9_-]+)') 8 | 9 | installed_versions = OrderedDict() 10 | missing_packages = [] 11 | 12 | try: 13 | # Read existing requirements_ver.txt if it exists 14 | try: 15 | with open('requirements_ver.txt', 'r') as f: 16 | for line in f: 17 | line = line.strip() 18 | if line and '==' in line: 19 | package_name, version = line.split('==') 20 | installed_versions[package_name] = version 21 | except FileNotFoundError: 22 | pass # If file doesn't exist, proceed without loading 23 | 24 | with open('requirements.txt', 'r') as f: 25 | for line in f: 26 | line = line.strip() 27 | 28 | # Skip empty lines and comments 29 | if not line or line.startswith('#'): 30 | continue 31 | 32 | # Extract package name from the line 33 | match = package_pattern.match(line) 34 | if not match: 35 | continue # Skip lines that don't match package pattern 36 | 37 | package_name = match.group(1) 38 | 39 | try: 40 | # Get installed version 41 | version = pkg_resources.get_distribution(package_name).version 42 | installed_versions[package_name] = version 43 | except pkg_resources.DistributionNotFound: 44 | if package_name not in installed_versions: 45 | missing_packages.append(package_name) 46 | 47 | except FileNotFoundError: 48 | print("Error: requirements.txt file not found") 49 | return 50 | 51 | # Write results to requirements_ver.txt 52 | with open('requirements_ver.txt', 'w') as f: 53 | for package_name, version in installed_versions.items(): 54 | f.write(f"{package_name}=={version}\n") 55 | 56 | # Print summary 57 | print(f"Successfully wrote {len(installed_versions)} packages to requirements_ver.txt") 58 | if missing_packages: 59 | print(f"Missing packages: {', '.join(missing_packages)}") 60 | 61 | if __name__ == "__main__": 62 | get_installed_versions() 63 | -------------------------------------------------------------------------------- /arg_parser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import sys 3 | 4 | def parse_arguments(): 5 | parser = ArgumentParser(description='Generate captions for images') 6 | 7 | parser.add_argument('--model_path', type=str, default="", help='Path to the used model') 8 | parser.add_argument('--model_type', type=str, default="exllama2", 9 | help='Model type (supported architectures: idefics3, llava, joy-caption, molmo, qwen2vl, molmo72b, pixtral, exllama2, minicpmo, generic (You can try this option if your model is not listed as supported. No warranties.))') 10 | parser.add_argument('--input_dir', type=str, default="./2tag", help='Path to the folder containing images') 11 | parser.add_argument('--CUDA_VISIBLE_DEVICES', type=str, default="0", 12 | help='Comma-separated list of CUDA devices. WARNING: multi-GPU captioning can overload your power supply unit. Model molmo72b ignores this arg and requires 2x24GB GPU') 13 | parser.add_argument('--caption_suffix', type=str, default=".txt", help='File extension for generated caption files') 14 | parser.add_argument('--tags_suffix', type=str, default=".ttxt", help='File extension for existing image info file (like *booru tags, traits, characters, etc)') 15 | parser.add_argument('--caption_format', type=str, choices=['json', 'markdown', 'short', 'long', 'bbox'], default='long', 16 | help='Format of the generated captions (supported formats: json, markdown, short, long, bbox), (req. ToriiGate-family models)') 17 | 18 | bool_args = [ 19 | ('--add_tags', 'Use an additional file as existing *booru tags to enhance captioning, (req. ToriiGate-family models)'), 20 | ('--add_chars', 'Use an additional file as information about represented characters, (req. ToriiGate >= 0.4 model)'), 21 | ('--add_char_traits', 'Use an additional file as information about character traits, (req. ToriiGate >= 0.4 model)'), 22 | ('--add_info', 'Use an additional file as misc information about image, (req. ToriiGate >= 0.4 model)'), 23 | ('--no_chars', 'Do not add any characters to the output, (req. ToriiGate >= 0.4 model)'), 24 | ] 25 | 26 | for arg, help_text in bool_args: 27 | parser.add_argument(arg, default=False, action='store_true', help=help_text) 28 | 29 | args = parser.parse_args() 30 | 31 | check_mutually_exclusive( 32 | args, ['--add_tags', '--add_chars', '--add_char_traits', '--add_info'] 33 | ) 34 | 35 | check_mutually_exclusive( 36 | args, ['--add_chars', '--no_chars'] 37 | ) 38 | 39 | return args 40 | 41 | def check_mutually_exclusive(args, arg_names): 42 | args_list = [getattr(args, arg_name.replace('--', '')) for arg_name in arg_names] 43 | if sum(args_list) > 1: 44 | print(f"Error: Only one of the following arguments can be True at a time: {', '.join(arg_names)}") 45 | sys.exit(1) 46 | -------------------------------------------------------------------------------- /batch_processing.bat: -------------------------------------------------------------------------------- 1 | call .\venv\Scripts\activate.bat 2 | python "ide-cap-chan.py" 3 | 4 | rem Full command line args example: 5 | rem python "ide-cap-chan3.py" --model_path "model_name" --model_type "model_type" --input_dir "folder_with_images_to_tag" --CUDA_VISIBLE_DEVICES "0, 1" --caption_suffix ".txt" --dont_use_tags --tags_suffix ".ttxt" 6 | 7 | rem Local models 8 | rem python "ide-cap-chan.py" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw" 9 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3" --CUDA_VISIBLE_DEVICES "1" 10 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3-nf4" --CUDA_VISIBLE_DEVICES "0" 11 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3" --CUDA_VISIBLE_DEVICES "0" --add_tags --caption_suffix ".text" --tags_suffix ".txt" 12 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Minthy_ToriiGate-v0.4-2B" --CUDA_VISIBLE_DEVICES "0" 13 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Vikhr-2-VL-2b-Instruct-experimental" --CUDA_VISIBLE_DEVICES "0" 14 | rem python "ide-cap-chan.py" --model_type "minicpmo" --model_path "MiniCPM-o-2_6" --CUDA_VISIBLE_DEVICES "0" 15 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Minthy_ToriiGate-v0.4-7B" --CUDA_VISIBLE_DEVICES "0" --no_chars --add_tags --caption_format "markdown" --caption_suffix ".text" --tags_suffix ".txt" 16 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0" --no_chars --add_tags --caption_format "bbox" --caption_suffix ".text" --tags_suffix ".txt" 17 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0, 1" 18 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-7B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0" 19 | rem python "ide-cap-chan.py" --model_type "pixtral" --model_path "Ertugrul_Pixtral-12B-Captioner-Relaxed" --CUDA_VISIBLE_DEVICES "0" 20 | rem python "ide-cap-chan.py" --model_type "molmo" --model_path "ctranslate2-4you_molmo-7B-O-bnb-4bit" --CUDA_VISIBLE_DEVICES "0" 21 | rem python "ide-cap-chan.py" --model_type "molmo72b" --model_path "SeanScripts_Molmo-72B-0924-nf4" 22 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Ertugrul_Qwen2-VL-7B-Captioner-Relaxed" --CUDA_VISIBLE_DEVICES "0" 23 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "Idefics3-8B-Llama3" 24 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llava-v1.6-mistral-7b-hf-nf4" --CUDA_VISIBLE_DEVICES "0" 25 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llava-v1.6-mistral-7b-hf" --CUDA_VISIBLE_DEVICES "0" 26 | rem python "ide-cap-chan.py" --model_type "joy-caption" --model_path "llama-joycaption-alpha-two-hf-llava" --CUDA_VISIBLE_DEVICES "0" 27 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llama-joycaption-alpha-two-hf-llava" --CUDA_VISIBLE_DEVICES "0" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #git+https://github.com/huggingface/transformers #latest, minicpm_o broken (due Loggits deprecated) 2 | transformers==4.48.3 #last verion minicpmo compatible 3 | #transformers==4.44.2 #recommended for minicpm_o, but doesn't have support for idefics3 4 | accelerate==1.4.0 5 | huggingface_hub==0.28.1 6 | bitsandbytes==0.45.5 7 | typing_extensions==4.12.2 8 | einops==0.8.1 9 | qwen-vl-utils==0.0.10 10 | vector-quantize-pytorch==1.21.8 11 | vocos==0.1.0 12 | soundfile==0.13.1 13 | librosa==0.10.2.post1 14 | 15 | http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl 16 | 17 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" 18 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" 19 | 20 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" 21 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" 22 | 23 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.12" 24 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.12" 25 | 26 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" 27 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" 28 | 29 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" 30 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" 31 | 32 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" 33 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" 34 | 35 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-win_amd64.whl; platform_system == "Windows" and python_version == "3.12" 36 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl; platform_system == "Windows" and python_version == "3.12" 37 | 38 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" 39 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" 40 | -------------------------------------------------------------------------------- /ide-cap-chan.py: -------------------------------------------------------------------------------- 1 | # ide-cap-chan v0.96 2 | from arg_parser import parse_arguments 3 | from utils import measure_gpu_speed 4 | from image_processor import process_image_worker 5 | import torch.multiprocessing as mp 6 | from os import walk 7 | from os.path import splitext as split_extension 8 | from pathlib import Path 9 | import queue 10 | import time 11 | 12 | def main(): 13 | args = parse_arguments() 14 | 15 | device_ids = list(map(int, args.CUDA_VISIBLE_DEVICES.split(','))) 16 | 17 | supported_model_types = ["idefics3", "llava", "joy-caption", "molmo", "qwen2vl", "molmo72b", "pixtral", "exllama2", "minicpmo", "generic"] 18 | input_model_type = args.model_type.lower() 19 | if input_model_type not in supported_model_types: 20 | print(f"Model type '{input_model_type}' not supported. Supported loaders: {', '.join(supported_model_types)}.") 21 | return 22 | 23 | model_name_or_path = args.model_path or { 24 | 'idefics3': "2dameneko/ToriiGate-v0.3-nf4", 25 | 'llava': "2dameneko/llava-v1.6-mistral-7b-hf-nf4", 26 | 'joy-caption': "fancyfeast/llama-joycaption-alpha-two-hf-llava", 27 | #'qwen2vl': "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed", 28 | #'qwen2vl': "Minthy/ToriiGate-v0.4-2B", 29 | 'qwen2vl': "Vikhrmodels/Vikhr-2-VL-2b-Instruct-experimental", 30 | 'molmo': "cyan2k/molmo-7B-O-bnb-4bit", 31 | 'molmo72b': "SeanScripts/Molmo-72B-0924-nf4", 32 | 'pixtral': "Ertugrul/Pixtral-12B-Captioner-Relaxed", 33 | 'exllama2': "Minthy/ToriiGate-v0.4-2B-exl2-8bpw", 34 | #'exllama2': "Minthy/ToriiGate-v0.4-7B-exl2-8bpw", 35 | 'minicpmo': "openbmb/MiniCPM-o-2_6", 36 | 'generic': None, 37 | }[input_model_type] 38 | 39 | quant_suffixes = ["nf4", "bnb", "4bit"] 40 | use_nf4 = any(suffix in model_name_or_path for suffix in quant_suffixes) 41 | 42 | if input_model_type == "joy-caption" and use_nf4: 43 | print(f"Model type '{input_model_type}' not supported with -nf4 quantization. Set to false.") 44 | use_nf4 = False 45 | 46 | args_dict = { 47 | 'use_nf4' : use_nf4, 48 | 'caption_suffix' : args.caption_suffix, 49 | 'tags_suffix' : args.tags_suffix, 50 | 'add_tags' : args.add_tags, 51 | 'add_chars': args.add_chars, 52 | 'add_char_traits' : args.add_char_traits, 53 | 'add_info' : args.add_info, 54 | 'no_chars' : args.no_chars, 55 | 'caption_format' : args.caption_format, 56 | } 57 | 58 | input_dir = args.input_dir 59 | image_extensions = [".jpg", ".png", ".webp", ".jpeg"] 60 | 61 | # Measure GPU speeds for informational purposes 62 | gpu_speeds = [(i, measure_gpu_speed(f"cuda:{i}")) for i in device_ids] 63 | 64 | print(f'Using GPU ids: {device_ids}') 65 | print("GPUs speeds:") 66 | for gpu_id, gpu_speed in gpu_speeds: 67 | print(f" {gpu_id} | {gpu_speed:.2f}") 68 | print(f'Using model: {model_name_or_path} (type: {input_model_type})') 69 | print(f'Use quantization: {args_dict.get("use_nf4")}') 70 | 71 | # Find existing captions to avoid reprocessing 72 | existing_captions = [] 73 | for root, _, files in walk(input_dir): 74 | for file in files: 75 | file_path = Path(root) / file 76 | if file_path.suffix.lower() == args_dict.get('caption_suffix'): 77 | path, _ = split_extension(str(file_path)) 78 | existing_captions.append(path) 79 | 80 | # Create a list of files to process 81 | filelist = [] 82 | for root, _, files in walk(input_dir): 83 | for file in files: 84 | file_path = Path(root) / file 85 | if file_path.suffix.lower() in image_extensions: 86 | path, _ = split_extension(str(file_path)) 87 | if path not in existing_captions: 88 | filelist.append(file_path) 89 | 90 | if not filelist: 91 | print('There are no files to process.') 92 | return 93 | 94 | # Create a shared job queue 95 | job_queue = mp.Queue() 96 | result_queue = mp.Queue() 97 | 98 | # Put all files in the job queue 99 | for file_path in filelist: 100 | job_queue.put(file_path) 101 | 102 | # Add termination signals (one for each worker) 103 | for _ in range(len(device_ids)): 104 | job_queue.put(None) 105 | 106 | # Create and start worker processes 107 | processes = [] 108 | for i, gpu_id in enumerate(device_ids): 109 | p = mp.Process( 110 | target=process_image_worker, 111 | args=( 112 | i, # worker_id 113 | gpu_id, # gpu_id 114 | job_queue, 115 | result_queue, 116 | model_name_or_path, 117 | input_model_type, 118 | args_dict, 119 | len(filelist) # total_files 120 | ) 121 | ) 122 | p.start() 123 | processes.append(p) 124 | 125 | # Monitor progress 126 | completed_files = 0 127 | total_files = len(filelist) 128 | start_time = time.time() 129 | 130 | while completed_files < total_files: 131 | try: 132 | # Get result from the result queue 133 | result = result_queue.get(timeout=1.0) 134 | if result is not None: 135 | worker_id, gpu_id, file_name, processing_time = result 136 | completed_files += 1 137 | 138 | # Calculate ETA 139 | elapsed = time.time() - start_time 140 | avg_time_per_file = elapsed / completed_files 141 | remaining_time = avg_time_per_file * (total_files - completed_files) 142 | 143 | print(f"Worker {worker_id} (GPU {gpu_id}): {completed_files}/{total_files} - {file_name} - {processing_time:.2f}s - ETA: {remaining_time:.2f}s") 144 | except queue.Empty: 145 | # No result available, just continue 146 | continue 147 | 148 | # Wait for all processes to finish 149 | for p in processes: 150 | p.join() 151 | 152 | print(f"All {total_files} files processed successfully.") 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ide-cap-chan 2 | 3 |
4 | Visitor count
5 |
6 | 7 | ide-cap-chan is a utility for batch captioning images with natural language using various Vision-Language (VL) models. 8 | 9 | ## Features 10 | * **High-speed processing**: Optimized for rapid batch caption generation with ExLlama2, Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct (Vikhr-family included), 11 | Idefics3-8B-Llama3, LLaVa-NeXT (LLaVa-1.6), Llama JoyCaption Alpha Two, Molmo-7B-O, Molmo-72B, MiniCPM-o-2_6 and Pixtral models 12 | * **Multi-GPU support**: Distribute workloads across multiple GPUs 13 | * **Efficient quantization**: Supports ExLlama2 (exl2), int8, and nf4 quantization for reduced VRAM usage 14 | * **Autoload strategies**: VRAM-optimized loading 15 | * **Model flexibility**: Use default or custom models via CLI arguments. 16 | * **Input flexibility**: Supports Hugging Face, local, and external models 17 | * **Tag integration**: Enhance captions with existing tags/captions 18 | * **Process control**: Interrupt and resume captioning tasks 19 | * **Batch processing**: Recursively process subfolders in input directories 20 | 21 | ## Requirements 22 | * NVIDIA GPU with CUDA support (8GB VRAM minimum for llava, 12GB recommended for Qwen2-VL-7B in exl2, 48GB VRAM total for Molmo-72B) 23 | 24 | ## Installation 25 | 1. Clone the repository: 26 | `git clone https://github.com/2dameneko/ide-cap-chan` 27 | 2. Install dependencies: 28 | - **Windows**: Run `install.bat` 29 | - **Linux**: Create a virtual environment and install requirements: 30 | ```bash 31 | python -m venv venv 32 | source venv/bin/activate 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Usage 37 | 1. Place images and corresponding tag files in the input folder (default: `2tag`) 38 | 2. Start processing: 39 | - **Windows**: Run `batch_processing.bat` 40 | - **Linux**: Execute `python ide-cap-chan.py` 41 | 3. Specify alternative models using CLI arguments 42 | 4. Customize prompts in `model_handler.py` (modify `system_prompt` and `user_prompt`) 43 | 44 | ## Updating 45 | - **Windows**: Run `update.cmd` 46 | 47 | ## Options 48 | Run without arguments for default behavior. Available CLI options (`python ide-cap-chan.py -h`): 49 | | Argument | Description | 50 | |----------|-------------| 51 | | `--model_path` | Path to model (Hugging Face, local, or external) | 52 | | `--model_type` | Model architecture/loader: idefics3, llava, joy-caption, molmo, qwen2vl, molmo72b, pixtral, exllama2, minicpmo, generic (default: `exllama2`) | 53 | | `--input_dir` | Input directory path (default: `2tag`) | 54 | | `--CUDA_VISIBLE_DEVICES` | Comma-separated GPU IDs (default: `0`). **Note**:
- Multi-GPU may strain your PSU
- `molmo72b` ignores this argument and auto-splits across GPUs | 55 | | `--caption_suffix` | Caption file extension (default: `.txt`) | 56 | | `--caption_format` | Output format: `json`, `markdown`, `short`, `long`, `bbox` (requires ToriiGate ≥0.4) | 57 | | `--add_tags` | Enhance captions with existing tag files (ToriiGate-family models), (default: `.ttxt`) | 58 | | `--add_chars` | Enhance captions with character information (requires ToriiGate ≥0.4), (default: `.ttxt`) | 59 | | `--add_char_traits` | Enhance captions with character traits (requires ToriiGate ≥0.4), (default: `.ttxt`) | 60 | | `--add_info` | Enhance captions with miscellaneous image info (requires ToriiGate ≥0.4), (default: `.ttxt`) | 61 | | `--no_chars` | Do not add character names (requires ToriiGate ≥0.4), (default: `.ttxt`) | 62 | 63 | ## Supported File Formats 64 | `.jpg`, `.png`, `.webp`, `.jpeg` 65 | 66 | ## Version History 67 | * **0.96**: Moved to CUDA 12.8, PyTorch2.7, added support for Blackwell GPUs 68 | * **0.95**: Dynamic multi-GPU task queuing instead of splitting based on approximate GPU speed 69 | * **0.9**: Added MiniCPM-o-2_6 loader support, rewritten to modular design, pinned versions, 70 | * **0.8**: Added ExLlama2 loader support (default), ToriiGate-v0.4 features, Molmo-72B auto-split 71 | * **0.7**: Added Molmo/Qwen2VL/Pixtral support, improved multi-GPU quant processing, code refactor 72 | * **0.6**: Internal code improvements 73 | * **0.5**: Added JoyCaption support, code refactor 74 | * **0.4**: Added LLaVA support, updated to PyTorch 2.5.1 75 | * **0.3**: Improved argument handling, fixed extension case sensitivity 76 | * **0.2**: 77 | - Multi-GPU support with load balancing 78 | - nf4 quantization 79 | - Fixed duplicate file filtering 80 | - Updated environment scripts 81 | * **0.1**: Initial release 82 | 83 | ## Note 84 | This project is a proof of concept and not production-ready. 85 | 86 | ## License 87 | [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) 88 | 89 | ## Credits 90 | - Idefics3 Architecture: [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) 91 | - LLaVA Architecture: [Transformers Documentation](https://huggingface.co/docs/transformers/main/model_doc/llava) 92 | - JoyCaption Code: [fpgaminer/joycaption](https://github.com/fpgaminer/joycaption) 93 | - Qwen2-VL Architecture: [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) 94 | - Qwen2-VL Implementation: [MNeMoNiCuZ/qwen2-vl-7b-captioner-relaxed-batch](https://github.com/MNeMoNiCuZ/qwen2-vl-7b-captioner-relaxed-batch) 95 | - Molmo Architecture: [AllenAI Collection](https://huggingface.co/collections/allenai/molmo-66f379e6fe3b8ef090a8ca19) 96 | - Pixtral Architecture: [Pixtral Documentation](https://huggingface.co/docs/transformers/model_doc/pixtral) 97 | - MiniCPM-o-2_6 Architecture: [MiniCPM-o-2_6 Documentation](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9) 98 | - Vikhr-2-VL: [Vikhr-2-VL Documentation](https://huggingface.co/Vikhrmodels) 99 | - ExLlamaV2: [ExLlamaV2 Documentation](https://github.com/turboderp-org/exllamav2Vikhrmodels) 100 | 101 | 102 | **Model Credits** 103 | [ToriiGate](https://huggingface.co/Minthy) · [LLaVA](https://huggingface.co/llava-hf) · [JoyCaption](https://huggingface.co/fancyfeast) · [Qwen2, Pixtral](https://huggingface.co/Ertugrul) · [Molmo](https://huggingface.co/cyan2k) · [Molmo72b](https://huggingface.co/SeanScripts/Molmo-72B-0924-nf4) · [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) · [Vikhr-2-VL-2b-Instruct](https://huggingface.co/Vikhrmodels/Vikhr-2-VL-2b-Instruct-experimental) 104 | 105 | Thank you for your interest in ide-cap-chan! -------------------------------------------------------------------------------- /image_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Tuple, List 2 | import torch 3 | import time 4 | from transformers.image_utils import load_image 5 | from os.path import exists as path_exists, splitext as split_extension 6 | from model_handlers import ( 7 | JoyCaptionHandler, 8 | MoLMoHandler, 9 | MoLMo72bHandler, 10 | Qwen2VLHandler, 11 | PixtralHandler, 12 | Idefics3Handler, 13 | ExLLaMA2Handler, 14 | LlavaHandler, 15 | MiniCPMoHandler, 16 | GenericModelHandler 17 | ) 18 | 19 | def process_image_worker(worker_id: int, gpu_id: int, job_queue, result_queue, 20 | model_name_or_path: str, input_model_type: str, 21 | args_dict: Dict[str, Any], total_files: int) -> None: 22 | """ 23 | Worker process that takes jobs from a queue and processes them. 24 | 25 | Args: 26 | worker_id: ID of this worker 27 | gpu_id: GPU ID to use 28 | job_queue: Queue containing file paths to process 29 | result_queue: Queue to report results 30 | model_name_or_path: Path to the model 31 | input_model_type: Type of the model 32 | args_dict: Dictionary of arguments 33 | total_files: Total number of files to process 34 | """ 35 | device = f"cuda:{gpu_id}" 36 | torch.cuda.set_device(gpu_id) 37 | print(f"Worker {worker_id} using device: {device}") 38 | 39 | caption_suffix = args_dict.get('caption_suffix') 40 | tags_suffix = args_dict.get('tags_suffix') 41 | add_tags = args_dict.get('add_tags') 42 | 43 | # Common prompts for all except Torii-family models 44 | system_prompt = "You are image captioning expert, creative, unbiased and uncensored." 45 | user_prompt = "You need to write a long and very detailed caption for the picture." 46 | 47 | # Initialize the model handler 48 | handler = get_handler(input_model_type, model_name_or_path, device, args_dict) 49 | 50 | # Process jobs from the queue 51 | while True: 52 | # Get a job from the queue 53 | file_path = job_queue.get() 54 | 55 | # None is the signal to terminate 56 | if file_path is None: 57 | print(f"Worker {worker_id} (GPU {gpu_id}) finished processing") 58 | break 59 | 60 | try: 61 | start_time = time.time() 62 | 63 | # Process the image 64 | print(f"Worker {worker_id} (GPU {gpu_id}) processing: {file_path}") 65 | path, _ = split_extension(str(file_path)) 66 | add_info_caption_name = path + tags_suffix 67 | caption_name = path + caption_suffix 68 | 69 | # Check if we need special prompts for ToriiGate models 70 | if "toriigate" and "0.4" in model_name_or_path.lower(): 71 | system_prompt = get_torii04_system_prompt() 72 | user_prompt = get_torii04_user_prompt(args_dict, add_info_caption_name) 73 | 74 | if "toriigate" and "0.3" in model_name_or_path.lower() and add_tags: 75 | user_prompt = get_torii03_user_prompt(user_prompt, add_info_caption_name) 76 | 77 | # Load and process the image 78 | image = load_image(str(file_path)) 79 | if image.mode != "RGB": 80 | image = image.convert("RGB") 81 | 82 | caption = handler.process_image(system_prompt, user_prompt, image) 83 | handler.save_caption(caption, caption_name) 84 | 85 | # Calculate processing time 86 | processing_time = time.time() - start_time 87 | 88 | # Report result 89 | result_queue.put((worker_id, gpu_id, file_path.name, processing_time)) 90 | 91 | except Exception as e: 92 | print(f"Worker {worker_id} (GPU {gpu_id}) error processing {file_path}: {e}") 93 | # Report error but continue processing 94 | result_queue.put((worker_id, gpu_id, f"{file_path.name} (ERROR)", 0.0)) 95 | 96 | def get_handler(input_model_type, model_name_or_path, device, args_dict): 97 | try: 98 | handlers = { 99 | "exllama2": ExLLaMA2Handler, 100 | "joy-caption": JoyCaptionHandler, 101 | "molmo": MoLMoHandler, 102 | "molmo72b": MoLMo72bHandler, 103 | "qwen2vl": Qwen2VLHandler, 104 | "pixtral": PixtralHandler, 105 | "idefics3": Idefics3Handler, 106 | "llava": LlavaHandler, 107 | "minicpmo": MiniCPMoHandler, 108 | "generic": GenericModelHandler 109 | } 110 | except Exception: 111 | print(f"Unsupported model type: {input_model_type}") 112 | return handlers[input_model_type](model_name_or_path, device, args_dict) 113 | 114 | def get_torii04_user_prompt(args_dict, add_info_caption_name): 115 | add_tags = args_dict.get('add_tags') 116 | add_chars = args_dict.get('add_chars') 117 | add_char_traits = args_dict.get('add_char_traits') 118 | add_info = args_dict.get('add_info') 119 | no_chars = args_dict.get('no_chars') 120 | caption_format = args_dict.get('caption_format') 121 | 122 | image_info={} 123 | 124 | if path_exists(add_info_caption_name): 125 | if add_tags: 126 | tags = open(add_info_caption_name).read().strip() 127 | image_info["booru_tags"] = tags 128 | 129 | if add_chars: 130 | chars = open(add_info_caption_name).read().strip() 131 | image_info["chars"] = chars 132 | 133 | if add_char_traits: 134 | traits = open(add_info_caption_name).read().strip() 135 | image_info["characters_traits"] = traits 136 | 137 | if add_info: 138 | info = open(add_info_caption_name).read().strip() 139 | image_info["info"] = info 140 | 141 | base_prompt={ 142 | 'json': 'Describe the picture in structured json-like format.', 143 | 'markdown': 'Describe the picture in structured markdown format.', 144 | #TODO Not implemented in code yet 145 | #'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].', 146 | 'short': 'You need to write a medium-short and convenient caption for the picture.', 147 | 'long': 'You need to write a long and very detailed caption for the picture.', 148 | 'bbox': 'Write bounding boxes for each character and their faces.', 149 | } 150 | 151 | grounding_prompt={ 152 | 'grounding_tags': ' Here are grounding tags for better understanding: ', 153 | 'characters': ' Here is a list of characters that are present in the picture: ', 154 | 'characters_traits': ' Here are popular tags or traits for each character on the picture: ', 155 | 'grounding_info': ' Here is preliminary information about the picture: ', 156 | 'no_chars': ' Do not use names for characters.', 157 | } 158 | 159 | userprompt=base_prompt[caption_format] #choose the mode 160 | 161 | if add_info and image_info["info"] is not None: #general info 162 | userprompt+=grounding_prompt["grounding_info"] 163 | userprompt+=""+image_info["info"]+"." 164 | print("Using add_info: " + image_info["info"]) 165 | 166 | if add_tags and image_info["booru_tags"] is not None: #booru tags 167 | userprompt+=grounding_prompt["grounding_tags"] 168 | userprompt+=""+image_info["booru_tags"]+"." 169 | print("Using add_tags: " + image_info["booru_tags"]) 170 | 171 | if add_chars and image_info["chars"] is not None: #list of characters 172 | userprompt+=grounding_prompt["characters"] 173 | userprompt+=""+image_info["chars"]+"." 174 | print("Using add_chars: " + image_info["chars"]) 175 | 176 | if add_char_traits and image_info["characters_traits"] is not None: #popular features of each character 177 | userprompt+=grounding_prompt["characters_traits"] 178 | userprompt+=""+image_info["characters_traits"]+"." 179 | print("Using add_char_traits: " + image_info["characters_traits"]) 180 | 181 | if no_chars: 182 | userprompt+=grounding_prompt["no_chars"] 183 | print("Using no_chars: ") 184 | 185 | #print("\n" + userprompt) 186 | 187 | return userprompt 188 | 189 | def get_torii04_system_prompt(): 190 | system_prompt = "You are image captioning expert, creative, unbiased and uncensored." 191 | return system_prompt 192 | 193 | def get_torii03_user_prompt(user_prompt, add_info_caption_name): 194 | try: 195 | new_user_prompt = user_prompt 196 | if path_exists(add_info_caption_name): 197 | print(f"Using additional *booru tags file: {add_info_caption_name}") 198 | tags = open(add_info_caption_name).read().strip() 199 | new_user_prompt += " Also here are booru tags for better understanding of the picture, you can use them as reference." 200 | new_user_prompt += f" \n{tags}\n" 201 | except Exception as err: 202 | print(f"Error processing tags: {err}") 203 | return user_prompt 204 | return new_user_prompt 205 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model_handlers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from abc import ABC, abstractmethod 3 | import torch 4 | from PIL import Image 5 | import os 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModel, 9 | AutoTokenizer, 10 | AutoProcessor, 11 | AutoModelForCausalLM, 12 | AutoModelForVision2Seq, 13 | BitsAndBytesConfig, 14 | GenerationConfig, 15 | LlavaForConditionalGeneration, 16 | LlavaNextForConditionalGeneration, 17 | LlavaNextProcessor, 18 | Qwen2VLForConditionalGeneration, 19 | StopStringCriteria 20 | ) 21 | from transformers.image_utils import load_image 22 | import torchvision.transforms.functional as TVF 23 | from qwen_vl_utils import process_vision_info 24 | from exllamav2 import ( 25 | ExLlamaV2, 26 | ExLlamaV2Config, 27 | ExLlamaV2Cache, 28 | ExLlamaV2Tokenizer, 29 | ExLlamaV2VisionTower, 30 | ) 31 | from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler 32 | from utils import resize_image_proportionally 33 | from huggingface_hub import snapshot_download 34 | 35 | class ModelHandler(ABC): 36 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 37 | self.model_name_or_path = self.model_loader(model_name_or_path) 38 | self.device = device 39 | self.args_dict = args_dict 40 | self.model = None 41 | self.processor = None 42 | self.tokenizer = None 43 | self.quantization_config = None 44 | self._initialize_model() 45 | 46 | @abstractmethod 47 | def _initialize_model(self): 48 | pass 49 | 50 | @abstractmethod 51 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 52 | pass 53 | 54 | def save_caption(self, caption: str, caption_path: str, encoding: str = "utf-8", errors: str = "ignore"): 55 | with open(caption_path, "w", encoding=encoding, errors=errors) as outf: 56 | outf.write(caption) 57 | 58 | def model_loader(self, model_name_or_path: str) -> str: 59 | local_model_dir = model_name_or_path.split('/')[-1] 60 | if os.path.exists(local_model_dir): 61 | print(f"Model directory '{local_model_dir}' already exists. Using local version.") 62 | return local_model_dir 63 | else: 64 | print(f"Downloading model '{model_name_or_path}'...") 65 | snapshot_download( 66 | repo_id=model_name_or_path, 67 | local_dir=local_model_dir, 68 | ) 69 | print(f"Model successfully saved to '{local_model_dir}' directory.") 70 | return local_model_dir 71 | 72 | def _get_quantization_config(self): 73 | use_nf4 = self.args_dict.get('use_nf4') 74 | if use_nf4: 75 | return BitsAndBytesConfig( 76 | load_in_4bit=True, 77 | bnb_4bit_quant_type="nf4", 78 | bnb_4bit_use_double_quant=True, 79 | bnb_4bit_compute_dtype=torch.bfloat16 80 | ) 81 | return None 82 | 83 | class ExLLaMA2Handler(ModelHandler): 84 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 85 | super().__init__(model_name_or_path, device, args_dict) 86 | 87 | def _initialize_model(self): 88 | self.config = ExLlamaV2Config(self.model_name_or_path) 89 | self.config.max_seq_len = 32768 90 | 91 | self.vision_model = ExLlamaV2VisionTower(self.config) 92 | self.vision_model.load(progress=True) 93 | 94 | self.model = ExLlamaV2(self.config) 95 | self.tokenizer = ExLlamaV2Tokenizer(self.config) 96 | 97 | device_count = torch.cuda.device_count() 98 | free_mem, total_mem = torch.cuda.mem_get_info(self.device) 99 | free_gb = free_mem / (1024 ** 3) 100 | total_gb = total_mem / (1024 ** 3) 101 | 102 | if len(self.args_dict.get('filelist_chunks', [])) == 1 and self.device == "cuda:0" and device_count > 1: 103 | autosplit = True 104 | else: 105 | autosplit = False 106 | 107 | if autosplit: 108 | print(f"VRAM allocation strategy: Autosplit on {device_count} GPUs") 109 | cache = ExLlamaV2Cache(self.model, lazy=True, max_seq_len=self.config.max_seq_len) 110 | self.model.load_autosplit(cache, progress=True) 111 | else: 112 | split = [0.0] * device_count 113 | gpu_id = int(self.device.split(":")[1]) 114 | split[gpu_id] = free_gb 115 | print(f"VRAM allocation strategy: allocated {free_gb:.2f} GB on GPU:{gpu_id}") 116 | self.model.load(split, progress=True) 117 | cache = ExLlamaV2Cache(self.model, lazy=False, max_seq_len=self.config.max_seq_len) 118 | 119 | self.generator = ExLlamaV2DynamicGenerator( 120 | model=self.model, 121 | cache=cache, 122 | tokenizer=self.tokenizer, 123 | ) 124 | 125 | def model_loader(self, model_name_or_path: str) -> str: 126 | local_model_dir = model_name_or_path.split('/')[-1] 127 | if os.path.exists(local_model_dir): 128 | print(f"Model directory '{local_model_dir}' already exists. Using local version.") 129 | return local_model_dir 130 | else: 131 | print(f"Downloading model '{model_name_or_path}'...") 132 | from huggingface_hub import snapshot_download 133 | snapshot_download( 134 | repo_id=model_name_or_path, 135 | local_dir=local_model_dir, 136 | local_dir_use_symlinks=False 137 | ) 138 | print(f"Model successfully saved to '{local_model_dir}' directory.") 139 | return local_model_dir 140 | 141 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 142 | image = resize_image_proportionally(image, 2048, 2048) 143 | image_embeddings = [self.vision_model.get_image_embeddings( 144 | model=self.model, 145 | tokenizer=self.tokenizer, 146 | image=image, 147 | )] 148 | placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n" 149 | 150 | msg_text = ( 151 | "<|im_start|>system\n" + 152 | system_prompt + 153 | "<|im_end|>\n" + 154 | "<|im_start|>user\n" + 155 | placeholders + 156 | user_prompt + 157 | "<|im_end|>\n" + 158 | "<|im_start|>assistant\n" 159 | ) 160 | 161 | gen_settings = ExLlamaV2Sampler.Settings() 162 | gen_settings.temperature = 0.6 163 | gen_settings.top_p = 0.9 164 | gen_settings.top_k = 40 165 | 166 | output = self.generator.generate( 167 | prompt=msg_text, 168 | max_new_tokens=512, 169 | add_bos=True, 170 | encode_special_tokens=True, 171 | decode_special_tokens=True, 172 | stop_conditions=[self.tokenizer.eos_token_id], 173 | gen_settings=gen_settings, 174 | embeddings=image_embeddings, 175 | ) 176 | 177 | return output.split('<|im_start|>assistant\n')[-1].strip() 178 | 179 | class JoyCaptionHandler(ModelHandler): 180 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 181 | super().__init__(model_name_or_path, device, args_dict) 182 | 183 | def _initialize_model(self): 184 | self.quantization_config = self._get_quantization_config() 185 | self.model = LlavaForConditionalGeneration.from_pretrained( 186 | self.model_name_or_path, 187 | torch_dtype=torch.bfloat16, 188 | quantization_config=self.quantization_config, 189 | device_map=self.device 190 | ) 191 | self.processor = AutoProcessor.from_pretrained( 192 | self.model_name_or_path, 193 | trust_remote_code=True, 194 | torch_dtype=torch.bfloat16, 195 | quantization_config=self.quantization_config, 196 | use_fast=False 197 | ) 198 | self.model.eval() 199 | 200 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 201 | if image.size != (384, 384): 202 | image = image.resize((384, 384), Image.LANCZOS) 203 | pixel_values = TVF.pil_to_tensor(image) 204 | pixel_values = (pixel_values / 255.0) 205 | pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) 206 | pixel_values = pixel_values.unsqueeze(0).to(self.device) 207 | 208 | convo = [ 209 | {"role": "system", "content": system_prompt}, 210 | {"role": "user", "content": user_prompt}, 211 | ] 212 | 213 | convo_string = self.processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) 214 | 215 | inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to(self.device) 216 | inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) 217 | 218 | generate_ids = self.model.generate( 219 | **inputs, 220 | max_new_tokens=300, 221 | do_sample=True, 222 | suppress_tokens=None, 223 | use_cache=True, 224 | temperature=0.6, 225 | top_k=None, 226 | top_p=0.9, 227 | )[0] 228 | 229 | generate_ids = generate_ids[inputs['input_ids'].shape[1]:] 230 | 231 | caption = self.processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 232 | caption = caption.strip() 233 | 234 | return caption 235 | 236 | class MoLMoHandler(ModelHandler): 237 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 238 | super().__init__(model_name_or_path, device, args_dict) 239 | 240 | def _initialize_model(self): 241 | self.processor = AutoProcessor.from_pretrained( 242 | self.model_name_or_path, 243 | trust_remote_code=True, 244 | torch_dtype='auto', 245 | use_fast=False 246 | ) 247 | self.model = AutoModelForCausalLM.from_pretrained( 248 | self.model_name_or_path, 249 | trust_remote_code=True, 250 | torch_dtype='auto', 251 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager" 252 | device_map=self.device 253 | ) 254 | self.model.eval() 255 | 256 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 257 | image = resize_image_proportionally(image, 1024, 1024) 258 | 259 | user_only_prompt = system_prompt + " " + user_prompt 260 | 261 | inputs = self.processor.process(images=image, text=user_only_prompt) 262 | inputs = {k: v.to(self.device).unsqueeze(0) for k, v in inputs.items()} 263 | prompt_tokens = inputs["input_ids"].size(1) 264 | 265 | output = self.model.generate_from_batch( 266 | inputs, 267 | generation_config=GenerationConfig( 268 | max_new_tokens=512, 269 | ), 270 | stopping_criteria=[StopStringCriteria(tokenizer=self.processor.tokenizer, stop_strings=["<|endoftext|>"])], 271 | tokenizer=self.processor.tokenizer, 272 | ) 273 | 274 | generated_tokens = output[0, prompt_tokens:] 275 | caption = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) 276 | return caption 277 | 278 | class MoLMo72bHandler(ModelHandler): 279 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 280 | super().__init__(model_name_or_path, device, args_dict) 281 | 282 | def _initialize_model(self): 283 | gpus = [] 284 | if torch.cuda.is_available(): 285 | for i in range(torch.cuda.device_count()): 286 | device_props = torch.cuda.get_device_properties(i) 287 | compute_cap = device_props.major * 10 + device_props.minor 288 | vram = device_props.total_memory 289 | gpus.append((i, compute_cap, vram)) 290 | 291 | sorted_gpus = sorted(gpus, key=lambda x: (-x[1], -x[2])) 292 | 293 | config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=True) 294 | n_layers = config.num_hidden_layers 295 | 296 | fixed_vram_main = 1 * 1024**3 297 | PER_LAYER_VRAM = 0.75 * 1024**3 298 | SAFETY_MARGIN = 1 299 | 300 | device_map = {"model.vision_backbone": "cpu"} 301 | if sorted_gpus: 302 | layer_allocations = [] 303 | remaining_layers = n_layers 304 | 305 | for i, (dev_id, _, vram) in enumerate(sorted_gpus): 306 | if remaining_layers <= 0: 307 | break 308 | 309 | available_vram = vram * SAFETY_MARGIN 310 | if i == 0: 311 | available_vram -= fixed_vram_main 312 | 313 | max_possible_layers = int(available_vram // PER_LAYER_VRAM) 314 | allocate_layers = min(max_possible_layers, remaining_layers) 315 | layer_allocations.append((dev_id, allocate_layers)) 316 | remaining_layers -= allocate_layers 317 | 318 | if remaining_layers > 0: 319 | layer_allocations[-1] = (layer_allocations[-1][0], layer_allocations[-1][1] + remaining_layers) 320 | 321 | current_layer = 0 322 | for dev_id, layers in layer_allocations: 323 | end_layer = current_layer + layers 324 | for layer_idx in range(current_layer, end_layer): 325 | device_map[f"model.transformer.blocks.{layer_idx}"] = dev_id 326 | current_layer = end_layer 327 | 328 | main_gpu = sorted_gpus[0][0] 329 | secondary_gpu = sorted_gpus[1][0] if len(sorted_gpus) > 1 else main_gpu 330 | device_map.update({ 331 | "model.transformer.wte": main_gpu, 332 | "model.transformer.ln_f": main_gpu, 333 | "model.transformer.ff_out": secondary_gpu, 334 | }) 335 | else: 336 | device_map = "auto" 337 | 338 | self.model = AutoModelForCausalLM.from_pretrained( 339 | self.model_name_or_path, 340 | device_map=device_map, 341 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager" 342 | torch_dtype=torch.bfloat16, 343 | use_safetensors=True, 344 | trust_remote_code=True 345 | ) 346 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path, trust_remote_code=True, use_fast=False) 347 | self.model.model.vision_backbone.float() 348 | 349 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 350 | image = resize_image_proportionally(image, 1024, 1024) 351 | 352 | user_only_prompt = system_prompt + " " + user_prompt 353 | 354 | inputs = self.processor.process(images=image, text=user_only_prompt) 355 | inputs = {k: v.to(self.device).unsqueeze(0) for k, v in inputs.items()} 356 | prompt_tokens = inputs["input_ids"].size(1) 357 | 358 | output = self.model.generate_from_batch( 359 | inputs, 360 | generation_config=GenerationConfig( 361 | max_new_tokens=512, 362 | ), 363 | stopping_criteria=[StopStringCriteria(tokenizer=self.processor.tokenizer, stop_strings=["<|qqqq|>"])], 364 | tokenizer=self.processor.tokenizer, 365 | ) 366 | 367 | generated_tokens = output[0, prompt_tokens:] 368 | caption = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) 369 | return caption 370 | 371 | class Qwen2VLHandler(ModelHandler): 372 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 373 | super().__init__(model_name_or_path, device, args_dict) 374 | 375 | def _initialize_model(self): 376 | self.model = Qwen2VLForConditionalGeneration.from_pretrained( 377 | self.model_name_or_path, 378 | torch_dtype='auto', 379 | quantization_config=self._get_quantization_config(), 380 | device_map=self.device 381 | ) 382 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path) 383 | 384 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 385 | image = resize_image_proportionally(image, 1024, 1024) 386 | messages = [ 387 | { 388 | "role": "system", 389 | "content": [{"type": "text", "text": system_prompt}] 390 | }, 391 | { 392 | "role": "user", 393 | "content": [ 394 | { 395 | "type": "image", 396 | 'image': image 397 | }, 398 | { 399 | "type": "text", 400 | "text": user_prompt 401 | } 402 | ] 403 | } 404 | ] 405 | 406 | text = self.processor.apply_chat_template( 407 | messages, tokenize=False, add_generation_prompt=True 408 | ) 409 | image_inputs, _ = process_vision_info(messages) 410 | inputs = self.processor( 411 | text=[text], 412 | images=image_inputs, 413 | padding=True, 414 | return_tensors="pt", 415 | ) 416 | inputs = inputs.to(self.device) 417 | 418 | generated_ids = self.model.generate(**inputs, max_new_tokens=512) 419 | generated_ids_trimmed = [ 420 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 421 | ] 422 | caption = self.processor.batch_decode( 423 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 424 | )[0] 425 | return caption 426 | 427 | class PixtralHandler(ModelHandler): 428 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 429 | super().__init__(model_name_or_path, device, args_dict) 430 | 431 | def _initialize_model(self): 432 | self.quantization_config = self._get_quantization_config() 433 | 434 | self.model = LlavaForConditionalGeneration.from_pretrained( 435 | self.model_name_or_path, 436 | torch_dtype=torch.bfloat16, 437 | quantization_config=self.quantization_config, 438 | device_map=self.device 439 | ) 440 | self.processor = AutoProcessor.from_pretrained( 441 | self.model_name_or_path, 442 | trust_remote_code=True, 443 | torch_dtype=torch.bfloat16, 444 | use_fast = False, 445 | quantization_config=self.quantization_config 446 | ) 447 | self.model.eval() 448 | 449 | def _get_quantization_config(self): 450 | return BitsAndBytesConfig( 451 | load_in_4bit=True, 452 | bnb_4bit_quant_type="nf4", 453 | bnb_4bit_use_double_quant=True, 454 | bnb_4bit_compute_dtype=torch.bfloat16 455 | ) 456 | 457 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 458 | image = resize_image_proportionally(image, 768, 768) 459 | 460 | user_only_prompt = system_prompt + " " + user_prompt 461 | 462 | conversation = [ 463 | { 464 | "role": "user", 465 | "content": [ 466 | {"type": "text", "text": user_only_prompt}, 467 | {"type": "image"}, 468 | ], 469 | } 470 | ] 471 | 472 | prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) 473 | 474 | inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device) 475 | 476 | with torch.no_grad(): 477 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 478 | generate_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, use_cache=True, top_k=20) 479 | caption = self.processor.batch_decode(generate_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] 480 | return caption 481 | 482 | class Idefics3Handler(ModelHandler): 483 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 484 | super().__init__(model_name_or_path, device, args_dict) 485 | 486 | def _initialize_model(self): 487 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path) 488 | self.model = AutoModelForVision2Seq.from_pretrained( 489 | self.model_name_or_path, 490 | torch_dtype=torch.bfloat16, 491 | device_map=self.device 492 | ) 493 | self.model.eval() 494 | 495 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 496 | image = resize_image_proportionally(image, 2048, 2048) 497 | messages = [ 498 | { 499 | "role": "system", 500 | "content": [ 501 | {"type": "text", "text": system_prompt} 502 | ] 503 | }, 504 | { 505 | "role": "user", 506 | "content": [ 507 | {"type": "image"}, 508 | {"type": "text", "text": user_prompt} 509 | ] 510 | } 511 | ] 512 | 513 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) 514 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt") 515 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 516 | 517 | with torch.no_grad(): 518 | generated_ids = self.model.generate(**inputs, max_new_tokens=512) 519 | generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True) 520 | caption = generated_texts[0].split("Assistant: ")[1] 521 | return caption 522 | 523 | class LlavaHandler(ModelHandler): 524 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 525 | super().__init__(model_name_or_path, device, args_dict) 526 | 527 | def _initialize_model(self): 528 | self.model = LlavaNextForConditionalGeneration.from_pretrained( 529 | self.model_name_or_path, 530 | torch_dtype=torch.float16, 531 | #low_cpu_mem_usage=True, 532 | vision_feature_select_strategy="default", 533 | attn_implementation='flash_attention_2', # sdpa or flash_attention_2 or "eager" 534 | device_map=self.device 535 | ) 536 | self.processor = LlavaNextProcessor.from_pretrained( 537 | self.model_name_or_path, 538 | #padding_side="left", 539 | #vision_feature_select_strategy="default", 540 | #patch_size=32, 541 | use_fast=False 542 | ) 543 | self.processor.patch_size = self.model.config.vision_config.patch_size 544 | self.processor.vision_feature_select_strategy = self.model.config.vision_feature_select_strategy 545 | 546 | self.model.eval() 547 | 548 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 549 | image = resize_image_proportionally(image, 768, 768) 550 | messages = [ 551 | { 552 | "role": "system", 553 | "content": [ 554 | {"type": "text", "text": system_prompt} 555 | ] 556 | }, 557 | { 558 | "role": "user", 559 | "content": [ 560 | {"type": "image"}, 561 | {"type": "text", "text": user_prompt} 562 | ] 563 | } 564 | ] 565 | 566 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) 567 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt") 568 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 569 | 570 | with torch.no_grad(): 571 | generated_ids = self.model.generate(**inputs, max_new_tokens=512) 572 | generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True) 573 | caption = generated_texts[0].split("[/INST] ")[1] 574 | return caption 575 | 576 | class MiniCPMoHandler(ModelHandler): 577 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 578 | super().__init__(model_name_or_path, device, args_dict) 579 | 580 | def _initialize_model(self): 581 | self.quantization_config = self._get_quantization_config() 582 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True) 583 | self.model = AutoModel.from_pretrained( 584 | self.model_name_or_path, 585 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager" 586 | torch_dtype=torch.bfloat16, 587 | trust_remote_code=True, 588 | local_files_only=True, 589 | init_vision=True, 590 | init_audio=True, 591 | init_tts=True, 592 | quantization_config=self.quantization_config, 593 | device_map=self.device 594 | ) 595 | 596 | self.model.eval() 597 | self.model.init_tts() 598 | self.model.tts.float() 599 | 600 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 601 | image = resize_image_proportionally(image, 2048, 2048) 602 | 603 | user_only_prompt = system_prompt + " " + user_prompt 604 | 605 | msgs = [{'role': 'user', 'content': [image, user_only_prompt]}] 606 | caption = self.model.chat( 607 | image=None, 608 | msgs=msgs, 609 | tokenizer=self.tokenizer 610 | ) 611 | 612 | return caption 613 | 614 | class GenericModelHandler(ModelHandler): 615 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]): 616 | super().__init__(model_name_or_path, device, args_dict) 617 | 618 | def _initialize_model(self): 619 | self.quantization_config = self._get_quantization_config() 620 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True) 621 | self.model = AutoModel.from_pretrained( 622 | self.model_name_or_path, 623 | torch_dtype=torch.bfloat16, 624 | trust_remote_code=True, 625 | quantization_config=self.quantization_config, 626 | device_map=self.device 627 | ) 628 | 629 | self.model.eval() 630 | 631 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str: 632 | image = resize_image_proportionally(image, 1024, 1024) 633 | 634 | user_only_prompt = system_prompt + " " + user_prompt 635 | 636 | msgs = [{'role': 'user', 'content': [image, user_only_prompt]}] 637 | caption = self.model.chat( 638 | image=None, 639 | msgs=msgs, 640 | tokenizer=self.tokenizer 641 | ) 642 | 643 | return caption 644 | --------------------------------------------------------------------------------