├── .gitignore ├── .streamlit └── config.toml ├── README.md ├── Welcome.py ├── assets ├── font.woff ├── font.woff2 ├── studio_logo.png └── studio_logo.svg ├── constants.py ├── pages ├── 10_Contextual_Answers.py ├── 11_Multi_Document_Q&A.py ├── 1_Blog_Post_Generator.py ├── 2_Product_Description_Generator.py ├── 4_Pitch_Email_Generator.py ├── 5_Social_Media_Generator.py ├── 6_Intent_Classifier.py ├── 7_Topic_Classifier.py ├── 8_Rewrite_Tool.py └── 9_Document_Summarizer.py ├── requirements.in ├── requirements.txt └── utils ├── __init__.py ├── completion.py ├── components ├── __init__.py └── completion_log.py ├── filters.py └── studio_style.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .streamlit 3 | .streamlit/secrets.toml 4 | .python-version 5 | 6 | .DS_Store 7 | file/ 8 | __pycache__/ 9 | .devcontainer/ -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor="#E91E63" 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | logo 2 | 3 | # AI21 Studio Demos 4 | 5 | This repository demonstrates the following examples of AI21 LLM use cases with Streamlit apps: 6 | - Blog Post Generator 7 | - Product Description Generator 8 | - Pitch Email and Social Media Generator with Article Summerization 9 | - Topic Classifier 10 | - Paraphrases & Rewrites Tool 11 | - Document and Website Summarizer 12 | - Contextual Answers 13 | - Multiple Document Q&A 14 | 15 | [Streamlit Demo App](https://ai21-studio-demos.streamlit.app/) 16 | 17 | # Setup the app with your AI21 account 18 | - Create `secrets.toml` in `.streamlit` folder, and add your credentials (replace `` with your AI21 API Key): 19 | ``` 20 | [api-keys] 21 | ai21-api-key = "" 22 | ``` 23 | - Run `Welcome.py` 24 | 25 | ## Request AI21 account and API key 26 | - [Create](https://studio.ai21.com/login) your AI21 Account 27 | - Locate your [AI21 API Key](https://studio.ai21.com/account/api-key) 28 | -------------------------------------------------------------------------------- /Welcome.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils.studio_style import apply_studio_style 3 | 4 | 5 | if __name__ == '__main__': 6 | st.set_page_config( 7 | page_title="Welcome" 8 | ) 9 | apply_studio_style() 10 | st.title("Welcome to AI21 Studio demos") 11 | st.markdown("Experience the incredible power of large language models first-hand. With these demos, you can explore a variety of unique use cases that showcase what our sophisticated technology is truly capable of. From instant content generation to a paraphraser that can rewrite any text, the world of AI text generation will be at your fingertips." ) 12 | st.markdown("Check out the brains behind the demos here: https://www.ai21.com/studio") 13 | st.markdown("Please note that this is a limited demonstration of AI21 Studio's capabilities. If you're interested in learning more, contact us at studio@ai21.com") 14 | -------------------------------------------------------------------------------- /assets/font.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/font.woff -------------------------------------------------------------------------------- /assets/font.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/font.woff2 -------------------------------------------------------------------------------- /assets/studio_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/studio_logo.png -------------------------------------------------------------------------------- /assets/studio_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | from ai21 import AI21Client 2 | import streamlit as st 3 | 4 | client = AI21Client(api_key=st.secrets['api-keys']['ai21-api-key']) 5 | 6 | DEFAULT_MODEL = 'j2-ultra' 7 | 8 | SUMMARIZATION_URL = "https://www.ai21.com/blog/announcing-ai21-studio-and-jurassic-1" 9 | SUMMARIZATION_TEXT = '''Perhaps no other crisis in modern history has had as great an impact on daily human existence as COVID-19. And none has forced businesses throughout the world to accelerate their evolution as their leaders worked to respond and recover on the way to thriving in the postpandemic environment. 10 | 11 | Deloitte Private’s latest global survey of private enterprises reveals that executives in every region used the crisis as a catalyst, accelerating change in virtually all aspects of how we work and live. They stepped up their digital transformation through greater technology investment and deployment. In-progress initiatives were pushed toward completion, while those that were on the drawing board came to life. They sought out new partnerships and alliances. They pursued new opportunities to strengthen their supply networks and grow markets. They increased efforts to understand their purpose beyond profits, seeking new ways to grow sustainably and strengthen trust with their employees, customers, and other key stakeholders. They also embraced new possibilities in how and where work gets done. 12 | ''' 13 | 14 | CLASSIFICATION_FEWSHOT="""Classify the following news article into one of the following topics: 15 | 1. World 16 | 2. Sports 17 | 3. Business 18 | 4. Science and Technology 19 | Title: 20 | Astronomers Observe Collision of Galaxies, Formation of Larger 21 | Summary: 22 | An international team of astronomers has obtained the clearest images yet of the merger of two distant clusters of galaxies, calling it one of the most powerful cosmic events ever witnessed. 23 | The topic of this article is: 24 | Science and Technology 25 | 26 | === 27 | 28 | Classify the following news article into one of the following topics: 29 | 1. World 30 | 2. Sports 31 | 3. Business 32 | 4. Science and Technology 33 | Title: 34 | Bomb Explodes Near U.S. Military Convoy (AP) 35 | Summary: 36 | AP - A car bomb exploded early Sunday near a U.S. military convoy on the road leading to Baghdad's airport, Iraqi police said, and a witness said two Humvees were destroyed. 37 | The topic of this article is: 38 | World 39 | 40 | === 41 | 42 | Classify the following news article into one of the following topics: 43 | 1. World 44 | 2. Sports 45 | 3. Business 46 | 4. Science and Technology 47 | Title: 48 | Maradona goes to Cuba 49 | Summary: 50 | The former Argentine football star, Diego Armando Maradona, traveled on Monday to Cuba to continue his treatment against his addiction to drugs. 51 | The topic of this article is: 52 | Sports 53 | 54 | === 55 | 56 | Classify the following news article into one of the following topics: 57 | 1. World 58 | 2. Sports 59 | 3. Business 60 | 4. Science and Technology 61 | Title: 62 | Duke earnings jump in third quarter 63 | Summary: 64 | Duke Energy Corp. reports third-quarter net income of $389 million, or 41 cents per diluted share, sharply above earnings of $49 million, or 5 cents per diluted share, in the same period last year. 65 | The topic of this article is: 66 | Business 67 | 68 | === 69 | 70 | """ 71 | 72 | CLASSIFICATION_PROMPT="""Classify the following news article into one of the following topics: 73 | 1. World 74 | 2. Sports 75 | 3. Business 76 | 4. Science and Technology""" 77 | 78 | CLASSIFICATION_TITLE = "D.C. Unveils Stadium Plan" 79 | 80 | CLASSIFICATION_DESCRIPTION = "Rumors spread that Major League Baseball is edging closer to moving the Expos to Washington as D.C. officials announce plans for a stadium on the Anacostia waterfront." 81 | 82 | PRODUCT_DESCRIPTION_FEW_SHOT = '''Write product descriptions for fashion eCommerce site based on a list of features. 83 | Product: On Every Spectrum Fit and Flare Dress 84 | Features: 85 | - Designed by Retrolicious 86 | - Stretch cotton fabric 87 | - Side pockets 88 | - Rainbow stripes print 89 | Description: In a bold rainbow-striped print, made up of exceptionally vibrant hues, this outstanding skater dress from Retroliciousis on every spectrum of vintage-inspired style. Made from a stretchy cotton fabric and boasting a round neckline, a sleeveless fitted bodice, and a gathered flare skirt with handy side pockets, this adorable fit-and-flare dress is truly unique and so retro-chic. 90 | 91 | ## 92 | 93 | Write product descriptions for fashion eCommerce site based on a list of features. 94 | Product: Camp Director Crossbody Bag 95 | Features: 96 | - Black canvas purse 97 | - Rainbow space print 98 | - Leather trim 99 | - Two securely-zipped compartments 100 | Description: Take a bit of camp charm with you wherever you go with this black canvas purse! Adorned with a rainbow space motif print, black faux-leather trim, two securely-zipped compartments, and adjustable crossbody strap, this ModCloth-exclusive bag makes sure you command a smile wherever you wander. 101 | 102 | ## 103 | 104 | Write product descriptions for fashion eCommerce site based on a list of features.''' 105 | 106 | OBQA_CONTEXT = """Large Language Models 107 | Introduction to the core of our product 108 | 109 | Natural language processing (NLP) has seen rapid growth in the last few years since large language models (LLMs) were introduced. Those huge models are based on the Transformers architecture, which allowed for the training of much larger and more powerful language models. 110 | We divide LLMs into two main categories, Autoregressive and Masked LM (language model). In this page we will focus on Autoregressive LLMs, as our language models, Jurassic-1 series, belongs to this category. 111 | 112 | ⚡ The task: predict the next word 113 | Autoregressive LLM is a neural network model composed from billions of parameters. It was trained on a massive amount of texts with one goal: to predict the next word, based on the given text. By repeating this action several times, every time adding the prediction word to the provided text, you will end up with a complete text (e.g. full sentences, paragraphs, articles, books, and more). In terms of terminology, the textual output (the complete text) is called a completion while the input (the given, original text) is called prompt. 114 | 115 | 🎓 Added value: knowledge acquisition 116 | Imagine you had to read all of Shakespeare's works repeatedly to learn a language. Eventually, you would be able to not only memorize all of his plays and poems, but also imitate his writing style. 117 | In similar fashion, we trained the LLMs by supplying them with many textual sources. This enabled them to gain an in-depth understanding of English as well as general knowledge. 118 | 119 | 🗣️ Interacting with Large Language Models 120 | The LLMs are queried using natural language, also known as prompt engineering. 121 | Rather than writing lines of code and loading a model, you write a natural language prompt and pass it to the model as the input. 122 | 123 | ⚙️ Resource-intensive 124 | Data, computation, and engineering resources are required for training and deploying large language models. LLMs, such as our Jurassic-1 models, play an important role here, providing access to this type of technology to academic researchers and developers. 125 | 126 | Tokenizer & Tokenization 127 | 128 | Now that you know what large language models are, you must be wondering: “How does a neural network use text as input and output?”. 129 | 130 | The answer is: Tokenization 🧩 131 | Any language can be broken down into basic pieces (in our case, tokens). Each of those pieces is translated into its own vector representation, which is eventually fed into the model. For example: 132 | Each model has its own dictionary of tokens, which determines the language it "speaks". Each text in the input will be decomposed into these tokens, and every text generated by the model will be composed of them. 133 | But how do we break down a language? Which pieces are we choosing as our tokens? There are several approaches to solve this: 134 | 135 | 🔡 Character-level tokenization 136 | As a simple solution, each character can be treated as its own token. By doing so, we can represent the entire English language with just 26 characters (okay, double it for capital letters and add some punctuation). This would give us a small token dictionary, thereby reducing the width we need for those vectors and saving us some valuable memory. However, those tokens don’t have any inherent meaning - we all know what the meaning of “Cat” is, but what is the meaning of “C”? The key to understanding language is context. Although it is clear to us readers that a "Cat" and a "Cradle" have different meanings, for a language model with this tokenizer - the "C" is the same. 137 | 138 | 🆒 Word-level tokenization 139 | Another approach we can try is breaking our text into words, just like in the example above ("I want to break free"). 140 | Now, every token has a meaning that the model can learn and use. We are gaining meaning, but that requires a much larger dictionary. Also, it raises another question: what about words stemming from the same root-word like ”helped”, “helping”, and “helpful”? In this approach each of these words will get a different token with no inherent relation between them, whereas for us readers it's clear that they all have a similar meaning. 141 | Furthermore, words may have fundamentally different meanings when strung together - for instance, my run-down car isn't running anywhere. What if we went a step further? 142 | 143 | 💬 Sentence-level tokenization 144 | In this approach we break our text into sentences. This will capture meaningful phrases! However, this would result in an absurdly large dictionary, with some tokens being so rare that we would require an enormous amount of data to teach the model the meaning of each token. 145 | 146 | 🏅 Which is best? 147 | Each method has pros and cons, and like any real-life problem, the best solution involves a number of compromises. AI21 Studio uses a large token dictionary (250K), which contains some from every method: separate characters, words, word parts such as prefixes and suffixes, and many multi-word tokens.""" 148 | 149 | OBQA_QUESTION = "Which tokenization methods are there?" 150 | 151 | DOC_QA = "What would you like to know?" 152 | -------------------------------------------------------------------------------- /pages/10_Contextual_Answers.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils.completion import tokenize 3 | from utils.studio_style import apply_studio_style 4 | from constants import OBQA_CONTEXT, OBQA_QUESTION, client 5 | 6 | st.set_page_config( 7 | page_title="Answers", 8 | ) 9 | 10 | max_tokens = 2048 - 200 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | apply_studio_style() 16 | st.title("Contextual Answers") 17 | 18 | st.write("Ask a question on a given context.") 19 | 20 | context = st.text_area(label="Context:", value=OBQA_CONTEXT, height=300) 21 | question = st.text_input(label="Question:", value=OBQA_QUESTION) 22 | 23 | if st.button(label="Answer"): 24 | with st.spinner("Loading..."): 25 | num_tokens = tokenize(context + question) 26 | if num_tokens > max_tokens: 27 | st.write("Text is too long. Input is limited up to 2048 tokens. Try using a shorter text.") 28 | if 'answer' in st.session_state: 29 | del st.session_state['completions'] 30 | else: 31 | response = client.answer.create(context=context, question=question) 32 | st.session_state["answer"] = response.answer 33 | 34 | if "answer" in st.session_state: 35 | st.write(st.session_state['answer']) 36 | -------------------------------------------------------------------------------- /pages/11_Multi_Document_Q&A.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import re 3 | import pdfplumber 4 | from ai21.errors import UnprocessableEntity 5 | from utils.studio_style import apply_studio_style 6 | from constants import DOC_QA, client 7 | import os 8 | from datetime import date 9 | from ai21.models.chat import ChatMessage 10 | 11 | max_chars = 200000 12 | label = 'multi_doc'+str(date.today()) 13 | 14 | 15 | def write_to_library(segmented_text, file_name): 16 | folder_name = "file" 17 | if not os.path.exists(folder_name): 18 | os.mkdir(folder_name) 19 | path = f"./{folder_name}/{file_name}.txt" 20 | f = open(path, "w") 21 | f.write(segmented_text) 22 | f.close() 23 | 24 | return path 25 | 26 | 27 | def parse_file(user_file): 28 | file_type = user_file.type 29 | with st.spinner("File is being processed..."): 30 | if file_type == "text/plain": 31 | all_text = str(user_file.read(), "utf-8", errors='ignore') 32 | else: 33 | with pdfplumber.open(user_file) as pdf: 34 | all_text = [p.extract_text() for p in pdf.pages][0] 35 | 36 | file_path_p = write_to_library(all_text, user_file.name) 37 | return file_path_p 38 | 39 | 40 | def upload_file(file_path_p): 41 | try: 42 | file_id_p = client.library.files.create(file_path=file_path_p, labels=label) 43 | st.session_state['files_ids'] = file_id_p 44 | st.session_state['file_uploaded'] = True 45 | except UnprocessableEntity: 46 | file_id_p = None 47 | return file_id_p 48 | 49 | 50 | st.set_page_config(page_title="Multi-Document Q&A") 51 | 52 | if __name__ == '__main__': 53 | apply_studio_style() 54 | st.title("Multi-Document Q&A") 55 | st.markdown("**Upload documents**") 56 | 57 | uploaded_files = st.file_uploader("choose .pdf/.txt file ", 58 | accept_multiple_files=True, 59 | type=["pdf", "text", "txt"], 60 | key="a") 61 | file_id_list = list() 62 | file_path_list = list() 63 | for uploaded_file in uploaded_files: 64 | file_path = parse_file(uploaded_file) 65 | file_id = upload_file(file_path) 66 | file_id_list.append(file_id) 67 | file_path_list.append(file_path) 68 | 69 | if st.button("Remove file"): 70 | for file in file_path_list: 71 | try: 72 | os.remove(file) 73 | except UnprocessableEntity: 74 | pass 75 | try: 76 | client.library.files.delete(st.session_state['files_ids']) 77 | except UnprocessableEntity: 78 | pass 79 | # for file in file_id_list: 80 | # with st.spinner("Loading..."): 81 | # try: 82 | # client.library.files.delete(file) 83 | # except: 84 | # continue 85 | 86 | st.write("files removed successfully") 87 | 88 | st.markdown("**Ask a question about the uploaded document, and here is the answer:**") 89 | 90 | question = st.chat_input(DOC_QA) 91 | if question: 92 | messages=[ChatMessage(content=question, role="user")] 93 | response = client.beta.conversational_rag.create(messages=messages, label=label) 94 | if response.choices[0].content is None: 95 | st.write("I'm sorry, I cannot answer your questions based on the documents I have access to.") 96 | else: 97 | st.write(response.choices[0].content) 98 | 99 | -------------------------------------------------------------------------------- /pages/1_Blog_Post_Generator.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import numpy as np 3 | import asyncio 4 | from constants import DEFAULT_MODEL 5 | from utils.studio_style import apply_studio_style 6 | import argparse 7 | from utils.completion import async_complete 8 | from utils.completion import paraphrase_req 9 | from constants import client 10 | 11 | st.set_page_config( 12 | page_title="Blog Post Generator", 13 | ) 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--port', 19 | type=int, 20 | default=8888) 21 | 22 | parser.add_argument('--num_results', 23 | type=int, 24 | default=5) 25 | 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def build_prompt(title, sections, section_heading): 31 | sections_text = '\n'.join(sections) 32 | prompt = f"Write a descriptive section in a blog post according to the following details.\n\nBlog Title:\n{title}\n\nBlog Sections:\n{sections_text}\n\nCurrent Section Heading:\n{section_heading}\n\nCurrent Section Text:\n" 33 | return prompt 34 | 35 | 36 | def generate_sections_content(num_results, sections, title): 37 | 38 | loop = asyncio.new_event_loop() 39 | asyncio.set_event_loop(loop) 40 | 41 | config = { 42 | "numResults": num_results, 43 | "maxTokens": 256, 44 | "minTokens": 10, 45 | "temperature": 0.7, 46 | "topKReturn": 0, 47 | "topP": 1, 48 | "stopSequences": [] 49 | } 50 | group = asyncio.gather(*[async_complete(DEFAULT_MODEL, build_prompt(title, sections, s), config) for s in sections]) 51 | results = loop.run_until_complete(group) 52 | loop.close() 53 | return results 54 | 55 | 56 | def build_generate_outline(title): 57 | return lambda: generate_outline(title) 58 | 59 | 60 | def generate_outline(title): 61 | st.session_state['show_outline'] = True 62 | st.session_state['show_sections'] = False 63 | 64 | res = _generate_outline(title) 65 | 66 | st.session_state["outline"] = res.completions[0].data.text.strip() 67 | 68 | 69 | def _generate_outline(title): 70 | prompt = f"Write sections to a great blog post for the following title.\nBlog title: How to start a personal blog \nBlog sections:\n1. Pick a personal blog template\n2. Develop your brand\n3. Choose a hosting plan and domain name\n4. Create a content calendar \n5. Optimize your content for SEO\n6. Build an email list\n7. Get the word out\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: A real-world example on Improving JavaScript performance\nBlog sections:\n1. Why I needed to Improve my JavaScript performance\n2. Three common ways to find performance issues in Javascript\n3. How I found the JavaScript performance issue using console.time\n4. How does lodash cloneDeep work?\n5. What is the alternative to lodash cloneDeep?\n6. Conclusion\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: Is a Happy Life Different from a Meaningful One?\nBlog sections:\n1. Five differences between a happy life and a meaningful one\n2. What is happiness, anyway?\n3. Is the happiness without pleasure?\n4. Can you have it all?\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: {title}\nBlog Sections:\n" 71 | 72 | res = client.completion.create( 73 | model=DEFAULT_MODEL, 74 | prompt=prompt, 75 | num_results=1, 76 | max_tokens=296, 77 | temperature=0.84, 78 | top_k_return=0, 79 | top_p=1, 80 | stop_sequences=["##"] 81 | ) 82 | return res 83 | 84 | 85 | def generate_sections(): 86 | st.session_state['show_sections'] = True 87 | 88 | 89 | def build_on_next_click(section_heading, section_index, completions, arg_sorted_by_length): 90 | return lambda: on_next_click(section_heading, section_index, completions, arg_sorted_by_length) 91 | 92 | 93 | def on_next_click(section_heading, section_index, completions, arg_sorted_by_length): 94 | st.session_state['show_paraphrase'][section_heading] = False 95 | new_comp_index = (st.session_state['generated_sections_data'][section_heading]["text_area_index"] + 1) % 5 96 | section_i_text = completions[arg_sorted_by_length[new_comp_index]]["data"]["text"] 97 | st.session_state['generated_sections_data'][section_heading]["text_area_index"] = new_comp_index 98 | st.session_state['generated_sections_data'][section_heading]["text_area_data"].text_area(label=section_heading, 99 | height=300, 100 | value=section_i_text, 101 | key=section_index) 102 | 103 | 104 | def build_on_prev_click(section_heading, section_index, completions, arg_sorted_by_length): 105 | return lambda: on_prev_click(section_heading, section_index, completions, arg_sorted_by_length) 106 | 107 | 108 | def on_prev_click(section_heading, section_index, completions, arg_sorted_by_length): 109 | st.session_state['show_paraphrase'][section_heading] = False 110 | 111 | new_comp_index = (st.session_state['generated_sections_data'][section_heading]["text_area_index"] - 1) % 5 112 | section_i_text = completions[arg_sorted_by_length[new_comp_index]]["data"]["text"] 113 | st.session_state['generated_sections_data'][section_heading]["text_area_index"] = new_comp_index 114 | st.session_state['generated_sections_data'][section_heading]["text_area_data"].text_area(label=section_heading, 115 | height=300, 116 | value=section_i_text, 117 | key=section_index) 118 | 119 | 120 | def get_event_loop(title, sections, num_results): 121 | st.session_state['show_sections'] = True 122 | 123 | for s in sections: 124 | st.session_state['generated_sections_data'][s] = {} 125 | st.session_state['show_paraphrase'][s] = False 126 | 127 | # perform request, actually generate sections 128 | results = generate_sections_content(num_results, sections, title) 129 | 130 | # moved these lines here to detach st code from logic 131 | for i, s in enumerate(sections): 132 | response_json = results[i] 133 | section_completions = response_json["completions"] # gets the generated candidates of the current completion 134 | st.session_state['generated_sections_data'][s]["completions"] = section_completions 135 | 136 | # rank/filter 137 | for i, s in enumerate(sections): 138 | response_json = results[i] 139 | section_completions = response_json["completions"] 140 | st.session_state['generated_sections_data'][s]["completions"] = section_completions 141 | 142 | lengths = [] 143 | for c in range(len(section_completions)): 144 | l = len(section_completions[c]["data"]["text"]) 145 | lengths.append(l) 146 | 147 | arg_sort = np.argsort(lengths) 148 | index = 2 149 | st.session_state['generated_sections_data'][s]["text_area_index"] = index 150 | st.session_state['generated_sections_data'][s]["arg_sort"] = arg_sort 151 | 152 | st.session_state['generated_sections_data'][s]["rewrites"] = ["" for c in range(len(section_completions))] 153 | 154 | 155 | def build_event_loop(title, section_heading, num_results): 156 | return lambda: get_event_loop(title, section_heading, num_results) 157 | 158 | 159 | def build_event_loop_one_section(title, section, num_results): 160 | return lambda: get_event_loop(title, [section], num_results) 161 | 162 | 163 | def on_outline_change(): 164 | st.session_state['show_sections'] = False 165 | 166 | 167 | def paraphrase(text, tone, times): 168 | len_text = len(text) 169 | entire_text = text 170 | for i in range(times): 171 | if len_text > 500: 172 | sentences = text.split(".") 173 | else: 174 | sentences = [text] 175 | 176 | filtered_sentences = [] 177 | for sentence in sentences: 178 | sentence = sentence.strip() 179 | len_sent = len(sentence) 180 | if len_sent > 1: 181 | filtered_sentences.append(sentence) 182 | 183 | loop = asyncio.new_event_loop() 184 | asyncio.set_event_loop(loop) 185 | group = asyncio.gather(*[paraphrase_req(sentence, tone) for sentence in filtered_sentences]) 186 | 187 | results = loop.run_until_complete(group) 188 | loop.close() 189 | 190 | final_text = [] 191 | for r in results: 192 | sugg = r["suggestions"][0]["text"] 193 | final_text.append(sugg) 194 | 195 | entire_text = ". ".join(final_text) 196 | entire_text = (entire_text + ".").replace(",.", ".").replace("?.", ".").replace("..", ".") 197 | 198 | text = entire_text 199 | len_text = len(entire_text) 200 | 201 | return entire_text 202 | 203 | 204 | def on_paraphrase_click(s, tone, times): 205 | 206 | all_sections_data = st.session_state['generated_sections_data'] 207 | 208 | index = st.session_state['generated_sections_data'][s]["text_area_index"] 209 | section_completions = all_sections_data[s]["completions"] 210 | 211 | sec_text = section_completions[index]["data"]["text"] 212 | paraphrased_section = paraphrase(sec_text, tone, times) 213 | 214 | st.session_state['generated_sections_data'][s]["rewrites"][index] = paraphrased_section 215 | st.session_state['show_paraphrase'][s] = True 216 | 217 | 218 | def build_paraphrase(s, tone, times): 219 | return lambda: on_paraphrase_click(s, tone, times) 220 | 221 | 222 | def on_heading_change(): 223 | st.session_state['show_sections'] = False 224 | 225 | 226 | def on_title_change(): 227 | st.session_state['show_sections'] = False 228 | 229 | 230 | if __name__ == '__main__': 231 | args = get_args() 232 | apply_studio_style() 233 | num_results = args.num_results 234 | 235 | # Initialization 236 | if 'show_outline' not in st.session_state: 237 | st.session_state['show_outline'] = False 238 | 239 | if 'show_sections' not in st.session_state: 240 | st.session_state['show_sections'] = False 241 | 242 | if 'show_paraphrase' not in st.session_state: 243 | st.session_state['show_paraphrase'] = {} 244 | 245 | if 'generated_sections_data' not in st.session_state: 246 | st.session_state['generated_sections_data'] = {} 247 | 248 | st.title("Blog Post Generator") 249 | st.markdown("Using only a title, you can instantly generate an entire article with the click of a button! Simply select your topic and this tool will create an engaging article from beginning to end.") 250 | st.markdown("#### Blog Title") 251 | title = st.text_input(label="Write the title of your article here:", placeholder="", 252 | value="5 Strategies to overcome writer's block").strip() 253 | st.markdown("#### Blog Outline") 254 | st.text("Click the button to generate the blog outline") 255 | st.button(label="Generate Outline", on_click=build_generate_outline(title)) 256 | 257 | sections = [] 258 | if st.session_state['show_outline']: 259 | text_area_outline = st.text_area(label=" ", height=250, value=st.session_state["outline"], 260 | on_change=on_outline_change) 261 | sections = text_area_outline.split("\n") 262 | st.text("Unsatisfied with the generated outline? Click the 'Generate Outline' button again to re-generate it, or edit it inline.") 263 | 264 | st.markdown("#### Blog Sections") 265 | st.text("Click the button to effortlessly generate an outline for your blog post:") 266 | st.button(label="Generate Sections", on_click=build_event_loop(title, sections, num_results)) 267 | 268 | if st.session_state['show_sections']: 269 | st.markdown(f"**{title}**") 270 | for s in sections: 271 | st.session_state['generated_sections_data'][s]["text_area_data"] = st.empty() 272 | st.session_state['generated_sections_data'][s]["cols"] = st.empty() 273 | 274 | all_sections_data = st.session_state['generated_sections_data'] 275 | for i, s in enumerate(st.session_state['generated_sections_data'].keys()): 276 | index = st.session_state['generated_sections_data'][s]["text_area_index"] 277 | section_completions = all_sections_data[s]["completions"] 278 | arg_sort = st.session_state['generated_sections_data'][s]["arg_sort"] 279 | 280 | section_text_area_value = st.session_state['generated_sections_data'][s]["rewrites"][index] if st.session_state['show_paraphrase'][s] == True else section_completions[ 281 | index]["data"]["text"] 282 | section_i_text = st.session_state['generated_sections_data'][s]["text_area_data"].text_area(label=s, 283 | height=300, 284 | value=section_text_area_value, 285 | key="generated-section"+s) 286 | st.session_state['generated_sections_data'][s]["completions"][index]["data"]["text"] = section_i_text 287 | col1, col2, col3, col4, col5, col6 = st.session_state['generated_sections_data'][s]["cols"].columns( 288 | [0.2, 0.2, 0.06, 0.047, 0.05, 0.4]) 289 | 290 | with col1: 291 | st.button("Generate Again", on_click=build_event_loop_one_section(title, s, num_results), 292 | key="generate-again-" + s) 293 | 294 | with col2: 295 | st.button("Paraphrase", on_click=build_paraphrase(s, tone="general", times=1), 296 | key="paraphrase-button-" + s) 297 | 298 | with col3: 299 | st.button("<", on_click=build_on_prev_click(s, i, section_completions, arg_sort), key="<" + s) 300 | 301 | with col4: 302 | st.text(f"{index+1}/{num_results}") 303 | 304 | with col5: 305 | st.button(">", on_click=build_on_next_click(s, i, section_completions, arg_sort), key=">" + s) 306 | -------------------------------------------------------------------------------- /pages/2_Product_Description_Generator.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from constants import PRODUCT_DESCRIPTION_FEW_SHOT, DEFAULT_MODEL 3 | from utils.studio_style import apply_studio_style 4 | from constants import client 5 | from ai21.models import Penalty 6 | 7 | st.set_page_config( 8 | page_title="Product Description Generator", 9 | ) 10 | 11 | 12 | def query(prompt): 13 | 14 | res = client.completion.create( 15 | model=DEFAULT_MODEL, 16 | prompt=prompt, 17 | num_results=1, 18 | max_tokens=240, 19 | temperature=1, 20 | top_k_return=0, 21 | top_p=0.98, 22 | count_penalty=Penalty( 23 | scale=0, 24 | apply_to_emojis=False, 25 | apply_to_numbers=False, 26 | apply_to_stopwords=False, 27 | apply_to_punctuation=False, 28 | apply_to_whitespaces=False, 29 | ), 30 | frequency_penalty=Penalty( 31 | scale=225, 32 | apply_to_emojis=False, 33 | apply_to_numbers=False, 34 | apply_to_stopwords=False, 35 | apply_to_punctuation=False, 36 | apply_to_whitespaces=False, 37 | ), 38 | presence_penalty=Penalty( 39 | scale=1.2, 40 | apply_to_emojis=False, 41 | apply_to_numbers=False, 42 | apply_to_stopwords=False, 43 | apply_to_punctuation=False, 44 | apply_to_whitespaces=False, 45 | ) 46 | ) 47 | 48 | return res.completions[0].data.text 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | apply_studio_style() 54 | st.title("Product Description Generator") 55 | st.markdown("###### Create valuable marketing copy for product pages that describes your product and its benefits within seconds! Simply choose a fashion accessory, a few key features, and let our tool work its magic.") 56 | 57 | product_input = st.text_input("Enter the name of your product:", value="Talking Picture Oxford Flat") 58 | features = st.text_area("List your product features here:", value="- Flat shoes\n- Amazing chestnut color\n- Man made materials") 59 | 60 | prompt = PRODUCT_DESCRIPTION_FEW_SHOT + f"Product: {product_input}\nFeatures:\n{features}\nDescription:" 61 | 62 | if st.button(label="Generate Description"): 63 | st.session_state["short-form-save_results_ind"] = [] 64 | with st.spinner("Loading..."): 65 | st.session_state["short-form-result"] = { 66 | "completion": query(prompt), 67 | } 68 | 69 | if "short-form-result" in st.session_state: 70 | result = st.session_state["short-form-result"]["completion"] 71 | st.text("") 72 | st.text_area("Generated Product Description", result, height=200) 73 | -------------------------------------------------------------------------------- /pages/4_Pitch_Email_Generator.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from constants import DEFAULT_MODEL 3 | from utils.completion import tokenize 4 | from utils.studio_style import apply_studio_style 5 | import re 6 | from constants import client 7 | 8 | 9 | st.set_page_config( 10 | page_title="Pitch Email Generator", 11 | ) 12 | 13 | max_tokens = 2048 - 200 14 | 15 | WORDS_LIMIT = { 16 | "pitch": (150, 200), 17 | } 18 | 19 | title_placeholder = "PetSmart Charities® Commits $100 Million to Improve Access to Veterinary Care" 20 | article_placeholder = """The inability of many families to access veterinary care is a pressing issue facing animal welfare nationally. To combat this, PetSmart Charities announced a commitment of $100 million over the next five years to help break down the geographic, cultural, language and financial barriers that prevent pets from receiving the veterinary care they need to thrive. 21 | Zora is a lovable pug whose owner knew something was wrong with her breathing, but was struggling to find a vet that would provide non-routine care at low cost. Despite the challenges of transitioning from homelessness and landing in the hospital himself, Zora’s owner found someone to take her to a free clinic offered by Ruthless Kindness, a PetSmart Charities grantee. Thanks to the care she received, Zora and her owner are now thriving together. 22 | Veterinary access impacts the animal welfare industry and individual families in every community in the country. More than 70 percent of homes in the United States now include pets, but 50 million pets in the U.S. lack even basic veterinary care, including spay/neuter surgeries, annual exams and vaccinations. Without regular veterinary care, minor pet health issues often become bigger, costlier problems; and preventable diseases can be passed on to people and other animals. Pet parents may be forced to relinquish their beloved furry family members to already overcrowded animal shelters or be forced to watch them suffer when they can't access treatment. With pets being universally recognized as beloved family members, the challenges posed by an inability to access veterinary care can have a profound impact. 23 | PetSmart Charities estimates it would cost more than $20 billion annually to bridge the gap for pets in need of veterinary care at standard veterinary prices. More needs to be done to expand availability of lower-cost services, ensure access for remote and bilingual communities and ensure there are enough veterinarians able to perform a variety of services through clinics and urgent care centers. To help lead the charge, the nonprofit is taking a leadership role in marshaling partners and stakeholders to develop and execute solutions to solving the gap in veterinary care access. 24 | ""The challenges facing the veterinary care system are vast and varied and no single organization can solve them alone,"" said Aimee Gilbreath, president of PetSmart Charities. ""Through PetSmart Charities' commitment, we plan to invest further in our partners and build new alliances to innovate solutions across the entire system — while also funding long-term solutions already in place such as low-cost vet clinics and veterinary student scholarships. We're confident this approach will produce sustainable change within the veterinary care industry. Our best friends deserve access to adequate health care like any family members."" 25 | 26 | Barriers to Veterinary Care 27 | While affordability remains the most prominent barrier to veterinary care, additional challenges contribute to the current veterinary care gap, including: 28 | Veterinary Shortage: With pet ownership steadily on the rise, a 33% increase in pet healthcare service spending is expected over the next 10 years. 90 million U.S. households now include pets, but the number of nationwide veterinarians has increased by just 2.7 percent each year since 2007. To meet the growing need for veterinary care, an additional 41,000 veterinarians would be needed by 2030. 29 | Veterinary Deserts and Cultural Inclusion: Within rural and underserved regions, veterinary practices are difficult to find close-by, making trips to the veterinarian costly and sometimes impossible. Additionally, as veterinarians are currently listed among the top five least diverse professions, cultural and language divides can often occur between clients and veterinarians, discouraging some pet parents from seeking care. 30 | Economic Challenges: The cost of veterinary care has spiked 10 percent in the last year alone and amidst the ongoing housing crisis and economic uncertainty, 63 percent of pet parents find it difficult to cover surprise vet expenses. 31 | Regulatory Challenges: Nationwide, fragmented and varied veterinary regulations pose challenges to the development of easy, efficient and consistent solutions such as telemedicine. 32 | How PetSmart Charities Will Innovate Solutions 33 | Through a $100 million commitment in funding over the next five years, PetSmart Charities will take a multifaceted approach to improve access to adequate veterinary care for all pets, including: 34 | Funding solutions across the system of veterinary care – from investing in new and more affordable types of clinics to working directly with providers to help them overcome challenges in care delivery. 35 | Supporting innovative solutions such as new telehealth care and delayed payment models that reduce and help manage the cost of care for pet parents. 36 | Partnering with universities and thought leaders to research the evolving needs of pets while developing innovative, cost-effective ways to deliver care. 37 | Awarding scholarships to veterinary students pursuing community-based practices and establishing a training program for Master's-level veterinary practitioners to offer basic care at affordable prices. 38 | Expanding access to lower-cost veterinary care through sustainable nonprofit clinics. 39 | Developing community-based models led by local changemakers to improve access to veterinary care to underserved communities through an emphasis on their unique challenges. 40 | For more information on how PetSmart Charities is working to expand access to veterinary care nationwide or to help support initiatives like this for pets and their families, visit petsmartcharities.org. 41 | 42 | About PetSmart Charities® 43 | PetSmart Charities is committed to making the world a better place for pets and all who love them. Through its in-store adoption program in all PetSmart® stores across the U.S. and Puerto Rico, PetSmart Charities helps up to 600,000 pets connect with loving families each year. PetSmart Charities also provides grant funding to support organizations that advocate and care for the well-being of all pets and their families. PetSmart Charities' grants and efforts connect pets with loving homes through adoption, improve access to affordable veterinary care and support families in times of crises with access to food, shelter and disaster response. Each year, millions of generous supporters help pets in need by donating to PetSmart Charities directly at PetSmartCharities.org, while shopping at PetSmart.com, and by using the PIN pads at checkout registers inside PetSmart® stores. In turn, PetSmart Charities efficiently uses more than 90 cents of every dollar donated to fulfill its role as the leading funder of animal welfare in North America, granting more than $500 million since its inception in 1994. Independent from PetSmart LLC, PetSmart Charities is a 501(c)(3) organization that has received the Four-Star Rating from Charity Navigator for the past 18 years in a row – placing it among the top one percent of rated charities. To learn more visit www.petsmartcharities.org.""" 44 | 45 | 46 | def anonymize(text): 47 | text = re.sub(r'https?:\/\/.*', '[URL]', text) 48 | return re.sub(r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+', '[EMAIL]', text) 49 | 50 | 51 | def generate(prompt, category, max_retries=2): 52 | min_length, max_length = WORDS_LIMIT[category] 53 | completions_filtered = [] 54 | try_count = 0 55 | while not len(completions_filtered) and try_count < max_retries: 56 | res = client.completion.create( 57 | model=DEFAULT_MODEL, 58 | prompt=prompt, 59 | max_tokens=200, 60 | temperature=0.8, 61 | num_results=16 62 | ) 63 | completions_filtered = [comp.data.text for comp in res.completions 64 | if comp.finish_reason.reason == "endoftext" 65 | and min_length <= len(comp.data.text.split()) <= max_length] 66 | try_count += 1 67 | st.session_state["completions"] = [anonymize(i) for i in completions_filtered] 68 | 69 | 70 | def on_next(): 71 | st.session_state['index'] = (st.session_state['index'] + 1) % len(st.session_state['completions']) 72 | 73 | 74 | def on_prev(): 75 | st.session_state['index'] = (st.session_state['index'] - 1) % len(st.session_state['completions']) 76 | 77 | 78 | def toolbar(): 79 | cols = st.columns([0.2, 0.2, 0.2, 0.2, 0.2]) 80 | with cols[1]: 81 | if st.button(label='<', key='prev'): 82 | on_prev() 83 | with cols[2]: 84 | st.text(f"{st.session_state['index'] + 1}/{len(st.session_state['completions'])}") 85 | with cols[3]: 86 | if st.button(label='\>', key='next'): 87 | on_next() 88 | 89 | 90 | if __name__ == '__main__': 91 | apply_studio_style() 92 | st.title("Marketing Generator") 93 | 94 | st.session_state['title'] = st.text_input(label="Title", value=title_placeholder).strip() 95 | st.session_state['article'] = st.text_area(label="Article", value=article_placeholder, height=500).strip() 96 | 97 | domain = st.radio( 98 | "Select domain of reporter 👉", 99 | options=['Technology', 'Healthcare', 'Venture Funding', 'Other'], 100 | ) 101 | 102 | if domain == 'Other': 103 | instruction = "Write a pitch to reporters persuading them why they should write about this for their publication." 104 | else: 105 | instruction = f"Write a pitch to reporters that cover {domain} stories persuading them why they should write about this for their publication." 106 | suffix = "Email Introduction" 107 | prompt = f"{instruction}\nTitle: {st.session_state['title']}\nPress Release:\n{st.session_state['article']}\n\n{suffix}:\n" 108 | category = 'pitch' 109 | 110 | if st.button(label="Compose"): 111 | with st.spinner("Loading..."): 112 | num_tokens = tokenize(prompt) 113 | if num_tokens > max_tokens: 114 | st.write("Text is too long. Input is limited up to 2048 tokens. Try using a shorter text.") 115 | if 'completions' in st.session_state: 116 | del st.session_state['completions'] 117 | else: 118 | generate(prompt, category=category) 119 | st.session_state['index'] = 0 120 | 121 | if 'completions' in st.session_state: 122 | if len(st.session_state['completions']) == 0: 123 | st.write("Please try again 😔") 124 | 125 | else: 126 | curr_text = st.session_state['completions'][st.session_state['index']] 127 | st.subheader(f'Generated Email') 128 | st.text_area(label=" ", value=curr_text.strip(), height=400) 129 | st.write(f"Number of words: {len(curr_text.split())}") 130 | if len(st.session_state['completions']) > 1: 131 | toolbar() 132 | -------------------------------------------------------------------------------- /pages/5_Social_Media_Generator.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | from utils.filters import * 3 | from utils.studio_style import apply_studio_style 4 | from constants import client 5 | 6 | 7 | def create_prompt(media, article): 8 | post_type = "tweet" if media == "Twitter" else "Linkedin post" 9 | instruction = f"Write a {post_type} touting the following press release." 10 | return f"{instruction}\nArticle:\n{article}\n\n{post_type}:\n" 11 | 12 | 13 | def generate(article, media, max_retries=2, top=3): 14 | prompt = create_prompt(media, article) 15 | completions_filtered = [] 16 | try_count = 0 17 | while not len(completions_filtered) and try_count < max_retries: 18 | res = client.completion.create( 19 | model=DEFAULT_MODEL, 20 | prompt=prompt, 21 | max_tokens=200, 22 | temperature=0.8, 23 | num_results=16 24 | ) 25 | completions_filtered = [comp.data.text for comp in res.completions 26 | if apply_filters(comp, article, media)] 27 | try_count += 1 28 | res = filter_duplicates(completions_filtered)[:top] 29 | return [remove_utf_emojis(anonymize(i)) for i in res] 30 | 31 | 32 | def on_next(): 33 | st.session_state['index'] = (st.session_state['index'] + 1) % len(st.session_state['completions']) 34 | 35 | 36 | def on_prev(): 37 | st.session_state['index'] = (st.session_state['index'] - 1) % len(st.session_state['completions']) 38 | 39 | 40 | def toolbar(): 41 | cols = st.columns([0.35, 0.1, 0.1, 0.1, 0.35]) 42 | with cols[1]: 43 | st.button(label='<', key='prev', on_click=on_prev) 44 | with cols[2]: 45 | st.text(f"{st.session_state['index'] + 1}/{len(st.session_state['completions'])}") 46 | with cols[3]: 47 | st.button(label="\>", key='next', on_click=on_next) 48 | with cols[4]: 49 | st.button(label="🔄", on_click=lambda: compose()) 50 | 51 | 52 | def extract(): 53 | with st.spinner("Summarizing article..."): 54 | try: 55 | st.session_state['article'] = client.summarize.create(source=st.session_state['url'], source_type='URL').summary 56 | except: 57 | st.session_state['article'] = False 58 | 59 | 60 | def compose(): 61 | with st.spinner("Generating post..."): 62 | st.session_state["completions"] = generate(st.session_state['article'], media=st.session_state['media']) 63 | st.session_state['index'] = 0 64 | 65 | 66 | if __name__ == '__main__': 67 | apply_studio_style() 68 | st.title("Social Media Generator") 69 | 70 | st.session_state['url'] = st.text_input(label="Enter your article URL", 71 | value=st.session_state.get('url', 'https://www.ai21.com/blog/announcing-ai21-studio-and-jurassic-1')).strip() 72 | 73 | if st.button(label='Summarize'): 74 | extract() 75 | 76 | if 'article' in st.session_state: 77 | if not st.session_state['article']: 78 | st.write("This article is not supported, please try another one") 79 | 80 | else: 81 | st.text_area(label='Summary', value=st.session_state['article'], height=200) 82 | 83 | st.session_state['media'] = st.radio( 84 | "Compose a post for this article for 👉", 85 | options=['Twitter', 'Linkedin'], 86 | horizontal=True 87 | ) 88 | 89 | st.button(label="Compose", on_click=lambda: compose()) 90 | 91 | if 'completions' in st.session_state: 92 | if len(st.session_state['completions']) == 0: 93 | st.write("Please try again 😔") 94 | 95 | else: 96 | curr_text = st.session_state['completions'][st.session_state['index']] 97 | st.text_area(label="Your awesome generated post", value=curr_text.strip(), height=200) 98 | if len(st.session_state['completions']) > 1: 99 | toolbar() 100 | -------------------------------------------------------------------------------- /pages/6_Intent_Classifier.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import math 3 | import streamlit as st 4 | from constants import DEFAULT_MODEL 5 | from utils.studio_style import apply_studio_style 6 | from utils.completion import async_complete, tokenize 7 | import re 8 | 9 | 10 | OTHER_THRESHOLD = 0.2 11 | 12 | st.set_page_config( 13 | page_title="Intent Classifier", 14 | ) 15 | 16 | 17 | def generate_response(prompt, delay): 18 | config = {"maxTokens": 0, "temperature": 1} 19 | res = async_complete(model_type=DEFAULT_MODEL, 20 | prompt=prompt, 21 | config=config, delay=delay) 22 | return res 23 | 24 | 25 | def batch_responses(prompts, delay=0.25): 26 | delay = delay if len(prompts) >= 5 else 0 # may deal with less than 5 requests in 1 second 27 | loop = asyncio.new_event_loop() 28 | asyncio.set_event_loop(loop) 29 | group = asyncio.gather(*[generate_response(p, i*delay) for i, p in enumerate(prompts)]) 30 | results = loop.run_until_complete(group) 31 | loop.close() 32 | return results 33 | 34 | 35 | if __name__ == "__main__": 36 | apply_studio_style() 37 | st.title("Intent Classifier") 38 | instruction = "Classify the following question into one of the following classes:" 39 | st.write(instruction) 40 | 41 | st.session_state['classes'] = st.text_area(label='Classes', value="Parking\nFood/Restaurants\nRoom Facilities\nSpa\nTransport", height=180) 42 | st.session_state['classes'] = re.sub('\n+', '\n', st.session_state['classes']) 43 | 44 | st.session_state['question'] = st.text_input(label="Question", value="How to get from the airport?") 45 | 46 | if st.button('Classify'): 47 | prompt = f"{instruction}\n{st.session_state['classes']}\n\nQuestion:\n{st.session_state['question']}\n\nClass:\n" 48 | num_tokens = tokenize(prompt) 49 | responses = batch_responses([prompt + c for c in st.session_state['classes'].split('\n')]) 50 | results = {} 51 | for i, r in enumerate(responses): 52 | token_list = [t['generatedToken'] for t in r['prompt']['tokens'][num_tokens:] if t['generatedToken']['token'] != '<|newline|>'] 53 | class_name = ''.join([t['token'] for t in token_list]).replace('▁', ' ') 54 | sum_logprobs = sum([t['logprob'] for t in token_list]) 55 | results[class_name] = round(math.exp(sum_logprobs), 2) 56 | 57 | results['Other'] = 1 - sum(results.values()) 58 | 59 | sorted_results = {k: v for k, v in sorted(results.items(), key=lambda x: x[1], reverse=True)} 60 | for name, prob in sorted_results.items(): 61 | st.write(name, round(prob, 2)) 62 | 63 | -------------------------------------------------------------------------------- /pages/7_Topic_Classifier.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils.studio_style import apply_studio_style 3 | from constants import CLASSIFICATION_FEWSHOT, CLASSIFICATION_PROMPT, CLASSIFICATION_TITLE, CLASSIFICATION_DESCRIPTION, \ 4 | DEFAULT_MODEL 5 | from constants import client 6 | 7 | 8 | st.set_page_config( 9 | page_title="Topic Classifier", 10 | ) 11 | 12 | 13 | def query(prompt): 14 | 15 | res = client.completion.create( 16 | model=st.session_state['classification_model'], 17 | prompt=prompt, 18 | num_results=1, 19 | max_tokens=5, 20 | temperature=0, 21 | stop_sequences=["##"] 22 | ) 23 | return res.completions[0].data.text 24 | 25 | 26 | if __name__ == '__main__': 27 | 28 | apply_studio_style() 29 | st.title("Topic Classifier") 30 | st.write("Read any interesting news lately? Let's see if our topic classifier can skim through it and identify whether its category is sports, business, world news, or science and technology.") 31 | st.session_state['classification_model'] = DEFAULT_MODEL 32 | 33 | st.text(CLASSIFICATION_PROMPT) 34 | classification_title = st.text_input(label="Title:", value=CLASSIFICATION_TITLE) 35 | classification_description = st.text_area(label="Description:", value=CLASSIFICATION_DESCRIPTION, height=100) 36 | 37 | if st.button(label="Classify"): 38 | with st.spinner("Loading..."): 39 | classification_prompt = f"{CLASSIFICATION_PROMPT}\nTitle:\n{classification_title}" \ 40 | f"Description:\n{classification_description}The topic of this article is:\n" 41 | st.session_state["classification_result"] = query(CLASSIFICATION_FEWSHOT + classification_prompt) 42 | 43 | if "classification_result" in st.session_state: 44 | st.subheader(f"Topic: {st.session_state['classification_result']}") 45 | -------------------------------------------------------------------------------- /pages/8_Rewrite_Tool.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils.studio_style import apply_studio_style 3 | from constants import client 4 | from ai21.models import ParaphraseStyleType 5 | 6 | st.set_page_config( 7 | page_title="Rewrite Tool", 8 | ) 9 | 10 | 11 | def get_suggestions(text, intent=ParaphraseStyleType.GENERAL, span_start=0, span_end=None): 12 | rewrite_resp = client.paraphrase.create( 13 | text=text, 14 | style=intent, 15 | start_index=span_start, 16 | end_index=span_end or len(text)) 17 | rewritten_texts = [sug.text for sug in rewrite_resp.suggestions] 18 | st.session_state["rewrite_rewritten_texts"] = rewritten_texts 19 | 20 | 21 | def show_next(cycle_length): 22 | # From streamlit docs: "When updating Session state in response to events, a callback function gets executed first, and then the app is executed from top to bottom." 23 | # This means this function just needs to update the current index. The text itself would be shown since the entire app is executed again 24 | curr_index = st.session_state["rewrite_curr_index"] 25 | next_index = (curr_index + 1) % cycle_length 26 | st.session_state["rewrite_curr_index"] = next_index 27 | 28 | 29 | def show_prev(cycle_length): 30 | curr_index = st.session_state["rewrite_curr_index"] 31 | prev_index = (curr_index - 1) % cycle_length 32 | st.session_state["rewrite_curr_index"] = prev_index 33 | 34 | 35 | if __name__ == '__main__': 36 | apply_studio_style() 37 | 38 | st.title("Rewrite Tool") 39 | st.write("Rephrase with ease! Find fresh new ways to reword your sentences with an AI writing companion that paraphrases & rewrites any text. Select rewrite suggestions that clearly convey your ideas with a range of different tones to choose from.") 40 | text = st.text_area(label="Write your text here to see what the rewrite tool can do:", 41 | max_chars=500, 42 | placeholder="AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labs’ state-of-the-art language models.", 43 | value="AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labs’ state-of-the-art language models.").strip() 44 | 45 | intent = st.radio( 46 | "Set your tone 👉", 47 | key="intent", 48 | options=["general", "formal", "casual", "long", "short"], 49 | horizontal=True 50 | ) 51 | 52 | st.button(label="Rewrite ✍️", on_click=lambda: get_suggestions(text, intent=intent)) 53 | if "rewrite_rewritten_texts" in st.session_state: 54 | suggestions = st.session_state["rewrite_rewritten_texts"] 55 | 56 | ph = st.empty() 57 | if "rewrite_curr_index" not in st.session_state: 58 | st.session_state["rewrite_curr_index"] = 0 59 | curr_index = st.session_state["rewrite_curr_index"] 60 | ph.text_area(label="Suggestions", value=suggestions[curr_index]) 61 | 62 | col1, col2, col3, *_ = st.columns([1, 1, 1, 10]) 63 | with col1: 64 | st.button("<", on_click=show_prev, args=(len(suggestions),)) 65 | with col2: 66 | st.markdown(f"{curr_index+1}/{len(suggestions)}") 67 | with col3: 68 | st.button("\>", on_click=show_next, args=(len(suggestions),)) 69 | -------------------------------------------------------------------------------- /pages/9_Document_Summarizer.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from ai21.errors import UnprocessableEntity 3 | 4 | from utils.studio_style import apply_studio_style 5 | from constants import client, SUMMARIZATION_URL, SUMMARIZATION_TEXT 6 | 7 | st.set_page_config( 8 | page_title="Document Summarizer", 9 | ) 10 | 11 | if __name__ == '__main__': 12 | apply_studio_style() 13 | 14 | st.title("Document Summarizer") 15 | st.write( 16 | "Effortlessly transform lengthy material into a focused summary. Whether it’s an article, research paper or even your own notes - this tool will sum up the key points!") 17 | sourceType = st.radio(label="Source type", options=['Text', 'URL']) 18 | if sourceType == 'Text': 19 | source = st.text_area(label="Paste your text here:", 20 | height=400, 21 | value=SUMMARIZATION_TEXT).strip() 22 | else: 23 | source = st.text_input(label="Paste your URL here:", 24 | value=SUMMARIZATION_URL).strip() 25 | 26 | if st.button(label="Answer"): 27 | with st.spinner("Loading..."): 28 | try: 29 | response = client.summarize.create(source=source, source_type=sourceType.upper()) 30 | st.text_area(label="Summary", height=250, value=response.summary) 31 | except UnprocessableEntity: 32 | st.write('Text is too long for the document summarizer, please try the segment summarizer instead.') 33 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | tqdm 2 | streamlit 3 | streamlit-chat 4 | aiohttp 5 | ai21 6 | requests>=2.31.0 7 | tornado>=6.3.2 8 | pdfplumber -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.9 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | --no-binary grpcio 8 | 9 | ai21==2.15.1 10 | # via -r requirements.in 11 | aiohttp==3.10.10 12 | # via -r requirements.in 13 | aiosignal==1.3.1 14 | # via aiohttp 15 | altair==5.1.2 16 | # via streamlit 17 | async-timeout==4.0.3 18 | # via aiohttp 19 | attrs==23.1.0 20 | # via 21 | # aiohttp 22 | # jsonschema 23 | # referencing 24 | blinker==1.6.3 25 | # via streamlit 26 | cachetools==5.3.2 27 | # via streamlit 28 | certifi==2023.7.22 29 | # via requests 30 | cffi==1.16.0 31 | # via cryptography 32 | charset-normalizer==3.3.2 33 | # via 34 | # aiohttp 35 | # pdfminer-six 36 | # requests 37 | click==8.1.7 38 | # via streamlit 39 | cryptography==41.0.5 40 | # via pdfminer-six 41 | frozenlist==1.4.0 42 | # via 43 | # aiohttp 44 | # aiosignal 45 | gitdb==4.0.11 46 | # via gitpython 47 | gitpython==3.1.40 48 | # via streamlit 49 | idna==3.4 50 | # via 51 | # requests 52 | # yarl 53 | importlib-metadata==6.8.0 54 | # via streamlit 55 | jinja2==3.1.2 56 | # via 57 | # altair 58 | # pydeck 59 | jsonschema==4.19.2 60 | # via altair 61 | jsonschema-specifications==2023.7.1 62 | # via jsonschema 63 | markdown-it-py==3.0.0 64 | # via rich 65 | markupsafe==2.1.3 66 | # via jinja2 67 | mdurl==0.1.2 68 | # via markdown-it-py 69 | multidict==6.0.4 70 | # via 71 | # aiohttp 72 | # yarl 73 | numpy==1.26.1 74 | # via 75 | # altair 76 | # pandas 77 | # pyarrow 78 | # pydeck 79 | # streamlit 80 | packaging==23.2 81 | # via 82 | # altair 83 | # streamlit 84 | pandas==2.1.2 85 | # via 86 | # altair 87 | # streamlit 88 | pdfminer-six==20221105 89 | # via pdfplumber 90 | pdfplumber==0.10.3 91 | # via -r requirements.in 92 | pillow==10.1.0 93 | # via 94 | # pdfplumber 95 | # streamlit 96 | protobuf==4.24.4 97 | # via streamlit 98 | pyarrow==14.0.0 99 | # via streamlit 100 | pycparser==2.21 101 | # via cffi 102 | pydeck==0.8.0 103 | # via streamlit 104 | pygments==2.16.1 105 | # via rich 106 | pypdfium2==4.23.1 107 | # via pdfplumber 108 | python-dateutil==2.8.2 109 | # via 110 | # pandas 111 | # streamlit 112 | pytz==2023.3.post1 113 | # via pandas 114 | referencing==0.30.2 115 | # via 116 | # jsonschema 117 | # jsonschema-specifications 118 | requests==2.31.0 119 | # via 120 | # -r requirements.in 121 | # ai21 122 | # streamlit 123 | rich==13.6.0 124 | # via streamlit 125 | rpds-py==0.10.6 126 | # via 127 | # jsonschema 128 | # referencing 129 | six==1.16.0 130 | # via python-dateutil 131 | smmap==5.0.1 132 | # via gitdb 133 | streamlit==1.31.1 134 | # via 135 | # -r requirements.in 136 | # streamlit-chat 137 | streamlit-chat==0.1.1 138 | # via -r requirements.in 139 | tenacity==8.5.0 140 | # via streamlit 141 | toml==0.10.2 142 | # via streamlit 143 | toolz==0.12.0 144 | # via altair 145 | tornado==6.3.3 146 | # via 147 | # -r requirements.in 148 | # streamlit 149 | tqdm==4.66.1 150 | # via -r requirements.in 151 | typing-extensions==4.10.0 152 | # via 153 | # altair 154 | # streamlit 155 | tzdata==2023.3 156 | # via pandas 157 | tzlocal==5.2 158 | # via streamlit 159 | urllib3==2.0.7 160 | # via requests 161 | validators==0.22.0 162 | # via streamlit 163 | yarl==1.17.0 164 | # via aiohttp 165 | zipp==3.17.0 166 | # via importlib-metadata 167 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/utils/__init__.py -------------------------------------------------------------------------------- /utils/completion.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from aiohttp import ClientSession 3 | from constants import client 4 | import streamlit as st 5 | 6 | api_key = st.secrets['api-keys']['ai21-api-key'] 7 | 8 | endpoint = lambda model_type: f"https://api.ai21.com/studio/v1/{model_type}/complete" 9 | 10 | 11 | async def async_complete(model_type, prompt, config, delay=0): 12 | async with ClientSession() as session: 13 | await asyncio.sleep(delay) 14 | auth_header = f"Bearer {api_key}" 15 | res = await session.post( 16 | endpoint(model_type), 17 | headers={"Authorization": auth_header}, 18 | json={"prompt": prompt, **config} 19 | ) 20 | res = await res.json() 21 | return res 22 | 23 | 24 | def tokenize(text): 25 | res = client.count_tokens(text) 26 | return res 27 | 28 | 29 | async def paraphrase_req(sentence, tone): 30 | async with ClientSession() as session: 31 | res = await session.post( 32 | "https://api.ai21.com/studio/v1/paraphrase", 33 | headers={f"Authorization": f"Bearer {st.secrets['api-keys']['ai21-api-key']}"}, 34 | json={ 35 | "text": sentence, 36 | "intent": tone.lower(), 37 | "spanStart": 0, 38 | "spanEnd": len(sentence) 39 | } 40 | ) 41 | res = await res.json() 42 | return res 43 | -------------------------------------------------------------------------------- /utils/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/utils/components/__init__.py -------------------------------------------------------------------------------- /utils/components/completion_log.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import datetime 4 | 5 | 6 | def init_log(log_cols): 7 | if "completion_log" not in st.session_state: 8 | st.session_state["completion_log"] = CompletionLog(cols=log_cols) 9 | 10 | 11 | class CompletionLog: 12 | 13 | def __init__(self, cols): 14 | self.completion_log = pd.DataFrame(columns=cols) 15 | 16 | def get_completions_log(self): 17 | return self.completion_history 18 | 19 | def add_completion(self, completion_data): 20 | self.completion_log = pd.concat([self.completion_log, pd.DataFrame(completion_data)], ignore_index=True) 21 | 22 | def add_completion_button(self, completion_data): 23 | if st.button("Add item to dataset"): 24 | self.add_completion(completion_data) 25 | 26 | def remove_completion(self, ind): 27 | self.completion_log.drop(ind, axis=0) 28 | 29 | def download_history(self): 30 | st.download_button( 31 | "Download Completion Log", 32 | self.completion_log.to_csv().encode('utf-8'), 33 | f"completion_log_{datetime.datetime.now():%Y-%m-%d_%H-%M-%S}.csv", 34 | "text/csv", 35 | key='download-csv' 36 | ) 37 | 38 | def display(self): 39 | st.markdown('#') 40 | st.subheader("Completion Log") 41 | st.dataframe(self.completion_log) 42 | self.download_history() 43 | -------------------------------------------------------------------------------- /utils/filters.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | CHAR_LIMIT = { 4 | "Twitter": (30, 280), 5 | "Linkedin": (100, 1500), 6 | } 7 | 8 | def anonymize(text): 9 | text = re.sub(r'https?:\/\/.*', '[URL]', text) 10 | return re.sub(r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+', '[EMAIL]', text) 11 | 12 | 13 | def is_duplicate_prefix(input_text, output_text, th=0.7): 14 | input_words = input_text.strip().split() 15 | output_words = output_text.strip().split() 16 | if len(input_words) == 0 or len(output_words) == 0: 17 | return True 18 | output_prefix = output_words[:len(input_words)] 19 | overlap = set(output_prefix) & set(input_words) 20 | return len(overlap) / len(output_prefix) > th 21 | 22 | 23 | def apply_filters(completion, prompt, media): 24 | min_length, max_length = CHAR_LIMIT[media] 25 | text = completion.data.text 26 | return completion.finish_reason.reason == "endoftext" \ 27 | and min_length <= len(text) <= max_length \ 28 | and not is_duplicate_prefix(text, prompt) \ 29 | and "[" not in text and "]" not in text 30 | 31 | 32 | def filter_duplicates(completions): 33 | results = list() 34 | for curr in completions: 35 | if not any(is_duplicate_prefix(r, curr) for r in results): 36 | results.append(curr) 37 | return results 38 | 39 | 40 | def remove_utf_emojis(s): 41 | return re.sub(r'<0x([0-9A-Fa-f]+)>', "", s) -------------------------------------------------------------------------------- /utils/studio_style.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | def apply_studio_style(): 4 | st.markdown( 5 | """ 6 | 13 | """, 14 | unsafe_allow_html=True, 15 | ) 16 | 17 | st.image("./assets/studio_logo.png") --------------------------------------------------------------------------------