├── app.py ├── .gitignore ├── ai_gradio ├── __init__.py └── providers │ ├── kokoro_gradio.py │ ├── cerebras_gradio.py │ ├── anthropic_gradio.py │ ├── browser_use_gradio.py │ ├── swarms_gradio.py │ ├── langchain_gradio.py │ ├── xai_gradio.py │ ├── cohere_gradio.py │ ├── together_gradio.py │ ├── sambanova_gradio.py │ ├── smolagents_gradio.py │ ├── transformers_gradio.py │ ├── lumaai_gradio.py │ ├── fireworks_gradio.py │ ├── hyperbolic_gradio.py │ ├── perplexity_gradio.py │ ├── crewai_gradio.py │ ├── mistral_gradio.py │ ├── qwen_gradio.py │ ├── deepseek_gradio.py │ ├── replicate_gradio.py │ └── __init__.py ├── pyproject.toml └── README.md /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import ai_gradio 3 | 4 | 5 | gr.load( 6 | name='gemini:gemini-2.0-flash-lite-preview-02-05', 7 | src=ai_gradio.registry, 8 | coder=True 9 | ).launch() 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | __pycache__/ 3 | *.pyc 4 | 5 | # Virtual environment 6 | env/ 7 | .venv/ 8 | venv/** 9 | venv/ 10 | 11 | # Package artifacts 12 | dist/ 13 | build/ 14 | *.egg-info/ 15 | .codegpt 16 | .vscode 17 | -------------------------------------------------------------------------------- /ai_gradio/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | try: 4 | __version__ = version("ai-gradio") 5 | except Exception: 6 | __version__ = "unknown" 7 | 8 | from .providers import registry 9 | 10 | __all__ = ["registry"] 11 | -------------------------------------------------------------------------------- /ai_gradio/providers/kokoro_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import soundfile as sf 3 | from kokoro_onnx import Kokoro 4 | import gradio as gr 5 | from typing import Callable 6 | from huggingface_hub import hf_hub_download 7 | 8 | __version__ = "0.0.1" 9 | 10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable): 11 | # Download model and voices.json from HuggingFace Hub 12 | model_path = hf_hub_download( 13 | repo_id="hexgrad/Kokoro-82M", 14 | filename="kokoro-v0_19.onnx", 15 | repo_type="model" 16 | ) 17 | 18 | voices_path = hf_hub_download( 19 | repo_id="akhaliq/Kokoro-82M", 20 | filename="voices.json", 21 | repo_type="model" 22 | ) 23 | 24 | def chat_response(message, history, voice="af_sarah", speed=1.0, lang="en-us"): 25 | try: 26 | kokoro = Kokoro(model_path, voices_path) 27 | samples, sample_rate = kokoro.create( 28 | text=message, 29 | voice=voice, 30 | speed=speed, 31 | lang=lang 32 | ) 33 | 34 | # Save to temporary file with unique name based on history length 35 | output_path = f"response_{len(history)}.wav" 36 | sf.write(output_path, samples, sample_rate) 37 | 38 | return { 39 | "role": "assistant", 40 | "content": { 41 | "path": output_path 42 | } 43 | } 44 | 45 | except Exception as e: 46 | return f"Error generating audio: {str(e)}" 47 | 48 | return chat_response 49 | 50 | def registry(name: str, **kwargs): 51 | """Register kokoro TTS interface""" 52 | 53 | interface = gr.ChatInterface( 54 | fn=get_fn(name, lambda x: x, lambda x: x), 55 | additional_inputs=[ 56 | gr.Dropdown( 57 | choices=["af_sarah", "en_jenny", "en_ryan"], 58 | value="af_sarah", 59 | label="Voice" 60 | ), 61 | gr.Slider( 62 | minimum=0.5, 63 | maximum=2.0, 64 | value=1.0, 65 | step=0.1, 66 | label="Speed" 67 | ), 68 | gr.Dropdown( 69 | choices=["en-us", "en-gb"], 70 | value="en-us", 71 | label="Language" 72 | ) 73 | ], 74 | title="Kokoro Text-to-Speech", 75 | description="Generate speech from text using Kokoro TTS", 76 | **kwargs 77 | ) 78 | 79 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/cerebras_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | from typing import Callable 4 | from cerebras.cloud.sdk import Cerebras 5 | 6 | __version__ = "0.0.1" 7 | 8 | 9 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 10 | def fn(message, history): 11 | inputs = preprocess(message, history) 12 | client = Cerebras( 13 | api_key=api_key, 14 | ) 15 | completion = client.chat.completions.create( 16 | model=model_name, 17 | messages=inputs["messages"], 18 | stream=True, 19 | max_completion_tokens=1024, 20 | temperature=0.2, 21 | top_p=1 22 | ) 23 | 24 | # Streaming response to Gradio ChatInterface UI 25 | response_text = "" 26 | for chunk in completion: 27 | delta = chunk.choices[0].delta.content or "" 28 | response_text += delta 29 | yield postprocess(response_text) 30 | 31 | return fn 32 | 33 | 34 | def get_interface_args(pipeline): 35 | if pipeline == "chat": 36 | inputs = None 37 | outputs = None 38 | 39 | def preprocess(message, history): 40 | messages = [] 41 | for user_msg, assistant_msg in history: 42 | messages.append({"role": "user", "content": user_msg}) 43 | messages.append({"role": "assistant", "content": assistant_msg}) 44 | messages.append({"role": "user", "content": message}) 45 | return {"messages": messages} 46 | 47 | postprocess = lambda x: x 48 | else: 49 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 50 | return inputs, outputs, preprocess, postprocess 51 | 52 | 53 | def get_pipeline(model_name): 54 | # Determine the pipeline type based on the model name 55 | # For simplicity, assuming all models are chat models at the moment 56 | return "chat" 57 | 58 | 59 | def registry(name: str, token: str | None = None, **kwargs): 60 | """ 61 | Create a Gradio Interface for a model on Cerebras. 62 | 63 | Parameters: 64 | - name (str): The name of the model on Cerebras. 65 | - token (str, optional): The API key for Cerebras. 66 | """ 67 | api_key = token or os.environ.get("CEREBRAS_API_KEY") 68 | if not api_key: 69 | raise ValueError("CEREBRAS_API_KEY environment variable is not set.") 70 | 71 | pipeline = get_pipeline(name) 72 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 73 | fn = get_fn(name, preprocess, postprocess, api_key) 74 | 75 | if pipeline == "chat": 76 | interface = gr.ChatInterface(fn=fn, **kwargs) 77 | else: 78 | # For other pipelines, create a standard Interface (not implemented yet) 79 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 80 | 81 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/anthropic_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import anthropic 3 | import gradio as gr 4 | from typing import Callable 5 | 6 | __version__ = "0.0.1" 7 | 8 | 9 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 10 | def fn(message, history): 11 | inputs = preprocess(message, history) 12 | client = anthropic.Anthropic(api_key=api_key) 13 | with client.messages.stream( 14 | model=model_name, 15 | max_tokens=1000, 16 | messages=inputs["messages"] 17 | ) as stream: 18 | response_text = "" 19 | for chunk in stream: 20 | if chunk.type == "content_block_delta": 21 | delta = chunk.delta.text 22 | response_text += delta 23 | yield postprocess(response_text) 24 | 25 | return fn 26 | 27 | 28 | def get_interface_args(pipeline): 29 | if pipeline == "chat": 30 | inputs = None 31 | outputs = None 32 | 33 | def preprocess(message, history): 34 | messages = [] 35 | for user_msg, assistant_msg in history: 36 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) 37 | messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) 38 | messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) 39 | return {"messages": messages} 40 | 41 | postprocess = lambda x: x # No post-processing needed 42 | else: 43 | # Add other pipeline types when they will be needed 44 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 45 | return inputs, outputs, preprocess, postprocess 46 | 47 | 48 | def get_pipeline(model_name): 49 | # Determine the pipeline type based on the model name 50 | # For simplicity, assuming all models are chat models at the moment 51 | return "chat" 52 | 53 | 54 | def registry(name: str, token: str | None = None, **kwargs): 55 | """ 56 | Create a Gradio Interface for a model on Anthropic. 57 | 58 | Parameters: 59 | - name (str): The name of the Anthropic model. 60 | - token (str, optional): The API key for Anthropic. 61 | """ 62 | api_key = token or os.environ.get("ANTHROPIC_API_KEY") 63 | if not api_key: 64 | raise ValueError("ANTHROPIC_API_KEY environment variable is not set.") 65 | 66 | pipeline = get_pipeline(name) 67 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 68 | fn = get_fn(name, preprocess, postprocess, api_key) 69 | 70 | if pipeline == "chat": 71 | interface = gr.ChatInterface(fn=fn, **kwargs) 72 | else: 73 | # For other pipelines, create a standard Interface (not implemented yet) 74 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 75 | 76 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/browser_use_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_openai import ChatOpenAI 3 | from browser_use import Agent 4 | import gradio as gr 5 | from typing import Callable 6 | import asyncio 7 | 8 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 9 | async def fn(message, history): 10 | inputs = preprocess(message, history) 11 | 12 | agent = Agent( 13 | task=inputs["message"], 14 | llm=ChatOpenAI( 15 | api_key=api_key, 16 | model=model_name, 17 | disabled_params={"parallel_tool_calls": None} 18 | ), 19 | use_vision=(model_name != "o3-mini-2025-01-31") # Only disable vision for o3-mini 20 | ) 21 | 22 | try: 23 | result = await agent.run() 24 | return postprocess(result) 25 | except Exception as e: 26 | return f"Error: {str(e)}" 27 | 28 | return fn # Remove sync_wrapper, return async function directly 29 | 30 | def get_interface_args(pipeline): 31 | if pipeline == "browser": 32 | def preprocess(message, history): 33 | return {"message": message} 34 | 35 | def postprocess(result): 36 | if hasattr(result, 'all_results') and hasattr(result, 'all_model_outputs'): 37 | # Get the thought process from non-final results 38 | thoughts = [r.extracted_content for r in result.all_results if not r.is_done] 39 | # Get the final answer from the last result 40 | final_answer = next((r.extracted_content for r in reversed(result.all_results) if r.is_done), None) 41 | 42 | if final_answer: 43 | # Return in Gradio's message format 44 | return { 45 | "role": "assistant", 46 | "content": final_answer, 47 | "metadata": { 48 | "title": "🔍 " + " → ".join(thoughts) 49 | } 50 | } 51 | 52 | # Fallback to simple message format 53 | return { 54 | "role": "assistant", 55 | "content": str(result) if not isinstance(result, str) else result 56 | } 57 | 58 | return preprocess, postprocess 59 | else: 60 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 61 | 62 | def registry(name: str, token: str | None = None, **kwargs): 63 | api_key = token or os.environ.get("OPENAI_API_KEY") 64 | if not api_key: 65 | raise ValueError("OPENAI_API_KEY environment variable is not set.") 66 | 67 | pipeline = "browser" 68 | preprocess, postprocess = get_interface_args(pipeline) 69 | fn = get_fn(name, preprocess, postprocess, api_key) 70 | 71 | # Remove title and description from kwargs if they exist 72 | kwargs.pop('title', None) 73 | kwargs.pop('description', None) 74 | 75 | interface = gr.ChatInterface( 76 | fn=fn, 77 | title="Browser Use Agent", 78 | description="Chat with an AI agent that can perform browser tasks.", 79 | examples=["Go to amazon.com and find the best rated laptop and return the price.", 80 | "Find the current weather in New York"], 81 | **kwargs 82 | ) 83 | 84 | return interface 85 | -------------------------------------------------------------------------------- /ai_gradio/providers/swarms_gradio.py: -------------------------------------------------------------------------------- 1 | from swarms import Agent 2 | from swarm_models import OpenAIChat 3 | import gradio as gr 4 | from typing import Generator, List, Dict, Callable 5 | import base64 6 | import os 7 | 8 | __version__ = "0.0.1" 9 | 10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str, agent_name: str = "Assistant"): 11 | def fn(message, history): 12 | inputs = preprocess(message, history) 13 | agent = create_agent(model_name, agent_name) 14 | 15 | try: 16 | for response in stream_agent_response(agent, inputs["message"]): 17 | yield postprocess(response) 18 | except Exception as e: 19 | yield postprocess({ 20 | "role": "assistant", 21 | "content": f"Error: {str(e)}", 22 | "metadata": {"title": "❌ Error"} 23 | }) 24 | 25 | return fn 26 | 27 | def get_interface_args(pipeline): 28 | if pipeline == "chat": 29 | def preprocess(message, history): 30 | messages = [] 31 | for user_msg, assistant_msg in history: 32 | if assistant_msg is not None: 33 | messages.append({"role": "user", "content": user_msg}) 34 | messages.append({"role": "assistant", "content": assistant_msg}) 35 | messages.append({"role": "user", "content": message}) 36 | return {"message": message, "history": messages} 37 | 38 | def postprocess(response): 39 | return response 40 | 41 | return None, None, preprocess, postprocess 42 | else: 43 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 44 | 45 | def create_agent(model_name: str = None, agent_name: str = "Assistant", system_prompt: str = None): 46 | """Create a Swarms Agent with the specified model and configuration""" 47 | model = OpenAIChat(model_name=model_name) if model_name else OpenAIChat() 48 | 49 | return Agent( 50 | agent_name=agent_name, 51 | system_prompt=system_prompt or "You are a helpful AI assistant.", 52 | llm=model, 53 | max_loops=1, 54 | autosave=False, 55 | dashboard=False, 56 | verbose=True, 57 | dynamic_temperature_enabled=True, 58 | streaming_on=True, 59 | saved_state_path=f"{agent_name.lower()}_state.json", 60 | user_name="gradio_user", 61 | retry_attempts=1, 62 | context_length=200000, 63 | return_step_meta=False, 64 | output_type="string" 65 | ) 66 | 67 | def stream_agent_response(agent: Agent, prompt: str) -> Generator[Dict, None, None]: 68 | # Initial thinking message 69 | yield { 70 | "role": "assistant", 71 | "content": "Let me think about that...", 72 | "metadata": {"title": "🤔 Thinking"} 73 | } 74 | 75 | try: 76 | # Get response from agent 77 | response = agent.run(prompt) 78 | 79 | # Stream final response 80 | yield { 81 | "role": "assistant", 82 | "content": response, 83 | "metadata": {"title": "💬 Response"} 84 | } 85 | 86 | except Exception as e: 87 | yield { 88 | "role": "assistant", 89 | "content": f"Error: {str(e)}", 90 | "metadata": {"title": "❌ Error"} 91 | } 92 | 93 | def registry(name: str, token: str | None = None, agent_name: str = "Assistant", **kwargs): 94 | api_key = token or os.environ.get("OPENAI_API_KEY") 95 | if not api_key: 96 | raise ValueError("API key is not set. Please provide a token or set OPENAI_API_KEY environment variable.") 97 | 98 | pipeline = "chat" # Swarms only supports chat for now 99 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 100 | fn = get_fn(name, preprocess, postprocess, api_key, agent_name) 101 | 102 | interface = gr.ChatInterface(fn=fn, **kwargs) 103 | return interface -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "ai-gradio" 7 | version = "0.2.38" 8 | description = "A Python package for creating Gradio applications with AI models" 9 | authors = [ 10 | { name = "AK", email = "ahsen.khaliq@gmail.com" } 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [ 20 | "torch>=2.0.0", 21 | "numpy", 22 | "accelerate>=0.27.0", 23 | "bitsandbytes>=0.41.0", 24 | "gradio>=5.9.1", 25 | "gradio-webrtc", 26 | "websockets", 27 | "twilio", 28 | "Pillow", 29 | "opencv-python", 30 | "librosa", 31 | "pydub", 32 | "gradio_webrtc[vad]", 33 | "numba==0.60.0", 34 | "python-dotenv", 35 | "modelscope-studio", 36 | ] 37 | 38 | [project.urls] 39 | homepage = "https://github.com/AK391/ai-gradio" 40 | repository = "https://github.com/AK391/ai-gradio" 41 | 42 | [project.optional-dependencies] 43 | dev = [ 44 | "pytest", 45 | "black", 46 | "isort", 47 | "flake8" 48 | ] 49 | transformers = ["transformers>=4.37.0", "torch>=2.0.0", "accelerate>=0.27.0", "bitsandbytes>=0.41.0", "einops>=0.8.0", "Pillow>=10.4.0", "pyvips-binary>=8.16.0", "pyvips>=2.2.3", "torchvision>=0.18.1"] 50 | openai = ["openai>=1.58.1"] 51 | gemini = ["google-generativeai>=0.8.3", "google-genai==0.3.0"] 52 | crewai = [ 53 | "crewai>=0.1.0", 54 | "langchain>=0.1.0", 55 | "langchain-openai>=0.0.2", 56 | "crewai-tools>=0.0.1" 57 | ] 58 | anthropic = ["anthropic>=1.0.0"] 59 | lumaai = ["lumaai>=0.0.3"] 60 | xai = ["openai>=1.58.1"] 61 | cohere = ["cohere>=5.0.0"] 62 | sambanova = ["openai>=1.58.1"] 63 | hyperbolic = ["openai>=1.58.1"] 64 | qwen = [ 65 | "openai>=1.58.1", 66 | ] 67 | browser = [ 68 | "browser-use>=0.1.16", 69 | "playwright>=1.49.1" 70 | ] 71 | swarms = [ 72 | "swarms>=6.8.9,<7.0.0", 73 | "swarm-models>=0.3.0,<0.4.0", 74 | "swarms-memory>=0.1.2,<0.2.0", 75 | "langchain>=0.1.0", 76 | "langchain-community>=0.0.10", 77 | "pydantic>=2.0.0,<3.0.0" 78 | ] 79 | kokoro = [ 80 | "kokoro-onnx>=0.3.3", 81 | "soundfile>=0.13.0", 82 | "huggingface-hub>=0.27.1" 83 | ] 84 | all = [ 85 | "openai>=1.58.1", 86 | "google-generativeai", 87 | "crewai>=0.1.0", 88 | "langchain>=0.1.0", 89 | "langchain-openai>=0.0.2", 90 | "crewai-tools>=0.0.1", 91 | "anthropic>=1.0.0", 92 | "lumaai>=0.0.3", 93 | "cohere>=5.0.0", 94 | "sambanova>=0.0.1", 95 | "hyperbolic>=1.58.1", 96 | "groq>=0.3.0", 97 | "browser-use>=0.1.0", 98 | "swarms>=6.8.9,<7.0.0", 99 | "swarm-models>=0.3.0,<0.4.0", 100 | "swarms-memory>=0.1.2,<0.2.0", 101 | "langchain>=0.1.0", 102 | "pydantic>=2.0.0,<3.0.0", 103 | "langchain-openai>=0.0.2", 104 | "langchain-community>=0.0.10", 105 | "langchain-core>=0.1.0", 106 | "tavily-python>=0.3.0", 107 | "requests>=2.31.0", 108 | "kokoro-onnx>=0.1.0", 109 | "soundfile>=0.12.0", 110 | "huggingface-hub>=0.20.0" 111 | ] 112 | groq = ["openai>=1.58.1"] 113 | fireworks = ["openai>=1.58.1"] 114 | together = ["openai>=1.58.1"] 115 | deepseek = ["openai>=1.58.1"] 116 | smolagents = ["smolagents>=0.1.3"] 117 | jupyter = [ 118 | "nbformat>=5.9.0", 119 | "nbconvert>=7.16.0", 120 | "e2b-code-interpreter>=0.5.0", 121 | "huggingface-hub>=0.20.0", 122 | ] 123 | 124 | langchain = [ 125 | "langchain", 126 | "langchain-community", 127 | "langchain-core", 128 | "tavily-python", 129 | "langchain-openai" 130 | ] 131 | 132 | mistral = [ 133 | "mistralai" 134 | ] 135 | 136 | nvidia = [ 137 | "openai>=1.58.1" 138 | ] 139 | 140 | minimax = ["requests>=2.31.0"] 141 | 142 | perplexity = [ 143 | "openai>=1.58.1" 144 | ] 145 | 146 | replicate = [ 147 | "replicate>=1.0.4" 148 | ] 149 | 150 | [tool.hatch.build.targets.wheel] 151 | packages = ["ai_gradio"] 152 | 153 | -------------------------------------------------------------------------------- /ai_gradio/providers/langchain_gradio.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from langchain.agents import create_tool_calling_agent, AgentExecutor 3 | from langchain_community.tools.tavily_search import TavilySearchResults 4 | from langchain import hub 5 | from langchain_core.messages import AIMessage, HumanMessage 6 | from langchain_community.chat_message_histories import ChatMessageHistory 7 | from langchain_core.runnables.history import RunnableWithMessageHistory 8 | import gradio as gr 9 | from typing import Generator, List, Dict 10 | 11 | def create_agent(model_name: str = None): 12 | # Initialize search tool 13 | search = TavilySearchResults() 14 | tools = [search] 15 | 16 | # Initialize LLM 17 | llm = ChatOpenAI( 18 | model_name=model_name if model_name else "gpt-3.5-turbo-0125", 19 | temperature=0 20 | ) 21 | 22 | # Get the prompt 23 | prompt = hub.pull("hwchase17/openai-functions-agent") 24 | 25 | # Create the agent 26 | agent = create_tool_calling_agent(llm, tools, prompt) 27 | 28 | # Create the executor 29 | return AgentExecutor(agent=agent, tools=tools, verbose=True) 30 | 31 | def stream_agent_response(agent: AgentExecutor, message: str, history: List) -> Generator[Dict, None, None]: 32 | # First yield the thinking message 33 | yield { 34 | "role": "assistant", 35 | "content": "Let me think about that...", 36 | "metadata": {"title": "🤔 Thinking"} 37 | } 38 | 39 | try: 40 | # Convert history to LangChain format 41 | chat_history = [] 42 | for msg in history: 43 | if msg["role"] == "user": 44 | chat_history.append(HumanMessage(content=msg["content"])) 45 | elif msg["role"] == "assistant": 46 | chat_history.append(AIMessage(content=msg["content"])) 47 | 48 | # Run the agent 49 | response = agent.invoke({ 50 | "input": message, 51 | "chat_history": chat_history 52 | }) 53 | 54 | # Yield the final response 55 | yield { 56 | "role": "assistant", 57 | "content": response["output"] 58 | } 59 | 60 | except Exception as e: 61 | yield { 62 | "role": "assistant", 63 | "content": f"Error: {str(e)}", 64 | "metadata": {"title": "❌ Error"} 65 | } 66 | 67 | async def interact_with_agent(message: str, history: List, model_name: str = None) -> Generator[List, None, None]: 68 | # Add user message 69 | history.append({"role": "user", "content": message}) 70 | yield history 71 | 72 | # Create agent instance with specified model 73 | agent = create_agent(model_name) 74 | 75 | # Stream agent responses 76 | for response in stream_agent_response(agent, message, history): 77 | history.append(response) 78 | yield history 79 | 80 | def registry(name: str, **kwargs): 81 | # Extract model name from the name parameter 82 | model_name = name.split(':')[-1] if ':' in name else None 83 | 84 | with gr.Blocks() as demo: 85 | gr.Markdown("# LangChain Assistant 🦜️") 86 | 87 | chatbot = gr.Chatbot( 88 | type="messages", 89 | label="Agent", 90 | avatar_images=(None, "https://python.langchain.com/img/favicon.ico"), 91 | height=500 92 | ) 93 | 94 | msg = gr.Textbox( 95 | label="Your message", 96 | placeholder="Type your message here...", 97 | lines=1 98 | ) 99 | 100 | async def handle_message(message, history): 101 | async for response in interact_with_agent(message, history, model_name=model_name): 102 | yield response 103 | 104 | msg.submit( 105 | fn=handle_message, 106 | inputs=[msg, chatbot], 107 | outputs=[chatbot], 108 | api_name="predict" 109 | ).then(lambda _:"", msg, msg) 110 | 111 | return demo 112 | -------------------------------------------------------------------------------- /ai_gradio/providers/xai_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from openai import OpenAI 4 | import gradio as gr 5 | from typing import Callable 6 | 7 | __version__ = "0.0.1" 8 | 9 | def get_image_base64(url: str, ext: str): 10 | with open(url, "rb") as image_file: 11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 12 | return "data:image/" + ext + ";base64," + encoded_string 13 | 14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 15 | def fn(message, history): 16 | inputs = preprocess(message, history) 17 | client = OpenAI( 18 | api_key=api_key, 19 | base_url="https://api.x.ai/v1", 20 | ) 21 | completion = client.chat.completions.create( 22 | model=model_name, 23 | messages=inputs["messages"], 24 | stream=True, 25 | ) 26 | response_text = "" 27 | for chunk in completion: 28 | delta = chunk.choices[0].delta.content or "" 29 | response_text += delta 30 | yield postprocess(response_text) 31 | 32 | return fn 33 | 34 | def handle_user_msg(message: str): 35 | if type(message) is str: 36 | return message 37 | elif type(message) is dict: 38 | if message["files"] is not None and len(message["files"]) > 0: 39 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 40 | if ext.lower() in ["png", "jpg", "jpeg", "gif"]: 41 | encoded_str = get_image_base64(message["files"][-1], ext) 42 | else: 43 | raise NotImplementedError(f"Not supported file type {ext}") 44 | content = [ 45 | {"type": "text", "text": message["text"]}, 46 | { 47 | "type": "image_url", 48 | "image_url": { 49 | "url": encoded_str, 50 | } 51 | }, 52 | ] 53 | else: 54 | content = message["text"] 55 | return content 56 | else: 57 | raise NotImplementedError 58 | 59 | def get_interface_args(model_name: str): 60 | inputs = None 61 | outputs = None 62 | 63 | def preprocess(message, history): 64 | messages = [{"role": "system", "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy."}] 65 | files = None 66 | for user_msg, assistant_msg in history: 67 | if assistant_msg is not None: 68 | messages.append({"role": "user", "content": handle_user_msg(user_msg)}) 69 | messages.append({"role": "assistant", "content": assistant_msg}) 70 | else: 71 | files = user_msg 72 | 73 | if type(message) is str and files is not None: 74 | message = {"text": message, "files": files} 75 | elif type(message) is dict and files is not None: 76 | if message["files"] is None or len(message["files"]) == 0: 77 | message["files"] = files 78 | 79 | messages.append({"role": "user", "content": handle_user_msg(message)}) 80 | return {"messages": messages} 81 | 82 | postprocess = lambda x: x # No post-processing needed 83 | return inputs, outputs, preprocess, postprocess 84 | 85 | def registry(name: str = "grok-beta", token: str | None = None, **kwargs): 86 | """ 87 | Create a Gradio Interface for X.AI's Grok model. 88 | 89 | Parameters: 90 | - name (str): The name of the model (defaults to "grok-beta" or "grok-vision-beta") 91 | - token (str, optional): The X.AI API key 92 | """ 93 | api_key = token or os.environ.get("XAI_API_KEY") 94 | if not api_key: 95 | raise ValueError("XAI_API_KEY environment variable is not set.") 96 | 97 | inputs, outputs, preprocess, postprocess = get_interface_args(name) 98 | fn = get_fn(name, preprocess, postprocess, api_key) 99 | 100 | # Always set multimodal=True 101 | kwargs["multimodal"] = True 102 | 103 | interface = gr.ChatInterface(fn=fn, **kwargs) 104 | 105 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/cohere_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cohere 3 | import gradio as gr 4 | from typing import Callable 5 | import base64 6 | 7 | __version__ = "0.0.3" 8 | 9 | 10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 11 | def fn(message, history): 12 | inputs = preprocess(message, history) 13 | client = cohere.Client(api_key=api_key) 14 | stream = client.chat_stream( 15 | model=model_name, 16 | message=inputs["message"], 17 | chat_history=inputs["chat_history"], 18 | ) 19 | response_text = "" 20 | for chunk in stream: 21 | if chunk.event_type == "text-generation": 22 | delta = chunk.text 23 | response_text += delta 24 | yield postprocess(response_text) 25 | 26 | return fn 27 | 28 | 29 | def get_image_base64(url: str, ext: str): 30 | with open(url, "rb") as image_file: 31 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 32 | return "data:image/" + ext + ";base64," + encoded_string 33 | 34 | 35 | def handle_user_msg(message: str): 36 | if type(message) is str: 37 | return message 38 | elif type(message) is dict: 39 | if message["files"] is not None and len(message["files"]) > 0: 40 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 41 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: 42 | encoded_str = get_image_base64(message["files"][-1], ext) 43 | else: 44 | raise NotImplementedError(f"Not supported file type {ext}") 45 | return { 46 | "message": message["text"], 47 | "attachments": [ 48 | { 49 | "source": encoded_str, 50 | "type": "image", 51 | } 52 | ] 53 | } 54 | else: 55 | return message["text"] 56 | else: 57 | raise NotImplementedError 58 | 59 | 60 | def get_interface_args(pipeline): 61 | if pipeline == "chat": 62 | inputs = None 63 | outputs = None 64 | 65 | def preprocess(message, history): 66 | chat_history = [] 67 | files = None 68 | for user_msg, assistant_msg in history: 69 | if assistant_msg is not None: 70 | chat_history.append({"role": "USER", "message": handle_user_msg(user_msg)}) 71 | chat_history.append({"role": "ASSISTANT", "message": assistant_msg}) 72 | else: 73 | files = user_msg 74 | if type(message) is str and files is not None: 75 | message = {"text": message, "files": files} 76 | elif type(message) is dict and files is not None: 77 | if message["files"] is None or len(message["files"]) == 0: 78 | message["files"] = files 79 | 80 | return { 81 | "message": handle_user_msg(message), 82 | "chat_history": chat_history 83 | } 84 | 85 | postprocess = lambda x: x 86 | else: 87 | # Add other pipeline types when they will be needed 88 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 89 | return inputs, outputs, preprocess, postprocess 90 | 91 | 92 | def get_pipeline(model_name): 93 | # Determine the pipeline type based on the model name 94 | # For simplicity, assuming all models are chat models at the moment 95 | return "chat" 96 | 97 | 98 | def registry(name: str, token: str | None = None, **kwargs): 99 | """ 100 | Create a Gradio Interface for a model on Cohere. 101 | 102 | Parameters: 103 | - name (str): The name of the Cohere model. 104 | - token (str, optional): The API key for Cohere. 105 | """ 106 | api_key = token or os.environ.get("COHERE_API_KEY") 107 | if not api_key: 108 | raise ValueError("COHERE_API_KEY environment variable is not set.") 109 | 110 | pipeline = get_pipeline(name) 111 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 112 | fn = get_fn(name, preprocess, postprocess, api_key) 113 | 114 | if pipeline == "chat": 115 | interface = gr.ChatInterface(fn=fn, multimodal=True, **kwargs) 116 | else: 117 | # For other pipelines, create a standard Interface (not implemented yet) 118 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 119 | 120 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/together_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from huggingface_hub import InferenceClient 4 | import gradio as gr 5 | from typing import Callable 6 | 7 | __version__ = "0.0.1" 8 | 9 | def get_image_base64(url: str, ext: str): 10 | with open(url, "rb") as image_file: 11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 12 | return "data:image/" + ext + ";base64," + encoded_string 13 | 14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 15 | def fn(message, history): 16 | inputs = preprocess(message, history) 17 | client = InferenceClient( 18 | provider="together", 19 | token=api_key 20 | ) 21 | try: 22 | completion = client.chat.completions.create( 23 | model=model_name, 24 | messages=inputs["messages"], 25 | stream=True, 26 | max_tokens=1000 27 | ) 28 | 29 | partial_message = "" 30 | for chunk in completion: 31 | if chunk.choices: 32 | delta = chunk.choices[0].delta.content or "" 33 | delta = delta.replace("", "[think]").replace("", "[/think]") 34 | partial_message += delta 35 | yield postprocess(partial_message) 36 | 37 | except Exception as e: 38 | error_message = f"Error: {str(e)}" 39 | yield error_message 40 | 41 | return fn 42 | 43 | def handle_user_msg(message: str): 44 | if type(message) is str: 45 | return message 46 | elif type(message) is dict: 47 | if message["files"] is not None and len(message["files"]) > 0: 48 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 49 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: 50 | encoded_str = get_image_base64(message["files"][-1], ext) 51 | else: 52 | raise NotImplementedError(f"Not supported file type {ext}") 53 | content = [ 54 | {"type": "text", "text": message["text"]}, 55 | { 56 | "type": "image_url", 57 | "image_url": { 58 | "url": encoded_str, 59 | } 60 | }, 61 | ] 62 | else: 63 | content = message["text"] 64 | return content 65 | else: 66 | raise NotImplementedError 67 | 68 | def get_interface_args(pipeline): 69 | if pipeline == "chat": 70 | inputs = None 71 | outputs = None 72 | 73 | def preprocess(message, history): 74 | messages = [] 75 | # Process history first 76 | for user_msg, assistant_msg in history: 77 | messages.append({"role": "user", "content": str(user_msg)}) 78 | if assistant_msg is not None: 79 | messages.append({"role": "assistant", "content": str(assistant_msg)}) 80 | 81 | # Add current message 82 | messages.append({"role": "user", "content": str(message)}) 83 | return {"messages": messages} 84 | 85 | postprocess = lambda x: x # No post-processing needed 86 | else: 87 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 88 | return inputs, outputs, preprocess, postprocess 89 | 90 | 91 | def get_pipeline(model_name): 92 | # Determine the pipeline type based on the model name 93 | # For simplicity, assuming all models are chat models at the moment 94 | return "chat" 95 | 96 | 97 | def registry(name: str, token: str | None = None, **kwargs): 98 | """ 99 | Create a Gradio Interface for a model on Together. 100 | 101 | Parameters: 102 | - name (str): The name of the model on Together. 103 | - token (str, optional): The API key for Together. 104 | """ 105 | api_key = token or os.environ.get("HF_TOKEN") 106 | if not api_key: 107 | raise ValueError("HF_TOKEN environment variable is not set.") 108 | 109 | pipeline = get_pipeline(name) 110 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 111 | fn = get_fn(name, preprocess, postprocess, api_key) 112 | 113 | if pipeline == "chat": 114 | interface = gr.ChatInterface(fn=fn, **kwargs) 115 | else: 116 | # For other pipelines, create a standard Interface (not implemented yet) 117 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 118 | 119 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/sambanova_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from openai import OpenAI 4 | import gradio as gr 5 | from typing import Callable 6 | 7 | __version__ = "0.0.1" 8 | 9 | def get_image_base64(url: str, ext: str): 10 | with open(url, "rb") as image_file: 11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 12 | return "data:image/" + ext + ";base64," + encoded_string 13 | 14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 15 | def fn(message, history): 16 | inputs = preprocess(message, history) 17 | client = OpenAI( 18 | base_url="https://api.sambanova.ai/v1/", 19 | api_key=api_key, 20 | ) 21 | try: 22 | completion = client.chat.completions.create( 23 | model=model_name, 24 | messages=inputs["messages"], 25 | stream=True, 26 | ) 27 | response_text = "" 28 | for chunk in completion: 29 | delta = chunk.choices[0].delta.content or "" 30 | response_text += delta 31 | yield postprocess(response_text) 32 | except Exception as e: 33 | error_message = f"Error: {str(e)}" 34 | return error_message 35 | 36 | return fn 37 | 38 | def handle_user_msg(message: str): 39 | if type(message) is str: 40 | return message 41 | elif type(message) is dict: 42 | if message["files"] is not None and len(message["files"]) > 0: 43 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 44 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: 45 | encoded_str = get_image_base64(message["files"][-1], ext) 46 | else: 47 | raise NotImplementedError(f"Not supported file type {ext}") 48 | content = [ 49 | {"type": "text", "text": message["text"]}, 50 | { 51 | "type": "image_url", 52 | "image_url": { 53 | "url": encoded_str, 54 | } 55 | }, 56 | ] 57 | else: 58 | content = message["text"] 59 | return content 60 | else: 61 | raise NotImplementedError 62 | 63 | def get_interface_args(pipeline): 64 | if pipeline == "chat": 65 | inputs = None 66 | outputs = None 67 | 68 | def preprocess(message, history): 69 | messages = [] 70 | files = None 71 | for user_msg, assistant_msg in history: 72 | if assistant_msg is not None: 73 | messages.append({"role": "user", "content": handle_user_msg(user_msg)}) 74 | messages.append({"role": "assistant", "content": assistant_msg}) 75 | else: 76 | files = user_msg 77 | if type(message) is str and files is not None: 78 | message = {"text":message, "files":files} 79 | elif type(message) is dict and files is not None: 80 | if message["files"] is None or len(message["files"]) == 0: 81 | message["files"] = files 82 | messages.append({"role": "user", "content": handle_user_msg(message)}) 83 | return {"messages": messages} 84 | 85 | postprocess = lambda x: x # No post-processing needed 86 | else: 87 | # Add other pipeline types when they will be needed 88 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 89 | return inputs, outputs, preprocess, postprocess 90 | 91 | 92 | def get_pipeline(model_name): 93 | # Determine the pipeline type based on the model name 94 | # For simplicity, assuming all models are chat models at the moment 95 | return "chat" 96 | 97 | 98 | def registry(name: str, token: str | None = None, **kwargs): 99 | """ 100 | Create a Gradio Interface for a model on Sambanova. 101 | 102 | Parameters: 103 | - name (str): The name of the model on Sambanova. 104 | - token (str, optional): The API key for Sambanova. 105 | """ 106 | api_key = token or os.environ.get("SAMBANOVA_API_KEY") 107 | if not api_key: 108 | raise ValueError("SAMBANOVA_API_KEY environment variable is not set.") 109 | 110 | pipeline = get_pipeline(name) 111 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 112 | fn = get_fn(name, preprocess, postprocess, api_key) 113 | 114 | if pipeline == "chat": 115 | interface = gr.ChatInterface(fn=fn, **kwargs) 116 | else: 117 | # For other pipelines, create a standard Interface (not implemented yet) 118 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 119 | 120 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/smolagents_gradio.py: -------------------------------------------------------------------------------- 1 | from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel 2 | from smolagents.agents import ActionStep 3 | import gradio as gr 4 | from typing import Generator, List, Dict 5 | 6 | def create_agent(model_name: str = None): 7 | model = HfApiModel(model_name) if model_name else HfApiModel() 8 | return CodeAgent(tools=[DuckDuckGoSearchTool()], model=model) 9 | 10 | def stream_agent_response(agent: CodeAgent, prompt: str) -> Generator[Dict, None, None]: 11 | # First yield the thinking message 12 | yield { 13 | "role": "assistant", 14 | "content": "Let me think about that...", 15 | "metadata": {"title": "🤔 Thinking"} 16 | } 17 | 18 | # Run the agent and capture its response 19 | try: 20 | # Get the agent's response 21 | for step in agent.run(prompt, stream=True): 22 | if isinstance(step, ActionStep): 23 | # Show LLM output if present (as collapsible thought) 24 | if step.llm_output: 25 | yield { 26 | "role": "assistant", 27 | "content": step.llm_output, 28 | "metadata": {"title": "🧠 Thought Process"} 29 | } 30 | 31 | # Show tool call if present 32 | if step.tool_call: 33 | content = step.tool_call.arguments 34 | if step.tool_call.name == "python_interpreter": 35 | content = f"```python\n{content}\n```" 36 | yield { 37 | "role": "assistant", 38 | "content": str(content), 39 | "metadata": {"title": f"🛠️ Using {step.tool_call.name}"} 40 | } 41 | 42 | # Show observations if present 43 | if step.observations: 44 | yield { 45 | "role": "assistant", 46 | "content": f"```\n{step.observations}\n```", 47 | "metadata": {"title": "👁️ Observations"} 48 | } 49 | 50 | # Show errors if present 51 | if step.error: 52 | yield { 53 | "role": "assistant", 54 | "content": str(step.error), 55 | "metadata": {"title": "❌ Error"} 56 | } 57 | 58 | # Show final output if present (without metadata to keep it expanded) 59 | if step.action_output is not None and not step.error: 60 | # Only show the final output if it's actually the last step 61 | if step == step.action_output: 62 | yield { 63 | "role": "assistant", 64 | "content": str(step.action_output) 65 | } 66 | else: 67 | # For any other type of step output 68 | yield { 69 | "role": "assistant", 70 | "content": str(step), 71 | "metadata": {"title": "🔄 Processing"} 72 | } 73 | 74 | except Exception as e: 75 | yield { 76 | "role": "assistant", 77 | "content": f"Error: {str(e)}", 78 | "metadata": {"title": "❌ Error"} 79 | } 80 | 81 | async def interact_with_agent(message: str, history: List, model_name: str = None) -> Generator[List, None, None]: 82 | # Add user message 83 | history.append({"role": "user", "content": message}) 84 | yield history 85 | 86 | # Create agent instance with specified model 87 | agent = create_agent(model_name) 88 | 89 | # Stream agent responses 90 | for response in stream_agent_response(agent, message): 91 | history.append(response) 92 | yield history 93 | 94 | def registry(name: str, **kwargs): 95 | # Extract model name from the name parameter 96 | model_name = name.split(':')[-1] if ':' in name else None 97 | 98 | with gr.Blocks() as demo: 99 | gr.Markdown("# SmolagentsAI Assistant 🤖") 100 | 101 | chatbot = gr.Chatbot( 102 | type="messages", 103 | label="Agent", 104 | avatar_images=(None, "https://cdn-lfs.hf.co/repos/96/a2/96a2c8468c1546e660ac2609e49404b8588fcf5a748761fa72c154b2836b4c83/9cf16f4f32604eaf76dabbdf47701eea5a768ebcc7296acc1d1758181f71db73?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27hf-logo.png%3B+filename%3D%22hf-logo.png%22%3B&response-content-type=image%2Fpng&Expires=1735927745&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNTkyNzc0NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy85Ni9hMi85NmEyYzg0NjhjMTU0NmU2NjBhYzI2MDllNDk0MDRiODU4OGZjZjVhNzQ4NzYxZmE3MmMxNTRiMjgzNmI0YzgzLzljZjE2ZjRmMzI2MDRlYWY3NmRhYmJkZjQ3NzAxZWVhNWE3NjhlYmNjNzI5NmFjYzFkMTc1ODE4MWY3MWRiNzM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=a%7EGHhtB%7EWzd7pjl5F5wJHVxQPymZEQcgcAoc5FkJggZcWPDrcZ82CtDpIQafi0CultDPMy8SiZgHA0PDDk31c5a6wzdIbsbm7zZ5NvGTTlZpXskL3x7Gbr-f2E3yOA%7EHR%7E2heEJlpim78-xLqkWA92CYo-tKLg-yHKMx0acQcBvhptHOZtwlb9%7EyHlqlzNpcLo4iqEgEH39ADRNhpkf54-Zj6SQNBod7AkjFA3-iIzX5LVzW6EEYyFs03Ba0AfBUODgZIt8cjglULQ2a02rgiM%7EjKMBmB2eKNDFtvoe7YSGlFbVcLt21pWjhzA-z9MgQsw-U3ZDY539iHkMMkfoQzA__&Key-Pair-Id=K3RPWS32NSSJCE"), 105 | height=500 106 | ) 107 | 108 | msg = gr.Textbox( 109 | label="Your message", 110 | placeholder="Type your message here...", 111 | lines=1 112 | ) 113 | 114 | # Make the wrapper function async and await the generator 115 | async def handle_message(message, history): 116 | async for response in interact_with_agent(message, history, model_name=model_name): 117 | yield response 118 | 119 | msg.submit( 120 | fn=handle_message, 121 | inputs=[msg, chatbot], 122 | outputs=[chatbot], 123 | api_name="predict" 124 | ) 125 | 126 | return demo -------------------------------------------------------------------------------- /ai_gradio/providers/transformers_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 3 | import torch 4 | import gradio as gr 5 | from typing import Callable 6 | from PIL import Image 7 | 8 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, **kwargs): 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | 11 | # Get the full model path 12 | if "/" in model_name: 13 | model_path = model_name 14 | else: 15 | model_mapping = { 16 | "tulu-3": "allenai/llama-tulu-3-8b", 17 | "olmo-2-13b": "allenai/OLMo-2-1124-13B-Instruct", 18 | "smolvlm": "HuggingFaceTB/SmolVLM-Instruct", 19 | "phi-4": "microsoft/phi-4", 20 | "moondream": "vikhyatk/moondream2", 21 | } 22 | model_path = model_mapping.get(model_name) 23 | if not model_path: 24 | raise ValueError(f"Unknown model name: {model_name}") 25 | 26 | # Load model and tokenizer 27 | tokenizer = AutoTokenizer.from_pretrained(model_path) 28 | 29 | if model_name == "moondream": 30 | model = AutoModelForCausalLM.from_pretrained( 31 | model_path, 32 | revision="2025-01-09", 33 | trust_remote_code=True, 34 | device_map={"": "cuda" if torch.cuda.is_available() else "cpu"} 35 | ) 36 | elif device == "cuda": 37 | model = AutoModelForCausalLM.from_pretrained( 38 | model_path, 39 | device_map="auto", 40 | torch_dtype=torch.float16 41 | ) 42 | else: 43 | model = AutoModelForCausalLM.from_pretrained( 44 | model_path, 45 | device_map="auto", 46 | torch_dtype=torch.float32 47 | ) 48 | 49 | def predict(message, history, temperature=0.7, max_tokens=512, image=None): 50 | # Create a new list for history to avoid sharing between sessions 51 | history = list(history) if history else [] 52 | 53 | if model_name == "moondream": 54 | if isinstance(message, dict): 55 | text = message["text"] 56 | # Get the first image from files if available 57 | files = message.get("files", []) 58 | image = files[0] if files else image # Use provided image if no files 59 | 60 | if image is not None: 61 | # Ensure image is a PIL Image 62 | if not isinstance(image, Image.Image): 63 | try: 64 | image = Image.open(image) 65 | except Exception as e: 66 | return f"Error processing image: {str(e)}" 67 | 68 | # Generate response and return as a string 69 | response = model.query(image, text)["answer"] 70 | return response 71 | else: 72 | return "Please provide an image to analyze." 73 | 74 | # Format conversation history 75 | if isinstance(message, dict): 76 | message = message["text"] 77 | 78 | messages = [] 79 | for user_msg, assistant_msg in history: 80 | messages.append({"role": "user", "content": user_msg}) 81 | messages.append({"role": "assistant", "content": assistant_msg}) 82 | messages.append({"role": "user", "content": message}) 83 | 84 | # Convert to model format 85 | input_text = tokenizer.apply_chat_template(messages, tokenize=False) 86 | inputs = tokenizer(input_text, return_tensors="pt").to(device) 87 | 88 | # Generate response 89 | generation_config = { 90 | "input_ids": inputs["input_ids"], 91 | "max_new_tokens": max_tokens, 92 | "temperature": float(temperature), # Ensure temperature is a float 93 | "do_sample": True, 94 | "pad_token_id": tokenizer.eos_token_id 95 | } 96 | 97 | outputs = model.generate(**generation_config) 98 | 99 | # For phi-4, extract only the new generated text 100 | if model_name == "phi-4": 101 | input_length = inputs["input_ids"].shape[1] 102 | generated_tokens = outputs[0][input_length:] 103 | response = tokenizer.decode(generated_tokens, skip_special_tokens=True) 104 | else: 105 | response = tokenizer.decode(outputs[0], skip_special_tokens=True) 106 | 107 | return response 108 | 109 | return predict 110 | 111 | def get_interface_args(pipeline): 112 | if pipeline == "chat": 113 | def preprocess(message, history): 114 | return {"message": message, "history": history} 115 | 116 | def postprocess(response): 117 | return response 118 | 119 | return None, None, preprocess, postprocess 120 | elif pipeline == "vision-chat": 121 | def preprocess(message, history): 122 | return {"message": message, "history": history} 123 | 124 | def postprocess(response): 125 | return response 126 | 127 | return [gr.Textbox(label="Message"), gr.Image(type="pil")], None, preprocess, postprocess 128 | else: 129 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 130 | 131 | def get_pipeline(model_name): 132 | if model_name == "moondream": 133 | return "vision-chat" 134 | return "chat" 135 | 136 | def registry(name: str = None, **kwargs): 137 | pipeline = get_pipeline(name) 138 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 139 | fn = get_fn(name, preprocess, postprocess, **kwargs) 140 | 141 | if pipeline == "vision-chat": 142 | interface = gr.ChatInterface( 143 | fn=fn, 144 | additional_inputs=[ 145 | gr.Slider(0, 1, 0.7, label="Temperature"), 146 | gr.Slider(1, 2048, 512, label="Max tokens"), 147 | ], 148 | multimodal=True # Enable multimodal input 149 | ) 150 | else: 151 | interface = gr.ChatInterface( 152 | fn=fn, 153 | additional_inputs=[ 154 | gr.Slider(0, 1, 0.7, label="Temperature"), 155 | gr.Slider(1, 2048, 512, label="Max tokens"), 156 | ] 157 | ) 158 | 159 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/lumaai_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lumaai import LumaAI 3 | import gradio as gr 4 | from typing import Callable 5 | import base64 6 | import time 7 | 8 | __version__ = "0.0.3" 9 | 10 | 11 | def get_fn(preprocess: Callable, postprocess: Callable, api_key: str, pipeline: str): 12 | def fn(message, history, generation_type): 13 | try: 14 | inputs = preprocess(message, history) 15 | # Create a fresh client instance for each generation 16 | client = LumaAI(auth_token=api_key) 17 | 18 | # Validate generation type 19 | if generation_type not in ["video", "image"]: 20 | raise ValueError(f"Invalid generation type: {generation_type}") 21 | 22 | try: 23 | if generation_type == "video": 24 | generation = client.generations.create( 25 | prompt=inputs["prompt"], 26 | **inputs.get("additional_params", {}) 27 | ) 28 | else: # image 29 | generation = client.generations.image.create( 30 | prompt=inputs["prompt"], 31 | **inputs.get("additional_params", {}) 32 | ) 33 | except Exception as e: 34 | raise RuntimeError(f"Failed to create generation: {str(e)}") 35 | 36 | # Poll for completion with timeout 37 | start_time = time.time() 38 | timeout = 300 # 5 minutes timeout 39 | 40 | while True: 41 | if time.time() - start_time > timeout: 42 | raise RuntimeError("Generation timed out after 5 minutes") 43 | 44 | try: 45 | generation = client.generations.get(id=generation.id) 46 | if generation.state == "completed": 47 | asset_url = generation.assets.video if generation_type == "video" else generation.assets.image 48 | break 49 | elif generation.state == "failed": 50 | raise RuntimeError(f"Generation failed: {generation.failure_reason}") 51 | time.sleep(3) 52 | yield f"Generating {generation_type}... (Status: {generation.state})" 53 | except Exception as e: 54 | raise RuntimeError(f"Error checking generation status: {str(e)}") 55 | 56 | # Return asset URL wrapped in appropriate format for display 57 | yield postprocess(asset_url, generation_type) 58 | 59 | except Exception as e: 60 | yield f"Error: {str(e)}" 61 | raise 62 | 63 | return fn 64 | 65 | 66 | def get_image_base64(url: str, ext: str): 67 | with open(url, "rb") as image_file: 68 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 69 | return "data:image/" + ext + ";base64," + encoded_string 70 | 71 | 72 | def handle_user_msg(message: str): 73 | if type(message) is str: 74 | return message 75 | elif type(message) is dict: 76 | if message["files"] is not None and len(message["files"]) > 0: 77 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 78 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: 79 | encoded_str = get_image_base64(message["files"][-1], ext) 80 | else: 81 | raise NotImplementedError(f"Not supported file type {ext}") 82 | content = [ 83 | {"type": "text", "text": message["text"]}, 84 | { 85 | "type": "image_url", 86 | "image_url": { 87 | "url": encoded_str, 88 | } 89 | }, 90 | ] 91 | else: 92 | content = message["text"] 93 | return content 94 | else: 95 | raise NotImplementedError 96 | 97 | 98 | def get_interface_args(pipeline): 99 | generation_type = gr.Dropdown( 100 | choices=["video", "image"], 101 | label="Generation Type", 102 | value="video" if pipeline == "video" else "image" 103 | ) 104 | outputs = None 105 | 106 | def preprocess(message, history): 107 | if isinstance(message, str): 108 | return {"prompt": message} 109 | elif isinstance(message, dict): 110 | prompt = message.get("text", "") 111 | additional_params = {} 112 | 113 | # Handle optional parameters 114 | if message.get("aspect_ratio"): 115 | additional_params["aspect_ratio"] = message["aspect_ratio"] 116 | if message.get("model"): 117 | additional_params["model"] = message["model"] 118 | 119 | return { 120 | "prompt": prompt, 121 | "additional_params": additional_params 122 | } 123 | 124 | def postprocess(url, generation_type): 125 | if generation_type == "video": 126 | return f'' 127 | else: 128 | return f"![Generated Image]({url})" 129 | 130 | return generation_type, outputs, preprocess, postprocess 131 | 132 | 133 | def get_pipeline(model_name): 134 | # Support both video and image pipelines 135 | return "video" if model_name == "dream-machine" else "image" 136 | 137 | 138 | def registry(name: str = "dream-machine", token: str = None, **kwargs): 139 | """ 140 | Create a Gradio Interface for LumaAI generation. 141 | 142 | Parameters: 143 | - name (str): Model name (defaults to 'dream-machine' for video, use 'photon-1' or 'photon-flash-1' for images) 144 | - token (str, optional): The API key for LumaAI 145 | - **kwargs: Additional keyword arguments passed to gr.Interface 146 | """ 147 | api_key = token or kwargs.pop('api_key', None) or os.environ.get("LUMAAI_API_KEY") 148 | if not api_key: 149 | raise ValueError("API key must be provided either through token parameter, kwargs, or LUMAAI_API_KEY environment variable.") 150 | 151 | pipeline = get_pipeline(name) 152 | generation_type, outputs, preprocess, postprocess = get_interface_args(pipeline) 153 | fn = get_fn(preprocess, postprocess, api_key, pipeline) 154 | 155 | interface = gr.ChatInterface( 156 | fn=fn, 157 | additional_inputs=[generation_type], 158 | type="messages", 159 | title="LumaAI Generation", 160 | description="Generate videos or images from text prompts using LumaAI", 161 | **kwargs 162 | ) 163 | 164 | return interface -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ai-gradio 2 | 3 | A Python package that makes it easy for developers to create machine learning apps powered by various AI providers. Built on top of Gradio, it provides a unified interface for multiple AI models and services. 4 | 5 | ## Features 6 | 7 | ### Core Features 8 | - **Multi-Provider Support**: Integrate with 15+ AI providers including OpenAI, Google Gemini, Anthropic, and more 9 | - **Text Chat**: Interactive chat interfaces for all text models 10 | - **Voice Chat**: Real-time voice interactions with OpenAI models 11 | - **Video Chat**: Video processing capabilities with Gemini models 12 | - **Code Generation**: Specialized interfaces for coding assistance 13 | - **Multi-Modal**: Support for text, image, and video inputs 14 | - **Agent Teams**: CrewAI integration for collaborative AI tasks 15 | - **Browser Automation**: AI agents that can perform web-based tasks 16 | 17 | ### Model Support 18 | 19 | #### Core Language Models 20 | | Provider | Models | 21 | |----------|---------| 22 | | OpenAI | gpt-4-turbo, gpt-4, gpt-3.5-turbo | 23 | | Anthropic | claude-3-opus, claude-3-sonnet, claude-3-haiku | 24 | | Gemini | gemini-pro, gemini-pro-vision, gemini-2.0-flash-exp | 25 | | Groq | llama-3.2-70b-chat, mixtral-8x7b-chat | 26 | 27 | #### Specialized Models 28 | | Provider | Type | Models | 29 | |----------|------|---------| 30 | | LumaAI | Generation | dream-machine, photon-1 | 31 | | DeepSeek | Multi-purpose | deepseek-chat, deepseek-coder, deepseek-vision | 32 | | CrewAI | Agent Teams | Support Team, Article Team | 33 | | Qwen | Language | qwen-turbo, qwen-plus, qwen-max | 34 | | Browser | Automation | browser-use-agent | 35 | 36 | ## Installation 37 | 38 | ### Basic Installation 39 | ```bash 40 | # Install core package 41 | pip install ai-gradio 42 | 43 | # Install with specific provider support 44 | pip install 'ai-gradio[openai]' # OpenAI support 45 | pip install 'ai-gradio[gemini]' # Google Gemini support 46 | pip install 'ai-gradio[anthropic]' # Anthropic Claude support 47 | pip install 'ai-gradio[groq]' # Groq support 48 | 49 | # Install all providers 50 | pip install 'ai-gradio[all]' 51 | ``` 52 | 53 | ### Additional Providers 54 | ```bash 55 | pip install 'ai-gradio[crewai]' # CrewAI support 56 | pip install 'ai-gradio[lumaai]' # LumaAI support 57 | pip install 'ai-gradio[xai]' # XAI/Grok support 58 | pip install 'ai-gradio[cohere]' # Cohere support 59 | pip install 'ai-gradio[sambanova]' # SambaNova support 60 | pip install 'ai-gradio[hyperbolic]' # Hyperbolic support 61 | pip install 'ai-gradio[deepseek]' # DeepSeek support 62 | pip install 'ai-gradio[smolagents]' # SmolagentsAI support 63 | pip install 'ai-gradio[fireworks]' # Fireworks support 64 | pip install 'ai-gradio[together]' # Together support 65 | pip install 'ai-gradio[qwen]' # Qwen support 66 | pip install 'ai-gradio[browser]' # Browser support 67 | ``` 68 | 69 | ## Usage 70 | 71 | ### API Key Configuration 72 | ```bash 73 | # Core Providers 74 | export OPENAI_API_KEY= 75 | export GEMINI_API_KEY= 76 | export ANTHROPIC_API_KEY= 77 | export GROQ_API_KEY= 78 | export TAVILY_API_KEY= # Required for Langchain agents 79 | 80 | # Additional Providers (as needed) 81 | export LUMAAI_API_KEY= 82 | export XAI_API_KEY= 83 | export COHERE_API_KEY= 84 | # ... (other provider keys) 85 | 86 | # Twilio credentials (required for WebRTC voice chat) 87 | export TWILIO_ACCOUNT_SID= 88 | export TWILIO_AUTH_TOKEN= 89 | ``` 90 | 91 | ### Quick Start 92 | ```python 93 | import gradio as gr 94 | import ai_gradio 95 | 96 | # Create a simple chat interface 97 | gr.load( 98 | name='openai:gpt-4-turbo', # or 'gemini:gemini-1.5-flash', 'groq:llama-3.2-70b-chat' 99 | src=ai_gradio.registry, 100 | title='AI Chat', 101 | description='Chat with an AI model' 102 | ).launch() 103 | 104 | # Create a chat interface with Transformers models 105 | gr.load( 106 | name='transformers:phi-4', # or 'transformers:tulu-3', 'transformers:olmo-2-13b' 107 | src=ai_gradio.registry, 108 | title='Local AI Chat', 109 | description='Chat with locally running models' 110 | ).launch() 111 | 112 | # Create a coding assistant with OpenAI 113 | gr.load( 114 | name='openai:gpt-4-turbo', 115 | src=ai_gradio.registry, 116 | coder=True, 117 | title='OpenAI Code Assistant', 118 | description='OpenAI Code Generator' 119 | ).launch() 120 | 121 | # Create a coding assistant with Gemini 122 | gr.load( 123 | name='gemini:gemini-2.0-flash-thinking-exp-1219', # or 'openai:gpt-4-turbo', 'anthropic:claude-3-opus' 124 | src=ai_gradio.registry, 125 | coder=True, 126 | title='Gemini Code Generator', 127 | ).launch() 128 | ``` 129 | 130 | ### Advanced Features 131 | 132 | #### Voice Chat 133 | ```python 134 | gr.load( 135 | name='openai:gpt-4-turbo', 136 | src=ai_gradio.registry, 137 | enable_voice=True, 138 | title='AI Voice Assistant' 139 | ).launch() 140 | ``` 141 | 142 | #### Camera Mode 143 | ```python 144 | # Create a vision-enabled interface with camera support 145 | gr.load( 146 | name='gemini:gemini-2.0-flash-exp', 147 | src=ai_gradio.registry, 148 | camera=True, 149 | ).launch() 150 | ``` 151 | 152 | #### Multi-Provider Interface 153 | ```python 154 | import gradio as gr 155 | import ai_gradio 156 | 157 | with gr.Blocks() as demo: 158 | with gr.Tab("Text"): 159 | gr.load('openai:gpt-4-turbo', src=ai_gradio.registry) 160 | with gr.Tab("Vision"): 161 | gr.load('gemini:gemini-pro-vision', src=ai_gradio.registry) 162 | with gr.Tab("Code"): 163 | gr.load('deepseek:deepseek-coder', src=ai_gradio.registry) 164 | 165 | demo.launch() 166 | ``` 167 | 168 | #### CrewAI Teams 169 | ```python 170 | # Article Creation Team 171 | gr.load( 172 | name='crewai:gpt-4-turbo', 173 | src=ai_gradio.registry, 174 | crew_type='article', 175 | title='AI Writing Team' 176 | ).launch() 177 | ``` 178 | 179 | #### Browser Automation 180 | 181 | ```bash 182 | playwright install 183 | ``` 184 | 185 | use python 3.11+ for browser use 186 | 187 | ```python 188 | import gradio as gr 189 | import ai_gradio 190 | 191 | # Create a browser automation interface 192 | gr.load( 193 | name='browser:gpt-4-turbo', 194 | src=ai_gradio.registry, 195 | title='AI Browser Assistant', 196 | description='Let AI help with web tasks' 197 | ).launch() 198 | ``` 199 | 200 | Example tasks: 201 | - Flight searches on Google Flights 202 | - Weather lookups 203 | - Product price comparisons 204 | - News searches 205 | 206 | #### Swarms Integration 207 | ```python 208 | import gradio as gr 209 | import ai_gradio 210 | 211 | # Create a chat interface with Swarms 212 | gr.load( 213 | name='swarms:gpt-4-turbo', # or other OpenAI models 214 | src=ai_gradio.registry, 215 | agent_name="Stock-Analysis-Agent", # customize agent name 216 | title='Swarms Chat', 217 | description='Chat with an AI agent powered by Swarms' 218 | ).launch() 219 | ``` 220 | 221 | #### Langchain Agents 222 | ```python 223 | import gradio as gr 224 | import ai_gradio 225 | 226 | # Create a Langchain agent interface 227 | gr.load( 228 | name='langchain:gpt-4-turbo', # or other supported models 229 | src=ai_gradio.registry, 230 | title='Langchain Agent', 231 | description='AI agent powered by Langchain' 232 | ).launch() 233 | ``` 234 | 235 | ## Requirements 236 | 237 | ### Core Requirements 238 | - Python 3.10+ 239 | - gradio >= 5.9.1 240 | 241 | ### Optional Features 242 | - Voice Chat: gradio-webrtc, numba==0.60.0, pydub, librosa 243 | - Video Chat: opencv-python, Pillow 244 | - Agent Teams: crewai>=0.1.0, langchain>=0.1.0 245 | 246 | ## Troubleshooting 247 | 248 | ### Authentication Issues 249 | If you encounter 401 errors, verify your API keys: 250 | ```python 251 | import os 252 | 253 | # Set API keys manually if needed 254 | os.environ["OPENAI_API_KEY"] = "your-api-key" 255 | os.environ["GEMINI_API_KEY"] = "your-api-key" 256 | ``` 257 | 258 | ### Provider Installation 259 | If you see "no providers installed" errors: 260 | ```bash 261 | # Install specific provider 262 | pip install 'ai-gradio[provider_name]' 263 | 264 | # Or install all providers 265 | pip install 'ai-gradio[all]' 266 | ``` 267 | 268 | 269 | ## Contributing 270 | Contributions are welcome! Please feel free to submit a Pull Request. 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /ai_gradio/providers/fireworks_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | import gradio as gr 4 | from typing import Callable 5 | from fireworks.client.audio import AudioInference 6 | 7 | __version__ = "0.0.3" 8 | 9 | LANGUAGES = { 10 | "en": "english", "zh": "chinese", "de": "german", "es": "spanish", 11 | "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", 12 | "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", 13 | "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", 14 | "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", 15 | "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", 16 | "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", 17 | "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", 18 | "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", 19 | "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", 20 | "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", 21 | "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", 22 | "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", 23 | "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", 24 | "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", 25 | "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", 26 | "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", 27 | "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", 28 | "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", 29 | "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", 30 | "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", 31 | "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", 32 | "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", 33 | "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", 34 | "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese" 35 | } 36 | 37 | # Language code lookup by name, with additional aliases 38 | TO_LANGUAGE_CODE = { 39 | **{language: code for code, language in LANGUAGES.items()}, 40 | "burmese": "my", 41 | "valencian": "ca", 42 | "flemish": "nl", 43 | "haitian": "ht", 44 | "letzeburgesch": "lb", 45 | "pushto": "ps", 46 | "panjabi": "pa", 47 | "moldavian": "ro", 48 | "moldovan": "ro", 49 | "sinhalese": "si", 50 | "castilian": "es", 51 | "mandarin": "zh" 52 | } 53 | 54 | 55 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 56 | if "whisper" in model_name: 57 | def fn(message, history, audio_input=None): 58 | # Handle audio input if provided 59 | if audio_input: 60 | if not audio_input.endswith('.wav'): 61 | new_path = audio_input + '.wav' 62 | os.rename(audio_input, new_path) 63 | audio_input = new_path 64 | 65 | base_url = ( 66 | "https://audio-turbo.us-virginia-1.direct.fireworks.ai" 67 | if model_name == "whisper-v3-turbo" 68 | else "https://audio-prod.us-virginia-1.direct.fireworks.ai" 69 | ) 70 | 71 | client = AudioInference( 72 | model=model_name, 73 | base_url=base_url, 74 | api_key=api_key 75 | ) 76 | 77 | with open(audio_input, "rb") as f: 78 | audio_data = f.read() 79 | response = client.transcribe(audio=audio_data) 80 | return {"role": "assistant", "content": response.text} 81 | 82 | # Handle text message 83 | if isinstance(message, dict): # Multimodal input 84 | audio_path = message.get("files", [None])[0] or message.get("audio") 85 | text = message.get("text", "") 86 | if audio_path: 87 | # Process audio file 88 | return fn(None, history, audio_path) 89 | return {"role": "assistant", "content": "No audio input provided."} 90 | else: # String input 91 | return {"role": "assistant", "content": "Please upload an audio file or use the microphone to record audio."} 92 | 93 | else: 94 | def fn(message, history, audio_input=None): 95 | # Ignore audio_input for non-whisper models 96 | inputs = preprocess(message, history) 97 | client = OpenAI( 98 | base_url="https://api.fireworks.ai/inference/v1", 99 | api_key=api_key 100 | ) 101 | 102 | model_path = ( 103 | "accounts/fireworks/agents/f1-preview" if model_name == "f1-preview" 104 | else "accounts/fireworks/agents/f1-mini-preview" if model_name == "f1-mini" 105 | else f"accounts/fireworks/models/{model_name}" 106 | ) 107 | 108 | completion = client.chat.completions.create( 109 | model=model_path, 110 | messages=[{"role": "user", "content": inputs["prompt"]}], 111 | stream=True, 112 | max_tokens=1024, 113 | temperature=0.7, 114 | top_p=1, 115 | ) 116 | 117 | response_text = "" 118 | for chunk in completion: 119 | delta = chunk.choices[0].delta.content or "" 120 | response_text += delta 121 | yield {"role": "assistant", "content": response_text} 122 | 123 | return fn 124 | 125 | 126 | def get_interface_args(pipeline): 127 | if pipeline == "audio": 128 | inputs = [ 129 | gr.Audio(sources=["microphone"], type="filepath"), 130 | gr.Radio(["transcribe"], label="Task", value="transcribe"), 131 | ] 132 | outputs = "text" 133 | 134 | def preprocess(audio_path, task, text, history): 135 | if audio_path and not audio_path.endswith('.wav'): 136 | new_path = audio_path + '.wav' 137 | os.rename(audio_path, new_path) 138 | audio_path = new_path 139 | return {"role": "user", "content": {"audio_path": audio_path, "task": task}} 140 | 141 | def postprocess(text): 142 | return {"role": "assistant", "content": text} 143 | 144 | elif pipeline == "chat": 145 | inputs = gr.Textbox(label="Message") 146 | outputs = "text" 147 | 148 | def preprocess(message, history): 149 | return {"prompt": message} 150 | 151 | def postprocess(response): 152 | return response 153 | 154 | else: 155 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 156 | 157 | return inputs, outputs, preprocess, postprocess 158 | 159 | 160 | def get_pipeline(model_name): 161 | if "whisper" in model_name: 162 | return "audio" 163 | return "chat" 164 | 165 | 166 | def registry(name: str, token: str | None = None, **kwargs): 167 | """ 168 | Create a Gradio Interface for a model on Fireworks. 169 | Can be used directly or with gr.load: 170 | 171 | Example: 172 | # Direct usage 173 | interface = fireworks_gradio.registry("whisper-v3", token="your-api-key") 174 | interface.launch() 175 | 176 | # With gr.load 177 | gr.load( 178 | name='whisper-v3', 179 | src=fireworks_gradio.registry, 180 | ).launch() 181 | 182 | Parameters: 183 | name (str): The name of the OpenAI model. 184 | token (str, optional): The API key for OpenAI. 185 | """ 186 | # Make the function compatible with gr.load by accepting name as a positional argument 187 | if not isinstance(name, str): 188 | raise ValueError("Model name must be a string") 189 | 190 | api_key = token or os.environ.get("FIREWORKS_API_KEY") 191 | if not api_key: 192 | raise ValueError("FIREWORKS_API_KEY environment variable is not set.") 193 | 194 | pipeline = get_pipeline(name) 195 | _, _, preprocess, postprocess = get_interface_args(pipeline) 196 | fn = get_fn(name, preprocess, postprocess, api_key) 197 | 198 | description = kwargs.pop("description", None) 199 | if "whisper" in name: 200 | description = (description or "") + """ 201 | \n\nSupported inputs: 202 | - Upload audio files using the textbox 203 | - Record audio using the microphone 204 | """ 205 | 206 | with gr.Blocks() as interface: 207 | chatbot = gr.Chatbot(type="messages") 208 | with gr.Row(): 209 | mic = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio") 210 | 211 | def process_audio(audio_path): 212 | if audio_path: 213 | if not audio_path.endswith('.wav'): 214 | new_path = audio_path + '.wav' 215 | os.rename(audio_path, new_path) 216 | audio_path = new_path 217 | 218 | # Create message format expected by fn 219 | message = {"files": [audio_path], "text": ""} 220 | response = fn(message, []) 221 | 222 | return [ 223 | {"role": "user", "content": gr.Audio(value=audio_path)}, 224 | {"role": "assistant", "content": response["content"]} 225 | ] 226 | return [] 227 | 228 | mic.change( 229 | fn=process_audio, 230 | inputs=[mic], 231 | outputs=[chatbot] 232 | ) 233 | 234 | else: 235 | # For non-whisper models, use regular ChatInterface 236 | interface = gr.ChatInterface( 237 | fn=fn, 238 | type="messages", 239 | description=description, 240 | **kwargs 241 | ) 242 | 243 | return interface 244 | 245 | 246 | # Add these to make the module more discoverable 247 | MODELS = [ 248 | "whisper-v3", 249 | "whisper-v3-turbo", 250 | "f1-preview", 251 | "f1-mini", 252 | # Add other supported models here 253 | ] 254 | 255 | def get_all_models(): 256 | """Returns a list of all supported models.""" 257 | return MODELS -------------------------------------------------------------------------------- /ai_gradio/providers/hyperbolic_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | import gradio as gr 4 | from typing import Callable 5 | import base64 6 | import re 7 | import modelscope_studio.components.base as ms 8 | import modelscope_studio.components.legacy as legacy 9 | import modelscope_studio.components.antd as antd 10 | 11 | # Constants for coder interface 12 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications. 13 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript. 14 | Focus on: 15 | - Writing clear, maintainable code 16 | - Following best practices 17 | - Creating responsive designs 18 | - Adding appropriate styling and interactivity 19 | Return only the complete HTML code without any additional explanation.""" 20 | 21 | DEMO_LIST = [ 22 | { 23 | "card": {"index": 0}, 24 | "title": "Simple Button", 25 | "description": "Create a button that changes color when clicked" 26 | }, 27 | { 28 | "card": {"index": 1}, 29 | "title": "Todo List", 30 | "description": "Create a simple todo list with add/remove functionality" 31 | }, 32 | { 33 | "card": {"index": 2}, 34 | "title": "Timer App", 35 | "description": "Create a countdown timer with start/pause/reset controls" 36 | } 37 | ] 38 | 39 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str, base_url: str = None): 40 | def fn(message, history): 41 | inputs = preprocess(message, history) 42 | 43 | client = OpenAI( 44 | api_key=api_key, 45 | base_url="https://api.hyperbolic.xyz/v1" 46 | ) 47 | 48 | completion = client.chat.completions.create( 49 | model=model_name, 50 | messages=[ 51 | *inputs["messages"] 52 | ], 53 | stream=True, 54 | temperature=0.7, 55 | max_tokens=512, 56 | top_p=0.9, 57 | ) 58 | response_text = "" 59 | for chunk in completion: 60 | delta = chunk.choices[0].delta.content or "" 61 | # Replace DeepSeek-R1 special tokens 62 | if "deepseek" in model_name.lower(): 63 | delta = delta.replace("", "[think]").replace("", "[/think]") 64 | response_text += delta 65 | yield postprocess(response_text) 66 | 67 | return fn 68 | 69 | 70 | def get_interface_args(pipeline): 71 | if pipeline == "chat": 72 | inputs = None 73 | outputs = None 74 | 75 | def preprocess(message, history): 76 | messages = [] 77 | for user_msg, assistant_msg in history: 78 | messages.append({"role": "user", "content": user_msg}) 79 | messages.append({"role": "assistant", "content": assistant_msg}) 80 | messages.append({"role": "user", "content": message}) 81 | return {"messages": messages} 82 | 83 | postprocess = lambda x: x # No post-processing needed 84 | else: 85 | # Add other pipeline types when they will be needed 86 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 87 | return inputs, outputs, preprocess, postprocess 88 | 89 | 90 | def get_pipeline(model_name): 91 | # Determine the pipeline type based on the model name 92 | # For simplicity, assuming all models are chat models at the moment 93 | return "chat" 94 | 95 | 96 | def registry(name: str, token: str | None = None, base_url: str | None = None, coder: bool = False, **kwargs): 97 | api_key = token or os.environ.get("HYPERBOLIC_API_KEY") 98 | if not api_key: 99 | raise ValueError("API key is not set. Please provide a token or set HYPERBOLIC_API_KEY environment variable.") 100 | 101 | if coder: 102 | interface = gr.Blocks(css=""" 103 | .left_header { 104 | text-align: center; 105 | margin-bottom: 20px; 106 | } 107 | .right_panel { 108 | background: white; 109 | border-radius: 8px; 110 | overflow: hidden; 111 | box-shadow: 0 2px 8px rgba(0,0,0,0.15); 112 | } 113 | .render_header { 114 | background: #f5f5f5; 115 | padding: 8px; 116 | border-bottom: 1px solid #e8e8e8; 117 | } 118 | .header_btn { 119 | display: inline-block; 120 | width: 12px; 121 | height: 12px; 122 | border-radius: 50%; 123 | margin-right: 8px; 124 | background: #ff5f56; 125 | } 126 | .header_btn:nth-child(2) { 127 | background: #ffbd2e; 128 | } 129 | .header_btn:nth-child(3) { 130 | background: #27c93f; 131 | } 132 | .right_content { 133 | padding: 24px; 134 | height: 920px; 135 | display: flex; 136 | align-items: center; 137 | justify-content: center; 138 | } 139 | .html_content { 140 | height: 920px; 141 | width: 100%; 142 | } 143 | .history_chatbot { 144 | height: 100%; 145 | } 146 | """) 147 | 148 | with interface: 149 | history = gr.State([]) 150 | setting = gr.State({"system": SystemPrompt}) 151 | 152 | with ms.Application() as app: 153 | with antd.ConfigProvider(): 154 | with antd.Row(gutter=[32, 12]) as layout: 155 | # Left Column 156 | with antd.Col(span=24, md=8): 157 | with antd.Flex(vertical=True, gap="middle", wrap=True): 158 | header = gr.HTML(""" 159 |
160 |

Hyperbolic Code Generator

161 |
162 | """) 163 | 164 | input = antd.InputTextarea( 165 | size="large", 166 | allow_clear=True, 167 | placeholder="Describe the web application you want to create" 168 | ) 169 | btn = antd.Button("Generate", type="primary", size="large") 170 | clear_btn = antd.Button("Clear History", type="default", size="large") 171 | 172 | antd.Divider("Examples") 173 | with antd.Flex(gap="small", wrap=True): 174 | with ms.Each(DEMO_LIST): 175 | with antd.Card(hoverable=True, as_item="card") as demoCard: 176 | antd.CardMeta() 177 | 178 | antd.Divider("Settings") 179 | with antd.Flex(gap="small", wrap=True): 180 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default") 181 | codeBtn = antd.Button("🧑‍💻 View Code", type="default") 182 | historyBtn = antd.Button("📜 History", type="default") 183 | 184 | # Modals and Drawers 185 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal: 186 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True) 187 | 188 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer: 189 | code_output = legacy.Markdown() 190 | 191 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer: 192 | history_output = legacy.Chatbot( 193 | show_label=False, 194 | height=960, 195 | elem_classes="history_chatbot" 196 | ) 197 | 198 | # Right Column 199 | with antd.Col(span=24, md=16): 200 | with ms.Div(elem_classes="right_panel"): 201 | gr.HTML(''' 202 |
203 | 204 | 205 | 206 |
207 | ''') 208 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab: 209 | with antd.Tabs.Item(key="empty"): 210 | empty = antd.Empty( 211 | description="Enter your request to generate code", 212 | elem_classes="right_content" 213 | ) 214 | with antd.Tabs.Item(key="loading"): 215 | loading = antd.Spin( 216 | True, 217 | tip="Generating code...", 218 | size="large", 219 | elem_classes="right_content" 220 | ) 221 | with antd.Tabs.Item(key="render"): 222 | preview = gr.HTML(elem_classes="html_content") 223 | 224 | # Helper functions 225 | def demo_card_click(e: gr.EventData): 226 | index = e._data['component']['index'] 227 | return DEMO_LIST[index]['description'] 228 | 229 | def send_to_preview(code): 230 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8') 231 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" 232 | return f'' 233 | 234 | def remove_code_block(text): 235 | pattern = r'```html\n(.+?)\n```' 236 | match = re.search(pattern, text, re.DOTALL) 237 | if match: 238 | return match.group(1).strip() 239 | return text.strip() 240 | 241 | def generate_code(query, setting, history): 242 | client = OpenAI( 243 | api_key=api_key, 244 | base_url="https://api.hyperbolic.xyz/v1" 245 | ) 246 | 247 | messages = [ 248 | {"role": "system", "content": setting["system"]}, 249 | ] 250 | 251 | for h in history: 252 | messages.append({"role": "user", "content": h[0]}) 253 | messages.append({"role": "assistant", "content": h[1]}) 254 | 255 | messages.append({"role": "user", "content": query}) 256 | 257 | response = client.chat.completions.create( 258 | model=name, 259 | messages=messages, 260 | stream=True, 261 | temperature=0.7, 262 | max_tokens=2048, 263 | ) 264 | 265 | response_text = "" 266 | for chunk in response: 267 | if chunk.choices[0].delta.content: 268 | response_text += chunk.choices[0].delta.content 269 | yield ( 270 | response_text, 271 | history, 272 | None, 273 | gr.update(active_key="loading"), 274 | gr.update(open=True) 275 | ) 276 | 277 | clean_code = remove_code_block(response_text) 278 | new_history = history + [(query, response_text)] 279 | 280 | yield ( 281 | response_text, 282 | new_history, 283 | send_to_preview(clean_code), 284 | gr.update(active_key="render"), 285 | gr.update(open=False) 286 | ) 287 | 288 | # Wire up event handlers 289 | demoCard.click(demo_card_click, outputs=[input]) 290 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal]) 291 | system_prompt_modal.ok( 292 | lambda input: ({"system": input}, gr.update(open=False)), 293 | inputs=[systemPromptInput], 294 | outputs=[setting, system_prompt_modal] 295 | ) 296 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal]) 297 | 298 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer]) 299 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer]) 300 | 301 | historyBtn.click( 302 | lambda h: (gr.update(open=True), h), 303 | inputs=[history], 304 | outputs=[history_drawer, history_output] 305 | ) 306 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer]) 307 | 308 | btn.click( 309 | generate_code, 310 | inputs=[input, setting, history], 311 | outputs=[code_output, history, preview, state_tab, code_drawer] 312 | ) 313 | 314 | clear_btn.click(lambda: [], outputs=[history]) 315 | 316 | return interface 317 | 318 | # Regular chat interface 319 | pipeline = get_pipeline(name) 320 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 321 | fn = get_fn(name, preprocess, postprocess, api_key, base_url) 322 | 323 | if pipeline == "chat": 324 | interface = gr.ChatInterface(fn=fn, **kwargs) 325 | else: 326 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 327 | 328 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/perplexity_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | import gradio as gr 4 | from typing import Callable 5 | import re 6 | import base64 7 | import modelscope_studio.components.base as ms 8 | import modelscope_studio.components.legacy as legacy 9 | import modelscope_studio.components.antd as antd 10 | 11 | __version__ = "0.0.1" 12 | 13 | # Add these constants at the top of the file 14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications. 15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript. 16 | Focus on: 17 | - Writing clear, maintainable code 18 | - Following best practices 19 | - Creating responsive designs 20 | - Adding appropriate styling and interactivity 21 | Return only the complete HTML code without any additional explanation.""" 22 | 23 | DEMO_LIST = [ 24 | { 25 | "card": {"index": 0}, 26 | "title": "Simple Button", 27 | "description": "Create a button that changes color when clicked" 28 | }, 29 | { 30 | "card": {"index": 1}, 31 | "title": "Todo List", 32 | "description": "Create a simple todo list with add/remove functionality" 33 | }, 34 | { 35 | "card": {"index": 2}, 36 | "title": "Timer App", 37 | "description": "Create a countdown timer with start/pause/reset controls" 38 | } 39 | ] 40 | 41 | 42 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 43 | def fn(message, history): 44 | inputs = preprocess(message, history) 45 | client = OpenAI( 46 | api_key=api_key, 47 | base_url="https://api.perplexity.ai" 48 | ) 49 | completion = client.chat.completions.create( 50 | model=model_name, 51 | messages=inputs["messages"], 52 | stream=True, 53 | ) 54 | response_text = "" 55 | for chunk in completion: 56 | delta = chunk.choices[0].delta.content or "" 57 | # Replace DeepSeek special tokens if present 58 | if "deepseek" in model_name.lower(): 59 | delta = delta.replace("", "[think]").replace("", "[/think]") 60 | response_text += delta 61 | yield postprocess(response_text) 62 | 63 | return fn 64 | 65 | 66 | def get_interface_args(pipeline): 67 | if pipeline == "chat": 68 | inputs = None 69 | outputs = None 70 | 71 | def preprocess(message, history): 72 | messages = [] 73 | for user_msg, assistant_msg in history: 74 | messages.append({"role": "user", "content": user_msg}) 75 | messages.append({"role": "assistant", "content": assistant_msg}) 76 | messages.append({"role": "user", "content": message}) 77 | return {"messages": messages} 78 | 79 | postprocess = lambda x: x # No post-processing needed 80 | else: 81 | # Add other pipeline types when they will be needed 82 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 83 | return inputs, outputs, preprocess, postprocess 84 | 85 | 86 | def get_pipeline(model_name): 87 | # Determine the pipeline type based on the model name 88 | # For simplicity, assuming all models are chat models at the moment 89 | return "chat" 90 | 91 | 92 | def registry( 93 | name: str, 94 | token: str | None = None, 95 | examples: list | None = None, 96 | coder: bool = False, # Add coder parameter 97 | **kwargs 98 | ): 99 | """ 100 | Create a Gradio Interface for a model on Perplexity. 101 | 102 | Parameters: 103 | - name (str): The name of the model 104 | - token (str, optional): The API key 105 | - examples (list, optional): Example inputs 106 | - coder (bool, optional): Whether to use coding interface 107 | """ 108 | api_key = token or os.environ.get("PERPLEXITY_API_KEY") 109 | if not api_key: 110 | raise ValueError("PERPLEXITY_API_KEY environment variable is not set.") 111 | 112 | if coder: 113 | interface = gr.Blocks(css=""" 114 | .left_header { 115 | text-align: center; 116 | margin-bottom: 20px; 117 | } 118 | 119 | .right_panel { 120 | background: white; 121 | border-radius: 8px; 122 | overflow: hidden; 123 | box-shadow: 0 2px 8px rgba(0,0,0,0.15); 124 | } 125 | 126 | .render_header { 127 | background: #f5f5f5; 128 | padding: 8px; 129 | border-bottom: 1px solid #e8e8e8; 130 | } 131 | 132 | .header_btn { 133 | display: inline-block; 134 | width: 12px; 135 | height: 12px; 136 | border-radius: 50%; 137 | margin-right: 8px; 138 | background: #ff5f56; 139 | } 140 | 141 | .header_btn:nth-child(2) { 142 | background: #ffbd2e; 143 | } 144 | 145 | .header_btn:nth-child(3) { 146 | background: #27c93f; 147 | } 148 | 149 | .right_content { 150 | padding: 24px; 151 | height: 920px; 152 | display: flex; 153 | align-items: center; 154 | justify-content: center; 155 | } 156 | 157 | .html_content { 158 | height: 920px; 159 | width: 100%; 160 | } 161 | 162 | .history_chatbot { 163 | height: 100%; 164 | } 165 | """) 166 | with interface: 167 | history = gr.State([]) 168 | setting = gr.State({"system": SystemPrompt}) 169 | 170 | with ms.Application() as app: 171 | with antd.ConfigProvider(): 172 | with antd.Row(gutter=[32, 12]) as layout: 173 | # Left Column 174 | with antd.Col(span=24, md=8): 175 | with antd.Flex(vertical=True, gap="middle", wrap=True): 176 | header = gr.HTML(""" 177 |
178 |

Perplexity Code Generator

179 |
180 | """) 181 | 182 | input = antd.InputTextarea( 183 | size="large", 184 | allow_clear=True, 185 | placeholder="Describe the web application you want to create" 186 | ) 187 | btn = antd.Button("Generate", type="primary", size="large") 188 | clear_btn = antd.Button("Clear History", type="default", size="large") 189 | 190 | antd.Divider("Examples") 191 | with antd.Flex(gap="small", wrap=True): 192 | with ms.Each(DEMO_LIST): 193 | with antd.Card(hoverable=True, as_item="card") as demoCard: 194 | antd.CardMeta() 195 | 196 | antd.Divider("Settings") 197 | with antd.Flex(gap="small", wrap=True): 198 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default") 199 | codeBtn = antd.Button("🧑‍💻 View Code", type="default") 200 | historyBtn = antd.Button("📜 History", type="default") 201 | 202 | # Modals and Drawers 203 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal: 204 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True) 205 | 206 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer: 207 | code_output = legacy.Markdown() 208 | 209 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer: 210 | history_output = legacy.Chatbot( 211 | show_label=False, 212 | height=960, 213 | elem_classes="history_chatbot" 214 | ) 215 | 216 | # Right Column 217 | with antd.Col(span=24, md=16): 218 | with ms.Div(elem_classes="right_panel"): 219 | gr.HTML(''' 220 |
221 | 222 | 223 | 224 |
225 | ''') 226 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab: 227 | with antd.Tabs.Item(key="empty"): 228 | empty = antd.Empty( 229 | description="Enter your request to generate code", 230 | elem_classes="right_content" 231 | ) 232 | with antd.Tabs.Item(key="loading"): 233 | loading = antd.Spin( 234 | True, 235 | tip="Generating code...", 236 | size="large", 237 | elem_classes="right_content" 238 | ) 239 | with antd.Tabs.Item(key="render"): 240 | preview = gr.HTML(elem_classes="html_content") 241 | 242 | # Event Handlers 243 | def demo_card_click(e: gr.EventData): 244 | index = e._data['component']['index'] 245 | return DEMO_LIST[index]['description'] 246 | 247 | def send_to_preview(code): 248 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8') 249 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" 250 | return f'' 251 | 252 | def remove_code_block(text): 253 | # Look for HTML code block first 254 | html_pattern = r'```html\n(.*?)\n```' 255 | match = re.search(html_pattern, text, re.DOTALL) 256 | if match: 257 | return match.group(1).strip() 258 | 259 | # If no HTML block found, look for any code block 260 | code_pattern = r'```(?:\w+)?\n(.*?)\n```' 261 | match = re.search(code_pattern, text, re.DOTALL) 262 | if match: 263 | return match.group(1).strip() 264 | 265 | # If no code block found, return the whole text 266 | return text.strip() 267 | 268 | def generate_code(query, setting, history): 269 | messages = [] 270 | messages.append({"role": "system", "content": setting["system"]}) 271 | 272 | for h in history: 273 | messages.append({"role": "user", "content": h[0]}) 274 | messages.append({"role": "assistant", "content": h[1]}) 275 | 276 | messages.append({"role": "user", "content": query}) 277 | 278 | client = OpenAI( 279 | api_key=api_key, 280 | base_url="https://api.perplexity.ai" 281 | ) 282 | 283 | response = client.chat.completions.create( 284 | model=name, 285 | messages=messages, 286 | stream=True 287 | ) 288 | 289 | response_text = "" 290 | for chunk in response: 291 | if chunk.choices[0].delta.content: 292 | response_text += chunk.choices[0].delta.content 293 | # Format the code for display 294 | formatted_response = f"```html\n{response_text}\n```" 295 | # Return all 5 required outputs 296 | yield ( 297 | formatted_response, # code_output (modelscopemarkdown) 298 | history, # state 299 | None, # preview (html) 300 | gr.update(active_key="loading"), # state_tab (antdtabs) 301 | gr.update(open=True) # code_drawer (antddrawer) 302 | ) 303 | 304 | clean_code = remove_code_block(response_text) 305 | new_history = history + [(query, response_text)] 306 | 307 | # Final yield with all outputs 308 | yield ( 309 | f"```html\n{clean_code}\n```", # code_output (modelscopemarkdown) 310 | new_history, # state 311 | send_to_preview(clean_code), # preview (html) 312 | gr.update(active_key="render"), # state_tab (antdtabs) 313 | gr.update(open=False) # code_drawer (antddrawer) 314 | ) 315 | 316 | # Wire up event handlers 317 | demoCard.click(demo_card_click, outputs=[input]) 318 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal]) 319 | system_prompt_modal.ok( 320 | lambda input: ({"system": input}, gr.update(open=False)), 321 | inputs=[systemPromptInput], 322 | outputs=[setting, system_prompt_modal] 323 | ) 324 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal]) 325 | 326 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer]) 327 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer]) 328 | 329 | historyBtn.click( 330 | lambda h: (gr.update(open=True), h), 331 | inputs=[history], 332 | outputs=[history_drawer, history_output] 333 | ) 334 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer]) 335 | 336 | btn.click( 337 | generate_code, 338 | inputs=[input, setting, history], 339 | outputs=[code_output, history, preview, state_tab, code_drawer] 340 | ) 341 | 342 | clear_btn.click(lambda: [], outputs=[history]) 343 | 344 | return interface 345 | 346 | # Continue with existing chat interface code... 347 | pipeline = get_pipeline(name) 348 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 349 | fn = get_fn(name, preprocess, postprocess, api_key) 350 | 351 | if examples: 352 | kwargs["examples"] = examples 353 | 354 | if pipeline == "chat": 355 | interface = gr.ChatInterface(fn=fn, **kwargs) 356 | else: 357 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 358 | 359 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/crewai_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Generator 3 | import gradio as gr 4 | from crewai import Agent, Task, Crew 5 | from crewai_tools import ScrapeWebsiteTool 6 | import queue 7 | import threading 8 | import asyncio 9 | 10 | class MessageQueue: 11 | def __init__(self): 12 | self.message_queue = queue.Queue() 13 | self.last_agent = None 14 | 15 | def add_message(self, message: Dict): 16 | print(f"Adding message to queue: {message}") 17 | self.message_queue.put(message) 18 | 19 | def get_messages(self) -> List[Dict]: 20 | messages = [] 21 | while not self.message_queue.empty(): 22 | messages.append(self.message_queue.get()) 23 | return messages 24 | 25 | class CrewFactory: 26 | @staticmethod 27 | def create_crew_config(crew_type: str, topic: str) -> dict: 28 | configs = { 29 | "article": { 30 | "agents": [ 31 | { 32 | "role": "Content Planner", 33 | "goal": f"Plan engaging and factually accurate content on {topic}", 34 | "backstory": "Expert content planner with focus on creating engaging outlines", 35 | "tasks": [ 36 | """Create a detailed content plan for an article by: 37 | 1. Prioritizing the latest trends and key players 38 | 2. Identifying the target audience 39 | 3. Developing a detailed content outline 40 | 4. Including SEO keywords and sources""" 41 | ] 42 | }, 43 | { 44 | "role": "Content Writer", 45 | "goal": f"Write insightful piece about {topic}", 46 | "backstory": "Expert content writer with focus on engaging articles", 47 | "tasks": [ 48 | """1. Use the content plan to craft a compelling blog post 49 | 2. Incorporate SEO keywords naturally 50 | 3. Create proper structure with introduction, body, and conclusion""" 51 | ] 52 | }, 53 | { 54 | "role": "Editor", 55 | "goal": "Polish and refine the content", 56 | "backstory": "Expert editor with eye for detail and clarity", 57 | "tasks": [ 58 | """1. Review for clarity and coherence 59 | 2. Correct any errors 60 | 3. Ensure consistent tone and style""" 61 | ] 62 | } 63 | ] 64 | }, 65 | "support": { 66 | "agents": [ 67 | { 68 | "role": "Senior Support Representative", 69 | "goal": "Be the most helpful support representative", 70 | "backstory": "Expert at analyzing questions and providing support", 71 | "tasks": [ 72 | f"""Analyze this inquiry thoroughly: {topic} 73 | Provide detailed support response.""" 74 | ] 75 | }, 76 | { 77 | "role": "Support Quality Assurance", 78 | "goal": "Ensure highest quality of support responses", 79 | "backstory": "Expert at reviewing and improving support responses", 80 | "tasks": [ 81 | """Review and improve the support response to ensure it's: 82 | 1. Comprehensive and helpful 83 | 2. Properly formatted with clear structure""" 84 | ] 85 | } 86 | ] 87 | } 88 | } 89 | return configs.get(crew_type, configs["article"]) 90 | 91 | class CrewManager: 92 | def __init__(self, api_key: str = None): 93 | self.api_key = api_key 94 | self.message_queue = MessageQueue() 95 | self.agents = [] 96 | self.current_agent = None 97 | self.scrape_tool = None 98 | 99 | def initialize_agents(self, crew_type: str, topic: str, website_url: str = None): 100 | if not self.api_key: 101 | raise ValueError("OpenAI API key is required") 102 | 103 | os.environ["OPENAI_API_KEY"] = self.api_key 104 | if website_url: 105 | self.scrape_tool = ScrapeWebsiteTool(website_url=website_url) 106 | 107 | # Get crew configuration 108 | config = CrewFactory.create_crew_config(crew_type, topic) 109 | 110 | # Initialize agents from configuration 111 | self.agents = [] 112 | for agent_config in config["agents"]: 113 | agent = Agent( 114 | role=agent_config["role"], 115 | goal=agent_config["goal"], 116 | backstory=agent_config["backstory"], 117 | allow_delegation=False, 118 | verbose=True 119 | ) 120 | self.agents.append((agent, agent_config["tasks"])) 121 | 122 | def create_tasks(self, topic: str) -> List[Task]: 123 | tasks = [] 124 | for agent, task_descriptions in self.agents: 125 | for task_desc in task_descriptions: 126 | task = Task( 127 | description=task_desc, 128 | expected_output="Detailed and well-formatted response", 129 | agent=agent, 130 | tools=[self.scrape_tool] if self.scrape_tool else [] 131 | ) 132 | tasks.append(task) 133 | return tasks 134 | 135 | async def process_support(self, inquiry: str, website_url: str, crew_type: str) -> Generator[List[Dict], None, None]: 136 | def add_agent_messages(agent_name: str, tasks: str, emoji: str = "🤖"): 137 | self.message_queue.add_message({ 138 | "role": "assistant", 139 | "content": agent_name, 140 | "metadata": {"title": f"{emoji} {agent_name}"} 141 | }) 142 | 143 | self.message_queue.add_message({ 144 | "role": "assistant", 145 | "content": tasks, 146 | "metadata": {"title": f"📋 Task for {agent_name}"} 147 | }) 148 | 149 | def setup_next_agent(current_agent: str): 150 | if crew_type == "support": 151 | if current_agent == "Senior Support Representative": 152 | self.current_agent = "Support Quality Assurance Specialist" 153 | add_agent_messages( 154 | "Support Quality Assurance Specialist", 155 | "Review and improve the support response" 156 | ) 157 | elif crew_type == "article": 158 | if current_agent == "Content Planner": 159 | self.current_agent = "Content Writer" 160 | add_agent_messages( 161 | "Content Writer", 162 | "Write the article based on the content plan" 163 | ) 164 | elif current_agent == "Content Writer": 165 | self.current_agent = "Editor" 166 | add_agent_messages( 167 | "Editor", 168 | "Review and polish the article" 169 | ) 170 | 171 | def task_callback(task_output): 172 | raw_output = task_output.raw 173 | if "## Final Answer:" in raw_output: 174 | content = raw_output.split("## Final Answer:")[1].strip() 175 | else: 176 | content = raw_output.strip() 177 | 178 | if self.current_agent == "Support Quality Assurance Specialist": 179 | self.message_queue.add_message({ 180 | "role": "assistant", 181 | "content": "Final response is ready!", 182 | "metadata": {"title": "✅ Final Response"} 183 | }) 184 | 185 | formatted_content = content 186 | formatted_content = formatted_content.replace("\n#", "\n\n#") 187 | formatted_content = formatted_content.replace("\n-", "\n\n-") 188 | formatted_content = formatted_content.replace("\n*", "\n\n*") 189 | formatted_content = formatted_content.replace("\n1.", "\n\n1.") 190 | formatted_content = formatted_content.replace("\n\n\n", "\n\n") 191 | 192 | self.message_queue.add_message({ 193 | "role": "assistant", 194 | "content": formatted_content 195 | }) 196 | else: 197 | self.message_queue.add_message({ 198 | "role": "assistant", 199 | "content": content, 200 | "metadata": {"title": f"✨ Output from {self.current_agent}"} 201 | }) 202 | setup_next_agent(self.current_agent) 203 | 204 | try: 205 | self.initialize_agents(crew_type, inquiry, website_url) 206 | # Set initial agent based on crew type 207 | self.current_agent = "Senior Support Representative" if crew_type == "support" else "Content Planner" 208 | 209 | yield [{ 210 | "role": "assistant", 211 | "content": "Starting to process your inquiry...", 212 | "metadata": {"title": "🚀 Process Started"} 213 | }] 214 | 215 | # Set initial task message based on crew type 216 | if crew_type == "support": 217 | add_agent_messages( 218 | "Senior Support Representative", 219 | "Analyze inquiry and provide comprehensive support" 220 | ) 221 | else: 222 | add_agent_messages( 223 | "Content Planner", 224 | "Create a detailed content plan for the article" 225 | ) 226 | 227 | crew = Crew( 228 | agents=[agent for agent, _ in self.agents], 229 | tasks=self.create_tasks(inquiry), 230 | verbose=True, 231 | task_callback=task_callback 232 | ) 233 | 234 | def run_crew(): 235 | try: 236 | crew.kickoff() 237 | except Exception as e: 238 | print(f"Error in crew execution: {str(e)}") 239 | self.message_queue.add_message({ 240 | "role": "assistant", 241 | "content": f"Error: {str(e)}", 242 | "metadata": {"title": "❌ Error"} 243 | }) 244 | 245 | thread = threading.Thread(target=run_crew) 246 | thread.start() 247 | 248 | while thread.is_alive() or not self.message_queue.message_queue.empty(): 249 | messages = self.message_queue.get_messages() 250 | if messages: 251 | yield messages 252 | await asyncio.sleep(0.1) 253 | 254 | except Exception as e: 255 | print(f"Error in process_support: {str(e)}") 256 | yield [{ 257 | "role": "assistant", 258 | "content": f"An error occurred: {str(e)}", 259 | "metadata": {"title": "❌ Error"} 260 | }] 261 | 262 | def registry(name: str, token: str | None = None, crew_type: str = "support", **kwargs): 263 | has_api_key = bool(token or os.environ.get("OPENAI_API_KEY")) 264 | crew_manager = None 265 | 266 | with gr.Blocks(theme=gr.themes.Soft()) as demo: 267 | title = "📝 AI Article Writing Crew" if crew_type == "article" else "🤖 CrewAI Assistant" 268 | gr.Markdown(f"# {title}") 269 | 270 | api_key = gr.Textbox( 271 | label="OpenAI API Key", 272 | type="password", 273 | placeholder="Type your OpenAI API key and press Enter...", 274 | interactive=True, 275 | visible=not has_api_key 276 | ) 277 | 278 | chatbot = gr.Chatbot( 279 | label="Writing Process" if crew_type == "article" else "Process", 280 | height=700 if crew_type == "article" else 600, 281 | show_label=True, 282 | visible=has_api_key, 283 | avatar_images=(None, "https://avatars.githubusercontent.com/u/170677839?v=4"), 284 | render_markdown=True, 285 | type="messages" 286 | ) 287 | 288 | with gr.Row(equal_height=True): 289 | topic = gr.Textbox( 290 | label="Article Topic" if crew_type == "article" else "Topic/Question", 291 | placeholder="Enter topic..." if crew_type == "article" else "Enter your question...", 292 | scale=4, 293 | visible=has_api_key 294 | ) 295 | website_url = gr.Textbox( 296 | label="Documentation URL", 297 | placeholder="Enter documentation URL to search...", 298 | scale=4, 299 | visible=(has_api_key) and crew_type == "support" 300 | ) 301 | btn = gr.Button( 302 | "Write Article" if crew_type == "article" else "Start", 303 | variant="primary", 304 | scale=1, 305 | visible=has_api_key 306 | ) 307 | 308 | async def process_input(topic, website_url, history, api_key): 309 | nonlocal crew_manager 310 | effective_api_key = token or api_key or os.environ.get("OPENAI_API_KEY") 311 | 312 | if not effective_api_key: 313 | yield [ 314 | {"role": "user", "content": f"Question: {topic}\nDocumentation: {website_url}"}, 315 | {"role": "assistant", "content": "Please provide an OpenAI API key."} 316 | ] 317 | return # Early return without value 318 | 319 | if crew_manager is None: 320 | crew_manager = CrewManager(api_key=effective_api_key) 321 | 322 | messages = [{"role": "user", "content": f"Question: {topic}\nDocumentation: {website_url}"}] 323 | yield messages 324 | 325 | try: 326 | async for new_messages in crew_manager.process_support(topic, website_url, crew_type): 327 | for msg in new_messages: 328 | if "metadata" in msg: 329 | messages.append({ 330 | "role": msg["role"], 331 | "content": msg["content"], 332 | "metadata": {"title": msg["metadata"]["title"]} 333 | }) 334 | else: 335 | messages.append({ 336 | "role": msg["role"], 337 | "content": msg["content"] 338 | }) 339 | yield messages 340 | except Exception as e: 341 | messages.append({ 342 | "role": "assistant", 343 | "content": f"An error occurred: {str(e)}", 344 | "metadata": {"title": "❌ Error"} 345 | }) 346 | yield messages 347 | 348 | def show_interface(): 349 | return { 350 | api_key: gr.Textbox(visible=False), 351 | chatbot: gr.Chatbot(visible=True), 352 | topic: gr.Textbox(visible=True), 353 | website_url: gr.Textbox(visible=True), 354 | btn: gr.Button(visible=True) 355 | } 356 | 357 | if not has_api_key: 358 | api_key.submit(show_interface, None, [api_key, chatbot, topic, website_url, btn]) 359 | 360 | btn.click(process_input, [topic, website_url, chatbot, api_key], [chatbot]) 361 | topic.submit(process_input, [topic, website_url, chatbot, api_key], [chatbot]) 362 | 363 | return demo -------------------------------------------------------------------------------- /ai_gradio/providers/mistral_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from mistralai import Mistral 4 | import gradio as gr 5 | from typing import Callable 6 | from urllib.parse import urlparse 7 | import modelscope_studio.components.base as ms 8 | import modelscope_studio.components.legacy as legacy 9 | import modelscope_studio.components.antd as antd 10 | import re 11 | 12 | __version__ = "0.0.1" 13 | 14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications. 15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript. 16 | Focus on: 17 | - Writing clear, maintainable code 18 | - Following best practices 19 | - Creating responsive designs 20 | - Adding appropriate styling and interactivity 21 | Return only the complete HTML code without any additional explanation.""" 22 | 23 | 24 | def encode_image_file(image_path): 25 | """Encode an image file to base64.""" 26 | try: 27 | with open(image_path, "rb") as image_file: 28 | return base64.b64encode(image_file.read()).decode('utf-8') 29 | except Exception as e: 30 | print(f"Error encoding image: {str(e)}") 31 | return None 32 | 33 | 34 | def process_image(image): 35 | """Process image input to the format expected by Mistral API.""" 36 | if isinstance(image, str): 37 | # Check if it's a URL or base64 string 38 | if image.startswith('data:'): 39 | return image # Already in base64 format 40 | elif urlparse(image).scheme in ('http', 'https'): 41 | return image # It's a URL 42 | else: 43 | # Assume it's a local file path 44 | encoded = encode_image_file(image) 45 | return f"data:image/jpeg;base64,{encoded}" if encoded else None 46 | return None 47 | 48 | 49 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 50 | def fn(message, history): 51 | inputs = preprocess(message, history) 52 | client = Mistral(api_key=api_key) 53 | try: 54 | # Check if the model is Codestral 55 | if model_name.startswith("codestral"): 56 | # Handle Codestral API calls 57 | response = client.chat.complete( 58 | model=model_name, 59 | messages=inputs["messages"] 60 | ) 61 | yield postprocess(response.choices[0].message.content) 62 | else: 63 | # Create the streaming chat completion for other models 64 | stream_response = client.chat.stream( 65 | model=model_name, 66 | messages=inputs["messages"] 67 | ) 68 | 69 | response_text = "" 70 | for chunk in stream_response: 71 | if chunk.data.choices[0].delta.content is not None: 72 | delta = chunk.data.choices[0].delta.content 73 | response_text += delta 74 | yield postprocess(response_text) 75 | 76 | except Exception as e: 77 | print(f"Error during chat completion: {str(e)}") 78 | yield "Sorry, there was an error processing your request." 79 | 80 | return fn 81 | 82 | 83 | def get_interface_args(pipeline): 84 | if pipeline == "chat": 85 | inputs = None 86 | outputs = None 87 | 88 | def preprocess(message, history): 89 | messages = [] 90 | # Process history 91 | for user_msg, assistant_msg in history: 92 | if isinstance(user_msg, dict): 93 | # Handle multimodal history messages 94 | content = [] 95 | if user_msg.get("text"): 96 | content.append({"type": "text", "text": user_msg["text"]}) 97 | for file in user_msg.get("files", []): 98 | processed_image = process_image(file) 99 | if processed_image: 100 | content.append({"type": "image_url", "image_url": processed_image}) 101 | messages.append({"role": "user", "content": content}) 102 | else: 103 | # Handle text-only history messages 104 | messages.append({"role": "user", "content": user_msg}) 105 | messages.append({"role": "assistant", "content": assistant_msg}) 106 | 107 | # Process current message 108 | if isinstance(message, dict): 109 | # Handle multimodal input 110 | content = [] 111 | if message.get("text"): 112 | content.append({"type": "text", "text": message["text"]}) 113 | for file in message.get("files", []): 114 | processed_image = process_image(file) 115 | if processed_image: 116 | content.append({"type": "image_url", "image_url": processed_image}) 117 | messages.append({"role": "user", "content": content}) 118 | else: 119 | # Handle text-only input 120 | messages.append({"role": "user", "content": message}) 121 | 122 | return {"messages": messages} 123 | 124 | postprocess = lambda x: x # No post-processing needed 125 | else: 126 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 127 | return inputs, outputs, preprocess, postprocess 128 | 129 | 130 | def get_pipeline(model_name): 131 | # Determine the pipeline type based on the model name 132 | # For simplicity, assuming all models are chat models at the moment 133 | return "chat" 134 | 135 | 136 | def generate_code(query, history, setting, api_key): 137 | """Generate code using Mistral API and handle UI updates.""" 138 | client = Mistral(api_key=api_key) 139 | 140 | messages = [] 141 | # Add system prompt 142 | messages.append({"role": "system", "content": setting["system"]}) 143 | 144 | # Add history 145 | for h in history: 146 | messages.append({"role": "user", "content": h[0]}) 147 | messages.append({"role": "assistant", "content": h[1]}) 148 | 149 | # Add current query 150 | messages.append({"role": "user", "content": query}) 151 | 152 | try: 153 | # Create the streaming chat completion 154 | stream_response = client.chat.stream( 155 | model="mistral-large-latest", 156 | messages=messages 157 | ) 158 | 159 | response_text = "" 160 | for chunk in stream_response: 161 | if chunk.data.choices[0].delta.content is not None: 162 | delta = chunk.data.choices[0].delta.content 163 | response_text += delta 164 | # Yield intermediate updates 165 | yield ( 166 | response_text, # code_output (for markdown display) 167 | history, # history state 168 | None, # preview HTML 169 | gr.update(active_key="loading"), # state_tab 170 | gr.update(open=True) # code_drawer 171 | ) 172 | 173 | # Clean the code and prepare final preview 174 | clean_code = remove_code_block(response_text) 175 | new_history = history + [(query, response_text)] 176 | 177 | # Final yield with complete response 178 | yield ( 179 | response_text, # code_output 180 | new_history, # history state 181 | send_to_preview(clean_code), # preview HTML 182 | gr.update(active_key="render"), # state_tab 183 | gr.update(open=False) # code_drawer 184 | ) 185 | 186 | except Exception as e: 187 | print(f"Error generating code: {str(e)}") 188 | yield ( 189 | f"Error: {str(e)}", 190 | history, 191 | None, 192 | gr.update(active_key="empty"), 193 | gr.update(open=True) 194 | ) 195 | 196 | 197 | def remove_code_block(text): 198 | """Extract code from markdown code blocks.""" 199 | pattern = r'```html\n(.+?)\n```' 200 | match = re.search(pattern, text, re.DOTALL) 201 | if match: 202 | return match.group(1).strip() 203 | return text.strip() 204 | 205 | 206 | def send_to_preview(code): 207 | """Convert code to base64 encoded iframe source.""" 208 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8') 209 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" 210 | return f'' 211 | 212 | 213 | def registry(name: str, token: str | None = None, coder: bool = False, **kwargs): 214 | """ 215 | Create a Gradio Interface for a model on Mistral AI. 216 | 217 | Parameters: 218 | - name (str): The name of the Mistral AI model. 219 | - token (str, optional): The API key for Mistral AI. 220 | """ 221 | api_key = token or os.environ.get("MISTRAL_API_KEY") 222 | if not api_key: 223 | raise ValueError("MISTRAL_API_KEY environment variable is not set.") 224 | 225 | pipeline = get_pipeline(name) 226 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 227 | fn = get_fn(name, preprocess, postprocess, api_key) 228 | 229 | if pipeline == "chat": 230 | # Always enable multimodal support 231 | interface = gr.ChatInterface( 232 | fn=fn, 233 | multimodal=True, 234 | **kwargs 235 | ) 236 | else: 237 | # For other pipelines, create a standard Interface (not implemented yet) 238 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 239 | 240 | if coder: 241 | interface = gr.Blocks(css=""" 242 | .left_header { 243 | text-align: center; 244 | margin-bottom: 20px; 245 | } 246 | .right_panel { 247 | background: white; 248 | border-radius: 8px; 249 | overflow: hidden; 250 | box-shadow: 0 2px 8px rgba(0,0,0,0.15); 251 | } 252 | .render_header { 253 | background: #f5f5f5; 254 | padding: 8px; 255 | border-bottom: 1px solid #e8e8e8; 256 | } 257 | .header_btn { 258 | display: inline-block; 259 | width: 12px; 260 | height: 12px; 261 | border-radius: 50%; 262 | margin-right: 8px; 263 | background: #ff5f56; 264 | } 265 | .header_btn:nth-child(2) { 266 | background: #ffbd2e; 267 | } 268 | .header_btn:nth-child(3) { 269 | background: #27c93f; 270 | } 271 | .right_content { 272 | padding: 24px; 273 | height: 920px; 274 | display: flex; 275 | align-items: center; 276 | justify-content: center; 277 | } 278 | .html_content { 279 | height: 920px; 280 | width: 100%; 281 | } 282 | .history_chatbot { 283 | height: 100%; 284 | } 285 | """) 286 | 287 | with interface: 288 | history = gr.State([]) 289 | setting = gr.State({"system": SystemPrompt}) 290 | 291 | with ms.Application() as app: 292 | with antd.ConfigProvider(): 293 | with antd.Row(gutter=[32, 12]) as layout: 294 | # Left Column 295 | with antd.Col(span=24, md=8): 296 | with antd.Flex(vertical=True, gap="middle", wrap=True): 297 | header = gr.HTML(""" 298 |
299 |

Codestral Code Generator

300 |
301 | """) 302 | 303 | input = antd.InputTextarea( 304 | size="large", 305 | allow_clear=True, 306 | placeholder="Describe the code you want to generate" 307 | ) 308 | btn = antd.Button("Generate", type="primary", size="large") 309 | clear_btn = antd.Button("Clear History", type="default", size="large") 310 | 311 | antd.Divider("Settings") 312 | with antd.Flex(gap="small", wrap=True): 313 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default") 314 | codeBtn = antd.Button("🧑‍💻 View Code", type="default") 315 | historyBtn = antd.Button("📜 History", type="default") 316 | 317 | # Modals and Drawers 318 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal: 319 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True) 320 | 321 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer: 322 | code_output = legacy.Markdown() 323 | 324 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer: 325 | history_output = legacy.Chatbot( 326 | show_label=False, 327 | height=960, 328 | elem_classes="history_chatbot" 329 | ) 330 | 331 | # Right Column 332 | with antd.Col(span=24, md=16): 333 | with ms.Div(elem_classes="right_panel"): 334 | gr.HTML(''' 335 |
336 | 337 | 338 | 339 |
340 | ''') 341 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab: 342 | with antd.Tabs.Item(key="empty"): 343 | empty = antd.Empty( 344 | description="Enter your request to generate code", 345 | elem_classes="right_content" 346 | ) 347 | with antd.Tabs.Item(key="loading"): 348 | loading = antd.Spin( 349 | True, 350 | tip="Generating code...", 351 | size="large", 352 | elem_classes="right_content" 353 | ) 354 | with antd.Tabs.Item(key="render"): 355 | preview = gr.HTML(elem_classes="html_content") 356 | 357 | # Wire up event handlers 358 | btn.click( 359 | generate_code, 360 | inputs=[input, history, setting, gr.State(api_key)], 361 | outputs=[code_output, history, preview, state_tab, code_drawer], 362 | api_name=False 363 | ) 364 | 365 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal]) 366 | system_prompt_modal.ok( 367 | lambda input: ({"system": input}, gr.update(open=False)), 368 | inputs=[systemPromptInput], 369 | outputs=[setting, system_prompt_modal] 370 | ) 371 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal]) 372 | 373 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer]) 374 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer]) 375 | 376 | historyBtn.click( 377 | lambda h: (gr.update(open=True), h), 378 | inputs=[history], 379 | outputs=[history_drawer, history_output] 380 | ) 381 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer]) 382 | 383 | clear_btn.click(lambda: [], outputs=[history]) 384 | 385 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/qwen_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | import gradio as gr 4 | from typing import Callable 5 | import base64 6 | import re 7 | import modelscope_studio.components.base as ms 8 | import modelscope_studio.components.legacy as legacy 9 | import modelscope_studio.components.antd as antd 10 | 11 | __version__ = "0.0.3" 12 | 13 | # Add these constants at the top of the file 14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications. 15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript. 16 | Focus on: 17 | - Writing clear, maintainable code 18 | - Following best practices 19 | - Creating responsive designs 20 | - Adding appropriate styling and interactivity 21 | Return only the complete HTML code without any additional explanation.""" 22 | 23 | DEMO_LIST = [ 24 | { 25 | "card": {"index": 0}, 26 | "title": "Simple Button", 27 | "description": "Create a button that changes color when clicked" 28 | }, 29 | { 30 | "card": {"index": 1}, 31 | "title": "Todo List", 32 | "description": "Create a simple todo list with add/remove functionality" 33 | }, 34 | { 35 | "card": {"index": 2}, 36 | "title": "Timer App", 37 | "description": "Create a countdown timer with start/pause/reset controls" 38 | } 39 | ] 40 | 41 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str | None = None, base_url: str | None = None, local: bool = False): 42 | if local: 43 | from vllm import LLM, SamplingParams 44 | from transformers import AutoTokenizer 45 | 46 | # Initialize tokenizer and model 47 | tokenizer = AutoTokenizer.from_pretrained(model_name) 48 | llm = LLM( 49 | model=model_name, 50 | tensor_parallel_size=1, # Adjust based on available GPUs 51 | max_model_len=1010000, 52 | enable_chunked_prefill=True, 53 | max_num_batched_tokens=131072, 54 | enforce_eager=True, 55 | ) 56 | 57 | # Default sampling parameters 58 | sampling_params = SamplingParams( 59 | temperature=0.7, 60 | top_p=0.8, 61 | repetition_penalty=1.05, 62 | max_tokens=512 63 | ) 64 | 65 | def fn(message, history): 66 | inputs = preprocess(message, history) 67 | messages = inputs["messages"] 68 | 69 | # Convert messages to model input format 70 | text = tokenizer.apply_chat_template( 71 | messages, 72 | tokenize=False, 73 | add_generation_prompt=True 74 | ) 75 | 76 | # Generate response 77 | outputs = llm.generate([text], sampling_params) 78 | response_text = outputs[0].outputs[0].text 79 | yield postprocess(response_text) 80 | 81 | return fn 82 | 83 | # Original cloud API implementation 84 | def fn(message, history): 85 | inputs = preprocess(message, history) 86 | 87 | client = OpenAI( 88 | api_key=api_key, 89 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" 90 | ) 91 | completion = client.chat.completions.create( 92 | model=model_name, 93 | messages=inputs["messages"], 94 | stream=True, 95 | ) 96 | response_text = "" 97 | for chunk in completion: 98 | delta = chunk.choices[0].delta.content or "" 99 | response_text += delta 100 | yield postprocess(response_text) 101 | 102 | return fn 103 | 104 | 105 | def get_interface_args(pipeline, model_name: str): 106 | if pipeline == "chat": 107 | inputs = None 108 | outputs = None 109 | 110 | def preprocess(message, history): 111 | messages = [] 112 | # Add system prompt for qwq-32b-preview 113 | if model_name == "qwq-32b-preview": 114 | messages.append({ 115 | "role": "system", 116 | "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." 117 | }) 118 | 119 | for user_msg, assistant_msg in history: 120 | messages.append({"role": "user", "content": user_msg}) 121 | messages.append({"role": "assistant", "content": assistant_msg}) 122 | 123 | # Handle multimodal input 124 | if isinstance(message, dict): 125 | content = [] 126 | if message.get("files"): 127 | # Convert local file path to data URL 128 | with open(message["files"][0], "rb") as image_file: 129 | encoded_image = base64.b64encode(image_file.read()).decode('utf-8') 130 | content.append({ 131 | "type": "image_url", 132 | "image_url": f"data:image/jpeg;base64,{encoded_image}" 133 | }) 134 | content.append({ 135 | "type": "text", 136 | "text": message["text"] 137 | }) 138 | messages.append({"role": "user", "content": content}) 139 | else: 140 | messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) 141 | 142 | return {"messages": messages} 143 | 144 | postprocess = lambda x: x # No post-processing needed 145 | else: 146 | # Add other pipeline types when they will be needed 147 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 148 | return inputs, outputs, preprocess, postprocess 149 | 150 | 151 | def get_pipeline(model_name): 152 | # Determine the pipeline type based on the model name 153 | # For simplicity, assuming all models are chat models at the moment 154 | return "chat" 155 | 156 | 157 | def registry( 158 | name: str, 159 | token: str | None = None, 160 | examples: list | None = None, 161 | coder: bool = False, 162 | local: bool = False, # Add local parameter 163 | **kwargs 164 | ): 165 | api_key = None if local else (token or os.environ.get("DASHSCOPE_API_KEY")) 166 | if not local and not api_key: 167 | raise ValueError("API key not found in environment variables.") 168 | 169 | if coder: 170 | interface = gr.Blocks(css=""" 171 | .left_header { 172 | text-align: center; 173 | margin-bottom: 20px; 174 | } 175 | 176 | .right_panel { 177 | background: white; 178 | border-radius: 8px; 179 | overflow: hidden; 180 | box-shadow: 0 2px 8px rgba(0,0,0,0.15); 181 | } 182 | 183 | .render_header { 184 | background: #f5f5f5; 185 | padding: 8px; 186 | border-bottom: 1px solid #e8e8e8; 187 | } 188 | 189 | .header_btn { 190 | display: inline-block; 191 | width: 12px; 192 | height: 12px; 193 | border-radius: 50%; 194 | margin-right: 8px; 195 | background: #ff5f56; 196 | } 197 | 198 | .header_btn:nth-child(2) { 199 | background: #ffbd2e; 200 | } 201 | 202 | .header_btn:nth-child(3) { 203 | background: #27c93f; 204 | } 205 | 206 | .right_content { 207 | padding: 24px; 208 | height: 920px; 209 | display: flex; 210 | align-items: center; 211 | justify-content: center; 212 | } 213 | 214 | .html_content { 215 | height: 920px; 216 | width: 100%; 217 | } 218 | 219 | .history_chatbot { 220 | height: 100%; 221 | } 222 | """) 223 | with interface: 224 | history = gr.State([]) 225 | setting = gr.State({"system": SystemPrompt}) 226 | 227 | with ms.Application() as app: 228 | with antd.ConfigProvider(): 229 | with antd.Row(gutter=[32, 12]) as layout: 230 | # Left Column 231 | with antd.Col(span=24, md=8): 232 | with antd.Flex(vertical=True, gap="middle", wrap=True): 233 | header = gr.HTML(""" 234 |
235 |

Qwen Code Generator

236 |
237 | """) 238 | 239 | input = antd.InputTextarea( 240 | size="large", 241 | allow_clear=True, 242 | placeholder="Describe the web application you want to create" 243 | ) 244 | btn = antd.Button("Generate", type="primary", size="large") 245 | clear_btn = antd.Button("Clear History", type="default", size="large") 246 | 247 | antd.Divider("Examples") 248 | with antd.Flex(gap="small", wrap=True): 249 | with ms.Each(DEMO_LIST): 250 | with antd.Card(hoverable=True, as_item="card") as demoCard: 251 | antd.CardMeta() 252 | 253 | antd.Divider("Settings") 254 | with antd.Flex(gap="small", wrap=True): 255 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default") 256 | codeBtn = antd.Button("🧑‍💻 View Code", type="default") 257 | historyBtn = antd.Button("📜 History", type="default") 258 | 259 | # Modals and Drawers 260 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal: 261 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True) 262 | 263 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer: 264 | code_output = legacy.Markdown() 265 | 266 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer: 267 | history_output = legacy.Chatbot( 268 | show_label=False, 269 | height=960, 270 | elem_classes="history_chatbot" 271 | ) 272 | 273 | # Right Column 274 | with antd.Col(span=24, md=16): 275 | with ms.Div(elem_classes="right_panel"): 276 | gr.HTML(''' 277 |
278 | 279 | 280 | 281 |
282 | ''') 283 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab: 284 | with antd.Tabs.Item(key="empty"): 285 | empty = antd.Empty( 286 | description="Enter your request to generate code", 287 | elem_classes="right_content" 288 | ) 289 | with antd.Tabs.Item(key="loading"): 290 | loading = antd.Spin( 291 | True, 292 | tip="Generating code...", 293 | size="large", 294 | elem_classes="right_content" 295 | ) 296 | with antd.Tabs.Item(key="render"): 297 | preview = gr.HTML(elem_classes="html_content") 298 | 299 | # Event Handlers 300 | def demo_card_click(e: gr.EventData): 301 | index = e._data['component']['index'] 302 | return DEMO_LIST[index]['description'] 303 | 304 | def send_to_preview(code): 305 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8') 306 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" 307 | return f'' 308 | 309 | def remove_code_block(text): 310 | pattern = r'```html\n(.+?)\n```' 311 | match = re.search(pattern, text, re.DOTALL) 312 | if match: 313 | return match.group(1).strip() 314 | return text.strip() 315 | 316 | def generate_code(query, setting, history): 317 | client = OpenAI( 318 | api_key=api_key, 319 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" 320 | ) 321 | 322 | messages = [ 323 | {"role": "system", "content": setting["system"]}, 324 | {"role": "assistant", "content": "I understand. I will help you write clean, efficient web code."} 325 | ] 326 | 327 | # Add history 328 | for h in history: 329 | messages.append({"role": "user", "content": h[0]}) 330 | messages.append({"role": "assistant", "content": h[1]}) 331 | 332 | messages.append({"role": "user", "content": query}) 333 | 334 | completion = client.chat.completions.create( 335 | model=name, 336 | messages=messages, 337 | stream=True, 338 | ) 339 | 340 | response_text = "" 341 | for chunk in completion: 342 | if chunk.choices[0].delta.content: 343 | response_text += chunk.choices[0].delta.content 344 | yield ( 345 | response_text, 346 | history, 347 | None, 348 | gr.update(active_key="loading"), 349 | gr.update(open=True) 350 | ) 351 | 352 | clean_code = remove_code_block(response_text) 353 | new_history = history + [(query, response_text)] 354 | 355 | yield ( 356 | response_text, 357 | new_history, 358 | send_to_preview(clean_code), 359 | gr.update(active_key="render"), 360 | gr.update(open=False) 361 | ) 362 | 363 | # Wire up event handlers 364 | demoCard.click(demo_card_click, outputs=[input]) 365 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal]) 366 | system_prompt_modal.ok( 367 | lambda input: ({"system": input}, gr.update(open=False)), 368 | inputs=[systemPromptInput], 369 | outputs=[setting, system_prompt_modal] 370 | ) 371 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal]) 372 | 373 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer]) 374 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer]) 375 | 376 | historyBtn.click( 377 | lambda h: (gr.update(open=True), h), 378 | inputs=[history], 379 | outputs=[history_drawer, history_output] 380 | ) 381 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer]) 382 | 383 | btn.click( 384 | generate_code, 385 | inputs=[input, setting, history], 386 | outputs=[code_output, history, preview, state_tab, code_drawer] 387 | ) 388 | 389 | clear_btn.click(lambda: [], outputs=[history]) 390 | 391 | return interface 392 | 393 | # Continue with existing chat interface code... 394 | pipeline = get_pipeline(name) 395 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline, name) 396 | fn = get_fn(name, preprocess, postprocess, api_key, local=local) 397 | 398 | if examples: 399 | formatted_examples = [[example, False] for example in examples] 400 | kwargs["examples"] = formatted_examples 401 | 402 | if pipeline == "chat": 403 | interface = gr.ChatInterface( 404 | fn=fn, 405 | additional_inputs=inputs, 406 | multimodal=True, 407 | **kwargs 408 | ) 409 | else: 410 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 411 | 412 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/deepseek_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from openai import OpenAI 4 | import gradio as gr 5 | from typing import Callable 6 | import modelscope_studio.components.base as ms 7 | import modelscope_studio.components.legacy as legacy 8 | import modelscope_studio.components.antd as antd 9 | 10 | __version__ = "0.0.1" 11 | 12 | # Add these constants at the top of the file 13 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications. 14 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript. 15 | Focus on: 16 | - Writing clear, maintainable code 17 | - Following best practices 18 | - Creating responsive designs 19 | - Adding appropriate styling and interactivity 20 | Return only the complete HTML code without any additional explanation.""" 21 | 22 | DEMO_LIST = [ 23 | { 24 | "card": {"index": 0}, 25 | "title": "Simple Button", 26 | "description": "Create a button that changes color when clicked" 27 | }, 28 | { 29 | "card": {"index": 1}, 30 | "title": "Todo List", 31 | "description": "Create a simple todo list with add/remove functionality" 32 | }, 33 | { 34 | "card": {"index": 2}, 35 | "title": "Timer App", 36 | "description": "Create a countdown timer with start/pause/reset controls" 37 | } 38 | ] 39 | 40 | def get_image_base64(url: str, ext: str): 41 | with open(url, "rb") as image_file: 42 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 43 | return "data:image/" + ext + ";base64," + encoded_string 44 | 45 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 46 | def fn(message, history): 47 | inputs = preprocess(message, history) 48 | client = OpenAI( 49 | base_url="https://api.deepseek.com", 50 | api_key=api_key, 51 | ) 52 | try: 53 | completion = client.chat.completions.create( 54 | model=model_name, 55 | messages=inputs["messages"], 56 | stream=True, 57 | ) 58 | response_text = "" 59 | reasoning_text = "" 60 | for chunk in completion: 61 | if chunk.choices[0].delta.reasoning_content: 62 | reasoning_text += chunk.choices[0].delta.reasoning_content 63 | else: 64 | delta = chunk.choices[0].delta.content or "" 65 | response_text += delta 66 | yield postprocess(response_text, reasoning_text) 67 | except Exception as e: 68 | error_message = f"Error: {str(e)}" 69 | return error_message 70 | 71 | return fn 72 | 73 | def handle_user_msg(message: str): 74 | if type(message) is str: 75 | return message 76 | elif type(message) is dict: 77 | if message["files"] is not None and len(message["files"]) > 0: 78 | ext = os.path.splitext(message["files"][-1])[1].strip(".") 79 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: 80 | encoded_str = get_image_base64(message["files"][-1], ext) 81 | else: 82 | raise NotImplementedError(f"Not supported file type {ext}") 83 | content = [ 84 | {"type": "text", "text": message["text"]}, 85 | { 86 | "type": "image_url", 87 | "image_url": { 88 | "url": encoded_str, 89 | } 90 | }, 91 | ] 92 | else: 93 | content = message["text"] 94 | return content 95 | else: 96 | raise NotImplementedError 97 | 98 | def get_interface_args(pipeline): 99 | if pipeline == "chat": 100 | inputs = None 101 | outputs = [ 102 | gr.Textbox(label="Chain of Thought", lines=10, visible=True), 103 | gr.Textbox(label="Response") 104 | ] 105 | 106 | def preprocess(message, history): 107 | messages = [] 108 | files = None 109 | for user_msg, assistant_msg in history: 110 | if assistant_msg is not None: 111 | if isinstance(assistant_msg, dict): 112 | assistant_msg = assistant_msg["visible"][1] # Get the response text 113 | messages.append({"role": "user", "content": handle_user_msg(user_msg)}) 114 | messages.append({"role": "assistant", "content": assistant_msg}) 115 | else: 116 | files = user_msg 117 | if type(message) is str and files is not None: 118 | message = {"text": message, "files": files} 119 | elif type(message) is dict and files is not None: 120 | if message["files"] is None or len(message["files"]) == 0: 121 | message["files"] = files 122 | messages.append({"role": "user", "content": handle_user_msg(message)}) 123 | return {"messages": messages} 124 | 125 | def postprocess(response, reasoning=None): 126 | # Update both textboxes but return only the response for the chat history 127 | return response 128 | else: 129 | # Add other pipeline types when they will be needed 130 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 131 | return inputs, outputs, preprocess, postprocess 132 | 133 | 134 | def get_pipeline(model_name): 135 | if model_name == "deepseek-reasoner": 136 | return "chat" 137 | # For other models, assume chat pipeline 138 | return "chat" 139 | 140 | 141 | def registry( 142 | name: str, 143 | token: str | None = None, 144 | examples: list | None = None, 145 | coder: bool = False, 146 | **kwargs 147 | ): 148 | api_key = token or os.environ.get("DEEPSEEK_API_KEY") 149 | if not api_key: 150 | raise ValueError("DEEPSEEK_API_KEY environment variable is not set.") 151 | 152 | if coder: 153 | interface = gr.Blocks(css=""" 154 | .left_header { 155 | text-align: center; 156 | margin-bottom: 20px; 157 | } 158 | 159 | .right_panel { 160 | background: white; 161 | border-radius: 8px; 162 | overflow: hidden; 163 | box-shadow: 0 2px 8px rgba(0,0,0,0.15); 164 | } 165 | 166 | .render_header { 167 | background: #f5f5f5; 168 | padding: 8px; 169 | border-bottom: 1px solid #e8e8e8; 170 | } 171 | 172 | .header_btn { 173 | display: inline-block; 174 | width: 12px; 175 | height: 12px; 176 | border-radius: 50%; 177 | margin-right: 8px; 178 | background: #ff5f56; 179 | } 180 | 181 | .header_btn:nth-child(2) { 182 | background: #ffbd2e; 183 | } 184 | 185 | .header_btn:nth-child(3) { 186 | background: #27c93f; 187 | } 188 | 189 | .right_content { 190 | padding: 24px; 191 | height: 920px; 192 | display: flex; 193 | align-items: center; 194 | justify-content: center; 195 | } 196 | 197 | .html_content { 198 | height: 920px; 199 | width: 100%; 200 | } 201 | 202 | .history_chatbot { 203 | height: 100%; 204 | } 205 | """) 206 | with interface: 207 | history = gr.State([]) 208 | setting = gr.State({"system": SystemPrompt}) 209 | 210 | with ms.Application() as app: 211 | with antd.ConfigProvider(): 212 | with antd.Row(gutter=[32, 12]) as layout: 213 | # Left Column 214 | with antd.Col(span=24, md=8): 215 | with antd.Flex(vertical=True, gap="middle", wrap=True): 216 | header = gr.HTML(""" 217 |
218 |

DeepSeek R1 Code Generator

219 |
220 | """) 221 | 222 | input = antd.InputTextarea( 223 | size="large", 224 | allow_clear=True, 225 | placeholder="Describe the web application you want to create" 226 | ) 227 | btn = antd.Button("Generate", type="primary", size="large") 228 | clear_btn = antd.Button("Clear History", type="default", size="large") 229 | 230 | antd.Divider("Examples") 231 | with antd.Flex(gap="small", wrap=True): 232 | with ms.Each(DEMO_LIST): 233 | with antd.Card(hoverable=True, as_item="card") as demoCard: 234 | antd.CardMeta() 235 | 236 | antd.Divider("Settings") 237 | with antd.Flex(gap="small", wrap=True): 238 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default") 239 | codeBtn = antd.Button("🧑‍💻 View Code", type="default") 240 | historyBtn = antd.Button("📜 History", type="default") 241 | 242 | # Modals and Drawers 243 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal: 244 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True) 245 | 246 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer: 247 | code_output = legacy.Markdown() 248 | 249 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer: 250 | history_output = legacy.Chatbot( 251 | show_label=False, 252 | height=960, 253 | elem_classes="history_chatbot" 254 | ) 255 | 256 | # Right Column 257 | with antd.Col(span=24, md=16): 258 | with ms.Div(elem_classes="right_panel"): 259 | gr.HTML(''' 260 |
261 | 262 | 263 | 264 |
265 | ''') 266 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab: 267 | with antd.Tabs.Item(key="empty"): 268 | empty = antd.Empty( 269 | description="Enter your request to generate code", 270 | elem_classes="right_content" 271 | ) 272 | with antd.Tabs.Item(key="loading"): 273 | loading = antd.Spin( 274 | True, 275 | tip="Generating code...", 276 | size="large", 277 | elem_classes="right_content" 278 | ) 279 | with antd.Tabs.Item(key="render"): 280 | preview = gr.HTML(elem_classes="html_content") 281 | 282 | # Event Handlers 283 | def demo_card_click(e: gr.EventData): 284 | index = e._data['component']['index'] 285 | return DEMO_LIST[index]['description'] 286 | 287 | def send_to_preview(code): 288 | # Clean the code and escape special characters for HTML 289 | clean_code = code.replace("```html", "").replace("```", "").strip() 290 | escaped_code = clean_code.replace('"', '"').replace('<', '<').replace('>', '>') 291 | return f''' 292 | 298 | ''' 299 | 300 | def generate_code(query, setting, history): 301 | messages = [{"role": "system", "content": setting["system"]}] 302 | 303 | # Add history in alternating user/assistant pattern 304 | for user_msg, assistant_msg in history: 305 | messages.append({"role": "user", "content": user_msg}) 306 | messages.append({"role": "assistant", "content": assistant_msg}) 307 | 308 | # Add current query 309 | messages.append({"role": "user", "content": query}) 310 | 311 | try: 312 | client = OpenAI( 313 | base_url="https://api.deepseek.com", 314 | api_key=api_key, 315 | ) 316 | 317 | response = client.chat.completions.create( 318 | model=name, 319 | messages=messages, 320 | stream=True, 321 | ) 322 | 323 | code = "" 324 | for chunk in response: 325 | if chunk.choices[0].delta.content: 326 | code += chunk.choices[0].delta.content 327 | # Return all 5 required outputs 328 | yield ( 329 | f"```html\n{code}\n```", # code_output (modelscopemarkdown) 330 | history, # state 331 | None, # preview (html) 332 | gr.update(active_key="loading"), # state_tab (antdtabs) 333 | gr.update(open=True) # code_drawer (antddrawer) 334 | ) 335 | 336 | new_history = history + [(query, code)] 337 | 338 | # Final yield with all outputs 339 | yield ( 340 | f"```html\n{code}\n```", # code_output (modelscopemarkdown) 341 | new_history, # state 342 | send_to_preview(code), # preview (html) 343 | gr.update(active_key="render"), # state_tab (antdtabs) 344 | gr.update(open=False) # code_drawer (antddrawer) 345 | ) 346 | 347 | except Exception as e: 348 | error_msg = f"Error: {str(e)}" 349 | yield ( 350 | error_msg, 351 | history, 352 | None, 353 | gr.update(active_key="empty"), 354 | gr.update(open=True) 355 | ) 356 | 357 | # Wire up event handlers 358 | demoCard.click(demo_card_click, outputs=[input]) 359 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal]) 360 | system_prompt_modal.ok( 361 | lambda input: ({"system": input}, gr.update(open=False)), 362 | inputs=[systemPromptInput], 363 | outputs=[setting, system_prompt_modal] 364 | ) 365 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal]) 366 | 367 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer]) 368 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer]) 369 | 370 | historyBtn.click( 371 | lambda h: (gr.update(open=True), h), 372 | inputs=[history], 373 | outputs=[history_drawer, history_output] 374 | ) 375 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer]) 376 | 377 | btn.click( 378 | generate_code, 379 | inputs=[input, setting, history], 380 | outputs=[code_output, history, preview, state_tab, code_drawer] 381 | ) 382 | 383 | clear_btn.click(lambda: [], outputs=[history]) 384 | 385 | return interface 386 | 387 | # Continue with existing chat interface code... 388 | pipeline = get_pipeline(name) 389 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline) 390 | fn = get_fn(name, preprocess, postprocess, api_key) 391 | 392 | if examples: 393 | kwargs["examples"] = examples 394 | 395 | if pipeline == "chat": 396 | interface = gr.ChatInterface(fn=fn, **kwargs) 397 | else: 398 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 399 | 400 | return interface -------------------------------------------------------------------------------- /ai_gradio/providers/replicate_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import replicate 3 | import asyncio 4 | import os 5 | from typing import Callable, Dict, Any, List, Tuple 6 | import httpx 7 | from PIL import Image 8 | import io 9 | import base64 10 | import numpy as np 11 | import tempfile 12 | import time 13 | 14 | def resize_image_if_needed(image, max_size=1024): 15 | """Resize image if either dimension exceeds max_size while maintaining aspect ratio""" 16 | if isinstance(image, str) and image.startswith('data:image'): 17 | return image # Already a data URI, skip processing 18 | 19 | if isinstance(image, np.ndarray): 20 | image = Image.fromarray(image) 21 | elif not isinstance(image, Image.Image): 22 | image = Image.open(image) 23 | 24 | # Get original dimensions 25 | width, height = image.size 26 | 27 | # Calculate new dimensions if needed 28 | if width > max_size or height > max_size: 29 | if width > height: 30 | new_width = max_size 31 | new_height = int(height * (max_size / width)) 32 | else: 33 | new_height = max_size 34 | new_width = int(width * (max_size / height)) 35 | 36 | image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 37 | 38 | # Convert to RGB if necessary 39 | if image.mode != 'RGB': 40 | image = image.convert('RGB') 41 | 42 | # Convert to base64 43 | buffered = io.BytesIO() 44 | image.save(buffered, format="JPEG", quality=85) 45 | img_str = base64.b64encode(buffered.getvalue()).decode() 46 | return f"data:image/jpeg;base64,{img_str}" 47 | 48 | def bytes_to_image(byte_data): 49 | """Convert bytes to PIL Image and ensure we get fresh data""" 50 | if isinstance(byte_data, bytes): 51 | return Image.open(io.BytesIO(byte_data)) 52 | # For file-like objects 53 | if hasattr(byte_data, 'seek'): 54 | byte_data.seek(0) 55 | return Image.open(io.BytesIO(byte_data.read())) 56 | 57 | def save_bytes_to_video(video_bytes): 58 | """Save video bytes to a temporary file and return the path""" 59 | if not isinstance(video_bytes, bytes): 60 | raise ValueError(f"Expected bytes input, got {type(video_bytes)}") 61 | 62 | # Create a temporary file with .mp4 extension 63 | temp_dir = tempfile.gettempdir() 64 | temp_path = os.path.join(temp_dir, f"temp_{int(time.time())}_{os.urandom(4).hex()}.mp4") 65 | 66 | try: 67 | # Write the bytes to the temporary file 68 | with open(temp_path, "wb") as f: 69 | f.write(video_bytes) 70 | 71 | # Ensure the file exists and has content 72 | if not os.path.exists(temp_path) or os.path.getsize(temp_path) == 0: 73 | raise ValueError("Failed to save video file or file is empty") 74 | 75 | return str(temp_path) # Return string path as expected by Gradio 76 | finally: 77 | # Clean up the temporary file 78 | if os.path.exists(temp_path): 79 | os.remove(temp_path) 80 | 81 | PIPELINE_REGISTRY = { 82 | "text-to-image": { 83 | "inputs": [ 84 | ("prompt", gr.Textbox, {"label": "Prompt"}), 85 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}), 86 | ("width", gr.Number, {"label": "Width", "value": 1024, "minimum": 512, "maximum": 2048, "step": 64, "optional": True}), 87 | ("height", gr.Number, {"label": "Height", "value": 1024, "minimum": 512, "maximum": 2048, "step": 64, "optional": True}), 88 | ("num_outputs", gr.Number, {"label": "Number of Images", "value": 1, "minimum": 1, "maximum": 4, "step": 1, "optional": True}), 89 | ("scheduler", gr.Dropdown, {"label": "Scheduler", "choices": ["DPM++ 2M", "DPM++ 2M Karras", "DPM++ 2M SDE", "DPM++ 2M SDE Karras"], "optional": True}), 90 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}), 91 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}), 92 | ("seed", gr.Number, {"label": "Seed", "optional": True}) 93 | ], 94 | "outputs": [("images", gr.Gallery, {})], 95 | "preprocess": lambda *args: { 96 | k: v for k, v in zip([ 97 | "prompt", "negative_prompt", "width", "height", "num_outputs", 98 | "scheduler", "num_inference_steps", "guidance_scale", "seed" 99 | ], args) if v is not None and v != "" 100 | }, 101 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)] 102 | }, 103 | 104 | "image-to-image": { 105 | "inputs": [ 106 | ("prompt", gr.Textbox, {"label": "Prompt"}), 107 | ("image", gr.Image, {"label": "Input Image", "type": "pil"}), 108 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}), 109 | ("strength", gr.Slider, {"label": "Strength", "minimum": 0, "maximum": 1, "value": 0.7, "optional": True}), 110 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}), 111 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}), 112 | ("seed", gr.Number, {"label": "Seed", "optional": True}) 113 | ], 114 | "outputs": [("images", gr.Gallery, {})], 115 | "preprocess": lambda *args: { 116 | k: (resize_image_if_needed(v) if k == "image" else v) 117 | for k, v in zip([ 118 | "prompt", "image", "negative_prompt", "strength", 119 | "num_inference_steps", "guidance_scale", "seed" 120 | ], args) if v is not None and v != "" 121 | }, 122 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)] 123 | }, 124 | 125 | "control-net": { 126 | "inputs": [ 127 | ("prompt", gr.Textbox, {"label": "Prompt"}), 128 | ("control_image", gr.Image, {"label": "Control Image", "type": "pil"}), 129 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}), 130 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}), 131 | ("control_guidance_scale", gr.Slider, {"label": "Control Guidance Scale", "minimum": 1, "maximum": 20, "value": 1.5, "optional": True}), 132 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}), 133 | ("seed", gr.Number, {"label": "Seed", "optional": True}) 134 | ], 135 | "outputs": [("images", gr.Gallery, {})], 136 | "preprocess": lambda *args: { 137 | k: (resize_image_if_needed(v) if k == "control_image" else v) 138 | for k, v in zip([ 139 | "prompt", "control_image", "negative_prompt", "guidance_scale", 140 | "control_guidance_scale", "num_inference_steps", "seed" 141 | ], args) if v is not None and v != "" 142 | }, 143 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)] 144 | }, 145 | 146 | "inpainting": { 147 | "inputs": [ 148 | ("prompt", gr.Textbox, {"label": "Prompt"}), 149 | ("image", gr.Image, {"label": "Original Image", "type": "pil"}), 150 | ("mask", gr.Image, {"label": "Mask Image", "type": "pil"}), 151 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}), 152 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}), 153 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}), 154 | ("seed", gr.Number, {"label": "Seed", "optional": True}) 155 | ], 156 | "outputs": [("images", gr.Gallery, {})], 157 | "preprocess": lambda *args: { 158 | k: (resize_image_if_needed(v) if k in ["image", "mask"] else v) 159 | for k, v in zip([ 160 | "prompt", "image", "mask", "negative_prompt", 161 | "num_inference_steps", "guidance_scale", "seed" 162 | ], args) if v is not None and v != "" 163 | }, 164 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)] 165 | }, 166 | 167 | "text-to-video": { 168 | "inputs": [ 169 | ("prompt", gr.Textbox, { 170 | "label": "Prompt", 171 | "value": "A cat walks on the grass, realistic style.", 172 | "info": "Text prompt to generate video." 173 | }), 174 | ("height", gr.Number, { 175 | "label": "Height", 176 | "value": 480, 177 | "minimum": 1, 178 | "info": "Height of the video in pixels." 179 | }), 180 | ("width", gr.Number, { 181 | "label": "Width", 182 | "value": 854, 183 | "minimum": 1, 184 | "info": "Width of the video in pixels." 185 | }), 186 | ("video_length", gr.Number, { 187 | "label": "Video Length", 188 | "value": 129, 189 | "minimum": 1, 190 | "info": "Length of the video in frames." 191 | }), 192 | ("infer_steps", gr.Number, { 193 | "label": "Infer Steps", 194 | "value": 30, 195 | "minimum": 1, 196 | "maximum": 50, 197 | "info": "Number of inference steps." 198 | }), 199 | ("flow_shift", gr.Number, { 200 | "label": "Flow Shift", 201 | "value": 7, 202 | "info": "Flow-shift parameter." 203 | }), 204 | ("embedded_guidance_scale", gr.Slider, { 205 | "label": "Embedded Guidance Scale", 206 | "value": 6, 207 | "minimum": 1, 208 | "maximum": 6, 209 | "info": "Embedded guidance scale for generation." 210 | }), 211 | ("seed", gr.Number, { 212 | "label": "Seed", 213 | "optional": True, 214 | "info": "Random seed for reproducibility." 215 | }) 216 | ], 217 | "outputs": [ 218 | ("video", gr.Video, { 219 | "format": "mp4", 220 | "autoplay": True, 221 | "show_label": True, 222 | "label": "Generated Video", 223 | "height": 480, 224 | "width": 854, 225 | "interactive": False, 226 | "show_download_button": True 227 | }) 228 | ], 229 | "preprocess": lambda *args: { 230 | k: (int(v) if k in ["height", "width", "video_length", "infer_steps", "seed"] else 231 | float(v) if k in ["flow_shift", "embedded_guidance_scale"] else v) 232 | for k, v in zip([ 233 | "prompt", "height", "width", "video_length", 234 | "infer_steps", "flow_shift", "embedded_guidance_scale", "seed" 235 | ], args) if v is not None and v != "" 236 | }, 237 | "postprocess": lambda x: ( 238 | x.url if hasattr(x, 'url') 239 | else (lambda p: os.remove(p) or p)(x) # Delete file after getting path 240 | ), 241 | }, 242 | 243 | "text-generation": { 244 | "inputs": [ 245 | ("message", gr.Textbox, { 246 | "label": "Message", 247 | "lines": 3, 248 | "placeholder": "Enter your message here..." 249 | }), 250 | ], 251 | "outputs": [ 252 | ("response", gr.Textbox, { 253 | "label": "Assistant", 254 | "lines": 10, 255 | "show_copy_button": True 256 | }) 257 | ], 258 | "preprocess": lambda *args: { 259 | "prompt": args[0] if args[0] is not None and args[0] != "" else "" 260 | }, 261 | "postprocess": lambda x: x, 262 | "is_chat": True # New flag to indicate this is a chat interface 263 | }, 264 | } 265 | 266 | MODEL_TO_PIPELINE = { 267 | "stability-ai/sdxl": "text-to-image", 268 | "black-forest-labs/flux-pro": "text-to-image", 269 | "stability-ai/stable-diffusion": "text-to-image", 270 | 271 | "black-forest-labs/flux-depth-pro": "control-net", 272 | "black-forest-labs/flux-canny-pro": "control-net", 273 | "black-forest-labs/flux-depth-dev": "control-net", 274 | 275 | "black-forest-labs/flux-fill-pro": "inpainting", 276 | "stability-ai/stable-diffusion-inpainting": "inpainting", 277 | "tencent/hunyuan-video:140176772be3b423d14fdaf5403e6d4e38b85646ccad0c3fd2ed07c211f0cad1": "text-to-video", 278 | "deepseek-ai/deepseek-r1": "text-generation", 279 | } 280 | 281 | def create_component(comp_type: type, name: str, config: Dict[str, Any]) -> gr.components.Component: 282 | # Remove 'optional' from config as it's not a valid Gradio parameter 283 | config = config.copy() 284 | is_optional = config.pop('optional', False) 285 | 286 | # Add "(Optional)" to label if the field is optional 287 | if is_optional: 288 | label = config.get('label', name) 289 | config['label'] = f"{label} (Optional)" 290 | 291 | return comp_type(label=config.get("label", name), **{k:v for k,v in config.items() if k != "label"}) 292 | 293 | def get_pipeline(model: str) -> str: 294 | return MODEL_TO_PIPELINE.get(model, "text-to-image") 295 | 296 | def get_interface_args(pipeline: str) -> Tuple[List, List, Callable, Callable]: 297 | if pipeline not in PIPELINE_REGISTRY: 298 | raise ValueError(f"Unsupported pipeline: {pipeline}") 299 | 300 | config = PIPELINE_REGISTRY[pipeline] 301 | 302 | inputs = [create_component(comp_type, name, conf) 303 | for name, comp_type, conf in config["inputs"]] 304 | 305 | outputs = [create_component(comp_type, name, conf) 306 | for name, comp_type, conf in config["outputs"]] 307 | 308 | return inputs, outputs, config["preprocess"], config["postprocess"] 309 | 310 | async def async_run_with_timeout(model_name: str, args: dict): 311 | try: 312 | stream = await replicate.async_stream( 313 | model_name, 314 | input=args 315 | ) 316 | async for output in stream: 317 | yield output 318 | except Exception as e: 319 | print(f"Error during model prediction: {str(e)}") 320 | raise gr.Error(f"Model prediction failed: {str(e)}") 321 | 322 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable): 323 | async def fn(*args): 324 | try: 325 | args = preprocess(*args) 326 | if model_name == "deepseek-ai/deepseek-r1": 327 | response = "" 328 | async for chunk in async_run_with_timeout(model_name, args): 329 | chunk = str(chunk) 330 | # Replace XML-like tags with square bracket versions 331 | chunk = (chunk.replace("", "[think]") 332 | .replace("", "[/think]") 333 | .replace("", "[Answer]") 334 | .replace("", "[/Answer]")) 335 | response += chunk 336 | return response.strip() 337 | output = await async_run_with_timeout(model_name, args) 338 | return postprocess(output) 339 | except Exception as e: 340 | raise gr.Error(f"Error: {str(e)}") 341 | return fn 342 | 343 | def registry(name: str | Dict, token: str | None = None, inputs=None, outputs=None, src=None, accept_token: bool = False, **kwargs) -> gr.Interface: 344 | """ 345 | Create a Gradio Interface for a model on Replicate. 346 | Parameters: 347 | - name (str | Dict): The name of the model on Replicate, or a dict with model info. 348 | - token (str, optional): The API token for the model on Replicate. 349 | - inputs (List[gr.Component], optional): The input components to use instead of the default. 350 | - outputs (List[gr.Component], optional): The output components to use instead of the default. 351 | - src (callable, optional): Ignored, used by gr.load for routing. 352 | - accept_token (bool, optional): Whether to accept a token input field. 353 | Returns: 354 | gr.Interface or gr.ChatInterface: A Gradio interface for the model. 355 | """ 356 | # Handle both string names and dict configurations 357 | if isinstance(name, dict): 358 | model_name = name.get('name', name.get('model_name', '')) 359 | else: 360 | model_name = name 361 | 362 | if token: 363 | os.environ["REPLICATE_API_TOKEN"] = token 364 | 365 | pipeline = get_pipeline(model_name) 366 | inputs_, outputs_, preprocess, postprocess = get_interface_args(pipeline) 367 | 368 | # Add token input if accept_token is True 369 | if accept_token: 370 | token_input = gr.Textbox(label="API Token", type="password") 371 | inputs_ = [token_input] + inputs_ 372 | 373 | # Modify preprocess function to handle token 374 | original_preprocess = preprocess 375 | def new_preprocess(token, *args): 376 | if token: 377 | os.environ["REPLICATE_API_TOKEN"] = token 378 | return original_preprocess(*args) 379 | preprocess = new_preprocess 380 | 381 | inputs, outputs = inputs or inputs_, outputs or outputs_ 382 | fn = get_fn(model_name, preprocess, postprocess) 383 | 384 | # Use ChatInterface for text-generation models 385 | if pipeline == "text-generation": 386 | return gr.ChatInterface( 387 | fn=fn, 388 | **kwargs 389 | ) 390 | 391 | return gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 392 | 393 | 394 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /ai_gradio/providers/__init__.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | def custom_load(name: str, src: dict, **kwargs): 4 | # Only use custom loading if name contains provider prefix 5 | if ':' in name: 6 | provider, model = name.split(':') 7 | # Create provider-specific model key 8 | model_key = f"{provider}:{model}" 9 | 10 | if model_key not in src: 11 | available_models = [k for k in src.keys()] 12 | raise ValueError(f"Model {model_key} not found. Available models: {available_models}") 13 | return src[model_key](name=model, **kwargs) 14 | 15 | # Fall back to original gradio behavior if no provider prefix 16 | return original_load(name, src, **kwargs) 17 | 18 | # Store original load function before overriding 19 | original_load = gr.load 20 | gr.load = custom_load 21 | 22 | registry = {} 23 | 24 | 25 | try: 26 | from .openai_gradio import registry as openai_registry 27 | registry.update({f"openai:{k}": openai_registry for k in [ 28 | "gpt-4o-2024-11-20", 29 | "gpt-4o", 30 | "gpt-4o-2024-08-06", 31 | "gpt-4o-2024-05-13", 32 | "chatgpt-4o-latest", 33 | "gpt-4o-mini", 34 | "gpt-4o-mini-2024-07-18", 35 | "o1-preview", 36 | "o1-preview-2024-09-12", 37 | "o1-mini", 38 | "o1-mini-2024-09-12", 39 | "gpt-4-turbo", 40 | "gpt-4-turbo-2024-04-09", 41 | "gpt-4-turbo-preview", 42 | "gpt-4-0125-preview", 43 | "gpt-4-1106-preview", 44 | "gpt-4", 45 | "gpt-4-0613", 46 | "o1-2024-12-17", 47 | "gpt-4o-realtime-preview-2024-10-01", 48 | "gpt-4o-realtime-preview", 49 | "gpt-4o-realtime-preview-2024-12-17", 50 | "gpt-4o-mini-realtime-preview", 51 | "gpt-4o-mini-realtime-preview-2024-12-17", 52 | "o3-mini-2025-01-31" 53 | ]}) 54 | except ImportError: 55 | pass 56 | 57 | try: 58 | from .gemini_gradio import registry as gemini_registry 59 | registry.update({f"gemini:{k}": gemini_registry for k in [ 60 | 'gemini-1.5-flash', 61 | 'gemini-1.5-flash-8b', 62 | 'gemini-1.5-pro', 63 | 'gemini-exp-1114', 64 | 'gemini-exp-1121', 65 | 'gemini-exp-1206', 66 | 'gemini-2.0-flash-exp', 67 | 'gemini-2.0-flash-thinking-exp-1219', 68 | 'gemini-2.0-flash-thinking-exp-01-21', 69 | 'gemini-2.0-pro-exp-02-05', 70 | 'gemini-2.0-flash-lite-preview-02-05' 71 | ]}) 72 | except ImportError: 73 | pass 74 | 75 | try: 76 | from .crewai_gradio import registry as crewai_registry 77 | # Add CrewAI models with their own prefix 78 | registry.update({f"crewai:{k}": crewai_registry for k in ['gpt-4-turbo', 'gpt-4', 'gpt-3.5-turbo']}) 79 | except ImportError: 80 | pass 81 | 82 | try: 83 | from .anthropic_gradio import registry as anthropic_registry 84 | registry.update({f"anthropic:{k}": anthropic_registry for k in [ 85 | 'claude-3-5-sonnet-20241022', 86 | 'claude-3-5-haiku-20241022', 87 | 'claude-3-opus-20240229', 88 | 'claude-3-sonnet-20240229', 89 | 'claude-3-haiku-20240307', 90 | ]}) 91 | except ImportError: 92 | pass 93 | 94 | try: 95 | from .lumaai_gradio import registry as lumaai_registry 96 | registry.update({f"lumaai:{k}": lumaai_registry for k in [ 97 | 'dream-machine', 98 | 'photon-1', 99 | 'photon-flash-1' 100 | ]}) 101 | except ImportError: 102 | pass 103 | 104 | try: 105 | from .xai_gradio import registry as xai_registry 106 | registry.update({f"xai:{k}": xai_registry for k in [ 107 | 'grok-beta', 108 | 'grok-vision-beta' 109 | ]}) 110 | except ImportError: 111 | pass 112 | 113 | try: 114 | from .cohere_gradio import registry as cohere_registry 115 | registry.update({f"cohere:{k}": cohere_registry for k in [ 116 | 'command-r7b-12-2024', 117 | 'command-light', 118 | 'command-nightly', 119 | 'command-light-nightly' 120 | ]}) 121 | except ImportError: 122 | pass 123 | 124 | try: 125 | from .sambanova_gradio import registry as sambanova_registry 126 | registry.update({f"sambanova:{k}": sambanova_registry for k in [ 127 | 'Meta-Llama-3.1-405B-Instruct', 128 | 'Meta-Llama-3.1-8B-Instruct', 129 | 'Meta-Llama-3.1-70B-Instruct', 130 | 'Meta-Llama-3.1-405B-Instruct-Preview', 131 | 'Meta-Llama-3.1-8B-Instruct-Preview', 132 | 'Meta-Llama-3.3-70B-Instruct', 133 | 'Meta-Llama-3.2-3B-Instruct', 134 | ]}) 135 | except ImportError: 136 | pass 137 | 138 | try: 139 | from .hyperbolic_gradio import registry as hyperbolic_registry 140 | registry.update({f"hyperbolic:{k}": hyperbolic_registry for k in [ 141 | 'Qwen/Qwen2.5-Coder-32B-Instruct', 142 | 'meta-llama/Llama-3.2-3B-Instruct', 143 | 'meta-llama/Meta-Llama-3.1-8B-Instruct', 144 | 'meta-llama/Meta-Llama-3.1-70B-Instruct', 145 | 'meta-llama/Meta-Llama-3-70B-Instruct', 146 | 'NousResearch/Hermes-3-Llama-3.1-70B', 147 | 'Qwen/Qwen2.5-72B-Instruct', 148 | 'deepseek-ai/DeepSeek-V2.5', 149 | 'meta-llama/Meta-Llama-3.1-405B-Instruct', 150 | 'Qwen/QwQ-32B-Preview', 151 | 'meta-llama/Llama-3.3-70B-Instruct', 152 | 'deepseek-ai/DeepSeek-V3', 153 | 'deepseek-ai/DeepSeek-R1', 154 | 'deepseek-ai/DeepSeek-R1-Zero' 155 | ]}) 156 | except ImportError: 157 | pass 158 | 159 | try: 160 | from .qwen_gradio import registry as qwen_registry 161 | registry.update({f"qwen:{k}": qwen_registry for k in [ 162 | "qwen-turbo-latest", 163 | "qwen-turbo", 164 | "qwen-plus", 165 | "qwen-max", 166 | "qwen1.5-110b-chat", 167 | "qwen1.5-72b-chat", 168 | "qwen1.5-32b-chat", 169 | "qwen1.5-14b-chat", 170 | "qwen1.5-7b-chat", 171 | "qwq-32b-preview", 172 | 'qvq-72b-preview', 173 | 'qwen2.5-14b-instruct-1m', 174 | 'qwen2.5-7b-instruct-1m', 175 | 'qwen-max-0125' 176 | ]}) 177 | except ImportError: 178 | pass 179 | 180 | try: 181 | from .fireworks_gradio import registry as fireworks_registry 182 | registry.update({f"fireworks:{k}": fireworks_registry for k in [ 183 | 'whisper-v3', 184 | 'whisper-v3-turbo', 185 | 'f1-preview', 186 | 'f1-mini' 187 | ]}) 188 | except ImportError: 189 | pass 190 | 191 | try: 192 | from .together_gradio import registry as together_registry 193 | registry.update({f"together:{k}": together_registry for k in [ 194 | # Vision Models 195 | 'meta-llama/Llama-Vision-Free', 196 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo', 197 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo', 198 | 199 | # Llama 3 Series 200 | 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', 201 | 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 202 | 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', 203 | 'meta-llama/Meta-Llama-3-8B-Instruct-Turbo', 204 | 'meta-llama/Meta-Llama-3-70B-Instruct-Turbo', 205 | 'meta-llama/Llama-3.2-3B-Instruct-Turbo', 206 | 'meta-llama/Meta-Llama-3-8B-Instruct-Lite', 207 | 'meta-llama/Meta-Llama-3-70B-Instruct-Lite', 208 | 'meta-llama/Llama-3-8b-chat-hf', 209 | 'meta-llama/Llama-3-70b-chat-hf', 210 | 'meta-llama/Llama-3.3-70B-Instruct-Turbo', 211 | 212 | # Other Large Models 213 | 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF', 214 | 'Qwen/Qwen2.5-Coder-32B-Instruct', 215 | 'microsoft/WizardLM-2-8x22B', 216 | 'databricks/dbrx-instruct', 217 | 218 | # Gemma Models 219 | 'google/gemma-2-27b-it', 220 | 'google/gemma-2-9b-it', 221 | 'google/gemma-2b-it', 222 | 223 | # Mixtral Models 224 | 'mistralai/Mixtral-8x7B-Instruct-v0.1', 225 | 'mistralai/Mixtral-8x22B-Instruct-v0.1', 226 | 227 | # Qwen Models 228 | 'Qwen/Qwen2.5-7B-Instruct-Turbo', 229 | 'Qwen/Qwen2.5-72B-Instruct-Turbo', 230 | 'Qwen/Qwen2-72B-Instruct', 231 | 232 | # Other Models 233 | 'deepseek-ai/deepseek-llm-67b-chat', 234 | 'Gryphe/MythoMax-L2-13b', 235 | 'meta-llama/Llama-2-13b-chat-hf', 236 | 'mistralai/Mistral-7B-Instruct-v0.1', 237 | 'mistralai/Mistral-7B-Instruct-v0.2', 238 | 'mistralai/Mistral-7B-Instruct-v0.3', 239 | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO', 240 | 'togethercomputer/StripedHyena-Nous-7B', 241 | 'upstage/SOLAR-10.7B-Instruct-v1.0', 242 | 'deepseek-ai/DeepSeek-V3', 243 | 'deepseek-ai/DeepSeek-R1', 244 | 'mistralai/Mistral-Small-24B-Instruct-2501' 245 | ]}) 246 | except ImportError: 247 | pass 248 | 249 | try: 250 | from .deepseek_gradio import registry as deepseek_registry 251 | registry.update({f"deepseek:{k}": deepseek_registry for k in [ 252 | 'deepseek-chat', 253 | 'deepseek-coder', 254 | 'deepseek-vision', 255 | 'deepseek-reasoner' 256 | ]}) 257 | except ImportError: 258 | pass 259 | 260 | try: 261 | from .smolagents_gradio import registry as smolagents_registry 262 | registry.update({f"smolagents:{k}": smolagents_registry for k in [ 263 | 'Qwen/Qwen2.5-72B-Instruct', 264 | 'Qwen/Qwen2.5-4B-Instruct', 265 | 'Qwen/Qwen2.5-1.8B-Instruct', 266 | 'meta-llama/Llama-3.3-70B-Instruct', 267 | 'meta-llama/Llama-3.1-8B-Instruct' 268 | ]}) 269 | except ImportError: 270 | pass 271 | 272 | try: 273 | from .groq_gradio import registry as groq_registry 274 | registry.update({f"groq:{k}": groq_registry for k in [ 275 | "llama3-groq-8b-8192-tool-use-preview", 276 | "llama3-groq-70b-8192-tool-use-preview", 277 | "llama-3.2-1b-preview", 278 | "llama-3.2-3b-preview", 279 | "llama-3.2-11b-vision-preview", 280 | "llama-3.2-90b-vision-preview", 281 | "mixtral-8x7b-32768", 282 | "gemma2-9b-it", 283 | "gemma-7b-it", 284 | "llama-3.3-70b-versatile", 285 | "llama-3.3-70b-specdec", 286 | "deepseek-r1-distill-llama-70b" 287 | ]}) 288 | except ImportError: 289 | pass 290 | 291 | try: 292 | from .browser_use_gradio import registry as browser_use_registry 293 | registry.update({f"browser:{k}": browser_use_registry for k in [ 294 | "gpt-4o-2024-11-20", 295 | "gpt-4o", 296 | "gpt-4o-2024-08-06", 297 | "gpt-4o-2024-05-13", 298 | "chatgpt-4o-latest", 299 | "gpt-4o-mini", 300 | "gpt-4o-mini-2024-07-18", 301 | "o1-preview", 302 | "o1-preview-2024-09-12", 303 | "o1-mini", 304 | "o1-mini-2024-09-12", 305 | "gpt-4-turbo", 306 | "gpt-4-turbo-2024-04-09", 307 | "gpt-4-turbo-preview", 308 | "gpt-4-0125-preview", 309 | "gpt-4-1106-preview", 310 | "gpt-4", 311 | "gpt-4-0613", 312 | "o1-2024-12-17", 313 | "gpt-4o-realtime-preview-2024-10-01", 314 | "gpt-4o-realtime-preview", 315 | "gpt-4o-realtime-preview-2024-12-17", 316 | "gpt-4o-mini-realtime-preview", 317 | "gpt-4o-mini-realtime-preview-2024-12-17", 318 | "gpt-3.5-turbo", 319 | "o3-mini-2025-01-31" 320 | ]}) 321 | except ImportError: 322 | pass 323 | 324 | try: 325 | from .swarms_gradio import registry as swarms_registry 326 | registry.update({f"swarms:{k}": swarms_registry for k in [ 327 | 'gpt-4-turbo', 328 | 'gpt-4o-mini', 329 | 'gpt-4', 330 | 'gpt-3.5-turbo' 331 | ]}) 332 | except ImportError: 333 | pass 334 | 335 | try: 336 | from .transformers_gradio import registry as transformers_registry 337 | registry.update({f"transformers:{k}": transformers_registry for k in [ 338 | "phi-4", 339 | "tulu-3", 340 | "olmo-2-13b", 341 | "smolvlm", 342 | "moondream", 343 | # Add other default transformers models here 344 | ]}) 345 | except ImportError: 346 | pass 347 | 348 | try: 349 | from .jupyter_agent import registry as jupyter_registry 350 | registry.update({f"jupyter:{k}": jupyter_registry for k in [ 351 | 'meta-llama/Llama-3.2-3B-Instruct', 352 | 'meta-llama/Llama-3.1-8B-Instruct', 353 | 'meta-llama/Llama-3.1-70B-Instruct' 354 | ]}) 355 | except ImportError: 356 | pass 357 | 358 | try: 359 | from .langchain_gradio import registry as langchain_registry 360 | registry.update({f"langchain:{k}": langchain_registry for k in [ 361 | 'gpt-4-turbo', 362 | 'gpt-4', 363 | 'gpt-3.5-turbo', 364 | 'gpt-3.5-turbo-0125' 365 | ]}) 366 | except ImportError as e: 367 | print(f"Failed to import LangChain registry: {e}") 368 | # Optionally add more detailed error handling here 369 | 370 | try: 371 | from .mistral_gradio import registry as mistral_registry 372 | registry.update({f"mistral:{k}": mistral_registry for k in [ 373 | "mistral-large-latest", 374 | "pixtral-large-latest", 375 | "ministral-3b-latest", 376 | "ministral-8b-latest", 377 | "mistral-small-latest", 378 | "codestral-latest", 379 | "mistral-embed", 380 | "mistral-moderation-latest", 381 | "pixtral-12b-2409", 382 | "open-mistral-nemo", 383 | "open-codestral-mamba", 384 | ]}) 385 | except ImportError: 386 | pass 387 | 388 | try: 389 | from .nvidia_gradio import registry as nvidia_registry 390 | registry.update({f"nvidia:{k}": nvidia_registry for k in [ 391 | "nvidia/llama3-chatqa-1.5-70b", 392 | "nvidia/cosmos-nemotron-34b", 393 | "nvidia/llama3-chatqa-1.5-8b", 394 | "nvidia-nemotron-4-340b-instruct", 395 | "meta/llama-3.1-70b-instruct", 396 | "meta/codellama-70b", 397 | "meta/llama2-70b", 398 | "meta/llama3-8b", 399 | "meta/llama3-70b", 400 | "mistralai/codestral-22b-instruct-v0.1", 401 | "mistralai/mathstral-7b-v0.1", 402 | "mistralai/mistral-large-2-instruct", 403 | "mistralai/mistral-7b-instruct", 404 | "mistralai/mistral-7b-instruct-v0.3", 405 | "mistralai/mixtral-8x7b-instruct", 406 | "mistralai/mixtral-8x22b-instruct", 407 | "mistralai/mistral-large", 408 | "google/gemma-2b", 409 | "google/gemma-7b", 410 | "google/gemma-2-2b-it", 411 | "google/gemma-2-9b-it", 412 | "google/gemma-2-27b-it", 413 | "google/codegemma-1.1-7b", 414 | "google/codegemma-7b", 415 | "google/recurrentgemma-2b", 416 | "google/shieldgemma-9b", 417 | "microsoft/phi-3-medium-128k-instruct", 418 | "microsoft/phi-3-medium-4k-instruct", 419 | "microsoft/phi-3-mini-128k-instruct", 420 | "microsoft/phi-3-mini-4k-instruct", 421 | "microsoft/phi-3-small-128k-instruct", 422 | "microsoft/phi-3-small-8k-instruct", 423 | "qwen/qwen2-7b-instruct", 424 | "databricks/dbrx-instruct", 425 | "deepseek-ai/deepseek-coder-6.7b-instruct", 426 | "upstage/solar-10.7b-instruct", 427 | "snowflake/arctic", 428 | "qwen/qwen2.5-7b-instruct", 429 | "deepseek-ai/deepseek-r1" 430 | ]}) 431 | except ImportError: 432 | pass 433 | 434 | try: 435 | from .minimax_gradio import registry as minimax_registry 436 | registry.update({f"minimax:{k}": minimax_registry for k in [ 437 | "MiniMax-Text-01", 438 | ]}) 439 | except ImportError: 440 | pass 441 | 442 | try: 443 | from .kokoro_gradio import registry as kokoro_registry 444 | registry.update({f"kokoro:{k}": kokoro_registry for k in [ 445 | "kokoro-v0_19" 446 | ]}) 447 | except ImportError: 448 | pass 449 | 450 | try: 451 | from .perplexity_gradio import registry as perplexity_registry 452 | registry.update({f"perplexity:{k}": perplexity_registry for k in [ 453 | 'sonar-pro', 454 | 'sonar', 455 | 'sonar-reasoning' 456 | ]}) 457 | except ImportError: 458 | pass 459 | 460 | try: 461 | from .cerebras_gradio import registry as cerebras_registry 462 | registry.update({f"cerebras:{k}": cerebras_registry for k in [ 463 | 'deepseek-r1-distill-llama-70b', 464 | ]}) 465 | except ImportError: 466 | pass 467 | 468 | try: 469 | from .replicate_gradio import registry as replicate_registry 470 | registry.update({f"replicate:{k}": replicate_registry for k in [ 471 | # Text to Image Models 472 | "stability-ai/sdxl", 473 | "black-forest-labs/flux-pro", 474 | "stability-ai/stable-diffusion", 475 | 476 | # Control Net Models 477 | "black-forest-labs/flux-depth-pro", 478 | "black-forest-labs/flux-canny-pro", 479 | "black-forest-labs/flux-depth-dev", 480 | 481 | # Inpainting Models 482 | "black-forest-labs/flux-fill-pro", 483 | "stability-ai/stable-diffusion-inpainting", 484 | 485 | # Text to Video Models 486 | "tencent/hunyuan-video:140176772be3b423d14fdaf5403e6d4e38b85646ccad0c3fd2ed07c211f0cad1", 487 | 488 | # Text Generation Models 489 | "deepseek-ai/deepseek-r1" 490 | ]}) 491 | except ImportError: 492 | pass 493 | 494 | if not registry: 495 | raise ImportError( 496 | "No providers installed. Install with either:\n" 497 | "pip install 'ai-gradio[openai]' for OpenAI support\n" 498 | "pip install 'ai-gradio[gemini]' for Gemini support\n" 499 | "pip install 'ai-gradio[crewai]' for CrewAI support\n" 500 | "pip install 'ai-gradio[anthropic]' for Anthropic support\n" 501 | "pip install 'ai-gradio[lumaai]' for LumaAI support\n" 502 | "pip install 'ai-gradio[xai]' for X.AI support\n" 503 | "pip install 'ai-gradio[cohere]' for Cohere support\n" 504 | "pip install 'ai-gradio[sambanova]' for SambaNova support\n" 505 | "pip install 'ai-gradio[hyperbolic]' for Hyperbolic support\n" 506 | "pip install 'ai-gradio[qwen]' for Qwen support\n" 507 | "pip install 'ai-gradio[fireworks]' for Fireworks support\n" 508 | "pip install 'ai-gradio[deepseek]' for DeepSeek support\n" 509 | "pip install 'ai-gradio[smolagents]' for SmolaAgents support\n" 510 | "pip install 'ai-gradio[jupyter]' for Jupyter support\n" 511 | "pip install 'ai-gradio[langchain]' for LangChain support\n" 512 | "pip install 'ai-gradio[mistral]' for Mistral support\n" 513 | "pip install 'ai-gradio[nvidia]' for NVIDIA support\n" 514 | "pip install 'ai-gradio[minimax]' for MiniMax support\n" 515 | "pip install 'ai-gradio[kokoro]' for Kokoro support\n" 516 | "pip install 'ai-gradio[perplexity]' for Perplexity support\n" 517 | "pip install 'ai-gradio[cerebras]' for Cerebras support\n" 518 | "pip install 'ai-gradio[replicate]' for Replicate support\n" 519 | "pip install 'ai-gradio[all]' for all providers\n" 520 | "pip install 'ai-gradio[swarms]' for Swarms support" 521 | ) 522 | 523 | __all__ = ["registry"] 524 | --------------------------------------------------------------------------------