├── .gitignore ├── docs ├── banner.jpg ├── llama_cpp_installation.md └── ollama_installation.md ├── .gitattributes ├── requirements.txt ├── run.py ├── help_text.md ├── setup.py ├── services ├── openrouter_api.py ├── create_dataset.py ├── template_manager.py ├── model_converter.py └── train_model.py ├── README.md └── app.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | __pycache__ -------------------------------------------------------------------------------- /docs/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuslanKoroy/QTune/HEAD/docs/banner.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | torch 3 | transformers 4 | peft 5 | trl 6 | bitsandbytes 7 | accelerate 8 | datasets 9 | huggingface-hub 10 | openai 11 | icecream 12 | sentencepiece 13 | protobuf 14 | jinja2 15 | tokenizers 16 | psutil 17 | gputil 18 | requests 19 | tensorboard -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | def main(): 6 | print("🚀 Starting QTune...") 7 | print("Please wait while the application initializes...\n") 8 | 9 | # Run the main application 10 | try: 11 | subprocess.run([sys.executable, "app.py"], check=True) 12 | except subprocess.CalledProcessError as e: 13 | print(f"❌ Error running QTune Studio: {e}") 14 | sys.exit(1) 15 | except KeyboardInterrupt: 16 | print("\n👋 QTune Studio stopped by user.") 17 | sys.exit(0) 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /help_text.md: -------------------------------------------------------------------------------- 1 | ## Help & Documentation 2 | 3 | ### Getting Started 4 | 1. Set your API keys in the Settings tab 5 | 2. Select a model from the Model Selection tab 6 | 3. Prepare your dataset in the Dataset Preparation tab 7 | 4. Configure training parameters in the Training Configuration tab 8 | 5. Start training in the Training tab 9 | 6. Convert your model in the Model Conversion tab 10 | 11 | ### API Keys 12 | - **OpenRouter API Key**: Required for dataset generation using large models 13 | - Get one at https://openrouter.ai/ 14 | 15 | ### System Requirements 16 | - Python 3.8 or higher 17 | - CUDA-compatible GPU with at least 8GB VRAM (recommended) 18 | - At least 16GB system RAM 19 | 20 | ### Troubleshooting 21 | - If you encounter memory errors, reduce batch size or enable gradient checkpointing 22 | - For slow training, ensure you're using a CUDA-compatible GPU 23 | - If models fail to load, check your internet connection 24 | 25 | ### Tips for Best Results 26 | - Use high-quality, diverse datasets for training 27 | - Start with smaller models for testing 28 | - Monitor VRAM usage during training 29 | - Experiment with different QLoRA parameters -------------------------------------------------------------------------------- /docs/llama_cpp_installation.md: -------------------------------------------------------------------------------- 1 | # Installing llama.cpp 2 | 3 | To install llama.cpp for GGUF conversion, follow these steps: 4 | 5 | ## Prerequisites 6 | - Git 7 | - CMake (for building from source) 8 | - A C++ compiler (GCC or Clang) 9 | 10 | ## Installation Steps 11 | 12 | ### 1. Clone the llama.cpp repository: 13 | ```bash 14 | git clone https://github.com/ggerganov/llama.cpp.git 15 | cd llama.cpp 16 | ``` 17 | 18 | ### 2. Compile the tools: 19 | ```bash 20 | make 21 | ``` 22 | 23 | ### 3. Add the tools to your PATH: 24 | ```bash 25 | export PATH=$(pwd):$PATH 26 | ``` 27 | 28 | ### 4. Verify installation: 29 | ```bash 30 | llama-quantize --help 31 | ``` 32 | 33 | ## Alternative Installation Methods 34 | 35 | ### Using Package Managers 36 | 37 | #### On Ubuntu/Debian: 38 | ```bash 39 | sudo apt update 40 | sudo apt install llama.cpp 41 | ``` 42 | 43 | #### On macOS with Homebrew: 44 | ```bash 45 | brew install llama.cpp 46 | ``` 47 | 48 | ### Using Pre-compiled Binaries 49 | 50 | You can download pre-compiled binaries from the [llama.cpp releases page](https://github.com/ggerganov/llama.cpp/releases). 51 | 52 | ## Required Tools 53 | 54 | The following tools from llama.cpp are used by this converter: 55 | - `convert-hf-to-gguf.py` - For converting Hugging Face models to GGUF format 56 | - `llama-quantize` - For quantizing GGUF models 57 | 58 | Make sure these tools are available in your PATH after installation. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r", encoding="utf-8") as fh: 7 | requirements = fh.read().splitlines() 8 | 9 | setup( 10 | name="qtune", 11 | version="1.0.0", 12 | author="QTune Team", 13 | author_email="qtune@example.com", 14 | description="A comprehensive web application for fine-tuning language models on consumer GPUs", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/RuslanKoroy/qtune", 18 | packages=find_packages(), 19 | classifiers=[ 20 | "Development Status :: 4 - Beta", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ], 32 | python_requires=">=3.8", 33 | install_requires=requirements, 34 | entry_points={ 35 | "console_scripts": [ 36 | "qtune=run:main", 37 | ], 38 | }, 39 | include_package_data=True, 40 | package_data={ 41 | "": ["README.md", "requirements.txt"], 42 | }, 43 | ) -------------------------------------------------------------------------------- /docs/ollama_installation.md: -------------------------------------------------------------------------------- 1 | # Installing Ollama 2 | 3 | To install Ollama, follow these steps: 4 | 5 | ## Installation Steps 6 | 7 | ### 1. For Linux/Mac: 8 | ```bash 9 | curl -fsSL https://ollama.com/install.sh | sh 10 | ``` 11 | 12 | ### 2. For Windows: 13 | Download the installer from [Ollama Setup.exe](https://ollama.com/download/OllamaSetup.exe) 14 | 15 | ### 3. Start Ollama: 16 | ```bash 17 | ollama serve 18 | ``` 19 | 20 | ### 4. Verify installation: 21 | ```bash 22 | ollama --version 23 | ``` 24 | 25 | ## System Requirements 26 | 27 | - **Linux**: Most modern Linux distributions 28 | - **macOS**: macOS 11+ (Big Sur) 29 | - **Windows**: Windows 10 or later 30 | 31 | ## GPU Support 32 | 33 | Ollama supports GPU acceleration on the following platforms: 34 | 35 | ### NVIDIA GPUs 36 | - Install NVIDIA drivers 37 | - CUDA support is included with Ollama 38 | 39 | ### AMD GPUs 40 | - Install ROCm drivers 41 | - AMD GPU support is experimental 42 | 43 | ### Apple Silicon 44 | - Native support for M1/M2/M3 chips 45 | - No additional drivers required 46 | 47 | ## Common Issues and Troubleshooting 48 | 49 | ### Permission Denied Errors 50 | If you encounter permission errors, try: 51 | ```bash 52 | sudo curl -fsSL https://ollama.com/install.sh | sh 53 | ``` 54 | 55 | ### Service Not Starting 56 | If Ollama service doesn't start automatically: 57 | ```bash 58 | sudo systemctl start ollama 59 | sudo systemctl enable ollama 60 | ``` 61 | 62 | ### Model Pull Issues 63 | If you have issues pulling models: 64 | ```bash 65 | # Check if the service is running 66 | ollama list 67 | 68 | # Try pulling a model manually 69 | ollama pull llama3 70 | ``` 71 | 72 | ## Useful Commands 73 | 74 | - `ollama list` - List downloaded models 75 | - `ollama run ` - Run a model 76 | - `ollama ps` - Show running models 77 | - `ollama rm ` - Remove a model -------------------------------------------------------------------------------- /services/openrouter_api.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import os 3 | 4 | openrouter_client = None 5 | OPENROUTER_KEY = "" 6 | 7 | def update_openrouter_client(api_key=None): 8 | """Update the OpenRouter client with a new API key""" 9 | global openrouter_client, OPENROUTER_KEY 10 | if api_key: 11 | OPENROUTER_KEY = api_key 12 | if OPENROUTER_KEY: 13 | openrouter_client = OpenAI( 14 | api_key=OPENROUTER_KEY, 15 | base_url="https://openrouter.ai/api/v1", 16 | ) 17 | else: 18 | openrouter_client = None 19 | 20 | def generate(messages: list[dict], prompt_name: str, replace_dict: dict = None, model: str = None) -> str: 21 | if model is None: 22 | return "Error: No model specified for generation" 23 | 24 | prompt_file_path = f'{prompt_name}.md' 25 | try: 26 | with open(prompt_file_path, 'r', encoding='utf8') as f: 27 | system_message = f.read() 28 | except FileNotFoundError: 29 | print(f"Error: Prompt file '{prompt_file_path}' not found.") 30 | system_message = "You are a useful assistant." 31 | except Exception as e: 32 | print(f"Error reading prompt file '{prompt_file_path}': {e}") 33 | system_message = "You are a useful assistant." 34 | 35 | if replace_dict: 36 | for key, value in replace_dict.items(): 37 | if value is not None: 38 | system_message = system_message.replace(key, str(value)) 39 | 40 | full_messages = [{'role': 'system', 'content': system_message}] + list(messages) 41 | 42 | if openrouter_client is None: 43 | print("OpenRouter API key not set. Please set OPENROUTER_KEY in Settings.") 44 | return "API key not set. Please configure your OpenRouter API key in Settings." 45 | 46 | try: 47 | chat_completion = openrouter_client.chat.completions.create( 48 | model=model, 49 | messages=full_messages 50 | ) 51 | generated_text = chat_completion.choices[0].message.content 52 | return generated_text 53 | 54 | except Exception as exception: 55 | print(f'Error in openrouter_generate with model {model}: {exception}') 56 | return f"Error generating response: {exception}" 57 | -------------------------------------------------------------------------------- /services/create_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from .openrouter_api import generate 4 | from icecream import ic 5 | 6 | def create_dataset( 7 | samples_file: str = 'samples_request.md', 8 | system_prompt_file: str = 'system_prompt.md', 9 | output_dataset_file: str = 'dataset.json', 10 | model_name: str = 'google/gemma-3-27b-it:free' 11 | ): 12 | dataset = [] 13 | 14 | # Load system prompt 15 | try: 16 | with open(system_prompt_file, 'r', encoding='utf8') as f: 17 | system_prompt_content = f.read() 18 | except FileNotFoundError: 19 | ic(f"Error: System prompt file '{system_prompt_file}' not found.") 20 | return False, f"Error: System prompt file '{system_prompt_file}' not found." 21 | except Exception as e: 22 | ic(f"Error reading system prompt file '{system_prompt_file}': {e}") 23 | return False, f"Error reading system prompt file '{system_prompt_file}': {e}" 24 | 25 | # Load user samples 26 | try: 27 | with open(samples_file, 'r', encoding='utf8') as f: 28 | user_samples = f.readlines() 29 | except FileNotFoundError: 30 | ic(f"Error: User samples file '{samples_file}' not found.") 31 | return False, f"Error: Sample query file '{samples_file}' not found." 32 | except Exception as e: 33 | ic(f"Error reading user samples file '{samples_file}': {e}") 34 | return False, f"Error reading sample queries file'{samples_file}': {e}" 35 | 36 | ic(f"Starting dataset generation from {len(user_samples)} samples...") 37 | 38 | for i, sample_line in enumerate(user_samples): 39 | user_message = sample_line.strip() 40 | if not user_message: 41 | continue 42 | 43 | ic(f"Processing sample {i+1}/{len(user_samples)}): '{user_message}'") 44 | 45 | # Generate messages for LLM 46 | messages = [ 47 | {'role': 'user', 'content': user_message} 48 | ] 49 | 50 | # Generate response 51 | assistant_response = generate( 52 | messages=messages, 53 | prompt_name=os.path.splitext(os.path.basename(system_prompt_file))[0], 54 | model=model_name 55 | ) 56 | 57 | if assistant_response: 58 | dataset_entry = { 59 | "messages": [ 60 | {"role": "system", "content": system_prompt_content}, 61 | {"role": "user", "content": user_message}, 62 | {"role": "assistant", "content": assistant_response} 63 | ] 64 | } 65 | dataset.append(dataset_entry) 66 | ic(f"Added entry to dataset for sample {i+1}.") 67 | else: 68 | ic(f"Failed to get response for sample {i+1}: '{user_message}'. Skipping.") 69 | 70 | # Save the dataset 71 | try: 72 | with open(output_dataset_file, 'w', encoding='utf8') as f: 73 | json.dump(dataset, f, ensure_ascii=False, indent=2) 74 | ic(f"Dataset successfully saved to '{output_dataset_file}'. Total records: {len(dataset)}") 75 | except Exception as e: 76 | ic(f"Error saving dataset to '{output_dataset_file}': {e}") 77 | return False, f"Ошибка при сохранении датасета в '{output_dataset_file}': {e}" 78 | 79 | return True, f"Dataset successfully created and saved to '{output_dataset_file}'. Total entries: {len(dataset)}" 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QTune Documentation 2 | 3 | ![Banner](docs/banner.jpg) 4 | 5 | ## Overview 6 | 7 | QTune is a comprehensive web application for fine-tuning language models on consumer GPUs with as little as 8GB of VRAM. Built with Gradio, it provides an intuitive interface for the entire fine-tuning workflow. 8 | 9 | ## Features 10 | 11 | ### 1. Model Selection 12 | - Automatically fetch models from Hugging Face 13 | - Detailed model information including VRAM requirements 14 | - Support for popular models like Gemma, Llama, Mistral, and more 15 | - Manual model entry for models not in the list 16 | - Automatic validation of manually entered models on HuggingFace 17 | - Links to model repositories (Hugging Face and OpenRouter) 18 | 19 | ### 2. Dataset Preparation 20 | - Create datasets using larger models via OpenRouter API 21 | - Upload your own datasets in JSON format 22 | - Dataset preview and validation 23 | - Manual model entry for generation models not in the list 24 | 25 | ### 3. Training Configuration 26 | - Fine-tune QLoRA parameters (rank, alpha, dropout, target modules) 27 | - Configure training parameters (epochs, batch size, learning rate) 28 | - Memory optimization settings for 8GB VRAM GPUs 29 | 30 | ### 4. Training Execution 31 | - Start and monitor training process 32 | - Real-time training logs 33 | - Progress tracking 34 | 35 | ### 5. Model Conversion 36 | - Convert models to GGUF format 37 | - Multiple quantization options (Q4_K_M, Q5_K_M, Q8_0, F16) 38 | - Integration with Ollama for easy deployment 39 | 40 | ### 6. Template Management 41 | - Automatically fetch chat templates from models 42 | - Create and save custom templates 43 | - Apply templates to datasets 44 | 45 | ## Installation 46 | 47 | ### Prerequisites 48 | - Python 3.8 or higher 49 | - CUDA-compatible GPU with at least 8GB VRAM (recommended) 50 | - At least 16GB system RAM 51 | 52 | ### Quick Installation 53 | ```bash 54 | # Clone the repository 55 | git clone https://github.com/RuslanKoroy/QTune 56 | cd qtune 57 | 58 | # Install dependencies 59 | pip install -r requirements.txt 60 | 61 | # For GGUF conversion and Ollama integration (optional) 62 | # Install llama.cpp 63 | git clone https://github.com/ggerganov/llama.cpp.git 64 | cd llama.cpp 65 | make 66 | export PATH=$(pwd):$PATH 67 | 68 | # Install Ollama 69 | curl -fsSL https://ollama.com/install.sh | sh 70 | ``` 71 | 72 | ## Usage 73 | 74 | ### Starting the Application 75 | ```bash 76 | python app.py 77 | ``` 78 | 79 | The application will start and provide a local URL to access the web interface. 80 | 81 | ### Workflow 82 | 83 | 1. **Select a Model**: Choose from popular models on Hugging Face 84 | 2. **Prepare Dataset**: Either generate a dataset using larger models or upload your own 85 | 3. **Configure Training**: Adjust QLoRA and training parameters 86 | 4. **Start Training**: Begin the fine-tuning process 87 | 5. **Convert Model**: Export to GGUF format for deployment 88 | 6. **Deploy**: Push to Ollama for easy inference 89 | 90 | ## API Keys 91 | 92 | ### OpenRouter API Key 93 | To use the dataset generation features, you need an OpenRouter API key: 94 | 1. Sign up at https://openrouter.ai/ 95 | 2. Get your API key from the dashboard 96 | 3. Enter the key in the Settings tab of the application 97 | - The key will be saved automatically to api_keys.json 98 | - It will be loaded automatically when the application starts 99 | 100 | ## Troubleshooting 101 | 102 | ### Common Issues 103 | 104 | 1. **Import Errors**: Make sure all dependencies are installed correctly 105 | 2. **CUDA Issues**: Ensure you have the correct CUDA version for your PyTorch installation 106 | 3. **Memory Errors**: Reduce batch size or enable gradient checkpointing 107 | 4. **Model Loading Issues**: Check internet connection and Hugging Face credentials 108 | 109 | ### Compatibility Notes 110 | - The application is optimized for 8GB VRAM consumer GPUs 111 | - Training on CPU is possible but will be very slow 112 | - Some features require additional tools (llama.cpp, Ollama) to be installed separately 113 | -------------------------------------------------------------------------------- /services/template_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from transformers import AutoTokenizer 4 | 5 | def fetch_template_from_model(model_name): 6 | try: 7 | # Load the tokenizer 8 | tokenizer = AutoTokenizer.from_pretrained(model_name) 9 | 10 | # Get the chat template if it exists 11 | if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template: 12 | template = tokenizer.chat_template 13 | return True, template 14 | else: 15 | # Return a default template if none exists 16 | default_template = """{% for message in messages %} 17 | {% if message['role'] == 'user' %} 18 | User: {{ message['content'] }} 19 | {% elif message['role'] == 'assistant' %} 20 | Assistant: {{ message['content'] }} 21 | {% elif message['role'] == 'system' %} 22 | System: {{ message['content'] }} 23 | {% endif %} 24 | {% endfor %} 25 | {% if add_generation_prompt %} 26 | Assistant: 27 | {% endif %}""" 28 | return True, default_template 29 | 30 | except Exception as e: 31 | # Return a default template if we can't fetch from the model 32 | default_template = """{% for message in messages %} 33 | {% if message['role'] == 'user' %} 34 | User: {{ message['content'] }} 35 | {% elif message['role'] == 'assistant' %} 36 | Assistant: {{ message['content'] }} 37 | {% elif message['role'] == 'system' %} 38 | System: {{ message['content'] }} 39 | {% endif %} 40 | {% endfor %} 41 | {% if add_generation_prompt %} 42 | Assistant: 43 | {% endif %}""" 44 | return False, f"Could not fetch template from model: {str(e)}. Using default template." 45 | 46 | def save_custom_template(template_content, template_name="custom_template"): 47 | try: 48 | # Create templates directory if it doesn't exist 49 | templates_dir = "templates" 50 | if not os.path.exists(templates_dir): 51 | os.makedirs(templates_dir) 52 | 53 | # Save the template 54 | template_path = os.path.join(templates_dir, f"{template_name}.txt") 55 | with open(template_path, "w", encoding="utf-8") as f: 56 | f.write(template_content) 57 | 58 | return True, f"Custom template saved successfully to {template_path}" 59 | 60 | except Exception as e: 61 | return False, f"Error saving template: {str(e)}" 62 | 63 | def load_custom_template(template_name="custom_template"): 64 | try: 65 | template_path = os.path.join("templates", f"{template_name}.txt") 66 | 67 | if not os.path.exists(template_path): 68 | return False, f"Template file not found: {template_path}" 69 | 70 | with open(template_path, "r", encoding="utf-8") as f: 71 | template_content = f.read() 72 | 73 | return True, template_content 74 | 75 | except Exception as e: 76 | return False, f"Error loading template: {str(e)}" 77 | 78 | def list_available_templates(): 79 | templates_dir = "templates" 80 | if not os.path.exists(templates_dir): 81 | return [] 82 | 83 | templates = [] 84 | for file in os.listdir(templates_dir): 85 | if file.endswith(".txt"): 86 | templates.append(file[:-4]) # Remove .txt extension 87 | 88 | return templates 89 | 90 | def apply_template_to_dataset(dataset_path, template_name="custom_template"): 91 | try: 92 | # Load the dataset 93 | with open(dataset_path, "r", encoding="utf-8") as f: 94 | dataset = json.load(f) 95 | 96 | # Load the template 97 | success, template = load_custom_template(template_name) 98 | if not success: 99 | return False, template # template contains the error message 100 | 101 | # In a real implementation, we would apply the template to the dataset 102 | # For now, we'll just return a success message 103 | message = f"Template '{template_name}' would be applied to dataset '{dataset_path}'.\n" 104 | message += "In a full implementation, this would format the dataset according to the template." 105 | 106 | return True, message 107 | 108 | except Exception as e: 109 | return False, f"Error applying template to dataset: {str(e)}" 110 | -------------------------------------------------------------------------------- /services/model_converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | 6 | def check_llama_cpp_installed(): 7 | try: 8 | # Try to run llama.cpp help command 9 | result = subprocess.run(["llama-quantize", "--help"], 10 | capture_output=True, text=True, timeout=10) 11 | return True 12 | except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): 13 | return False 14 | 15 | def convert_to_gguf(model_path, output_path, quantization_type="Q4_K_M"): 16 | try: 17 | # Check if model path exists 18 | if not os.path.exists(model_path): 19 | return False, f"Model path does not exist: {model_path}" 20 | 21 | # Check if llama.cpp tools are available 22 | if not check_llama_cpp_installed(): 23 | return False, "llama.cpp tools not found. Please install llama.cpp first." 24 | 25 | # Create output directory if it doesn't exist 26 | output_dir = os.path.dirname(output_path) 27 | if output_dir and not os.path.exists(output_dir): 28 | os.makedirs(output_dir) 29 | 30 | # Convert model to GGUF format 31 | 32 | # Check if the model is in a format that needs conversion to ggml first 33 | model_extension = os.path.splitext(model_path)[1].lower() 34 | 35 | if model_extension in ['.bin', '.pt', '.pth']: 36 | # Convert to ggml first 37 | cmd = ["python", "convert-hf-to-gguf.py", model_path] 38 | result = subprocess.run(cmd, capture_output=True, text=True) 39 | if result.returncode != 0: 40 | return False, f"Error converting model to ggml: {result.stderr}" 41 | 42 | # The conversion script typically outputs to the same directory 43 | # with a .gguf extension 44 | ggml_path = os.path.splitext(model_path)[0] + ".gguf" 45 | else: 46 | ggml_path = model_path 47 | 48 | # Quantize the model 49 | cmd = ["llama-quantize", ggml_path, output_path, quantization_type] 50 | result = subprocess.run(cmd, capture_output=True, text=True) 51 | 52 | if result.returncode != 0: 53 | return False, f"Error during quantization: {result.stderr}" 54 | 55 | message = f"Successfully converted {model_path} to GGUF format with {quantization_type} quantization.\n" 56 | message += f"Output saved to: {output_path}" 57 | 58 | return True, message 59 | 60 | except Exception as e: 61 | return False, f"Error during conversion: {str(e)}" 62 | 63 | def push_to_ollama(model_path, model_name): 64 | try: 65 | # Check if Ollama is installed and running 66 | try: 67 | result = subprocess.run(["ollama", "--version"], 68 | capture_output=True, text=True, timeout=10) 69 | ollama_available = True 70 | except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): 71 | ollama_available = False 72 | 73 | if not ollama_available: 74 | return False, "Ollama not found. Please install and start Ollama first." 75 | 76 | # Check if model file exists 77 | if not os.path.exists(model_path): 78 | return False, f"Model file not found: {model_path}" 79 | 80 | # Create a temporary model file for Ollama 81 | model_file_content = f"""FROM {model_path} 82 | PARAMETER temperature 0.7 83 | PARAMETER stop Result 84 | PARAMETER stop Human 85 | PARAMETER stop ### 86 | """ 87 | 88 | # Create a temporary model file 89 | temp_model_file = f"/tmp/{model_name}.modelfile" 90 | with open(temp_model_file, 'w') as f: 91 | f.write(model_file_content) 92 | 93 | # Push the model to Ollama 94 | cmd = ["ollama", "create", model_name, "-f", temp_model_file] 95 | result = subprocess.run(cmd, capture_output=True, text=True) 96 | 97 | if result.returncode != 0: 98 | return False, f"Error creating Ollama model: {result.stderr}" 99 | 100 | message = f"Successfully pushed model to Ollama as '{model_name}'.\n" 101 | message += "You can now use the model with: ollama run " + model_name 102 | 103 | return True, message 104 | 105 | except Exception as e: 106 | return False, f"Error during Ollama push: {str(e)}" 107 | 108 | -------------------------------------------------------------------------------- /services/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 4 | from peft import LoraConfig, PeftModel 5 | from trl import SFTConfig, SFTTrainer 6 | import os 7 | import json 8 | 9 | def get_model_class(model_id): 10 | """Determine the appropriate model class based on model ID""" 11 | return AutoModelForCausalLM 12 | 13 | def get_torch_dtype(): 14 | """Get appropriate torch dtype based on GPU capabilities""" 15 | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: 16 | return torch.bfloat16 17 | else: 18 | return torch.float16 19 | 20 | def load_model_and_tokenizer(model_id, use_4bit=True): 21 | """Load model and tokenizer with quantization for low VRAM usage""" 22 | # Determine model class 23 | model_class = get_model_class(model_id) 24 | 25 | # Get appropriate dtype 26 | torch_dtype = get_torch_dtype() 27 | 28 | # Model kwargs 29 | model_kwargs = dict( 30 | attn_implementation="eager", # Safer for older GPUs 31 | torch_dtype=torch_dtype, 32 | device_map="auto", 33 | ) 34 | 35 | # Add quantization config for 4-bit if requested 36 | if use_4bit: 37 | model_kwargs["quantization_config"] = BitsAndBytesConfig( 38 | load_in_4bit=True, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4', 41 | bnb_4bit_compute_dtype=torch_dtype, 42 | ) 43 | 44 | # Load model and tokenizer 45 | model = model_class.from_pretrained(model_id, **model_kwargs) 46 | tokenizer = AutoTokenizer.from_pretrained(model_id) 47 | 48 | # Set padding side for Gemma models 49 | tokenizer.padding_side = 'right' 50 | 51 | return model, tokenizer 52 | 53 | def create_lora_config(r=16, lora_alpha=32, lora_dropout=0.05, target_modules="all-linear"): 54 | """Create LoRA configuration""" 55 | return LoraConfig( 56 | r=r, 57 | lora_alpha=lora_alpha, 58 | lora_dropout=lora_dropout, 59 | target_modules=target_modules, 60 | bias="none", 61 | task_type="CAUSAL_LM", 62 | ) 63 | 64 | def create_training_args(output_dir="tuned", num_train_epochs=8, per_device_train_batch_size=1, 65 | gradient_accumulation_steps=8, learning_rate=1e-4, 66 | gradient_checkpointing=True, fp16=None, bf16=None): 67 | """Create training arguments""" 68 | torch_dtype = get_torch_dtype() 69 | 70 | # Set precision flags if not provided 71 | if fp16 is None: 72 | fp16 = (torch_dtype == torch.float16) 73 | if bf16 is None: 74 | bf16 = (torch_dtype == torch.bfloat16) 75 | 76 | return SFTConfig( 77 | output_dir=output_dir, 78 | num_train_epochs=num_train_epochs, 79 | per_device_train_batch_size=per_device_train_batch_size, 80 | gradient_accumulation_steps=gradient_accumulation_steps, 81 | gradient_checkpointing=gradient_checkpointing, 82 | optim="paged_adamw_8bit", 83 | logging_steps=20, 84 | save_strategy="steps", 85 | save_steps=200, 86 | learning_rate=learning_rate, 87 | fp16=fp16, 88 | bf16=bf16, 89 | max_grad_norm=0.3, 90 | warmup_ratio=0.03, 91 | lr_scheduler_type="cosine", 92 | weight_decay=0.01, 93 | report_to="tensorboard", 94 | dataset_kwargs={ 95 | "add_special_tokens": True, 96 | "append_concat_token": True, 97 | }, 98 | ) 99 | 100 | def load_and_prepare_dataset(dataset_path, sample_size=5000, test_size=0.2): 101 | """Load and prepare dataset for training""" 102 | # Load dataset 103 | dataset = load_dataset("json", data_files=dataset_path, split="train") 104 | 105 | # Reduce dataset size if needed 106 | if len(dataset) > sample_size: 107 | dataset = dataset.select(range(sample_size)) 108 | 109 | # Split into train/test 110 | dataset = dataset.train_test_split(test_size=test_size) 111 | 112 | return dataset 113 | 114 | def train_model(model_id, dataset_path, output_dir="tuned", 115 | lora_r=16, lora_alpha=32, lora_dropout=0.05, target_modules="all-linear", 116 | num_train_epochs=8, per_device_train_batch_size=1, 117 | gradient_accumulation_steps=8, learning_rate=1e-4, 118 | gradient_checkpointing=True, fp16=None, bf16=None): 119 | """Main training function""" 120 | try: 121 | # Load model and tokenizer 122 | model, tokenizer = load_model_and_tokenizer(model_id) 123 | 124 | # Create LoRA config 125 | peft_config = create_lora_config( 126 | r=lora_r, 127 | lora_alpha=lora_alpha, 128 | lora_dropout=lora_dropout, 129 | target_modules=target_modules 130 | ) 131 | 132 | # Create training args 133 | training_args = create_training_args( 134 | output_dir=output_dir, 135 | num_train_epochs=num_train_epochs, 136 | per_device_train_batch_size=per_device_train_batch_size, 137 | gradient_accumulation_steps=gradient_accumulation_steps, 138 | learning_rate=learning_rate, 139 | gradient_checkpointing=gradient_checkpointing, 140 | fp16=fp16, 141 | bf16=bf16 142 | ) 143 | 144 | # Load and prepare dataset 145 | dataset = load_and_prepare_dataset(dataset_path) 146 | 147 | # Create trainer 148 | trainer = SFTTrainer( 149 | model=model, 150 | args=training_args, 151 | train_dataset=dataset["train"], 152 | peft_config=peft_config 153 | ) 154 | 155 | # Start training 156 | trainer.train() 157 | 158 | # Save model 159 | trainer.save_model() 160 | 161 | # Clean up memory 162 | del model 163 | del trainer 164 | torch.cuda.empty_cache() 165 | 166 | return True, f"Training completed successfully. Model saved to {output_dir}" 167 | 168 | except Exception as e: 169 | # Clean up memory in case of error 170 | torch.cuda.empty_cache() 171 | return False, f"Training failed with error: {str(e)}" 172 | 173 | def merge_lora_model(base_model_id, lora_model_path, output_path="merged_model"): 174 | """Merge LoRA weights with base model""" 175 | try: 176 | # Get appropriate dtype 177 | torch_dtype = get_torch_dtype() 178 | 179 | # Determine model class 180 | model_class = get_model_class(base_model_id) 181 | 182 | # Load base model 183 | base_model = model_class.from_pretrained( 184 | base_model_id, 185 | torch_dtype=torch_dtype, 186 | device_map="auto" 187 | ) 188 | 189 | # Load LoRA model 190 | lora_model = PeftModel.from_pretrained(base_model, lora_model_path) 191 | 192 | # Merge models 193 | merged_model = lora_model.merge_and_unload() 194 | 195 | # Save merged model 196 | merged_model.save_pretrained(output_path, safe_serialization=True, max_shard_size="2GB") 197 | 198 | # Save tokenizer 199 | tokenizer = AutoTokenizer.from_pretrained(lora_model_path) 200 | tokenizer.save_pretrained(output_path) 201 | 202 | # Clean up memory 203 | del base_model 204 | del lora_model 205 | del merged_model 206 | torch.cuda.empty_cache() 207 | 208 | return True, f"Merged model saved to {output_path}" 209 | 210 | except Exception as e: 211 | # Clean up memory in case of error 212 | torch.cuda.empty_cache() 213 | return False, f"Model merging failed with error: {str(e)}" -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | import json 4 | import torch 5 | from huggingface_hub import list_models, model_info 6 | import psutil 7 | import GPUtil 8 | import requests 9 | 10 | # Import our existing modules 11 | from services.openrouter_api import generate 12 | 13 | from services.create_dataset import create_dataset 14 | from services.train_model import train_model, merge_lora_model 15 | from services.model_converter import convert_to_gguf, push_to_ollama 16 | from services.template_manager import fetch_template_from_model, save_custom_template 17 | 18 | # Check if CUDA is available 19 | CUDA_AVAILABLE = torch.cuda.is_available() 20 | DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" 21 | 22 | # Global variables to store state 23 | current_model = None 24 | current_tokenizer = None 25 | training_status = "Not started" 26 | 27 | # Global variables for API keys 28 | openrouter_key = "" 29 | 30 | # File to store API keys 31 | API_KEYS_FILE = "api_keys.json" 32 | 33 | # Load API keys from file at startup 34 | def load_api_keys(): 35 | """Load API keys from file""" 36 | global openrouter_key 37 | try: 38 | if os.path.exists(API_KEYS_FILE): 39 | with open(API_KEYS_FILE, 'r') as f: 40 | keys = json.load(f) 41 | openrouter_key = keys.get("openrouter_key", "") 42 | # Update the OpenRouter client with the loaded key 43 | if openrouter_key: 44 | from services.openrouter_api import update_openrouter_client 45 | update_openrouter_client(openrouter_key) 46 | except Exception as e: 47 | print(f"Error loading API keys: {e}") 48 | 49 | # Model selection history (last 5 models) 50 | model_history = [] 51 | 52 | # Function to fetch models from Hugging Face 53 | def fetch_hf_models(): 54 | try: 55 | # Fetch popular language models 56 | models = list_models(filter="text-generation", sort="downloads", direction=-1, limit=20) 57 | model_names = [model.id for model in models] 58 | return model_names 59 | except Exception as e: 60 | print(f"Error fetching models from Hugging Face: {e}") 61 | # Return default models if API fails 62 | return [ 63 | "google/gemma-3-4b-it", 64 | "google/gemma-3-2b-it" 65 | ] 66 | 67 | # Function to fetch models from OpenRouter 68 | def fetch_openrouter_models(): 69 | try: 70 | url = "https://openrouter.ai/api/v1/models" 71 | headers = {} 72 | # Add authentication header if API key is available 73 | if openrouter_key: 74 | headers["Authorization"] = f"Bearer {openrouter_key}" 75 | response = requests.get(url, headers=headers) 76 | if response.status_code == 200: 77 | data = response.json() 78 | # Extract model IDs from the response 79 | model_names = [model["id"] for model in data["data"]] 80 | # Filter for text generation models and sort by popularity 81 | text_models = [model for model in model_names if "embedding" not in model.lower()] 82 | return text_models[:50] # Return top 50 models 83 | else: 84 | print(f"Error fetching models from OpenRouter: {response.status_code}") 85 | return [] 86 | except Exception as e: 87 | print(f"Error fetching models from OpenRouter: {e}") 88 | return [] 89 | 90 | # Function to search models 91 | def search_models(query, all_models): 92 | """Filter models based on search query""" 93 | if not query: 94 | return all_models[:50] # Return top 50 models if no query 95 | filtered_models = [model for model in all_models if query.lower() in model.lower()] 96 | return filtered_models[:50] # Return top 50 matching models 97 | 98 | # Function to search across all available models (HF and OpenRouter) 99 | def search_all_models(query): 100 | """Search across all available models from both HuggingFace and OpenRouter""" 101 | all_models = [] 102 | 103 | # Get HuggingFace models 104 | try: 105 | hf_models = fetch_hf_models() 106 | all_models.extend(hf_models) 107 | except Exception as e: 108 | print(f"Error fetching HuggingFace models: {e}") 109 | 110 | # Get OpenRouter models 111 | try: 112 | or_models = fetch_openrouter_models() 113 | all_models.extend(or_models) 114 | except Exception as e: 115 | print(f"Error fetching OpenRouter models: {e}") 116 | 117 | # Search through all models 118 | return search_models(query, all_models) 119 | 120 | # Function to add model to history 121 | def add_to_model_history(model_name): 122 | """Add model to history, keeping only the last 5""" 123 | global model_history 124 | if model_name in model_history: 125 | model_history.remove(model_name) 126 | model_history.insert(0, model_name) 127 | model_history = model_history[:5] # Keep only last 5 128 | return model_history 129 | 130 | # Function to validate if a model exists on HuggingFace 131 | def validate_hf_model(model_name): 132 | """Check if a model exists on HuggingFace""" 133 | try: 134 | # Try to fetch model info 135 | from huggingface_hub import model_info as hf_model_info 136 | hf_model_info(model_name) 137 | return True 138 | except Exception as e: 139 | print(f"Model {model_name} not found on HuggingFace: {e}") 140 | return False 141 | 142 | # Function to validate if a model exists on OpenRouter 143 | def validate_or_model(model_name): 144 | """Check if a model exists on OpenRouter""" 145 | try: 146 | # Fetch OpenRouter models and check if the model is in the list 147 | or_models = fetch_openrouter_models() 148 | return model_name in or_models 149 | except Exception as e: 150 | print(f"Error checking model {model_name} on OpenRouter: {e}") 151 | return False 152 | 153 | # Function to validate model existence on both platforms 154 | def validate_model(model_name): 155 | """Check if a model exists on either HuggingFace or OpenRouter""" 156 | if not model_name or model_name == "": 157 | return "Please enter a model name", "red" 158 | 159 | # Check HuggingFace first 160 | if validate_hf_model(model_name): 161 | return f"✅ Model '{model_name}' found on HuggingFace", "green" 162 | 163 | # Check OpenRouter if not found on HuggingFace 164 | if validate_or_model(model_name): 165 | return f"✅ Model '{model_name}' found on OpenRouter", "green" 166 | 167 | # If not found on either platform 168 | return f"❌ Model '{model_name}' not found on HuggingFace", "red" 169 | 170 | # Function to load model info 171 | def load_model_info(model_name): 172 | info = f"Model: {model_name}\n" 173 | info += f"Device: {DEVICE}\n" 174 | if CUDA_AVAILABLE: 175 | info += f"CUDA Version: {torch.version.cuda}\n" 176 | info += f"GPU: {torch.cuda.get_device_name(0)}\n" 177 | info += f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\n" 178 | else: 179 | info += "Running on CPU (training will be slow)\n" 180 | return info 181 | 182 | # Function to create dataset from conversation builder 183 | def create_dataset_gradio(conversation_history, system_prompt, model_name, output_file, num_examples, append_to_existing=False): 184 | # Save system prompt to temporary file 185 | with open("temp_system_prompt.md", "w", encoding="utf-8") as f: 186 | f.write(system_prompt) 187 | 188 | try: 189 | dataset = [] 190 | 191 | # If appending to existing dataset, load existing data 192 | if append_to_existing and os.path.exists(output_file): 193 | try: 194 | with open(output_file, 'r', encoding='utf8') as f: 195 | dataset = json.load(f) 196 | except Exception as e: 197 | print(f"Error loading existing dataset: {e}") 198 | dataset = [] # Start fresh if there's an error loading existing data 199 | 200 | # Generate multiple examples based on conversation history 201 | for i in range(num_examples): 202 | # Convert conversation history to messages format 203 | messages = [] 204 | for user_msg, assistant_msg in conversation_history: 205 | if user_msg: 206 | messages.append({"role": "user", "content": user_msg}) 207 | if assistant_msg: 208 | messages.append({"role": "assistant", "content": assistant_msg}) 209 | 210 | # Generate response based on the full conversation history 211 | from services.openrouter_api import generate 212 | assistant_response = generate( 213 | messages=messages, 214 | prompt_name="temp_system_prompt", 215 | model=model_name 216 | ) 217 | 218 | if assistant_response: 219 | # Create dataset entry with full conversation + generated response 220 | dataset_entry = { 221 | "messages": [ 222 | {"role": "system", "content": system_prompt} 223 | ] 224 | } 225 | 226 | # Add all messages from conversation history 227 | for user_msg, assistant_msg in conversation_history: 228 | if user_msg: 229 | dataset_entry["messages"].append({"role": "user", "content": user_msg}) 230 | if assistant_msg: 231 | dataset_entry["messages"].append({"role": "assistant", "content": assistant_msg}) 232 | 233 | # Add the generated response as the final assistant message 234 | dataset_entry["messages"].append({"role": "assistant", "content": assistant_response}) 235 | 236 | dataset.append(dataset_entry) 237 | 238 | # Save dataset to file 239 | with open(output_file, 'w', encoding='utf8') as f: 240 | json.dump(dataset, f, ensure_ascii=False, indent=2) 241 | 242 | # Read the created dataset 243 | if os.path.exists(output_file): 244 | with open(output_file, "r", encoding="utf-8") as f: 245 | dataset_content = f.read() 246 | action = "appended to" if append_to_existing and os.path.exists(output_file) else "saved to" 247 | # Return status message and parsed dataset for preview 248 | parsed_dataset = preview_dataset_content(dataset_content) 249 | return f"Dataset {action} successfully! Generated {num_examples} examples. Total entries: {len(dataset)}. Saved to {output_file}", parsed_dataset 250 | else: 251 | return "Error: Dataset file was not created", "" 252 | except Exception as e: 253 | return f"Error creating dataset: {str(e)}", "" 254 | 255 | # Function to load dataset 256 | def load_dataset_preview(dataset_file): 257 | try: 258 | if dataset_file is None: 259 | return "Please upload a dataset file" 260 | 261 | if dataset_file.name.endswith('.json'): 262 | with open(dataset_file.name, 'r', encoding='utf-8') as f: 263 | dataset = json.load(f) 264 | 265 | # Prepare data for DataFrame 266 | data = [] 267 | headers = ["Entry", "Role", "Content"] 268 | 269 | # Show first few entries 270 | for i, entry in enumerate(dataset[:5]): # Show first 5 entries 271 | for message in entry['messages']: 272 | data.append([f"Entry {i+1}", message['role'], message['content'][:200] + "..." if len(message['content']) > 200 else message['content']]) 273 | 274 | return { "headers": headers, "data": data } 275 | else: 276 | return "Unsupported file format. Please upload a JSON file." 277 | except Exception as e: 278 | return f"Error loading dataset: {str(e)}" 279 | # Function to preview dataset content (for use with string content) 280 | def preview_dataset_content(dataset_content): 281 | try: 282 | if not dataset_content: 283 | return "No dataset content to preview" 284 | 285 | # Parse the JSON content 286 | dataset = json.loads(dataset_content) 287 | 288 | # Prepare data for DataFrame 289 | data = [] 290 | headers = ["Entry", "Role", "Content"] 291 | 292 | # Show first few entries 293 | for i, entry in enumerate(dataset[:5]): # Show first 5 entries 294 | for message in entry['messages']: 295 | data.append([f"Entry {i+1}", message['role'], message['content'][:200] + "..." if len(message['content']) > 200 else message['content']]) 296 | 297 | return { "headers": headers, "data": data } 298 | except Exception as e: 299 | return f"Error previewing dataset: {str(e)}" 300 | 301 | # Training functions 302 | def start_training_wrapper(model_name, dataset_file, lora_r, lora_alpha, lora_dropout, 303 | target_modules, num_epochs, batch_size, grad_accum, learning_rate, 304 | gradient_checkpointing, fp16): 305 | global training_status 306 | 307 | # Handle dataset path 308 | if dataset_file is None: 309 | return "Please select a dataset file" 310 | 311 | dataset_file_path = dataset_file.name if hasattr(dataset_file, 'name') else dataset_file 312 | 313 | # Convert target_modules to proper format 314 | if target_modules == "all-linear": 315 | target_modules = "all-linear" 316 | elif target_modules == "q_proj,v_proj": 317 | target_modules = ["q_proj", "v_proj"] 318 | elif target_modules == "k_proj,o_proj": 319 | target_modules = ["k_proj", "o_proj"] 320 | 321 | # Start training 322 | success, message = train_model( 323 | model_id=model_name, 324 | dataset_path=dataset_file_path, 325 | output_dir="tuned_model", 326 | lora_r=lora_r, 327 | lora_alpha=lora_alpha, 328 | lora_dropout=lora_dropout, 329 | target_modules=target_modules, 330 | num_train_epochs=num_epochs, 331 | per_device_train_batch_size=batch_size, 332 | gradient_accumulation_steps=grad_accum, 333 | learning_rate=learning_rate, 334 | gradient_checkpointing=gradient_checkpointing, 335 | fp16=fp16 336 | ) 337 | 338 | training_status = message 339 | return message 340 | 341 | def stop_training(): 342 | return "Training stopped." 343 | 344 | # Model conversion functions 345 | def convert_to_gguf_wrapper(model_path, quantization_type, output_format): 346 | # Convert output_format to file extension 347 | ext = ".gguf" if output_format == "gguf" else ".bin" 348 | 349 | # Generate output path 350 | model_name = os.path.basename(model_path) 351 | output_path = f"converted_models/{model_name}{ext}" 352 | 353 | # Call the actual conversion function 354 | success, message = convert_to_gguf(model_path, output_path, quantization_type) 355 | return message 356 | 357 | def push_to_ollama_wrapper(model_path, ollama_model_name): 358 | # Call the actual Ollama push function 359 | success, message = push_to_ollama(model_path, ollama_model_name) 360 | return message 361 | 362 | # Template management functions 363 | def fetch_template_from_model_wrapper(model_name): 364 | success, template = fetch_template_from_model(model_name) 365 | if success: 366 | return template 367 | else: 368 | return f"Error: {template}" 369 | 370 | def save_custom_template_wrapper(template_content): 371 | success, message = save_custom_template(template_content, "custom_template") 372 | return message 373 | 374 | # Conversation builder functions 375 | def add_user_message(history, message): 376 | """Add user message to conversation history""" 377 | if history is None: 378 | history = [] 379 | if message: 380 | history.append((message, None)) 381 | return history, "" 382 | 383 | def add_assistant_message(history, message): 384 | """Add assistant message to conversation history""" 385 | if history is None: 386 | history = [] 387 | if message and len(history) > 0: 388 | # Update the last entry with the assistant message 389 | history[-1] = (history[-1][0], message) 390 | return history, "" 391 | 392 | def clear_conversation(): 393 | """Clear conversation history""" 394 | return [], "", "" 395 | 396 | def save_conversation_template(history, template_name): 397 | """Save conversation as a template""" 398 | if not template_name: 399 | return "Please enter a template name" 400 | 401 | if not history: 402 | return "Conversation is empty" 403 | 404 | try: 405 | # Convert history to JSON format 406 | template_data = [] 407 | for user_msg, assistant_msg in history: 408 | if user_msg: 409 | template_data.append({"role": "user", "content": user_msg}) 410 | if assistant_msg: 411 | template_data.append({"role": "assistant", "content": assistant_msg}) 412 | 413 | # Save to file 414 | templates_dir = "templates" 415 | if not os.path.exists(templates_dir): 416 | os.makedirs(templates_dir) 417 | 418 | template_path = os.path.join(templates_dir, f"{template_name}.json") 419 | with open(template_path, "w", encoding="utf-8") as f: 420 | json.dump(template_data, f, ensure_ascii=False, indent=2) 421 | 422 | return f"Template '{template_name}' saved successfully!" 423 | except Exception as e: 424 | return f"Error saving template: {str(e)}" 425 | 426 | def load_conversation_templates(): 427 | """Load available conversation templates""" 428 | templates_dir = "templates" 429 | if not os.path.exists(templates_dir): 430 | return [] 431 | 432 | templates = [] 433 | for file in os.listdir(templates_dir): 434 | if file.endswith(".json"): 435 | templates.append(file[:-5]) # Remove .json extension 436 | return templates 437 | 438 | # Settings functions 439 | def save_api_keys(openrouter_key_input): 440 | """Save API keys to environment variables and file""" 441 | global openrouter_key 442 | if openrouter_key_input: 443 | openrouter_key = openrouter_key_input 444 | os.environ["OPENROUTER_KEY"] = openrouter_key_input 445 | # Update the OpenRouter client with the new key 446 | from services.openrouter_api import update_openrouter_client 447 | update_openrouter_client(openrouter_key_input) 448 | # Save to file 449 | try: 450 | keys = {"openrouter_key": openrouter_key_input} 451 | with open(API_KEYS_FILE, 'w') as f: 452 | json.dump(keys, f) 453 | return "API keys saved successfully!" 454 | except Exception as e: 455 | return f"API keys saved to environment, but failed to save to file: {e}" 456 | return "Please enter a valid OpenRouter API key." 457 | 458 | def validate_openrouter_key(key_to_test): 459 | """Validate the OpenRouter API key""" 460 | if not key_to_test: 461 | return "No API key provided" 462 | 463 | try: 464 | # Test the key by fetching models 465 | url = "https://openrouter.ai/api/v1/models" 466 | headers = {"Authorization": f"Bearer {key_to_test}"} 467 | response = requests.get(url, headers=headers, timeout=10) 468 | 469 | if response.status_code == 200: 470 | return "✅ API key is valid" 471 | elif response.status_code == 401: 472 | return "❌ API key is invalid" 473 | else: 474 | return f"⚠️ Unexpected response: {response.status_code}" 475 | except Exception as e: 476 | return f"❌ Error validating key: {str(e)}" 477 | 478 | def get_system_info(): 479 | """Get system information""" 480 | info = "## System Information\n\n" 481 | 482 | # CPU Info 483 | info += f"**CPU:** {psutil.cpu_count()} cores\n" 484 | info += f"**RAM:** {psutil.virtual_memory().total / (1024**3):.1f} GB\n" 485 | 486 | # GPU Info 487 | if CUDA_AVAILABLE: 488 | info += f"**GPU:** {torch.cuda.get_device_name(0)}\n" 489 | info += f"**VRAM:** {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB\n" 490 | info += f"**CUDA Version:** {torch.version.cuda}\n" 491 | else: 492 | # Try to get GPU info using GPUtil 493 | try: 494 | gpus = GPUtil.getGPUs() 495 | if gpus: 496 | gpu = gpus[0] 497 | info += f"**GPU:** {gpu.name}\n" 498 | info += f"**VRAM:** {gpu.memoryTotal / 1024:.1f} GB\n" 499 | else: 500 | info += "**GPU:** Not available (running on CPU)\n" 501 | except: 502 | info += "**GPU:** Not available (running on CPU)\n" 503 | 504 | # Disk Space 505 | disk_usage = psutil.disk_usage('/') 506 | info += f"**Disk Space:** {disk_usage.total / (1024**3):.1f} GB total\n" 507 | 508 | return info 509 | 510 | def get_help_info(): 511 | """Get help information""" 512 | try: 513 | with open("help_text.md", "r", encoding="utf-8") as f: 514 | help_text = f.read() 515 | return help_text 516 | except FileNotFoundError: 517 | return "Help file not found." 518 | except Exception as e: 519 | return f"Error reading help file: {str(e)}" 520 | 521 | # Load API keys at startup 522 | load_api_keys() 523 | 524 | # Main Gradio interface 525 | with gr.Blocks(title="QTune", css=""" 526 | .main-action-btn { 527 | background: linear-gradient(45deg, #FF6B6B, #4ECDC4); 528 | border: none; 529 | color: white; 530 | padding: 12px 24px; 531 | text-align: center; 532 | text-decoration: none; 533 | display: inline-block; 534 | font-size: 16px; 535 | margin: 4px 2px; 536 | cursor: pointer; 537 | border-radius: 12px; 538 | font-weight: bold; 539 | box-shadow: 0 4px 8px rgba(0,0,0,0.2); 540 | transition: all 0.3s ease; 541 | } 542 | .main-action-btn:hover { 543 | transform: translateY(-2px); 544 | box-shadow: 0 6px 12px rgba(0,0,0,0.3); 545 | } 546 | .model-history-btn { 547 | margin: 2px; 548 | font-size: 12px; 549 | padding: 4px 8px; 550 | } 551 | """) as demo: 552 | gr.Markdown("# 🚀 QTune") 553 | gr.Markdown("### Fine-tune language models on consumer GPU") 554 | 555 | with gr.Tabs(): 556 | # Model Selection Tab 557 | with gr.TabItem("🤖 Model Selection"): 558 | gr.Markdown("## Select a Model for Fine-tuning") 559 | 560 | with gr.Row(): 561 | with gr.Column(scale=1): 562 | model_history_container = gr.Column(visible=False) 563 | with model_history_container: 564 | gr.Markdown("### 🕒 Recent Models") 565 | model_history_btn1 = gr.Button(visible=False, variant="secondary", size="sm", elem_classes=["model-history-btn"]) 566 | model_history_btn2 = gr.Button(visible=False, variant="secondary", size="sm", elem_classes=["model-history-btn"]) 567 | model_history_btn3 = gr.Button(visible=False, variant="secondary", size="sm", elem_classes=["model-history-btn"]) 568 | model_history_btn4 = gr.Button(visible=False, variant="secondary", size="sm", elem_classes=["model-history-btn"]) 569 | model_history_btn5 = gr.Button(visible=False, variant="secondary", size="sm", elem_classes=["model-history-btn"]) 570 | 571 | with gr.Column(scale=2): 572 | model_dropdown = gr.Dropdown( 573 | choices=fetch_hf_models(), 574 | label="🎯 Popular Models", 575 | value="google/gemma-3-4b-it" if fetch_hf_models() else None, 576 | allow_custom_value=True 577 | ) 578 | 579 | # Add a search function for the model dropdown 580 | def update_model_choices(query): 581 | """Update model choices based on search query""" 582 | if query and len(query) > 2: # Only search if query is long enough 583 | search_results = search_all_models(query) 584 | # Include the current query as a valid choice even if not in search results 585 | if query not in search_results: 586 | search_results.insert(0, query) 587 | return gr.Dropdown(choices=search_results) 588 | else: 589 | # If no query or too short, show default HF models 590 | default_models = fetch_hf_models() 591 | return gr.Dropdown(choices=default_models) 592 | 593 | # Add Apply Model button and status display 594 | with gr.Row(): 595 | apply_model_btn = gr.Button("Применить модель", variant="primary") 596 | model_validation_status = gr.Textbox( 597 | label="Статус модели", 598 | interactive=False 599 | ) 600 | 601 | gr.Markdown("[Browse Hugging Face Models](https://huggingface.co/models)") 602 | 603 | # Function to handle model application 604 | def apply_model(model_name): 605 | """Validate and apply the selected model""" 606 | status, color = validate_model(model_name) 607 | # Return the status message 608 | return status 609 | 610 | # Connect the Apply Model button 611 | apply_model_btn.click( 612 | fn=apply_model, 613 | inputs=model_dropdown, 614 | outputs=model_validation_status 615 | ) 616 | 617 | with gr.Row(): 618 | with gr.Column(scale=2): 619 | model_info = gr.Textbox(label="📊 Model Information", interactive=False, lines=8) 620 | 621 | def update_model_info_and_history(model_name): 622 | # Add to history only if it's a valid model selection 623 | if model_name and model_name != "": 624 | # Check if model_name is in the current list or validate it on HuggingFace 625 | current_models = fetch_hf_models() or [] 626 | if model_name in current_models or validate_hf_model(model_name): 627 | history = add_to_model_history(model_name) 628 | else: 629 | history = model_history # Use existing history 630 | else: 631 | history = model_history # Use existing history 632 | # Update model info 633 | info = load_model_info(model_name) 634 | # Update history buttons 635 | updates = [] 636 | for i in range(5): 637 | if i < len(history): 638 | updates.extend([gr.Button(visible=True, value=history[i]), history[i]]) 639 | else: 640 | updates.extend([gr.Button(visible=False), ""]) 641 | return [info, gr.Column(visible=len(history) > 0)] + updates 642 | 643 | def select_history_model(model_name): 644 | return model_name 645 | 646 | model_dropdown.change( 647 | fn=update_model_info_and_history, 648 | inputs=model_dropdown, 649 | outputs=[model_info, model_history_container, 650 | model_history_btn1, model_history_btn1, 651 | model_history_btn2, model_history_btn2, 652 | model_history_btn3, model_history_btn3, 653 | model_history_btn4, model_history_btn4, 654 | model_history_btn5, model_history_btn5] 655 | ) 656 | 657 | # Connect history buttons 658 | model_history_btn1.click(fn=select_history_model, inputs=model_history_btn1, outputs=model_dropdown) 659 | model_history_btn2.click(fn=select_history_model, inputs=model_history_btn2, outputs=model_dropdown) 660 | model_history_btn3.click(fn=select_history_model, inputs=model_history_btn3, outputs=model_dropdown) 661 | model_history_btn4.click(fn=select_history_model, inputs=model_history_btn4, outputs=model_dropdown) 662 | model_history_btn5.click(fn=select_history_model, inputs=model_history_btn5, outputs=model_dropdown) 663 | 664 | # Dataset Preparation Tab 665 | with gr.TabItem("📂 Dataset Preparation"): 666 | gr.Markdown("## Prepare Your Dataset") 667 | 668 | with gr.Tabs(): 669 | with gr.TabItem("📝 Create Dataset"): 670 | with gr.Row(): 671 | with gr.Column(): 672 | gr.Markdown("### 💬 Conversation Builder") 673 | conversation_builder = gr.Chatbot( 674 | label="Build Conversation Template", 675 | height=400 676 | ) 677 | 678 | with gr.Row(): 679 | user_msg = gr.Textbox( 680 | label="User Message", 681 | placeholder="Enter user message...", 682 | scale=3 683 | ) 684 | add_user_btn = gr.Button("👤 Add User Message", scale=1) 685 | 686 | with gr.Row(): 687 | assistant_msg = gr.Textbox( 688 | label="Assistant Message", 689 | placeholder="Enter assistant message...", 690 | scale=3 691 | ) 692 | add_assistant_btn = gr.Button("🤖 Add Assistant Message", scale=1) 693 | 694 | with gr.Row(): 695 | clear_conv_btn = gr.Button("🗑️ Clear Conversation") 696 | save_template_btn = gr.Button("💾 Save Template") 697 | 698 | template_name = gr.Textbox( 699 | label="Template Name", 700 | placeholder="Enter template name..." 701 | ) 702 | 703 | with gr.Column(): 704 | gr.Markdown("### ⚙️ Dataset Generation") 705 | system_prompt_input = gr.Textbox( 706 | label="System Prompt", 707 | lines=3, 708 | placeholder="Enter system prompt for the model..." 709 | ) 710 | 711 | num_examples = gr.Number( 712 | label="Number of Examples to Generate", 713 | value=1, 714 | precision=0 715 | ) 716 | 717 | 718 | model_selector = gr.Dropdown( 719 | choices=fetch_openrouter_models() or ["⚠️ OpenRouter not connected - Please set API key in Settings"], 720 | label="Generation Model", 721 | value=(fetch_openrouter_models()[0] if fetch_openrouter_models() else 722 | "⚠️ OpenRouter not connected - Please set API key in Settings"), 723 | allow_custom_value=True 724 | ) 725 | gr.Markdown("[Browse OpenRouter Models](https://openrouter.ai/models)") 726 | output_filename = gr.Textbox( 727 | label="Output Filename", 728 | value="dataset.json" 729 | ) 730 | 731 | append_to_existing = gr.Checkbox( 732 | label="Append to existing dataset", 733 | value=False 734 | ) 735 | 736 | create_btn = gr.Button("🚀 Create Dataset", variant="primary", elem_classes=["main-action-btn"]) 737 | 738 | gr.Markdown("### 📊 Dataset Creation") 739 | dataset_output = gr.Textbox( 740 | label="Status", 741 | interactive=False 742 | ) 743 | dataset_preview = gr.Dataframe( 744 | label="Dataset Preview", 745 | interactive=False 746 | ) 747 | 748 | with gr.TabItem("📁 Load Dataset"): 749 | with gr.Row(): 750 | dataset_file = gr.File(label="📁 Upload Dataset File", file_types=[".json"]) 751 | with gr.Column(): 752 | load_dataset_btn = gr.Button("📥 Load Dataset") 753 | dataset_info = gr.Dataframe( 754 | label="📊 Dataset Information", 755 | interactive=False 756 | ) 757 | 758 | # Event handlers for Dataset Preparation 759 | create_btn.click( 760 | fn=create_dataset_gradio, 761 | inputs=[conversation_builder, system_prompt_input, model_selector, output_filename, num_examples, append_to_existing], 762 | outputs=[dataset_output, dataset_preview] 763 | ) 764 | 765 | load_dataset_btn.click( 766 | fn=load_dataset_preview, 767 | inputs=dataset_file, 768 | outputs=dataset_info 769 | ) 770 | 771 | # Event handlers for Conversation Builder 772 | add_user_btn.click( 773 | fn=add_user_message, 774 | inputs=[conversation_builder, user_msg], 775 | outputs=[conversation_builder, user_msg] 776 | ) 777 | 778 | add_assistant_btn.click( 779 | fn=add_assistant_message, 780 | inputs=[conversation_builder, assistant_msg], 781 | outputs=[conversation_builder, assistant_msg] 782 | ) 783 | 784 | clear_conv_btn.click( 785 | fn=clear_conversation, 786 | inputs=None, 787 | outputs=[conversation_builder, user_msg, assistant_msg] 788 | ) 789 | 790 | save_template_btn.click( 791 | fn=save_conversation_template, 792 | inputs=[conversation_builder, template_name], 793 | outputs=dataset_output 794 | ) 795 | 796 | # Training Configuration Tab 797 | with gr.TabItem("⚙️ Training Configuration"): 798 | gr.Markdown("## Configure Training Parameters") 799 | 800 | with gr.Row(): 801 | with gr.Column(): 802 | gr.Markdown("### 🔧 Primary Parameters") 803 | lora_r = gr.Slider( 804 | minimum=1, maximum=128, value=16, step=1, 805 | label="Rank (r)" 806 | ) 807 | lora_alpha = gr.Slider( 808 | minimum=1, maximum=256, value=32, step=1, 809 | label="Alpha" 810 | ) 811 | num_epochs = gr.Slider( 812 | minimum=1, maximum=50, value=8, step=1, 813 | label="Number of Epochs" 814 | ) 815 | batch_size = gr.Slider( 816 | minimum=1, maximum=16, value=1, step=1, 817 | label="Batch Size" 818 | ) 819 | 820 | with gr.Column(): 821 | with gr.Accordion("🔧 QLoRA Parameters", open=False): 822 | lora_dropout = gr.Slider( 823 | minimum=0.0, maximum=0.5, value=0.05, step=0.01, 824 | label="Dropout" 825 | ) 826 | target_modules = gr.Radio( 827 | choices=["all-linear", "q_proj,v_proj", "k_proj,o_proj"], 828 | value="all-linear", 829 | label="Target Modules" 830 | ) 831 | 832 | with gr.Accordion("💾 Memory Optimization", open=False): 833 | gradient_checkpointing = gr.Checkbox( 834 | label="Gradient Checkpointing", 835 | value=True 836 | ) 837 | fp16 = gr.Checkbox( 838 | label="FP16 Precision", 839 | value=True 840 | ) 841 | optim = gr.Dropdown( 842 | choices=["paged_adamw_8bit", "adamw_torch", "adamw_hf"], 843 | value="paged_adamw_8bit", 844 | label="Optimizer", 845 | visible=False 846 | ) 847 | 848 | with gr.Accordion("📊 Logging & Saving", open=False): 849 | logging_steps = gr.Number( 850 | label="Logging Steps", 851 | value=20, 852 | precision=0 853 | ) 854 | save_steps = gr.Number( 855 | label="Save Steps", 856 | value=200, 857 | precision=0 858 | ) 859 | save_strategy = gr.Radio( 860 | choices=["steps", "epoch"], 861 | value="steps", 862 | label="Save Strategy" 863 | ) 864 | grad_accum = gr.Slider( 865 | minimum=1, maximum=32, value=8, step=1, 866 | label="Gradient Accumulation Steps" 867 | ) 868 | learning_rate = gr.Number( 869 | label="Learning Rate", 870 | value=1e-4, 871 | precision=6 872 | ) 873 | 874 | # Add some spacing 875 | gr.Markdown("") 876 | 877 | # Training Execution Tab 878 | with gr.TabItem("🏃 Training"): 879 | gr.Markdown("## Start Training Process") 880 | 881 | with gr.Row(): 882 | train_btn = gr.Button("🚀 Start Training", variant="primary", elem_classes=["main-action-btn"]) 883 | stop_btn = gr.Button("⏹️ Stop Training", variant="secondary") 884 | 885 | with gr.Row(): 886 | training_progress = gr.Slider( 887 | minimum=0, maximum=100, value=0, 888 | label="Training Progress", 889 | interactive=False 890 | ) 891 | training_progress_label = gr.Label(value="0% done", label="Progress") 892 | 893 | with gr.Row(): 894 | training_logs = gr.Code( 895 | label="📋 Training Logs", 896 | lines=20, 897 | interactive=False 898 | ) 899 | 900 | # Event handlers for Training 901 | train_btn.click( 902 | fn=start_training_wrapper, 903 | inputs=[model_dropdown, dataset_file, lora_r, lora_alpha, lora_dropout, target_modules, 904 | num_epochs, batch_size, grad_accum, learning_rate, gradient_checkpointing, fp16], 905 | outputs=training_logs 906 | ) 907 | 908 | stop_btn.click( 909 | fn=stop_training, 910 | outputs=training_logs 911 | ) 912 | 913 | # Model Conversion Tab 914 | with gr.TabItem("🔄 Model Conversion"): 915 | gr.Markdown("## Convert Model to GGUF and Quantize") 916 | 917 | with gr.Row(): 918 | with gr.Column(): 919 | gr.Markdown("### 📦 GGUF Conversion") 920 | quantization_type = gr.Dropdown( 921 | choices=["Q4_K_M", "Q5_K_M", "Q8_0", "F16"], 922 | value="Q4_K_M", 923 | label="Quantization Type" 924 | ) 925 | output_format = gr.Radio( 926 | choices=["gguf", "bin"], 927 | value="gguf", 928 | label="Output Format" 929 | ) 930 | model_path_input = gr.Textbox( 931 | label="📁 Model Path", 932 | placeholder="Path to your trained model" 933 | ) 934 | convert_btn = gr.Button("🔨 Convert Model", variant="primary") 935 | 936 | with gr.Column(): 937 | gr.Markdown("### 🐳 Ollama Integration") 938 | ollama_model_name = gr.Textbox( 939 | label="🏷️ Ollama Model Name", 940 | placeholder="Enter model name for Ollama" 941 | ) 942 | push_to_ollama_btn = gr.Button("📤 Push to Ollama", variant="secondary") 943 | 944 | with gr.Row(): 945 | conversion_logs = gr.Textbox( 946 | label="📋 Conversion Logs", 947 | lines=10, 948 | interactive=False 949 | ) 950 | 951 | # Event handlers for Conversion 952 | convert_btn.click( 953 | fn=convert_to_gguf_wrapper, 954 | inputs=[model_path_input, quantization_type, output_format], 955 | outputs=conversion_logs 956 | ) 957 | 958 | push_to_ollama_btn.click( 959 | fn=push_to_ollama_wrapper, 960 | inputs=[model_path_input, ollama_model_name], 961 | outputs=conversion_logs 962 | ) 963 | 964 | # Template Management Tab 965 | with gr.TabItem("📝 Template Management"): 966 | gr.Markdown("## Chat Template Configuration") 967 | 968 | with gr.Row(): 969 | with gr.Column(): 970 | gr.Markdown("### 🤖 Automatic Template") 971 | auto_template_btn = gr.Button("🔍 Fetch Template from Model") 972 | template_preview = gr.Textbox( 973 | label="👁️ Template Preview", 974 | lines=10, 975 | interactive=False 976 | ) 977 | 978 | with gr.Column(): 979 | gr.Markdown("### ✏️ Custom Template") 980 | custom_template = gr.Textbox( 981 | label="📝 Custom Template", 982 | lines=10, 983 | placeholder="Enter your custom chat template here..." 984 | ) 985 | save_template_btn = gr.Button("💾 Save Template") 986 | template_list = gr.Dropdown( 987 | choices=load_conversation_templates(), 988 | label="💾 Available Templates", 989 | interactive=True 990 | ) 991 | refresh_templates_btn = gr.Button("🔄 Refresh Templates") 992 | 993 | with gr.Row(): 994 | template_status = gr.Textbox( 995 | label="📊 Template Status", 996 | interactive=False 997 | ) 998 | 999 | # Event handlers for Template Management 1000 | auto_template_btn.click( 1001 | fn=fetch_template_from_model_wrapper, 1002 | inputs=model_dropdown, 1003 | outputs=template_preview 1004 | ) 1005 | 1006 | save_template_btn.click( 1007 | fn=save_custom_template_wrapper, 1008 | inputs=custom_template, 1009 | outputs=template_status 1010 | ) 1011 | 1012 | def refresh_templates(): 1013 | return gr.Dropdown(choices=load_conversation_templates()) 1014 | 1015 | refresh_templates_btn.click( 1016 | fn=refresh_templates, 1017 | outputs=template_list 1018 | ) 1019 | 1020 | def load_selected_template(template_name): 1021 | if not template_name: 1022 | return "" 1023 | try: 1024 | templates_dir = "templates" 1025 | template_path = os.path.join(templates_dir, f"{template_name}.json") 1026 | with open(template_path, "r", encoding="utf-8") as f: 1027 | template_data = json.load(f) 1028 | # Convert to string format 1029 | template_str = json.dumps(template_data, ensure_ascii=False, indent=2) 1030 | return template_str 1031 | except Exception as e: 1032 | return f"Error loading template: {str(e)}" 1033 | 1034 | template_list.change( 1035 | fn=load_selected_template, 1036 | inputs=template_list, 1037 | outputs=custom_template 1038 | ) 1039 | 1040 | # Settings Tab 1041 | with gr.TabItem("⚙️ Settings"): 1042 | gr.Markdown("# Settings") 1043 | 1044 | with gr.Tabs(): 1045 | with gr.TabItem("🔑 API Keys"): 1046 | gr.Markdown("## API Key Configuration") 1047 | 1048 | with gr.Row(): 1049 | with gr.Column(): 1050 | openrouter_key_input = gr.Textbox( 1051 | label="OpenRouter API Key", 1052 | value=openrouter_key or "", # Use loaded key or empty string 1053 | placeholder="Enter your OpenRouter API key", 1054 | type="password" 1055 | ) 1056 | with gr.Row(): 1057 | save_keys_btn = gr.Button("💾 Save API Keys", variant="primary") 1058 | validate_key_btn = gr.Button("🔍 Validate Key", variant="secondary") 1059 | 1060 | with gr.Column(): 1061 | gr.Markdown(""" 1062 | ### Where to get API keys: 1063 | - **OpenRouter**: [https://openrouter.ai/](https://openrouter.ai/) 1064 | - Required for dataset generation using large models 1065 | """) 1066 | key_status = gr.Markdown(f"**Key Status:** {'✅ Key saved' if openrouter_key else '❌ No key saved'}") 1067 | 1068 | keys_status = gr.Textbox( 1069 | label="Status", 1070 | interactive=False 1071 | ) 1072 | 1073 | with gr.TabItem("🖥️ System Info"): 1074 | gr.Markdown("## System Information") 1075 | system_info = gr.Markdown() 1076 | refresh_system_btn = gr.Button("🔄 Refresh System Info") 1077 | 1078 | with gr.TabItem("❓ Help"): 1079 | gr.Markdown("## Help & Documentation") 1080 | help_info = gr.Markdown() 1081 | refresh_help_btn = gr.Button("🔄 Refresh Help") 1082 | 1083 | # Event handlers for Settings 1084 | def save_and_update_status(key_input): 1085 | # Save the API key 1086 | save_result = save_api_keys(key_input) 1087 | # Update the key status display 1088 | status_text = f"**Key Status:** {'✅ Key saved' if openrouter_key else '❌ No key saved'}" 1089 | return save_result, gr.Markdown(status_text) 1090 | 1091 | def refresh_model_selector(): 1092 | """Refresh the model selector with updated API key""" 1093 | models = fetch_openrouter_models() or ["⚠️ OpenRouter not connected - Please set API key in Settings"] 1094 | first_model = models[0] if models else "⚠️ OpenRouter not connected - Please set API key in Settings" 1095 | return gr.Dropdown(choices=models, value=first_model) 1096 | 1097 | save_keys_btn.click( 1098 | fn=save_and_update_status, 1099 | inputs=openrouter_key_input, 1100 | outputs=[keys_status, key_status] 1101 | ).then( 1102 | fn=refresh_model_selector, 1103 | inputs=None, 1104 | outputs=model_selector 1105 | ) 1106 | 1107 | validate_key_btn.click( 1108 | fn=validate_openrouter_key, 1109 | inputs=openrouter_key_input, 1110 | outputs=keys_status 1111 | ) 1112 | 1113 | refresh_system_btn.click( 1114 | fn=get_system_info, 1115 | outputs=system_info 1116 | ) 1117 | 1118 | refresh_help_btn.click( 1119 | fn=get_help_info, 1120 | outputs=help_info 1121 | ) 1122 | 1123 | # Initialize system info and help on load 1124 | demo.load(get_system_info, None, system_info) 1125 | demo.load(get_help_info, None, help_info) 1126 | 1127 | # Launch the app 1128 | if __name__ == "__main__": 1129 | demo.launch() --------------------------------------------------------------------------------