├── chat-gemini.png ├── chatinterface.png ├── requirements.txt ├── .gitignore ├── app.py ├── composition.py ├── custom_app.py ├── pyproject.toml ├── README.md └── gemini_gradio └── __init__.py /chat-gemini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AK391/gemini-gradio/HEAD/chat-gemini.png -------------------------------------------------------------------------------- /chatinterface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AK391/gemini-gradio/HEAD/chatinterface.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | google-generativeai 2 | gradio>=5.9.1 3 | websockets 4 | numpy 5 | gradio-webrtc 6 | twilio 7 | Pillow 8 | opencv-python 9 | librosa 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | __pycache__/ 3 | *.pyc 4 | 5 | # Virtual environment 6 | env/ 7 | .venv/ 8 | 9 | # Package artifacts 10 | dist/ 11 | build/ 12 | *.egg-info/ -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import gemini_gradio 3 | 4 | gr.load( 5 | name='gemini-2.0-flash-exp', 6 | src=gemini_gradio.registry, 7 | enable_video=True, 8 | enable_voice=True 9 | ).launch() 10 | -------------------------------------------------------------------------------- /composition.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import gemini_gradio 3 | 4 | with gr.Blocks() as demo: 5 | with gr.Tab("Gemini 1.5 Pro"): 6 | gr.load('gemini-1.5-flash', src=gemini_gradio.registry) 7 | with gr.Tab("Gemini 1.5 Flash"): 8 | gr.load('gemini-1.5-pro', src=gemini_gradio.registry) 9 | 10 | demo.launch() -------------------------------------------------------------------------------- /custom_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import gemini_gradio 3 | 4 | gr.load( 5 | name='gemini-1.5-pro', 6 | src=gemini_gradio.registry, 7 | title='Gemini Pro Integration', 8 | description="Chat with Google's Gemini 1.5 Pro model.", 9 | examples=["Explain quantum gravity to a 5-year old.", "Write a creative story about a magical library."] 10 | ).launch() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "setuptools>=61.0"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "gemini-gradio" 7 | version = "0.0.3" 8 | authors = [ 9 | {name = "AK", email = "ahsen.khaliq@gmail.com"} 10 | ] 11 | description = "A Python package for creating Gradio applications with Google Gemini models" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | license = {text = "MIT"} 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | dependencies = [ 21 | "gradio>=5.9.1", 22 | "google-generativeai", 23 | "gradio-webrtc", 24 | "numpy", 25 | "websockets", 26 | "twilio", 27 | "Pillow", 28 | "opencv-python", 29 | "librosa", 30 | ] 31 | 32 | [project.urls] 33 | homepage = "https://github.com/AK391/gemini-gradio" 34 | repository = "https://github.com/AK391/gemini-gradio" 35 | 36 | [project.optional-dependencies] 37 | dev = ["pytest"] 38 | 39 | [tool.hatch.build.targets.wheel] 40 | packages = ["gemini_gradio"] 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `gemini-gradio` 2 | 3 | is a Python package that makes it very easy for developers to create machine learning apps that are powered by Google's Gemini API. 4 | 5 | # Installation 6 | 7 | You can install gemini-gradio directly using pip: 8 | 9 | ```bash 10 | pip install gemini-gradio 11 | ``` 12 | 13 | # Basic Usage 14 | 15 | You'll need to set up your Gemini API key first: 16 | 17 | ```bash 18 | export GEMINI_API_KEY= 19 | ``` 20 | 21 | Then in a Python file, write: 22 | 23 | ```python 24 | import gradio as gr 25 | import gemini_gradio 26 | 27 | gr.load( 28 | name='gemini-1.5-pro-002', 29 | src=gemini_gradio.registry, 30 | ).launch() 31 | ``` 32 | 33 | Run the Python file, and you should see a Gradio Interface connected to the Gemini model! 34 | 35 | ![ChatInterface](chatinterface.png) 36 | 37 | # Voice Chat 38 | 39 | You can enable voice chat with Gemini by setting the `enable_voice` parameter: 40 | 41 | ```python 42 | import gradio as gr 43 | import gemini_gradio 44 | 45 | gr.load( 46 | name='gemini-2.0-flash-exp', 47 | src=gemini_gradio.registry, 48 | enable_voice=True 49 | ).launch() 50 | ``` 51 | 52 | This will create a voice interface where you can have a spoken conversation with the Gemini model using your microphone. 53 | 54 | ## Required API Keys for Voice Chat 55 | 56 | For voice chat functionality, you'll need: 57 | 1. `GEMINI_API_KEY` - Your Google Gemini API key 58 | 2. `GOOGLE_API_KEY` - Your Google API key (required for multimodal features) 59 | 3. Twilio credentials for WebRTC functionality: 60 | - `TWILIO_ACCOUNT_SID` 61 | - `TWILIO_AUTH_TOKEN` 62 | 63 | Make sure these environment variables are set before using the voice chat feature. 64 | 65 | # Customization 66 | 67 | Once you can create a Gradio UI from a Gemini endpoint, you can customize it by setting your own input and output components, or any other arguments to `gr.Interface`. For example, the screenshot below was generated with: 68 | 69 | ```py 70 | import gradio as gr 71 | import gemini_gradio 72 | 73 | gr.load( 74 | name='gemini-2.0-flash-exp', 75 | src=gemini_gradio.registry, 76 | title='Gemini-Gradio Integration', 77 | description="Chat with Gemini Pro model.", 78 | examples=["Explain quantum gravity to a 5-year old.", "How many R are there in the word Strawberry?"] 79 | ).launch() 80 | ``` 81 | ![ChatInterface with customizations](chat-gemini.png) 82 | 83 | # Composition 84 | 85 | Or use your loaded Interface within larger Gradio Web UIs, e.g. 86 | 87 | ```python 88 | import gradio as gr 89 | import gemini_gradio 90 | 91 | with gr.Blocks() as demo: 92 | with gr.Tab("Gemini Pro"): 93 | gr.load('gemini-1.5-pro-002', src=gemini_gradio.registry) 94 | with gr.Tab("gemini-1.5-flash"): 95 | gr.load('gemini-1.5-flash', src=gemini_gradio.registry) 96 | 97 | demo.launch() 98 | ``` 99 | 100 | # Under the Hood 101 | 102 | The `gemini-gradio` Python library has two dependencies: `google-generativeai` and `gradio`. It defines a "registry" function `gemini_gradio.registry`, which takes in a model name and returns a Gradio app. 103 | 104 | # Supported Models in Gemini 105 | 106 | All chat API models supported by Google's Gemini are compatible with this integration. For a comprehensive list of available models and their specifications, please refer to the [Google AI Studio documentation](https://ai.google.dev/models/gemini). 107 | 108 | ------- 109 | 110 | Note: if you are getting an authentication error, then the Gemini API Client is not able to get the API token from the environment variable. You can set it in your Python session like this: 111 | 112 | ```python 113 | import os 114 | 115 | os.environ["GEMINI_API_KEY"] = ... 116 | ``` -------------------------------------------------------------------------------- /gemini_gradio/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | import gradio as gr 4 | import google.generativeai as genai 5 | import base64 6 | import json 7 | import numpy as np 8 | import websockets.sync.client 9 | from gradio_webrtc import StreamHandler, WebRTC, get_twilio_turn_credentials 10 | import cv2 11 | import PIL.Image 12 | import io 13 | 14 | __version__ = "0.0.3" 15 | 16 | 17 | class GeminiConfig: 18 | def __init__(self): 19 | self.api_key = os.getenv("GEMINI_API_KEY") 20 | self.host = "generativelanguage.googleapis.com" 21 | self.model = "models/gemini-2.0-flash-exp" 22 | self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}" 23 | 24 | 25 | class AudioProcessor: 26 | @staticmethod 27 | def encode_audio(data, sample_rate): 28 | encoded = base64.b64encode(data.tobytes()).decode("UTF-8") 29 | return { 30 | "realtimeInput": { 31 | "mediaChunks": [ 32 | { 33 | "mimeType": f"audio/pcm;rate={sample_rate}", 34 | "data": encoded, 35 | } 36 | ], 37 | }, 38 | } 39 | 40 | @staticmethod 41 | def process_audio_response(data): 42 | audio_data = base64.b64decode(data) 43 | return np.frombuffer(audio_data, dtype=np.int16) 44 | 45 | 46 | def detection(frame, conf_threshold=0.3): 47 | """Process video frame.""" 48 | try: 49 | # Convert BGR to RGB 50 | image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 51 | 52 | # Create PIL Image 53 | pil_image = PIL.Image.fromarray(image_rgb) 54 | pil_image.thumbnail([1024, 1024]) 55 | 56 | # Convert back to numpy array 57 | processed_frame = np.array(pil_image) 58 | 59 | # Convert back to BGR for OpenCV 60 | processed_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR) 61 | 62 | return processed_frame 63 | except Exception as e: 64 | print(f"Error processing frame: {e}") 65 | return frame 66 | 67 | 68 | class GeminiHandler(StreamHandler): 69 | def __init__(self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480) -> None: 70 | super().__init__(expected_layout, output_sample_rate, output_frame_size, input_sample_rate=24000) 71 | self.config = GeminiConfig() 72 | self.ws = None 73 | self.all_output_data = None 74 | self.audio_processor = AudioProcessor() 75 | self.current_frame = None 76 | 77 | def copy(self): 78 | handler = GeminiHandler( 79 | expected_layout=self.expected_layout, 80 | output_sample_rate=self.output_sample_rate, 81 | output_frame_size=self.output_frame_size, 82 | ) 83 | return handler 84 | 85 | def _initialize_websocket(self): 86 | try: 87 | self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30) 88 | initial_request = { 89 | "setup": { 90 | "model": self.config.model, 91 | } 92 | } 93 | self.ws.send(json.dumps(initial_request)) 94 | setup_response = json.loads(self.ws.recv()) 95 | print(f"Setup response: {setup_response}") 96 | except websockets.exceptions.WebSocketException as e: 97 | print(f"WebSocket connection failed: {str(e)}") 98 | self.ws = None 99 | except Exception as e: 100 | print(f"Setup failed: {str(e)}") 101 | self.ws = None 102 | 103 | def process_video_frame(self, frame): 104 | self.current_frame = frame 105 | _, buffer = cv2.imencode('.jpg', frame) 106 | image_data = base64.b64encode(buffer).decode('utf-8') 107 | return image_data 108 | 109 | def receive(self, frame: tuple[int, np.ndarray]) -> None: 110 | try: 111 | if not self.ws: 112 | self._initialize_websocket() 113 | 114 | _, array = frame 115 | array = array.squeeze() 116 | 117 | audio_data = self.audio_processor.encode_audio(array, self.output_sample_rate) 118 | 119 | message = { 120 | "realtimeInput": { 121 | "mediaChunks": [ 122 | { 123 | "mimeType": f"audio/pcm;rate={self.output_sample_rate}", 124 | "data": audio_data["realtimeInput"]["mediaChunks"][0]["data"], 125 | } 126 | ], 127 | } 128 | } 129 | 130 | if self.current_frame is not None: 131 | image_data = self.process_video_frame(self.current_frame) 132 | message["realtimeInput"]["mediaChunks"].append({ 133 | "mimeType": "image/jpeg", 134 | "data": image_data 135 | }) 136 | 137 | self.ws.send(json.dumps(message)) 138 | except Exception as e: 139 | print(f"Error in receive: {str(e)}") 140 | if self.ws: 141 | self.ws.close() 142 | self.ws = None 143 | 144 | def _process_server_content(self, content): 145 | for part in content.get("parts", []): 146 | data = part.get("inlineData", {}).get("data", "") 147 | if data: 148 | audio_array = self.audio_processor.process_audio_response(data) 149 | if self.all_output_data is None: 150 | self.all_output_data = audio_array 151 | else: 152 | self.all_output_data = np.concatenate((self.all_output_data, audio_array)) 153 | 154 | while self.all_output_data.shape[-1] >= self.output_frame_size: 155 | yield (self.output_sample_rate, self.all_output_data[: self.output_frame_size].reshape(1, -1)) 156 | self.all_output_data = self.all_output_data[self.output_frame_size :] 157 | 158 | def generator(self): 159 | while True: 160 | if not self.ws: 161 | print("WebSocket not connected") 162 | yield None 163 | continue 164 | 165 | try: 166 | message = self.ws.recv(timeout=5) 167 | msg = json.loads(message) 168 | 169 | if "serverContent" in msg: 170 | content = msg["serverContent"].get("modelTurn", {}) 171 | yield from self._process_server_content(content) 172 | except TimeoutError: 173 | print("Timeout waiting for server response") 174 | yield None 175 | except Exception as e: 176 | print(f"Error in generator: {str(e)}") 177 | yield None 178 | 179 | def emit(self) -> tuple[int, np.ndarray] | None: 180 | if not self.ws: 181 | return None 182 | if not hasattr(self, "_generator"): 183 | self._generator = self.generator() 184 | try: 185 | return next(self._generator) 186 | except StopIteration: 187 | self.reset() 188 | return None 189 | 190 | def reset(self) -> None: 191 | if hasattr(self, "_generator"): 192 | delattr(self, "_generator") 193 | self.all_output_data = None 194 | 195 | def shutdown(self) -> None: 196 | if self.ws: 197 | self.ws.close() 198 | 199 | def check_connection(self): 200 | try: 201 | if not self.ws or self.ws.closed: 202 | self._initialize_websocket() 203 | return True 204 | except Exception as e: 205 | print(f"Connection check failed: {str(e)}") 206 | return False 207 | 208 | 209 | def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): 210 | def fn(message, history, enable_search): 211 | inputs = preprocess(message, history, enable_search) 212 | is_gemini = model_name.startswith("gemini-") 213 | 214 | if is_gemini: 215 | genai.configure(api_key=api_key) 216 | 217 | generation_config = { 218 | "temperature": 1, 219 | "top_p": 0.95, 220 | "top_k": 40, 221 | "max_output_tokens": 8192, 222 | "response_mime_type": "text/plain", 223 | } 224 | 225 | model = genai.GenerativeModel( 226 | model_name=model_name, 227 | generation_config=generation_config 228 | ) 229 | 230 | chat = model.start_chat(history=inputs.get("history", [])) 231 | 232 | if inputs.get("enable_search"): 233 | response = chat.send_message( 234 | inputs["message"], 235 | stream=True, 236 | tools='google_search_retrieval' 237 | ) 238 | else: 239 | response = chat.send_message(inputs["message"], stream=True) 240 | 241 | response_text = "" 242 | for chunk in response: 243 | if chunk.text: 244 | response_text += chunk.text 245 | yield {"role": "assistant", "content": response_text} 246 | 247 | return fn 248 | 249 | 250 | def get_interface_args(pipeline, model_name: str): 251 | if pipeline == "chat": 252 | inputs = [gr.Checkbox(label="Enable Search", value=False)] 253 | outputs = None 254 | 255 | def preprocess(message, history, enable_search): 256 | is_gemini = model_name.startswith("gemini-") 257 | if is_gemini: 258 | # Handle multimodal input 259 | if isinstance(message, dict): 260 | parts = [] 261 | if message.get("text"): 262 | parts.append({"text": message["text"]}) 263 | if message.get("files"): 264 | for file in message["files"]: 265 | # Determine file type and handle accordingly 266 | if isinstance(file, str): # If it's a file path 267 | mime_type = None 268 | if file.lower().endswith('.pdf'): 269 | mime_type = "application/pdf" 270 | elif file.lower().endswith('.txt'): 271 | mime_type = "text/plain" 272 | elif file.lower().endswith('.html'): 273 | mime_type = "text/html" 274 | elif file.lower().endswith('.md'): 275 | mime_type = "text/md" 276 | elif file.lower().endswith('.csv'): 277 | mime_type = "text/csv" 278 | elif file.lower().endswith(('.js', '.javascript')): 279 | mime_type = "application/x-javascript" 280 | elif file.lower().endswith('.py'): 281 | mime_type = "application/x-python" 282 | 283 | if mime_type: 284 | try: 285 | uploaded_file = genai.upload_file(file) 286 | parts.append(uploaded_file) 287 | except Exception as e: 288 | print(f"Error uploading file: {e}") 289 | else: 290 | with open(file, "rb") as f: 291 | image_data = f.read() 292 | import base64 293 | image_data = base64.b64encode(image_data).decode() 294 | parts.append({ 295 | "inline_data": { 296 | "mime_type": "image/jpeg", 297 | "data": image_data 298 | } 299 | }) 300 | else: # If it's binary data, treat as image 301 | import base64 302 | image_data = base64.b64encode(file).decode() 303 | parts.append({ 304 | "inline_data": { 305 | "mime_type": "image/jpeg", 306 | "data": image_data 307 | } 308 | }) 309 | message_parts = parts 310 | else: 311 | message_parts = [{"text": message}] 312 | 313 | # Process history 314 | gemini_history = [] 315 | for entry in history: 316 | # Handle different history formats 317 | if isinstance(entry, (list, tuple)): 318 | user_msg, assistant_msg = entry 319 | else: 320 | # If it's a dict with role/content format 321 | if entry.get("role") == "user": 322 | user_msg = entry.get("content") 323 | continue # Skip to next iteration to get assistant message 324 | elif entry.get("role") == "assistant": 325 | assistant_msg = entry.get("content") 326 | continue # Skip to next iteration 327 | else: 328 | continue # Skip unknown roles 329 | 330 | # Process user message 331 | if isinstance(user_msg, dict): 332 | parts = [] 333 | if user_msg.get("text"): 334 | parts.append({"text": user_msg["text"]}) 335 | if user_msg.get("files"): 336 | for file in user_msg["files"]: 337 | if isinstance(file, str): 338 | mime_type = None 339 | if file.lower().endswith('.pdf'): 340 | mime_type = "application/pdf" 341 | # ... (same mime type checks as before) 342 | 343 | if mime_type: 344 | try: 345 | uploaded_file = genai.upload_file(file) 346 | parts.append(uploaded_file) 347 | except Exception as e: 348 | print(f"Error uploading file in history: {e}") 349 | else: 350 | with open(file, "rb") as f: 351 | image_data = f.read() 352 | import base64 353 | image_data = base64.b64encode(image_data).decode() 354 | parts.append({ 355 | "inline_data": { 356 | "mime_type": "image/jpeg", 357 | "data": image_data 358 | } 359 | }) 360 | else: 361 | import base64 362 | image_data = base64.b64encode(file).decode() 363 | parts.append({ 364 | "inline_data": { 365 | "mime_type": "image/jpeg", 366 | "data": image_data 367 | } 368 | }) 369 | gemini_history.append({ 370 | "role": "user", 371 | "parts": parts 372 | }) 373 | else: 374 | gemini_history.append({ 375 | "role": "user", 376 | "parts": [{"text": str(user_msg)}] 377 | }) 378 | 379 | # Process assistant message 380 | gemini_history.append({ 381 | "role": "model", 382 | "parts": [{"text": str(assistant_msg)}] 383 | }) 384 | 385 | return { 386 | "history": gemini_history, 387 | "message": message_parts, 388 | "enable_search": enable_search 389 | } 390 | else: 391 | messages = [] 392 | for user_msg, assistant_msg in history: 393 | messages.append({"role": "user", "content": user_msg}) 394 | messages.append({"role": "assistant", "content": assistant_msg}) 395 | messages.append({"role": "user", "content": message}) 396 | return {"messages": messages} 397 | 398 | postprocess = lambda x: x 399 | else: 400 | raise ValueError(f"Unsupported pipeline type: {pipeline}") 401 | return inputs, outputs, preprocess, postprocess 402 | 403 | 404 | def get_pipeline(model_name): 405 | return "chat" 406 | 407 | 408 | def registry( 409 | name: str, 410 | token: str | None = None, 411 | examples: list | None = None, 412 | enable_voice: bool = False, 413 | enable_video: bool = False, 414 | **kwargs 415 | ): 416 | env_key = "GEMINI_API_KEY" 417 | api_key = token or os.environ.get(env_key) 418 | if not api_key: 419 | raise ValueError(f"{env_key} environment variable is not set.") 420 | 421 | pipeline = get_pipeline(name) 422 | inputs, outputs, preprocess, postprocess = get_interface_args(pipeline, name) 423 | fn = get_fn(name, preprocess, postprocess, api_key) 424 | 425 | if examples: 426 | formatted_examples = [[example, False] for example in examples] 427 | kwargs["examples"] = formatted_examples 428 | 429 | if pipeline == "chat": 430 | if enable_voice or enable_video: 431 | interface = gr.Blocks() 432 | with interface: 433 | gr.HTML( 434 | """ 435 |
436 |

Gemini Chat

437 |
438 | """ 439 | ) 440 | 441 | gemini_handler = GeminiHandler() 442 | 443 | with gr.Row(): 444 | with gr.Column(scale=1): 445 | if enable_video: 446 | video = WebRTC( 447 | label="Stream", 448 | mode="send-receive", 449 | modality="video", 450 | rtc_configuration=get_twilio_turn_credentials() 451 | ) 452 | 453 | if enable_voice: 454 | audio = WebRTC( 455 | label="Voice Chat", 456 | modality="audio", 457 | mode="send-receive", 458 | rtc_configuration=get_twilio_turn_credentials(), 459 | ) 460 | 461 | if enable_video: 462 | video.stream( 463 | fn=lambda frame: (frame, detection(frame)), 464 | inputs=[video], 465 | outputs=[video], 466 | time_limit=90, 467 | concurrency_limit=10 468 | ) 469 | 470 | if enable_voice: 471 | audio.stream( 472 | gemini_handler, 473 | inputs=[audio], 474 | outputs=[audio], 475 | time_limit=90, 476 | concurrency_limit=10 477 | ) 478 | else: 479 | interface = gr.ChatInterface( 480 | fn=fn, 481 | additional_inputs=inputs, 482 | multimodal=True, 483 | type="messages", 484 | **kwargs 485 | ) 486 | else: 487 | interface = gr.Interface(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 488 | 489 | return interface 490 | --------------------------------------------------------------------------------