├── .gitignore ├── LICENSE ├── README.md ├── app_config.py ├── index.html ├── main.py ├── openai_client.py ├── podcast_generator.py ├── requirements.txt ├── static ├── script.js └── styles.css └── turn_handler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 aymenfurter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Podcast Generator 2 | 3 | AI Podcast Generator is an demo app that leverages OpenAI's language models to create engaging podcast scripts effortlessly. Simply input your desired topic or content, and the AI will generate a comprehensive podcast. Enhance your podcasting experience by asking audience questions in real time. 4 | 5 | [Watch the Demo Video](https://www.youtube.com/watch?v=CrdZtwO6x6o) 6 | 7 | ## Installation 8 | 9 | 1. **Clone the Repository:** 10 | ```bash 11 | git clone https://github.com/aymenfurter/ai-podcast-generator.git 12 | cd ai-podcast-generator 13 | ``` 14 | 15 | 2. **Create a Virtual Environment:** 16 | ```bash 17 | python -m venv venv 18 | source venv/bin/activate # On Windows: venv\Scripts\activate 19 | ``` 20 | 21 | 3. **Install Dependencies:** 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 4. **Set Up Environment Variables:** 27 | - Create a `.env` file in the root directory. 28 | - Add the following variables (You'll need an Azure OpenAI Instance): 29 | ```env 30 | OPENAI_API_KEY=your_openai_api_key 31 | OPENAI_API_BASE=your_openai_api_base 32 | OPENAI_DEPLOYMENT_NAME=your_deployment_name 33 | OPENAI_REALTIME_DEPLOYMENT_NAME=your_realtime_deployment_name 34 | OPENAI_API_KEY_B=your_secondary_openai_api_key 35 | OPENAI_API_BASE_B=your_secondary_openai_api_base 36 | ``` 37 | 38 | ## Usage 39 | 40 | 1. **Run the Application:** 41 | ```bash 42 | python main.py 43 | ``` 44 | 45 | 2. **Access the Web Interface:** 46 | - Open your browser and navigate to `http://localhost:8000`. 47 | - Enter your podcast topic and generate your script. 48 | 49 | ## License 50 | 51 | This project is licensed under the [MIT License](LICENSE). -------------------------------------------------------------------------------- /app_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | load_dotenv() 5 | 6 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 7 | OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") 8 | OPENAI_DEPLOYMENT_NAME = os.getenv("OPENAI_DEPLOYMENT_NAME") 9 | OPENAI_REALTIME_DEPLOYMENT_NAME = os.getenv("OPENAI_REALTIME_DEPLOYMENT_NAME") 10 | OPENAI_API_KEY_B = os.getenv("OPENAI_API_KEY_B") 11 | OPENAI_API_BASE_B = os.getenv("OPENAI_API_BASE_B") 12 | MAX_RETRIES = 15 13 | RETRY_DELAY = 5 14 | 15 | required_env_vars = [ 16 | "OPENAI_API_KEY", 17 | "OPENAI_API_BASE", 18 | "OPENAI_DEPLOYMENT_NAME", 19 | "OPENAI_REALTIME_DEPLOYMENT_NAME", 20 | "OPENAI_API_KEY_B", 21 | "OPENAI_API_BASE_B" 22 | ] 23 | 24 | for var in required_env_vars: 25 | if not os.getenv(var): 26 | raise EnvironmentError(f"Missing required environment variable: {var}") 27 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI Podcast Generator 7 | 8 | 9 | 10 |
11 |

AI Podcast Generator

12 |
13 | 14 | 15 | 16 |
17 | 22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | 45 | 46 |
47 |
48 | 49 | 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import FastAPI, HTTPException, Depends 3 | from fastapi.responses import HTMLResponse 4 | from fastapi.staticfiles import StaticFiles 5 | from pydantic import BaseModel 6 | from typing import Optional, Dict, Any 7 | 8 | from openai_client import OpenAIClient, OpenAIClientError 9 | from podcast_generator import PodcastGenerator 10 | from turn_handler import TurnHandler, TurnHandlerError 11 | from app_config import ( 12 | OPENAI_API_KEY, 13 | OPENAI_API_BASE, 14 | OPENAI_DEPLOYMENT_NAME, 15 | OPENAI_REALTIME_DEPLOYMENT_NAME, 16 | OPENAI_API_KEY_B, 17 | OPENAI_API_BASE_B 18 | ) 19 | 20 | logging.basicConfig( 21 | level=logging.INFO, 22 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s' 23 | ) 24 | logger = logging.getLogger(__name__) 25 | 26 | async def lifespan(app: FastAPI): 27 | """ 28 | Lifespan event handler to manage startup and shutdown events. 29 | Initializes shared clients on startup and ensures proper closure on shutdown. 30 | """ 31 | logger.info("Lifespan startup: Initializing OpenAIClient and TurnHandler.") 32 | 33 | app.state.openai_client = OpenAIClient( 34 | api_key=OPENAI_API_KEY, 35 | api_base=OPENAI_API_BASE, 36 | deployment_name=OPENAI_DEPLOYMENT_NAME, 37 | realtime_deployment_name=OPENAI_REALTIME_DEPLOYMENT_NAME 38 | ) 39 | await app.state.openai_client.__aenter__() 40 | 41 | app.state.turn_handler = TurnHandler( 42 | api_key=OPENAI_API_KEY, 43 | api_base=OPENAI_API_BASE, 44 | api_key_b=OPENAI_API_KEY_B, 45 | api_base_b=OPENAI_API_BASE_B, 46 | realtime_deployment_name=OPENAI_REALTIME_DEPLOYMENT_NAME 47 | ) 48 | await app.state.turn_handler._get_session() 49 | 50 | logger.info("OpenAIClient and TurnHandler initialized successfully.") 51 | 52 | try: 53 | yield 54 | finally: 55 | logger.info("Lifespan shutdown: Closing OpenAIClient and TurnHandler.") 56 | 57 | await app.state.openai_client.close() 58 | await app.state.turn_handler.close() 59 | 60 | logger.info("OpenAIClient and TurnHandler closed successfully.") 61 | 62 | app = FastAPI( 63 | title="Podcast Generator API", 64 | lifespan=lifespan 65 | ) 66 | 67 | app.mount("/static", StaticFiles(directory="static"), name="static") 68 | 69 | class PodcastRequest(BaseModel): 70 | topic: str 71 | 72 | class TurnRequest(BaseModel): 73 | podcast_script: str 74 | combined_transcript: str 75 | audience_question: Optional[str] = None 76 | turn: int 77 | 78 | async def get_openai_client() -> OpenAIClient: 79 | return app.state.openai_client 80 | 81 | async def get_podcast_generator(client: OpenAIClient = Depends(get_openai_client)) -> PodcastGenerator: 82 | return PodcastGenerator(client) 83 | 84 | async def get_turn_handler() -> TurnHandler: 85 | return app.state.turn_handler 86 | 87 | @app.post("/generate_podcast_script") 88 | async def generate_podcast_script( 89 | request: PodcastRequest, 90 | generator: PodcastGenerator = Depends(get_podcast_generator) 91 | ) -> Dict[str, Any]: 92 | """ 93 | Generate a podcast script based on the given topic. 94 | 95 | Args: 96 | request (PodcastRequest): The request containing the podcast topic. 97 | generator (PodcastGenerator): The injected PodcastGenerator instance. 98 | 99 | Returns: 100 | dict: A dictionary containing the generated podcast script. 101 | 102 | Raises: 103 | HTTPException: If the podcast script generation fails. 104 | """ 105 | logger.info(f"Received request to generate podcast script for topic: '{request.topic}'") 106 | try: 107 | podcast_script = await generator.generate_full_podcast(request.topic) 108 | logger.info("Podcast script generated successfully.") 109 | return {"podcast_script": podcast_script} 110 | except OpenAIClientError as e: 111 | logger.error(f"Failed to generate podcast script: {e}") 112 | raise HTTPException(status_code=500, detail="Failed to generate podcast script.") from e 113 | except Exception as e: 114 | logger.exception(f"Unexpected error during podcast script generation: {e}") 115 | raise HTTPException(status_code=500, detail="An unexpected error occurred.") from e 116 | 117 | @app.post("/next_turn") 118 | async def next_turn( 119 | request: TurnRequest, 120 | turn_handler: TurnHandler = Depends(get_turn_handler) 121 | ) -> Dict[str, Any]: 122 | """ 123 | Handle the next turn in the podcast conversation. 124 | 125 | Args: 126 | request (TurnRequest): The request containing the turn information. 127 | turn_handler (TurnHandler): The injected TurnHandler instance. 128 | 129 | Returns: 130 | dict: A dictionary containing the speaker, transcript, and audio data. 131 | 132 | Raises: 133 | HTTPException: If handling the turn fails. 134 | """ 135 | logger.info(f"Received request to handle turn {request.turn}.") 136 | try: 137 | response = await turn_handler.handle_turn(request) 138 | logger.info(f"Turn {request.turn} handled successfully.") 139 | return response 140 | except TurnHandlerError as e: 141 | logger.error(f"Failed to handle turn {request.turn}: {e}") 142 | raise HTTPException(status_code=500, detail="Failed to handle the turn.") from e 143 | except Exception as e: 144 | logger.exception(f"Unexpected error during turn {request.turn} handling: {e}") 145 | raise HTTPException(status_code=500, detail="An unexpected error occurred.") from e 146 | 147 | @app.get("/", response_class=HTMLResponse) 148 | async def read_root() -> str: 149 | """ 150 | Serve the main HTML page. 151 | 152 | Returns: 153 | str: The content of the index.html file. 154 | 155 | Raises: 156 | HTTPException: If index.html is not found. 157 | """ 158 | try: 159 | with open("index.html", "r", encoding="utf-8") as f: 160 | content = f.read() 161 | logger.info("Served index.html successfully.") 162 | return content 163 | except FileNotFoundError: 164 | logger.error("index.html not found.") 165 | raise HTTPException(status_code=404, detail="index.html not found.") 166 | except Exception as e: 167 | logger.exception(f"Error serving index.html: {e}") 168 | raise HTTPException(status_code=500, detail="Failed to serve the main page.") from e 169 | 170 | def main(): 171 | """ 172 | Entry point to run the FastAPI application using Uvicorn. 173 | """ 174 | import uvicorn 175 | uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /openai_client.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import logging 3 | from typing import List, Optional, Dict, Any 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger(__name__) 7 | 8 | class OpenAIClient: 9 | """ 10 | A client for interacting with the OpenAI API. 11 | """ 12 | 13 | API_VERSION = "2023-05-15" 14 | 15 | def __init__( 16 | self, 17 | api_key: str, 18 | api_base: str, 19 | deployment_name: str, 20 | realtime_deployment_name: str 21 | ) -> None: 22 | """ 23 | Initialize the OpenAIClient. 24 | 25 | Args: 26 | api_key (str): The API key for authentication. 27 | api_base (str): The base URL for the API. 28 | deployment_name (str): The name of the deployment. 29 | realtime_deployment_name (str): The name of the realtime deployment. 30 | """ 31 | self.api_key = api_key 32 | self.api_base = api_base.rstrip('/') # Ensure no trailing slash 33 | self.deployment_name = deployment_name 34 | self.realtime_deployment_name = realtime_deployment_name 35 | self.session: Optional[aiohttp.ClientSession] = None 36 | 37 | async def _get_session(self) -> aiohttp.ClientSession: 38 | """ 39 | Get or create an aiohttp ClientSession. 40 | 41 | Returns: 42 | aiohttp.ClientSession: The aiohttp session. 43 | """ 44 | if self.session is None or self.session.closed: 45 | self.session = aiohttp.ClientSession() 46 | logger.debug("Created new aiohttp ClientSession.") 47 | return self.session 48 | 49 | async def close(self) -> None: 50 | """ 51 | Close the aiohttp ClientSession if it's open. 52 | """ 53 | if self.session and not self.session.closed: 54 | await self.session.close() 55 | logger.debug("Closed aiohttp ClientSession.") 56 | 57 | async def create_chat_completion( 58 | self, 59 | messages: List[Dict[str, Any]], 60 | max_tokens: Optional[int] = None 61 | ) -> str: 62 | """ 63 | Create a chat completion using the OpenAI API. 64 | 65 | Args: 66 | messages (List[Dict[str, Any]]): A list of message dictionaries. 67 | max_tokens (Optional[int], optional): The maximum number of tokens to generate. 68 | 69 | Returns: 70 | str: The generated chat completion content. 71 | 72 | Raises: 73 | OpenAIClientError: If the chat completion request fails. 74 | """ 75 | url = ( 76 | f"{self.api_base}/openai/deployments/{self.deployment_name}/chat/completions" 77 | f"?api-version={self.API_VERSION}" 78 | ) 79 | headers = { 80 | "Content-Type": "application/json", 81 | "api-key": self.api_key 82 | } 83 | payload: Dict[str, Any] = { 84 | "messages": messages 85 | } 86 | if max_tokens is not None: 87 | payload["max_tokens"] = max_tokens 88 | 89 | logger.debug(f"Sending request to {url} with payload: {payload}") 90 | 91 | session = await self._get_session() 92 | try: 93 | async with session.post(url, headers=headers, json=payload) as response: 94 | response_text = await response.text() 95 | logger.debug(f"Received response status: {response.status}") 96 | logger.debug(f"Response text: {response_text}") 97 | 98 | if response.status != 200: 99 | logger.error( 100 | f"Chat completion failed with status {response.status}: {response_text}" 101 | ) 102 | raise OpenAIClientError( 103 | f"Chat completion failed: {response.status} - {response_text}" 104 | ) 105 | 106 | result = await response.json() 107 | completion = result.get('choices', [{}])[0].get('message', {}).get('content', '') 108 | 109 | if not completion: 110 | logger.warning("No content found in the response.") 111 | raise OpenAIClientError("No content found in the response.") 112 | 113 | logger.info("Chat completion successful.") 114 | return completion 115 | 116 | except aiohttp.ClientError as e: 117 | logger.exception("HTTP request failed.") 118 | raise OpenAIClientError(f"HTTP request failed: {e}") from e 119 | 120 | async def __aenter__(self) -> 'OpenAIClient': 121 | """ 122 | Enter the runtime context related to this object. 123 | 124 | Returns: 125 | OpenAIClient: The client instance. 126 | """ 127 | await self._get_session() 128 | return self 129 | 130 | async def __aexit__(self, exc_type, exc, tb) -> None: 131 | """ 132 | Exit the runtime context and close the session. 133 | """ 134 | await self.close() 135 | 136 | 137 | class OpenAIClientError(Exception): 138 | """Custom exception class for OpenAIClient errors.""" 139 | pass -------------------------------------------------------------------------------- /podcast_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import logging 3 | from openai_client import OpenAIClient 4 | from openai_client import OpenAIClientError 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class PodcastGenerator: 9 | """ 10 | A class for generating podcast content using OpenAI's language model. 11 | """ 12 | 13 | def __init__(self, client: OpenAIClient) -> None: 14 | """ 15 | Initialize the PodcastGenerator. 16 | 17 | Args: 18 | client (OpenAIClient): An instance of the OpenAIClient. 19 | """ 20 | self.client = client 21 | logger.debug("PodcastGenerator initialized with OpenAIClient.") 22 | 23 | async def generate_summary(self, text: str, max_tokens: int = 300) -> str: 24 | """ 25 | Generate a summary of the given text. 26 | 27 | Args: 28 | text (str): The input text to summarize. 29 | max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 300. 30 | 31 | Returns: 32 | str: The generated summary. 33 | 34 | Raises: 35 | OpenAIClientError: If the summary generation fails. 36 | """ 37 | messages = [ 38 | { 39 | "role": "system", 40 | "content": "You are a skilled summarizer. Create a concise summary of the following text." 41 | }, 42 | { 43 | "role": "user", 44 | "content": text 45 | } 46 | ] 47 | logger.debug("Generating summary with provided text.") 48 | summary = await self.client.create_chat_completion(messages, max_tokens=max_tokens) 49 | logger.info("Summary generated successfully.") 50 | return summary 51 | 52 | async def chain_of_density(self, summary: str, iterations: int = 3, max_tokens: int = 300) -> str: 53 | """ 54 | Apply the chain of density technique to compress the summary. 55 | 56 | Args: 57 | summary (str): The initial summary. 58 | iterations (int, optional): The number of compression iterations. Defaults to 3. 59 | max_tokens (int, optional): The maximum number of tokens to generate per iteration. Defaults to 300. 60 | 61 | Returns: 62 | str: The compressed summary. 63 | 64 | Raises: 65 | OpenAIClientError: If the compression process fails. 66 | """ 67 | dense_summary = summary 68 | logger.debug(f"Starting chain of density with {iterations} iterations.") 69 | for i in range(1, iterations + 1): 70 | messages = [ 71 | { 72 | "role": "system", 73 | "content": ( 74 | "You are an expert in information compression. Your task is to make the given text " 75 | "more concise while preserving all key information. Aim to reduce the word count by " 76 | "25% without losing important content." 77 | ) 78 | }, 79 | { 80 | "role": "user", 81 | "content": ( 82 | f"Original text:\n{dense_summary}\n\n" 83 | "Compress this text, maintaining all key points but reducing verbosity." 84 | ) 85 | } 86 | ] 87 | logger.debug(f"Compression iteration {i} of {iterations}.") 88 | dense_summary = await self.client.create_chat_completion(messages, max_tokens=max_tokens) 89 | logger.info(f"Compression iteration {i} completed successfully.") 90 | logger.info("Chain of density compression completed.") 91 | return dense_summary 92 | 93 | async def create_podcast_outline(self, compressed_summary: str, max_tokens: int = 300) -> str: 94 | """ 95 | Create a podcast outline based on the compressed summary. 96 | 97 | Args: 98 | compressed_summary (str): The compressed summary of the topic. 99 | max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 300. 100 | 101 | Returns: 102 | str: The generated podcast outline. 103 | 104 | Raises: 105 | OpenAIClientError: If the outline creation fails. 106 | """ 107 | messages = [ 108 | { 109 | "role": "system", 110 | "content": ( 111 | "Create a high-level outline for a podcast episode based on the following summary. " 112 | "Include 3-5 main topics to discuss." 113 | ) 114 | }, 115 | { 116 | "role": "user", 117 | "content": compressed_summary 118 | } 119 | ] 120 | logger.debug("Creating podcast outline with compressed summary.") 121 | outline = await self.client.create_chat_completion(messages, max_tokens=max_tokens) 122 | logger.info("Podcast outline created successfully.") 123 | return outline 124 | 125 | async def generate_full_podcast(self, input_text: str) -> str: 126 | """ 127 | Generate a full podcast script based on the input text. 128 | 129 | Args: 130 | input_text (str): The input text to base the podcast on. 131 | 132 | Returns: 133 | str: The generated podcast script. 134 | 135 | Raises: 136 | OpenAIClientError: If any step in the podcast generation process fails. 137 | """ 138 | try: 139 | logger.info("Starting podcast generation process.") 140 | 141 | logger.info("Generating summary...") 142 | summary = await self.generate_summary(input_text) 143 | 144 | logger.info("Compressing summary...") 145 | compressed_summary = await self.chain_of_density(summary) 146 | 147 | logger.info("Creating podcast outline...") 148 | podcast_input = f"## Topic Summary:\n{compressed_summary}\n## Full Document:\n{input_text[:20000]}" 149 | outline = await self.create_podcast_outline(compressed_summary) 150 | 151 | full_podcast = f"## Talking Points:\n{outline}\n\n## Topic Summary:\n{compressed_summary}\n## Full Document:\n{input_text[:20000]}" 152 | logger.info("Podcast generation completed successfully.") 153 | return full_podcast 154 | 155 | except OpenAIClientError as e: 156 | logger.error(f"Podcast generation failed: {e}") 157 | raise 158 | 159 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | pydantic 4 | aiohttp 5 | python-multipart -------------------------------------------------------------------------------- /static/script.js: -------------------------------------------------------------------------------- 1 | let podcastScript = ''; 2 | let combinedTranscript = ''; 3 | let currentTurnNumber = 0; 4 | let preloadedTurn = null; 5 | let isPreloading = false; 6 | let isPlaying = false; 7 | let isProcessingPreloadedTurn = false; 8 | let pendingAudienceQuestion = null; 9 | 10 | const MAX_RETRIES = 3; 11 | const RETRY_DELAY = 2000; 12 | 13 | let audioContext, analyser, dataArray; 14 | const FFT_SIZE = 32; 15 | let animationFrameId = null; 16 | 17 | function initAudio() { 18 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 19 | analyser = audioContext.createAnalyser(); 20 | analyser.fftSize = FFT_SIZE; 21 | dataArray = new Uint8Array(analyser.frequencyBinCount); 22 | } 23 | 24 | async function generatePodcastScript() { 25 | const topic = document.getElementById('topic').value.trim(); 26 | if (!topic) { 27 | showNotification('Please enter content for the podcast.'); 28 | return; 29 | } 30 | 31 | updateStatusMessage('Generating podcast...', true); 32 | clearRetryMessage(); 33 | 34 | try { 35 | const response = await fetch('/generate_podcast_script', { 36 | method: 'POST', 37 | headers: { 38 | 'Content-Type': 'application/json', 39 | }, 40 | body: JSON.stringify({ topic }), 41 | }); 42 | 43 | if (!response.ok) { 44 | throw new Error('Network response was not ok'); 45 | } 46 | 47 | const data = await response.json(); 48 | podcastScript = data.podcast_script; 49 | 50 | document.getElementById('audienceQuestionInput').classList.remove('hidden'); 51 | document.getElementById('topicInput').classList.add('hidden'); 52 | 53 | startPodcast(); 54 | } catch (error) { 55 | console.error('Error generating podcast script:', error); 56 | showNotification('An error occurred while generating the podcast. Please try again.'); 57 | } finally { 58 | clearStatusMessage(); 59 | } 60 | } 61 | 62 | async function startPodcast() { 63 | if (isPlaying) { 64 | return; 65 | } 66 | 67 | isPlaying = true; 68 | currentTurnNumber = 0; 69 | combinedTranscript = ''; 70 | clearRetryMessage(); 71 | 72 | const firstTurn = await fetchTurnWithRetry(currentTurnNumber); 73 | if (!firstTurn) { 74 | showNotification('No turns available to play.'); 75 | isPlaying = false; 76 | return; 77 | } 78 | 79 | combinedTranscript += formatTranscript(firstTurn); 80 | 81 | preloadedTurn = await fetchTurnWithRetry(currentTurnNumber + 1); 82 | 83 | await playTurn(firstTurn); 84 | } 85 | 86 | async function fetchTurnWithRetry(turnNumber, attempt = 1) { 87 | if (turnNumber >= 7) { 88 | return null; 89 | } 90 | 91 | updateStatusMessage(`Fetching turn ${turnNumber} (Attempt ${attempt}/${MAX_RETRIES})...`, true); 92 | clearRetryMessage(); 93 | 94 | try { 95 | let body = { 96 | podcast_script: podcastScript, 97 | combined_transcript: combinedTranscript, 98 | turn: turnNumber, 99 | }; 100 | 101 | if (pendingAudienceQuestion) { 102 | body.audience_question = pendingAudienceQuestion; 103 | body.combined_transcript += `\n\nAudience: ${pendingAudienceQuestion}\nANSWER THE AUDIENCE QUESTION FIRST THEN MOVE ON`; 104 | pendingAudienceQuestion = null; 105 | } else { 106 | body.audience_question = ""; 107 | } 108 | 109 | const response = await fetch('/next_turn', { 110 | method: 'POST', 111 | headers: { 112 | 'Content-Type': 'application/json', 113 | }, 114 | body: JSON.stringify(body), 115 | }); 116 | 117 | if (!response.ok) { 118 | throw new Error(`Failed to fetch turn ${turnNumber}`); 119 | } 120 | 121 | const data = await response.json(); 122 | 123 | if (!data.transcript || !data.audio_base64) { 124 | throw new Error(`Incomplete data for turn ${turnNumber}`); 125 | } 126 | 127 | let transcript = data.transcript.trim(); 128 | 129 | if (!/[.!?]$/.test(transcript)) { 130 | const lastPeriod = transcript.lastIndexOf('.'); 131 | const lastQuestion = transcript.lastIndexOf('?'); 132 | const lastExclamation = transcript.lastIndexOf('!'); 133 | const lastPunc = Math.max(lastPeriod, lastQuestion, lastExclamation); 134 | 135 | if (lastPunc !== -1) { 136 | transcript = transcript.substring(0, lastPunc + 1).trim(); 137 | } else { 138 | transcript = ''; 139 | } 140 | } 141 | 142 | return { 143 | turnNumber: turnNumber, 144 | transcript: transcript.startsWith(`${data.speaker}:`) ? transcript : `${data.speaker}: ${transcript}`, 145 | audio_base64: data.audio_base64, 146 | }; 147 | } catch (error) { 148 | console.error(`Error fetching turn ${turnNumber}:`, error); 149 | if (attempt < MAX_RETRIES) { 150 | showRetryMessage(`Failed to load turn ${turnNumber}. Retrying in ${RETRY_DELAY / 1000} seconds...`); 151 | await delay(RETRY_DELAY); 152 | return await fetchTurnWithRetry(turnNumber, attempt + 1); 153 | } else { 154 | showRetryMessage(`Failed to load turn ${turnNumber} after ${MAX_RETRIES} attempts.`); 155 | return null; 156 | } 157 | } finally { 158 | clearStatusMessage(); 159 | } 160 | } 161 | 162 | async function playTurn(turnData) { 163 | if (!turnData || !turnData.audio_base64) { 164 | console.error('Invalid turn data:', turnData); 165 | return; 166 | } 167 | 168 | updateAvatars(); 169 | 170 | try { 171 | if (!audioContext) initAudio(); 172 | 173 | const binaryString = atob(turnData.audio_base64); 174 | const len = binaryString.length; 175 | const pcmData = new Uint8Array(len); 176 | for (let i = 0; i < len; i++) { 177 | pcmData[i] = binaryString.charCodeAt(i); 178 | } 179 | 180 | const sampleRate = 24000; 181 | const numChannels = 1; 182 | const wavBuffer = pcm16ToWav(pcmData, sampleRate, numChannels); 183 | 184 | const audioBuffer = await audioContext.decodeAudioData(wavBuffer); 185 | 186 | const source = audioContext.createBufferSource(); 187 | source.buffer = audioBuffer; 188 | source.connect(analyser); 189 | analyser.connect(audioContext.destination); 190 | 191 | source.onended = async () => { 192 | isPlaying = false; 193 | currentTurnNumber++; 194 | 195 | if (animationFrameId) { 196 | cancelAnimationFrame(animationFrameId); 197 | animationFrameId = null; 198 | } 199 | 200 | stopVisualizer(document.querySelector('.avatar.active')); 201 | 202 | if (currentTurnNumber >= 7) { 203 | showNotification('Podcast has ended.'); 204 | return; 205 | } 206 | 207 | if (preloadedTurn && !isProcessingPreloadedTurn) { 208 | isProcessingPreloadedTurn = true; 209 | combinedTranscript += `\n# Turn Number ${preloadedTurn.turnNumber}\n${preloadedTurn.transcript}`; 210 | 211 | await playTurn(preloadedTurn); 212 | 213 | preloadedTurn = null; 214 | isProcessingPreloadedTurn = false; 215 | 216 | if (currentTurnNumber < 6) { 217 | preloadedTurn = await fetchTurnWithRetry(currentTurnNumber + 1); 218 | if (!preloadedTurn) { 219 | showNotification('Podcast has ended.'); 220 | } 221 | } 222 | } else { 223 | showNotification('Podcast has ended.'); 224 | } 225 | }; 226 | 227 | isPlaying = true; 228 | source.start(0); 229 | animateVisualizer(); 230 | } catch (error) { 231 | console.error('Error playing audio:', error); 232 | showNotification('An error occurred during audio playback.'); 233 | isPlaying = false 234 | } 235 | } 236 | 237 | function pcm16ToWav(pcmData, sampleRate = 24000, numChannels = 1) { 238 | const byteRate = sampleRate * numChannels * 2; 239 | const buffer = new ArrayBuffer(44 + pcmData.length); 240 | const view = new DataView(buffer); 241 | 242 | writeString(view, 0, 'RIFF'); 243 | view.setUint32(4, 36 + pcmData.length, true); 244 | writeString(view, 8, 'WAVE'); 245 | writeString(view, 12, 'fmt '); 246 | view.setUint32(16, 16, true); 247 | view.setUint16(20, 1, true); 248 | view.setUint16(22, numChannels, true); 249 | view.setUint32(24, sampleRate, true); 250 | view.setUint32(28, byteRate, true); 251 | view.setUint16(32, numChannels * 2, true); 252 | view.setUint16(34, 16, true); 253 | writeString(view, 36, 'data'); 254 | view.setUint32(40, pcmData.length, true); 255 | 256 | const uint8Array = new Uint8Array(buffer, 44); 257 | uint8Array.set(pcmData); 258 | 259 | return buffer; 260 | } 261 | 262 | function writeString(view, offset, string) { 263 | for (let i = 0; i < string.length; i++) { 264 | view.setUint8(offset + i, string.charCodeAt(i)); 265 | } 266 | } 267 | 268 | function formatTranscript(turnData) { 269 | return `${turnData.transcript}`; 270 | } 271 | 272 | function addAudienceQuestion() { 273 | const question = document.getElementById('question').value.trim(); 274 | if (!question) { 275 | showNotification('Please enter a question.'); 276 | return; 277 | } 278 | 279 | pendingAudienceQuestion = question; 280 | document.getElementById('question').value = ''; 281 | showNotification('Question submitted. It will be answered in the next turn.'); 282 | } 283 | 284 | function animateVisualizer() { 285 | const activeAvatar = document.querySelector('.avatar.active'); 286 | if (!activeAvatar || !isPlaying) return; 287 | 288 | const bars = activeAvatar.querySelectorAll('.bar'); 289 | const pulse = activeAvatar.querySelector('.pulse'); 290 | 291 | analyser.getByteFrequencyData(dataArray); 292 | 293 | let sum = 0; 294 | for (let i = 0; i < bars.length; i++) { 295 | const value = dataArray[i]; 296 | const percent = value / 255 * 100; 297 | bars[i].style.height = `${percent}%`; 298 | sum += value; 299 | } 300 | 301 | const average = sum / dataArray.length; 302 | const pulseScale = 1 + (average / 255) * 0.3; // Scale between 1 and 1.3 303 | pulse.style.transform = `scale(${pulseScale})`; 304 | 305 | animationFrameId = requestAnimationFrame(animateVisualizer); 306 | } 307 | 308 | function stopVisualizer(avatar) { 309 | if (animationFrameId) { 310 | cancelAnimationFrame(animationFrameId); 311 | animationFrameId = null; 312 | } 313 | 314 | const bars = avatar.querySelectorAll('.bar'); 315 | const pulse = avatar.querySelector('.pulse'); 316 | bars.forEach(bar => { 317 | bar.style.height = '0%'; 318 | }); 319 | pulse.style.transform = 'scale(1)'; 320 | } 321 | 322 | function updateAvatars() { 323 | const avatar1 = document.getElementById('avatar1'); 324 | const avatar2 = document.getElementById('avatar2'); 325 | avatar1.classList.toggle('active', currentTurnNumber % 2 === 0); 326 | avatar2.classList.toggle('active', currentTurnNumber % 2 !== 0); 327 | 328 | stopVisualizer(avatar1); 329 | stopVisualizer(avatar2); 330 | } 331 | 332 | function updateStatusMessage(message, isBlinking = false) { 333 | const statusMessage = document.getElementById('statusMessage'); 334 | statusMessage.textContent = message; 335 | statusMessage.classList.remove('hidden'); 336 | if (isBlinking) { 337 | statusMessage.classList.add('blink'); 338 | } else { 339 | statusMessage.classList.remove('blink'); 340 | } 341 | } 342 | 343 | function clearStatusMessage() { 344 | const statusMessage = document.getElementById('statusMessage'); 345 | statusMessage.textContent = ''; 346 | statusMessage.classList.add('hidden'); 347 | statusMessage.classList.remove('blink'); 348 | } 349 | 350 | function showRetryMessage(message) { 351 | const retryDiv = document.getElementById('retryMessage'); 352 | retryDiv.textContent = message; 353 | retryDiv.classList.remove('hidden'); 354 | } 355 | 356 | function clearRetryMessage() { 357 | const retryDiv = document.getElementById('retryMessage'); 358 | retryDiv.textContent = ''; 359 | retryDiv.classList.add('hidden'); 360 | } 361 | 362 | function showNotification(message, duration = 3000) { 363 | const notification = document.getElementById('notification'); 364 | notification.textContent = message; 365 | notification.classList.add('show'); 366 | setTimeout(() => { 367 | notification.classList.remove('show'); 368 | }, duration); 369 | } 370 | 371 | function delay(ms) { 372 | return new Promise(resolve => setTimeout(resolve, ms)); 373 | } -------------------------------------------------------------------------------- /static/styles.css: -------------------------------------------------------------------------------- 1 | @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap'); 2 | 3 | :root { 4 | --primary-color: #FFD700; 5 | --secondary-color: #FFA500; 6 | --background-color: #1A1A1A; 7 | --container-color: #2A2A2A; 8 | --text-color: #FFFFFF; 9 | --box-shadow: 0 0 10px rgba(255, 215, 0, 0.3); 10 | } 11 | 12 | * { 13 | box-sizing: border-box; 14 | margin: 0; 15 | padding: 0; 16 | } 17 | 18 | body { 19 | font-family: 'Roboto Mono', monospace; 20 | line-height: 1.6; 21 | color: var(--text-color); 22 | background-color: var(--background-color); 23 | min-height: 100vh; 24 | display: flex; 25 | justify-content: center; 26 | align-items: center; 27 | padding: 20px; 28 | } 29 | 30 | .container { 31 | max-width: 600px; 32 | width: 100%; 33 | background-color: var(--container-color); 34 | padding: 30px; 35 | border-radius: 10px; 36 | box-shadow: var(--box-shadow); 37 | } 38 | 39 | h1 { 40 | color: var(--primary-color); 41 | text-align: center; 42 | margin-bottom: 30px; 43 | font-size: 2.5em; 44 | } 45 | 46 | label { 47 | display: block; 48 | margin-bottom: 10px; 49 | font-weight: bold; 50 | color: var(--primary-color); 51 | } 52 | 53 | textarea, input[type="text"] { 54 | width: 100%; 55 | padding: 12px; 56 | margin-bottom: 20px; 57 | background-color: var(--background-color); 58 | color: var(--text-color); 59 | border: 1px solid var(--primary-color); 60 | font-family: 'Roboto Mono', monospace; 61 | font-size: 1em; 62 | transition: all 0.3s ease; 63 | border-radius: 5px; 64 | } 65 | 66 | textarea { 67 | resize: vertical; 68 | min-height: 100px; 69 | } 70 | 71 | textarea:focus, input[type="text"]:focus { 72 | outline: none; 73 | box-shadow: 0 0 5px var(--primary-color); 74 | } 75 | 76 | button { 77 | background-color: var(--primary-color); 78 | color: var(--background-color); 79 | padding: 12px 20px; 80 | border: none; 81 | cursor: pointer; 82 | font-family: 'Roboto Mono', monospace; 83 | font-size: 1em; 84 | transition: all 0.3s ease; 85 | border-radius: 5px; 86 | } 87 | 88 | button:hover { 89 | background-color: var(--secondary-color); 90 | } 91 | 92 | button:disabled { 93 | background-color: #555555; 94 | cursor: not-allowed; 95 | } 96 | 97 | .hidden { 98 | display: none; 99 | } 100 | 101 | #loading { 102 | text-align: center; 103 | margin-top: 20px; 104 | font-size: 1.2em; 105 | color: var(--primary-color); 106 | } 107 | 108 | #retryMessage { 109 | color: #FF6B6B; 110 | margin-top: 20px; 111 | text-align: center; 112 | font-weight: bold; 113 | } 114 | 115 | #avatarContainer { 116 | display: flex; 117 | justify-content: space-between; 118 | margin-top: 30px; 119 | } 120 | 121 | .avatar { 122 | width: 80px; 123 | height: 80px; 124 | background-color: var(--secondary-color); 125 | display: flex; 126 | align-items: center; 127 | justify-content: center; 128 | font-size: 24px; 129 | transition: all 0.3s ease; 130 | overflow: hidden; 131 | position: relative; 132 | } 133 | 134 | .avatar.active { 135 | background-color: var(--primary-color); 136 | transform: scale(1.1); 137 | } 138 | 139 | .avatar .visualizer { 140 | display: flex; 141 | justify-content: space-around; 142 | align-items: flex-end; 143 | width: 100%; 144 | height: 100%; 145 | padding: 10px; 146 | } 147 | 148 | .avatar .bar { 149 | width: 4px; 150 | background-color: var(--background-color); 151 | transition: height 0.1s ease; 152 | } 153 | 154 | .avatar .pulse { 155 | position: absolute; 156 | width: 100%; 157 | height: 100%; 158 | background-color: rgba(255, 255, 255, 0.2); 159 | transform: scale(1); 160 | transition: transform 0.1s ease; 161 | } 162 | 163 | #audienceQuestionInput { 164 | margin-top: 30px; 165 | padding: 20px; 166 | border: 1px solid var(--primary-color); 167 | border-radius: 5px; 168 | background-color: var(--background-color); 169 | } 170 | 171 | #notification { 172 | position: fixed; 173 | top: 20px; 174 | right: 20px; 175 | padding: 15px 20px; 176 | background-color: var(--primary-color); 177 | color: var(--background-color); 178 | border-radius: 5px; 179 | font-size: 1em; 180 | opacity: 0; 181 | transition: opacity 0.3s ease; 182 | pointer-events: none; 183 | z-index: 1000; 184 | box-shadow: var(--box-shadow); 185 | } 186 | 187 | #notification.show { 188 | opacity: 1; 189 | } 190 | 191 | #statusMessage { 192 | text-align: center; 193 | margin-top: 20px; 194 | font-size: 1.2em; 195 | color: var(--primary-color); 196 | } 197 | 198 | .blink { 199 | animation: blink 1s infinite; 200 | } 201 | 202 | @keyframes blink { 203 | 0% { opacity: 1; } 204 | 50% { opacity: 0; } 205 | 100% { opacity: 1; } 206 | } -------------------------------------------------------------------------------- /turn_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | import base64 4 | import json 5 | import logging 6 | from typing import Any, Dict, Tuple, Optional 7 | 8 | from app_config import MAX_RETRIES, RETRY_DELAY 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class TurnHandlerError(Exception): 14 | """Custom exception class for TurnHandler errors.""" 15 | pass 16 | 17 | 18 | class TurnHandler: 19 | """ 20 | A class for handling turns in a podcast conversation. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | api_key: str, 26 | api_base: str, 27 | api_key_b: str, 28 | api_base_b: str, 29 | realtime_deployment_name: str 30 | ) -> None: 31 | """ 32 | Initialize the TurnHandler. 33 | 34 | Args: 35 | api_key (str): The primary API key. 36 | api_base (str): The primary API base URL. 37 | api_key_b (str): The secondary API key. 38 | api_base_b (str): The secondary API base URL. 39 | realtime_deployment_name (str): The name of the realtime deployment. 40 | """ 41 | self.api_key = api_key 42 | self.api_base = api_base.rstrip('/') 43 | self.api_key_b = api_key_b 44 | self.api_base_b = api_base_b.rstrip('/') 45 | self.realtime_deployment_name = realtime_deployment_name 46 | self.session: Optional[aiohttp.ClientSession] = None 47 | logger.debug("TurnHandler initialized with primary and secondary API credentials.") 48 | 49 | async def _get_session(self) -> aiohttp.ClientSession: 50 | """ 51 | Get or create an aiohttp ClientSession. 52 | 53 | Returns: 54 | aiohttp.ClientSession: The aiohttp session. 55 | """ 56 | if self.session is None or self.session.closed: 57 | self.session = aiohttp.ClientSession() 58 | logger.debug("Created new aiohttp ClientSession for TurnHandler.") 59 | return self.session 60 | 61 | async def close(self) -> None: 62 | """ 63 | Close the aiohttp ClientSession if it's open. 64 | """ 65 | if self.session and not self.session.closed: 66 | await self.session.close() 67 | logger.debug("Closed aiohttp ClientSession for TurnHandler.") 68 | 69 | async def handle_turn(self, request: Any) -> Dict[str, Any]: 70 | """ 71 | Handle a turn in the podcast conversation. 72 | 73 | Args: 74 | request (Any): The turn request containing podcast script, transcript, and other information. 75 | 76 | Returns: 77 | Dict[str, Any]: A dictionary containing the speaker, transcript, and audio data. 78 | 79 | Raises: 80 | TurnHandlerError: If handling the turn fails. 81 | """ 82 | try: 83 | audience_question: Optional[str] = getattr(request, 'audience_question', None) 84 | turn_number: int = getattr(request, 'turn', 0) 85 | podcast_script: str = getattr(request, 'podcast_script', '') 86 | combined_transcript: str = getattr(request, 'combined_transcript', '') 87 | 88 | logger.debug(f"Handling turn {turn_number} with audience question: {audience_question}") 89 | 90 | if turn_number % 2 == 0: 91 | speaker = "Dan" 92 | instructions = ( 93 | f"You are Dan, the host of a podcast. Make sure to move to the next talking point by asking " 94 | f"questions to Anna. Discuss the following topic: {podcast_script}" 95 | ) 96 | voice = "dan" 97 | api_key = self.api_key 98 | api_base = self.api_base 99 | else: 100 | speaker = "Anna" 101 | instructions = ( 102 | f"You are Anna, a guest on a podcast. Discuss the following topic: {podcast_script}" 103 | ) 104 | voice = "marilyn" 105 | api_key = self.api_key_b 106 | api_base = self.api_base_b 107 | 108 | if audience_question: 109 | instructions += ( 110 | f" Answer the audience question before proceeding (start with 'Oh I see we have a question from the audience'): " 111 | f"{audience_question}" 112 | ) 113 | 114 | if turn_number == 6: 115 | instructions = ( 116 | "You are Dan, you just held a podcast with your guest Anna. Ask her and the audience for her time " 117 | "and thoughts on the podcast. (i.e., do the outro)" 118 | ) 119 | 120 | logger.info(f"Generating response for speaker: {speaker}, turn: {turn_number}") 121 | audio_base64, transcript = await self.generate_response( 122 | api_key=api_key, 123 | api_base=api_base, 124 | deployment_name=self.realtime_deployment_name, 125 | context=instructions, 126 | combined_transcript=combined_transcript, 127 | voice=voice, 128 | speaker=speaker 129 | ) 130 | 131 | logger.info(f"Turn {turn_number} handled successfully for speaker: {speaker}") 132 | return { 133 | "speaker": speaker, 134 | "transcript": transcript, 135 | "audio_base64": audio_base64 136 | } 137 | 138 | except Exception as e: 139 | logger.exception(f"Failed to handle turn {getattr(request, 'turn', 'unknown')}: {e}") 140 | raise TurnHandlerError(f"Failed to handle turn: {e}") from e 141 | 142 | async def generate_response( 143 | self, 144 | api_key: str, 145 | api_base: str, 146 | deployment_name: str, 147 | context: str, 148 | combined_transcript: str, 149 | voice: str, 150 | speaker: str 151 | ) -> Tuple[str, str]: 152 | """ 153 | Generate a response for the current turn. 154 | 155 | Args: 156 | api_key (str): The API key to use. 157 | api_base (str): The API base URL. 158 | deployment_name (str): The name of the deployment. 159 | context (str): The context for the response. 160 | combined_transcript (str): The combined transcript of the conversation so far. 161 | voice (str): The voice to use for the response. 162 | speaker (str): The name of the speaker. 163 | 164 | Returns: 165 | Tuple[str, str]: A tuple containing the base64 encoded audio and the transcript. 166 | 167 | Raises: 168 | TurnHandlerError: If all retry attempts fail. 169 | """ 170 | url = f"{api_base.replace('https://', 'wss://')}/openai/realtime" 171 | query_params = f"?api-version=2024-10-01-preview&deployment={deployment_name}" 172 | full_url = f"{url}{query_params}" 173 | headers = { 174 | "api-key": api_key, 175 | "OpenAI-Beta": "realtime=v1" 176 | } 177 | 178 | for attempt in range(1, MAX_RETRIES + 1): 179 | try: 180 | session = await self._get_session() 181 | async with session.ws_connect(full_url, headers=headers) as ws: 182 | logger.debug(f"WebSocket connection established to {full_url}") 183 | 184 | # Wait for session.created message 185 | msg = await ws.receive_json() 186 | if msg.get("type") != "session.created": 187 | raise TurnHandlerError(f"Expected 'session.created', got '{msg.get('type')}'") 188 | logger.debug("Session created successfully.") 189 | 190 | # Send session.update message with instructions 191 | await ws.send_json({ 192 | "type": "session.update", 193 | "session": { 194 | "modalities": ["text", "audio"], 195 | "instructions": ( 196 | f"You are {speaker}. {context}. Always speak 1-2 sentences at a time. " 197 | f"Continue the conversation (CONTINUE WHERE IT LEFT OFF at !):\n" 198 | f"{combined_transcript}\n\n{speaker}: " 199 | ), 200 | "voice": voice, 201 | "input_audio_format": "pcm16", 202 | "output_audio_format": "pcm16", 203 | "turn_detection": None, 204 | "temperature": 0.6, 205 | } 206 | }) 207 | logger.debug("Sent session.update with instructions.") 208 | 209 | # Send response.create message 210 | await ws.send_json({ 211 | "type": "response.create", 212 | "response": { 213 | "modalities": ["audio", "text"], 214 | "voice": voice, 215 | "output_audio_format": "pcm16", 216 | "temperature": 0.6, 217 | } 218 | }) 219 | logger.debug("Sent response.create message.") 220 | 221 | audio_data = bytearray() 222 | transcript = "" 223 | 224 | async for msg in ws: 225 | if msg.type == aiohttp.WSMsgType.TEXT: 226 | data = json.loads(msg.data) 227 | msg_type = data.get("type") 228 | logger.debug(f"Received WebSocket message of type: {msg_type}") 229 | 230 | if msg_type == "response.audio.delta": 231 | delta = data.get("delta", "") 232 | audio_bytes = base64.b64decode(delta) 233 | audio_data.extend(audio_bytes) 234 | logger.debug(f"Appended audio delta: {len(audio_bytes)} bytes.") 235 | 236 | elif msg_type == "response.audio_transcript.done": 237 | transcript = data.get("transcript", "") 238 | logger.debug("Received completed transcript.") 239 | 240 | elif msg_type == "response.done": 241 | logger.debug("Received response.done message. Closing WebSocket.") 242 | break 243 | 244 | elif msg.type == aiohttp.WSMsgType.ERROR: 245 | error_msg = msg.data 246 | logger.error(f"WebSocket error: {error_msg}") 247 | raise TurnHandlerError(f"WebSocket error: {error_msg}") 248 | 249 | await ws.close() 250 | logger.debug("WebSocket connection closed.") 251 | 252 | # Encode audio data to base64 253 | audio_base64 = base64.b64encode(audio_data).decode('utf-8') 254 | logger.debug("Audio data encoded to base64.") 255 | 256 | if not transcript: 257 | logger.warning("Transcript is empty after response.") 258 | raise TurnHandlerError("Transcript is empty after response.") 259 | 260 | return audio_base64, transcript 261 | 262 | except (aiohttp.ClientError, TurnHandlerError) as e: 263 | logger.error(f"Attempt {attempt} failed with error: {e}") 264 | if attempt < MAX_RETRIES: 265 | logger.info(f"Retrying in {RETRY_DELAY} seconds...") 266 | await asyncio.sleep(RETRY_DELAY) 267 | else: 268 | logger.critical("All retry attempts failed.") 269 | raise TurnHandlerError(f"All retry attempts failed: {e}") from e 270 | --------------------------------------------------------------------------------