├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── PodcastSnippet.mp3 ├── PodcastSocialMediaCopilot.py ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── dalle_helper.py ├── images └── PodcastCopilotDataFlow.png ├── instruct_pipeline.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /PodcastSnippet.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PodcastCopilot/c2f863ad87688e140e934296925fd46c0255a1d0/PodcastSnippet.mp3 -------------------------------------------------------------------------------- /PodcastSocialMediaCopilot.py: -------------------------------------------------------------------------------- 1 | # The Podcast Copilot will automatically create and post a LinkedIn promotional post for a new episode of the Behind the Tech podcast. 2 | # Given the audio recording of the episode, the copilot will use a locally-hosted Whisper model to transcribe the audio recording. 3 | # The copilot uses the Dolly 2 model to extract the guest's name from the transcript. 4 | # The copilot uses the Bing Search Grounding API to retrieve a bio for the guest. 5 | # The copilot uses the GPT-4 model in the Azure OpenAI Service to generate a social media blurb for the episode, given the transcript and the guest's bio. 6 | # The copilot uses the DALL-E 2 model to generate an image for the post. 7 | # The copilot calls a LinkedIn plugin to post. 8 | 9 | from pydub import AudioSegment 10 | from pydub.silence import split_on_silence 11 | import whisper 12 | import torch 13 | from langchain.chains import TransformChain, LLMChain, SequentialChain 14 | from langchain.chat_models import AzureChatOpenAI 15 | from langchain.llms import HuggingFacePipeline 16 | from langchain.prompts import ( 17 | PromptTemplate, 18 | ChatPromptTemplate, 19 | SystemMessagePromptTemplate, 20 | AIMessagePromptTemplate, 21 | HumanMessagePromptTemplate, 22 | ) 23 | from langchain.schema import ( 24 | AIMessage, 25 | HumanMessage, 26 | SystemMessage 27 | ) 28 | import requests 29 | import time 30 | from PIL import Image 31 | from io import BytesIO 32 | import datetime 33 | import json 34 | from dalle_helper import ImageClient 35 | 36 | # For Dolly 2 37 | from transformers import AutoTokenizer, TextStreamer 38 | from optimum.onnxruntime import ORTModelForCausalLM 39 | from instruct_pipeline import InstructionTextGenerationPipeline 40 | import onnxruntime as ort 41 | ort.set_default_logger_severity(3) 42 | 43 | print("Imports are complete") 44 | 45 | 46 | # Endpoint Settings 47 | bing_search_url = "https://api.bing.microsoft.com/v7.0/search" 48 | bing_subscription_key = "TODO" # Your key will look something like this: 00000000000000000000000000000000 49 | openai_api_type = "azure" 50 | openai_api_base = "https://TODO.openai.azure.com/" # Your endpoint will look something like this: https://YOUR_AOAI_RESOURCE_NAME.openai.azure.com/ 51 | openai_api_key = "TODO" # Your key will look something like this: 00000000000000000000000000000000 52 | gpt4_deployment_name="gpt-4" 53 | 54 | # We are assuming that you have all model deployments on the same Azure OpenAI service resource above. If not, you can change these settings below to point to different resources. 55 | gpt4_endpoint = openai_api_base # Your endpoint will look something like this: https://YOUR_AOAI_RESOURCE_NAME.openai.azure.com/ 56 | gpt4_api_key = openai_api_key # Your key will look something like this: 00000000000000000000000000000000 57 | dalle_endpoint = openai_api_base # Your endpoint will look something like this: https://YOUR_AOAI_RESOURCE_NAME.openai.azure.com/ 58 | dalle_api_key = openai_api_key # Your key will look something like this: 00000000000000000000000000000000 59 | plugin_model_url = openai_api_base 60 | plugin_model_api_key = openai_api_key # Your key will look something like this: 00000000000000000000000000000000 61 | 62 | # Inputs about the podcast 63 | podcast_url = "https://www.microsoft.com/behind-the-tech" 64 | podcast_audio_file = ".\PodcastSnippet.mp3" 65 | 66 | 67 | # Step 1 - Call Whisper to transcribe audio 68 | print("Calling Whisper to transcribe audio...\n") 69 | 70 | # Chunk up the audio file 71 | sound_file = AudioSegment.from_mp3(podcast_audio_file) 72 | audio_chunks = split_on_silence(sound_file, min_silence_len=1000, silence_thresh=-40 ) 73 | count = len(audio_chunks) 74 | print("Audio split into " + str(count) + " audio chunks") 75 | 76 | # Call Whisper to transcribe audio 77 | model = whisper.load_model("base") 78 | transcript = "" 79 | for i, chunk in enumerate(audio_chunks): 80 | # If you have a long audio file, you can enable this to only run for a subset of chunks 81 | if i < 10 or i > count - 10: 82 | out_file = "chunk{0}.wav".format(i) 83 | print("Exporting", out_file) 84 | chunk.export(out_file, format="wav") 85 | result = model.transcribe(out_file) 86 | transcriptChunk = result["text"] 87 | print(transcriptChunk) 88 | 89 | # Append transcript in memory if you have sufficient memory 90 | transcript += " " + transcriptChunk 91 | 92 | # Alternatively, here's how to write the transcript to disk if you have memory constraints 93 | #textfile = open("chunk{0}.txt".format(i), "w") 94 | #textfile.write(transcript) 95 | #textfile.close() 96 | #print("Exported chunk{0}.txt".format(i)) 97 | 98 | print("Transcript: \n") 99 | print(transcript) 100 | print("\n") 101 | 102 | 103 | # Step 2 - Make a call to a local Dolly 2.0 model optimized for Windows to extract the name of who I'm interviewing from the transcript 104 | print("Calling a local Dolly 2.0 model optimized for Windows to extract the name of the podcast guest...\n") 105 | repo_id = "microsoft/dolly-v2-7b-olive-optimized" 106 | tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left") 107 | model = ORTModelForCausalLM.from_pretrained(repo_id, provider="DmlExecutionProvider", use_cache=True, use_merged=True, use_io_binding=False) 108 | streamer = TextStreamer(tokenizer, skip_prompt=True) 109 | generate_text = InstructionTextGenerationPipeline(model=model, streamer=streamer, tokenizer=tokenizer, max_new_tokens=128, return_full_text=True, task="text-generation") 110 | hf_pipeline = HuggingFacePipeline(pipeline=generate_text) 111 | 112 | dolly2_prompt = PromptTemplate( 113 | input_variables=["transcript"], 114 | template="Extract the guest name on the Beyond the Tech podcast from the following transcript. Beyond the Tech is hosted by Kevin Scott and Christina Warren, so they will never be the guests. \n\n Transcript: {transcript}\n\n Host name: Kevin Scott\n\n Guest name: " 115 | ) 116 | 117 | extract_llm_chain = LLMChain(llm=hf_pipeline, prompt=dolly2_prompt, output_key="guest") 118 | guest = extract_llm_chain.predict(transcript=transcript) 119 | 120 | print("Guest:\n") 121 | print(guest) 122 | print("\n") 123 | 124 | 125 | # Step 3 - Make a call to the Bing Search Grounding API to retrieve a bio for the guest 126 | def bing_grounding(input_dict:dict) -> dict: 127 | print("Calling Bing Search API to get bio for guest...\n") 128 | search_term = input_dict["guest"] 129 | print("Search term is " + search_term) 130 | 131 | headers = {"Ocp-Apim-Subscription-Key": bing_subscription_key} 132 | params = {"q": search_term, "textDecorations": True, "textFormat": "HTML"} 133 | response = requests.get(bing_search_url, headers=headers, params=params) 134 | response.raise_for_status() 135 | search_results = response.json() 136 | #print(search_results) 137 | 138 | # Parse out a bio. 139 | bio = search_results["webPages"]["value"][0]["snippet"] 140 | 141 | print("Bio:\n") 142 | print(bio) 143 | print("\n") 144 | 145 | return {"bio": bio} 146 | 147 | bing_chain = TransformChain(input_variables=["guest"], output_variables=["bio"], transform=bing_grounding) 148 | bio = bing_chain.run(guest) 149 | 150 | 151 | # Step 4 - Put bio in the prompt with the transcript 152 | system_template="You are a helpful large language model that can create a LinkedIn promo blurb for episodes of the podcast Behind the Tech, when given transcripts of the podcasts. The Behind the Tech podcast is hosted by Kevin Scott.\n" 153 | system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) 154 | 155 | user_prompt=PromptTemplate( 156 | template="Create a short summary of this podcast episode that would be appropriate to post on LinkedIn to promote the podcast episode. The post should be from the first-person perspective of Kevin Scott, who hosts the podcast.\n" + 157 | "Here is the transcript of the podcast episode: {transcript} \n" + 158 | "Here is the bio of the guest: {bio} \n", 159 | input_variables=["transcript", "bio"], 160 | ) 161 | human_message_prompt = HumanMessagePromptTemplate(prompt=user_prompt) 162 | chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) 163 | 164 | # Get formatted messages for the chat completion 165 | blurb_messages = chat_prompt.format_prompt(transcript={transcript}, bio={bio}).to_messages() 166 | 167 | 168 | # Step 5 - Make a call to Azure OpenAI Service to get a social media blurb, 169 | print("Calling GPT-4 model on Azure OpenAI Service to get a social media blurb...\n") 170 | gpt4 = AzureChatOpenAI( 171 | openai_api_base=gpt4_endpoint, 172 | openai_api_version="2023-03-15-preview", 173 | deployment_name=gpt4_deployment_name, 174 | openai_api_key=gpt4_api_key, 175 | openai_api_type = openai_api_type, 176 | ) 177 | #print(gpt4) #shows parameters 178 | 179 | output = gpt4(blurb_messages) 180 | social_media_copy = output.content 181 | 182 | gpt4_chain = LLMChain(llm=gpt4, prompt=chat_prompt, output_key="social_media_copy") 183 | 184 | print("Social Media Copy:\n") 185 | print(social_media_copy) 186 | print("\n") 187 | 188 | 189 | # Step 6 - Use GPT-4 to generate a DALL-E prompt 190 | system_template="You are a helpful large language model that generates DALL-E prompts, that when given to the DALL-E model can generate beautiful high-quality images to use in social media posts about a podcast on technology. Good DALL-E prompts will contain mention of related objects, and will not contain people or words. Good DALL-E prompts should include a reference to podcasting along with items from the domain of the podcast guest.\n" 191 | system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) 192 | 193 | user_prompt=PromptTemplate( 194 | template="Create a DALL-E prompt to create an image to post along with this social media text: {social_media_copy}", 195 | input_variables=["social_media_copy"], 196 | ) 197 | human_message_prompt = HumanMessagePromptTemplate(prompt=user_prompt) 198 | chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) 199 | 200 | # Get formatted messages for the chat completion 201 | dalle_messages = chat_prompt.format_prompt(social_media_copy={social_media_copy}).to_messages() 202 | 203 | # Call Azure OpenAI Service to get a DALL-E prompt 204 | print("Calling GPT-4 model on Azure OpenAI Service to get a DALL-E prompt...\n") 205 | gpt4 = AzureChatOpenAI( 206 | openai_api_base=gpt4_endpoint, 207 | openai_api_version="2023-03-15-preview", 208 | deployment_name=gpt4_deployment_name, 209 | openai_api_key=gpt4_api_key, 210 | openai_api_type = openai_api_type, 211 | ) 212 | #print(gpt4) #shows parameters 213 | 214 | output = gpt4(dalle_messages) 215 | dalle_prompt = output.content 216 | 217 | dalle_prompt_chain = LLMChain(llm=gpt4, prompt=chat_prompt, output_key="dalle_prompt") 218 | 219 | print("DALL-E Prompt:\n") 220 | print(dalle_prompt) 221 | print("\n") 222 | 223 | 224 | # For the demo, we showed the step by step execution of each chain above, but you can also run the entire chain in one step. 225 | # You can uncomment and run the following code for an example. Feel free to substitute your own transcript. 226 | ''' 227 | transcript = "Hello, and welcome to Beyond the Tech podcast. I am your host, Kevin Scott. I am the CTO of Microsoft. I am joined today by an amazing guest, Lionel Messi. Messi is an accomplished soccer player for the Paris Saint-Germain football club. Lionel, how are you doing today?" 228 | 229 | podcast_copilot_chain = SequentialChain( 230 | chains=[extract_llm_chain, bing_chain, gpt4_chain, dalle_prompt_chain], 231 | input_variables=["transcript"], 232 | output_variables=["guest", "bio", "social_media_copy", "dalle_prompt"], 233 | verbose=True) 234 | podcast_copilot = podcast_copilot_chain({"transcript":transcript}) 235 | print(podcast_copilot) # This is helpful for debugging. 236 | social_media_copy = podcast_copilot["social_media_copy"] 237 | dalle_prompt = podcast_copilot["dalle_prompt"] 238 | 239 | print("Social Media Copy:\n") 240 | print(social_media_copy) 241 | print("\n") 242 | ''' 243 | 244 | 245 | # Append "high-quality digital art" to the generated DALL-E prompt 246 | dalle_prompt = dalle_prompt + ", high-quality digital art" 247 | 248 | 249 | # Step 7 - Make a call to DALL-E model on the Azure OpenAI Service to generate an image 250 | print("Calling DALL-E model on Azure OpenAI Service to get an image for social media...\n") 251 | 252 | # Establish the client class instance 253 | client = ImageClient(dalle_endpoint, dalle_api_key, verbose=False) # change verbose to True for including debug print statements 254 | 255 | # Generate an image 256 | imageURL, postImage = client.generateImage(dalle_prompt) 257 | print("Image URL: " + imageURL + "\n") 258 | 259 | # Write image to file - this is optional if you would like to have a local copy of the image 260 | stream = BytesIO(postImage) 261 | image = Image.open(stream).convert("RGB") 262 | stream.close() 263 | photo_path = ".\PostImage.jpg" 264 | image.save(photo_path) 265 | print("Image: saved to PostImage.jpg\n") 266 | 267 | 268 | # Append the podcast URL to the generated social media copy 269 | social_media_copy = social_media_copy + " " + podcast_url 270 | 271 | 272 | # Step 8 - Call the LinkedIn Plugin for Copilots to do the post. 273 | # Currently there is not support in the SDK for the plugin model on Azure OpenAI, so we are using the REST API directly. 274 | PROMPT_MESSAGES = [ 275 | { 276 | "role": "system", 277 | "content": "You are a helpful large language model that can post a LinkedIn promo blurb for episodes of Behind the Tech with Kevin Scott, when given some text and a link to an image.\n", 278 | }, 279 | { 280 | "role": "user", 281 | "content": 282 | "Post the following social media text to LinkedIn to promote my latest podcast episode: \n" + 283 | "Here is the text to post: \n" + social_media_copy + "\n" + 284 | "Here is a link to the image that should be included with the post: \n" + imageURL + "\n", 285 | }, 286 | ] 287 | 288 | print("Calling GPT-4 model with plugin support on Azure OpenAI Service to post to LinkedIn...\n") 289 | 290 | payload = { 291 | "messages": PROMPT_MESSAGES, 292 | "max_tokens": 1024, 293 | "temperature": 0.5, 294 | "n": 1, 295 | "stop": None 296 | } 297 | 298 | headers = { 299 | "Content-Type": "application/json", 300 | "api-key": plugin_model_api_key, 301 | } 302 | 303 | # Confirm whether it is okay to post, to follow Responsible AI best practices 304 | print("The following will be posted to LinkedIn:\n") 305 | print(social_media_copy + "\n") 306 | confirm = input("Do you want to post this to LinkedIn? (y/n): ") 307 | if confirm == "y": 308 | # Call a model with plugin support. 309 | response = requests.post(plugin_model_url, headers=headers, data=json.dumps(payload)) 310 | 311 | #print (type(response)) 312 | print("Response:\n") 313 | print(response) 314 | print("Headers:\n") 315 | print(response.headers) 316 | print("Json:\n") 317 | print(response.json()) 318 | 319 | response_dict = response.json() 320 | print(response_dict["choices"][0]["messages"][-1]["content"]) 321 | 322 | # To use plugins, you must call a model that understands how to leverage them. Support for plugins is in limited private preview 323 | # for the Azure OpenAI service, and a LinkedIn plugin is coming soon! 324 | 325 | 326 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Podcast Copilot 2 | 3 | This code was demonstrated at the Build 2023 keynote by Microsoft CTO Kevin Scott, illustrating the architecture of a Copilot. 4 | 5 | Kevin Scott hosts a podcast, [Behind the Tech](https://www.microsoft.com/behind-the-tech). This Podcast Copilot makes it easier to generate a social media post promoting a new episode of the podcast, when given the audio file for the podcast. The Podcast Copilot uses a series of machine learning models orchestrated by LangChain to do this: 6 | + Given the podcast audio file, the Whisper model performs speech-to-text to generate a transcript of the podcast. 7 | + Given this transcript, the Dolly 2 model extracts the name of the guest on the podcast. 8 | + Given the guest name, the Bing Search Grounding API retrieves a bio for the guest from the internet. 9 | + Given the transcript and guest's bio, the GPT-4 model generates a social media post promoting the podcast episode. 10 | + Given the social media post, we use GPT-4 to create a relevant DALL-E prompt. 11 | + Given that DALL-E prompt, the DALL-E model generates a corresponding image for the post. 12 | + Finally, the user has an opportunity to review the content before posting, and if approved, a LinkedIn plugin will post the social media copy and image to LinkedIn. 13 | 14 | ![Diagram of the data flow and chain of machine learning models described above](./images/PodcastCopilotDataFlow.png) 15 | 16 | For the demo, we ran Whisper and Dolly 2 locally. The Bing Search Grounding API is available on Azure. We used model deployments of GPT-4, DALL-E 2, and a plugins-capable model on the Azure OpenAI service. 17 | 18 | Please note that as of Build (May 2023): 19 | + The DALL-E models are still in private preview. For the DALL-E model, you must request access using the form at https://aka.ms/oai/access and in question #22, request access to the DALL-E models for image generation. 20 | + The plugins-capable models are not publicly released yet. 21 | 22 | ## Setup 23 | 24 | This project requires creating an Azure OpenAI resource to run several cloud-based models. 25 | + You can request access to Azure OpenAI at https://aka.ms/oai/access. 26 | + After approval, create an Azure OpenAI resource at https://portal.azure.com/#create/Microsoft.CognitiveServicesOpenAI following the instructions at https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource. 27 | + You will need to create model deployments of the following models: gpt-4, dalle, and a plugins-capable model. Follow the instructions [here](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource#deploy-a-model). 28 | 29 | You will also need to create a Bing search resource at https://portal.azure.com/#create/Microsoft.BingSearch. 30 | 31 | Next, update the PodcastSocialMediaCopilot.py file with your settings. 32 | + Update **bing_subscription_key** with the API key of your Bing resource on Azure. 33 | + Update **openai_api_base** with the name of your Azure OpenAI resource; this value should look like this: "https://YOUR_AOAI_RESOURCE_NAME.openai.azure.com/" 34 | + Update **openai_api_key** with the corresponding API key for your Azure OpenAI resource. 35 | + Update **gpt4_deployment_name** with the name of your model deployment for GPT-4 in your Azure OpenAI resource. 36 | + If your model deployments for gpt-4, dalle, and the plugins-capable model are all on the same Azure OpenAI resource, you're all set! If not, you can override the individual endpoints and keys for the resources for the various model deployments using the variables **gpt4_endpoint**, **gpt4_api_key**, **dalle_endpoint**, **dalle_api_key**, **plugin_model_url**, and **plugin_model_api_key**. 37 | + Optionally, you can also update the **podcast_url** and **podcast_audio_file** to reflect your own podcast. 38 | 39 | Finally, set up your environment and run the code using the following commands: 40 | ``` 41 | pip install -r requirements.txt 42 | python PodcastSocialMediaCopilot.py 43 | ``` 44 | 45 | ## Contributing 46 | 47 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 48 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 49 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 50 | 51 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 52 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 53 | provided by the bot. You will only need to do this once across all repos using our CLA. 54 | 55 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 56 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 57 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 58 | 59 | ## Trademarks 60 | 61 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 62 | trademarks or logos is subject to and must follow 63 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 64 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 65 | Any use of third-party trademarks or logos are subject to those third-party's policies. 66 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please file a GitHub issue on this repo. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this project is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /dalle_helper.py: -------------------------------------------------------------------------------- 1 | # Helper class for DALL-E 2 | # The following class creates a simple wrapper on the Azure OpenAI REST endpoints. It will simplify the steps for calling the text-to-image API to submit your request and then poll for the results 3 | 4 | import requests 5 | import time 6 | 7 | class ImageClient: 8 | def __init__(self, endpoint, key, api_version = "2022-08-03-preview", verbose=False): 9 | # These are the paramters for the class: 10 | # ### endpoint: The endpoint for your Azure OpenAI resource 11 | # ### key: The API key for your Azure OpenAI resource 12 | # ### api_version: The API version to use. This is optional and defaults to the latest version 13 | self.endpoint = endpoint 14 | self.api_key = key 15 | self.api_version = api_version 16 | self.verbose = verbose 17 | 18 | def text_to_image(self, prompt): 19 | # this method makes the text-to-image API call. It will return the raw response from the API call 20 | 21 | reqURL = requests.models.PreparedRequest() 22 | params = {'api-version':self.api_version} 23 | #the full endpoint will look something like this https://YOUR_AOAI_RESOURCE_NAME.openai.azure.com/dalle/text-to-image 24 | reqURL.prepare_url(self.endpoint + "dalle/text-to-image", params) 25 | if self.verbose: 26 | print("Sending a POST call to the following URL: {URL}".format(URL=reqURL.url)) 27 | 28 | #Construct the data payload for the call. This includes the prompt text as well as many optional parameters. 29 | payload = { "caption": prompt} 30 | 31 | r = requests.post(reqURL.url, 32 | headers={ 33 | "Api-key": self.api_key, 34 | "Content-Type": "application/json" 35 | }, 36 | json = payload 37 | ) 38 | # Response Body example: { "id": "80b095cb-4248-4fa7-90c2-933f0907fb2a", "status": "Running" } 39 | # Key headers: 40 | # Operation-Location: URL to get response 41 | # Retry-after: 3 //seconds to respond 42 | 43 | if r.status_code != 202: 44 | print("Error: {error}".format(error=r.json())) 45 | 46 | data = r.json() 47 | if self.verbose: 48 | print('text-to-image API response body:') 49 | print(data) 50 | return r 51 | 52 | def getImageResults(self, operation_location): 53 | # This method will make an API call to get the status/results of the text-to-image API call using the 54 | # Operation-Location header from the original API call 55 | 56 | params = {'api-version':self.api_version} 57 | # the full endpoint will look something like this 58 | # https://YOUR_RESOURCE_NAME.openai.azure.com/dalle/text-to-image/operations/OPERATION_ID_FROM_PRIOR_RESPONSE?api-version=2022-08-03-preview 59 | 60 | if self.verbose: 61 | print("Sending a POST call to the following URL: {URL}".format(URL=operation_location)) 62 | 63 | r = requests.get(operation_location, 64 | headers={ 65 | "Api-key": self.api_key, 66 | "Content-Type": "application/json" 67 | } 68 | ) 69 | 70 | data = r.json() 71 | 72 | if self.verbose: 73 | print('Get Image results call response body') 74 | print(data) 75 | return r 76 | 77 | # Sending a POST call to the following URL: 78 | # {'id': 'd63fc675-f751-40b7-a297-e692c3b966b9', 'result': {'caption': 'An avocado chair.', 'contentUrl': '', 'contentUrlExpiresAt': '2022-08-13T22:52:45Z', 'createdDateTime': '2022-08-13T21:50:55Z'}, 'status': 'Succeeded'} 79 | 80 | 81 | def getImage(self, contentUrl): 82 | # Download the images from the given URL 83 | r = requests.get(contentUrl) 84 | return r 85 | 86 | 87 | def generateImage(self, prompt): 88 | submission = self.text_to_image( prompt) 89 | if self.verbose: 90 | print('Response code from submission') 91 | print(submission.status_code) 92 | print('Response body:') 93 | print(submission.json()) 94 | if submission.status_code == 202: 95 | operation_location = submission.headers['Operation-Location'] 96 | retry_after = submission.headers['Retry-after'] 97 | else: 98 | print('Not a 202 response') 99 | return "-1" 100 | 101 | #wait to request 102 | status = "not running" 103 | while status != "Succeeded": 104 | if self.verbose: 105 | print('retry after: ' + retry_after) 106 | time.sleep(int(retry_after)) 107 | r = self.getImageResults(operation_location) 108 | # print(r.status_code) 109 | # print(r.headers) 110 | # print(r.json()) 111 | status = r.json()['status'] 112 | # print(status) 113 | if status == "Failed": 114 | return "-1" 115 | 116 | contentUrl = r.json()['result']['contentUrl'] 117 | image = self.getImage(contentUrl) 118 | return contentUrl, image.content 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /images/PodcastCopilotDataFlow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PodcastCopilot/c2f863ad87688e140e934296925fd46c0255a1d0/images/PodcastCopilotDataFlow.png -------------------------------------------------------------------------------- /instruct_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import List 4 | 5 | import numpy as np 6 | from transformers import Pipeline, PreTrainedTokenizer 7 | 8 | from transformers.utils import is_tf_available 9 | from transformers import TextStreamer 10 | 11 | if is_tf_available(): 12 | import tensorflow as tf 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | INSTRUCTION_KEY = "### Instruction:" 17 | RESPONSE_KEY = "### Response:" 18 | END_KEY = "### End" 19 | INTRO_BLURB = ( 20 | "Below is an instruction that describes a task. Write a response that appropriately completes the request." 21 | ) 22 | 23 | # This is the prompt that is used for generating responses using an already trained model. It ends with the response 24 | # key, where the job of the model is to provide the completion that follows it (i.e. the response itself). 25 | PROMPT_FOR_GENERATION_FORMAT = """{intro} 26 | {instruction_key} 27 | {instruction} 28 | {response_key} 29 | """.format( 30 | intro=INTRO_BLURB, 31 | instruction_key=INSTRUCTION_KEY, 32 | instruction="{instruction}", 33 | response_key=RESPONSE_KEY, 34 | ) 35 | 36 | 37 | def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: 38 | """Gets the token ID for a given string that has been added to the tokenizer as a special token. 39 | When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are 40 | treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. 41 | Args: 42 | tokenizer (PreTrainedTokenizer): the tokenizer 43 | key (str): the key to convert to a single token 44 | Raises: 45 | RuntimeError: if more than one ID was generated 46 | Returns: 47 | int: the token ID for the given key 48 | """ 49 | token_ids = tokenizer.encode(key) 50 | if len(token_ids) > 1: 51 | raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") 52 | return token_ids[0] 53 | 54 | 55 | class InstructionTextGenerationPipeline(Pipeline): 56 | def __init__( 57 | self, *args, do_sample: bool = True, max_new_tokens: int = 256, streamer: TextStreamer, top_p: float = 0.92, top_k: int = 0, **kwargs 58 | ): 59 | """Initialize the pipeline 60 | Args: 61 | do_sample (bool, optional): Whether or not to use sampling. Defaults to True. 62 | max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128. 63 | top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with 64 | probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92. 65 | top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering. 66 | Defaults to 0. 67 | """ 68 | super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, 69 | **kwargs) 70 | self.streamer = streamer 71 | 72 | def _sanitize_parameters(self, 73 | return_full_text: bool = None, 74 | **generate_kwargs): 75 | preprocess_params = {} 76 | 77 | # newer versions of the tokenizer configure the response key as a special token. newer versions still may 78 | # append a newline to yield a single token. find whatever token is configured for the response key. 79 | tokenizer_response_key = next( 80 | (token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None 81 | ) 82 | 83 | response_key_token_id = None 84 | end_key_token_id = None 85 | if tokenizer_response_key: 86 | try: 87 | response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key) 88 | end_key_token_id = get_special_token_id(self.tokenizer, END_KEY) 89 | 90 | # Ensure generation stops once it generates "### End" 91 | generate_kwargs["eos_token_id"] = end_key_token_id 92 | except ValueError: 93 | pass 94 | 95 | forward_params = generate_kwargs 96 | postprocess_params = { 97 | "response_key_token_id": response_key_token_id, 98 | "end_key_token_id": end_key_token_id 99 | } 100 | 101 | if return_full_text is not None: 102 | postprocess_params["return_full_text"] = return_full_text 103 | 104 | return preprocess_params, forward_params, postprocess_params 105 | 106 | def preprocess(self, instruction_text, **generate_kwargs): 107 | prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text) 108 | inputs = self.tokenizer( 109 | prompt_text, 110 | return_tensors="pt", 111 | ) 112 | inputs["prompt_text"] = prompt_text 113 | inputs["instruction_text"] = instruction_text 114 | return inputs 115 | 116 | def _forward(self, model_inputs, **generate_kwargs): 117 | input_ids = model_inputs["input_ids"] 118 | attention_mask = model_inputs.get("attention_mask", None) 119 | 120 | if input_ids.shape[1] == 0: 121 | input_ids = None 122 | attention_mask = None 123 | in_b = 1 124 | else: 125 | in_b = input_ids.shape[0] 126 | 127 | generated_sequence = self.model.generate( 128 | input_ids=input_ids.to(self.model.device), 129 | attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, 130 | pad_token_id=self.tokenizer.pad_token_id, 131 | streamer=self.streamer, 132 | **generate_kwargs, 133 | ) 134 | 135 | out_b = generated_sequence.shape[0] 136 | if self.framework == "pt": 137 | generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) 138 | elif self.framework == "tf": 139 | generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) 140 | 141 | instruction_text = model_inputs.pop("instruction_text") 142 | return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} 143 | 144 | def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_full_text: bool = False): 145 | 146 | generated_sequence = model_outputs["generated_sequence"][0] 147 | instruction_text = model_outputs["instruction_text"] 148 | 149 | generated_sequence: List[List[int]] = generated_sequence.numpy().tolist() 150 | records = [] 151 | for sequence in generated_sequence: 152 | 153 | # The response will be set to this variable if we can identify it. 154 | decoded = None 155 | 156 | # If we have token IDs for the response and end, then we can find the tokens and only decode between them. 157 | if response_key_token_id and end_key_token_id: 158 | # Find where "### Response:" is first found in the generated tokens. Considering this is part of the 159 | # prompt, we should definitely find it. We will return the tokens found after this token. 160 | try: 161 | response_pos = sequence.index(response_key_token_id) 162 | except ValueError: 163 | logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}") 164 | response_pos = None 165 | 166 | if response_pos: 167 | # Next find where "### End" is located. The model has been trained to end its responses with this 168 | # sequence (or actually, the token ID it maps to, since it is a special token). We may not find 169 | # this token, as the response could be truncated. If we don't find it then just return everything 170 | # to the end. Note that even though we set eos_token_id, we still see the this token at the end. 171 | try: 172 | end_pos = sequence.index(end_key_token_id) 173 | except ValueError: 174 | end_pos = None 175 | 176 | decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip() 177 | 178 | if not decoded: 179 | # Otherwise we'll decode everything and use a regex to find the response and end. 180 | 181 | fully_decoded = self.tokenizer.decode(sequence) 182 | 183 | # The response appears after "### Response:". The model has been trained to append "### End" at the 184 | # end. 185 | m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL) 186 | 187 | if m: 188 | decoded = m.group(1).strip() 189 | else: 190 | # The model might not generate the "### End" sequence before reaching the max tokens. In this case, 191 | # return everything after "### Response:". 192 | m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL) 193 | if m: 194 | decoded = m.group(1).strip() 195 | else: 196 | logger.warn(f"Failed to find response in:\n{fully_decoded}") 197 | 198 | # If the full text is requested, then append the decoded text to the original instruction. 199 | # This technically isn't the full text, as we format the instruction in the prompt the model has been 200 | # trained on, but to the client it will appear to be the full text. 201 | if return_full_text: 202 | decoded = f"{instruction_text}\n{decoded}" 203 | 204 | rec = {"generated_text": decoded} 205 | 206 | records.append(rec) 207 | 208 | return records -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | optimum 4 | onnx 5 | openai-whisper 6 | langchain 7 | pydub 8 | openai 9 | ffmpeg-python 10 | onnxruntime-directml>=1.15.0 11 | --------------------------------------------------------------------------------