├── .streamlit
└── config.toml
├── 30days_logo.png
├── README.md
├── dashboard_utils
├── __pycache__
│ └── gui.cpython-37.pyc
└── gui.py
├── logo.png
├── requirements.txt
└── streamlit_app.py
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | primaryColor="#F63366"
3 | backgroundColor="#FFFFFF"
4 | secondaryBackgroundColor="#F0F2F6"
5 | textColor="#262730"
6 | font="sans serif"
7 |
--------------------------------------------------------------------------------
/30days_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/30days_logo.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # 🤗 Zero-shot Text Classifier
3 |
4 | [](https://share.streamlit.io/charlywargnier/zero-shot-classifier/main)
5 |
6 | Classify keyphrases fast and on-the-fly with this mighty app. No ML training needed! Create classifying labels, paste your keyphrases, and you're off! 🚀
7 |
8 | You can set these labels anything, e.g.:
9 | - `Positive`, `Negative` and `Neutral` for sentiment analysis
10 | - `Angry`, `Happy`, `Emotional` for emotion analysis
11 | - `Navigational`, `Transactional`, Informational for intent classification purposes
12 | - Your product range (`Bags`, `Shoes`, `Boots` etc.)
13 |
14 | You decide!
15 |
16 | ### About the app
17 |
18 | - App created by [Datachaz](https://twitter.com/DataChaz) using 🎈[Streamlit](https://streamlit.io/) and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model.
19 | - Deployed on [Streamlit Cloud](https://streamlit.io/cloud) ☁️
20 | - Created as part of the [30 days of streamlit challenge](https://blog.streamlit.io/30-days-of-streamlit/)
21 |
22 |
23 |
24 | ### Questions? Comments?
25 |
26 | Please ask in the [Streamlit community](https://discuss.streamlit.io).
27 |
--------------------------------------------------------------------------------
/dashboard_utils/__pycache__/gui.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/dashboard_utils/__pycache__/gui.cpython-37.pyc
--------------------------------------------------------------------------------
/dashboard_utils/gui.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | # Import for keyboard shortcuts
4 | import streamlit.components.v1 as components
5 |
6 | def load_keyboard_class():
7 | """This class enables to render some elements as if they were .
8 | Without this class, currently looks the same as in Streamlit.
9 | Usage:
10 | load_keyboard_class()
11 | st.write(' Press here ', unsafe_allow_html=True)
12 | """
13 | st.write(
14 | """""",
29 | unsafe_allow_html=True,
30 | )
31 |
32 |
33 |
34 |
35 | def keyboard_to_url(
36 | key: str = None,
37 | key_code: int = None,
38 | url: str = None,
39 | ):
40 | """Map a keyboard key to open a new tab with a given URL.
41 | Args:
42 | key (str, optional): Key to trigger (example 'k'). Defaults to None.
43 | key_code (int, optional): If key doesn't work, try hard-coding the key_code instead. Defaults to None.
44 | url (str, optional): Opens the input URL in new tab. Defaults to None.
45 | """
46 |
47 | assert not (
48 | key and key_code
49 | ), """You can not provide key and key_code.
50 | Either give key and we'll try to find its associated key_code. Or directly
51 | provide the key_code."""
52 |
53 | assert (key or key_code) and url, """You must provide key or key_code, and a URL"""
54 |
55 | if key:
56 | key_code_js_row = f"const keyCode = '{key}'.toUpperCase().charCodeAt(0);"
57 | if key_code:
58 | key_code_js_row = f"const keyCode = {key_code};"
59 |
60 | components.html(
61 | f"""
62 |
79 | """,
80 | height=0,
81 | width=0,
82 | )
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | streamlit-tags
3 | streamlit-aggrid
4 | streamlit-option-menu
5 |
--------------------------------------------------------------------------------
/streamlit_app.py:
--------------------------------------------------------------------------------
1 | # Import Streamlit and Pandas
2 | import streamlit as st
3 | import pandas as pd
4 |
5 | # Import for API calls
6 | import requests
7 |
8 | # Import for navbar
9 | from streamlit_option_menu import option_menu
10 |
11 | # Import for dyanmic tagging
12 | from streamlit_tags import st_tags, st_tags_sidebar
13 |
14 | # Imports for aggrid
15 | from st_aggrid import AgGrid
16 | from st_aggrid import AgGrid
17 | import pandas as pd
18 | from st_aggrid.grid_options_builder import GridOptionsBuilder
19 | from st_aggrid.shared import JsCode
20 | from st_aggrid import GridUpdateMode, DataReturnMode
21 |
22 | # Import for keyboard shortcuts
23 | from dashboard_utils.gui import keyboard_to_url
24 | from dashboard_utils.gui import load_keyboard_class
25 |
26 | #######################################################
27 |
28 | # The code below is to control the layout width of the app.
29 | if "widen" not in st.session_state:
30 | layout = "centered"
31 | else:
32 | layout = "wide" if st.session_state.widen else "centered"
33 |
34 | #######################################################
35 |
36 | # The code below is for the title and logo.
37 | st.set_page_config(layout=layout, page_title="Zero-Shot Text Classifier", page_icon="🤗")
38 |
39 | #######################################################
40 |
41 | # The class below is for adding some formatting to the shortcut button on the left sidebar.
42 | load_keyboard_class()
43 |
44 | #######################################################
45 |
46 | # Set up session state so app interactions don't reset the app.
47 | if not "valid_inputs_received" in st.session_state:
48 | st.session_state["valid_inputs_received"] = False
49 |
50 | #######################################################
51 |
52 | # The block of code below is to display the title, logos and introduce the app.
53 |
54 | c1, c2 = st.columns([0.4, 2])
55 |
56 | with c1:
57 |
58 | st.image(
59 | "logo.png",
60 | width=110,
61 | )
62 |
63 | with c2:
64 |
65 | st.caption("")
66 | st.title("Zero-Shot Text Classifier")
67 |
68 |
69 | st.sidebar.image(
70 | "30days_logo.png",
71 | )
72 |
73 | st.write("")
74 |
75 | st.markdown(
76 | """
77 |
78 | Classify keyphrases fast and on-the-fly with this mighty app. No ML training needed!
79 |
80 | Create classifying labels (e.g. `Positive`, `Negative` and `Neutral`), paste your keyphrases, and you're off!
81 |
82 | """
83 | )
84 |
85 | st.write("")
86 |
87 | st.sidebar.write("")
88 |
89 | #######################################################
90 |
91 | # The code below is to display the menu bar.ß
92 | with st.sidebar:
93 | selected = option_menu(
94 | "",
95 | ["Demo (5 phrases max)", "Unlocked Mode"],
96 | icons=["bi-joystick", "bi-key-fill"],
97 | menu_icon="",
98 | default_index=0,
99 | )
100 |
101 | #######################################################
102 |
103 | # The code below is to display the shortcuts.
104 | st.sidebar.header("Shortcuts")
105 | st.sidebar.write(
106 | 'G GitHub',
107 | unsafe_allow_html=True,
108 | )
109 |
110 | st.sidebar.write(
111 | ' . GitHub Dev (VS Code)',
112 | unsafe_allow_html=True,
113 | )
114 |
115 | #######################################################
116 |
117 | # The block of code below is to display information about Streamlit.
118 |
119 | st.sidebar.markdown("---")
120 |
121 | # Sidebar
122 | st.sidebar.header("About")
123 |
124 | st.sidebar.markdown(
125 | """
126 |
127 | App created by [Datachaz](https://twitter.com/DataChaz) using 🎈[Streamlit](https://streamlit.io/) and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model.
128 |
129 | """
130 | )
131 |
132 | st.sidebar.markdown(
133 | "[Streamlit](https://streamlit.io) is a Python library that allows the creation of interactive, data-driven web applications in Python."
134 | )
135 |
136 | st.sidebar.header("Resources")
137 | st.sidebar.markdown(
138 | """
139 | - [Streamlit Documentation](https://docs.streamlit.io/)
140 | - [Cheat sheet](https://docs.streamlit.io/library/cheatsheet)
141 | - [Book](https://www.amazon.com/dp/180056550X) (Getting Started with Streamlit for Data Science)
142 | - [Blog](https://blog.streamlit.io/how-to-master-streamlit-for-data-science/) (How to master Streamlit for data science)
143 | """
144 | )
145 |
146 | st.sidebar.header("Deploy")
147 | st.sidebar.markdown(
148 | "You can quickly deploy Streamlit apps using [Streamlit Cloud](https://streamlit.io/cloud) in just a few clicks."
149 | )
150 |
151 |
152 | def main():
153 | st.caption("")
154 |
155 |
156 | if selected == "Demo (5 phrases max)":
157 |
158 | API_KEY = st.secrets["API_KEY"]
159 |
160 | API_URL = (
161 | "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
162 | )
163 |
164 | headers = {"Authorization": f"Bearer {API_KEY}"}
165 |
166 | with st.form(key="my_form"):
167 |
168 | multiselectComponent = st_tags(
169 | label="",
170 | text="Add labels - 3 max",
171 | value=["Transactional", "Informational", "Navigational"],
172 | suggestions=[
173 | "Informational",
174 | "Transactional",
175 | "Navigational",
176 | "Positive",
177 | "Negative",
178 | "Neutral",
179 | ],
180 | maxtags=3,
181 | )
182 |
183 | new_line = "\n"
184 | nums = [
185 | "I want to buy something in this store",
186 | "How to ask a question about a product",
187 | "Request a refund through the Google Play store",
188 | "I have a broken screen, what should I do?",
189 | "Can I have the link to the product?",
190 | ]
191 |
192 | sample = f"{new_line.join(map(str, nums))}"
193 |
194 | linesDeduped2 = []
195 |
196 | MAX_LINES = 5
197 | text = st.text_area(
198 | "Enter keyphrases to classify",
199 | sample,
200 | height=200,
201 | key="2",
202 | help="At least two keyphrases for the classifier to work, one per line, "
203 | + str(MAX_LINES)
204 | + " keyphrases max as part of the demo",
205 | )
206 | lines = text.split("\n") # A list of lines
207 | linesList = []
208 | for x in lines:
209 | linesList.append(x)
210 | linesList = list(dict.fromkeys(linesList)) # Remove dupes
211 | linesList = list(filter(None, linesList)) # Remove empty
212 |
213 | if len(linesList) > MAX_LINES:
214 |
215 | st.info(
216 | f"❄️ Only the first "
217 | + str(MAX_LINES)
218 | + " keyprases will be reviewed. Unlock that limit by switching to 'Unlocked Mode'"
219 | )
220 |
221 | linesList = linesList[:MAX_LINES]
222 |
223 | submit_button = st.form_submit_button(label="Submit")
224 |
225 | if not submit_button and not st.session_state.valid_inputs_received:
226 | st.stop()
227 |
228 | elif submit_button and not text:
229 | st.warning("❄️ There is no keyphrases to classify")
230 | st.session_state.valid_inputs_received = False
231 | st.stop()
232 |
233 | elif submit_button and not multiselectComponent:
234 | st.warning("❄️ You have not added any labels, please add some! ")
235 | st.session_state.valid_inputs_received = False
236 | st.stop()
237 |
238 | elif submit_button and len(multiselectComponent) == 1:
239 | st.warning("❄️ Please make sure to add at least two labels for classification")
240 | st.session_state.valid_inputs_received = False
241 | st.stop()
242 |
243 | elif submit_button or st.session_state.valid_inputs_received:
244 |
245 | if submit_button:
246 | st.session_state.valid_inputs_received = True
247 |
248 | def query(payload):
249 | response = requests.post(API_URL, headers=headers, json=payload)
250 | # Unhash to check status codes from the API response
251 | # st.write(response.status_code)
252 | return response.json()
253 |
254 | listtest = ["I want a refund", "I have a question"]
255 | listToAppend = []
256 |
257 | for row in linesList:
258 | output2 = query(
259 | {
260 | "inputs": row,
261 | "parameters": {"candidate_labels": multiselectComponent},
262 | "options": {"wait_for_model": True},
263 | }
264 | )
265 |
266 | listToAppend.append(output2)
267 |
268 | df = pd.DataFrame.from_dict(output2)
269 |
270 | st.success("✅ Done!")
271 |
272 | df = pd.DataFrame.from_dict(listToAppend)
273 |
274 | st.caption("")
275 | st.markdown("### Check classifier results")
276 | st.caption("")
277 |
278 | st.checkbox(
279 | "Widen layout",
280 | key="widen",
281 | help="Tick this box to toggle the layout to 'Wide' mode",
282 | )
283 |
284 | st.caption("")
285 |
286 | # This is a list comprehension to convert the decimals to percentages
287 | f = [[f"{x:.2%}" for x in row] for row in df["scores"]]
288 |
289 | # This code is for re-integrating the labels back into the dataframe
290 | df["classification scores"] = f
291 | df.drop("scores", inplace=True, axis=1)
292 |
293 | # This code is to rename the columns
294 | df.rename(columns={"sequence": "keyphrase"}, inplace=True)
295 |
296 | # The code below is for Ag-grid
297 |
298 | gb = GridOptionsBuilder.from_dataframe(df)
299 | # enables pivoting on all columns
300 | gb.configure_default_column(
301 | enablePivot=True, enableValue=True, enableRowGroup=True
302 | )
303 | gb.configure_selection(selection_mode="multiple", use_checkbox=True)
304 | gb.configure_side_bar()
305 | gridOptions = gb.build()
306 |
307 | response = AgGrid(
308 | df,
309 | gridOptions=gridOptions,
310 | enable_enterprise_modules=True,
311 | update_mode=GridUpdateMode.MODEL_CHANGED,
312 | data_return_mode=DataReturnMode.FILTERED_AND_SORTED,
313 | height=300,
314 | fit_columns_on_grid_load=False,
315 | configure_side_bar=True,
316 | )
317 |
318 | # The code below is for the download button
319 |
320 | cs, c1 = st.columns([2, 2])
321 |
322 | with cs:
323 |
324 | @st.cache
325 | def convert_df(df):
326 | # IMPORTANT: Cache the conversion to prevent computation on every rerun
327 | return df.to_csv().encode("utf-8")
328 |
329 | csv = convert_df(df) #
330 |
331 | st.download_button(
332 | label="Download results as CSV",
333 | data=csv,
334 | file_name="results.csv",
335 | mime="text/csv",
336 | )
337 |
338 | elif selected == "Unlocked Mode":
339 |
340 | with st.form(key="my_form"):
341 | API_KEY2 = st.text_input(
342 | "Enter your 🤗 HuggingFace API key",
343 | help="Once you created you HuggiginFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
344 | )
345 |
346 | API_URL = (
347 | "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
348 | )
349 |
350 | headers = {"Authorization": f"Bearer {API_KEY2}"}
351 |
352 | multiselectComponent = st_tags(
353 | label="",
354 | text="Add labels - 3 max",
355 | value=["Transactional", "Informational", "Navigational"],
356 | suggestions=[
357 | "Informational",
358 | "Transactional",
359 | "Navigational",
360 | "Positive",
361 | "Negative",
362 | "Neutral",
363 | ],
364 | maxtags=3,
365 | )
366 |
367 | new_line = "\n"
368 | nums = [
369 | "I want to buy something in this store",
370 | "How to ask a question about a product",
371 | "Request a refund through the Google Play store",
372 | "I have a broken screen, what should I do?",
373 | "Can I have the link to the product?",
374 | ]
375 |
376 | sample = f"{new_line.join(map(str, nums))}"
377 |
378 | linesDeduped2 = []
379 |
380 | MAX_LINES_FULL = 50
381 | text = st.text_area(
382 | "Enter keyphrases to classify",
383 | sample,
384 | height=200,
385 | key="2",
386 | help="At least two keyphrases for the classifier to work, one per line, "
387 | + str(MAX_LINES_FULL)
388 | + " keyphrases max in 'unlocked mode'. You can tweak 'MAX_LINES_FULL' in the code to change this",
389 | )
390 |
391 | lines = text.split("\n") # A list of lines
392 | linesList = []
393 | for x in lines:
394 | linesList.append(x)
395 | linesList = list(dict.fromkeys(linesList)) # Remove dupes from list
396 | linesList = list(filter(None, linesList)) # Remove empty lines from list
397 |
398 | if len(linesList) > MAX_LINES_FULL:
399 | st.info(
400 | f"❄️ Note that only the first "
401 | + str(MAX_LINES_FULL)
402 | + " keyprases will be reviewed to preserve performance. Fork the repo and tweak 'MAX_LINES_FULL' in the code to increase that limit."
403 | )
404 |
405 | linesList = linesList[:MAX_LINES_FULL]
406 |
407 | submit_button = st.form_submit_button(label="Submit")
408 |
409 | if not submit_button and not st.session_state.valid_inputs_received:
410 | st.stop()
411 |
412 | elif submit_button and not text:
413 | st.warning("❄️ There is no keyphrases to classify")
414 | st.session_state.valid_inputs_received = False
415 | st.stop()
416 |
417 | elif submit_button and not multiselectComponent:
418 | st.warning("❄️ You have not added any labels, please add some! ")
419 | st.session_state.valid_inputs_received = False
420 | st.stop()
421 |
422 | elif submit_button and len(multiselectComponent) == 1:
423 | st.warning("❄️ Please make sure to add at least two labels for classification")
424 | st.session_state.valid_inputs_received = False
425 | st.stop()
426 |
427 | elif submit_button or st.session_state.valid_inputs_received:
428 |
429 | try:
430 |
431 | if submit_button:
432 |
433 | st.session_state.valid_inputs_received = True
434 |
435 | def query(payload):
436 | response = requests.post(API_URL, headers=headers, json=payload)
437 | # Unhash to check status codes from the API response
438 | # st.write(response.status_code)
439 | return response.json()
440 |
441 | listtest = ["I want a refund", "I have a question"]
442 | listToAppend = []
443 |
444 | for row in linesList:
445 | output2 = query(
446 | {
447 | "inputs": row,
448 | "parameters": {"candidate_labels": multiselectComponent},
449 | "options": {"wait_for_model": True},
450 | }
451 | )
452 |
453 | listToAppend.append(output2)
454 |
455 | df = pd.DataFrame.from_dict(output2)
456 |
457 | st.success("✅ Done!")
458 |
459 | df = pd.DataFrame.from_dict(listToAppend)
460 |
461 | st.caption("")
462 | st.markdown("### Check classifier results")
463 | st.caption("")
464 |
465 | st.checkbox(
466 | "Widen layout",
467 | key="widen",
468 | help="Tick this box to toggle the layout to 'Wide' mode",
469 | )
470 |
471 | # This is a list comprehension to convert the decimals to percentages
472 | f = [[f"{x:.2%}" for x in row] for row in df["scores"]]
473 |
474 | # This code is for re-integrating the labels back into the dataframe
475 | df["classification scores"] = f
476 | df.drop("scores", inplace=True, axis=1)
477 |
478 | # This code is to rename the columns
479 | df.rename(columns={"sequence": "keyphrase"}, inplace=True)
480 |
481 | # The code below is for Ag-grid
482 | gb = GridOptionsBuilder.from_dataframe(df)
483 | # enables pivoting on all columns
484 | gb.configure_default_column(
485 | enablePivot=True, enableValue=True, enableRowGroup=True
486 | )
487 | gb.configure_selection(selection_mode="multiple", use_checkbox=True)
488 | gb.configure_side_bar()
489 | gridOptions = gb.build()
490 |
491 | response = AgGrid(
492 | df,
493 | gridOptions=gridOptions,
494 | enable_enterprise_modules=True,
495 | update_mode=GridUpdateMode.MODEL_CHANGED,
496 | data_return_mode=DataReturnMode.FILTERED_AND_SORTED,
497 | height=300,
498 | fit_columns_on_grid_load=False,
499 | configure_side_bar=True,
500 | )
501 |
502 | # The code below is for the download button
503 |
504 | cs, c1 = st.columns([2, 2])
505 |
506 | with cs:
507 |
508 | @st.cache
509 | def convert_df(df):
510 | # IMPORTANT: Cache the conversion to prevent computation on every rerun
511 | return df.to_csv().encode("utf-8")
512 |
513 | csv = convert_df(df) #
514 |
515 | st.caption("")
516 |
517 | st.download_button(
518 | label="Download results as CSV",
519 | data=csv,
520 | file_name="results.csv",
521 | mime="text/csv",
522 | )
523 |
524 | except ValueError as ve:
525 |
526 | st.warning("❄️ Add a valid HuggingFace API key in the text box above ☝️")
527 | st.stop()
528 |
529 |
530 | if __name__ == "__main__":
531 | main()
532 |
533 | keyboard_to_url(
534 | key="g",
535 | url="https://github.com/CharlyWargnier/zero-shot-classifier/blob/main/streamlit_app.py",
536 | )
537 | keyboard_to_url(
538 | key_code=190,
539 | url="https://github.dev/CharlyWargnier/zero-shot-classifier/blob/main/streamlit_app.py",
540 | )
541 |
--------------------------------------------------------------------------------