├── .gitignore
├── .streamlit
└── config.toml
├── README.md
├── Welcome.py
├── assets
├── font.woff
├── font.woff2
├── studio_logo.png
└── studio_logo.svg
├── constants.py
├── pages
├── 10_Contextual_Answers.py
├── 11_Multi_Document_Q&A.py
├── 1_Blog_Post_Generator.py
├── 2_Product_Description_Generator.py
├── 4_Pitch_Email_Generator.py
├── 5_Social_Media_Generator.py
├── 6_Intent_Classifier.py
├── 7_Topic_Classifier.py
├── 8_Rewrite_Tool.py
└── 9_Document_Summarizer.py
├── requirements.in
├── requirements.txt
└── utils
├── __init__.py
├── completion.py
├── components
├── __init__.py
└── completion_log.py
├── filters.py
└── studio_style.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .streamlit
3 | .streamlit/secrets.toml
4 | .python-version
5 |
6 | .DS_Store
7 | file/
8 | __pycache__/
9 | .devcontainer/
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | primaryColor="#E91E63"
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # AI21 Studio Demos
4 |
5 | This repository demonstrates the following examples of AI21 LLM use cases with Streamlit apps:
6 | - Blog Post Generator
7 | - Product Description Generator
8 | - Pitch Email and Social Media Generator with Article Summerization
9 | - Topic Classifier
10 | - Paraphrases & Rewrites Tool
11 | - Document and Website Summarizer
12 | - Contextual Answers
13 | - Multiple Document Q&A
14 |
15 | [Streamlit Demo App](https://ai21-studio-demos.streamlit.app/)
16 |
17 | # Setup the app with your AI21 account
18 | - Create `secrets.toml` in `.streamlit` folder, and add your credentials (replace `` with your AI21 API Key):
19 | ```
20 | [api-keys]
21 | ai21-api-key = ""
22 | ```
23 | - Run `Welcome.py`
24 |
25 | ## Request AI21 account and API key
26 | - [Create](https://studio.ai21.com/login) your AI21 Account
27 | - Locate your [AI21 API Key](https://studio.ai21.com/account/api-key)
28 |
--------------------------------------------------------------------------------
/Welcome.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils.studio_style import apply_studio_style
3 |
4 |
5 | if __name__ == '__main__':
6 | st.set_page_config(
7 | page_title="Welcome"
8 | )
9 | apply_studio_style()
10 | st.title("Welcome to AI21 Studio demos")
11 | st.markdown("Experience the incredible power of large language models first-hand. With these demos, you can explore a variety of unique use cases that showcase what our sophisticated technology is truly capable of. From instant content generation to a paraphraser that can rewrite any text, the world of AI text generation will be at your fingertips." )
12 | st.markdown("Check out the brains behind the demos here: https://www.ai21.com/studio")
13 | st.markdown("Please note that this is a limited demonstration of AI21 Studio's capabilities. If you're interested in learning more, contact us at studio@ai21.com")
14 |
--------------------------------------------------------------------------------
/assets/font.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/font.woff
--------------------------------------------------------------------------------
/assets/font.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/font.woff2
--------------------------------------------------------------------------------
/assets/studio_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/assets/studio_logo.png
--------------------------------------------------------------------------------
/assets/studio_logo.svg:
--------------------------------------------------------------------------------
1 |
12 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | from ai21 import AI21Client
2 | import streamlit as st
3 |
4 | client = AI21Client(api_key=st.secrets['api-keys']['ai21-api-key'])
5 |
6 | DEFAULT_MODEL = 'j2-ultra'
7 |
8 | SUMMARIZATION_URL = "https://www.ai21.com/blog/announcing-ai21-studio-and-jurassic-1"
9 | SUMMARIZATION_TEXT = '''Perhaps no other crisis in modern history has had as great an impact on daily human existence as COVID-19. And none has forced businesses throughout the world to accelerate their evolution as their leaders worked to respond and recover on the way to thriving in the postpandemic environment.
10 |
11 | Deloitte Private’s latest global survey of private enterprises reveals that executives in every region used the crisis as a catalyst, accelerating change in virtually all aspects of how we work and live. They stepped up their digital transformation through greater technology investment and deployment. In-progress initiatives were pushed toward completion, while those that were on the drawing board came to life. They sought out new partnerships and alliances. They pursued new opportunities to strengthen their supply networks and grow markets. They increased efforts to understand their purpose beyond profits, seeking new ways to grow sustainably and strengthen trust with their employees, customers, and other key stakeholders. They also embraced new possibilities in how and where work gets done.
12 | '''
13 |
14 | CLASSIFICATION_FEWSHOT="""Classify the following news article into one of the following topics:
15 | 1. World
16 | 2. Sports
17 | 3. Business
18 | 4. Science and Technology
19 | Title:
20 | Astronomers Observe Collision of Galaxies, Formation of Larger
21 | Summary:
22 | An international team of astronomers has obtained the clearest images yet of the merger of two distant clusters of galaxies, calling it one of the most powerful cosmic events ever witnessed.
23 | The topic of this article is:
24 | Science and Technology
25 |
26 | ===
27 |
28 | Classify the following news article into one of the following topics:
29 | 1. World
30 | 2. Sports
31 | 3. Business
32 | 4. Science and Technology
33 | Title:
34 | Bomb Explodes Near U.S. Military Convoy (AP)
35 | Summary:
36 | AP - A car bomb exploded early Sunday near a U.S. military convoy on the road leading to Baghdad's airport, Iraqi police said, and a witness said two Humvees were destroyed.
37 | The topic of this article is:
38 | World
39 |
40 | ===
41 |
42 | Classify the following news article into one of the following topics:
43 | 1. World
44 | 2. Sports
45 | 3. Business
46 | 4. Science and Technology
47 | Title:
48 | Maradona goes to Cuba
49 | Summary:
50 | The former Argentine football star, Diego Armando Maradona, traveled on Monday to Cuba to continue his treatment against his addiction to drugs.
51 | The topic of this article is:
52 | Sports
53 |
54 | ===
55 |
56 | Classify the following news article into one of the following topics:
57 | 1. World
58 | 2. Sports
59 | 3. Business
60 | 4. Science and Technology
61 | Title:
62 | Duke earnings jump in third quarter
63 | Summary:
64 | Duke Energy Corp. reports third-quarter net income of $389 million, or 41 cents per diluted share, sharply above earnings of $49 million, or 5 cents per diluted share, in the same period last year.
65 | The topic of this article is:
66 | Business
67 |
68 | ===
69 |
70 | """
71 |
72 | CLASSIFICATION_PROMPT="""Classify the following news article into one of the following topics:
73 | 1. World
74 | 2. Sports
75 | 3. Business
76 | 4. Science and Technology"""
77 |
78 | CLASSIFICATION_TITLE = "D.C. Unveils Stadium Plan"
79 |
80 | CLASSIFICATION_DESCRIPTION = "Rumors spread that Major League Baseball is edging closer to moving the Expos to Washington as D.C. officials announce plans for a stadium on the Anacostia waterfront."
81 |
82 | PRODUCT_DESCRIPTION_FEW_SHOT = '''Write product descriptions for fashion eCommerce site based on a list of features.
83 | Product: On Every Spectrum Fit and Flare Dress
84 | Features:
85 | - Designed by Retrolicious
86 | - Stretch cotton fabric
87 | - Side pockets
88 | - Rainbow stripes print
89 | Description: In a bold rainbow-striped print, made up of exceptionally vibrant hues, this outstanding skater dress from Retroliciousis on every spectrum of vintage-inspired style. Made from a stretchy cotton fabric and boasting a round neckline, a sleeveless fitted bodice, and a gathered flare skirt with handy side pockets, this adorable fit-and-flare dress is truly unique and so retro-chic.
90 |
91 | ##
92 |
93 | Write product descriptions for fashion eCommerce site based on a list of features.
94 | Product: Camp Director Crossbody Bag
95 | Features:
96 | - Black canvas purse
97 | - Rainbow space print
98 | - Leather trim
99 | - Two securely-zipped compartments
100 | Description: Take a bit of camp charm with you wherever you go with this black canvas purse! Adorned with a rainbow space motif print, black faux-leather trim, two securely-zipped compartments, and adjustable crossbody strap, this ModCloth-exclusive bag makes sure you command a smile wherever you wander.
101 |
102 | ##
103 |
104 | Write product descriptions for fashion eCommerce site based on a list of features.'''
105 |
106 | OBQA_CONTEXT = """Large Language Models
107 | Introduction to the core of our product
108 |
109 | Natural language processing (NLP) has seen rapid growth in the last few years since large language models (LLMs) were introduced. Those huge models are based on the Transformers architecture, which allowed for the training of much larger and more powerful language models.
110 | We divide LLMs into two main categories, Autoregressive and Masked LM (language model). In this page we will focus on Autoregressive LLMs, as our language models, Jurassic-1 series, belongs to this category.
111 |
112 | ⚡ The task: predict the next word
113 | Autoregressive LLM is a neural network model composed from billions of parameters. It was trained on a massive amount of texts with one goal: to predict the next word, based on the given text. By repeating this action several times, every time adding the prediction word to the provided text, you will end up with a complete text (e.g. full sentences, paragraphs, articles, books, and more). In terms of terminology, the textual output (the complete text) is called a completion while the input (the given, original text) is called prompt.
114 |
115 | 🎓 Added value: knowledge acquisition
116 | Imagine you had to read all of Shakespeare's works repeatedly to learn a language. Eventually, you would be able to not only memorize all of his plays and poems, but also imitate his writing style.
117 | In similar fashion, we trained the LLMs by supplying them with many textual sources. This enabled them to gain an in-depth understanding of English as well as general knowledge.
118 |
119 | 🗣️ Interacting with Large Language Models
120 | The LLMs are queried using natural language, also known as prompt engineering.
121 | Rather than writing lines of code and loading a model, you write a natural language prompt and pass it to the model as the input.
122 |
123 | ⚙️ Resource-intensive
124 | Data, computation, and engineering resources are required for training and deploying large language models. LLMs, such as our Jurassic-1 models, play an important role here, providing access to this type of technology to academic researchers and developers.
125 |
126 | Tokenizer & Tokenization
127 |
128 | Now that you know what large language models are, you must be wondering: “How does a neural network use text as input and output?”.
129 |
130 | The answer is: Tokenization 🧩
131 | Any language can be broken down into basic pieces (in our case, tokens). Each of those pieces is translated into its own vector representation, which is eventually fed into the model. For example:
132 | Each model has its own dictionary of tokens, which determines the language it "speaks". Each text in the input will be decomposed into these tokens, and every text generated by the model will be composed of them.
133 | But how do we break down a language? Which pieces are we choosing as our tokens? There are several approaches to solve this:
134 |
135 | 🔡 Character-level tokenization
136 | As a simple solution, each character can be treated as its own token. By doing so, we can represent the entire English language with just 26 characters (okay, double it for capital letters and add some punctuation). This would give us a small token dictionary, thereby reducing the width we need for those vectors and saving us some valuable memory. However, those tokens don’t have any inherent meaning - we all know what the meaning of “Cat” is, but what is the meaning of “C”? The key to understanding language is context. Although it is clear to us readers that a "Cat" and a "Cradle" have different meanings, for a language model with this tokenizer - the "C" is the same.
137 |
138 | 🆒 Word-level tokenization
139 | Another approach we can try is breaking our text into words, just like in the example above ("I want to break free").
140 | Now, every token has a meaning that the model can learn and use. We are gaining meaning, but that requires a much larger dictionary. Also, it raises another question: what about words stemming from the same root-word like ”helped”, “helping”, and “helpful”? In this approach each of these words will get a different token with no inherent relation between them, whereas for us readers it's clear that they all have a similar meaning.
141 | Furthermore, words may have fundamentally different meanings when strung together - for instance, my run-down car isn't running anywhere. What if we went a step further?
142 |
143 | 💬 Sentence-level tokenization
144 | In this approach we break our text into sentences. This will capture meaningful phrases! However, this would result in an absurdly large dictionary, with some tokens being so rare that we would require an enormous amount of data to teach the model the meaning of each token.
145 |
146 | 🏅 Which is best?
147 | Each method has pros and cons, and like any real-life problem, the best solution involves a number of compromises. AI21 Studio uses a large token dictionary (250K), which contains some from every method: separate characters, words, word parts such as prefixes and suffixes, and many multi-word tokens."""
148 |
149 | OBQA_QUESTION = "Which tokenization methods are there?"
150 |
151 | DOC_QA = "What would you like to know?"
152 |
--------------------------------------------------------------------------------
/pages/10_Contextual_Answers.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils.completion import tokenize
3 | from utils.studio_style import apply_studio_style
4 | from constants import OBQA_CONTEXT, OBQA_QUESTION, client
5 |
6 | st.set_page_config(
7 | page_title="Answers",
8 | )
9 |
10 | max_tokens = 2048 - 200
11 |
12 |
13 | if __name__ == '__main__':
14 |
15 | apply_studio_style()
16 | st.title("Contextual Answers")
17 |
18 | st.write("Ask a question on a given context.")
19 |
20 | context = st.text_area(label="Context:", value=OBQA_CONTEXT, height=300)
21 | question = st.text_input(label="Question:", value=OBQA_QUESTION)
22 |
23 | if st.button(label="Answer"):
24 | with st.spinner("Loading..."):
25 | num_tokens = tokenize(context + question)
26 | if num_tokens > max_tokens:
27 | st.write("Text is too long. Input is limited up to 2048 tokens. Try using a shorter text.")
28 | if 'answer' in st.session_state:
29 | del st.session_state['completions']
30 | else:
31 | response = client.answer.create(context=context, question=question)
32 | st.session_state["answer"] = response.answer
33 |
34 | if "answer" in st.session_state:
35 | st.write(st.session_state['answer'])
36 |
--------------------------------------------------------------------------------
/pages/11_Multi_Document_Q&A.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import re
3 | import pdfplumber
4 | from ai21.errors import UnprocessableEntity
5 | from utils.studio_style import apply_studio_style
6 | from constants import DOC_QA, client
7 | import os
8 | from datetime import date
9 | from ai21.models.chat import ChatMessage
10 |
11 | max_chars = 200000
12 | label = 'multi_doc'+str(date.today())
13 |
14 |
15 | def write_to_library(segmented_text, file_name):
16 | folder_name = "file"
17 | if not os.path.exists(folder_name):
18 | os.mkdir(folder_name)
19 | path = f"./{folder_name}/{file_name}.txt"
20 | f = open(path, "w")
21 | f.write(segmented_text)
22 | f.close()
23 |
24 | return path
25 |
26 |
27 | def parse_file(user_file):
28 | file_type = user_file.type
29 | with st.spinner("File is being processed..."):
30 | if file_type == "text/plain":
31 | all_text = str(user_file.read(), "utf-8", errors='ignore')
32 | else:
33 | with pdfplumber.open(user_file) as pdf:
34 | all_text = [p.extract_text() for p in pdf.pages][0]
35 |
36 | file_path_p = write_to_library(all_text, user_file.name)
37 | return file_path_p
38 |
39 |
40 | def upload_file(file_path_p):
41 | try:
42 | file_id_p = client.library.files.create(file_path=file_path_p, labels=label)
43 | st.session_state['files_ids'] = file_id_p
44 | st.session_state['file_uploaded'] = True
45 | except UnprocessableEntity:
46 | file_id_p = None
47 | return file_id_p
48 |
49 |
50 | st.set_page_config(page_title="Multi-Document Q&A")
51 |
52 | if __name__ == '__main__':
53 | apply_studio_style()
54 | st.title("Multi-Document Q&A")
55 | st.markdown("**Upload documents**")
56 |
57 | uploaded_files = st.file_uploader("choose .pdf/.txt file ",
58 | accept_multiple_files=True,
59 | type=["pdf", "text", "txt"],
60 | key="a")
61 | file_id_list = list()
62 | file_path_list = list()
63 | for uploaded_file in uploaded_files:
64 | file_path = parse_file(uploaded_file)
65 | file_id = upload_file(file_path)
66 | file_id_list.append(file_id)
67 | file_path_list.append(file_path)
68 |
69 | if st.button("Remove file"):
70 | for file in file_path_list:
71 | try:
72 | os.remove(file)
73 | except UnprocessableEntity:
74 | pass
75 | try:
76 | client.library.files.delete(st.session_state['files_ids'])
77 | except UnprocessableEntity:
78 | pass
79 | # for file in file_id_list:
80 | # with st.spinner("Loading..."):
81 | # try:
82 | # client.library.files.delete(file)
83 | # except:
84 | # continue
85 |
86 | st.write("files removed successfully")
87 |
88 | st.markdown("**Ask a question about the uploaded document, and here is the answer:**")
89 |
90 | question = st.chat_input(DOC_QA)
91 | if question:
92 | messages=[ChatMessage(content=question, role="user")]
93 | response = client.beta.conversational_rag.create(messages=messages, label=label)
94 | if response.choices[0].content is None:
95 | st.write("I'm sorry, I cannot answer your questions based on the documents I have access to.")
96 | else:
97 | st.write(response.choices[0].content)
98 |
99 |
--------------------------------------------------------------------------------
/pages/1_Blog_Post_Generator.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import numpy as np
3 | import asyncio
4 | from constants import DEFAULT_MODEL
5 | from utils.studio_style import apply_studio_style
6 | import argparse
7 | from utils.completion import async_complete
8 | from utils.completion import paraphrase_req
9 | from constants import client
10 |
11 | st.set_page_config(
12 | page_title="Blog Post Generator",
13 | )
14 |
15 |
16 | def get_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--port',
19 | type=int,
20 | default=8888)
21 |
22 | parser.add_argument('--num_results',
23 | type=int,
24 | default=5)
25 |
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def build_prompt(title, sections, section_heading):
31 | sections_text = '\n'.join(sections)
32 | prompt = f"Write a descriptive section in a blog post according to the following details.\n\nBlog Title:\n{title}\n\nBlog Sections:\n{sections_text}\n\nCurrent Section Heading:\n{section_heading}\n\nCurrent Section Text:\n"
33 | return prompt
34 |
35 |
36 | def generate_sections_content(num_results, sections, title):
37 |
38 | loop = asyncio.new_event_loop()
39 | asyncio.set_event_loop(loop)
40 |
41 | config = {
42 | "numResults": num_results,
43 | "maxTokens": 256,
44 | "minTokens": 10,
45 | "temperature": 0.7,
46 | "topKReturn": 0,
47 | "topP": 1,
48 | "stopSequences": []
49 | }
50 | group = asyncio.gather(*[async_complete(DEFAULT_MODEL, build_prompt(title, sections, s), config) for s in sections])
51 | results = loop.run_until_complete(group)
52 | loop.close()
53 | return results
54 |
55 |
56 | def build_generate_outline(title):
57 | return lambda: generate_outline(title)
58 |
59 |
60 | def generate_outline(title):
61 | st.session_state['show_outline'] = True
62 | st.session_state['show_sections'] = False
63 |
64 | res = _generate_outline(title)
65 |
66 | st.session_state["outline"] = res.completions[0].data.text.strip()
67 |
68 |
69 | def _generate_outline(title):
70 | prompt = f"Write sections to a great blog post for the following title.\nBlog title: How to start a personal blog \nBlog sections:\n1. Pick a personal blog template\n2. Develop your brand\n3. Choose a hosting plan and domain name\n4. Create a content calendar \n5. Optimize your content for SEO\n6. Build an email list\n7. Get the word out\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: A real-world example on Improving JavaScript performance\nBlog sections:\n1. Why I needed to Improve my JavaScript performance\n2. Three common ways to find performance issues in Javascript\n3. How I found the JavaScript performance issue using console.time\n4. How does lodash cloneDeep work?\n5. What is the alternative to lodash cloneDeep?\n6. Conclusion\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: Is a Happy Life Different from a Meaningful One?\nBlog sections:\n1. Five differences between a happy life and a meaningful one\n2. What is happiness, anyway?\n3. Is the happiness without pleasure?\n4. Can you have it all?\n\n##\n\nWrite sections to a great blog post for the following title.\nBlog title: {title}\nBlog Sections:\n"
71 |
72 | res = client.completion.create(
73 | model=DEFAULT_MODEL,
74 | prompt=prompt,
75 | num_results=1,
76 | max_tokens=296,
77 | temperature=0.84,
78 | top_k_return=0,
79 | top_p=1,
80 | stop_sequences=["##"]
81 | )
82 | return res
83 |
84 |
85 | def generate_sections():
86 | st.session_state['show_sections'] = True
87 |
88 |
89 | def build_on_next_click(section_heading, section_index, completions, arg_sorted_by_length):
90 | return lambda: on_next_click(section_heading, section_index, completions, arg_sorted_by_length)
91 |
92 |
93 | def on_next_click(section_heading, section_index, completions, arg_sorted_by_length):
94 | st.session_state['show_paraphrase'][section_heading] = False
95 | new_comp_index = (st.session_state['generated_sections_data'][section_heading]["text_area_index"] + 1) % 5
96 | section_i_text = completions[arg_sorted_by_length[new_comp_index]]["data"]["text"]
97 | st.session_state['generated_sections_data'][section_heading]["text_area_index"] = new_comp_index
98 | st.session_state['generated_sections_data'][section_heading]["text_area_data"].text_area(label=section_heading,
99 | height=300,
100 | value=section_i_text,
101 | key=section_index)
102 |
103 |
104 | def build_on_prev_click(section_heading, section_index, completions, arg_sorted_by_length):
105 | return lambda: on_prev_click(section_heading, section_index, completions, arg_sorted_by_length)
106 |
107 |
108 | def on_prev_click(section_heading, section_index, completions, arg_sorted_by_length):
109 | st.session_state['show_paraphrase'][section_heading] = False
110 |
111 | new_comp_index = (st.session_state['generated_sections_data'][section_heading]["text_area_index"] - 1) % 5
112 | section_i_text = completions[arg_sorted_by_length[new_comp_index]]["data"]["text"]
113 | st.session_state['generated_sections_data'][section_heading]["text_area_index"] = new_comp_index
114 | st.session_state['generated_sections_data'][section_heading]["text_area_data"].text_area(label=section_heading,
115 | height=300,
116 | value=section_i_text,
117 | key=section_index)
118 |
119 |
120 | def get_event_loop(title, sections, num_results):
121 | st.session_state['show_sections'] = True
122 |
123 | for s in sections:
124 | st.session_state['generated_sections_data'][s] = {}
125 | st.session_state['show_paraphrase'][s] = False
126 |
127 | # perform request, actually generate sections
128 | results = generate_sections_content(num_results, sections, title)
129 |
130 | # moved these lines here to detach st code from logic
131 | for i, s in enumerate(sections):
132 | response_json = results[i]
133 | section_completions = response_json["completions"] # gets the generated candidates of the current completion
134 | st.session_state['generated_sections_data'][s]["completions"] = section_completions
135 |
136 | # rank/filter
137 | for i, s in enumerate(sections):
138 | response_json = results[i]
139 | section_completions = response_json["completions"]
140 | st.session_state['generated_sections_data'][s]["completions"] = section_completions
141 |
142 | lengths = []
143 | for c in range(len(section_completions)):
144 | l = len(section_completions[c]["data"]["text"])
145 | lengths.append(l)
146 |
147 | arg_sort = np.argsort(lengths)
148 | index = 2
149 | st.session_state['generated_sections_data'][s]["text_area_index"] = index
150 | st.session_state['generated_sections_data'][s]["arg_sort"] = arg_sort
151 |
152 | st.session_state['generated_sections_data'][s]["rewrites"] = ["" for c in range(len(section_completions))]
153 |
154 |
155 | def build_event_loop(title, section_heading, num_results):
156 | return lambda: get_event_loop(title, section_heading, num_results)
157 |
158 |
159 | def build_event_loop_one_section(title, section, num_results):
160 | return lambda: get_event_loop(title, [section], num_results)
161 |
162 |
163 | def on_outline_change():
164 | st.session_state['show_sections'] = False
165 |
166 |
167 | def paraphrase(text, tone, times):
168 | len_text = len(text)
169 | entire_text = text
170 | for i in range(times):
171 | if len_text > 500:
172 | sentences = text.split(".")
173 | else:
174 | sentences = [text]
175 |
176 | filtered_sentences = []
177 | for sentence in sentences:
178 | sentence = sentence.strip()
179 | len_sent = len(sentence)
180 | if len_sent > 1:
181 | filtered_sentences.append(sentence)
182 |
183 | loop = asyncio.new_event_loop()
184 | asyncio.set_event_loop(loop)
185 | group = asyncio.gather(*[paraphrase_req(sentence, tone) for sentence in filtered_sentences])
186 |
187 | results = loop.run_until_complete(group)
188 | loop.close()
189 |
190 | final_text = []
191 | for r in results:
192 | sugg = r["suggestions"][0]["text"]
193 | final_text.append(sugg)
194 |
195 | entire_text = ". ".join(final_text)
196 | entire_text = (entire_text + ".").replace(",.", ".").replace("?.", ".").replace("..", ".")
197 |
198 | text = entire_text
199 | len_text = len(entire_text)
200 |
201 | return entire_text
202 |
203 |
204 | def on_paraphrase_click(s, tone, times):
205 |
206 | all_sections_data = st.session_state['generated_sections_data']
207 |
208 | index = st.session_state['generated_sections_data'][s]["text_area_index"]
209 | section_completions = all_sections_data[s]["completions"]
210 |
211 | sec_text = section_completions[index]["data"]["text"]
212 | paraphrased_section = paraphrase(sec_text, tone, times)
213 |
214 | st.session_state['generated_sections_data'][s]["rewrites"][index] = paraphrased_section
215 | st.session_state['show_paraphrase'][s] = True
216 |
217 |
218 | def build_paraphrase(s, tone, times):
219 | return lambda: on_paraphrase_click(s, tone, times)
220 |
221 |
222 | def on_heading_change():
223 | st.session_state['show_sections'] = False
224 |
225 |
226 | def on_title_change():
227 | st.session_state['show_sections'] = False
228 |
229 |
230 | if __name__ == '__main__':
231 | args = get_args()
232 | apply_studio_style()
233 | num_results = args.num_results
234 |
235 | # Initialization
236 | if 'show_outline' not in st.session_state:
237 | st.session_state['show_outline'] = False
238 |
239 | if 'show_sections' not in st.session_state:
240 | st.session_state['show_sections'] = False
241 |
242 | if 'show_paraphrase' not in st.session_state:
243 | st.session_state['show_paraphrase'] = {}
244 |
245 | if 'generated_sections_data' not in st.session_state:
246 | st.session_state['generated_sections_data'] = {}
247 |
248 | st.title("Blog Post Generator")
249 | st.markdown("Using only a title, you can instantly generate an entire article with the click of a button! Simply select your topic and this tool will create an engaging article from beginning to end.")
250 | st.markdown("#### Blog Title")
251 | title = st.text_input(label="Write the title of your article here:", placeholder="",
252 | value="5 Strategies to overcome writer's block").strip()
253 | st.markdown("#### Blog Outline")
254 | st.text("Click the button to generate the blog outline")
255 | st.button(label="Generate Outline", on_click=build_generate_outline(title))
256 |
257 | sections = []
258 | if st.session_state['show_outline']:
259 | text_area_outline = st.text_area(label=" ", height=250, value=st.session_state["outline"],
260 | on_change=on_outline_change)
261 | sections = text_area_outline.split("\n")
262 | st.text("Unsatisfied with the generated outline? Click the 'Generate Outline' button again to re-generate it, or edit it inline.")
263 |
264 | st.markdown("#### Blog Sections")
265 | st.text("Click the button to effortlessly generate an outline for your blog post:")
266 | st.button(label="Generate Sections", on_click=build_event_loop(title, sections, num_results))
267 |
268 | if st.session_state['show_sections']:
269 | st.markdown(f"**{title}**")
270 | for s in sections:
271 | st.session_state['generated_sections_data'][s]["text_area_data"] = st.empty()
272 | st.session_state['generated_sections_data'][s]["cols"] = st.empty()
273 |
274 | all_sections_data = st.session_state['generated_sections_data']
275 | for i, s in enumerate(st.session_state['generated_sections_data'].keys()):
276 | index = st.session_state['generated_sections_data'][s]["text_area_index"]
277 | section_completions = all_sections_data[s]["completions"]
278 | arg_sort = st.session_state['generated_sections_data'][s]["arg_sort"]
279 |
280 | section_text_area_value = st.session_state['generated_sections_data'][s]["rewrites"][index] if st.session_state['show_paraphrase'][s] == True else section_completions[
281 | index]["data"]["text"]
282 | section_i_text = st.session_state['generated_sections_data'][s]["text_area_data"].text_area(label=s,
283 | height=300,
284 | value=section_text_area_value,
285 | key="generated-section"+s)
286 | st.session_state['generated_sections_data'][s]["completions"][index]["data"]["text"] = section_i_text
287 | col1, col2, col3, col4, col5, col6 = st.session_state['generated_sections_data'][s]["cols"].columns(
288 | [0.2, 0.2, 0.06, 0.047, 0.05, 0.4])
289 |
290 | with col1:
291 | st.button("Generate Again", on_click=build_event_loop_one_section(title, s, num_results),
292 | key="generate-again-" + s)
293 |
294 | with col2:
295 | st.button("Paraphrase", on_click=build_paraphrase(s, tone="general", times=1),
296 | key="paraphrase-button-" + s)
297 |
298 | with col3:
299 | st.button("<", on_click=build_on_prev_click(s, i, section_completions, arg_sort), key="<" + s)
300 |
301 | with col4:
302 | st.text(f"{index+1}/{num_results}")
303 |
304 | with col5:
305 | st.button(">", on_click=build_on_next_click(s, i, section_completions, arg_sort), key=">" + s)
306 |
--------------------------------------------------------------------------------
/pages/2_Product_Description_Generator.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from constants import PRODUCT_DESCRIPTION_FEW_SHOT, DEFAULT_MODEL
3 | from utils.studio_style import apply_studio_style
4 | from constants import client
5 | from ai21.models import Penalty
6 |
7 | st.set_page_config(
8 | page_title="Product Description Generator",
9 | )
10 |
11 |
12 | def query(prompt):
13 |
14 | res = client.completion.create(
15 | model=DEFAULT_MODEL,
16 | prompt=prompt,
17 | num_results=1,
18 | max_tokens=240,
19 | temperature=1,
20 | top_k_return=0,
21 | top_p=0.98,
22 | count_penalty=Penalty(
23 | scale=0,
24 | apply_to_emojis=False,
25 | apply_to_numbers=False,
26 | apply_to_stopwords=False,
27 | apply_to_punctuation=False,
28 | apply_to_whitespaces=False,
29 | ),
30 | frequency_penalty=Penalty(
31 | scale=225,
32 | apply_to_emojis=False,
33 | apply_to_numbers=False,
34 | apply_to_stopwords=False,
35 | apply_to_punctuation=False,
36 | apply_to_whitespaces=False,
37 | ),
38 | presence_penalty=Penalty(
39 | scale=1.2,
40 | apply_to_emojis=False,
41 | apply_to_numbers=False,
42 | apply_to_stopwords=False,
43 | apply_to_punctuation=False,
44 | apply_to_whitespaces=False,
45 | )
46 | )
47 |
48 | return res.completions[0].data.text
49 |
50 |
51 | if __name__ == '__main__':
52 |
53 | apply_studio_style()
54 | st.title("Product Description Generator")
55 | st.markdown("###### Create valuable marketing copy for product pages that describes your product and its benefits within seconds! Simply choose a fashion accessory, a few key features, and let our tool work its magic.")
56 |
57 | product_input = st.text_input("Enter the name of your product:", value="Talking Picture Oxford Flat")
58 | features = st.text_area("List your product features here:", value="- Flat shoes\n- Amazing chestnut color\n- Man made materials")
59 |
60 | prompt = PRODUCT_DESCRIPTION_FEW_SHOT + f"Product: {product_input}\nFeatures:\n{features}\nDescription:"
61 |
62 | if st.button(label="Generate Description"):
63 | st.session_state["short-form-save_results_ind"] = []
64 | with st.spinner("Loading..."):
65 | st.session_state["short-form-result"] = {
66 | "completion": query(prompt),
67 | }
68 |
69 | if "short-form-result" in st.session_state:
70 | result = st.session_state["short-form-result"]["completion"]
71 | st.text("")
72 | st.text_area("Generated Product Description", result, height=200)
73 |
--------------------------------------------------------------------------------
/pages/4_Pitch_Email_Generator.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from constants import DEFAULT_MODEL
3 | from utils.completion import tokenize
4 | from utils.studio_style import apply_studio_style
5 | import re
6 | from constants import client
7 |
8 |
9 | st.set_page_config(
10 | page_title="Pitch Email Generator",
11 | )
12 |
13 | max_tokens = 2048 - 200
14 |
15 | WORDS_LIMIT = {
16 | "pitch": (150, 200),
17 | }
18 |
19 | title_placeholder = "PetSmart Charities® Commits $100 Million to Improve Access to Veterinary Care"
20 | article_placeholder = """The inability of many families to access veterinary care is a pressing issue facing animal welfare nationally. To combat this, PetSmart Charities announced a commitment of $100 million over the next five years to help break down the geographic, cultural, language and financial barriers that prevent pets from receiving the veterinary care they need to thrive.
21 | Zora is a lovable pug whose owner knew something was wrong with her breathing, but was struggling to find a vet that would provide non-routine care at low cost. Despite the challenges of transitioning from homelessness and landing in the hospital himself, Zora’s owner found someone to take her to a free clinic offered by Ruthless Kindness, a PetSmart Charities grantee. Thanks to the care she received, Zora and her owner are now thriving together.
22 | Veterinary access impacts the animal welfare industry and individual families in every community in the country. More than 70 percent of homes in the United States now include pets, but 50 million pets in the U.S. lack even basic veterinary care, including spay/neuter surgeries, annual exams and vaccinations. Without regular veterinary care, minor pet health issues often become bigger, costlier problems; and preventable diseases can be passed on to people and other animals. Pet parents may be forced to relinquish their beloved furry family members to already overcrowded animal shelters or be forced to watch them suffer when they can't access treatment. With pets being universally recognized as beloved family members, the challenges posed by an inability to access veterinary care can have a profound impact.
23 | PetSmart Charities estimates it would cost more than $20 billion annually to bridge the gap for pets in need of veterinary care at standard veterinary prices. More needs to be done to expand availability of lower-cost services, ensure access for remote and bilingual communities and ensure there are enough veterinarians able to perform a variety of services through clinics and urgent care centers. To help lead the charge, the nonprofit is taking a leadership role in marshaling partners and stakeholders to develop and execute solutions to solving the gap in veterinary care access.
24 | ""The challenges facing the veterinary care system are vast and varied and no single organization can solve them alone,"" said Aimee Gilbreath, president of PetSmart Charities. ""Through PetSmart Charities' commitment, we plan to invest further in our partners and build new alliances to innovate solutions across the entire system — while also funding long-term solutions already in place such as low-cost vet clinics and veterinary student scholarships. We're confident this approach will produce sustainable change within the veterinary care industry. Our best friends deserve access to adequate health care like any family members.""
25 |
26 | Barriers to Veterinary Care
27 | While affordability remains the most prominent barrier to veterinary care, additional challenges contribute to the current veterinary care gap, including:
28 | Veterinary Shortage: With pet ownership steadily on the rise, a 33% increase in pet healthcare service spending is expected over the next 10 years. 90 million U.S. households now include pets, but the number of nationwide veterinarians has increased by just 2.7 percent each year since 2007. To meet the growing need for veterinary care, an additional 41,000 veterinarians would be needed by 2030.
29 | Veterinary Deserts and Cultural Inclusion: Within rural and underserved regions, veterinary practices are difficult to find close-by, making trips to the veterinarian costly and sometimes impossible. Additionally, as veterinarians are currently listed among the top five least diverse professions, cultural and language divides can often occur between clients and veterinarians, discouraging some pet parents from seeking care.
30 | Economic Challenges: The cost of veterinary care has spiked 10 percent in the last year alone and amidst the ongoing housing crisis and economic uncertainty, 63 percent of pet parents find it difficult to cover surprise vet expenses.
31 | Regulatory Challenges: Nationwide, fragmented and varied veterinary regulations pose challenges to the development of easy, efficient and consistent solutions such as telemedicine.
32 | How PetSmart Charities Will Innovate Solutions
33 | Through a $100 million commitment in funding over the next five years, PetSmart Charities will take a multifaceted approach to improve access to adequate veterinary care for all pets, including:
34 | Funding solutions across the system of veterinary care – from investing in new and more affordable types of clinics to working directly with providers to help them overcome challenges in care delivery.
35 | Supporting innovative solutions such as new telehealth care and delayed payment models that reduce and help manage the cost of care for pet parents.
36 | Partnering with universities and thought leaders to research the evolving needs of pets while developing innovative, cost-effective ways to deliver care.
37 | Awarding scholarships to veterinary students pursuing community-based practices and establishing a training program for Master's-level veterinary practitioners to offer basic care at affordable prices.
38 | Expanding access to lower-cost veterinary care through sustainable nonprofit clinics.
39 | Developing community-based models led by local changemakers to improve access to veterinary care to underserved communities through an emphasis on their unique challenges.
40 | For more information on how PetSmart Charities is working to expand access to veterinary care nationwide or to help support initiatives like this for pets and their families, visit petsmartcharities.org.
41 |
42 | About PetSmart Charities®
43 | PetSmart Charities is committed to making the world a better place for pets and all who love them. Through its in-store adoption program in all PetSmart® stores across the U.S. and Puerto Rico, PetSmart Charities helps up to 600,000 pets connect with loving families each year. PetSmart Charities also provides grant funding to support organizations that advocate and care for the well-being of all pets and their families. PetSmart Charities' grants and efforts connect pets with loving homes through adoption, improve access to affordable veterinary care and support families in times of crises with access to food, shelter and disaster response. Each year, millions of generous supporters help pets in need by donating to PetSmart Charities directly at PetSmartCharities.org, while shopping at PetSmart.com, and by using the PIN pads at checkout registers inside PetSmart® stores. In turn, PetSmart Charities efficiently uses more than 90 cents of every dollar donated to fulfill its role as the leading funder of animal welfare in North America, granting more than $500 million since its inception in 1994. Independent from PetSmart LLC, PetSmart Charities is a 501(c)(3) organization that has received the Four-Star Rating from Charity Navigator for the past 18 years in a row – placing it among the top one percent of rated charities. To learn more visit www.petsmartcharities.org."""
44 |
45 |
46 | def anonymize(text):
47 | text = re.sub(r'https?:\/\/.*', '[URL]', text)
48 | return re.sub(r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+', '[EMAIL]', text)
49 |
50 |
51 | def generate(prompt, category, max_retries=2):
52 | min_length, max_length = WORDS_LIMIT[category]
53 | completions_filtered = []
54 | try_count = 0
55 | while not len(completions_filtered) and try_count < max_retries:
56 | res = client.completion.create(
57 | model=DEFAULT_MODEL,
58 | prompt=prompt,
59 | max_tokens=200,
60 | temperature=0.8,
61 | num_results=16
62 | )
63 | completions_filtered = [comp.data.text for comp in res.completions
64 | if comp.finish_reason.reason == "endoftext"
65 | and min_length <= len(comp.data.text.split()) <= max_length]
66 | try_count += 1
67 | st.session_state["completions"] = [anonymize(i) for i in completions_filtered]
68 |
69 |
70 | def on_next():
71 | st.session_state['index'] = (st.session_state['index'] + 1) % len(st.session_state['completions'])
72 |
73 |
74 | def on_prev():
75 | st.session_state['index'] = (st.session_state['index'] - 1) % len(st.session_state['completions'])
76 |
77 |
78 | def toolbar():
79 | cols = st.columns([0.2, 0.2, 0.2, 0.2, 0.2])
80 | with cols[1]:
81 | if st.button(label='<', key='prev'):
82 | on_prev()
83 | with cols[2]:
84 | st.text(f"{st.session_state['index'] + 1}/{len(st.session_state['completions'])}")
85 | with cols[3]:
86 | if st.button(label='\>', key='next'):
87 | on_next()
88 |
89 |
90 | if __name__ == '__main__':
91 | apply_studio_style()
92 | st.title("Marketing Generator")
93 |
94 | st.session_state['title'] = st.text_input(label="Title", value=title_placeholder).strip()
95 | st.session_state['article'] = st.text_area(label="Article", value=article_placeholder, height=500).strip()
96 |
97 | domain = st.radio(
98 | "Select domain of reporter 👉",
99 | options=['Technology', 'Healthcare', 'Venture Funding', 'Other'],
100 | )
101 |
102 | if domain == 'Other':
103 | instruction = "Write a pitch to reporters persuading them why they should write about this for their publication."
104 | else:
105 | instruction = f"Write a pitch to reporters that cover {domain} stories persuading them why they should write about this for their publication."
106 | suffix = "Email Introduction"
107 | prompt = f"{instruction}\nTitle: {st.session_state['title']}\nPress Release:\n{st.session_state['article']}\n\n{suffix}:\n"
108 | category = 'pitch'
109 |
110 | if st.button(label="Compose"):
111 | with st.spinner("Loading..."):
112 | num_tokens = tokenize(prompt)
113 | if num_tokens > max_tokens:
114 | st.write("Text is too long. Input is limited up to 2048 tokens. Try using a shorter text.")
115 | if 'completions' in st.session_state:
116 | del st.session_state['completions']
117 | else:
118 | generate(prompt, category=category)
119 | st.session_state['index'] = 0
120 |
121 | if 'completions' in st.session_state:
122 | if len(st.session_state['completions']) == 0:
123 | st.write("Please try again 😔")
124 |
125 | else:
126 | curr_text = st.session_state['completions'][st.session_state['index']]
127 | st.subheader(f'Generated Email')
128 | st.text_area(label=" ", value=curr_text.strip(), height=400)
129 | st.write(f"Number of words: {len(curr_text.split())}")
130 | if len(st.session_state['completions']) > 1:
131 | toolbar()
132 |
--------------------------------------------------------------------------------
/pages/5_Social_Media_Generator.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | from utils.filters import *
3 | from utils.studio_style import apply_studio_style
4 | from constants import client
5 |
6 |
7 | def create_prompt(media, article):
8 | post_type = "tweet" if media == "Twitter" else "Linkedin post"
9 | instruction = f"Write a {post_type} touting the following press release."
10 | return f"{instruction}\nArticle:\n{article}\n\n{post_type}:\n"
11 |
12 |
13 | def generate(article, media, max_retries=2, top=3):
14 | prompt = create_prompt(media, article)
15 | completions_filtered = []
16 | try_count = 0
17 | while not len(completions_filtered) and try_count < max_retries:
18 | res = client.completion.create(
19 | model=DEFAULT_MODEL,
20 | prompt=prompt,
21 | max_tokens=200,
22 | temperature=0.8,
23 | num_results=16
24 | )
25 | completions_filtered = [comp.data.text for comp in res.completions
26 | if apply_filters(comp, article, media)]
27 | try_count += 1
28 | res = filter_duplicates(completions_filtered)[:top]
29 | return [remove_utf_emojis(anonymize(i)) for i in res]
30 |
31 |
32 | def on_next():
33 | st.session_state['index'] = (st.session_state['index'] + 1) % len(st.session_state['completions'])
34 |
35 |
36 | def on_prev():
37 | st.session_state['index'] = (st.session_state['index'] - 1) % len(st.session_state['completions'])
38 |
39 |
40 | def toolbar():
41 | cols = st.columns([0.35, 0.1, 0.1, 0.1, 0.35])
42 | with cols[1]:
43 | st.button(label='<', key='prev', on_click=on_prev)
44 | with cols[2]:
45 | st.text(f"{st.session_state['index'] + 1}/{len(st.session_state['completions'])}")
46 | with cols[3]:
47 | st.button(label="\>", key='next', on_click=on_next)
48 | with cols[4]:
49 | st.button(label="🔄", on_click=lambda: compose())
50 |
51 |
52 | def extract():
53 | with st.spinner("Summarizing article..."):
54 | try:
55 | st.session_state['article'] = client.summarize.create(source=st.session_state['url'], source_type='URL').summary
56 | except:
57 | st.session_state['article'] = False
58 |
59 |
60 | def compose():
61 | with st.spinner("Generating post..."):
62 | st.session_state["completions"] = generate(st.session_state['article'], media=st.session_state['media'])
63 | st.session_state['index'] = 0
64 |
65 |
66 | if __name__ == '__main__':
67 | apply_studio_style()
68 | st.title("Social Media Generator")
69 |
70 | st.session_state['url'] = st.text_input(label="Enter your article URL",
71 | value=st.session_state.get('url', 'https://www.ai21.com/blog/announcing-ai21-studio-and-jurassic-1')).strip()
72 |
73 | if st.button(label='Summarize'):
74 | extract()
75 |
76 | if 'article' in st.session_state:
77 | if not st.session_state['article']:
78 | st.write("This article is not supported, please try another one")
79 |
80 | else:
81 | st.text_area(label='Summary', value=st.session_state['article'], height=200)
82 |
83 | st.session_state['media'] = st.radio(
84 | "Compose a post for this article for 👉",
85 | options=['Twitter', 'Linkedin'],
86 | horizontal=True
87 | )
88 |
89 | st.button(label="Compose", on_click=lambda: compose())
90 |
91 | if 'completions' in st.session_state:
92 | if len(st.session_state['completions']) == 0:
93 | st.write("Please try again 😔")
94 |
95 | else:
96 | curr_text = st.session_state['completions'][st.session_state['index']]
97 | st.text_area(label="Your awesome generated post", value=curr_text.strip(), height=200)
98 | if len(st.session_state['completions']) > 1:
99 | toolbar()
100 |
--------------------------------------------------------------------------------
/pages/6_Intent_Classifier.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import math
3 | import streamlit as st
4 | from constants import DEFAULT_MODEL
5 | from utils.studio_style import apply_studio_style
6 | from utils.completion import async_complete, tokenize
7 | import re
8 |
9 |
10 | OTHER_THRESHOLD = 0.2
11 |
12 | st.set_page_config(
13 | page_title="Intent Classifier",
14 | )
15 |
16 |
17 | def generate_response(prompt, delay):
18 | config = {"maxTokens": 0, "temperature": 1}
19 | res = async_complete(model_type=DEFAULT_MODEL,
20 | prompt=prompt,
21 | config=config, delay=delay)
22 | return res
23 |
24 |
25 | def batch_responses(prompts, delay=0.25):
26 | delay = delay if len(prompts) >= 5 else 0 # may deal with less than 5 requests in 1 second
27 | loop = asyncio.new_event_loop()
28 | asyncio.set_event_loop(loop)
29 | group = asyncio.gather(*[generate_response(p, i*delay) for i, p in enumerate(prompts)])
30 | results = loop.run_until_complete(group)
31 | loop.close()
32 | return results
33 |
34 |
35 | if __name__ == "__main__":
36 | apply_studio_style()
37 | st.title("Intent Classifier")
38 | instruction = "Classify the following question into one of the following classes:"
39 | st.write(instruction)
40 |
41 | st.session_state['classes'] = st.text_area(label='Classes', value="Parking\nFood/Restaurants\nRoom Facilities\nSpa\nTransport", height=180)
42 | st.session_state['classes'] = re.sub('\n+', '\n', st.session_state['classes'])
43 |
44 | st.session_state['question'] = st.text_input(label="Question", value="How to get from the airport?")
45 |
46 | if st.button('Classify'):
47 | prompt = f"{instruction}\n{st.session_state['classes']}\n\nQuestion:\n{st.session_state['question']}\n\nClass:\n"
48 | num_tokens = tokenize(prompt)
49 | responses = batch_responses([prompt + c for c in st.session_state['classes'].split('\n')])
50 | results = {}
51 | for i, r in enumerate(responses):
52 | token_list = [t['generatedToken'] for t in r['prompt']['tokens'][num_tokens:] if t['generatedToken']['token'] != '<|newline|>']
53 | class_name = ''.join([t['token'] for t in token_list]).replace('▁', ' ')
54 | sum_logprobs = sum([t['logprob'] for t in token_list])
55 | results[class_name] = round(math.exp(sum_logprobs), 2)
56 |
57 | results['Other'] = 1 - sum(results.values())
58 |
59 | sorted_results = {k: v for k, v in sorted(results.items(), key=lambda x: x[1], reverse=True)}
60 | for name, prob in sorted_results.items():
61 | st.write(name, round(prob, 2))
62 |
63 |
--------------------------------------------------------------------------------
/pages/7_Topic_Classifier.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils.studio_style import apply_studio_style
3 | from constants import CLASSIFICATION_FEWSHOT, CLASSIFICATION_PROMPT, CLASSIFICATION_TITLE, CLASSIFICATION_DESCRIPTION, \
4 | DEFAULT_MODEL
5 | from constants import client
6 |
7 |
8 | st.set_page_config(
9 | page_title="Topic Classifier",
10 | )
11 |
12 |
13 | def query(prompt):
14 |
15 | res = client.completion.create(
16 | model=st.session_state['classification_model'],
17 | prompt=prompt,
18 | num_results=1,
19 | max_tokens=5,
20 | temperature=0,
21 | stop_sequences=["##"]
22 | )
23 | return res.completions[0].data.text
24 |
25 |
26 | if __name__ == '__main__':
27 |
28 | apply_studio_style()
29 | st.title("Topic Classifier")
30 | st.write("Read any interesting news lately? Let's see if our topic classifier can skim through it and identify whether its category is sports, business, world news, or science and technology.")
31 | st.session_state['classification_model'] = DEFAULT_MODEL
32 |
33 | st.text(CLASSIFICATION_PROMPT)
34 | classification_title = st.text_input(label="Title:", value=CLASSIFICATION_TITLE)
35 | classification_description = st.text_area(label="Description:", value=CLASSIFICATION_DESCRIPTION, height=100)
36 |
37 | if st.button(label="Classify"):
38 | with st.spinner("Loading..."):
39 | classification_prompt = f"{CLASSIFICATION_PROMPT}\nTitle:\n{classification_title}" \
40 | f"Description:\n{classification_description}The topic of this article is:\n"
41 | st.session_state["classification_result"] = query(CLASSIFICATION_FEWSHOT + classification_prompt)
42 |
43 | if "classification_result" in st.session_state:
44 | st.subheader(f"Topic: {st.session_state['classification_result']}")
45 |
--------------------------------------------------------------------------------
/pages/8_Rewrite_Tool.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils.studio_style import apply_studio_style
3 | from constants import client
4 | from ai21.models import ParaphraseStyleType
5 |
6 | st.set_page_config(
7 | page_title="Rewrite Tool",
8 | )
9 |
10 |
11 | def get_suggestions(text, intent=ParaphraseStyleType.GENERAL, span_start=0, span_end=None):
12 | rewrite_resp = client.paraphrase.create(
13 | text=text,
14 | style=intent,
15 | start_index=span_start,
16 | end_index=span_end or len(text))
17 | rewritten_texts = [sug.text for sug in rewrite_resp.suggestions]
18 | st.session_state["rewrite_rewritten_texts"] = rewritten_texts
19 |
20 |
21 | def show_next(cycle_length):
22 | # From streamlit docs: "When updating Session state in response to events, a callback function gets executed first, and then the app is executed from top to bottom."
23 | # This means this function just needs to update the current index. The text itself would be shown since the entire app is executed again
24 | curr_index = st.session_state["rewrite_curr_index"]
25 | next_index = (curr_index + 1) % cycle_length
26 | st.session_state["rewrite_curr_index"] = next_index
27 |
28 |
29 | def show_prev(cycle_length):
30 | curr_index = st.session_state["rewrite_curr_index"]
31 | prev_index = (curr_index - 1) % cycle_length
32 | st.session_state["rewrite_curr_index"] = prev_index
33 |
34 |
35 | if __name__ == '__main__':
36 | apply_studio_style()
37 |
38 | st.title("Rewrite Tool")
39 | st.write("Rephrase with ease! Find fresh new ways to reword your sentences with an AI writing companion that paraphrases & rewrites any text. Select rewrite suggestions that clearly convey your ideas with a range of different tones to choose from.")
40 | text = st.text_area(label="Write your text here to see what the rewrite tool can do:",
41 | max_chars=500,
42 | placeholder="AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labs’ state-of-the-art language models.",
43 | value="AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labs’ state-of-the-art language models.").strip()
44 |
45 | intent = st.radio(
46 | "Set your tone 👉",
47 | key="intent",
48 | options=["general", "formal", "casual", "long", "short"],
49 | horizontal=True
50 | )
51 |
52 | st.button(label="Rewrite ✍️", on_click=lambda: get_suggestions(text, intent=intent))
53 | if "rewrite_rewritten_texts" in st.session_state:
54 | suggestions = st.session_state["rewrite_rewritten_texts"]
55 |
56 | ph = st.empty()
57 | if "rewrite_curr_index" not in st.session_state:
58 | st.session_state["rewrite_curr_index"] = 0
59 | curr_index = st.session_state["rewrite_curr_index"]
60 | ph.text_area(label="Suggestions", value=suggestions[curr_index])
61 |
62 | col1, col2, col3, *_ = st.columns([1, 1, 1, 10])
63 | with col1:
64 | st.button("<", on_click=show_prev, args=(len(suggestions),))
65 | with col2:
66 | st.markdown(f"{curr_index+1}/{len(suggestions)}")
67 | with col3:
68 | st.button("\>", on_click=show_next, args=(len(suggestions),))
69 |
--------------------------------------------------------------------------------
/pages/9_Document_Summarizer.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from ai21.errors import UnprocessableEntity
3 |
4 | from utils.studio_style import apply_studio_style
5 | from constants import client, SUMMARIZATION_URL, SUMMARIZATION_TEXT
6 |
7 | st.set_page_config(
8 | page_title="Document Summarizer",
9 | )
10 |
11 | if __name__ == '__main__':
12 | apply_studio_style()
13 |
14 | st.title("Document Summarizer")
15 | st.write(
16 | "Effortlessly transform lengthy material into a focused summary. Whether it’s an article, research paper or even your own notes - this tool will sum up the key points!")
17 | sourceType = st.radio(label="Source type", options=['Text', 'URL'])
18 | if sourceType == 'Text':
19 | source = st.text_area(label="Paste your text here:",
20 | height=400,
21 | value=SUMMARIZATION_TEXT).strip()
22 | else:
23 | source = st.text_input(label="Paste your URL here:",
24 | value=SUMMARIZATION_URL).strip()
25 |
26 | if st.button(label="Answer"):
27 | with st.spinner("Loading..."):
28 | try:
29 | response = client.summarize.create(source=source, source_type=sourceType.upper())
30 | st.text_area(label="Summary", height=250, value=response.summary)
31 | except UnprocessableEntity:
32 | st.write('Text is too long for the document summarizer, please try the segment summarizer instead.')
33 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | tqdm
2 | streamlit
3 | streamlit-chat
4 | aiohttp
5 | ai21
6 | requests>=2.31.0
7 | tornado>=6.3.2
8 | pdfplumber
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with python 3.9
3 | # To update, run:
4 | #
5 | # pip-compile requirements.in
6 | #
7 | --no-binary grpcio
8 |
9 | ai21==2.15.1
10 | # via -r requirements.in
11 | aiohttp==3.10.10
12 | # via -r requirements.in
13 | aiosignal==1.3.1
14 | # via aiohttp
15 | altair==5.1.2
16 | # via streamlit
17 | async-timeout==4.0.3
18 | # via aiohttp
19 | attrs==23.1.0
20 | # via
21 | # aiohttp
22 | # jsonschema
23 | # referencing
24 | blinker==1.6.3
25 | # via streamlit
26 | cachetools==5.3.2
27 | # via streamlit
28 | certifi==2023.7.22
29 | # via requests
30 | cffi==1.16.0
31 | # via cryptography
32 | charset-normalizer==3.3.2
33 | # via
34 | # aiohttp
35 | # pdfminer-six
36 | # requests
37 | click==8.1.7
38 | # via streamlit
39 | cryptography==41.0.5
40 | # via pdfminer-six
41 | frozenlist==1.4.0
42 | # via
43 | # aiohttp
44 | # aiosignal
45 | gitdb==4.0.11
46 | # via gitpython
47 | gitpython==3.1.40
48 | # via streamlit
49 | idna==3.4
50 | # via
51 | # requests
52 | # yarl
53 | importlib-metadata==6.8.0
54 | # via streamlit
55 | jinja2==3.1.2
56 | # via
57 | # altair
58 | # pydeck
59 | jsonschema==4.19.2
60 | # via altair
61 | jsonschema-specifications==2023.7.1
62 | # via jsonschema
63 | markdown-it-py==3.0.0
64 | # via rich
65 | markupsafe==2.1.3
66 | # via jinja2
67 | mdurl==0.1.2
68 | # via markdown-it-py
69 | multidict==6.0.4
70 | # via
71 | # aiohttp
72 | # yarl
73 | numpy==1.26.1
74 | # via
75 | # altair
76 | # pandas
77 | # pyarrow
78 | # pydeck
79 | # streamlit
80 | packaging==23.2
81 | # via
82 | # altair
83 | # streamlit
84 | pandas==2.1.2
85 | # via
86 | # altair
87 | # streamlit
88 | pdfminer-six==20221105
89 | # via pdfplumber
90 | pdfplumber==0.10.3
91 | # via -r requirements.in
92 | pillow==10.1.0
93 | # via
94 | # pdfplumber
95 | # streamlit
96 | protobuf==4.24.4
97 | # via streamlit
98 | pyarrow==14.0.0
99 | # via streamlit
100 | pycparser==2.21
101 | # via cffi
102 | pydeck==0.8.0
103 | # via streamlit
104 | pygments==2.16.1
105 | # via rich
106 | pypdfium2==4.23.1
107 | # via pdfplumber
108 | python-dateutil==2.8.2
109 | # via
110 | # pandas
111 | # streamlit
112 | pytz==2023.3.post1
113 | # via pandas
114 | referencing==0.30.2
115 | # via
116 | # jsonschema
117 | # jsonschema-specifications
118 | requests==2.31.0
119 | # via
120 | # -r requirements.in
121 | # ai21
122 | # streamlit
123 | rich==13.6.0
124 | # via streamlit
125 | rpds-py==0.10.6
126 | # via
127 | # jsonschema
128 | # referencing
129 | six==1.16.0
130 | # via python-dateutil
131 | smmap==5.0.1
132 | # via gitdb
133 | streamlit==1.31.1
134 | # via
135 | # -r requirements.in
136 | # streamlit-chat
137 | streamlit-chat==0.1.1
138 | # via -r requirements.in
139 | tenacity==8.5.0
140 | # via streamlit
141 | toml==0.10.2
142 | # via streamlit
143 | toolz==0.12.0
144 | # via altair
145 | tornado==6.3.3
146 | # via
147 | # -r requirements.in
148 | # streamlit
149 | tqdm==4.66.1
150 | # via -r requirements.in
151 | typing-extensions==4.10.0
152 | # via
153 | # altair
154 | # streamlit
155 | tzdata==2023.3
156 | # via pandas
157 | tzlocal==5.2
158 | # via streamlit
159 | urllib3==2.0.7
160 | # via requests
161 | validators==0.22.0
162 | # via streamlit
163 | yarl==1.17.0
164 | # via aiohttp
165 | zipp==3.17.0
166 | # via importlib-metadata
167 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/utils/__init__.py
--------------------------------------------------------------------------------
/utils/completion.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from aiohttp import ClientSession
3 | from constants import client
4 | import streamlit as st
5 |
6 | api_key = st.secrets['api-keys']['ai21-api-key']
7 |
8 | endpoint = lambda model_type: f"https://api.ai21.com/studio/v1/{model_type}/complete"
9 |
10 |
11 | async def async_complete(model_type, prompt, config, delay=0):
12 | async with ClientSession() as session:
13 | await asyncio.sleep(delay)
14 | auth_header = f"Bearer {api_key}"
15 | res = await session.post(
16 | endpoint(model_type),
17 | headers={"Authorization": auth_header},
18 | json={"prompt": prompt, **config}
19 | )
20 | res = await res.json()
21 | return res
22 |
23 |
24 | def tokenize(text):
25 | res = client.count_tokens(text)
26 | return res
27 |
28 |
29 | async def paraphrase_req(sentence, tone):
30 | async with ClientSession() as session:
31 | res = await session.post(
32 | "https://api.ai21.com/studio/v1/paraphrase",
33 | headers={f"Authorization": f"Bearer {st.secrets['api-keys']['ai21-api-key']}"},
34 | json={
35 | "text": sentence,
36 | "intent": tone.lower(),
37 | "spanStart": 0,
38 | "spanEnd": len(sentence)
39 | }
40 | )
41 | res = await res.json()
42 | return res
43 |
--------------------------------------------------------------------------------
/utils/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI21Labs/studio-demos/3ffdb3341a18f839c4c069202e1f99dd273ff1b9/utils/components/__init__.py
--------------------------------------------------------------------------------
/utils/components/completion_log.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import datetime
4 |
5 |
6 | def init_log(log_cols):
7 | if "completion_log" not in st.session_state:
8 | st.session_state["completion_log"] = CompletionLog(cols=log_cols)
9 |
10 |
11 | class CompletionLog:
12 |
13 | def __init__(self, cols):
14 | self.completion_log = pd.DataFrame(columns=cols)
15 |
16 | def get_completions_log(self):
17 | return self.completion_history
18 |
19 | def add_completion(self, completion_data):
20 | self.completion_log = pd.concat([self.completion_log, pd.DataFrame(completion_data)], ignore_index=True)
21 |
22 | def add_completion_button(self, completion_data):
23 | if st.button("Add item to dataset"):
24 | self.add_completion(completion_data)
25 |
26 | def remove_completion(self, ind):
27 | self.completion_log.drop(ind, axis=0)
28 |
29 | def download_history(self):
30 | st.download_button(
31 | "Download Completion Log",
32 | self.completion_log.to_csv().encode('utf-8'),
33 | f"completion_log_{datetime.datetime.now():%Y-%m-%d_%H-%M-%S}.csv",
34 | "text/csv",
35 | key='download-csv'
36 | )
37 |
38 | def display(self):
39 | st.markdown('#')
40 | st.subheader("Completion Log")
41 | st.dataframe(self.completion_log)
42 | self.download_history()
43 |
--------------------------------------------------------------------------------
/utils/filters.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | CHAR_LIMIT = {
4 | "Twitter": (30, 280),
5 | "Linkedin": (100, 1500),
6 | }
7 |
8 | def anonymize(text):
9 | text = re.sub(r'https?:\/\/.*', '[URL]', text)
10 | return re.sub(r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+', '[EMAIL]', text)
11 |
12 |
13 | def is_duplicate_prefix(input_text, output_text, th=0.7):
14 | input_words = input_text.strip().split()
15 | output_words = output_text.strip().split()
16 | if len(input_words) == 0 or len(output_words) == 0:
17 | return True
18 | output_prefix = output_words[:len(input_words)]
19 | overlap = set(output_prefix) & set(input_words)
20 | return len(overlap) / len(output_prefix) > th
21 |
22 |
23 | def apply_filters(completion, prompt, media):
24 | min_length, max_length = CHAR_LIMIT[media]
25 | text = completion.data.text
26 | return completion.finish_reason.reason == "endoftext" \
27 | and min_length <= len(text) <= max_length \
28 | and not is_duplicate_prefix(text, prompt) \
29 | and "[" not in text and "]" not in text
30 |
31 |
32 | def filter_duplicates(completions):
33 | results = list()
34 | for curr in completions:
35 | if not any(is_duplicate_prefix(r, curr) for r in results):
36 | results.append(curr)
37 | return results
38 |
39 |
40 | def remove_utf_emojis(s):
41 | return re.sub(r'<0x([0-9A-Fa-f]+)>', "", s)
--------------------------------------------------------------------------------
/utils/studio_style.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | def apply_studio_style():
4 | st.markdown(
5 | """
6 |
13 | """,
14 | unsafe_allow_html=True,
15 | )
16 |
17 | st.image("./assets/studio_logo.png")
--------------------------------------------------------------------------------