├── app.py
├── .gitignore
├── ai_gradio
├── __init__.py
└── providers
│ ├── kokoro_gradio.py
│ ├── cerebras_gradio.py
│ ├── anthropic_gradio.py
│ ├── browser_use_gradio.py
│ ├── swarms_gradio.py
│ ├── langchain_gradio.py
│ ├── xai_gradio.py
│ ├── cohere_gradio.py
│ ├── together_gradio.py
│ ├── sambanova_gradio.py
│ ├── smolagents_gradio.py
│ ├── transformers_gradio.py
│ ├── lumaai_gradio.py
│ ├── fireworks_gradio.py
│ ├── hyperbolic_gradio.py
│ ├── perplexity_gradio.py
│ ├── crewai_gradio.py
│ ├── mistral_gradio.py
│ ├── qwen_gradio.py
│ ├── deepseek_gradio.py
│ ├── replicate_gradio.py
│ └── __init__.py
├── pyproject.toml
└── README.md
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import ai_gradio
3 |
4 |
5 | gr.load(
6 | name='gemini:gemini-2.0-flash-lite-preview-02-05',
7 | src=ai_gradio.registry,
8 | coder=True
9 | ).launch()
10 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python cache files
2 | __pycache__/
3 | *.pyc
4 |
5 | # Virtual environment
6 | env/
7 | .venv/
8 | venv/**
9 | venv/
10 |
11 | # Package artifacts
12 | dist/
13 | build/
14 | *.egg-info/
15 | .codegpt
16 | .vscode
17 |
--------------------------------------------------------------------------------
/ai_gradio/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | try:
4 | __version__ = version("ai-gradio")
5 | except Exception:
6 | __version__ = "unknown"
7 |
8 | from .providers import registry
9 |
10 | __all__ = ["registry"]
11 |
--------------------------------------------------------------------------------
/ai_gradio/providers/kokoro_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import soundfile as sf
3 | from kokoro_onnx import Kokoro
4 | import gradio as gr
5 | from typing import Callable
6 | from huggingface_hub import hf_hub_download
7 |
8 | __version__ = "0.0.1"
9 |
10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable):
11 | # Download model and voices.json from HuggingFace Hub
12 | model_path = hf_hub_download(
13 | repo_id="hexgrad/Kokoro-82M",
14 | filename="kokoro-v0_19.onnx",
15 | repo_type="model"
16 | )
17 |
18 | voices_path = hf_hub_download(
19 | repo_id="akhaliq/Kokoro-82M",
20 | filename="voices.json",
21 | repo_type="model"
22 | )
23 |
24 | def chat_response(message, history, voice="af_sarah", speed=1.0, lang="en-us"):
25 | try:
26 | kokoro = Kokoro(model_path, voices_path)
27 | samples, sample_rate = kokoro.create(
28 | text=message,
29 | voice=voice,
30 | speed=speed,
31 | lang=lang
32 | )
33 |
34 | # Save to temporary file with unique name based on history length
35 | output_path = f"response_{len(history)}.wav"
36 | sf.write(output_path, samples, sample_rate)
37 |
38 | return {
39 | "role": "assistant",
40 | "content": {
41 | "path": output_path
42 | }
43 | }
44 |
45 | except Exception as e:
46 | return f"Error generating audio: {str(e)}"
47 |
48 | return chat_response
49 |
50 | def registry(name: str, **kwargs):
51 | """Register kokoro TTS interface"""
52 |
53 | interface = gr.ChatInterface(
54 | fn=get_fn(name, lambda x: x, lambda x: x),
55 | additional_inputs=[
56 | gr.Dropdown(
57 | choices=["af_sarah", "en_jenny", "en_ryan"],
58 | value="af_sarah",
59 | label="Voice"
60 | ),
61 | gr.Slider(
62 | minimum=0.5,
63 | maximum=2.0,
64 | value=1.0,
65 | step=0.1,
66 | label="Speed"
67 | ),
68 | gr.Dropdown(
69 | choices=["en-us", "en-gb"],
70 | value="en-us",
71 | label="Language"
72 | )
73 | ],
74 | title="Kokoro Text-to-Speech",
75 | description="Generate speech from text using Kokoro TTS",
76 | **kwargs
77 | )
78 |
79 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/cerebras_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | from typing import Callable
4 | from cerebras.cloud.sdk import Cerebras
5 |
6 | __version__ = "0.0.1"
7 |
8 |
9 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
10 | def fn(message, history):
11 | inputs = preprocess(message, history)
12 | client = Cerebras(
13 | api_key=api_key,
14 | )
15 | completion = client.chat.completions.create(
16 | model=model_name,
17 | messages=inputs["messages"],
18 | stream=True,
19 | max_completion_tokens=1024,
20 | temperature=0.2,
21 | top_p=1
22 | )
23 |
24 | # Streaming response to Gradio ChatInterface UI
25 | response_text = ""
26 | for chunk in completion:
27 | delta = chunk.choices[0].delta.content or ""
28 | response_text += delta
29 | yield postprocess(response_text)
30 |
31 | return fn
32 |
33 |
34 | def get_interface_args(pipeline):
35 | if pipeline == "chat":
36 | inputs = None
37 | outputs = None
38 |
39 | def preprocess(message, history):
40 | messages = []
41 | for user_msg, assistant_msg in history:
42 | messages.append({"role": "user", "content": user_msg})
43 | messages.append({"role": "assistant", "content": assistant_msg})
44 | messages.append({"role": "user", "content": message})
45 | return {"messages": messages}
46 |
47 | postprocess = lambda x: x
48 | else:
49 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
50 | return inputs, outputs, preprocess, postprocess
51 |
52 |
53 | def get_pipeline(model_name):
54 | # Determine the pipeline type based on the model name
55 | # For simplicity, assuming all models are chat models at the moment
56 | return "chat"
57 |
58 |
59 | def registry(name: str, token: str | None = None, **kwargs):
60 | """
61 | Create a Gradio Interface for a model on Cerebras.
62 |
63 | Parameters:
64 | - name (str): The name of the model on Cerebras.
65 | - token (str, optional): The API key for Cerebras.
66 | """
67 | api_key = token or os.environ.get("CEREBRAS_API_KEY")
68 | if not api_key:
69 | raise ValueError("CEREBRAS_API_KEY environment variable is not set.")
70 |
71 | pipeline = get_pipeline(name)
72 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
73 | fn = get_fn(name, preprocess, postprocess, api_key)
74 |
75 | if pipeline == "chat":
76 | interface = gr.ChatInterface(fn=fn, **kwargs)
77 | else:
78 | # For other pipelines, create a standard Interface (not implemented yet)
79 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
80 |
81 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/anthropic_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import anthropic
3 | import gradio as gr
4 | from typing import Callable
5 |
6 | __version__ = "0.0.1"
7 |
8 |
9 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
10 | def fn(message, history):
11 | inputs = preprocess(message, history)
12 | client = anthropic.Anthropic(api_key=api_key)
13 | with client.messages.stream(
14 | model=model_name,
15 | max_tokens=1000,
16 | messages=inputs["messages"]
17 | ) as stream:
18 | response_text = ""
19 | for chunk in stream:
20 | if chunk.type == "content_block_delta":
21 | delta = chunk.delta.text
22 | response_text += delta
23 | yield postprocess(response_text)
24 |
25 | return fn
26 |
27 |
28 | def get_interface_args(pipeline):
29 | if pipeline == "chat":
30 | inputs = None
31 | outputs = None
32 |
33 | def preprocess(message, history):
34 | messages = []
35 | for user_msg, assistant_msg in history:
36 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
37 | messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
38 | messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
39 | return {"messages": messages}
40 |
41 | postprocess = lambda x: x # No post-processing needed
42 | else:
43 | # Add other pipeline types when they will be needed
44 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
45 | return inputs, outputs, preprocess, postprocess
46 |
47 |
48 | def get_pipeline(model_name):
49 | # Determine the pipeline type based on the model name
50 | # For simplicity, assuming all models are chat models at the moment
51 | return "chat"
52 |
53 |
54 | def registry(name: str, token: str | None = None, **kwargs):
55 | """
56 | Create a Gradio Interface for a model on Anthropic.
57 |
58 | Parameters:
59 | - name (str): The name of the Anthropic model.
60 | - token (str, optional): The API key for Anthropic.
61 | """
62 | api_key = token or os.environ.get("ANTHROPIC_API_KEY")
63 | if not api_key:
64 | raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")
65 |
66 | pipeline = get_pipeline(name)
67 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
68 | fn = get_fn(name, preprocess, postprocess, api_key)
69 |
70 | if pipeline == "chat":
71 | interface = gr.ChatInterface(fn=fn, **kwargs)
72 | else:
73 | # For other pipelines, create a standard Interface (not implemented yet)
74 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
75 |
76 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/browser_use_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from langchain_openai import ChatOpenAI
3 | from browser_use import Agent
4 | import gradio as gr
5 | from typing import Callable
6 | import asyncio
7 |
8 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
9 | async def fn(message, history):
10 | inputs = preprocess(message, history)
11 |
12 | agent = Agent(
13 | task=inputs["message"],
14 | llm=ChatOpenAI(
15 | api_key=api_key,
16 | model=model_name,
17 | disabled_params={"parallel_tool_calls": None}
18 | ),
19 | use_vision=(model_name != "o3-mini-2025-01-31") # Only disable vision for o3-mini
20 | )
21 |
22 | try:
23 | result = await agent.run()
24 | return postprocess(result)
25 | except Exception as e:
26 | return f"Error: {str(e)}"
27 |
28 | return fn # Remove sync_wrapper, return async function directly
29 |
30 | def get_interface_args(pipeline):
31 | if pipeline == "browser":
32 | def preprocess(message, history):
33 | return {"message": message}
34 |
35 | def postprocess(result):
36 | if hasattr(result, 'all_results') and hasattr(result, 'all_model_outputs'):
37 | # Get the thought process from non-final results
38 | thoughts = [r.extracted_content for r in result.all_results if not r.is_done]
39 | # Get the final answer from the last result
40 | final_answer = next((r.extracted_content for r in reversed(result.all_results) if r.is_done), None)
41 |
42 | if final_answer:
43 | # Return in Gradio's message format
44 | return {
45 | "role": "assistant",
46 | "content": final_answer,
47 | "metadata": {
48 | "title": "🔍 " + " → ".join(thoughts)
49 | }
50 | }
51 |
52 | # Fallback to simple message format
53 | return {
54 | "role": "assistant",
55 | "content": str(result) if not isinstance(result, str) else result
56 | }
57 |
58 | return preprocess, postprocess
59 | else:
60 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
61 |
62 | def registry(name: str, token: str | None = None, **kwargs):
63 | api_key = token or os.environ.get("OPENAI_API_KEY")
64 | if not api_key:
65 | raise ValueError("OPENAI_API_KEY environment variable is not set.")
66 |
67 | pipeline = "browser"
68 | preprocess, postprocess = get_interface_args(pipeline)
69 | fn = get_fn(name, preprocess, postprocess, api_key)
70 |
71 | # Remove title and description from kwargs if they exist
72 | kwargs.pop('title', None)
73 | kwargs.pop('description', None)
74 |
75 | interface = gr.ChatInterface(
76 | fn=fn,
77 | title="Browser Use Agent",
78 | description="Chat with an AI agent that can perform browser tasks.",
79 | examples=["Go to amazon.com and find the best rated laptop and return the price.",
80 | "Find the current weather in New York"],
81 | **kwargs
82 | )
83 |
84 | return interface
85 |
--------------------------------------------------------------------------------
/ai_gradio/providers/swarms_gradio.py:
--------------------------------------------------------------------------------
1 | from swarms import Agent
2 | from swarm_models import OpenAIChat
3 | import gradio as gr
4 | from typing import Generator, List, Dict, Callable
5 | import base64
6 | import os
7 |
8 | __version__ = "0.0.1"
9 |
10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str, agent_name: str = "Assistant"):
11 | def fn(message, history):
12 | inputs = preprocess(message, history)
13 | agent = create_agent(model_name, agent_name)
14 |
15 | try:
16 | for response in stream_agent_response(agent, inputs["message"]):
17 | yield postprocess(response)
18 | except Exception as e:
19 | yield postprocess({
20 | "role": "assistant",
21 | "content": f"Error: {str(e)}",
22 | "metadata": {"title": "❌ Error"}
23 | })
24 |
25 | return fn
26 |
27 | def get_interface_args(pipeline):
28 | if pipeline == "chat":
29 | def preprocess(message, history):
30 | messages = []
31 | for user_msg, assistant_msg in history:
32 | if assistant_msg is not None:
33 | messages.append({"role": "user", "content": user_msg})
34 | messages.append({"role": "assistant", "content": assistant_msg})
35 | messages.append({"role": "user", "content": message})
36 | return {"message": message, "history": messages}
37 |
38 | def postprocess(response):
39 | return response
40 |
41 | return None, None, preprocess, postprocess
42 | else:
43 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
44 |
45 | def create_agent(model_name: str = None, agent_name: str = "Assistant", system_prompt: str = None):
46 | """Create a Swarms Agent with the specified model and configuration"""
47 | model = OpenAIChat(model_name=model_name) if model_name else OpenAIChat()
48 |
49 | return Agent(
50 | agent_name=agent_name,
51 | system_prompt=system_prompt or "You are a helpful AI assistant.",
52 | llm=model,
53 | max_loops=1,
54 | autosave=False,
55 | dashboard=False,
56 | verbose=True,
57 | dynamic_temperature_enabled=True,
58 | streaming_on=True,
59 | saved_state_path=f"{agent_name.lower()}_state.json",
60 | user_name="gradio_user",
61 | retry_attempts=1,
62 | context_length=200000,
63 | return_step_meta=False,
64 | output_type="string"
65 | )
66 |
67 | def stream_agent_response(agent: Agent, prompt: str) -> Generator[Dict, None, None]:
68 | # Initial thinking message
69 | yield {
70 | "role": "assistant",
71 | "content": "Let me think about that...",
72 | "metadata": {"title": "🤔 Thinking"}
73 | }
74 |
75 | try:
76 | # Get response from agent
77 | response = agent.run(prompt)
78 |
79 | # Stream final response
80 | yield {
81 | "role": "assistant",
82 | "content": response,
83 | "metadata": {"title": "💬 Response"}
84 | }
85 |
86 | except Exception as e:
87 | yield {
88 | "role": "assistant",
89 | "content": f"Error: {str(e)}",
90 | "metadata": {"title": "❌ Error"}
91 | }
92 |
93 | def registry(name: str, token: str | None = None, agent_name: str = "Assistant", **kwargs):
94 | api_key = token or os.environ.get("OPENAI_API_KEY")
95 | if not api_key:
96 | raise ValueError("API key is not set. Please provide a token or set OPENAI_API_KEY environment variable.")
97 |
98 | pipeline = "chat" # Swarms only supports chat for now
99 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
100 | fn = get_fn(name, preprocess, postprocess, api_key, agent_name)
101 |
102 | interface = gr.ChatInterface(fn=fn, **kwargs)
103 | return interface
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "ai-gradio"
7 | version = "0.2.38"
8 | description = "A Python package for creating Gradio applications with AI models"
9 | authors = [
10 | { name = "AK", email = "ahsen.khaliq@gmail.com" }
11 | ]
12 | readme = "README.md"
13 | requires-python = ">=3.10"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 | dependencies = [
20 | "torch>=2.0.0",
21 | "numpy",
22 | "accelerate>=0.27.0",
23 | "bitsandbytes>=0.41.0",
24 | "gradio>=5.9.1",
25 | "gradio-webrtc",
26 | "websockets",
27 | "twilio",
28 | "Pillow",
29 | "opencv-python",
30 | "librosa",
31 | "pydub",
32 | "gradio_webrtc[vad]",
33 | "numba==0.60.0",
34 | "python-dotenv",
35 | "modelscope-studio",
36 | ]
37 |
38 | [project.urls]
39 | homepage = "https://github.com/AK391/ai-gradio"
40 | repository = "https://github.com/AK391/ai-gradio"
41 |
42 | [project.optional-dependencies]
43 | dev = [
44 | "pytest",
45 | "black",
46 | "isort",
47 | "flake8"
48 | ]
49 | transformers = ["transformers>=4.37.0", "torch>=2.0.0", "accelerate>=0.27.0", "bitsandbytes>=0.41.0", "einops>=0.8.0", "Pillow>=10.4.0", "pyvips-binary>=8.16.0", "pyvips>=2.2.3", "torchvision>=0.18.1"]
50 | openai = ["openai>=1.58.1"]
51 | gemini = ["google-generativeai>=0.8.3", "google-genai==0.3.0"]
52 | crewai = [
53 | "crewai>=0.1.0",
54 | "langchain>=0.1.0",
55 | "langchain-openai>=0.0.2",
56 | "crewai-tools>=0.0.1"
57 | ]
58 | anthropic = ["anthropic>=1.0.0"]
59 | lumaai = ["lumaai>=0.0.3"]
60 | xai = ["openai>=1.58.1"]
61 | cohere = ["cohere>=5.0.0"]
62 | sambanova = ["openai>=1.58.1"]
63 | hyperbolic = ["openai>=1.58.1"]
64 | qwen = [
65 | "openai>=1.58.1",
66 | ]
67 | browser = [
68 | "browser-use>=0.1.16",
69 | "playwright>=1.49.1"
70 | ]
71 | swarms = [
72 | "swarms>=6.8.9,<7.0.0",
73 | "swarm-models>=0.3.0,<0.4.0",
74 | "swarms-memory>=0.1.2,<0.2.0",
75 | "langchain>=0.1.0",
76 | "langchain-community>=0.0.10",
77 | "pydantic>=2.0.0,<3.0.0"
78 | ]
79 | kokoro = [
80 | "kokoro-onnx>=0.3.3",
81 | "soundfile>=0.13.0",
82 | "huggingface-hub>=0.27.1"
83 | ]
84 | all = [
85 | "openai>=1.58.1",
86 | "google-generativeai",
87 | "crewai>=0.1.0",
88 | "langchain>=0.1.0",
89 | "langchain-openai>=0.0.2",
90 | "crewai-tools>=0.0.1",
91 | "anthropic>=1.0.0",
92 | "lumaai>=0.0.3",
93 | "cohere>=5.0.0",
94 | "sambanova>=0.0.1",
95 | "hyperbolic>=1.58.1",
96 | "groq>=0.3.0",
97 | "browser-use>=0.1.0",
98 | "swarms>=6.8.9,<7.0.0",
99 | "swarm-models>=0.3.0,<0.4.0",
100 | "swarms-memory>=0.1.2,<0.2.0",
101 | "langchain>=0.1.0",
102 | "pydantic>=2.0.0,<3.0.0",
103 | "langchain-openai>=0.0.2",
104 | "langchain-community>=0.0.10",
105 | "langchain-core>=0.1.0",
106 | "tavily-python>=0.3.0",
107 | "requests>=2.31.0",
108 | "kokoro-onnx>=0.1.0",
109 | "soundfile>=0.12.0",
110 | "huggingface-hub>=0.20.0"
111 | ]
112 | groq = ["openai>=1.58.1"]
113 | fireworks = ["openai>=1.58.1"]
114 | together = ["openai>=1.58.1"]
115 | deepseek = ["openai>=1.58.1"]
116 | smolagents = ["smolagents>=0.1.3"]
117 | jupyter = [
118 | "nbformat>=5.9.0",
119 | "nbconvert>=7.16.0",
120 | "e2b-code-interpreter>=0.5.0",
121 | "huggingface-hub>=0.20.0",
122 | ]
123 |
124 | langchain = [
125 | "langchain",
126 | "langchain-community",
127 | "langchain-core",
128 | "tavily-python",
129 | "langchain-openai"
130 | ]
131 |
132 | mistral = [
133 | "mistralai"
134 | ]
135 |
136 | nvidia = [
137 | "openai>=1.58.1"
138 | ]
139 |
140 | minimax = ["requests>=2.31.0"]
141 |
142 | perplexity = [
143 | "openai>=1.58.1"
144 | ]
145 |
146 | replicate = [
147 | "replicate>=1.0.4"
148 | ]
149 |
150 | [tool.hatch.build.targets.wheel]
151 | packages = ["ai_gradio"]
152 |
153 |
--------------------------------------------------------------------------------
/ai_gradio/providers/langchain_gradio.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 | from langchain.agents import create_tool_calling_agent, AgentExecutor
3 | from langchain_community.tools.tavily_search import TavilySearchResults
4 | from langchain import hub
5 | from langchain_core.messages import AIMessage, HumanMessage
6 | from langchain_community.chat_message_histories import ChatMessageHistory
7 | from langchain_core.runnables.history import RunnableWithMessageHistory
8 | import gradio as gr
9 | from typing import Generator, List, Dict
10 |
11 | def create_agent(model_name: str = None):
12 | # Initialize search tool
13 | search = TavilySearchResults()
14 | tools = [search]
15 |
16 | # Initialize LLM
17 | llm = ChatOpenAI(
18 | model_name=model_name if model_name else "gpt-3.5-turbo-0125",
19 | temperature=0
20 | )
21 |
22 | # Get the prompt
23 | prompt = hub.pull("hwchase17/openai-functions-agent")
24 |
25 | # Create the agent
26 | agent = create_tool_calling_agent(llm, tools, prompt)
27 |
28 | # Create the executor
29 | return AgentExecutor(agent=agent, tools=tools, verbose=True)
30 |
31 | def stream_agent_response(agent: AgentExecutor, message: str, history: List) -> Generator[Dict, None, None]:
32 | # First yield the thinking message
33 | yield {
34 | "role": "assistant",
35 | "content": "Let me think about that...",
36 | "metadata": {"title": "🤔 Thinking"}
37 | }
38 |
39 | try:
40 | # Convert history to LangChain format
41 | chat_history = []
42 | for msg in history:
43 | if msg["role"] == "user":
44 | chat_history.append(HumanMessage(content=msg["content"]))
45 | elif msg["role"] == "assistant":
46 | chat_history.append(AIMessage(content=msg["content"]))
47 |
48 | # Run the agent
49 | response = agent.invoke({
50 | "input": message,
51 | "chat_history": chat_history
52 | })
53 |
54 | # Yield the final response
55 | yield {
56 | "role": "assistant",
57 | "content": response["output"]
58 | }
59 |
60 | except Exception as e:
61 | yield {
62 | "role": "assistant",
63 | "content": f"Error: {str(e)}",
64 | "metadata": {"title": "❌ Error"}
65 | }
66 |
67 | async def interact_with_agent(message: str, history: List, model_name: str = None) -> Generator[List, None, None]:
68 | # Add user message
69 | history.append({"role": "user", "content": message})
70 | yield history
71 |
72 | # Create agent instance with specified model
73 | agent = create_agent(model_name)
74 |
75 | # Stream agent responses
76 | for response in stream_agent_response(agent, message, history):
77 | history.append(response)
78 | yield history
79 |
80 | def registry(name: str, **kwargs):
81 | # Extract model name from the name parameter
82 | model_name = name.split(':')[-1] if ':' in name else None
83 |
84 | with gr.Blocks() as demo:
85 | gr.Markdown("# LangChain Assistant 🦜️")
86 |
87 | chatbot = gr.Chatbot(
88 | type="messages",
89 | label="Agent",
90 | avatar_images=(None, "https://python.langchain.com/img/favicon.ico"),
91 | height=500
92 | )
93 |
94 | msg = gr.Textbox(
95 | label="Your message",
96 | placeholder="Type your message here...",
97 | lines=1
98 | )
99 |
100 | async def handle_message(message, history):
101 | async for response in interact_with_agent(message, history, model_name=model_name):
102 | yield response
103 |
104 | msg.submit(
105 | fn=handle_message,
106 | inputs=[msg, chatbot],
107 | outputs=[chatbot],
108 | api_name="predict"
109 | ).then(lambda _:"", msg, msg)
110 |
111 | return demo
112 |
--------------------------------------------------------------------------------
/ai_gradio/providers/xai_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | from openai import OpenAI
4 | import gradio as gr
5 | from typing import Callable
6 |
7 | __version__ = "0.0.1"
8 |
9 | def get_image_base64(url: str, ext: str):
10 | with open(url, "rb") as image_file:
11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
12 | return "data:image/" + ext + ";base64," + encoded_string
13 |
14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
15 | def fn(message, history):
16 | inputs = preprocess(message, history)
17 | client = OpenAI(
18 | api_key=api_key,
19 | base_url="https://api.x.ai/v1",
20 | )
21 | completion = client.chat.completions.create(
22 | model=model_name,
23 | messages=inputs["messages"],
24 | stream=True,
25 | )
26 | response_text = ""
27 | for chunk in completion:
28 | delta = chunk.choices[0].delta.content or ""
29 | response_text += delta
30 | yield postprocess(response_text)
31 |
32 | return fn
33 |
34 | def handle_user_msg(message: str):
35 | if type(message) is str:
36 | return message
37 | elif type(message) is dict:
38 | if message["files"] is not None and len(message["files"]) > 0:
39 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
40 | if ext.lower() in ["png", "jpg", "jpeg", "gif"]:
41 | encoded_str = get_image_base64(message["files"][-1], ext)
42 | else:
43 | raise NotImplementedError(f"Not supported file type {ext}")
44 | content = [
45 | {"type": "text", "text": message["text"]},
46 | {
47 | "type": "image_url",
48 | "image_url": {
49 | "url": encoded_str,
50 | }
51 | },
52 | ]
53 | else:
54 | content = message["text"]
55 | return content
56 | else:
57 | raise NotImplementedError
58 |
59 | def get_interface_args(model_name: str):
60 | inputs = None
61 | outputs = None
62 |
63 | def preprocess(message, history):
64 | messages = [{"role": "system", "content": "You are Grok, a chatbot inspired by the Hitchhikers Guide to the Galaxy."}]
65 | files = None
66 | for user_msg, assistant_msg in history:
67 | if assistant_msg is not None:
68 | messages.append({"role": "user", "content": handle_user_msg(user_msg)})
69 | messages.append({"role": "assistant", "content": assistant_msg})
70 | else:
71 | files = user_msg
72 |
73 | if type(message) is str and files is not None:
74 | message = {"text": message, "files": files}
75 | elif type(message) is dict and files is not None:
76 | if message["files"] is None or len(message["files"]) == 0:
77 | message["files"] = files
78 |
79 | messages.append({"role": "user", "content": handle_user_msg(message)})
80 | return {"messages": messages}
81 |
82 | postprocess = lambda x: x # No post-processing needed
83 | return inputs, outputs, preprocess, postprocess
84 |
85 | def registry(name: str = "grok-beta", token: str | None = None, **kwargs):
86 | """
87 | Create a Gradio Interface for X.AI's Grok model.
88 |
89 | Parameters:
90 | - name (str): The name of the model (defaults to "grok-beta" or "grok-vision-beta")
91 | - token (str, optional): The X.AI API key
92 | """
93 | api_key = token or os.environ.get("XAI_API_KEY")
94 | if not api_key:
95 | raise ValueError("XAI_API_KEY environment variable is not set.")
96 |
97 | inputs, outputs, preprocess, postprocess = get_interface_args(name)
98 | fn = get_fn(name, preprocess, postprocess, api_key)
99 |
100 | # Always set multimodal=True
101 | kwargs["multimodal"] = True
102 |
103 | interface = gr.ChatInterface(fn=fn, **kwargs)
104 |
105 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/cohere_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cohere
3 | import gradio as gr
4 | from typing import Callable
5 | import base64
6 |
7 | __version__ = "0.0.3"
8 |
9 |
10 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
11 | def fn(message, history):
12 | inputs = preprocess(message, history)
13 | client = cohere.Client(api_key=api_key)
14 | stream = client.chat_stream(
15 | model=model_name,
16 | message=inputs["message"],
17 | chat_history=inputs["chat_history"],
18 | )
19 | response_text = ""
20 | for chunk in stream:
21 | if chunk.event_type == "text-generation":
22 | delta = chunk.text
23 | response_text += delta
24 | yield postprocess(response_text)
25 |
26 | return fn
27 |
28 |
29 | def get_image_base64(url: str, ext: str):
30 | with open(url, "rb") as image_file:
31 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
32 | return "data:image/" + ext + ";base64," + encoded_string
33 |
34 |
35 | def handle_user_msg(message: str):
36 | if type(message) is str:
37 | return message
38 | elif type(message) is dict:
39 | if message["files"] is not None and len(message["files"]) > 0:
40 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
41 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
42 | encoded_str = get_image_base64(message["files"][-1], ext)
43 | else:
44 | raise NotImplementedError(f"Not supported file type {ext}")
45 | return {
46 | "message": message["text"],
47 | "attachments": [
48 | {
49 | "source": encoded_str,
50 | "type": "image",
51 | }
52 | ]
53 | }
54 | else:
55 | return message["text"]
56 | else:
57 | raise NotImplementedError
58 |
59 |
60 | def get_interface_args(pipeline):
61 | if pipeline == "chat":
62 | inputs = None
63 | outputs = None
64 |
65 | def preprocess(message, history):
66 | chat_history = []
67 | files = None
68 | for user_msg, assistant_msg in history:
69 | if assistant_msg is not None:
70 | chat_history.append({"role": "USER", "message": handle_user_msg(user_msg)})
71 | chat_history.append({"role": "ASSISTANT", "message": assistant_msg})
72 | else:
73 | files = user_msg
74 | if type(message) is str and files is not None:
75 | message = {"text": message, "files": files}
76 | elif type(message) is dict and files is not None:
77 | if message["files"] is None or len(message["files"]) == 0:
78 | message["files"] = files
79 |
80 | return {
81 | "message": handle_user_msg(message),
82 | "chat_history": chat_history
83 | }
84 |
85 | postprocess = lambda x: x
86 | else:
87 | # Add other pipeline types when they will be needed
88 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
89 | return inputs, outputs, preprocess, postprocess
90 |
91 |
92 | def get_pipeline(model_name):
93 | # Determine the pipeline type based on the model name
94 | # For simplicity, assuming all models are chat models at the moment
95 | return "chat"
96 |
97 |
98 | def registry(name: str, token: str | None = None, **kwargs):
99 | """
100 | Create a Gradio Interface for a model on Cohere.
101 |
102 | Parameters:
103 | - name (str): The name of the Cohere model.
104 | - token (str, optional): The API key for Cohere.
105 | """
106 | api_key = token or os.environ.get("COHERE_API_KEY")
107 | if not api_key:
108 | raise ValueError("COHERE_API_KEY environment variable is not set.")
109 |
110 | pipeline = get_pipeline(name)
111 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
112 | fn = get_fn(name, preprocess, postprocess, api_key)
113 |
114 | if pipeline == "chat":
115 | interface = gr.ChatInterface(fn=fn, multimodal=True, **kwargs)
116 | else:
117 | # For other pipelines, create a standard Interface (not implemented yet)
118 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
119 |
120 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/together_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | from huggingface_hub import InferenceClient
4 | import gradio as gr
5 | from typing import Callable
6 |
7 | __version__ = "0.0.1"
8 |
9 | def get_image_base64(url: str, ext: str):
10 | with open(url, "rb") as image_file:
11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
12 | return "data:image/" + ext + ";base64," + encoded_string
13 |
14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
15 | def fn(message, history):
16 | inputs = preprocess(message, history)
17 | client = InferenceClient(
18 | provider="together",
19 | token=api_key
20 | )
21 | try:
22 | completion = client.chat.completions.create(
23 | model=model_name,
24 | messages=inputs["messages"],
25 | stream=True,
26 | max_tokens=1000
27 | )
28 |
29 | partial_message = ""
30 | for chunk in completion:
31 | if chunk.choices:
32 | delta = chunk.choices[0].delta.content or ""
33 | delta = delta.replace("", "[think]").replace("", "[/think]")
34 | partial_message += delta
35 | yield postprocess(partial_message)
36 |
37 | except Exception as e:
38 | error_message = f"Error: {str(e)}"
39 | yield error_message
40 |
41 | return fn
42 |
43 | def handle_user_msg(message: str):
44 | if type(message) is str:
45 | return message
46 | elif type(message) is dict:
47 | if message["files"] is not None and len(message["files"]) > 0:
48 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
49 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
50 | encoded_str = get_image_base64(message["files"][-1], ext)
51 | else:
52 | raise NotImplementedError(f"Not supported file type {ext}")
53 | content = [
54 | {"type": "text", "text": message["text"]},
55 | {
56 | "type": "image_url",
57 | "image_url": {
58 | "url": encoded_str,
59 | }
60 | },
61 | ]
62 | else:
63 | content = message["text"]
64 | return content
65 | else:
66 | raise NotImplementedError
67 |
68 | def get_interface_args(pipeline):
69 | if pipeline == "chat":
70 | inputs = None
71 | outputs = None
72 |
73 | def preprocess(message, history):
74 | messages = []
75 | # Process history first
76 | for user_msg, assistant_msg in history:
77 | messages.append({"role": "user", "content": str(user_msg)})
78 | if assistant_msg is not None:
79 | messages.append({"role": "assistant", "content": str(assistant_msg)})
80 |
81 | # Add current message
82 | messages.append({"role": "user", "content": str(message)})
83 | return {"messages": messages}
84 |
85 | postprocess = lambda x: x # No post-processing needed
86 | else:
87 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
88 | return inputs, outputs, preprocess, postprocess
89 |
90 |
91 | def get_pipeline(model_name):
92 | # Determine the pipeline type based on the model name
93 | # For simplicity, assuming all models are chat models at the moment
94 | return "chat"
95 |
96 |
97 | def registry(name: str, token: str | None = None, **kwargs):
98 | """
99 | Create a Gradio Interface for a model on Together.
100 |
101 | Parameters:
102 | - name (str): The name of the model on Together.
103 | - token (str, optional): The API key for Together.
104 | """
105 | api_key = token or os.environ.get("HF_TOKEN")
106 | if not api_key:
107 | raise ValueError("HF_TOKEN environment variable is not set.")
108 |
109 | pipeline = get_pipeline(name)
110 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
111 | fn = get_fn(name, preprocess, postprocess, api_key)
112 |
113 | if pipeline == "chat":
114 | interface = gr.ChatInterface(fn=fn, **kwargs)
115 | else:
116 | # For other pipelines, create a standard Interface (not implemented yet)
117 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
118 |
119 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/sambanova_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | from openai import OpenAI
4 | import gradio as gr
5 | from typing import Callable
6 |
7 | __version__ = "0.0.1"
8 |
9 | def get_image_base64(url: str, ext: str):
10 | with open(url, "rb") as image_file:
11 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
12 | return "data:image/" + ext + ";base64," + encoded_string
13 |
14 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
15 | def fn(message, history):
16 | inputs = preprocess(message, history)
17 | client = OpenAI(
18 | base_url="https://api.sambanova.ai/v1/",
19 | api_key=api_key,
20 | )
21 | try:
22 | completion = client.chat.completions.create(
23 | model=model_name,
24 | messages=inputs["messages"],
25 | stream=True,
26 | )
27 | response_text = ""
28 | for chunk in completion:
29 | delta = chunk.choices[0].delta.content or ""
30 | response_text += delta
31 | yield postprocess(response_text)
32 | except Exception as e:
33 | error_message = f"Error: {str(e)}"
34 | return error_message
35 |
36 | return fn
37 |
38 | def handle_user_msg(message: str):
39 | if type(message) is str:
40 | return message
41 | elif type(message) is dict:
42 | if message["files"] is not None and len(message["files"]) > 0:
43 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
44 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
45 | encoded_str = get_image_base64(message["files"][-1], ext)
46 | else:
47 | raise NotImplementedError(f"Not supported file type {ext}")
48 | content = [
49 | {"type": "text", "text": message["text"]},
50 | {
51 | "type": "image_url",
52 | "image_url": {
53 | "url": encoded_str,
54 | }
55 | },
56 | ]
57 | else:
58 | content = message["text"]
59 | return content
60 | else:
61 | raise NotImplementedError
62 |
63 | def get_interface_args(pipeline):
64 | if pipeline == "chat":
65 | inputs = None
66 | outputs = None
67 |
68 | def preprocess(message, history):
69 | messages = []
70 | files = None
71 | for user_msg, assistant_msg in history:
72 | if assistant_msg is not None:
73 | messages.append({"role": "user", "content": handle_user_msg(user_msg)})
74 | messages.append({"role": "assistant", "content": assistant_msg})
75 | else:
76 | files = user_msg
77 | if type(message) is str and files is not None:
78 | message = {"text":message, "files":files}
79 | elif type(message) is dict and files is not None:
80 | if message["files"] is None or len(message["files"]) == 0:
81 | message["files"] = files
82 | messages.append({"role": "user", "content": handle_user_msg(message)})
83 | return {"messages": messages}
84 |
85 | postprocess = lambda x: x # No post-processing needed
86 | else:
87 | # Add other pipeline types when they will be needed
88 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
89 | return inputs, outputs, preprocess, postprocess
90 |
91 |
92 | def get_pipeline(model_name):
93 | # Determine the pipeline type based on the model name
94 | # For simplicity, assuming all models are chat models at the moment
95 | return "chat"
96 |
97 |
98 | def registry(name: str, token: str | None = None, **kwargs):
99 | """
100 | Create a Gradio Interface for a model on Sambanova.
101 |
102 | Parameters:
103 | - name (str): The name of the model on Sambanova.
104 | - token (str, optional): The API key for Sambanova.
105 | """
106 | api_key = token or os.environ.get("SAMBANOVA_API_KEY")
107 | if not api_key:
108 | raise ValueError("SAMBANOVA_API_KEY environment variable is not set.")
109 |
110 | pipeline = get_pipeline(name)
111 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
112 | fn = get_fn(name, preprocess, postprocess, api_key)
113 |
114 | if pipeline == "chat":
115 | interface = gr.ChatInterface(fn=fn, **kwargs)
116 | else:
117 | # For other pipelines, create a standard Interface (not implemented yet)
118 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
119 |
120 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/smolagents_gradio.py:
--------------------------------------------------------------------------------
1 | from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel
2 | from smolagents.agents import ActionStep
3 | import gradio as gr
4 | from typing import Generator, List, Dict
5 |
6 | def create_agent(model_name: str = None):
7 | model = HfApiModel(model_name) if model_name else HfApiModel()
8 | return CodeAgent(tools=[DuckDuckGoSearchTool()], model=model)
9 |
10 | def stream_agent_response(agent: CodeAgent, prompt: str) -> Generator[Dict, None, None]:
11 | # First yield the thinking message
12 | yield {
13 | "role": "assistant",
14 | "content": "Let me think about that...",
15 | "metadata": {"title": "🤔 Thinking"}
16 | }
17 |
18 | # Run the agent and capture its response
19 | try:
20 | # Get the agent's response
21 | for step in agent.run(prompt, stream=True):
22 | if isinstance(step, ActionStep):
23 | # Show LLM output if present (as collapsible thought)
24 | if step.llm_output:
25 | yield {
26 | "role": "assistant",
27 | "content": step.llm_output,
28 | "metadata": {"title": "🧠 Thought Process"}
29 | }
30 |
31 | # Show tool call if present
32 | if step.tool_call:
33 | content = step.tool_call.arguments
34 | if step.tool_call.name == "python_interpreter":
35 | content = f"```python\n{content}\n```"
36 | yield {
37 | "role": "assistant",
38 | "content": str(content),
39 | "metadata": {"title": f"🛠️ Using {step.tool_call.name}"}
40 | }
41 |
42 | # Show observations if present
43 | if step.observations:
44 | yield {
45 | "role": "assistant",
46 | "content": f"```\n{step.observations}\n```",
47 | "metadata": {"title": "👁️ Observations"}
48 | }
49 |
50 | # Show errors if present
51 | if step.error:
52 | yield {
53 | "role": "assistant",
54 | "content": str(step.error),
55 | "metadata": {"title": "❌ Error"}
56 | }
57 |
58 | # Show final output if present (without metadata to keep it expanded)
59 | if step.action_output is not None and not step.error:
60 | # Only show the final output if it's actually the last step
61 | if step == step.action_output:
62 | yield {
63 | "role": "assistant",
64 | "content": str(step.action_output)
65 | }
66 | else:
67 | # For any other type of step output
68 | yield {
69 | "role": "assistant",
70 | "content": str(step),
71 | "metadata": {"title": "🔄 Processing"}
72 | }
73 |
74 | except Exception as e:
75 | yield {
76 | "role": "assistant",
77 | "content": f"Error: {str(e)}",
78 | "metadata": {"title": "❌ Error"}
79 | }
80 |
81 | async def interact_with_agent(message: str, history: List, model_name: str = None) -> Generator[List, None, None]:
82 | # Add user message
83 | history.append({"role": "user", "content": message})
84 | yield history
85 |
86 | # Create agent instance with specified model
87 | agent = create_agent(model_name)
88 |
89 | # Stream agent responses
90 | for response in stream_agent_response(agent, message):
91 | history.append(response)
92 | yield history
93 |
94 | def registry(name: str, **kwargs):
95 | # Extract model name from the name parameter
96 | model_name = name.split(':')[-1] if ':' in name else None
97 |
98 | with gr.Blocks() as demo:
99 | gr.Markdown("# SmolagentsAI Assistant 🤖")
100 |
101 | chatbot = gr.Chatbot(
102 | type="messages",
103 | label="Agent",
104 | avatar_images=(None, "https://cdn-lfs.hf.co/repos/96/a2/96a2c8468c1546e660ac2609e49404b8588fcf5a748761fa72c154b2836b4c83/9cf16f4f32604eaf76dabbdf47701eea5a768ebcc7296acc1d1758181f71db73?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27hf-logo.png%3B+filename%3D%22hf-logo.png%22%3B&response-content-type=image%2Fpng&Expires=1735927745&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNTkyNzc0NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy85Ni9hMi85NmEyYzg0NjhjMTU0NmU2NjBhYzI2MDllNDk0MDRiODU4OGZjZjVhNzQ4NzYxZmE3MmMxNTRiMjgzNmI0YzgzLzljZjE2ZjRmMzI2MDRlYWY3NmRhYmJkZjQ3NzAxZWVhNWE3NjhlYmNjNzI5NmFjYzFkMTc1ODE4MWY3MWRiNzM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=a%7EGHhtB%7EWzd7pjl5F5wJHVxQPymZEQcgcAoc5FkJggZcWPDrcZ82CtDpIQafi0CultDPMy8SiZgHA0PDDk31c5a6wzdIbsbm7zZ5NvGTTlZpXskL3x7Gbr-f2E3yOA%7EHR%7E2heEJlpim78-xLqkWA92CYo-tKLg-yHKMx0acQcBvhptHOZtwlb9%7EyHlqlzNpcLo4iqEgEH39ADRNhpkf54-Zj6SQNBod7AkjFA3-iIzX5LVzW6EEYyFs03Ba0AfBUODgZIt8cjglULQ2a02rgiM%7EjKMBmB2eKNDFtvoe7YSGlFbVcLt21pWjhzA-z9MgQsw-U3ZDY539iHkMMkfoQzA__&Key-Pair-Id=K3RPWS32NSSJCE"),
105 | height=500
106 | )
107 |
108 | msg = gr.Textbox(
109 | label="Your message",
110 | placeholder="Type your message here...",
111 | lines=1
112 | )
113 |
114 | # Make the wrapper function async and await the generator
115 | async def handle_message(message, history):
116 | async for response in interact_with_agent(message, history, model_name=model_name):
117 | yield response
118 |
119 | msg.submit(
120 | fn=handle_message,
121 | inputs=[msg, chatbot],
122 | outputs=[chatbot],
123 | api_name="predict"
124 | )
125 |
126 | return demo
--------------------------------------------------------------------------------
/ai_gradio/providers/transformers_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3 | import torch
4 | import gradio as gr
5 | from typing import Callable
6 | from PIL import Image
7 |
8 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, **kwargs):
9 | device = "cuda" if torch.cuda.is_available() else "cpu"
10 |
11 | # Get the full model path
12 | if "/" in model_name:
13 | model_path = model_name
14 | else:
15 | model_mapping = {
16 | "tulu-3": "allenai/llama-tulu-3-8b",
17 | "olmo-2-13b": "allenai/OLMo-2-1124-13B-Instruct",
18 | "smolvlm": "HuggingFaceTB/SmolVLM-Instruct",
19 | "phi-4": "microsoft/phi-4",
20 | "moondream": "vikhyatk/moondream2",
21 | }
22 | model_path = model_mapping.get(model_name)
23 | if not model_path:
24 | raise ValueError(f"Unknown model name: {model_name}")
25 |
26 | # Load model and tokenizer
27 | tokenizer = AutoTokenizer.from_pretrained(model_path)
28 |
29 | if model_name == "moondream":
30 | model = AutoModelForCausalLM.from_pretrained(
31 | model_path,
32 | revision="2025-01-09",
33 | trust_remote_code=True,
34 | device_map={"": "cuda" if torch.cuda.is_available() else "cpu"}
35 | )
36 | elif device == "cuda":
37 | model = AutoModelForCausalLM.from_pretrained(
38 | model_path,
39 | device_map="auto",
40 | torch_dtype=torch.float16
41 | )
42 | else:
43 | model = AutoModelForCausalLM.from_pretrained(
44 | model_path,
45 | device_map="auto",
46 | torch_dtype=torch.float32
47 | )
48 |
49 | def predict(message, history, temperature=0.7, max_tokens=512, image=None):
50 | # Create a new list for history to avoid sharing between sessions
51 | history = list(history) if history else []
52 |
53 | if model_name == "moondream":
54 | if isinstance(message, dict):
55 | text = message["text"]
56 | # Get the first image from files if available
57 | files = message.get("files", [])
58 | image = files[0] if files else image # Use provided image if no files
59 |
60 | if image is not None:
61 | # Ensure image is a PIL Image
62 | if not isinstance(image, Image.Image):
63 | try:
64 | image = Image.open(image)
65 | except Exception as e:
66 | return f"Error processing image: {str(e)}"
67 |
68 | # Generate response and return as a string
69 | response = model.query(image, text)["answer"]
70 | return response
71 | else:
72 | return "Please provide an image to analyze."
73 |
74 | # Format conversation history
75 | if isinstance(message, dict):
76 | message = message["text"]
77 |
78 | messages = []
79 | for user_msg, assistant_msg in history:
80 | messages.append({"role": "user", "content": user_msg})
81 | messages.append({"role": "assistant", "content": assistant_msg})
82 | messages.append({"role": "user", "content": message})
83 |
84 | # Convert to model format
85 | input_text = tokenizer.apply_chat_template(messages, tokenize=False)
86 | inputs = tokenizer(input_text, return_tensors="pt").to(device)
87 |
88 | # Generate response
89 | generation_config = {
90 | "input_ids": inputs["input_ids"],
91 | "max_new_tokens": max_tokens,
92 | "temperature": float(temperature), # Ensure temperature is a float
93 | "do_sample": True,
94 | "pad_token_id": tokenizer.eos_token_id
95 | }
96 |
97 | outputs = model.generate(**generation_config)
98 |
99 | # For phi-4, extract only the new generated text
100 | if model_name == "phi-4":
101 | input_length = inputs["input_ids"].shape[1]
102 | generated_tokens = outputs[0][input_length:]
103 | response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
104 | else:
105 | response = tokenizer.decode(outputs[0], skip_special_tokens=True)
106 |
107 | return response
108 |
109 | return predict
110 |
111 | def get_interface_args(pipeline):
112 | if pipeline == "chat":
113 | def preprocess(message, history):
114 | return {"message": message, "history": history}
115 |
116 | def postprocess(response):
117 | return response
118 |
119 | return None, None, preprocess, postprocess
120 | elif pipeline == "vision-chat":
121 | def preprocess(message, history):
122 | return {"message": message, "history": history}
123 |
124 | def postprocess(response):
125 | return response
126 |
127 | return [gr.Textbox(label="Message"), gr.Image(type="pil")], None, preprocess, postprocess
128 | else:
129 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
130 |
131 | def get_pipeline(model_name):
132 | if model_name == "moondream":
133 | return "vision-chat"
134 | return "chat"
135 |
136 | def registry(name: str = None, **kwargs):
137 | pipeline = get_pipeline(name)
138 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
139 | fn = get_fn(name, preprocess, postprocess, **kwargs)
140 |
141 | if pipeline == "vision-chat":
142 | interface = gr.ChatInterface(
143 | fn=fn,
144 | additional_inputs=[
145 | gr.Slider(0, 1, 0.7, label="Temperature"),
146 | gr.Slider(1, 2048, 512, label="Max tokens"),
147 | ],
148 | multimodal=True # Enable multimodal input
149 | )
150 | else:
151 | interface = gr.ChatInterface(
152 | fn=fn,
153 | additional_inputs=[
154 | gr.Slider(0, 1, 0.7, label="Temperature"),
155 | gr.Slider(1, 2048, 512, label="Max tokens"),
156 | ]
157 | )
158 |
159 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/lumaai_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from lumaai import LumaAI
3 | import gradio as gr
4 | from typing import Callable
5 | import base64
6 | import time
7 |
8 | __version__ = "0.0.3"
9 |
10 |
11 | def get_fn(preprocess: Callable, postprocess: Callable, api_key: str, pipeline: str):
12 | def fn(message, history, generation_type):
13 | try:
14 | inputs = preprocess(message, history)
15 | # Create a fresh client instance for each generation
16 | client = LumaAI(auth_token=api_key)
17 |
18 | # Validate generation type
19 | if generation_type not in ["video", "image"]:
20 | raise ValueError(f"Invalid generation type: {generation_type}")
21 |
22 | try:
23 | if generation_type == "video":
24 | generation = client.generations.create(
25 | prompt=inputs["prompt"],
26 | **inputs.get("additional_params", {})
27 | )
28 | else: # image
29 | generation = client.generations.image.create(
30 | prompt=inputs["prompt"],
31 | **inputs.get("additional_params", {})
32 | )
33 | except Exception as e:
34 | raise RuntimeError(f"Failed to create generation: {str(e)}")
35 |
36 | # Poll for completion with timeout
37 | start_time = time.time()
38 | timeout = 300 # 5 minutes timeout
39 |
40 | while True:
41 | if time.time() - start_time > timeout:
42 | raise RuntimeError("Generation timed out after 5 minutes")
43 |
44 | try:
45 | generation = client.generations.get(id=generation.id)
46 | if generation.state == "completed":
47 | asset_url = generation.assets.video if generation_type == "video" else generation.assets.image
48 | break
49 | elif generation.state == "failed":
50 | raise RuntimeError(f"Generation failed: {generation.failure_reason}")
51 | time.sleep(3)
52 | yield f"Generating {generation_type}... (Status: {generation.state})"
53 | except Exception as e:
54 | raise RuntimeError(f"Error checking generation status: {str(e)}")
55 |
56 | # Return asset URL wrapped in appropriate format for display
57 | yield postprocess(asset_url, generation_type)
58 |
59 | except Exception as e:
60 | yield f"Error: {str(e)}"
61 | raise
62 |
63 | return fn
64 |
65 |
66 | def get_image_base64(url: str, ext: str):
67 | with open(url, "rb") as image_file:
68 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
69 | return "data:image/" + ext + ";base64," + encoded_string
70 |
71 |
72 | def handle_user_msg(message: str):
73 | if type(message) is str:
74 | return message
75 | elif type(message) is dict:
76 | if message["files"] is not None and len(message["files"]) > 0:
77 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
78 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
79 | encoded_str = get_image_base64(message["files"][-1], ext)
80 | else:
81 | raise NotImplementedError(f"Not supported file type {ext}")
82 | content = [
83 | {"type": "text", "text": message["text"]},
84 | {
85 | "type": "image_url",
86 | "image_url": {
87 | "url": encoded_str,
88 | }
89 | },
90 | ]
91 | else:
92 | content = message["text"]
93 | return content
94 | else:
95 | raise NotImplementedError
96 |
97 |
98 | def get_interface_args(pipeline):
99 | generation_type = gr.Dropdown(
100 | choices=["video", "image"],
101 | label="Generation Type",
102 | value="video" if pipeline == "video" else "image"
103 | )
104 | outputs = None
105 |
106 | def preprocess(message, history):
107 | if isinstance(message, str):
108 | return {"prompt": message}
109 | elif isinstance(message, dict):
110 | prompt = message.get("text", "")
111 | additional_params = {}
112 |
113 | # Handle optional parameters
114 | if message.get("aspect_ratio"):
115 | additional_params["aspect_ratio"] = message["aspect_ratio"]
116 | if message.get("model"):
117 | additional_params["model"] = message["model"]
118 |
119 | return {
120 | "prompt": prompt,
121 | "additional_params": additional_params
122 | }
123 |
124 | def postprocess(url, generation_type):
125 | if generation_type == "video":
126 | return f''
127 | else:
128 | return f""
129 |
130 | return generation_type, outputs, preprocess, postprocess
131 |
132 |
133 | def get_pipeline(model_name):
134 | # Support both video and image pipelines
135 | return "video" if model_name == "dream-machine" else "image"
136 |
137 |
138 | def registry(name: str = "dream-machine", token: str = None, **kwargs):
139 | """
140 | Create a Gradio Interface for LumaAI generation.
141 |
142 | Parameters:
143 | - name (str): Model name (defaults to 'dream-machine' for video, use 'photon-1' or 'photon-flash-1' for images)
144 | - token (str, optional): The API key for LumaAI
145 | - **kwargs: Additional keyword arguments passed to gr.Interface
146 | """
147 | api_key = token or kwargs.pop('api_key', None) or os.environ.get("LUMAAI_API_KEY")
148 | if not api_key:
149 | raise ValueError("API key must be provided either through token parameter, kwargs, or LUMAAI_API_KEY environment variable.")
150 |
151 | pipeline = get_pipeline(name)
152 | generation_type, outputs, preprocess, postprocess = get_interface_args(pipeline)
153 | fn = get_fn(preprocess, postprocess, api_key, pipeline)
154 |
155 | interface = gr.ChatInterface(
156 | fn=fn,
157 | additional_inputs=[generation_type],
158 | type="messages",
159 | title="LumaAI Generation",
160 | description="Generate videos or images from text prompts using LumaAI",
161 | **kwargs
162 | )
163 |
164 | return interface
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ai-gradio
2 |
3 | A Python package that makes it easy for developers to create machine learning apps powered by various AI providers. Built on top of Gradio, it provides a unified interface for multiple AI models and services.
4 |
5 | ## Features
6 |
7 | ### Core Features
8 | - **Multi-Provider Support**: Integrate with 15+ AI providers including OpenAI, Google Gemini, Anthropic, and more
9 | - **Text Chat**: Interactive chat interfaces for all text models
10 | - **Voice Chat**: Real-time voice interactions with OpenAI models
11 | - **Video Chat**: Video processing capabilities with Gemini models
12 | - **Code Generation**: Specialized interfaces for coding assistance
13 | - **Multi-Modal**: Support for text, image, and video inputs
14 | - **Agent Teams**: CrewAI integration for collaborative AI tasks
15 | - **Browser Automation**: AI agents that can perform web-based tasks
16 |
17 | ### Model Support
18 |
19 | #### Core Language Models
20 | | Provider | Models |
21 | |----------|---------|
22 | | OpenAI | gpt-4-turbo, gpt-4, gpt-3.5-turbo |
23 | | Anthropic | claude-3-opus, claude-3-sonnet, claude-3-haiku |
24 | | Gemini | gemini-pro, gemini-pro-vision, gemini-2.0-flash-exp |
25 | | Groq | llama-3.2-70b-chat, mixtral-8x7b-chat |
26 |
27 | #### Specialized Models
28 | | Provider | Type | Models |
29 | |----------|------|---------|
30 | | LumaAI | Generation | dream-machine, photon-1 |
31 | | DeepSeek | Multi-purpose | deepseek-chat, deepseek-coder, deepseek-vision |
32 | | CrewAI | Agent Teams | Support Team, Article Team |
33 | | Qwen | Language | qwen-turbo, qwen-plus, qwen-max |
34 | | Browser | Automation | browser-use-agent |
35 |
36 | ## Installation
37 |
38 | ### Basic Installation
39 | ```bash
40 | # Install core package
41 | pip install ai-gradio
42 |
43 | # Install with specific provider support
44 | pip install 'ai-gradio[openai]' # OpenAI support
45 | pip install 'ai-gradio[gemini]' # Google Gemini support
46 | pip install 'ai-gradio[anthropic]' # Anthropic Claude support
47 | pip install 'ai-gradio[groq]' # Groq support
48 |
49 | # Install all providers
50 | pip install 'ai-gradio[all]'
51 | ```
52 |
53 | ### Additional Providers
54 | ```bash
55 | pip install 'ai-gradio[crewai]' # CrewAI support
56 | pip install 'ai-gradio[lumaai]' # LumaAI support
57 | pip install 'ai-gradio[xai]' # XAI/Grok support
58 | pip install 'ai-gradio[cohere]' # Cohere support
59 | pip install 'ai-gradio[sambanova]' # SambaNova support
60 | pip install 'ai-gradio[hyperbolic]' # Hyperbolic support
61 | pip install 'ai-gradio[deepseek]' # DeepSeek support
62 | pip install 'ai-gradio[smolagents]' # SmolagentsAI support
63 | pip install 'ai-gradio[fireworks]' # Fireworks support
64 | pip install 'ai-gradio[together]' # Together support
65 | pip install 'ai-gradio[qwen]' # Qwen support
66 | pip install 'ai-gradio[browser]' # Browser support
67 | ```
68 |
69 | ## Usage
70 |
71 | ### API Key Configuration
72 | ```bash
73 | # Core Providers
74 | export OPENAI_API_KEY=
75 | export GEMINI_API_KEY=
76 | export ANTHROPIC_API_KEY=
77 | export GROQ_API_KEY=
78 | export TAVILY_API_KEY= # Required for Langchain agents
79 |
80 | # Additional Providers (as needed)
81 | export LUMAAI_API_KEY=
82 | export XAI_API_KEY=
83 | export COHERE_API_KEY=
84 | # ... (other provider keys)
85 |
86 | # Twilio credentials (required for WebRTC voice chat)
87 | export TWILIO_ACCOUNT_SID=
88 | export TWILIO_AUTH_TOKEN=
89 | ```
90 |
91 | ### Quick Start
92 | ```python
93 | import gradio as gr
94 | import ai_gradio
95 |
96 | # Create a simple chat interface
97 | gr.load(
98 | name='openai:gpt-4-turbo', # or 'gemini:gemini-1.5-flash', 'groq:llama-3.2-70b-chat'
99 | src=ai_gradio.registry,
100 | title='AI Chat',
101 | description='Chat with an AI model'
102 | ).launch()
103 |
104 | # Create a chat interface with Transformers models
105 | gr.load(
106 | name='transformers:phi-4', # or 'transformers:tulu-3', 'transformers:olmo-2-13b'
107 | src=ai_gradio.registry,
108 | title='Local AI Chat',
109 | description='Chat with locally running models'
110 | ).launch()
111 |
112 | # Create a coding assistant with OpenAI
113 | gr.load(
114 | name='openai:gpt-4-turbo',
115 | src=ai_gradio.registry,
116 | coder=True,
117 | title='OpenAI Code Assistant',
118 | description='OpenAI Code Generator'
119 | ).launch()
120 |
121 | # Create a coding assistant with Gemini
122 | gr.load(
123 | name='gemini:gemini-2.0-flash-thinking-exp-1219', # or 'openai:gpt-4-turbo', 'anthropic:claude-3-opus'
124 | src=ai_gradio.registry,
125 | coder=True,
126 | title='Gemini Code Generator',
127 | ).launch()
128 | ```
129 |
130 | ### Advanced Features
131 |
132 | #### Voice Chat
133 | ```python
134 | gr.load(
135 | name='openai:gpt-4-turbo',
136 | src=ai_gradio.registry,
137 | enable_voice=True,
138 | title='AI Voice Assistant'
139 | ).launch()
140 | ```
141 |
142 | #### Camera Mode
143 | ```python
144 | # Create a vision-enabled interface with camera support
145 | gr.load(
146 | name='gemini:gemini-2.0-flash-exp',
147 | src=ai_gradio.registry,
148 | camera=True,
149 | ).launch()
150 | ```
151 |
152 | #### Multi-Provider Interface
153 | ```python
154 | import gradio as gr
155 | import ai_gradio
156 |
157 | with gr.Blocks() as demo:
158 | with gr.Tab("Text"):
159 | gr.load('openai:gpt-4-turbo', src=ai_gradio.registry)
160 | with gr.Tab("Vision"):
161 | gr.load('gemini:gemini-pro-vision', src=ai_gradio.registry)
162 | with gr.Tab("Code"):
163 | gr.load('deepseek:deepseek-coder', src=ai_gradio.registry)
164 |
165 | demo.launch()
166 | ```
167 |
168 | #### CrewAI Teams
169 | ```python
170 | # Article Creation Team
171 | gr.load(
172 | name='crewai:gpt-4-turbo',
173 | src=ai_gradio.registry,
174 | crew_type='article',
175 | title='AI Writing Team'
176 | ).launch()
177 | ```
178 |
179 | #### Browser Automation
180 |
181 | ```bash
182 | playwright install
183 | ```
184 |
185 | use python 3.11+ for browser use
186 |
187 | ```python
188 | import gradio as gr
189 | import ai_gradio
190 |
191 | # Create a browser automation interface
192 | gr.load(
193 | name='browser:gpt-4-turbo',
194 | src=ai_gradio.registry,
195 | title='AI Browser Assistant',
196 | description='Let AI help with web tasks'
197 | ).launch()
198 | ```
199 |
200 | Example tasks:
201 | - Flight searches on Google Flights
202 | - Weather lookups
203 | - Product price comparisons
204 | - News searches
205 |
206 | #### Swarms Integration
207 | ```python
208 | import gradio as gr
209 | import ai_gradio
210 |
211 | # Create a chat interface with Swarms
212 | gr.load(
213 | name='swarms:gpt-4-turbo', # or other OpenAI models
214 | src=ai_gradio.registry,
215 | agent_name="Stock-Analysis-Agent", # customize agent name
216 | title='Swarms Chat',
217 | description='Chat with an AI agent powered by Swarms'
218 | ).launch()
219 | ```
220 |
221 | #### Langchain Agents
222 | ```python
223 | import gradio as gr
224 | import ai_gradio
225 |
226 | # Create a Langchain agent interface
227 | gr.load(
228 | name='langchain:gpt-4-turbo', # or other supported models
229 | src=ai_gradio.registry,
230 | title='Langchain Agent',
231 | description='AI agent powered by Langchain'
232 | ).launch()
233 | ```
234 |
235 | ## Requirements
236 |
237 | ### Core Requirements
238 | - Python 3.10+
239 | - gradio >= 5.9.1
240 |
241 | ### Optional Features
242 | - Voice Chat: gradio-webrtc, numba==0.60.0, pydub, librosa
243 | - Video Chat: opencv-python, Pillow
244 | - Agent Teams: crewai>=0.1.0, langchain>=0.1.0
245 |
246 | ## Troubleshooting
247 |
248 | ### Authentication Issues
249 | If you encounter 401 errors, verify your API keys:
250 | ```python
251 | import os
252 |
253 | # Set API keys manually if needed
254 | os.environ["OPENAI_API_KEY"] = "your-api-key"
255 | os.environ["GEMINI_API_KEY"] = "your-api-key"
256 | ```
257 |
258 | ### Provider Installation
259 | If you see "no providers installed" errors:
260 | ```bash
261 | # Install specific provider
262 | pip install 'ai-gradio[provider_name]'
263 |
264 | # Or install all providers
265 | pip install 'ai-gradio[all]'
266 | ```
267 |
268 |
269 | ## Contributing
270 | Contributions are welcome! Please feel free to submit a Pull Request.
271 |
272 |
273 |
274 |
275 |
276 |
277 |
--------------------------------------------------------------------------------
/ai_gradio/providers/fireworks_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from openai import OpenAI
3 | import gradio as gr
4 | from typing import Callable
5 | from fireworks.client.audio import AudioInference
6 |
7 | __version__ = "0.0.3"
8 |
9 | LANGUAGES = {
10 | "en": "english", "zh": "chinese", "de": "german", "es": "spanish",
11 | "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese",
12 | "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan",
13 | "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian",
14 | "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
15 | "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay",
16 | "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian",
17 | "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu",
18 | "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin",
19 | "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak",
20 | "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali",
21 | "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada",
22 | "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque",
23 | "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian",
24 | "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
25 | "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala",
26 | "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali",
27 | "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian",
28 | "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic",
29 | "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese",
30 | "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk",
31 | "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar",
32 | "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese",
33 | "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa",
34 | "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese"
35 | }
36 |
37 | # Language code lookup by name, with additional aliases
38 | TO_LANGUAGE_CODE = {
39 | **{language: code for code, language in LANGUAGES.items()},
40 | "burmese": "my",
41 | "valencian": "ca",
42 | "flemish": "nl",
43 | "haitian": "ht",
44 | "letzeburgesch": "lb",
45 | "pushto": "ps",
46 | "panjabi": "pa",
47 | "moldavian": "ro",
48 | "moldovan": "ro",
49 | "sinhalese": "si",
50 | "castilian": "es",
51 | "mandarin": "zh"
52 | }
53 |
54 |
55 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
56 | if "whisper" in model_name:
57 | def fn(message, history, audio_input=None):
58 | # Handle audio input if provided
59 | if audio_input:
60 | if not audio_input.endswith('.wav'):
61 | new_path = audio_input + '.wav'
62 | os.rename(audio_input, new_path)
63 | audio_input = new_path
64 |
65 | base_url = (
66 | "https://audio-turbo.us-virginia-1.direct.fireworks.ai"
67 | if model_name == "whisper-v3-turbo"
68 | else "https://audio-prod.us-virginia-1.direct.fireworks.ai"
69 | )
70 |
71 | client = AudioInference(
72 | model=model_name,
73 | base_url=base_url,
74 | api_key=api_key
75 | )
76 |
77 | with open(audio_input, "rb") as f:
78 | audio_data = f.read()
79 | response = client.transcribe(audio=audio_data)
80 | return {"role": "assistant", "content": response.text}
81 |
82 | # Handle text message
83 | if isinstance(message, dict): # Multimodal input
84 | audio_path = message.get("files", [None])[0] or message.get("audio")
85 | text = message.get("text", "")
86 | if audio_path:
87 | # Process audio file
88 | return fn(None, history, audio_path)
89 | return {"role": "assistant", "content": "No audio input provided."}
90 | else: # String input
91 | return {"role": "assistant", "content": "Please upload an audio file or use the microphone to record audio."}
92 |
93 | else:
94 | def fn(message, history, audio_input=None):
95 | # Ignore audio_input for non-whisper models
96 | inputs = preprocess(message, history)
97 | client = OpenAI(
98 | base_url="https://api.fireworks.ai/inference/v1",
99 | api_key=api_key
100 | )
101 |
102 | model_path = (
103 | "accounts/fireworks/agents/f1-preview" if model_name == "f1-preview"
104 | else "accounts/fireworks/agents/f1-mini-preview" if model_name == "f1-mini"
105 | else f"accounts/fireworks/models/{model_name}"
106 | )
107 |
108 | completion = client.chat.completions.create(
109 | model=model_path,
110 | messages=[{"role": "user", "content": inputs["prompt"]}],
111 | stream=True,
112 | max_tokens=1024,
113 | temperature=0.7,
114 | top_p=1,
115 | )
116 |
117 | response_text = ""
118 | for chunk in completion:
119 | delta = chunk.choices[0].delta.content or ""
120 | response_text += delta
121 | yield {"role": "assistant", "content": response_text}
122 |
123 | return fn
124 |
125 |
126 | def get_interface_args(pipeline):
127 | if pipeline == "audio":
128 | inputs = [
129 | gr.Audio(sources=["microphone"], type="filepath"),
130 | gr.Radio(["transcribe"], label="Task", value="transcribe"),
131 | ]
132 | outputs = "text"
133 |
134 | def preprocess(audio_path, task, text, history):
135 | if audio_path and not audio_path.endswith('.wav'):
136 | new_path = audio_path + '.wav'
137 | os.rename(audio_path, new_path)
138 | audio_path = new_path
139 | return {"role": "user", "content": {"audio_path": audio_path, "task": task}}
140 |
141 | def postprocess(text):
142 | return {"role": "assistant", "content": text}
143 |
144 | elif pipeline == "chat":
145 | inputs = gr.Textbox(label="Message")
146 | outputs = "text"
147 |
148 | def preprocess(message, history):
149 | return {"prompt": message}
150 |
151 | def postprocess(response):
152 | return response
153 |
154 | else:
155 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
156 |
157 | return inputs, outputs, preprocess, postprocess
158 |
159 |
160 | def get_pipeline(model_name):
161 | if "whisper" in model_name:
162 | return "audio"
163 | return "chat"
164 |
165 |
166 | def registry(name: str, token: str | None = None, **kwargs):
167 | """
168 | Create a Gradio Interface for a model on Fireworks.
169 | Can be used directly or with gr.load:
170 |
171 | Example:
172 | # Direct usage
173 | interface = fireworks_gradio.registry("whisper-v3", token="your-api-key")
174 | interface.launch()
175 |
176 | # With gr.load
177 | gr.load(
178 | name='whisper-v3',
179 | src=fireworks_gradio.registry,
180 | ).launch()
181 |
182 | Parameters:
183 | name (str): The name of the OpenAI model.
184 | token (str, optional): The API key for OpenAI.
185 | """
186 | # Make the function compatible with gr.load by accepting name as a positional argument
187 | if not isinstance(name, str):
188 | raise ValueError("Model name must be a string")
189 |
190 | api_key = token or os.environ.get("FIREWORKS_API_KEY")
191 | if not api_key:
192 | raise ValueError("FIREWORKS_API_KEY environment variable is not set.")
193 |
194 | pipeline = get_pipeline(name)
195 | _, _, preprocess, postprocess = get_interface_args(pipeline)
196 | fn = get_fn(name, preprocess, postprocess, api_key)
197 |
198 | description = kwargs.pop("description", None)
199 | if "whisper" in name:
200 | description = (description or "") + """
201 | \n\nSupported inputs:
202 | - Upload audio files using the textbox
203 | - Record audio using the microphone
204 | """
205 |
206 | with gr.Blocks() as interface:
207 | chatbot = gr.Chatbot(type="messages")
208 | with gr.Row():
209 | mic = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
210 |
211 | def process_audio(audio_path):
212 | if audio_path:
213 | if not audio_path.endswith('.wav'):
214 | new_path = audio_path + '.wav'
215 | os.rename(audio_path, new_path)
216 | audio_path = new_path
217 |
218 | # Create message format expected by fn
219 | message = {"files": [audio_path], "text": ""}
220 | response = fn(message, [])
221 |
222 | return [
223 | {"role": "user", "content": gr.Audio(value=audio_path)},
224 | {"role": "assistant", "content": response["content"]}
225 | ]
226 | return []
227 |
228 | mic.change(
229 | fn=process_audio,
230 | inputs=[mic],
231 | outputs=[chatbot]
232 | )
233 |
234 | else:
235 | # For non-whisper models, use regular ChatInterface
236 | interface = gr.ChatInterface(
237 | fn=fn,
238 | type="messages",
239 | description=description,
240 | **kwargs
241 | )
242 |
243 | return interface
244 |
245 |
246 | # Add these to make the module more discoverable
247 | MODELS = [
248 | "whisper-v3",
249 | "whisper-v3-turbo",
250 | "f1-preview",
251 | "f1-mini",
252 | # Add other supported models here
253 | ]
254 |
255 | def get_all_models():
256 | """Returns a list of all supported models."""
257 | return MODELS
--------------------------------------------------------------------------------
/ai_gradio/providers/hyperbolic_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from openai import OpenAI
3 | import gradio as gr
4 | from typing import Callable
5 | import base64
6 | import re
7 | import modelscope_studio.components.base as ms
8 | import modelscope_studio.components.legacy as legacy
9 | import modelscope_studio.components.antd as antd
10 |
11 | # Constants for coder interface
12 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications.
13 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript.
14 | Focus on:
15 | - Writing clear, maintainable code
16 | - Following best practices
17 | - Creating responsive designs
18 | - Adding appropriate styling and interactivity
19 | Return only the complete HTML code without any additional explanation."""
20 |
21 | DEMO_LIST = [
22 | {
23 | "card": {"index": 0},
24 | "title": "Simple Button",
25 | "description": "Create a button that changes color when clicked"
26 | },
27 | {
28 | "card": {"index": 1},
29 | "title": "Todo List",
30 | "description": "Create a simple todo list with add/remove functionality"
31 | },
32 | {
33 | "card": {"index": 2},
34 | "title": "Timer App",
35 | "description": "Create a countdown timer with start/pause/reset controls"
36 | }
37 | ]
38 |
39 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str, base_url: str = None):
40 | def fn(message, history):
41 | inputs = preprocess(message, history)
42 |
43 | client = OpenAI(
44 | api_key=api_key,
45 | base_url="https://api.hyperbolic.xyz/v1"
46 | )
47 |
48 | completion = client.chat.completions.create(
49 | model=model_name,
50 | messages=[
51 | *inputs["messages"]
52 | ],
53 | stream=True,
54 | temperature=0.7,
55 | max_tokens=512,
56 | top_p=0.9,
57 | )
58 | response_text = ""
59 | for chunk in completion:
60 | delta = chunk.choices[0].delta.content or ""
61 | # Replace DeepSeek-R1 special tokens
62 | if "deepseek" in model_name.lower():
63 | delta = delta.replace("", "[think]").replace("", "[/think]")
64 | response_text += delta
65 | yield postprocess(response_text)
66 |
67 | return fn
68 |
69 |
70 | def get_interface_args(pipeline):
71 | if pipeline == "chat":
72 | inputs = None
73 | outputs = None
74 |
75 | def preprocess(message, history):
76 | messages = []
77 | for user_msg, assistant_msg in history:
78 | messages.append({"role": "user", "content": user_msg})
79 | messages.append({"role": "assistant", "content": assistant_msg})
80 | messages.append({"role": "user", "content": message})
81 | return {"messages": messages}
82 |
83 | postprocess = lambda x: x # No post-processing needed
84 | else:
85 | # Add other pipeline types when they will be needed
86 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
87 | return inputs, outputs, preprocess, postprocess
88 |
89 |
90 | def get_pipeline(model_name):
91 | # Determine the pipeline type based on the model name
92 | # For simplicity, assuming all models are chat models at the moment
93 | return "chat"
94 |
95 |
96 | def registry(name: str, token: str | None = None, base_url: str | None = None, coder: bool = False, **kwargs):
97 | api_key = token or os.environ.get("HYPERBOLIC_API_KEY")
98 | if not api_key:
99 | raise ValueError("API key is not set. Please provide a token or set HYPERBOLIC_API_KEY environment variable.")
100 |
101 | if coder:
102 | interface = gr.Blocks(css="""
103 | .left_header {
104 | text-align: center;
105 | margin-bottom: 20px;
106 | }
107 | .right_panel {
108 | background: white;
109 | border-radius: 8px;
110 | overflow: hidden;
111 | box-shadow: 0 2px 8px rgba(0,0,0,0.15);
112 | }
113 | .render_header {
114 | background: #f5f5f5;
115 | padding: 8px;
116 | border-bottom: 1px solid #e8e8e8;
117 | }
118 | .header_btn {
119 | display: inline-block;
120 | width: 12px;
121 | height: 12px;
122 | border-radius: 50%;
123 | margin-right: 8px;
124 | background: #ff5f56;
125 | }
126 | .header_btn:nth-child(2) {
127 | background: #ffbd2e;
128 | }
129 | .header_btn:nth-child(3) {
130 | background: #27c93f;
131 | }
132 | .right_content {
133 | padding: 24px;
134 | height: 920px;
135 | display: flex;
136 | align-items: center;
137 | justify-content: center;
138 | }
139 | .html_content {
140 | height: 920px;
141 | width: 100%;
142 | }
143 | .history_chatbot {
144 | height: 100%;
145 | }
146 | """)
147 |
148 | with interface:
149 | history = gr.State([])
150 | setting = gr.State({"system": SystemPrompt})
151 |
152 | with ms.Application() as app:
153 | with antd.ConfigProvider():
154 | with antd.Row(gutter=[32, 12]) as layout:
155 | # Left Column
156 | with antd.Col(span=24, md=8):
157 | with antd.Flex(vertical=True, gap="middle", wrap=True):
158 | header = gr.HTML("""
159 |
162 | """)
163 |
164 | input = antd.InputTextarea(
165 | size="large",
166 | allow_clear=True,
167 | placeholder="Describe the web application you want to create"
168 | )
169 | btn = antd.Button("Generate", type="primary", size="large")
170 | clear_btn = antd.Button("Clear History", type="default", size="large")
171 |
172 | antd.Divider("Examples")
173 | with antd.Flex(gap="small", wrap=True):
174 | with ms.Each(DEMO_LIST):
175 | with antd.Card(hoverable=True, as_item="card") as demoCard:
176 | antd.CardMeta()
177 |
178 | antd.Divider("Settings")
179 | with antd.Flex(gap="small", wrap=True):
180 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default")
181 | codeBtn = antd.Button("🧑💻 View Code", type="default")
182 | historyBtn = antd.Button("📜 History", type="default")
183 |
184 | # Modals and Drawers
185 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal:
186 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True)
187 |
188 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer:
189 | code_output = legacy.Markdown()
190 |
191 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer:
192 | history_output = legacy.Chatbot(
193 | show_label=False,
194 | height=960,
195 | elem_classes="history_chatbot"
196 | )
197 |
198 | # Right Column
199 | with antd.Col(span=24, md=16):
200 | with ms.Div(elem_classes="right_panel"):
201 | gr.HTML('''
202 |
207 | ''')
208 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab:
209 | with antd.Tabs.Item(key="empty"):
210 | empty = antd.Empty(
211 | description="Enter your request to generate code",
212 | elem_classes="right_content"
213 | )
214 | with antd.Tabs.Item(key="loading"):
215 | loading = antd.Spin(
216 | True,
217 | tip="Generating code...",
218 | size="large",
219 | elem_classes="right_content"
220 | )
221 | with antd.Tabs.Item(key="render"):
222 | preview = gr.HTML(elem_classes="html_content")
223 |
224 | # Helper functions
225 | def demo_card_click(e: gr.EventData):
226 | index = e._data['component']['index']
227 | return DEMO_LIST[index]['description']
228 |
229 | def send_to_preview(code):
230 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8')
231 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
232 | return f''
233 |
234 | def remove_code_block(text):
235 | pattern = r'```html\n(.+?)\n```'
236 | match = re.search(pattern, text, re.DOTALL)
237 | if match:
238 | return match.group(1).strip()
239 | return text.strip()
240 |
241 | def generate_code(query, setting, history):
242 | client = OpenAI(
243 | api_key=api_key,
244 | base_url="https://api.hyperbolic.xyz/v1"
245 | )
246 |
247 | messages = [
248 | {"role": "system", "content": setting["system"]},
249 | ]
250 |
251 | for h in history:
252 | messages.append({"role": "user", "content": h[0]})
253 | messages.append({"role": "assistant", "content": h[1]})
254 |
255 | messages.append({"role": "user", "content": query})
256 |
257 | response = client.chat.completions.create(
258 | model=name,
259 | messages=messages,
260 | stream=True,
261 | temperature=0.7,
262 | max_tokens=2048,
263 | )
264 |
265 | response_text = ""
266 | for chunk in response:
267 | if chunk.choices[0].delta.content:
268 | response_text += chunk.choices[0].delta.content
269 | yield (
270 | response_text,
271 | history,
272 | None,
273 | gr.update(active_key="loading"),
274 | gr.update(open=True)
275 | )
276 |
277 | clean_code = remove_code_block(response_text)
278 | new_history = history + [(query, response_text)]
279 |
280 | yield (
281 | response_text,
282 | new_history,
283 | send_to_preview(clean_code),
284 | gr.update(active_key="render"),
285 | gr.update(open=False)
286 | )
287 |
288 | # Wire up event handlers
289 | demoCard.click(demo_card_click, outputs=[input])
290 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal])
291 | system_prompt_modal.ok(
292 | lambda input: ({"system": input}, gr.update(open=False)),
293 | inputs=[systemPromptInput],
294 | outputs=[setting, system_prompt_modal]
295 | )
296 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal])
297 |
298 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer])
299 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer])
300 |
301 | historyBtn.click(
302 | lambda h: (gr.update(open=True), h),
303 | inputs=[history],
304 | outputs=[history_drawer, history_output]
305 | )
306 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer])
307 |
308 | btn.click(
309 | generate_code,
310 | inputs=[input, setting, history],
311 | outputs=[code_output, history, preview, state_tab, code_drawer]
312 | )
313 |
314 | clear_btn.click(lambda: [], outputs=[history])
315 |
316 | return interface
317 |
318 | # Regular chat interface
319 | pipeline = get_pipeline(name)
320 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
321 | fn = get_fn(name, preprocess, postprocess, api_key, base_url)
322 |
323 | if pipeline == "chat":
324 | interface = gr.ChatInterface(fn=fn, **kwargs)
325 | else:
326 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
327 |
328 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/perplexity_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from openai import OpenAI
3 | import gradio as gr
4 | from typing import Callable
5 | import re
6 | import base64
7 | import modelscope_studio.components.base as ms
8 | import modelscope_studio.components.legacy as legacy
9 | import modelscope_studio.components.antd as antd
10 |
11 | __version__ = "0.0.1"
12 |
13 | # Add these constants at the top of the file
14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications.
15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript.
16 | Focus on:
17 | - Writing clear, maintainable code
18 | - Following best practices
19 | - Creating responsive designs
20 | - Adding appropriate styling and interactivity
21 | Return only the complete HTML code without any additional explanation."""
22 |
23 | DEMO_LIST = [
24 | {
25 | "card": {"index": 0},
26 | "title": "Simple Button",
27 | "description": "Create a button that changes color when clicked"
28 | },
29 | {
30 | "card": {"index": 1},
31 | "title": "Todo List",
32 | "description": "Create a simple todo list with add/remove functionality"
33 | },
34 | {
35 | "card": {"index": 2},
36 | "title": "Timer App",
37 | "description": "Create a countdown timer with start/pause/reset controls"
38 | }
39 | ]
40 |
41 |
42 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
43 | def fn(message, history):
44 | inputs = preprocess(message, history)
45 | client = OpenAI(
46 | api_key=api_key,
47 | base_url="https://api.perplexity.ai"
48 | )
49 | completion = client.chat.completions.create(
50 | model=model_name,
51 | messages=inputs["messages"],
52 | stream=True,
53 | )
54 | response_text = ""
55 | for chunk in completion:
56 | delta = chunk.choices[0].delta.content or ""
57 | # Replace DeepSeek special tokens if present
58 | if "deepseek" in model_name.lower():
59 | delta = delta.replace("", "[think]").replace("", "[/think]")
60 | response_text += delta
61 | yield postprocess(response_text)
62 |
63 | return fn
64 |
65 |
66 | def get_interface_args(pipeline):
67 | if pipeline == "chat":
68 | inputs = None
69 | outputs = None
70 |
71 | def preprocess(message, history):
72 | messages = []
73 | for user_msg, assistant_msg in history:
74 | messages.append({"role": "user", "content": user_msg})
75 | messages.append({"role": "assistant", "content": assistant_msg})
76 | messages.append({"role": "user", "content": message})
77 | return {"messages": messages}
78 |
79 | postprocess = lambda x: x # No post-processing needed
80 | else:
81 | # Add other pipeline types when they will be needed
82 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
83 | return inputs, outputs, preprocess, postprocess
84 |
85 |
86 | def get_pipeline(model_name):
87 | # Determine the pipeline type based on the model name
88 | # For simplicity, assuming all models are chat models at the moment
89 | return "chat"
90 |
91 |
92 | def registry(
93 | name: str,
94 | token: str | None = None,
95 | examples: list | None = None,
96 | coder: bool = False, # Add coder parameter
97 | **kwargs
98 | ):
99 | """
100 | Create a Gradio Interface for a model on Perplexity.
101 |
102 | Parameters:
103 | - name (str): The name of the model
104 | - token (str, optional): The API key
105 | - examples (list, optional): Example inputs
106 | - coder (bool, optional): Whether to use coding interface
107 | """
108 | api_key = token or os.environ.get("PERPLEXITY_API_KEY")
109 | if not api_key:
110 | raise ValueError("PERPLEXITY_API_KEY environment variable is not set.")
111 |
112 | if coder:
113 | interface = gr.Blocks(css="""
114 | .left_header {
115 | text-align: center;
116 | margin-bottom: 20px;
117 | }
118 |
119 | .right_panel {
120 | background: white;
121 | border-radius: 8px;
122 | overflow: hidden;
123 | box-shadow: 0 2px 8px rgba(0,0,0,0.15);
124 | }
125 |
126 | .render_header {
127 | background: #f5f5f5;
128 | padding: 8px;
129 | border-bottom: 1px solid #e8e8e8;
130 | }
131 |
132 | .header_btn {
133 | display: inline-block;
134 | width: 12px;
135 | height: 12px;
136 | border-radius: 50%;
137 | margin-right: 8px;
138 | background: #ff5f56;
139 | }
140 |
141 | .header_btn:nth-child(2) {
142 | background: #ffbd2e;
143 | }
144 |
145 | .header_btn:nth-child(3) {
146 | background: #27c93f;
147 | }
148 |
149 | .right_content {
150 | padding: 24px;
151 | height: 920px;
152 | display: flex;
153 | align-items: center;
154 | justify-content: center;
155 | }
156 |
157 | .html_content {
158 | height: 920px;
159 | width: 100%;
160 | }
161 |
162 | .history_chatbot {
163 | height: 100%;
164 | }
165 | """)
166 | with interface:
167 | history = gr.State([])
168 | setting = gr.State({"system": SystemPrompt})
169 |
170 | with ms.Application() as app:
171 | with antd.ConfigProvider():
172 | with antd.Row(gutter=[32, 12]) as layout:
173 | # Left Column
174 | with antd.Col(span=24, md=8):
175 | with antd.Flex(vertical=True, gap="middle", wrap=True):
176 | header = gr.HTML("""
177 |
180 | """)
181 |
182 | input = antd.InputTextarea(
183 | size="large",
184 | allow_clear=True,
185 | placeholder="Describe the web application you want to create"
186 | )
187 | btn = antd.Button("Generate", type="primary", size="large")
188 | clear_btn = antd.Button("Clear History", type="default", size="large")
189 |
190 | antd.Divider("Examples")
191 | with antd.Flex(gap="small", wrap=True):
192 | with ms.Each(DEMO_LIST):
193 | with antd.Card(hoverable=True, as_item="card") as demoCard:
194 | antd.CardMeta()
195 |
196 | antd.Divider("Settings")
197 | with antd.Flex(gap="small", wrap=True):
198 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default")
199 | codeBtn = antd.Button("🧑💻 View Code", type="default")
200 | historyBtn = antd.Button("📜 History", type="default")
201 |
202 | # Modals and Drawers
203 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal:
204 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True)
205 |
206 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer:
207 | code_output = legacy.Markdown()
208 |
209 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer:
210 | history_output = legacy.Chatbot(
211 | show_label=False,
212 | height=960,
213 | elem_classes="history_chatbot"
214 | )
215 |
216 | # Right Column
217 | with antd.Col(span=24, md=16):
218 | with ms.Div(elem_classes="right_panel"):
219 | gr.HTML('''
220 |
225 | ''')
226 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab:
227 | with antd.Tabs.Item(key="empty"):
228 | empty = antd.Empty(
229 | description="Enter your request to generate code",
230 | elem_classes="right_content"
231 | )
232 | with antd.Tabs.Item(key="loading"):
233 | loading = antd.Spin(
234 | True,
235 | tip="Generating code...",
236 | size="large",
237 | elem_classes="right_content"
238 | )
239 | with antd.Tabs.Item(key="render"):
240 | preview = gr.HTML(elem_classes="html_content")
241 |
242 | # Event Handlers
243 | def demo_card_click(e: gr.EventData):
244 | index = e._data['component']['index']
245 | return DEMO_LIST[index]['description']
246 |
247 | def send_to_preview(code):
248 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8')
249 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
250 | return f''
251 |
252 | def remove_code_block(text):
253 | # Look for HTML code block first
254 | html_pattern = r'```html\n(.*?)\n```'
255 | match = re.search(html_pattern, text, re.DOTALL)
256 | if match:
257 | return match.group(1).strip()
258 |
259 | # If no HTML block found, look for any code block
260 | code_pattern = r'```(?:\w+)?\n(.*?)\n```'
261 | match = re.search(code_pattern, text, re.DOTALL)
262 | if match:
263 | return match.group(1).strip()
264 |
265 | # If no code block found, return the whole text
266 | return text.strip()
267 |
268 | def generate_code(query, setting, history):
269 | messages = []
270 | messages.append({"role": "system", "content": setting["system"]})
271 |
272 | for h in history:
273 | messages.append({"role": "user", "content": h[0]})
274 | messages.append({"role": "assistant", "content": h[1]})
275 |
276 | messages.append({"role": "user", "content": query})
277 |
278 | client = OpenAI(
279 | api_key=api_key,
280 | base_url="https://api.perplexity.ai"
281 | )
282 |
283 | response = client.chat.completions.create(
284 | model=name,
285 | messages=messages,
286 | stream=True
287 | )
288 |
289 | response_text = ""
290 | for chunk in response:
291 | if chunk.choices[0].delta.content:
292 | response_text += chunk.choices[0].delta.content
293 | # Format the code for display
294 | formatted_response = f"```html\n{response_text}\n```"
295 | # Return all 5 required outputs
296 | yield (
297 | formatted_response, # code_output (modelscopemarkdown)
298 | history, # state
299 | None, # preview (html)
300 | gr.update(active_key="loading"), # state_tab (antdtabs)
301 | gr.update(open=True) # code_drawer (antddrawer)
302 | )
303 |
304 | clean_code = remove_code_block(response_text)
305 | new_history = history + [(query, response_text)]
306 |
307 | # Final yield with all outputs
308 | yield (
309 | f"```html\n{clean_code}\n```", # code_output (modelscopemarkdown)
310 | new_history, # state
311 | send_to_preview(clean_code), # preview (html)
312 | gr.update(active_key="render"), # state_tab (antdtabs)
313 | gr.update(open=False) # code_drawer (antddrawer)
314 | )
315 |
316 | # Wire up event handlers
317 | demoCard.click(demo_card_click, outputs=[input])
318 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal])
319 | system_prompt_modal.ok(
320 | lambda input: ({"system": input}, gr.update(open=False)),
321 | inputs=[systemPromptInput],
322 | outputs=[setting, system_prompt_modal]
323 | )
324 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal])
325 |
326 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer])
327 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer])
328 |
329 | historyBtn.click(
330 | lambda h: (gr.update(open=True), h),
331 | inputs=[history],
332 | outputs=[history_drawer, history_output]
333 | )
334 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer])
335 |
336 | btn.click(
337 | generate_code,
338 | inputs=[input, setting, history],
339 | outputs=[code_output, history, preview, state_tab, code_drawer]
340 | )
341 |
342 | clear_btn.click(lambda: [], outputs=[history])
343 |
344 | return interface
345 |
346 | # Continue with existing chat interface code...
347 | pipeline = get_pipeline(name)
348 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
349 | fn = get_fn(name, preprocess, postprocess, api_key)
350 |
351 | if examples:
352 | kwargs["examples"] = examples
353 |
354 | if pipeline == "chat":
355 | interface = gr.ChatInterface(fn=fn, **kwargs)
356 | else:
357 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
358 |
359 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/crewai_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Dict, Generator
3 | import gradio as gr
4 | from crewai import Agent, Task, Crew
5 | from crewai_tools import ScrapeWebsiteTool
6 | import queue
7 | import threading
8 | import asyncio
9 |
10 | class MessageQueue:
11 | def __init__(self):
12 | self.message_queue = queue.Queue()
13 | self.last_agent = None
14 |
15 | def add_message(self, message: Dict):
16 | print(f"Adding message to queue: {message}")
17 | self.message_queue.put(message)
18 |
19 | def get_messages(self) -> List[Dict]:
20 | messages = []
21 | while not self.message_queue.empty():
22 | messages.append(self.message_queue.get())
23 | return messages
24 |
25 | class CrewFactory:
26 | @staticmethod
27 | def create_crew_config(crew_type: str, topic: str) -> dict:
28 | configs = {
29 | "article": {
30 | "agents": [
31 | {
32 | "role": "Content Planner",
33 | "goal": f"Plan engaging and factually accurate content on {topic}",
34 | "backstory": "Expert content planner with focus on creating engaging outlines",
35 | "tasks": [
36 | """Create a detailed content plan for an article by:
37 | 1. Prioritizing the latest trends and key players
38 | 2. Identifying the target audience
39 | 3. Developing a detailed content outline
40 | 4. Including SEO keywords and sources"""
41 | ]
42 | },
43 | {
44 | "role": "Content Writer",
45 | "goal": f"Write insightful piece about {topic}",
46 | "backstory": "Expert content writer with focus on engaging articles",
47 | "tasks": [
48 | """1. Use the content plan to craft a compelling blog post
49 | 2. Incorporate SEO keywords naturally
50 | 3. Create proper structure with introduction, body, and conclusion"""
51 | ]
52 | },
53 | {
54 | "role": "Editor",
55 | "goal": "Polish and refine the content",
56 | "backstory": "Expert editor with eye for detail and clarity",
57 | "tasks": [
58 | """1. Review for clarity and coherence
59 | 2. Correct any errors
60 | 3. Ensure consistent tone and style"""
61 | ]
62 | }
63 | ]
64 | },
65 | "support": {
66 | "agents": [
67 | {
68 | "role": "Senior Support Representative",
69 | "goal": "Be the most helpful support representative",
70 | "backstory": "Expert at analyzing questions and providing support",
71 | "tasks": [
72 | f"""Analyze this inquiry thoroughly: {topic}
73 | Provide detailed support response."""
74 | ]
75 | },
76 | {
77 | "role": "Support Quality Assurance",
78 | "goal": "Ensure highest quality of support responses",
79 | "backstory": "Expert at reviewing and improving support responses",
80 | "tasks": [
81 | """Review and improve the support response to ensure it's:
82 | 1. Comprehensive and helpful
83 | 2. Properly formatted with clear structure"""
84 | ]
85 | }
86 | ]
87 | }
88 | }
89 | return configs.get(crew_type, configs["article"])
90 |
91 | class CrewManager:
92 | def __init__(self, api_key: str = None):
93 | self.api_key = api_key
94 | self.message_queue = MessageQueue()
95 | self.agents = []
96 | self.current_agent = None
97 | self.scrape_tool = None
98 |
99 | def initialize_agents(self, crew_type: str, topic: str, website_url: str = None):
100 | if not self.api_key:
101 | raise ValueError("OpenAI API key is required")
102 |
103 | os.environ["OPENAI_API_KEY"] = self.api_key
104 | if website_url:
105 | self.scrape_tool = ScrapeWebsiteTool(website_url=website_url)
106 |
107 | # Get crew configuration
108 | config = CrewFactory.create_crew_config(crew_type, topic)
109 |
110 | # Initialize agents from configuration
111 | self.agents = []
112 | for agent_config in config["agents"]:
113 | agent = Agent(
114 | role=agent_config["role"],
115 | goal=agent_config["goal"],
116 | backstory=agent_config["backstory"],
117 | allow_delegation=False,
118 | verbose=True
119 | )
120 | self.agents.append((agent, agent_config["tasks"]))
121 |
122 | def create_tasks(self, topic: str) -> List[Task]:
123 | tasks = []
124 | for agent, task_descriptions in self.agents:
125 | for task_desc in task_descriptions:
126 | task = Task(
127 | description=task_desc,
128 | expected_output="Detailed and well-formatted response",
129 | agent=agent,
130 | tools=[self.scrape_tool] if self.scrape_tool else []
131 | )
132 | tasks.append(task)
133 | return tasks
134 |
135 | async def process_support(self, inquiry: str, website_url: str, crew_type: str) -> Generator[List[Dict], None, None]:
136 | def add_agent_messages(agent_name: str, tasks: str, emoji: str = "🤖"):
137 | self.message_queue.add_message({
138 | "role": "assistant",
139 | "content": agent_name,
140 | "metadata": {"title": f"{emoji} {agent_name}"}
141 | })
142 |
143 | self.message_queue.add_message({
144 | "role": "assistant",
145 | "content": tasks,
146 | "metadata": {"title": f"📋 Task for {agent_name}"}
147 | })
148 |
149 | def setup_next_agent(current_agent: str):
150 | if crew_type == "support":
151 | if current_agent == "Senior Support Representative":
152 | self.current_agent = "Support Quality Assurance Specialist"
153 | add_agent_messages(
154 | "Support Quality Assurance Specialist",
155 | "Review and improve the support response"
156 | )
157 | elif crew_type == "article":
158 | if current_agent == "Content Planner":
159 | self.current_agent = "Content Writer"
160 | add_agent_messages(
161 | "Content Writer",
162 | "Write the article based on the content plan"
163 | )
164 | elif current_agent == "Content Writer":
165 | self.current_agent = "Editor"
166 | add_agent_messages(
167 | "Editor",
168 | "Review and polish the article"
169 | )
170 |
171 | def task_callback(task_output):
172 | raw_output = task_output.raw
173 | if "## Final Answer:" in raw_output:
174 | content = raw_output.split("## Final Answer:")[1].strip()
175 | else:
176 | content = raw_output.strip()
177 |
178 | if self.current_agent == "Support Quality Assurance Specialist":
179 | self.message_queue.add_message({
180 | "role": "assistant",
181 | "content": "Final response is ready!",
182 | "metadata": {"title": "✅ Final Response"}
183 | })
184 |
185 | formatted_content = content
186 | formatted_content = formatted_content.replace("\n#", "\n\n#")
187 | formatted_content = formatted_content.replace("\n-", "\n\n-")
188 | formatted_content = formatted_content.replace("\n*", "\n\n*")
189 | formatted_content = formatted_content.replace("\n1.", "\n\n1.")
190 | formatted_content = formatted_content.replace("\n\n\n", "\n\n")
191 |
192 | self.message_queue.add_message({
193 | "role": "assistant",
194 | "content": formatted_content
195 | })
196 | else:
197 | self.message_queue.add_message({
198 | "role": "assistant",
199 | "content": content,
200 | "metadata": {"title": f"✨ Output from {self.current_agent}"}
201 | })
202 | setup_next_agent(self.current_agent)
203 |
204 | try:
205 | self.initialize_agents(crew_type, inquiry, website_url)
206 | # Set initial agent based on crew type
207 | self.current_agent = "Senior Support Representative" if crew_type == "support" else "Content Planner"
208 |
209 | yield [{
210 | "role": "assistant",
211 | "content": "Starting to process your inquiry...",
212 | "metadata": {"title": "🚀 Process Started"}
213 | }]
214 |
215 | # Set initial task message based on crew type
216 | if crew_type == "support":
217 | add_agent_messages(
218 | "Senior Support Representative",
219 | "Analyze inquiry and provide comprehensive support"
220 | )
221 | else:
222 | add_agent_messages(
223 | "Content Planner",
224 | "Create a detailed content plan for the article"
225 | )
226 |
227 | crew = Crew(
228 | agents=[agent for agent, _ in self.agents],
229 | tasks=self.create_tasks(inquiry),
230 | verbose=True,
231 | task_callback=task_callback
232 | )
233 |
234 | def run_crew():
235 | try:
236 | crew.kickoff()
237 | except Exception as e:
238 | print(f"Error in crew execution: {str(e)}")
239 | self.message_queue.add_message({
240 | "role": "assistant",
241 | "content": f"Error: {str(e)}",
242 | "metadata": {"title": "❌ Error"}
243 | })
244 |
245 | thread = threading.Thread(target=run_crew)
246 | thread.start()
247 |
248 | while thread.is_alive() or not self.message_queue.message_queue.empty():
249 | messages = self.message_queue.get_messages()
250 | if messages:
251 | yield messages
252 | await asyncio.sleep(0.1)
253 |
254 | except Exception as e:
255 | print(f"Error in process_support: {str(e)}")
256 | yield [{
257 | "role": "assistant",
258 | "content": f"An error occurred: {str(e)}",
259 | "metadata": {"title": "❌ Error"}
260 | }]
261 |
262 | def registry(name: str, token: str | None = None, crew_type: str = "support", **kwargs):
263 | has_api_key = bool(token or os.environ.get("OPENAI_API_KEY"))
264 | crew_manager = None
265 |
266 | with gr.Blocks(theme=gr.themes.Soft()) as demo:
267 | title = "📝 AI Article Writing Crew" if crew_type == "article" else "🤖 CrewAI Assistant"
268 | gr.Markdown(f"# {title}")
269 |
270 | api_key = gr.Textbox(
271 | label="OpenAI API Key",
272 | type="password",
273 | placeholder="Type your OpenAI API key and press Enter...",
274 | interactive=True,
275 | visible=not has_api_key
276 | )
277 |
278 | chatbot = gr.Chatbot(
279 | label="Writing Process" if crew_type == "article" else "Process",
280 | height=700 if crew_type == "article" else 600,
281 | show_label=True,
282 | visible=has_api_key,
283 | avatar_images=(None, "https://avatars.githubusercontent.com/u/170677839?v=4"),
284 | render_markdown=True,
285 | type="messages"
286 | )
287 |
288 | with gr.Row(equal_height=True):
289 | topic = gr.Textbox(
290 | label="Article Topic" if crew_type == "article" else "Topic/Question",
291 | placeholder="Enter topic..." if crew_type == "article" else "Enter your question...",
292 | scale=4,
293 | visible=has_api_key
294 | )
295 | website_url = gr.Textbox(
296 | label="Documentation URL",
297 | placeholder="Enter documentation URL to search...",
298 | scale=4,
299 | visible=(has_api_key) and crew_type == "support"
300 | )
301 | btn = gr.Button(
302 | "Write Article" if crew_type == "article" else "Start",
303 | variant="primary",
304 | scale=1,
305 | visible=has_api_key
306 | )
307 |
308 | async def process_input(topic, website_url, history, api_key):
309 | nonlocal crew_manager
310 | effective_api_key = token or api_key or os.environ.get("OPENAI_API_KEY")
311 |
312 | if not effective_api_key:
313 | yield [
314 | {"role": "user", "content": f"Question: {topic}\nDocumentation: {website_url}"},
315 | {"role": "assistant", "content": "Please provide an OpenAI API key."}
316 | ]
317 | return # Early return without value
318 |
319 | if crew_manager is None:
320 | crew_manager = CrewManager(api_key=effective_api_key)
321 |
322 | messages = [{"role": "user", "content": f"Question: {topic}\nDocumentation: {website_url}"}]
323 | yield messages
324 |
325 | try:
326 | async for new_messages in crew_manager.process_support(topic, website_url, crew_type):
327 | for msg in new_messages:
328 | if "metadata" in msg:
329 | messages.append({
330 | "role": msg["role"],
331 | "content": msg["content"],
332 | "metadata": {"title": msg["metadata"]["title"]}
333 | })
334 | else:
335 | messages.append({
336 | "role": msg["role"],
337 | "content": msg["content"]
338 | })
339 | yield messages
340 | except Exception as e:
341 | messages.append({
342 | "role": "assistant",
343 | "content": f"An error occurred: {str(e)}",
344 | "metadata": {"title": "❌ Error"}
345 | })
346 | yield messages
347 |
348 | def show_interface():
349 | return {
350 | api_key: gr.Textbox(visible=False),
351 | chatbot: gr.Chatbot(visible=True),
352 | topic: gr.Textbox(visible=True),
353 | website_url: gr.Textbox(visible=True),
354 | btn: gr.Button(visible=True)
355 | }
356 |
357 | if not has_api_key:
358 | api_key.submit(show_interface, None, [api_key, chatbot, topic, website_url, btn])
359 |
360 | btn.click(process_input, [topic, website_url, chatbot, api_key], [chatbot])
361 | topic.submit(process_input, [topic, website_url, chatbot, api_key], [chatbot])
362 |
363 | return demo
--------------------------------------------------------------------------------
/ai_gradio/providers/mistral_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | from mistralai import Mistral
4 | import gradio as gr
5 | from typing import Callable
6 | from urllib.parse import urlparse
7 | import modelscope_studio.components.base as ms
8 | import modelscope_studio.components.legacy as legacy
9 | import modelscope_studio.components.antd as antd
10 | import re
11 |
12 | __version__ = "0.0.1"
13 |
14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications.
15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript.
16 | Focus on:
17 | - Writing clear, maintainable code
18 | - Following best practices
19 | - Creating responsive designs
20 | - Adding appropriate styling and interactivity
21 | Return only the complete HTML code without any additional explanation."""
22 |
23 |
24 | def encode_image_file(image_path):
25 | """Encode an image file to base64."""
26 | try:
27 | with open(image_path, "rb") as image_file:
28 | return base64.b64encode(image_file.read()).decode('utf-8')
29 | except Exception as e:
30 | print(f"Error encoding image: {str(e)}")
31 | return None
32 |
33 |
34 | def process_image(image):
35 | """Process image input to the format expected by Mistral API."""
36 | if isinstance(image, str):
37 | # Check if it's a URL or base64 string
38 | if image.startswith('data:'):
39 | return image # Already in base64 format
40 | elif urlparse(image).scheme in ('http', 'https'):
41 | return image # It's a URL
42 | else:
43 | # Assume it's a local file path
44 | encoded = encode_image_file(image)
45 | return f"data:image/jpeg;base64,{encoded}" if encoded else None
46 | return None
47 |
48 |
49 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
50 | def fn(message, history):
51 | inputs = preprocess(message, history)
52 | client = Mistral(api_key=api_key)
53 | try:
54 | # Check if the model is Codestral
55 | if model_name.startswith("codestral"):
56 | # Handle Codestral API calls
57 | response = client.chat.complete(
58 | model=model_name,
59 | messages=inputs["messages"]
60 | )
61 | yield postprocess(response.choices[0].message.content)
62 | else:
63 | # Create the streaming chat completion for other models
64 | stream_response = client.chat.stream(
65 | model=model_name,
66 | messages=inputs["messages"]
67 | )
68 |
69 | response_text = ""
70 | for chunk in stream_response:
71 | if chunk.data.choices[0].delta.content is not None:
72 | delta = chunk.data.choices[0].delta.content
73 | response_text += delta
74 | yield postprocess(response_text)
75 |
76 | except Exception as e:
77 | print(f"Error during chat completion: {str(e)}")
78 | yield "Sorry, there was an error processing your request."
79 |
80 | return fn
81 |
82 |
83 | def get_interface_args(pipeline):
84 | if pipeline == "chat":
85 | inputs = None
86 | outputs = None
87 |
88 | def preprocess(message, history):
89 | messages = []
90 | # Process history
91 | for user_msg, assistant_msg in history:
92 | if isinstance(user_msg, dict):
93 | # Handle multimodal history messages
94 | content = []
95 | if user_msg.get("text"):
96 | content.append({"type": "text", "text": user_msg["text"]})
97 | for file in user_msg.get("files", []):
98 | processed_image = process_image(file)
99 | if processed_image:
100 | content.append({"type": "image_url", "image_url": processed_image})
101 | messages.append({"role": "user", "content": content})
102 | else:
103 | # Handle text-only history messages
104 | messages.append({"role": "user", "content": user_msg})
105 | messages.append({"role": "assistant", "content": assistant_msg})
106 |
107 | # Process current message
108 | if isinstance(message, dict):
109 | # Handle multimodal input
110 | content = []
111 | if message.get("text"):
112 | content.append({"type": "text", "text": message["text"]})
113 | for file in message.get("files", []):
114 | processed_image = process_image(file)
115 | if processed_image:
116 | content.append({"type": "image_url", "image_url": processed_image})
117 | messages.append({"role": "user", "content": content})
118 | else:
119 | # Handle text-only input
120 | messages.append({"role": "user", "content": message})
121 |
122 | return {"messages": messages}
123 |
124 | postprocess = lambda x: x # No post-processing needed
125 | else:
126 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
127 | return inputs, outputs, preprocess, postprocess
128 |
129 |
130 | def get_pipeline(model_name):
131 | # Determine the pipeline type based on the model name
132 | # For simplicity, assuming all models are chat models at the moment
133 | return "chat"
134 |
135 |
136 | def generate_code(query, history, setting, api_key):
137 | """Generate code using Mistral API and handle UI updates."""
138 | client = Mistral(api_key=api_key)
139 |
140 | messages = []
141 | # Add system prompt
142 | messages.append({"role": "system", "content": setting["system"]})
143 |
144 | # Add history
145 | for h in history:
146 | messages.append({"role": "user", "content": h[0]})
147 | messages.append({"role": "assistant", "content": h[1]})
148 |
149 | # Add current query
150 | messages.append({"role": "user", "content": query})
151 |
152 | try:
153 | # Create the streaming chat completion
154 | stream_response = client.chat.stream(
155 | model="mistral-large-latest",
156 | messages=messages
157 | )
158 |
159 | response_text = ""
160 | for chunk in stream_response:
161 | if chunk.data.choices[0].delta.content is not None:
162 | delta = chunk.data.choices[0].delta.content
163 | response_text += delta
164 | # Yield intermediate updates
165 | yield (
166 | response_text, # code_output (for markdown display)
167 | history, # history state
168 | None, # preview HTML
169 | gr.update(active_key="loading"), # state_tab
170 | gr.update(open=True) # code_drawer
171 | )
172 |
173 | # Clean the code and prepare final preview
174 | clean_code = remove_code_block(response_text)
175 | new_history = history + [(query, response_text)]
176 |
177 | # Final yield with complete response
178 | yield (
179 | response_text, # code_output
180 | new_history, # history state
181 | send_to_preview(clean_code), # preview HTML
182 | gr.update(active_key="render"), # state_tab
183 | gr.update(open=False) # code_drawer
184 | )
185 |
186 | except Exception as e:
187 | print(f"Error generating code: {str(e)}")
188 | yield (
189 | f"Error: {str(e)}",
190 | history,
191 | None,
192 | gr.update(active_key="empty"),
193 | gr.update(open=True)
194 | )
195 |
196 |
197 | def remove_code_block(text):
198 | """Extract code from markdown code blocks."""
199 | pattern = r'```html\n(.+?)\n```'
200 | match = re.search(pattern, text, re.DOTALL)
201 | if match:
202 | return match.group(1).strip()
203 | return text.strip()
204 |
205 |
206 | def send_to_preview(code):
207 | """Convert code to base64 encoded iframe source."""
208 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8')
209 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
210 | return f''
211 |
212 |
213 | def registry(name: str, token: str | None = None, coder: bool = False, **kwargs):
214 | """
215 | Create a Gradio Interface for a model on Mistral AI.
216 |
217 | Parameters:
218 | - name (str): The name of the Mistral AI model.
219 | - token (str, optional): The API key for Mistral AI.
220 | """
221 | api_key = token or os.environ.get("MISTRAL_API_KEY")
222 | if not api_key:
223 | raise ValueError("MISTRAL_API_KEY environment variable is not set.")
224 |
225 | pipeline = get_pipeline(name)
226 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
227 | fn = get_fn(name, preprocess, postprocess, api_key)
228 |
229 | if pipeline == "chat":
230 | # Always enable multimodal support
231 | interface = gr.ChatInterface(
232 | fn=fn,
233 | multimodal=True,
234 | **kwargs
235 | )
236 | else:
237 | # For other pipelines, create a standard Interface (not implemented yet)
238 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
239 |
240 | if coder:
241 | interface = gr.Blocks(css="""
242 | .left_header {
243 | text-align: center;
244 | margin-bottom: 20px;
245 | }
246 | .right_panel {
247 | background: white;
248 | border-radius: 8px;
249 | overflow: hidden;
250 | box-shadow: 0 2px 8px rgba(0,0,0,0.15);
251 | }
252 | .render_header {
253 | background: #f5f5f5;
254 | padding: 8px;
255 | border-bottom: 1px solid #e8e8e8;
256 | }
257 | .header_btn {
258 | display: inline-block;
259 | width: 12px;
260 | height: 12px;
261 | border-radius: 50%;
262 | margin-right: 8px;
263 | background: #ff5f56;
264 | }
265 | .header_btn:nth-child(2) {
266 | background: #ffbd2e;
267 | }
268 | .header_btn:nth-child(3) {
269 | background: #27c93f;
270 | }
271 | .right_content {
272 | padding: 24px;
273 | height: 920px;
274 | display: flex;
275 | align-items: center;
276 | justify-content: center;
277 | }
278 | .html_content {
279 | height: 920px;
280 | width: 100%;
281 | }
282 | .history_chatbot {
283 | height: 100%;
284 | }
285 | """)
286 |
287 | with interface:
288 | history = gr.State([])
289 | setting = gr.State({"system": SystemPrompt})
290 |
291 | with ms.Application() as app:
292 | with antd.ConfigProvider():
293 | with antd.Row(gutter=[32, 12]) as layout:
294 | # Left Column
295 | with antd.Col(span=24, md=8):
296 | with antd.Flex(vertical=True, gap="middle", wrap=True):
297 | header = gr.HTML("""
298 |
301 | """)
302 |
303 | input = antd.InputTextarea(
304 | size="large",
305 | allow_clear=True,
306 | placeholder="Describe the code you want to generate"
307 | )
308 | btn = antd.Button("Generate", type="primary", size="large")
309 | clear_btn = antd.Button("Clear History", type="default", size="large")
310 |
311 | antd.Divider("Settings")
312 | with antd.Flex(gap="small", wrap=True):
313 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default")
314 | codeBtn = antd.Button("🧑💻 View Code", type="default")
315 | historyBtn = antd.Button("📜 History", type="default")
316 |
317 | # Modals and Drawers
318 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal:
319 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True)
320 |
321 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer:
322 | code_output = legacy.Markdown()
323 |
324 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer:
325 | history_output = legacy.Chatbot(
326 | show_label=False,
327 | height=960,
328 | elem_classes="history_chatbot"
329 | )
330 |
331 | # Right Column
332 | with antd.Col(span=24, md=16):
333 | with ms.Div(elem_classes="right_panel"):
334 | gr.HTML('''
335 |
340 | ''')
341 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab:
342 | with antd.Tabs.Item(key="empty"):
343 | empty = antd.Empty(
344 | description="Enter your request to generate code",
345 | elem_classes="right_content"
346 | )
347 | with antd.Tabs.Item(key="loading"):
348 | loading = antd.Spin(
349 | True,
350 | tip="Generating code...",
351 | size="large",
352 | elem_classes="right_content"
353 | )
354 | with antd.Tabs.Item(key="render"):
355 | preview = gr.HTML(elem_classes="html_content")
356 |
357 | # Wire up event handlers
358 | btn.click(
359 | generate_code,
360 | inputs=[input, history, setting, gr.State(api_key)],
361 | outputs=[code_output, history, preview, state_tab, code_drawer],
362 | api_name=False
363 | )
364 |
365 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal])
366 | system_prompt_modal.ok(
367 | lambda input: ({"system": input}, gr.update(open=False)),
368 | inputs=[systemPromptInput],
369 | outputs=[setting, system_prompt_modal]
370 | )
371 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal])
372 |
373 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer])
374 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer])
375 |
376 | historyBtn.click(
377 | lambda h: (gr.update(open=True), h),
378 | inputs=[history],
379 | outputs=[history_drawer, history_output]
380 | )
381 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer])
382 |
383 | clear_btn.click(lambda: [], outputs=[history])
384 |
385 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/qwen_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from openai import OpenAI
3 | import gradio as gr
4 | from typing import Callable
5 | import base64
6 | import re
7 | import modelscope_studio.components.base as ms
8 | import modelscope_studio.components.legacy as legacy
9 | import modelscope_studio.components.antd as antd
10 |
11 | __version__ = "0.0.3"
12 |
13 | # Add these constants at the top of the file
14 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications.
15 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript.
16 | Focus on:
17 | - Writing clear, maintainable code
18 | - Following best practices
19 | - Creating responsive designs
20 | - Adding appropriate styling and interactivity
21 | Return only the complete HTML code without any additional explanation."""
22 |
23 | DEMO_LIST = [
24 | {
25 | "card": {"index": 0},
26 | "title": "Simple Button",
27 | "description": "Create a button that changes color when clicked"
28 | },
29 | {
30 | "card": {"index": 1},
31 | "title": "Todo List",
32 | "description": "Create a simple todo list with add/remove functionality"
33 | },
34 | {
35 | "card": {"index": 2},
36 | "title": "Timer App",
37 | "description": "Create a countdown timer with start/pause/reset controls"
38 | }
39 | ]
40 |
41 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str | None = None, base_url: str | None = None, local: bool = False):
42 | if local:
43 | from vllm import LLM, SamplingParams
44 | from transformers import AutoTokenizer
45 |
46 | # Initialize tokenizer and model
47 | tokenizer = AutoTokenizer.from_pretrained(model_name)
48 | llm = LLM(
49 | model=model_name,
50 | tensor_parallel_size=1, # Adjust based on available GPUs
51 | max_model_len=1010000,
52 | enable_chunked_prefill=True,
53 | max_num_batched_tokens=131072,
54 | enforce_eager=True,
55 | )
56 |
57 | # Default sampling parameters
58 | sampling_params = SamplingParams(
59 | temperature=0.7,
60 | top_p=0.8,
61 | repetition_penalty=1.05,
62 | max_tokens=512
63 | )
64 |
65 | def fn(message, history):
66 | inputs = preprocess(message, history)
67 | messages = inputs["messages"]
68 |
69 | # Convert messages to model input format
70 | text = tokenizer.apply_chat_template(
71 | messages,
72 | tokenize=False,
73 | add_generation_prompt=True
74 | )
75 |
76 | # Generate response
77 | outputs = llm.generate([text], sampling_params)
78 | response_text = outputs[0].outputs[0].text
79 | yield postprocess(response_text)
80 |
81 | return fn
82 |
83 | # Original cloud API implementation
84 | def fn(message, history):
85 | inputs = preprocess(message, history)
86 |
87 | client = OpenAI(
88 | api_key=api_key,
89 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
90 | )
91 | completion = client.chat.completions.create(
92 | model=model_name,
93 | messages=inputs["messages"],
94 | stream=True,
95 | )
96 | response_text = ""
97 | for chunk in completion:
98 | delta = chunk.choices[0].delta.content or ""
99 | response_text += delta
100 | yield postprocess(response_text)
101 |
102 | return fn
103 |
104 |
105 | def get_interface_args(pipeline, model_name: str):
106 | if pipeline == "chat":
107 | inputs = None
108 | outputs = None
109 |
110 | def preprocess(message, history):
111 | messages = []
112 | # Add system prompt for qwq-32b-preview
113 | if model_name == "qwq-32b-preview":
114 | messages.append({
115 | "role": "system",
116 | "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
117 | })
118 |
119 | for user_msg, assistant_msg in history:
120 | messages.append({"role": "user", "content": user_msg})
121 | messages.append({"role": "assistant", "content": assistant_msg})
122 |
123 | # Handle multimodal input
124 | if isinstance(message, dict):
125 | content = []
126 | if message.get("files"):
127 | # Convert local file path to data URL
128 | with open(message["files"][0], "rb") as image_file:
129 | encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
130 | content.append({
131 | "type": "image_url",
132 | "image_url": f"data:image/jpeg;base64,{encoded_image}"
133 | })
134 | content.append({
135 | "type": "text",
136 | "text": message["text"]
137 | })
138 | messages.append({"role": "user", "content": content})
139 | else:
140 | messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
141 |
142 | return {"messages": messages}
143 |
144 | postprocess = lambda x: x # No post-processing needed
145 | else:
146 | # Add other pipeline types when they will be needed
147 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
148 | return inputs, outputs, preprocess, postprocess
149 |
150 |
151 | def get_pipeline(model_name):
152 | # Determine the pipeline type based on the model name
153 | # For simplicity, assuming all models are chat models at the moment
154 | return "chat"
155 |
156 |
157 | def registry(
158 | name: str,
159 | token: str | None = None,
160 | examples: list | None = None,
161 | coder: bool = False,
162 | local: bool = False, # Add local parameter
163 | **kwargs
164 | ):
165 | api_key = None if local else (token or os.environ.get("DASHSCOPE_API_KEY"))
166 | if not local and not api_key:
167 | raise ValueError("API key not found in environment variables.")
168 |
169 | if coder:
170 | interface = gr.Blocks(css="""
171 | .left_header {
172 | text-align: center;
173 | margin-bottom: 20px;
174 | }
175 |
176 | .right_panel {
177 | background: white;
178 | border-radius: 8px;
179 | overflow: hidden;
180 | box-shadow: 0 2px 8px rgba(0,0,0,0.15);
181 | }
182 |
183 | .render_header {
184 | background: #f5f5f5;
185 | padding: 8px;
186 | border-bottom: 1px solid #e8e8e8;
187 | }
188 |
189 | .header_btn {
190 | display: inline-block;
191 | width: 12px;
192 | height: 12px;
193 | border-radius: 50%;
194 | margin-right: 8px;
195 | background: #ff5f56;
196 | }
197 |
198 | .header_btn:nth-child(2) {
199 | background: #ffbd2e;
200 | }
201 |
202 | .header_btn:nth-child(3) {
203 | background: #27c93f;
204 | }
205 |
206 | .right_content {
207 | padding: 24px;
208 | height: 920px;
209 | display: flex;
210 | align-items: center;
211 | justify-content: center;
212 | }
213 |
214 | .html_content {
215 | height: 920px;
216 | width: 100%;
217 | }
218 |
219 | .history_chatbot {
220 | height: 100%;
221 | }
222 | """)
223 | with interface:
224 | history = gr.State([])
225 | setting = gr.State({"system": SystemPrompt})
226 |
227 | with ms.Application() as app:
228 | with antd.ConfigProvider():
229 | with antd.Row(gutter=[32, 12]) as layout:
230 | # Left Column
231 | with antd.Col(span=24, md=8):
232 | with antd.Flex(vertical=True, gap="middle", wrap=True):
233 | header = gr.HTML("""
234 |
237 | """)
238 |
239 | input = antd.InputTextarea(
240 | size="large",
241 | allow_clear=True,
242 | placeholder="Describe the web application you want to create"
243 | )
244 | btn = antd.Button("Generate", type="primary", size="large")
245 | clear_btn = antd.Button("Clear History", type="default", size="large")
246 |
247 | antd.Divider("Examples")
248 | with antd.Flex(gap="small", wrap=True):
249 | with ms.Each(DEMO_LIST):
250 | with antd.Card(hoverable=True, as_item="card") as demoCard:
251 | antd.CardMeta()
252 |
253 | antd.Divider("Settings")
254 | with antd.Flex(gap="small", wrap=True):
255 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default")
256 | codeBtn = antd.Button("🧑💻 View Code", type="default")
257 | historyBtn = antd.Button("📜 History", type="default")
258 |
259 | # Modals and Drawers
260 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal:
261 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True)
262 |
263 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer:
264 | code_output = legacy.Markdown()
265 |
266 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer:
267 | history_output = legacy.Chatbot(
268 | show_label=False,
269 | height=960,
270 | elem_classes="history_chatbot"
271 | )
272 |
273 | # Right Column
274 | with antd.Col(span=24, md=16):
275 | with ms.Div(elem_classes="right_panel"):
276 | gr.HTML('''
277 |
282 | ''')
283 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab:
284 | with antd.Tabs.Item(key="empty"):
285 | empty = antd.Empty(
286 | description="Enter your request to generate code",
287 | elem_classes="right_content"
288 | )
289 | with antd.Tabs.Item(key="loading"):
290 | loading = antd.Spin(
291 | True,
292 | tip="Generating code...",
293 | size="large",
294 | elem_classes="right_content"
295 | )
296 | with antd.Tabs.Item(key="render"):
297 | preview = gr.HTML(elem_classes="html_content")
298 |
299 | # Event Handlers
300 | def demo_card_click(e: gr.EventData):
301 | index = e._data['component']['index']
302 | return DEMO_LIST[index]['description']
303 |
304 | def send_to_preview(code):
305 | encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8')
306 | data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
307 | return f''
308 |
309 | def remove_code_block(text):
310 | pattern = r'```html\n(.+?)\n```'
311 | match = re.search(pattern, text, re.DOTALL)
312 | if match:
313 | return match.group(1).strip()
314 | return text.strip()
315 |
316 | def generate_code(query, setting, history):
317 | client = OpenAI(
318 | api_key=api_key,
319 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
320 | )
321 |
322 | messages = [
323 | {"role": "system", "content": setting["system"]},
324 | {"role": "assistant", "content": "I understand. I will help you write clean, efficient web code."}
325 | ]
326 |
327 | # Add history
328 | for h in history:
329 | messages.append({"role": "user", "content": h[0]})
330 | messages.append({"role": "assistant", "content": h[1]})
331 |
332 | messages.append({"role": "user", "content": query})
333 |
334 | completion = client.chat.completions.create(
335 | model=name,
336 | messages=messages,
337 | stream=True,
338 | )
339 |
340 | response_text = ""
341 | for chunk in completion:
342 | if chunk.choices[0].delta.content:
343 | response_text += chunk.choices[0].delta.content
344 | yield (
345 | response_text,
346 | history,
347 | None,
348 | gr.update(active_key="loading"),
349 | gr.update(open=True)
350 | )
351 |
352 | clean_code = remove_code_block(response_text)
353 | new_history = history + [(query, response_text)]
354 |
355 | yield (
356 | response_text,
357 | new_history,
358 | send_to_preview(clean_code),
359 | gr.update(active_key="render"),
360 | gr.update(open=False)
361 | )
362 |
363 | # Wire up event handlers
364 | demoCard.click(demo_card_click, outputs=[input])
365 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal])
366 | system_prompt_modal.ok(
367 | lambda input: ({"system": input}, gr.update(open=False)),
368 | inputs=[systemPromptInput],
369 | outputs=[setting, system_prompt_modal]
370 | )
371 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal])
372 |
373 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer])
374 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer])
375 |
376 | historyBtn.click(
377 | lambda h: (gr.update(open=True), h),
378 | inputs=[history],
379 | outputs=[history_drawer, history_output]
380 | )
381 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer])
382 |
383 | btn.click(
384 | generate_code,
385 | inputs=[input, setting, history],
386 | outputs=[code_output, history, preview, state_tab, code_drawer]
387 | )
388 |
389 | clear_btn.click(lambda: [], outputs=[history])
390 |
391 | return interface
392 |
393 | # Continue with existing chat interface code...
394 | pipeline = get_pipeline(name)
395 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline, name)
396 | fn = get_fn(name, preprocess, postprocess, api_key, local=local)
397 |
398 | if examples:
399 | formatted_examples = [[example, False] for example in examples]
400 | kwargs["examples"] = formatted_examples
401 |
402 | if pipeline == "chat":
403 | interface = gr.ChatInterface(
404 | fn=fn,
405 | additional_inputs=inputs,
406 | multimodal=True,
407 | **kwargs
408 | )
409 | else:
410 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
411 |
412 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/deepseek_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | from openai import OpenAI
4 | import gradio as gr
5 | from typing import Callable
6 | import modelscope_studio.components.base as ms
7 | import modelscope_studio.components.legacy as legacy
8 | import modelscope_studio.components.antd as antd
9 |
10 | __version__ = "0.0.1"
11 |
12 | # Add these constants at the top of the file
13 | SystemPrompt = """You are an expert web developer specializing in creating clean, efficient, and modern web applications.
14 | Your task is to write complete, self-contained HTML files that include all necessary CSS and JavaScript.
15 | Focus on:
16 | - Writing clear, maintainable code
17 | - Following best practices
18 | - Creating responsive designs
19 | - Adding appropriate styling and interactivity
20 | Return only the complete HTML code without any additional explanation."""
21 |
22 | DEMO_LIST = [
23 | {
24 | "card": {"index": 0},
25 | "title": "Simple Button",
26 | "description": "Create a button that changes color when clicked"
27 | },
28 | {
29 | "card": {"index": 1},
30 | "title": "Todo List",
31 | "description": "Create a simple todo list with add/remove functionality"
32 | },
33 | {
34 | "card": {"index": 2},
35 | "title": "Timer App",
36 | "description": "Create a countdown timer with start/pause/reset controls"
37 | }
38 | ]
39 |
40 | def get_image_base64(url: str, ext: str):
41 | with open(url, "rb") as image_file:
42 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
43 | return "data:image/" + ext + ";base64," + encoded_string
44 |
45 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
46 | def fn(message, history):
47 | inputs = preprocess(message, history)
48 | client = OpenAI(
49 | base_url="https://api.deepseek.com",
50 | api_key=api_key,
51 | )
52 | try:
53 | completion = client.chat.completions.create(
54 | model=model_name,
55 | messages=inputs["messages"],
56 | stream=True,
57 | )
58 | response_text = ""
59 | reasoning_text = ""
60 | for chunk in completion:
61 | if chunk.choices[0].delta.reasoning_content:
62 | reasoning_text += chunk.choices[0].delta.reasoning_content
63 | else:
64 | delta = chunk.choices[0].delta.content or ""
65 | response_text += delta
66 | yield postprocess(response_text, reasoning_text)
67 | except Exception as e:
68 | error_message = f"Error: {str(e)}"
69 | return error_message
70 |
71 | return fn
72 |
73 | def handle_user_msg(message: str):
74 | if type(message) is str:
75 | return message
76 | elif type(message) is dict:
77 | if message["files"] is not None and len(message["files"]) > 0:
78 | ext = os.path.splitext(message["files"][-1])[1].strip(".")
79 | if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
80 | encoded_str = get_image_base64(message["files"][-1], ext)
81 | else:
82 | raise NotImplementedError(f"Not supported file type {ext}")
83 | content = [
84 | {"type": "text", "text": message["text"]},
85 | {
86 | "type": "image_url",
87 | "image_url": {
88 | "url": encoded_str,
89 | }
90 | },
91 | ]
92 | else:
93 | content = message["text"]
94 | return content
95 | else:
96 | raise NotImplementedError
97 |
98 | def get_interface_args(pipeline):
99 | if pipeline == "chat":
100 | inputs = None
101 | outputs = [
102 | gr.Textbox(label="Chain of Thought", lines=10, visible=True),
103 | gr.Textbox(label="Response")
104 | ]
105 |
106 | def preprocess(message, history):
107 | messages = []
108 | files = None
109 | for user_msg, assistant_msg in history:
110 | if assistant_msg is not None:
111 | if isinstance(assistant_msg, dict):
112 | assistant_msg = assistant_msg["visible"][1] # Get the response text
113 | messages.append({"role": "user", "content": handle_user_msg(user_msg)})
114 | messages.append({"role": "assistant", "content": assistant_msg})
115 | else:
116 | files = user_msg
117 | if type(message) is str and files is not None:
118 | message = {"text": message, "files": files}
119 | elif type(message) is dict and files is not None:
120 | if message["files"] is None or len(message["files"]) == 0:
121 | message["files"] = files
122 | messages.append({"role": "user", "content": handle_user_msg(message)})
123 | return {"messages": messages}
124 |
125 | def postprocess(response, reasoning=None):
126 | # Update both textboxes but return only the response for the chat history
127 | return response
128 | else:
129 | # Add other pipeline types when they will be needed
130 | raise ValueError(f"Unsupported pipeline type: {pipeline}")
131 | return inputs, outputs, preprocess, postprocess
132 |
133 |
134 | def get_pipeline(model_name):
135 | if model_name == "deepseek-reasoner":
136 | return "chat"
137 | # For other models, assume chat pipeline
138 | return "chat"
139 |
140 |
141 | def registry(
142 | name: str,
143 | token: str | None = None,
144 | examples: list | None = None,
145 | coder: bool = False,
146 | **kwargs
147 | ):
148 | api_key = token or os.environ.get("DEEPSEEK_API_KEY")
149 | if not api_key:
150 | raise ValueError("DEEPSEEK_API_KEY environment variable is not set.")
151 |
152 | if coder:
153 | interface = gr.Blocks(css="""
154 | .left_header {
155 | text-align: center;
156 | margin-bottom: 20px;
157 | }
158 |
159 | .right_panel {
160 | background: white;
161 | border-radius: 8px;
162 | overflow: hidden;
163 | box-shadow: 0 2px 8px rgba(0,0,0,0.15);
164 | }
165 |
166 | .render_header {
167 | background: #f5f5f5;
168 | padding: 8px;
169 | border-bottom: 1px solid #e8e8e8;
170 | }
171 |
172 | .header_btn {
173 | display: inline-block;
174 | width: 12px;
175 | height: 12px;
176 | border-radius: 50%;
177 | margin-right: 8px;
178 | background: #ff5f56;
179 | }
180 |
181 | .header_btn:nth-child(2) {
182 | background: #ffbd2e;
183 | }
184 |
185 | .header_btn:nth-child(3) {
186 | background: #27c93f;
187 | }
188 |
189 | .right_content {
190 | padding: 24px;
191 | height: 920px;
192 | display: flex;
193 | align-items: center;
194 | justify-content: center;
195 | }
196 |
197 | .html_content {
198 | height: 920px;
199 | width: 100%;
200 | }
201 |
202 | .history_chatbot {
203 | height: 100%;
204 | }
205 | """)
206 | with interface:
207 | history = gr.State([])
208 | setting = gr.State({"system": SystemPrompt})
209 |
210 | with ms.Application() as app:
211 | with antd.ConfigProvider():
212 | with antd.Row(gutter=[32, 12]) as layout:
213 | # Left Column
214 | with antd.Col(span=24, md=8):
215 | with antd.Flex(vertical=True, gap="middle", wrap=True):
216 | header = gr.HTML("""
217 |
220 | """)
221 |
222 | input = antd.InputTextarea(
223 | size="large",
224 | allow_clear=True,
225 | placeholder="Describe the web application you want to create"
226 | )
227 | btn = antd.Button("Generate", type="primary", size="large")
228 | clear_btn = antd.Button("Clear History", type="default", size="large")
229 |
230 | antd.Divider("Examples")
231 | with antd.Flex(gap="small", wrap=True):
232 | with ms.Each(DEMO_LIST):
233 | with antd.Card(hoverable=True, as_item="card") as demoCard:
234 | antd.CardMeta()
235 |
236 | antd.Divider("Settings")
237 | with antd.Flex(gap="small", wrap=True):
238 | settingPromptBtn = antd.Button("⚙️ System Prompt", type="default")
239 | codeBtn = antd.Button("🧑💻 View Code", type="default")
240 | historyBtn = antd.Button("📜 History", type="default")
241 |
242 | # Modals and Drawers
243 | with antd.Modal(open=False, title="System Prompt", width="800px") as system_prompt_modal:
244 | systemPromptInput = antd.InputTextarea(SystemPrompt, auto_size=True)
245 |
246 | with antd.Drawer(open=False, title="Code", placement="left", width="750px") as code_drawer:
247 | code_output = legacy.Markdown()
248 |
249 | with antd.Drawer(open=False, title="History", placement="left", width="900px") as history_drawer:
250 | history_output = legacy.Chatbot(
251 | show_label=False,
252 | height=960,
253 | elem_classes="history_chatbot"
254 | )
255 |
256 | # Right Column
257 | with antd.Col(span=24, md=16):
258 | with ms.Div(elem_classes="right_panel"):
259 | gr.HTML('''
260 |
265 | ''')
266 | with antd.Tabs(active_key="empty", render_tab_bar="() => null") as state_tab:
267 | with antd.Tabs.Item(key="empty"):
268 | empty = antd.Empty(
269 | description="Enter your request to generate code",
270 | elem_classes="right_content"
271 | )
272 | with antd.Tabs.Item(key="loading"):
273 | loading = antd.Spin(
274 | True,
275 | tip="Generating code...",
276 | size="large",
277 | elem_classes="right_content"
278 | )
279 | with antd.Tabs.Item(key="render"):
280 | preview = gr.HTML(elem_classes="html_content")
281 |
282 | # Event Handlers
283 | def demo_card_click(e: gr.EventData):
284 | index = e._data['component']['index']
285 | return DEMO_LIST[index]['description']
286 |
287 | def send_to_preview(code):
288 | # Clean the code and escape special characters for HTML
289 | clean_code = code.replace("```html", "").replace("```", "").strip()
290 | escaped_code = clean_code.replace('"', '"').replace('<', '<').replace('>', '>')
291 | return f'''
292 |
298 | '''
299 |
300 | def generate_code(query, setting, history):
301 | messages = [{"role": "system", "content": setting["system"]}]
302 |
303 | # Add history in alternating user/assistant pattern
304 | for user_msg, assistant_msg in history:
305 | messages.append({"role": "user", "content": user_msg})
306 | messages.append({"role": "assistant", "content": assistant_msg})
307 |
308 | # Add current query
309 | messages.append({"role": "user", "content": query})
310 |
311 | try:
312 | client = OpenAI(
313 | base_url="https://api.deepseek.com",
314 | api_key=api_key,
315 | )
316 |
317 | response = client.chat.completions.create(
318 | model=name,
319 | messages=messages,
320 | stream=True,
321 | )
322 |
323 | code = ""
324 | for chunk in response:
325 | if chunk.choices[0].delta.content:
326 | code += chunk.choices[0].delta.content
327 | # Return all 5 required outputs
328 | yield (
329 | f"```html\n{code}\n```", # code_output (modelscopemarkdown)
330 | history, # state
331 | None, # preview (html)
332 | gr.update(active_key="loading"), # state_tab (antdtabs)
333 | gr.update(open=True) # code_drawer (antddrawer)
334 | )
335 |
336 | new_history = history + [(query, code)]
337 |
338 | # Final yield with all outputs
339 | yield (
340 | f"```html\n{code}\n```", # code_output (modelscopemarkdown)
341 | new_history, # state
342 | send_to_preview(code), # preview (html)
343 | gr.update(active_key="render"), # state_tab (antdtabs)
344 | gr.update(open=False) # code_drawer (antddrawer)
345 | )
346 |
347 | except Exception as e:
348 | error_msg = f"Error: {str(e)}"
349 | yield (
350 | error_msg,
351 | history,
352 | None,
353 | gr.update(active_key="empty"),
354 | gr.update(open=True)
355 | )
356 |
357 | # Wire up event handlers
358 | demoCard.click(demo_card_click, outputs=[input])
359 | settingPromptBtn.click(lambda: gr.update(open=True), outputs=[system_prompt_modal])
360 | system_prompt_modal.ok(
361 | lambda input: ({"system": input}, gr.update(open=False)),
362 | inputs=[systemPromptInput],
363 | outputs=[setting, system_prompt_modal]
364 | )
365 | system_prompt_modal.cancel(lambda: gr.update(open=False), outputs=[system_prompt_modal])
366 |
367 | codeBtn.click(lambda: gr.update(open=True), outputs=[code_drawer])
368 | code_drawer.close(lambda: gr.update(open=False), outputs=[code_drawer])
369 |
370 | historyBtn.click(
371 | lambda h: (gr.update(open=True), h),
372 | inputs=[history],
373 | outputs=[history_drawer, history_output]
374 | )
375 | history_drawer.close(lambda: gr.update(open=False), outputs=[history_drawer])
376 |
377 | btn.click(
378 | generate_code,
379 | inputs=[input, setting, history],
380 | outputs=[code_output, history, preview, state_tab, code_drawer]
381 | )
382 |
383 | clear_btn.click(lambda: [], outputs=[history])
384 |
385 | return interface
386 |
387 | # Continue with existing chat interface code...
388 | pipeline = get_pipeline(name)
389 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline)
390 | fn = get_fn(name, preprocess, postprocess, api_key)
391 |
392 | if examples:
393 | kwargs["examples"] = examples
394 |
395 | if pipeline == "chat":
396 | interface = gr.ChatInterface(fn=fn, **kwargs)
397 | else:
398 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
399 |
400 | return interface
--------------------------------------------------------------------------------
/ai_gradio/providers/replicate_gradio.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import replicate
3 | import asyncio
4 | import os
5 | from typing import Callable, Dict, Any, List, Tuple
6 | import httpx
7 | from PIL import Image
8 | import io
9 | import base64
10 | import numpy as np
11 | import tempfile
12 | import time
13 |
14 | def resize_image_if_needed(image, max_size=1024):
15 | """Resize image if either dimension exceeds max_size while maintaining aspect ratio"""
16 | if isinstance(image, str) and image.startswith('data:image'):
17 | return image # Already a data URI, skip processing
18 |
19 | if isinstance(image, np.ndarray):
20 | image = Image.fromarray(image)
21 | elif not isinstance(image, Image.Image):
22 | image = Image.open(image)
23 |
24 | # Get original dimensions
25 | width, height = image.size
26 |
27 | # Calculate new dimensions if needed
28 | if width > max_size or height > max_size:
29 | if width > height:
30 | new_width = max_size
31 | new_height = int(height * (max_size / width))
32 | else:
33 | new_height = max_size
34 | new_width = int(width * (max_size / height))
35 |
36 | image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
37 |
38 | # Convert to RGB if necessary
39 | if image.mode != 'RGB':
40 | image = image.convert('RGB')
41 |
42 | # Convert to base64
43 | buffered = io.BytesIO()
44 | image.save(buffered, format="JPEG", quality=85)
45 | img_str = base64.b64encode(buffered.getvalue()).decode()
46 | return f"data:image/jpeg;base64,{img_str}"
47 |
48 | def bytes_to_image(byte_data):
49 | """Convert bytes to PIL Image and ensure we get fresh data"""
50 | if isinstance(byte_data, bytes):
51 | return Image.open(io.BytesIO(byte_data))
52 | # For file-like objects
53 | if hasattr(byte_data, 'seek'):
54 | byte_data.seek(0)
55 | return Image.open(io.BytesIO(byte_data.read()))
56 |
57 | def save_bytes_to_video(video_bytes):
58 | """Save video bytes to a temporary file and return the path"""
59 | if not isinstance(video_bytes, bytes):
60 | raise ValueError(f"Expected bytes input, got {type(video_bytes)}")
61 |
62 | # Create a temporary file with .mp4 extension
63 | temp_dir = tempfile.gettempdir()
64 | temp_path = os.path.join(temp_dir, f"temp_{int(time.time())}_{os.urandom(4).hex()}.mp4")
65 |
66 | try:
67 | # Write the bytes to the temporary file
68 | with open(temp_path, "wb") as f:
69 | f.write(video_bytes)
70 |
71 | # Ensure the file exists and has content
72 | if not os.path.exists(temp_path) or os.path.getsize(temp_path) == 0:
73 | raise ValueError("Failed to save video file or file is empty")
74 |
75 | return str(temp_path) # Return string path as expected by Gradio
76 | finally:
77 | # Clean up the temporary file
78 | if os.path.exists(temp_path):
79 | os.remove(temp_path)
80 |
81 | PIPELINE_REGISTRY = {
82 | "text-to-image": {
83 | "inputs": [
84 | ("prompt", gr.Textbox, {"label": "Prompt"}),
85 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}),
86 | ("width", gr.Number, {"label": "Width", "value": 1024, "minimum": 512, "maximum": 2048, "step": 64, "optional": True}),
87 | ("height", gr.Number, {"label": "Height", "value": 1024, "minimum": 512, "maximum": 2048, "step": 64, "optional": True}),
88 | ("num_outputs", gr.Number, {"label": "Number of Images", "value": 1, "minimum": 1, "maximum": 4, "step": 1, "optional": True}),
89 | ("scheduler", gr.Dropdown, {"label": "Scheduler", "choices": ["DPM++ 2M", "DPM++ 2M Karras", "DPM++ 2M SDE", "DPM++ 2M SDE Karras"], "optional": True}),
90 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}),
91 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}),
92 | ("seed", gr.Number, {"label": "Seed", "optional": True})
93 | ],
94 | "outputs": [("images", gr.Gallery, {})],
95 | "preprocess": lambda *args: {
96 | k: v for k, v in zip([
97 | "prompt", "negative_prompt", "width", "height", "num_outputs",
98 | "scheduler", "num_inference_steps", "guidance_scale", "seed"
99 | ], args) if v is not None and v != ""
100 | },
101 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)]
102 | },
103 |
104 | "image-to-image": {
105 | "inputs": [
106 | ("prompt", gr.Textbox, {"label": "Prompt"}),
107 | ("image", gr.Image, {"label": "Input Image", "type": "pil"}),
108 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}),
109 | ("strength", gr.Slider, {"label": "Strength", "minimum": 0, "maximum": 1, "value": 0.7, "optional": True}),
110 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}),
111 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}),
112 | ("seed", gr.Number, {"label": "Seed", "optional": True})
113 | ],
114 | "outputs": [("images", gr.Gallery, {})],
115 | "preprocess": lambda *args: {
116 | k: (resize_image_if_needed(v) if k == "image" else v)
117 | for k, v in zip([
118 | "prompt", "image", "negative_prompt", "strength",
119 | "num_inference_steps", "guidance_scale", "seed"
120 | ], args) if v is not None and v != ""
121 | },
122 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)]
123 | },
124 |
125 | "control-net": {
126 | "inputs": [
127 | ("prompt", gr.Textbox, {"label": "Prompt"}),
128 | ("control_image", gr.Image, {"label": "Control Image", "type": "pil"}),
129 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}),
130 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}),
131 | ("control_guidance_scale", gr.Slider, {"label": "Control Guidance Scale", "minimum": 1, "maximum": 20, "value": 1.5, "optional": True}),
132 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}),
133 | ("seed", gr.Number, {"label": "Seed", "optional": True})
134 | ],
135 | "outputs": [("images", gr.Gallery, {})],
136 | "preprocess": lambda *args: {
137 | k: (resize_image_if_needed(v) if k == "control_image" else v)
138 | for k, v in zip([
139 | "prompt", "control_image", "negative_prompt", "guidance_scale",
140 | "control_guidance_scale", "num_inference_steps", "seed"
141 | ], args) if v is not None and v != ""
142 | },
143 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)]
144 | },
145 |
146 | "inpainting": {
147 | "inputs": [
148 | ("prompt", gr.Textbox, {"label": "Prompt"}),
149 | ("image", gr.Image, {"label": "Original Image", "type": "pil"}),
150 | ("mask", gr.Image, {"label": "Mask Image", "type": "pil"}),
151 | ("negative_prompt", gr.Textbox, {"label": "Negative Prompt", "optional": True}),
152 | ("num_inference_steps", gr.Slider, {"label": "Steps", "minimum": 1, "maximum": 100, "value": 30, "optional": True}),
153 | ("guidance_scale", gr.Slider, {"label": "Guidance Scale", "minimum": 1, "maximum": 20, "value": 7.5, "optional": True}),
154 | ("seed", gr.Number, {"label": "Seed", "optional": True})
155 | ],
156 | "outputs": [("images", gr.Gallery, {})],
157 | "preprocess": lambda *args: {
158 | k: (resize_image_if_needed(v) if k in ["image", "mask"] else v)
159 | for k, v in zip([
160 | "prompt", "image", "mask", "negative_prompt",
161 | "num_inference_steps", "guidance_scale", "seed"
162 | ], args) if v is not None and v != ""
163 | },
164 | "postprocess": lambda x: [bytes_to_image(img) for img in x] if isinstance(x, list) else [bytes_to_image(x)]
165 | },
166 |
167 | "text-to-video": {
168 | "inputs": [
169 | ("prompt", gr.Textbox, {
170 | "label": "Prompt",
171 | "value": "A cat walks on the grass, realistic style.",
172 | "info": "Text prompt to generate video."
173 | }),
174 | ("height", gr.Number, {
175 | "label": "Height",
176 | "value": 480,
177 | "minimum": 1,
178 | "info": "Height of the video in pixels."
179 | }),
180 | ("width", gr.Number, {
181 | "label": "Width",
182 | "value": 854,
183 | "minimum": 1,
184 | "info": "Width of the video in pixels."
185 | }),
186 | ("video_length", gr.Number, {
187 | "label": "Video Length",
188 | "value": 129,
189 | "minimum": 1,
190 | "info": "Length of the video in frames."
191 | }),
192 | ("infer_steps", gr.Number, {
193 | "label": "Infer Steps",
194 | "value": 30,
195 | "minimum": 1,
196 | "maximum": 50,
197 | "info": "Number of inference steps."
198 | }),
199 | ("flow_shift", gr.Number, {
200 | "label": "Flow Shift",
201 | "value": 7,
202 | "info": "Flow-shift parameter."
203 | }),
204 | ("embedded_guidance_scale", gr.Slider, {
205 | "label": "Embedded Guidance Scale",
206 | "value": 6,
207 | "minimum": 1,
208 | "maximum": 6,
209 | "info": "Embedded guidance scale for generation."
210 | }),
211 | ("seed", gr.Number, {
212 | "label": "Seed",
213 | "optional": True,
214 | "info": "Random seed for reproducibility."
215 | })
216 | ],
217 | "outputs": [
218 | ("video", gr.Video, {
219 | "format": "mp4",
220 | "autoplay": True,
221 | "show_label": True,
222 | "label": "Generated Video",
223 | "height": 480,
224 | "width": 854,
225 | "interactive": False,
226 | "show_download_button": True
227 | })
228 | ],
229 | "preprocess": lambda *args: {
230 | k: (int(v) if k in ["height", "width", "video_length", "infer_steps", "seed"] else
231 | float(v) if k in ["flow_shift", "embedded_guidance_scale"] else v)
232 | for k, v in zip([
233 | "prompt", "height", "width", "video_length",
234 | "infer_steps", "flow_shift", "embedded_guidance_scale", "seed"
235 | ], args) if v is not None and v != ""
236 | },
237 | "postprocess": lambda x: (
238 | x.url if hasattr(x, 'url')
239 | else (lambda p: os.remove(p) or p)(x) # Delete file after getting path
240 | ),
241 | },
242 |
243 | "text-generation": {
244 | "inputs": [
245 | ("message", gr.Textbox, {
246 | "label": "Message",
247 | "lines": 3,
248 | "placeholder": "Enter your message here..."
249 | }),
250 | ],
251 | "outputs": [
252 | ("response", gr.Textbox, {
253 | "label": "Assistant",
254 | "lines": 10,
255 | "show_copy_button": True
256 | })
257 | ],
258 | "preprocess": lambda *args: {
259 | "prompt": args[0] if args[0] is not None and args[0] != "" else ""
260 | },
261 | "postprocess": lambda x: x,
262 | "is_chat": True # New flag to indicate this is a chat interface
263 | },
264 | }
265 |
266 | MODEL_TO_PIPELINE = {
267 | "stability-ai/sdxl": "text-to-image",
268 | "black-forest-labs/flux-pro": "text-to-image",
269 | "stability-ai/stable-diffusion": "text-to-image",
270 |
271 | "black-forest-labs/flux-depth-pro": "control-net",
272 | "black-forest-labs/flux-canny-pro": "control-net",
273 | "black-forest-labs/flux-depth-dev": "control-net",
274 |
275 | "black-forest-labs/flux-fill-pro": "inpainting",
276 | "stability-ai/stable-diffusion-inpainting": "inpainting",
277 | "tencent/hunyuan-video:140176772be3b423d14fdaf5403e6d4e38b85646ccad0c3fd2ed07c211f0cad1": "text-to-video",
278 | "deepseek-ai/deepseek-r1": "text-generation",
279 | }
280 |
281 | def create_component(comp_type: type, name: str, config: Dict[str, Any]) -> gr.components.Component:
282 | # Remove 'optional' from config as it's not a valid Gradio parameter
283 | config = config.copy()
284 | is_optional = config.pop('optional', False)
285 |
286 | # Add "(Optional)" to label if the field is optional
287 | if is_optional:
288 | label = config.get('label', name)
289 | config['label'] = f"{label} (Optional)"
290 |
291 | return comp_type(label=config.get("label", name), **{k:v for k,v in config.items() if k != "label"})
292 |
293 | def get_pipeline(model: str) -> str:
294 | return MODEL_TO_PIPELINE.get(model, "text-to-image")
295 |
296 | def get_interface_args(pipeline: str) -> Tuple[List, List, Callable, Callable]:
297 | if pipeline not in PIPELINE_REGISTRY:
298 | raise ValueError(f"Unsupported pipeline: {pipeline}")
299 |
300 | config = PIPELINE_REGISTRY[pipeline]
301 |
302 | inputs = [create_component(comp_type, name, conf)
303 | for name, comp_type, conf in config["inputs"]]
304 |
305 | outputs = [create_component(comp_type, name, conf)
306 | for name, comp_type, conf in config["outputs"]]
307 |
308 | return inputs, outputs, config["preprocess"], config["postprocess"]
309 |
310 | async def async_run_with_timeout(model_name: str, args: dict):
311 | try:
312 | stream = await replicate.async_stream(
313 | model_name,
314 | input=args
315 | )
316 | async for output in stream:
317 | yield output
318 | except Exception as e:
319 | print(f"Error during model prediction: {str(e)}")
320 | raise gr.Error(f"Model prediction failed: {str(e)}")
321 |
322 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable):
323 | async def fn(*args):
324 | try:
325 | args = preprocess(*args)
326 | if model_name == "deepseek-ai/deepseek-r1":
327 | response = ""
328 | async for chunk in async_run_with_timeout(model_name, args):
329 | chunk = str(chunk)
330 | # Replace XML-like tags with square bracket versions
331 | chunk = (chunk.replace("", "[think]")
332 | .replace("", "[/think]")
333 | .replace("", "[Answer]")
334 | .replace("", "[/Answer]"))
335 | response += chunk
336 | return response.strip()
337 | output = await async_run_with_timeout(model_name, args)
338 | return postprocess(output)
339 | except Exception as e:
340 | raise gr.Error(f"Error: {str(e)}")
341 | return fn
342 |
343 | def registry(name: str | Dict, token: str | None = None, inputs=None, outputs=None, src=None, accept_token: bool = False, **kwargs) -> gr.Interface:
344 | """
345 | Create a Gradio Interface for a model on Replicate.
346 | Parameters:
347 | - name (str | Dict): The name of the model on Replicate, or a dict with model info.
348 | - token (str, optional): The API token for the model on Replicate.
349 | - inputs (List[gr.Component], optional): The input components to use instead of the default.
350 | - outputs (List[gr.Component], optional): The output components to use instead of the default.
351 | - src (callable, optional): Ignored, used by gr.load for routing.
352 | - accept_token (bool, optional): Whether to accept a token input field.
353 | Returns:
354 | gr.Interface or gr.ChatInterface: A Gradio interface for the model.
355 | """
356 | # Handle both string names and dict configurations
357 | if isinstance(name, dict):
358 | model_name = name.get('name', name.get('model_name', ''))
359 | else:
360 | model_name = name
361 |
362 | if token:
363 | os.environ["REPLICATE_API_TOKEN"] = token
364 |
365 | pipeline = get_pipeline(model_name)
366 | inputs_, outputs_, preprocess, postprocess = get_interface_args(pipeline)
367 |
368 | # Add token input if accept_token is True
369 | if accept_token:
370 | token_input = gr.Textbox(label="API Token", type="password")
371 | inputs_ = [token_input] + inputs_
372 |
373 | # Modify preprocess function to handle token
374 | original_preprocess = preprocess
375 | def new_preprocess(token, *args):
376 | if token:
377 | os.environ["REPLICATE_API_TOKEN"] = token
378 | return original_preprocess(*args)
379 | preprocess = new_preprocess
380 |
381 | inputs, outputs = inputs or inputs_, outputs or outputs_
382 | fn = get_fn(model_name, preprocess, postprocess)
383 |
384 | # Use ChatInterface for text-generation models
385 | if pipeline == "text-generation":
386 | return gr.ChatInterface(
387 | fn=fn,
388 | **kwargs
389 | )
390 |
391 | return gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs)
392 |
393 |
394 | __version__ = "0.1.0"
--------------------------------------------------------------------------------
/ai_gradio/providers/__init__.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | def custom_load(name: str, src: dict, **kwargs):
4 | # Only use custom loading if name contains provider prefix
5 | if ':' in name:
6 | provider, model = name.split(':')
7 | # Create provider-specific model key
8 | model_key = f"{provider}:{model}"
9 |
10 | if model_key not in src:
11 | available_models = [k for k in src.keys()]
12 | raise ValueError(f"Model {model_key} not found. Available models: {available_models}")
13 | return src[model_key](name=model, **kwargs)
14 |
15 | # Fall back to original gradio behavior if no provider prefix
16 | return original_load(name, src, **kwargs)
17 |
18 | # Store original load function before overriding
19 | original_load = gr.load
20 | gr.load = custom_load
21 |
22 | registry = {}
23 |
24 |
25 | try:
26 | from .openai_gradio import registry as openai_registry
27 | registry.update({f"openai:{k}": openai_registry for k in [
28 | "gpt-4o-2024-11-20",
29 | "gpt-4o",
30 | "gpt-4o-2024-08-06",
31 | "gpt-4o-2024-05-13",
32 | "chatgpt-4o-latest",
33 | "gpt-4o-mini",
34 | "gpt-4o-mini-2024-07-18",
35 | "o1-preview",
36 | "o1-preview-2024-09-12",
37 | "o1-mini",
38 | "o1-mini-2024-09-12",
39 | "gpt-4-turbo",
40 | "gpt-4-turbo-2024-04-09",
41 | "gpt-4-turbo-preview",
42 | "gpt-4-0125-preview",
43 | "gpt-4-1106-preview",
44 | "gpt-4",
45 | "gpt-4-0613",
46 | "o1-2024-12-17",
47 | "gpt-4o-realtime-preview-2024-10-01",
48 | "gpt-4o-realtime-preview",
49 | "gpt-4o-realtime-preview-2024-12-17",
50 | "gpt-4o-mini-realtime-preview",
51 | "gpt-4o-mini-realtime-preview-2024-12-17",
52 | "o3-mini-2025-01-31"
53 | ]})
54 | except ImportError:
55 | pass
56 |
57 | try:
58 | from .gemini_gradio import registry as gemini_registry
59 | registry.update({f"gemini:{k}": gemini_registry for k in [
60 | 'gemini-1.5-flash',
61 | 'gemini-1.5-flash-8b',
62 | 'gemini-1.5-pro',
63 | 'gemini-exp-1114',
64 | 'gemini-exp-1121',
65 | 'gemini-exp-1206',
66 | 'gemini-2.0-flash-exp',
67 | 'gemini-2.0-flash-thinking-exp-1219',
68 | 'gemini-2.0-flash-thinking-exp-01-21',
69 | 'gemini-2.0-pro-exp-02-05',
70 | 'gemini-2.0-flash-lite-preview-02-05'
71 | ]})
72 | except ImportError:
73 | pass
74 |
75 | try:
76 | from .crewai_gradio import registry as crewai_registry
77 | # Add CrewAI models with their own prefix
78 | registry.update({f"crewai:{k}": crewai_registry for k in ['gpt-4-turbo', 'gpt-4', 'gpt-3.5-turbo']})
79 | except ImportError:
80 | pass
81 |
82 | try:
83 | from .anthropic_gradio import registry as anthropic_registry
84 | registry.update({f"anthropic:{k}": anthropic_registry for k in [
85 | 'claude-3-5-sonnet-20241022',
86 | 'claude-3-5-haiku-20241022',
87 | 'claude-3-opus-20240229',
88 | 'claude-3-sonnet-20240229',
89 | 'claude-3-haiku-20240307',
90 | ]})
91 | except ImportError:
92 | pass
93 |
94 | try:
95 | from .lumaai_gradio import registry as lumaai_registry
96 | registry.update({f"lumaai:{k}": lumaai_registry for k in [
97 | 'dream-machine',
98 | 'photon-1',
99 | 'photon-flash-1'
100 | ]})
101 | except ImportError:
102 | pass
103 |
104 | try:
105 | from .xai_gradio import registry as xai_registry
106 | registry.update({f"xai:{k}": xai_registry for k in [
107 | 'grok-beta',
108 | 'grok-vision-beta'
109 | ]})
110 | except ImportError:
111 | pass
112 |
113 | try:
114 | from .cohere_gradio import registry as cohere_registry
115 | registry.update({f"cohere:{k}": cohere_registry for k in [
116 | 'command-r7b-12-2024',
117 | 'command-light',
118 | 'command-nightly',
119 | 'command-light-nightly'
120 | ]})
121 | except ImportError:
122 | pass
123 |
124 | try:
125 | from .sambanova_gradio import registry as sambanova_registry
126 | registry.update({f"sambanova:{k}": sambanova_registry for k in [
127 | 'Meta-Llama-3.1-405B-Instruct',
128 | 'Meta-Llama-3.1-8B-Instruct',
129 | 'Meta-Llama-3.1-70B-Instruct',
130 | 'Meta-Llama-3.1-405B-Instruct-Preview',
131 | 'Meta-Llama-3.1-8B-Instruct-Preview',
132 | 'Meta-Llama-3.3-70B-Instruct',
133 | 'Meta-Llama-3.2-3B-Instruct',
134 | ]})
135 | except ImportError:
136 | pass
137 |
138 | try:
139 | from .hyperbolic_gradio import registry as hyperbolic_registry
140 | registry.update({f"hyperbolic:{k}": hyperbolic_registry for k in [
141 | 'Qwen/Qwen2.5-Coder-32B-Instruct',
142 | 'meta-llama/Llama-3.2-3B-Instruct',
143 | 'meta-llama/Meta-Llama-3.1-8B-Instruct',
144 | 'meta-llama/Meta-Llama-3.1-70B-Instruct',
145 | 'meta-llama/Meta-Llama-3-70B-Instruct',
146 | 'NousResearch/Hermes-3-Llama-3.1-70B',
147 | 'Qwen/Qwen2.5-72B-Instruct',
148 | 'deepseek-ai/DeepSeek-V2.5',
149 | 'meta-llama/Meta-Llama-3.1-405B-Instruct',
150 | 'Qwen/QwQ-32B-Preview',
151 | 'meta-llama/Llama-3.3-70B-Instruct',
152 | 'deepseek-ai/DeepSeek-V3',
153 | 'deepseek-ai/DeepSeek-R1',
154 | 'deepseek-ai/DeepSeek-R1-Zero'
155 | ]})
156 | except ImportError:
157 | pass
158 |
159 | try:
160 | from .qwen_gradio import registry as qwen_registry
161 | registry.update({f"qwen:{k}": qwen_registry for k in [
162 | "qwen-turbo-latest",
163 | "qwen-turbo",
164 | "qwen-plus",
165 | "qwen-max",
166 | "qwen1.5-110b-chat",
167 | "qwen1.5-72b-chat",
168 | "qwen1.5-32b-chat",
169 | "qwen1.5-14b-chat",
170 | "qwen1.5-7b-chat",
171 | "qwq-32b-preview",
172 | 'qvq-72b-preview',
173 | 'qwen2.5-14b-instruct-1m',
174 | 'qwen2.5-7b-instruct-1m',
175 | 'qwen-max-0125'
176 | ]})
177 | except ImportError:
178 | pass
179 |
180 | try:
181 | from .fireworks_gradio import registry as fireworks_registry
182 | registry.update({f"fireworks:{k}": fireworks_registry for k in [
183 | 'whisper-v3',
184 | 'whisper-v3-turbo',
185 | 'f1-preview',
186 | 'f1-mini'
187 | ]})
188 | except ImportError:
189 | pass
190 |
191 | try:
192 | from .together_gradio import registry as together_registry
193 | registry.update({f"together:{k}": together_registry for k in [
194 | # Vision Models
195 | 'meta-llama/Llama-Vision-Free',
196 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo',
197 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo',
198 |
199 | # Llama 3 Series
200 | 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo',
201 | 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
202 | 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo',
203 | 'meta-llama/Meta-Llama-3-8B-Instruct-Turbo',
204 | 'meta-llama/Meta-Llama-3-70B-Instruct-Turbo',
205 | 'meta-llama/Llama-3.2-3B-Instruct-Turbo',
206 | 'meta-llama/Meta-Llama-3-8B-Instruct-Lite',
207 | 'meta-llama/Meta-Llama-3-70B-Instruct-Lite',
208 | 'meta-llama/Llama-3-8b-chat-hf',
209 | 'meta-llama/Llama-3-70b-chat-hf',
210 | 'meta-llama/Llama-3.3-70B-Instruct-Turbo',
211 |
212 | # Other Large Models
213 | 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
214 | 'Qwen/Qwen2.5-Coder-32B-Instruct',
215 | 'microsoft/WizardLM-2-8x22B',
216 | 'databricks/dbrx-instruct',
217 |
218 | # Gemma Models
219 | 'google/gemma-2-27b-it',
220 | 'google/gemma-2-9b-it',
221 | 'google/gemma-2b-it',
222 |
223 | # Mixtral Models
224 | 'mistralai/Mixtral-8x7B-Instruct-v0.1',
225 | 'mistralai/Mixtral-8x22B-Instruct-v0.1',
226 |
227 | # Qwen Models
228 | 'Qwen/Qwen2.5-7B-Instruct-Turbo',
229 | 'Qwen/Qwen2.5-72B-Instruct-Turbo',
230 | 'Qwen/Qwen2-72B-Instruct',
231 |
232 | # Other Models
233 | 'deepseek-ai/deepseek-llm-67b-chat',
234 | 'Gryphe/MythoMax-L2-13b',
235 | 'meta-llama/Llama-2-13b-chat-hf',
236 | 'mistralai/Mistral-7B-Instruct-v0.1',
237 | 'mistralai/Mistral-7B-Instruct-v0.2',
238 | 'mistralai/Mistral-7B-Instruct-v0.3',
239 | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO',
240 | 'togethercomputer/StripedHyena-Nous-7B',
241 | 'upstage/SOLAR-10.7B-Instruct-v1.0',
242 | 'deepseek-ai/DeepSeek-V3',
243 | 'deepseek-ai/DeepSeek-R1',
244 | 'mistralai/Mistral-Small-24B-Instruct-2501'
245 | ]})
246 | except ImportError:
247 | pass
248 |
249 | try:
250 | from .deepseek_gradio import registry as deepseek_registry
251 | registry.update({f"deepseek:{k}": deepseek_registry for k in [
252 | 'deepseek-chat',
253 | 'deepseek-coder',
254 | 'deepseek-vision',
255 | 'deepseek-reasoner'
256 | ]})
257 | except ImportError:
258 | pass
259 |
260 | try:
261 | from .smolagents_gradio import registry as smolagents_registry
262 | registry.update({f"smolagents:{k}": smolagents_registry for k in [
263 | 'Qwen/Qwen2.5-72B-Instruct',
264 | 'Qwen/Qwen2.5-4B-Instruct',
265 | 'Qwen/Qwen2.5-1.8B-Instruct',
266 | 'meta-llama/Llama-3.3-70B-Instruct',
267 | 'meta-llama/Llama-3.1-8B-Instruct'
268 | ]})
269 | except ImportError:
270 | pass
271 |
272 | try:
273 | from .groq_gradio import registry as groq_registry
274 | registry.update({f"groq:{k}": groq_registry for k in [
275 | "llama3-groq-8b-8192-tool-use-preview",
276 | "llama3-groq-70b-8192-tool-use-preview",
277 | "llama-3.2-1b-preview",
278 | "llama-3.2-3b-preview",
279 | "llama-3.2-11b-vision-preview",
280 | "llama-3.2-90b-vision-preview",
281 | "mixtral-8x7b-32768",
282 | "gemma2-9b-it",
283 | "gemma-7b-it",
284 | "llama-3.3-70b-versatile",
285 | "llama-3.3-70b-specdec",
286 | "deepseek-r1-distill-llama-70b"
287 | ]})
288 | except ImportError:
289 | pass
290 |
291 | try:
292 | from .browser_use_gradio import registry as browser_use_registry
293 | registry.update({f"browser:{k}": browser_use_registry for k in [
294 | "gpt-4o-2024-11-20",
295 | "gpt-4o",
296 | "gpt-4o-2024-08-06",
297 | "gpt-4o-2024-05-13",
298 | "chatgpt-4o-latest",
299 | "gpt-4o-mini",
300 | "gpt-4o-mini-2024-07-18",
301 | "o1-preview",
302 | "o1-preview-2024-09-12",
303 | "o1-mini",
304 | "o1-mini-2024-09-12",
305 | "gpt-4-turbo",
306 | "gpt-4-turbo-2024-04-09",
307 | "gpt-4-turbo-preview",
308 | "gpt-4-0125-preview",
309 | "gpt-4-1106-preview",
310 | "gpt-4",
311 | "gpt-4-0613",
312 | "o1-2024-12-17",
313 | "gpt-4o-realtime-preview-2024-10-01",
314 | "gpt-4o-realtime-preview",
315 | "gpt-4o-realtime-preview-2024-12-17",
316 | "gpt-4o-mini-realtime-preview",
317 | "gpt-4o-mini-realtime-preview-2024-12-17",
318 | "gpt-3.5-turbo",
319 | "o3-mini-2025-01-31"
320 | ]})
321 | except ImportError:
322 | pass
323 |
324 | try:
325 | from .swarms_gradio import registry as swarms_registry
326 | registry.update({f"swarms:{k}": swarms_registry for k in [
327 | 'gpt-4-turbo',
328 | 'gpt-4o-mini',
329 | 'gpt-4',
330 | 'gpt-3.5-turbo'
331 | ]})
332 | except ImportError:
333 | pass
334 |
335 | try:
336 | from .transformers_gradio import registry as transformers_registry
337 | registry.update({f"transformers:{k}": transformers_registry for k in [
338 | "phi-4",
339 | "tulu-3",
340 | "olmo-2-13b",
341 | "smolvlm",
342 | "moondream",
343 | # Add other default transformers models here
344 | ]})
345 | except ImportError:
346 | pass
347 |
348 | try:
349 | from .jupyter_agent import registry as jupyter_registry
350 | registry.update({f"jupyter:{k}": jupyter_registry for k in [
351 | 'meta-llama/Llama-3.2-3B-Instruct',
352 | 'meta-llama/Llama-3.1-8B-Instruct',
353 | 'meta-llama/Llama-3.1-70B-Instruct'
354 | ]})
355 | except ImportError:
356 | pass
357 |
358 | try:
359 | from .langchain_gradio import registry as langchain_registry
360 | registry.update({f"langchain:{k}": langchain_registry for k in [
361 | 'gpt-4-turbo',
362 | 'gpt-4',
363 | 'gpt-3.5-turbo',
364 | 'gpt-3.5-turbo-0125'
365 | ]})
366 | except ImportError as e:
367 | print(f"Failed to import LangChain registry: {e}")
368 | # Optionally add more detailed error handling here
369 |
370 | try:
371 | from .mistral_gradio import registry as mistral_registry
372 | registry.update({f"mistral:{k}": mistral_registry for k in [
373 | "mistral-large-latest",
374 | "pixtral-large-latest",
375 | "ministral-3b-latest",
376 | "ministral-8b-latest",
377 | "mistral-small-latest",
378 | "codestral-latest",
379 | "mistral-embed",
380 | "mistral-moderation-latest",
381 | "pixtral-12b-2409",
382 | "open-mistral-nemo",
383 | "open-codestral-mamba",
384 | ]})
385 | except ImportError:
386 | pass
387 |
388 | try:
389 | from .nvidia_gradio import registry as nvidia_registry
390 | registry.update({f"nvidia:{k}": nvidia_registry for k in [
391 | "nvidia/llama3-chatqa-1.5-70b",
392 | "nvidia/cosmos-nemotron-34b",
393 | "nvidia/llama3-chatqa-1.5-8b",
394 | "nvidia-nemotron-4-340b-instruct",
395 | "meta/llama-3.1-70b-instruct",
396 | "meta/codellama-70b",
397 | "meta/llama2-70b",
398 | "meta/llama3-8b",
399 | "meta/llama3-70b",
400 | "mistralai/codestral-22b-instruct-v0.1",
401 | "mistralai/mathstral-7b-v0.1",
402 | "mistralai/mistral-large-2-instruct",
403 | "mistralai/mistral-7b-instruct",
404 | "mistralai/mistral-7b-instruct-v0.3",
405 | "mistralai/mixtral-8x7b-instruct",
406 | "mistralai/mixtral-8x22b-instruct",
407 | "mistralai/mistral-large",
408 | "google/gemma-2b",
409 | "google/gemma-7b",
410 | "google/gemma-2-2b-it",
411 | "google/gemma-2-9b-it",
412 | "google/gemma-2-27b-it",
413 | "google/codegemma-1.1-7b",
414 | "google/codegemma-7b",
415 | "google/recurrentgemma-2b",
416 | "google/shieldgemma-9b",
417 | "microsoft/phi-3-medium-128k-instruct",
418 | "microsoft/phi-3-medium-4k-instruct",
419 | "microsoft/phi-3-mini-128k-instruct",
420 | "microsoft/phi-3-mini-4k-instruct",
421 | "microsoft/phi-3-small-128k-instruct",
422 | "microsoft/phi-3-small-8k-instruct",
423 | "qwen/qwen2-7b-instruct",
424 | "databricks/dbrx-instruct",
425 | "deepseek-ai/deepseek-coder-6.7b-instruct",
426 | "upstage/solar-10.7b-instruct",
427 | "snowflake/arctic",
428 | "qwen/qwen2.5-7b-instruct",
429 | "deepseek-ai/deepseek-r1"
430 | ]})
431 | except ImportError:
432 | pass
433 |
434 | try:
435 | from .minimax_gradio import registry as minimax_registry
436 | registry.update({f"minimax:{k}": minimax_registry for k in [
437 | "MiniMax-Text-01",
438 | ]})
439 | except ImportError:
440 | pass
441 |
442 | try:
443 | from .kokoro_gradio import registry as kokoro_registry
444 | registry.update({f"kokoro:{k}": kokoro_registry for k in [
445 | "kokoro-v0_19"
446 | ]})
447 | except ImportError:
448 | pass
449 |
450 | try:
451 | from .perplexity_gradio import registry as perplexity_registry
452 | registry.update({f"perplexity:{k}": perplexity_registry for k in [
453 | 'sonar-pro',
454 | 'sonar',
455 | 'sonar-reasoning'
456 | ]})
457 | except ImportError:
458 | pass
459 |
460 | try:
461 | from .cerebras_gradio import registry as cerebras_registry
462 | registry.update({f"cerebras:{k}": cerebras_registry for k in [
463 | 'deepseek-r1-distill-llama-70b',
464 | ]})
465 | except ImportError:
466 | pass
467 |
468 | try:
469 | from .replicate_gradio import registry as replicate_registry
470 | registry.update({f"replicate:{k}": replicate_registry for k in [
471 | # Text to Image Models
472 | "stability-ai/sdxl",
473 | "black-forest-labs/flux-pro",
474 | "stability-ai/stable-diffusion",
475 |
476 | # Control Net Models
477 | "black-forest-labs/flux-depth-pro",
478 | "black-forest-labs/flux-canny-pro",
479 | "black-forest-labs/flux-depth-dev",
480 |
481 | # Inpainting Models
482 | "black-forest-labs/flux-fill-pro",
483 | "stability-ai/stable-diffusion-inpainting",
484 |
485 | # Text to Video Models
486 | "tencent/hunyuan-video:140176772be3b423d14fdaf5403e6d4e38b85646ccad0c3fd2ed07c211f0cad1",
487 |
488 | # Text Generation Models
489 | "deepseek-ai/deepseek-r1"
490 | ]})
491 | except ImportError:
492 | pass
493 |
494 | if not registry:
495 | raise ImportError(
496 | "No providers installed. Install with either:\n"
497 | "pip install 'ai-gradio[openai]' for OpenAI support\n"
498 | "pip install 'ai-gradio[gemini]' for Gemini support\n"
499 | "pip install 'ai-gradio[crewai]' for CrewAI support\n"
500 | "pip install 'ai-gradio[anthropic]' for Anthropic support\n"
501 | "pip install 'ai-gradio[lumaai]' for LumaAI support\n"
502 | "pip install 'ai-gradio[xai]' for X.AI support\n"
503 | "pip install 'ai-gradio[cohere]' for Cohere support\n"
504 | "pip install 'ai-gradio[sambanova]' for SambaNova support\n"
505 | "pip install 'ai-gradio[hyperbolic]' for Hyperbolic support\n"
506 | "pip install 'ai-gradio[qwen]' for Qwen support\n"
507 | "pip install 'ai-gradio[fireworks]' for Fireworks support\n"
508 | "pip install 'ai-gradio[deepseek]' for DeepSeek support\n"
509 | "pip install 'ai-gradio[smolagents]' for SmolaAgents support\n"
510 | "pip install 'ai-gradio[jupyter]' for Jupyter support\n"
511 | "pip install 'ai-gradio[langchain]' for LangChain support\n"
512 | "pip install 'ai-gradio[mistral]' for Mistral support\n"
513 | "pip install 'ai-gradio[nvidia]' for NVIDIA support\n"
514 | "pip install 'ai-gradio[minimax]' for MiniMax support\n"
515 | "pip install 'ai-gradio[kokoro]' for Kokoro support\n"
516 | "pip install 'ai-gradio[perplexity]' for Perplexity support\n"
517 | "pip install 'ai-gradio[cerebras]' for Cerebras support\n"
518 | "pip install 'ai-gradio[replicate]' for Replicate support\n"
519 | "pip install 'ai-gradio[all]' for all providers\n"
520 | "pip install 'ai-gradio[swarms]' for Swarms support"
521 | )
522 |
523 | __all__ = ["registry"]
524 |
--------------------------------------------------------------------------------