├── .gitignore
├── LICENSE
├── README.md
├── app_config.py
├── index.html
├── main.py
├── openai_client.py
├── podcast_generator.py
├── requirements.txt
├── static
├── script.js
└── styles.css
└── turn_handler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 aymenfurter
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AI Podcast Generator
2 |
3 | AI Podcast Generator is an demo app that leverages OpenAI's language models to create engaging podcast scripts effortlessly. Simply input your desired topic or content, and the AI will generate a comprehensive podcast. Enhance your podcasting experience by asking audience questions in real time.
4 |
5 | [Watch the Demo Video](https://www.youtube.com/watch?v=CrdZtwO6x6o)
6 |
7 | ## Installation
8 |
9 | 1. **Clone the Repository:**
10 | ```bash
11 | git clone https://github.com/aymenfurter/ai-podcast-generator.git
12 | cd ai-podcast-generator
13 | ```
14 |
15 | 2. **Create a Virtual Environment:**
16 | ```bash
17 | python -m venv venv
18 | source venv/bin/activate # On Windows: venv\Scripts\activate
19 | ```
20 |
21 | 3. **Install Dependencies:**
22 | ```bash
23 | pip install -r requirements.txt
24 | ```
25 |
26 | 4. **Set Up Environment Variables:**
27 | - Create a `.env` file in the root directory.
28 | - Add the following variables (You'll need an Azure OpenAI Instance):
29 | ```env
30 | OPENAI_API_KEY=your_openai_api_key
31 | OPENAI_API_BASE=your_openai_api_base
32 | OPENAI_DEPLOYMENT_NAME=your_deployment_name
33 | OPENAI_REALTIME_DEPLOYMENT_NAME=your_realtime_deployment_name
34 | OPENAI_API_KEY_B=your_secondary_openai_api_key
35 | OPENAI_API_BASE_B=your_secondary_openai_api_base
36 | ```
37 |
38 | ## Usage
39 |
40 | 1. **Run the Application:**
41 | ```bash
42 | python main.py
43 | ```
44 |
45 | 2. **Access the Web Interface:**
46 | - Open your browser and navigate to `http://localhost:8000`.
47 | - Enter your podcast topic and generate your script.
48 |
49 | ## License
50 |
51 | This project is licensed under the [MIT License](LICENSE).
--------------------------------------------------------------------------------
/app_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 |
4 | load_dotenv()
5 |
6 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
7 | OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
8 | OPENAI_DEPLOYMENT_NAME = os.getenv("OPENAI_DEPLOYMENT_NAME")
9 | OPENAI_REALTIME_DEPLOYMENT_NAME = os.getenv("OPENAI_REALTIME_DEPLOYMENT_NAME")
10 | OPENAI_API_KEY_B = os.getenv("OPENAI_API_KEY_B")
11 | OPENAI_API_BASE_B = os.getenv("OPENAI_API_BASE_B")
12 | MAX_RETRIES = 15
13 | RETRY_DELAY = 5
14 |
15 | required_env_vars = [
16 | "OPENAI_API_KEY",
17 | "OPENAI_API_BASE",
18 | "OPENAI_DEPLOYMENT_NAME",
19 | "OPENAI_REALTIME_DEPLOYMENT_NAME",
20 | "OPENAI_API_KEY_B",
21 | "OPENAI_API_BASE_B"
22 | ]
23 |
24 | for var in required_env_vars:
25 | if not os.getenv(var):
26 | raise EnvironmentError(f"Missing required environment variable: {var}")
27 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | AI Podcast Generator
7 |
8 |
9 |
10 |
11 |
AI Podcast Generator
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from fastapi import FastAPI, HTTPException, Depends
3 | from fastapi.responses import HTMLResponse
4 | from fastapi.staticfiles import StaticFiles
5 | from pydantic import BaseModel
6 | from typing import Optional, Dict, Any
7 |
8 | from openai_client import OpenAIClient, OpenAIClientError
9 | from podcast_generator import PodcastGenerator
10 | from turn_handler import TurnHandler, TurnHandlerError
11 | from app_config import (
12 | OPENAI_API_KEY,
13 | OPENAI_API_BASE,
14 | OPENAI_DEPLOYMENT_NAME,
15 | OPENAI_REALTIME_DEPLOYMENT_NAME,
16 | OPENAI_API_KEY_B,
17 | OPENAI_API_BASE_B
18 | )
19 |
20 | logging.basicConfig(
21 | level=logging.INFO,
22 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s'
23 | )
24 | logger = logging.getLogger(__name__)
25 |
26 | async def lifespan(app: FastAPI):
27 | """
28 | Lifespan event handler to manage startup and shutdown events.
29 | Initializes shared clients on startup and ensures proper closure on shutdown.
30 | """
31 | logger.info("Lifespan startup: Initializing OpenAIClient and TurnHandler.")
32 |
33 | app.state.openai_client = OpenAIClient(
34 | api_key=OPENAI_API_KEY,
35 | api_base=OPENAI_API_BASE,
36 | deployment_name=OPENAI_DEPLOYMENT_NAME,
37 | realtime_deployment_name=OPENAI_REALTIME_DEPLOYMENT_NAME
38 | )
39 | await app.state.openai_client.__aenter__()
40 |
41 | app.state.turn_handler = TurnHandler(
42 | api_key=OPENAI_API_KEY,
43 | api_base=OPENAI_API_BASE,
44 | api_key_b=OPENAI_API_KEY_B,
45 | api_base_b=OPENAI_API_BASE_B,
46 | realtime_deployment_name=OPENAI_REALTIME_DEPLOYMENT_NAME
47 | )
48 | await app.state.turn_handler._get_session()
49 |
50 | logger.info("OpenAIClient and TurnHandler initialized successfully.")
51 |
52 | try:
53 | yield
54 | finally:
55 | logger.info("Lifespan shutdown: Closing OpenAIClient and TurnHandler.")
56 |
57 | await app.state.openai_client.close()
58 | await app.state.turn_handler.close()
59 |
60 | logger.info("OpenAIClient and TurnHandler closed successfully.")
61 |
62 | app = FastAPI(
63 | title="Podcast Generator API",
64 | lifespan=lifespan
65 | )
66 |
67 | app.mount("/static", StaticFiles(directory="static"), name="static")
68 |
69 | class PodcastRequest(BaseModel):
70 | topic: str
71 |
72 | class TurnRequest(BaseModel):
73 | podcast_script: str
74 | combined_transcript: str
75 | audience_question: Optional[str] = None
76 | turn: int
77 |
78 | async def get_openai_client() -> OpenAIClient:
79 | return app.state.openai_client
80 |
81 | async def get_podcast_generator(client: OpenAIClient = Depends(get_openai_client)) -> PodcastGenerator:
82 | return PodcastGenerator(client)
83 |
84 | async def get_turn_handler() -> TurnHandler:
85 | return app.state.turn_handler
86 |
87 | @app.post("/generate_podcast_script")
88 | async def generate_podcast_script(
89 | request: PodcastRequest,
90 | generator: PodcastGenerator = Depends(get_podcast_generator)
91 | ) -> Dict[str, Any]:
92 | """
93 | Generate a podcast script based on the given topic.
94 |
95 | Args:
96 | request (PodcastRequest): The request containing the podcast topic.
97 | generator (PodcastGenerator): The injected PodcastGenerator instance.
98 |
99 | Returns:
100 | dict: A dictionary containing the generated podcast script.
101 |
102 | Raises:
103 | HTTPException: If the podcast script generation fails.
104 | """
105 | logger.info(f"Received request to generate podcast script for topic: '{request.topic}'")
106 | try:
107 | podcast_script = await generator.generate_full_podcast(request.topic)
108 | logger.info("Podcast script generated successfully.")
109 | return {"podcast_script": podcast_script}
110 | except OpenAIClientError as e:
111 | logger.error(f"Failed to generate podcast script: {e}")
112 | raise HTTPException(status_code=500, detail="Failed to generate podcast script.") from e
113 | except Exception as e:
114 | logger.exception(f"Unexpected error during podcast script generation: {e}")
115 | raise HTTPException(status_code=500, detail="An unexpected error occurred.") from e
116 |
117 | @app.post("/next_turn")
118 | async def next_turn(
119 | request: TurnRequest,
120 | turn_handler: TurnHandler = Depends(get_turn_handler)
121 | ) -> Dict[str, Any]:
122 | """
123 | Handle the next turn in the podcast conversation.
124 |
125 | Args:
126 | request (TurnRequest): The request containing the turn information.
127 | turn_handler (TurnHandler): The injected TurnHandler instance.
128 |
129 | Returns:
130 | dict: A dictionary containing the speaker, transcript, and audio data.
131 |
132 | Raises:
133 | HTTPException: If handling the turn fails.
134 | """
135 | logger.info(f"Received request to handle turn {request.turn}.")
136 | try:
137 | response = await turn_handler.handle_turn(request)
138 | logger.info(f"Turn {request.turn} handled successfully.")
139 | return response
140 | except TurnHandlerError as e:
141 | logger.error(f"Failed to handle turn {request.turn}: {e}")
142 | raise HTTPException(status_code=500, detail="Failed to handle the turn.") from e
143 | except Exception as e:
144 | logger.exception(f"Unexpected error during turn {request.turn} handling: {e}")
145 | raise HTTPException(status_code=500, detail="An unexpected error occurred.") from e
146 |
147 | @app.get("/", response_class=HTMLResponse)
148 | async def read_root() -> str:
149 | """
150 | Serve the main HTML page.
151 |
152 | Returns:
153 | str: The content of the index.html file.
154 |
155 | Raises:
156 | HTTPException: If index.html is not found.
157 | """
158 | try:
159 | with open("index.html", "r", encoding="utf-8") as f:
160 | content = f.read()
161 | logger.info("Served index.html successfully.")
162 | return content
163 | except FileNotFoundError:
164 | logger.error("index.html not found.")
165 | raise HTTPException(status_code=404, detail="index.html not found.")
166 | except Exception as e:
167 | logger.exception(f"Error serving index.html: {e}")
168 | raise HTTPException(status_code=500, detail="Failed to serve the main page.") from e
169 |
170 | def main():
171 | """
172 | Entry point to run the FastAPI application using Uvicorn.
173 | """
174 | import uvicorn
175 | uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
176 |
177 | if __name__ == "__main__":
178 | main()
179 |
--------------------------------------------------------------------------------
/openai_client.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | import logging
3 | from typing import List, Optional, Dict, Any
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | logger = logging.getLogger(__name__)
7 |
8 | class OpenAIClient:
9 | """
10 | A client for interacting with the OpenAI API.
11 | """
12 |
13 | API_VERSION = "2023-05-15"
14 |
15 | def __init__(
16 | self,
17 | api_key: str,
18 | api_base: str,
19 | deployment_name: str,
20 | realtime_deployment_name: str
21 | ) -> None:
22 | """
23 | Initialize the OpenAIClient.
24 |
25 | Args:
26 | api_key (str): The API key for authentication.
27 | api_base (str): The base URL for the API.
28 | deployment_name (str): The name of the deployment.
29 | realtime_deployment_name (str): The name of the realtime deployment.
30 | """
31 | self.api_key = api_key
32 | self.api_base = api_base.rstrip('/') # Ensure no trailing slash
33 | self.deployment_name = deployment_name
34 | self.realtime_deployment_name = realtime_deployment_name
35 | self.session: Optional[aiohttp.ClientSession] = None
36 |
37 | async def _get_session(self) -> aiohttp.ClientSession:
38 | """
39 | Get or create an aiohttp ClientSession.
40 |
41 | Returns:
42 | aiohttp.ClientSession: The aiohttp session.
43 | """
44 | if self.session is None or self.session.closed:
45 | self.session = aiohttp.ClientSession()
46 | logger.debug("Created new aiohttp ClientSession.")
47 | return self.session
48 |
49 | async def close(self) -> None:
50 | """
51 | Close the aiohttp ClientSession if it's open.
52 | """
53 | if self.session and not self.session.closed:
54 | await self.session.close()
55 | logger.debug("Closed aiohttp ClientSession.")
56 |
57 | async def create_chat_completion(
58 | self,
59 | messages: List[Dict[str, Any]],
60 | max_tokens: Optional[int] = None
61 | ) -> str:
62 | """
63 | Create a chat completion using the OpenAI API.
64 |
65 | Args:
66 | messages (List[Dict[str, Any]]): A list of message dictionaries.
67 | max_tokens (Optional[int], optional): The maximum number of tokens to generate.
68 |
69 | Returns:
70 | str: The generated chat completion content.
71 |
72 | Raises:
73 | OpenAIClientError: If the chat completion request fails.
74 | """
75 | url = (
76 | f"{self.api_base}/openai/deployments/{self.deployment_name}/chat/completions"
77 | f"?api-version={self.API_VERSION}"
78 | )
79 | headers = {
80 | "Content-Type": "application/json",
81 | "api-key": self.api_key
82 | }
83 | payload: Dict[str, Any] = {
84 | "messages": messages
85 | }
86 | if max_tokens is not None:
87 | payload["max_tokens"] = max_tokens
88 |
89 | logger.debug(f"Sending request to {url} with payload: {payload}")
90 |
91 | session = await self._get_session()
92 | try:
93 | async with session.post(url, headers=headers, json=payload) as response:
94 | response_text = await response.text()
95 | logger.debug(f"Received response status: {response.status}")
96 | logger.debug(f"Response text: {response_text}")
97 |
98 | if response.status != 200:
99 | logger.error(
100 | f"Chat completion failed with status {response.status}: {response_text}"
101 | )
102 | raise OpenAIClientError(
103 | f"Chat completion failed: {response.status} - {response_text}"
104 | )
105 |
106 | result = await response.json()
107 | completion = result.get('choices', [{}])[0].get('message', {}).get('content', '')
108 |
109 | if not completion:
110 | logger.warning("No content found in the response.")
111 | raise OpenAIClientError("No content found in the response.")
112 |
113 | logger.info("Chat completion successful.")
114 | return completion
115 |
116 | except aiohttp.ClientError as e:
117 | logger.exception("HTTP request failed.")
118 | raise OpenAIClientError(f"HTTP request failed: {e}") from e
119 |
120 | async def __aenter__(self) -> 'OpenAIClient':
121 | """
122 | Enter the runtime context related to this object.
123 |
124 | Returns:
125 | OpenAIClient: The client instance.
126 | """
127 | await self._get_session()
128 | return self
129 |
130 | async def __aexit__(self, exc_type, exc, tb) -> None:
131 | """
132 | Exit the runtime context and close the session.
133 | """
134 | await self.close()
135 |
136 |
137 | class OpenAIClientError(Exception):
138 | """Custom exception class for OpenAIClient errors."""
139 | pass
--------------------------------------------------------------------------------
/podcast_generator.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 | import logging
3 | from openai_client import OpenAIClient
4 | from openai_client import OpenAIClientError
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 | class PodcastGenerator:
9 | """
10 | A class for generating podcast content using OpenAI's language model.
11 | """
12 |
13 | def __init__(self, client: OpenAIClient) -> None:
14 | """
15 | Initialize the PodcastGenerator.
16 |
17 | Args:
18 | client (OpenAIClient): An instance of the OpenAIClient.
19 | """
20 | self.client = client
21 | logger.debug("PodcastGenerator initialized with OpenAIClient.")
22 |
23 | async def generate_summary(self, text: str, max_tokens: int = 300) -> str:
24 | """
25 | Generate a summary of the given text.
26 |
27 | Args:
28 | text (str): The input text to summarize.
29 | max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 300.
30 |
31 | Returns:
32 | str: The generated summary.
33 |
34 | Raises:
35 | OpenAIClientError: If the summary generation fails.
36 | """
37 | messages = [
38 | {
39 | "role": "system",
40 | "content": "You are a skilled summarizer. Create a concise summary of the following text."
41 | },
42 | {
43 | "role": "user",
44 | "content": text
45 | }
46 | ]
47 | logger.debug("Generating summary with provided text.")
48 | summary = await self.client.create_chat_completion(messages, max_tokens=max_tokens)
49 | logger.info("Summary generated successfully.")
50 | return summary
51 |
52 | async def chain_of_density(self, summary: str, iterations: int = 3, max_tokens: int = 300) -> str:
53 | """
54 | Apply the chain of density technique to compress the summary.
55 |
56 | Args:
57 | summary (str): The initial summary.
58 | iterations (int, optional): The number of compression iterations. Defaults to 3.
59 | max_tokens (int, optional): The maximum number of tokens to generate per iteration. Defaults to 300.
60 |
61 | Returns:
62 | str: The compressed summary.
63 |
64 | Raises:
65 | OpenAIClientError: If the compression process fails.
66 | """
67 | dense_summary = summary
68 | logger.debug(f"Starting chain of density with {iterations} iterations.")
69 | for i in range(1, iterations + 1):
70 | messages = [
71 | {
72 | "role": "system",
73 | "content": (
74 | "You are an expert in information compression. Your task is to make the given text "
75 | "more concise while preserving all key information. Aim to reduce the word count by "
76 | "25% without losing important content."
77 | )
78 | },
79 | {
80 | "role": "user",
81 | "content": (
82 | f"Original text:\n{dense_summary}\n\n"
83 | "Compress this text, maintaining all key points but reducing verbosity."
84 | )
85 | }
86 | ]
87 | logger.debug(f"Compression iteration {i} of {iterations}.")
88 | dense_summary = await self.client.create_chat_completion(messages, max_tokens=max_tokens)
89 | logger.info(f"Compression iteration {i} completed successfully.")
90 | logger.info("Chain of density compression completed.")
91 | return dense_summary
92 |
93 | async def create_podcast_outline(self, compressed_summary: str, max_tokens: int = 300) -> str:
94 | """
95 | Create a podcast outline based on the compressed summary.
96 |
97 | Args:
98 | compressed_summary (str): The compressed summary of the topic.
99 | max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 300.
100 |
101 | Returns:
102 | str: The generated podcast outline.
103 |
104 | Raises:
105 | OpenAIClientError: If the outline creation fails.
106 | """
107 | messages = [
108 | {
109 | "role": "system",
110 | "content": (
111 | "Create a high-level outline for a podcast episode based on the following summary. "
112 | "Include 3-5 main topics to discuss."
113 | )
114 | },
115 | {
116 | "role": "user",
117 | "content": compressed_summary
118 | }
119 | ]
120 | logger.debug("Creating podcast outline with compressed summary.")
121 | outline = await self.client.create_chat_completion(messages, max_tokens=max_tokens)
122 | logger.info("Podcast outline created successfully.")
123 | return outline
124 |
125 | async def generate_full_podcast(self, input_text: str) -> str:
126 | """
127 | Generate a full podcast script based on the input text.
128 |
129 | Args:
130 | input_text (str): The input text to base the podcast on.
131 |
132 | Returns:
133 | str: The generated podcast script.
134 |
135 | Raises:
136 | OpenAIClientError: If any step in the podcast generation process fails.
137 | """
138 | try:
139 | logger.info("Starting podcast generation process.")
140 |
141 | logger.info("Generating summary...")
142 | summary = await self.generate_summary(input_text)
143 |
144 | logger.info("Compressing summary...")
145 | compressed_summary = await self.chain_of_density(summary)
146 |
147 | logger.info("Creating podcast outline...")
148 | podcast_input = f"## Topic Summary:\n{compressed_summary}\n## Full Document:\n{input_text[:20000]}"
149 | outline = await self.create_podcast_outline(compressed_summary)
150 |
151 | full_podcast = f"## Talking Points:\n{outline}\n\n## Topic Summary:\n{compressed_summary}\n## Full Document:\n{input_text[:20000]}"
152 | logger.info("Podcast generation completed successfully.")
153 | return full_podcast
154 |
155 | except OpenAIClientError as e:
156 | logger.error(f"Podcast generation failed: {e}")
157 | raise
158 |
159 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | uvicorn
3 | pydantic
4 | aiohttp
5 | python-multipart
--------------------------------------------------------------------------------
/static/script.js:
--------------------------------------------------------------------------------
1 | let podcastScript = '';
2 | let combinedTranscript = '';
3 | let currentTurnNumber = 0;
4 | let preloadedTurn = null;
5 | let isPreloading = false;
6 | let isPlaying = false;
7 | let isProcessingPreloadedTurn = false;
8 | let pendingAudienceQuestion = null;
9 |
10 | const MAX_RETRIES = 3;
11 | const RETRY_DELAY = 2000;
12 |
13 | let audioContext, analyser, dataArray;
14 | const FFT_SIZE = 32;
15 | let animationFrameId = null;
16 |
17 | function initAudio() {
18 | audioContext = new (window.AudioContext || window.webkitAudioContext)();
19 | analyser = audioContext.createAnalyser();
20 | analyser.fftSize = FFT_SIZE;
21 | dataArray = new Uint8Array(analyser.frequencyBinCount);
22 | }
23 |
24 | async function generatePodcastScript() {
25 | const topic = document.getElementById('topic').value.trim();
26 | if (!topic) {
27 | showNotification('Please enter content for the podcast.');
28 | return;
29 | }
30 |
31 | updateStatusMessage('Generating podcast...', true);
32 | clearRetryMessage();
33 |
34 | try {
35 | const response = await fetch('/generate_podcast_script', {
36 | method: 'POST',
37 | headers: {
38 | 'Content-Type': 'application/json',
39 | },
40 | body: JSON.stringify({ topic }),
41 | });
42 |
43 | if (!response.ok) {
44 | throw new Error('Network response was not ok');
45 | }
46 |
47 | const data = await response.json();
48 | podcastScript = data.podcast_script;
49 |
50 | document.getElementById('audienceQuestionInput').classList.remove('hidden');
51 | document.getElementById('topicInput').classList.add('hidden');
52 |
53 | startPodcast();
54 | } catch (error) {
55 | console.error('Error generating podcast script:', error);
56 | showNotification('An error occurred while generating the podcast. Please try again.');
57 | } finally {
58 | clearStatusMessage();
59 | }
60 | }
61 |
62 | async function startPodcast() {
63 | if (isPlaying) {
64 | return;
65 | }
66 |
67 | isPlaying = true;
68 | currentTurnNumber = 0;
69 | combinedTranscript = '';
70 | clearRetryMessage();
71 |
72 | const firstTurn = await fetchTurnWithRetry(currentTurnNumber);
73 | if (!firstTurn) {
74 | showNotification('No turns available to play.');
75 | isPlaying = false;
76 | return;
77 | }
78 |
79 | combinedTranscript += formatTranscript(firstTurn);
80 |
81 | preloadedTurn = await fetchTurnWithRetry(currentTurnNumber + 1);
82 |
83 | await playTurn(firstTurn);
84 | }
85 |
86 | async function fetchTurnWithRetry(turnNumber, attempt = 1) {
87 | if (turnNumber >= 7) {
88 | return null;
89 | }
90 |
91 | updateStatusMessage(`Fetching turn ${turnNumber} (Attempt ${attempt}/${MAX_RETRIES})...`, true);
92 | clearRetryMessage();
93 |
94 | try {
95 | let body = {
96 | podcast_script: podcastScript,
97 | combined_transcript: combinedTranscript,
98 | turn: turnNumber,
99 | };
100 |
101 | if (pendingAudienceQuestion) {
102 | body.audience_question = pendingAudienceQuestion;
103 | body.combined_transcript += `\n\nAudience: ${pendingAudienceQuestion}\nANSWER THE AUDIENCE QUESTION FIRST THEN MOVE ON`;
104 | pendingAudienceQuestion = null;
105 | } else {
106 | body.audience_question = "";
107 | }
108 |
109 | const response = await fetch('/next_turn', {
110 | method: 'POST',
111 | headers: {
112 | 'Content-Type': 'application/json',
113 | },
114 | body: JSON.stringify(body),
115 | });
116 |
117 | if (!response.ok) {
118 | throw new Error(`Failed to fetch turn ${turnNumber}`);
119 | }
120 |
121 | const data = await response.json();
122 |
123 | if (!data.transcript || !data.audio_base64) {
124 | throw new Error(`Incomplete data for turn ${turnNumber}`);
125 | }
126 |
127 | let transcript = data.transcript.trim();
128 |
129 | if (!/[.!?]$/.test(transcript)) {
130 | const lastPeriod = transcript.lastIndexOf('.');
131 | const lastQuestion = transcript.lastIndexOf('?');
132 | const lastExclamation = transcript.lastIndexOf('!');
133 | const lastPunc = Math.max(lastPeriod, lastQuestion, lastExclamation);
134 |
135 | if (lastPunc !== -1) {
136 | transcript = transcript.substring(0, lastPunc + 1).trim();
137 | } else {
138 | transcript = '';
139 | }
140 | }
141 |
142 | return {
143 | turnNumber: turnNumber,
144 | transcript: transcript.startsWith(`${data.speaker}:`) ? transcript : `${data.speaker}: ${transcript}`,
145 | audio_base64: data.audio_base64,
146 | };
147 | } catch (error) {
148 | console.error(`Error fetching turn ${turnNumber}:`, error);
149 | if (attempt < MAX_RETRIES) {
150 | showRetryMessage(`Failed to load turn ${turnNumber}. Retrying in ${RETRY_DELAY / 1000} seconds...`);
151 | await delay(RETRY_DELAY);
152 | return await fetchTurnWithRetry(turnNumber, attempt + 1);
153 | } else {
154 | showRetryMessage(`Failed to load turn ${turnNumber} after ${MAX_RETRIES} attempts.`);
155 | return null;
156 | }
157 | } finally {
158 | clearStatusMessage();
159 | }
160 | }
161 |
162 | async function playTurn(turnData) {
163 | if (!turnData || !turnData.audio_base64) {
164 | console.error('Invalid turn data:', turnData);
165 | return;
166 | }
167 |
168 | updateAvatars();
169 |
170 | try {
171 | if (!audioContext) initAudio();
172 |
173 | const binaryString = atob(turnData.audio_base64);
174 | const len = binaryString.length;
175 | const pcmData = new Uint8Array(len);
176 | for (let i = 0; i < len; i++) {
177 | pcmData[i] = binaryString.charCodeAt(i);
178 | }
179 |
180 | const sampleRate = 24000;
181 | const numChannels = 1;
182 | const wavBuffer = pcm16ToWav(pcmData, sampleRate, numChannels);
183 |
184 | const audioBuffer = await audioContext.decodeAudioData(wavBuffer);
185 |
186 | const source = audioContext.createBufferSource();
187 | source.buffer = audioBuffer;
188 | source.connect(analyser);
189 | analyser.connect(audioContext.destination);
190 |
191 | source.onended = async () => {
192 | isPlaying = false;
193 | currentTurnNumber++;
194 |
195 | if (animationFrameId) {
196 | cancelAnimationFrame(animationFrameId);
197 | animationFrameId = null;
198 | }
199 |
200 | stopVisualizer(document.querySelector('.avatar.active'));
201 |
202 | if (currentTurnNumber >= 7) {
203 | showNotification('Podcast has ended.');
204 | return;
205 | }
206 |
207 | if (preloadedTurn && !isProcessingPreloadedTurn) {
208 | isProcessingPreloadedTurn = true;
209 | combinedTranscript += `\n# Turn Number ${preloadedTurn.turnNumber}\n${preloadedTurn.transcript}`;
210 |
211 | await playTurn(preloadedTurn);
212 |
213 | preloadedTurn = null;
214 | isProcessingPreloadedTurn = false;
215 |
216 | if (currentTurnNumber < 6) {
217 | preloadedTurn = await fetchTurnWithRetry(currentTurnNumber + 1);
218 | if (!preloadedTurn) {
219 | showNotification('Podcast has ended.');
220 | }
221 | }
222 | } else {
223 | showNotification('Podcast has ended.');
224 | }
225 | };
226 |
227 | isPlaying = true;
228 | source.start(0);
229 | animateVisualizer();
230 | } catch (error) {
231 | console.error('Error playing audio:', error);
232 | showNotification('An error occurred during audio playback.');
233 | isPlaying = false
234 | }
235 | }
236 |
237 | function pcm16ToWav(pcmData, sampleRate = 24000, numChannels = 1) {
238 | const byteRate = sampleRate * numChannels * 2;
239 | const buffer = new ArrayBuffer(44 + pcmData.length);
240 | const view = new DataView(buffer);
241 |
242 | writeString(view, 0, 'RIFF');
243 | view.setUint32(4, 36 + pcmData.length, true);
244 | writeString(view, 8, 'WAVE');
245 | writeString(view, 12, 'fmt ');
246 | view.setUint32(16, 16, true);
247 | view.setUint16(20, 1, true);
248 | view.setUint16(22, numChannels, true);
249 | view.setUint32(24, sampleRate, true);
250 | view.setUint32(28, byteRate, true);
251 | view.setUint16(32, numChannels * 2, true);
252 | view.setUint16(34, 16, true);
253 | writeString(view, 36, 'data');
254 | view.setUint32(40, pcmData.length, true);
255 |
256 | const uint8Array = new Uint8Array(buffer, 44);
257 | uint8Array.set(pcmData);
258 |
259 | return buffer;
260 | }
261 |
262 | function writeString(view, offset, string) {
263 | for (let i = 0; i < string.length; i++) {
264 | view.setUint8(offset + i, string.charCodeAt(i));
265 | }
266 | }
267 |
268 | function formatTranscript(turnData) {
269 | return `${turnData.transcript}`;
270 | }
271 |
272 | function addAudienceQuestion() {
273 | const question = document.getElementById('question').value.trim();
274 | if (!question) {
275 | showNotification('Please enter a question.');
276 | return;
277 | }
278 |
279 | pendingAudienceQuestion = question;
280 | document.getElementById('question').value = '';
281 | showNotification('Question submitted. It will be answered in the next turn.');
282 | }
283 |
284 | function animateVisualizer() {
285 | const activeAvatar = document.querySelector('.avatar.active');
286 | if (!activeAvatar || !isPlaying) return;
287 |
288 | const bars = activeAvatar.querySelectorAll('.bar');
289 | const pulse = activeAvatar.querySelector('.pulse');
290 |
291 | analyser.getByteFrequencyData(dataArray);
292 |
293 | let sum = 0;
294 | for (let i = 0; i < bars.length; i++) {
295 | const value = dataArray[i];
296 | const percent = value / 255 * 100;
297 | bars[i].style.height = `${percent}%`;
298 | sum += value;
299 | }
300 |
301 | const average = sum / dataArray.length;
302 | const pulseScale = 1 + (average / 255) * 0.3; // Scale between 1 and 1.3
303 | pulse.style.transform = `scale(${pulseScale})`;
304 |
305 | animationFrameId = requestAnimationFrame(animateVisualizer);
306 | }
307 |
308 | function stopVisualizer(avatar) {
309 | if (animationFrameId) {
310 | cancelAnimationFrame(animationFrameId);
311 | animationFrameId = null;
312 | }
313 |
314 | const bars = avatar.querySelectorAll('.bar');
315 | const pulse = avatar.querySelector('.pulse');
316 | bars.forEach(bar => {
317 | bar.style.height = '0%';
318 | });
319 | pulse.style.transform = 'scale(1)';
320 | }
321 |
322 | function updateAvatars() {
323 | const avatar1 = document.getElementById('avatar1');
324 | const avatar2 = document.getElementById('avatar2');
325 | avatar1.classList.toggle('active', currentTurnNumber % 2 === 0);
326 | avatar2.classList.toggle('active', currentTurnNumber % 2 !== 0);
327 |
328 | stopVisualizer(avatar1);
329 | stopVisualizer(avatar2);
330 | }
331 |
332 | function updateStatusMessage(message, isBlinking = false) {
333 | const statusMessage = document.getElementById('statusMessage');
334 | statusMessage.textContent = message;
335 | statusMessage.classList.remove('hidden');
336 | if (isBlinking) {
337 | statusMessage.classList.add('blink');
338 | } else {
339 | statusMessage.classList.remove('blink');
340 | }
341 | }
342 |
343 | function clearStatusMessage() {
344 | const statusMessage = document.getElementById('statusMessage');
345 | statusMessage.textContent = '';
346 | statusMessage.classList.add('hidden');
347 | statusMessage.classList.remove('blink');
348 | }
349 |
350 | function showRetryMessage(message) {
351 | const retryDiv = document.getElementById('retryMessage');
352 | retryDiv.textContent = message;
353 | retryDiv.classList.remove('hidden');
354 | }
355 |
356 | function clearRetryMessage() {
357 | const retryDiv = document.getElementById('retryMessage');
358 | retryDiv.textContent = '';
359 | retryDiv.classList.add('hidden');
360 | }
361 |
362 | function showNotification(message, duration = 3000) {
363 | const notification = document.getElementById('notification');
364 | notification.textContent = message;
365 | notification.classList.add('show');
366 | setTimeout(() => {
367 | notification.classList.remove('show');
368 | }, duration);
369 | }
370 |
371 | function delay(ms) {
372 | return new Promise(resolve => setTimeout(resolve, ms));
373 | }
--------------------------------------------------------------------------------
/static/styles.css:
--------------------------------------------------------------------------------
1 | @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap');
2 |
3 | :root {
4 | --primary-color: #FFD700;
5 | --secondary-color: #FFA500;
6 | --background-color: #1A1A1A;
7 | --container-color: #2A2A2A;
8 | --text-color: #FFFFFF;
9 | --box-shadow: 0 0 10px rgba(255, 215, 0, 0.3);
10 | }
11 |
12 | * {
13 | box-sizing: border-box;
14 | margin: 0;
15 | padding: 0;
16 | }
17 |
18 | body {
19 | font-family: 'Roboto Mono', monospace;
20 | line-height: 1.6;
21 | color: var(--text-color);
22 | background-color: var(--background-color);
23 | min-height: 100vh;
24 | display: flex;
25 | justify-content: center;
26 | align-items: center;
27 | padding: 20px;
28 | }
29 |
30 | .container {
31 | max-width: 600px;
32 | width: 100%;
33 | background-color: var(--container-color);
34 | padding: 30px;
35 | border-radius: 10px;
36 | box-shadow: var(--box-shadow);
37 | }
38 |
39 | h1 {
40 | color: var(--primary-color);
41 | text-align: center;
42 | margin-bottom: 30px;
43 | font-size: 2.5em;
44 | }
45 |
46 | label {
47 | display: block;
48 | margin-bottom: 10px;
49 | font-weight: bold;
50 | color: var(--primary-color);
51 | }
52 |
53 | textarea, input[type="text"] {
54 | width: 100%;
55 | padding: 12px;
56 | margin-bottom: 20px;
57 | background-color: var(--background-color);
58 | color: var(--text-color);
59 | border: 1px solid var(--primary-color);
60 | font-family: 'Roboto Mono', monospace;
61 | font-size: 1em;
62 | transition: all 0.3s ease;
63 | border-radius: 5px;
64 | }
65 |
66 | textarea {
67 | resize: vertical;
68 | min-height: 100px;
69 | }
70 |
71 | textarea:focus, input[type="text"]:focus {
72 | outline: none;
73 | box-shadow: 0 0 5px var(--primary-color);
74 | }
75 |
76 | button {
77 | background-color: var(--primary-color);
78 | color: var(--background-color);
79 | padding: 12px 20px;
80 | border: none;
81 | cursor: pointer;
82 | font-family: 'Roboto Mono', monospace;
83 | font-size: 1em;
84 | transition: all 0.3s ease;
85 | border-radius: 5px;
86 | }
87 |
88 | button:hover {
89 | background-color: var(--secondary-color);
90 | }
91 |
92 | button:disabled {
93 | background-color: #555555;
94 | cursor: not-allowed;
95 | }
96 |
97 | .hidden {
98 | display: none;
99 | }
100 |
101 | #loading {
102 | text-align: center;
103 | margin-top: 20px;
104 | font-size: 1.2em;
105 | color: var(--primary-color);
106 | }
107 |
108 | #retryMessage {
109 | color: #FF6B6B;
110 | margin-top: 20px;
111 | text-align: center;
112 | font-weight: bold;
113 | }
114 |
115 | #avatarContainer {
116 | display: flex;
117 | justify-content: space-between;
118 | margin-top: 30px;
119 | }
120 |
121 | .avatar {
122 | width: 80px;
123 | height: 80px;
124 | background-color: var(--secondary-color);
125 | display: flex;
126 | align-items: center;
127 | justify-content: center;
128 | font-size: 24px;
129 | transition: all 0.3s ease;
130 | overflow: hidden;
131 | position: relative;
132 | }
133 |
134 | .avatar.active {
135 | background-color: var(--primary-color);
136 | transform: scale(1.1);
137 | }
138 |
139 | .avatar .visualizer {
140 | display: flex;
141 | justify-content: space-around;
142 | align-items: flex-end;
143 | width: 100%;
144 | height: 100%;
145 | padding: 10px;
146 | }
147 |
148 | .avatar .bar {
149 | width: 4px;
150 | background-color: var(--background-color);
151 | transition: height 0.1s ease;
152 | }
153 |
154 | .avatar .pulse {
155 | position: absolute;
156 | width: 100%;
157 | height: 100%;
158 | background-color: rgba(255, 255, 255, 0.2);
159 | transform: scale(1);
160 | transition: transform 0.1s ease;
161 | }
162 |
163 | #audienceQuestionInput {
164 | margin-top: 30px;
165 | padding: 20px;
166 | border: 1px solid var(--primary-color);
167 | border-radius: 5px;
168 | background-color: var(--background-color);
169 | }
170 |
171 | #notification {
172 | position: fixed;
173 | top: 20px;
174 | right: 20px;
175 | padding: 15px 20px;
176 | background-color: var(--primary-color);
177 | color: var(--background-color);
178 | border-radius: 5px;
179 | font-size: 1em;
180 | opacity: 0;
181 | transition: opacity 0.3s ease;
182 | pointer-events: none;
183 | z-index: 1000;
184 | box-shadow: var(--box-shadow);
185 | }
186 |
187 | #notification.show {
188 | opacity: 1;
189 | }
190 |
191 | #statusMessage {
192 | text-align: center;
193 | margin-top: 20px;
194 | font-size: 1.2em;
195 | color: var(--primary-color);
196 | }
197 |
198 | .blink {
199 | animation: blink 1s infinite;
200 | }
201 |
202 | @keyframes blink {
203 | 0% { opacity: 1; }
204 | 50% { opacity: 0; }
205 | 100% { opacity: 1; }
206 | }
--------------------------------------------------------------------------------
/turn_handler.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | import base64
4 | import json
5 | import logging
6 | from typing import Any, Dict, Tuple, Optional
7 |
8 | from app_config import MAX_RETRIES, RETRY_DELAY
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class TurnHandlerError(Exception):
14 | """Custom exception class for TurnHandler errors."""
15 | pass
16 |
17 |
18 | class TurnHandler:
19 | """
20 | A class for handling turns in a podcast conversation.
21 | """
22 |
23 | def __init__(
24 | self,
25 | api_key: str,
26 | api_base: str,
27 | api_key_b: str,
28 | api_base_b: str,
29 | realtime_deployment_name: str
30 | ) -> None:
31 | """
32 | Initialize the TurnHandler.
33 |
34 | Args:
35 | api_key (str): The primary API key.
36 | api_base (str): The primary API base URL.
37 | api_key_b (str): The secondary API key.
38 | api_base_b (str): The secondary API base URL.
39 | realtime_deployment_name (str): The name of the realtime deployment.
40 | """
41 | self.api_key = api_key
42 | self.api_base = api_base.rstrip('/')
43 | self.api_key_b = api_key_b
44 | self.api_base_b = api_base_b.rstrip('/')
45 | self.realtime_deployment_name = realtime_deployment_name
46 | self.session: Optional[aiohttp.ClientSession] = None
47 | logger.debug("TurnHandler initialized with primary and secondary API credentials.")
48 |
49 | async def _get_session(self) -> aiohttp.ClientSession:
50 | """
51 | Get or create an aiohttp ClientSession.
52 |
53 | Returns:
54 | aiohttp.ClientSession: The aiohttp session.
55 | """
56 | if self.session is None or self.session.closed:
57 | self.session = aiohttp.ClientSession()
58 | logger.debug("Created new aiohttp ClientSession for TurnHandler.")
59 | return self.session
60 |
61 | async def close(self) -> None:
62 | """
63 | Close the aiohttp ClientSession if it's open.
64 | """
65 | if self.session and not self.session.closed:
66 | await self.session.close()
67 | logger.debug("Closed aiohttp ClientSession for TurnHandler.")
68 |
69 | async def handle_turn(self, request: Any) -> Dict[str, Any]:
70 | """
71 | Handle a turn in the podcast conversation.
72 |
73 | Args:
74 | request (Any): The turn request containing podcast script, transcript, and other information.
75 |
76 | Returns:
77 | Dict[str, Any]: A dictionary containing the speaker, transcript, and audio data.
78 |
79 | Raises:
80 | TurnHandlerError: If handling the turn fails.
81 | """
82 | try:
83 | audience_question: Optional[str] = getattr(request, 'audience_question', None)
84 | turn_number: int = getattr(request, 'turn', 0)
85 | podcast_script: str = getattr(request, 'podcast_script', '')
86 | combined_transcript: str = getattr(request, 'combined_transcript', '')
87 |
88 | logger.debug(f"Handling turn {turn_number} with audience question: {audience_question}")
89 |
90 | if turn_number % 2 == 0:
91 | speaker = "Dan"
92 | instructions = (
93 | f"You are Dan, the host of a podcast. Make sure to move to the next talking point by asking "
94 | f"questions to Anna. Discuss the following topic: {podcast_script}"
95 | )
96 | voice = "dan"
97 | api_key = self.api_key
98 | api_base = self.api_base
99 | else:
100 | speaker = "Anna"
101 | instructions = (
102 | f"You are Anna, a guest on a podcast. Discuss the following topic: {podcast_script}"
103 | )
104 | voice = "marilyn"
105 | api_key = self.api_key_b
106 | api_base = self.api_base_b
107 |
108 | if audience_question:
109 | instructions += (
110 | f" Answer the audience question before proceeding (start with 'Oh I see we have a question from the audience'): "
111 | f"{audience_question}"
112 | )
113 |
114 | if turn_number == 6:
115 | instructions = (
116 | "You are Dan, you just held a podcast with your guest Anna. Ask her and the audience for her time "
117 | "and thoughts on the podcast. (i.e., do the outro)"
118 | )
119 |
120 | logger.info(f"Generating response for speaker: {speaker}, turn: {turn_number}")
121 | audio_base64, transcript = await self.generate_response(
122 | api_key=api_key,
123 | api_base=api_base,
124 | deployment_name=self.realtime_deployment_name,
125 | context=instructions,
126 | combined_transcript=combined_transcript,
127 | voice=voice,
128 | speaker=speaker
129 | )
130 |
131 | logger.info(f"Turn {turn_number} handled successfully for speaker: {speaker}")
132 | return {
133 | "speaker": speaker,
134 | "transcript": transcript,
135 | "audio_base64": audio_base64
136 | }
137 |
138 | except Exception as e:
139 | logger.exception(f"Failed to handle turn {getattr(request, 'turn', 'unknown')}: {e}")
140 | raise TurnHandlerError(f"Failed to handle turn: {e}") from e
141 |
142 | async def generate_response(
143 | self,
144 | api_key: str,
145 | api_base: str,
146 | deployment_name: str,
147 | context: str,
148 | combined_transcript: str,
149 | voice: str,
150 | speaker: str
151 | ) -> Tuple[str, str]:
152 | """
153 | Generate a response for the current turn.
154 |
155 | Args:
156 | api_key (str): The API key to use.
157 | api_base (str): The API base URL.
158 | deployment_name (str): The name of the deployment.
159 | context (str): The context for the response.
160 | combined_transcript (str): The combined transcript of the conversation so far.
161 | voice (str): The voice to use for the response.
162 | speaker (str): The name of the speaker.
163 |
164 | Returns:
165 | Tuple[str, str]: A tuple containing the base64 encoded audio and the transcript.
166 |
167 | Raises:
168 | TurnHandlerError: If all retry attempts fail.
169 | """
170 | url = f"{api_base.replace('https://', 'wss://')}/openai/realtime"
171 | query_params = f"?api-version=2024-10-01-preview&deployment={deployment_name}"
172 | full_url = f"{url}{query_params}"
173 | headers = {
174 | "api-key": api_key,
175 | "OpenAI-Beta": "realtime=v1"
176 | }
177 |
178 | for attempt in range(1, MAX_RETRIES + 1):
179 | try:
180 | session = await self._get_session()
181 | async with session.ws_connect(full_url, headers=headers) as ws:
182 | logger.debug(f"WebSocket connection established to {full_url}")
183 |
184 | # Wait for session.created message
185 | msg = await ws.receive_json()
186 | if msg.get("type") != "session.created":
187 | raise TurnHandlerError(f"Expected 'session.created', got '{msg.get('type')}'")
188 | logger.debug("Session created successfully.")
189 |
190 | # Send session.update message with instructions
191 | await ws.send_json({
192 | "type": "session.update",
193 | "session": {
194 | "modalities": ["text", "audio"],
195 | "instructions": (
196 | f"You are {speaker}. {context}. Always speak 1-2 sentences at a time. "
197 | f"Continue the conversation (CONTINUE WHERE IT LEFT OFF at !):\n"
198 | f"{combined_transcript}\n\n{speaker}: "
199 | ),
200 | "voice": voice,
201 | "input_audio_format": "pcm16",
202 | "output_audio_format": "pcm16",
203 | "turn_detection": None,
204 | "temperature": 0.6,
205 | }
206 | })
207 | logger.debug("Sent session.update with instructions.")
208 |
209 | # Send response.create message
210 | await ws.send_json({
211 | "type": "response.create",
212 | "response": {
213 | "modalities": ["audio", "text"],
214 | "voice": voice,
215 | "output_audio_format": "pcm16",
216 | "temperature": 0.6,
217 | }
218 | })
219 | logger.debug("Sent response.create message.")
220 |
221 | audio_data = bytearray()
222 | transcript = ""
223 |
224 | async for msg in ws:
225 | if msg.type == aiohttp.WSMsgType.TEXT:
226 | data = json.loads(msg.data)
227 | msg_type = data.get("type")
228 | logger.debug(f"Received WebSocket message of type: {msg_type}")
229 |
230 | if msg_type == "response.audio.delta":
231 | delta = data.get("delta", "")
232 | audio_bytes = base64.b64decode(delta)
233 | audio_data.extend(audio_bytes)
234 | logger.debug(f"Appended audio delta: {len(audio_bytes)} bytes.")
235 |
236 | elif msg_type == "response.audio_transcript.done":
237 | transcript = data.get("transcript", "")
238 | logger.debug("Received completed transcript.")
239 |
240 | elif msg_type == "response.done":
241 | logger.debug("Received response.done message. Closing WebSocket.")
242 | break
243 |
244 | elif msg.type == aiohttp.WSMsgType.ERROR:
245 | error_msg = msg.data
246 | logger.error(f"WebSocket error: {error_msg}")
247 | raise TurnHandlerError(f"WebSocket error: {error_msg}")
248 |
249 | await ws.close()
250 | logger.debug("WebSocket connection closed.")
251 |
252 | # Encode audio data to base64
253 | audio_base64 = base64.b64encode(audio_data).decode('utf-8')
254 | logger.debug("Audio data encoded to base64.")
255 |
256 | if not transcript:
257 | logger.warning("Transcript is empty after response.")
258 | raise TurnHandlerError("Transcript is empty after response.")
259 |
260 | return audio_base64, transcript
261 |
262 | except (aiohttp.ClientError, TurnHandlerError) as e:
263 | logger.error(f"Attempt {attempt} failed with error: {e}")
264 | if attempt < MAX_RETRIES:
265 | logger.info(f"Retrying in {RETRY_DELAY} seconds...")
266 | await asyncio.sleep(RETRY_DELAY)
267 | else:
268 | logger.critical("All retry attempts failed.")
269 | raise TurnHandlerError(f"All retry attempts failed: {e}") from e
270 |
--------------------------------------------------------------------------------