├── .streamlit └── config.toml ├── 30days_logo.png ├── README.md ├── dashboard_utils ├── __pycache__ │ └── gui.cpython-37.pyc └── gui.py ├── logo.png ├── requirements.txt └── streamlit_app.py /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor="#F63366" 3 | backgroundColor="#FFFFFF" 4 | secondaryBackgroundColor="#F0F2F6" 5 | textColor="#262730" 6 | font="sans serif" 7 | -------------------------------------------------------------------------------- /30days_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/30days_logo.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # 🤗 Zero-shot Text Classifier 3 | 4 | [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/charlywargnier/zero-shot-classifier/main) 5 | 6 | Classify keyphrases fast and on-the-fly with this mighty app. No ML training needed! Create classifying labels, paste your keyphrases, and you're off! 🚀 7 | 8 | You can set these labels anything, e.g.: 9 | - `Positive`, `Negative` and `Neutral` for sentiment analysis 10 | - `Angry`, `Happy`, `Emotional` for emotion analysis 11 | - `Navigational`, `Transactional`, Informational for intent classification purposes 12 | - Your product range (`Bags`, `Shoes`, `Boots` etc.) 13 | 14 | You decide! 15 | 16 | ### About the app 17 | 18 | - App created by [Datachaz](https://twitter.com/DataChaz) using 🎈[Streamlit](https://streamlit.io/) and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model. 19 | - Deployed on [Streamlit Cloud](https://streamlit.io/cloud) ☁️ 20 | - Created as part of the [30 days of streamlit challenge](https://blog.streamlit.io/30-days-of-streamlit/) 21 | 22 | 23 | 24 | ### Questions? Comments? 25 | 26 | Please ask in the [Streamlit community](https://discuss.streamlit.io). 27 | -------------------------------------------------------------------------------- /dashboard_utils/__pycache__/gui.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/dashboard_utils/__pycache__/gui.cpython-37.pyc -------------------------------------------------------------------------------- /dashboard_utils/gui.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | # Import for keyboard shortcuts 4 | import streamlit.components.v1 as components 5 | 6 | def load_keyboard_class(): 7 | """This class enables to render some elements as if they were . 8 | Without this class, currently looks the same as in Streamlit. 9 | Usage: 10 | load_keyboard_class() 11 | st.write(' Press here ', unsafe_allow_html=True) 12 | """ 13 | st.write( 14 | """""", 29 | unsafe_allow_html=True, 30 | ) 31 | 32 | 33 | 34 | 35 | def keyboard_to_url( 36 | key: str = None, 37 | key_code: int = None, 38 | url: str = None, 39 | ): 40 | """Map a keyboard key to open a new tab with a given URL. 41 | Args: 42 | key (str, optional): Key to trigger (example 'k'). Defaults to None. 43 | key_code (int, optional): If key doesn't work, try hard-coding the key_code instead. Defaults to None. 44 | url (str, optional): Opens the input URL in new tab. Defaults to None. 45 | """ 46 | 47 | assert not ( 48 | key and key_code 49 | ), """You can not provide key and key_code. 50 | Either give key and we'll try to find its associated key_code. Or directly 51 | provide the key_code.""" 52 | 53 | assert (key or key_code) and url, """You must provide key or key_code, and a URL""" 54 | 55 | if key: 56 | key_code_js_row = f"const keyCode = '{key}'.toUpperCase().charCodeAt(0);" 57 | if key_code: 58 | key_code_js_row = f"const keyCode = {key_code};" 59 | 60 | components.html( 61 | f""" 62 | 79 | """, 80 | height=0, 81 | width=0, 82 | ) -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CharlyWargnier/zero-shot-classifier/7d1b25ed86d8fd9d99fe02babf1b4fdd841ce109/logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | streamlit-tags 3 | streamlit-aggrid 4 | streamlit-option-menu 5 | -------------------------------------------------------------------------------- /streamlit_app.py: -------------------------------------------------------------------------------- 1 | # Import Streamlit and Pandas 2 | import streamlit as st 3 | import pandas as pd 4 | 5 | # Import for API calls 6 | import requests 7 | 8 | # Import for navbar 9 | from streamlit_option_menu import option_menu 10 | 11 | # Import for dyanmic tagging 12 | from streamlit_tags import st_tags, st_tags_sidebar 13 | 14 | # Imports for aggrid 15 | from st_aggrid import AgGrid 16 | from st_aggrid import AgGrid 17 | import pandas as pd 18 | from st_aggrid.grid_options_builder import GridOptionsBuilder 19 | from st_aggrid.shared import JsCode 20 | from st_aggrid import GridUpdateMode, DataReturnMode 21 | 22 | # Import for keyboard shortcuts 23 | from dashboard_utils.gui import keyboard_to_url 24 | from dashboard_utils.gui import load_keyboard_class 25 | 26 | ####################################################### 27 | 28 | # The code below is to control the layout width of the app. 29 | if "widen" not in st.session_state: 30 | layout = "centered" 31 | else: 32 | layout = "wide" if st.session_state.widen else "centered" 33 | 34 | ####################################################### 35 | 36 | # The code below is for the title and logo. 37 | st.set_page_config(layout=layout, page_title="Zero-Shot Text Classifier", page_icon="🤗") 38 | 39 | ####################################################### 40 | 41 | # The class below is for adding some formatting to the shortcut button on the left sidebar. 42 | load_keyboard_class() 43 | 44 | ####################################################### 45 | 46 | # Set up session state so app interactions don't reset the app. 47 | if not "valid_inputs_received" in st.session_state: 48 | st.session_state["valid_inputs_received"] = False 49 | 50 | ####################################################### 51 | 52 | # The block of code below is to display the title, logos and introduce the app. 53 | 54 | c1, c2 = st.columns([0.4, 2]) 55 | 56 | with c1: 57 | 58 | st.image( 59 | "logo.png", 60 | width=110, 61 | ) 62 | 63 | with c2: 64 | 65 | st.caption("") 66 | st.title("Zero-Shot Text Classifier") 67 | 68 | 69 | st.sidebar.image( 70 | "30days_logo.png", 71 | ) 72 | 73 | st.write("") 74 | 75 | st.markdown( 76 | """ 77 | 78 | Classify keyphrases fast and on-the-fly with this mighty app. No ML training needed! 79 | 80 | Create classifying labels (e.g. `Positive`, `Negative` and `Neutral`), paste your keyphrases, and you're off! 81 | 82 | """ 83 | ) 84 | 85 | st.write("") 86 | 87 | st.sidebar.write("") 88 | 89 | ####################################################### 90 | 91 | # The code below is to display the menu bar.ß 92 | with st.sidebar: 93 | selected = option_menu( 94 | "", 95 | ["Demo (5 phrases max)", "Unlocked Mode"], 96 | icons=["bi-joystick", "bi-key-fill"], 97 | menu_icon="", 98 | default_index=0, 99 | ) 100 | 101 | ####################################################### 102 | 103 | # The code below is to display the shortcuts. 104 | st.sidebar.header("Shortcuts") 105 | st.sidebar.write( 106 | 'G   GitHub', 107 | unsafe_allow_html=True, 108 | ) 109 | 110 | st.sidebar.write( 111 | ' .    GitHub Dev (VS Code)', 112 | unsafe_allow_html=True, 113 | ) 114 | 115 | ####################################################### 116 | 117 | # The block of code below is to display information about Streamlit. 118 | 119 | st.sidebar.markdown("---") 120 | 121 | # Sidebar 122 | st.sidebar.header("About") 123 | 124 | st.sidebar.markdown( 125 | """ 126 | 127 | App created by [Datachaz](https://twitter.com/DataChaz) using 🎈[Streamlit](https://streamlit.io/) and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model. 128 | 129 | """ 130 | ) 131 | 132 | st.sidebar.markdown( 133 | "[Streamlit](https://streamlit.io) is a Python library that allows the creation of interactive, data-driven web applications in Python." 134 | ) 135 | 136 | st.sidebar.header("Resources") 137 | st.sidebar.markdown( 138 | """ 139 | - [Streamlit Documentation](https://docs.streamlit.io/) 140 | - [Cheat sheet](https://docs.streamlit.io/library/cheatsheet) 141 | - [Book](https://www.amazon.com/dp/180056550X) (Getting Started with Streamlit for Data Science) 142 | - [Blog](https://blog.streamlit.io/how-to-master-streamlit-for-data-science/) (How to master Streamlit for data science) 143 | """ 144 | ) 145 | 146 | st.sidebar.header("Deploy") 147 | st.sidebar.markdown( 148 | "You can quickly deploy Streamlit apps using [Streamlit Cloud](https://streamlit.io/cloud) in just a few clicks." 149 | ) 150 | 151 | 152 | def main(): 153 | st.caption("") 154 | 155 | 156 | if selected == "Demo (5 phrases max)": 157 | 158 | API_KEY = st.secrets["API_KEY"] 159 | 160 | API_URL = ( 161 | "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3" 162 | ) 163 | 164 | headers = {"Authorization": f"Bearer {API_KEY}"} 165 | 166 | with st.form(key="my_form"): 167 | 168 | multiselectComponent = st_tags( 169 | label="", 170 | text="Add labels - 3 max", 171 | value=["Transactional", "Informational", "Navigational"], 172 | suggestions=[ 173 | "Informational", 174 | "Transactional", 175 | "Navigational", 176 | "Positive", 177 | "Negative", 178 | "Neutral", 179 | ], 180 | maxtags=3, 181 | ) 182 | 183 | new_line = "\n" 184 | nums = [ 185 | "I want to buy something in this store", 186 | "How to ask a question about a product", 187 | "Request a refund through the Google Play store", 188 | "I have a broken screen, what should I do?", 189 | "Can I have the link to the product?", 190 | ] 191 | 192 | sample = f"{new_line.join(map(str, nums))}" 193 | 194 | linesDeduped2 = [] 195 | 196 | MAX_LINES = 5 197 | text = st.text_area( 198 | "Enter keyphrases to classify", 199 | sample, 200 | height=200, 201 | key="2", 202 | help="At least two keyphrases for the classifier to work, one per line, " 203 | + str(MAX_LINES) 204 | + " keyphrases max as part of the demo", 205 | ) 206 | lines = text.split("\n") # A list of lines 207 | linesList = [] 208 | for x in lines: 209 | linesList.append(x) 210 | linesList = list(dict.fromkeys(linesList)) # Remove dupes 211 | linesList = list(filter(None, linesList)) # Remove empty 212 | 213 | if len(linesList) > MAX_LINES: 214 | 215 | st.info( 216 | f"❄️ Only the first " 217 | + str(MAX_LINES) 218 | + " keyprases will be reviewed. Unlock that limit by switching to 'Unlocked Mode'" 219 | ) 220 | 221 | linesList = linesList[:MAX_LINES] 222 | 223 | submit_button = st.form_submit_button(label="Submit") 224 | 225 | if not submit_button and not st.session_state.valid_inputs_received: 226 | st.stop() 227 | 228 | elif submit_button and not text: 229 | st.warning("❄️ There is no keyphrases to classify") 230 | st.session_state.valid_inputs_received = False 231 | st.stop() 232 | 233 | elif submit_button and not multiselectComponent: 234 | st.warning("❄️ You have not added any labels, please add some! ") 235 | st.session_state.valid_inputs_received = False 236 | st.stop() 237 | 238 | elif submit_button and len(multiselectComponent) == 1: 239 | st.warning("❄️ Please make sure to add at least two labels for classification") 240 | st.session_state.valid_inputs_received = False 241 | st.stop() 242 | 243 | elif submit_button or st.session_state.valid_inputs_received: 244 | 245 | if submit_button: 246 | st.session_state.valid_inputs_received = True 247 | 248 | def query(payload): 249 | response = requests.post(API_URL, headers=headers, json=payload) 250 | # Unhash to check status codes from the API response 251 | # st.write(response.status_code) 252 | return response.json() 253 | 254 | listtest = ["I want a refund", "I have a question"] 255 | listToAppend = [] 256 | 257 | for row in linesList: 258 | output2 = query( 259 | { 260 | "inputs": row, 261 | "parameters": {"candidate_labels": multiselectComponent}, 262 | "options": {"wait_for_model": True}, 263 | } 264 | ) 265 | 266 | listToAppend.append(output2) 267 | 268 | df = pd.DataFrame.from_dict(output2) 269 | 270 | st.success("✅ Done!") 271 | 272 | df = pd.DataFrame.from_dict(listToAppend) 273 | 274 | st.caption("") 275 | st.markdown("### Check classifier results") 276 | st.caption("") 277 | 278 | st.checkbox( 279 | "Widen layout", 280 | key="widen", 281 | help="Tick this box to toggle the layout to 'Wide' mode", 282 | ) 283 | 284 | st.caption("") 285 | 286 | # This is a list comprehension to convert the decimals to percentages 287 | f = [[f"{x:.2%}" for x in row] for row in df["scores"]] 288 | 289 | # This code is for re-integrating the labels back into the dataframe 290 | df["classification scores"] = f 291 | df.drop("scores", inplace=True, axis=1) 292 | 293 | # This code is to rename the columns 294 | df.rename(columns={"sequence": "keyphrase"}, inplace=True) 295 | 296 | # The code below is for Ag-grid 297 | 298 | gb = GridOptionsBuilder.from_dataframe(df) 299 | # enables pivoting on all columns 300 | gb.configure_default_column( 301 | enablePivot=True, enableValue=True, enableRowGroup=True 302 | ) 303 | gb.configure_selection(selection_mode="multiple", use_checkbox=True) 304 | gb.configure_side_bar() 305 | gridOptions = gb.build() 306 | 307 | response = AgGrid( 308 | df, 309 | gridOptions=gridOptions, 310 | enable_enterprise_modules=True, 311 | update_mode=GridUpdateMode.MODEL_CHANGED, 312 | data_return_mode=DataReturnMode.FILTERED_AND_SORTED, 313 | height=300, 314 | fit_columns_on_grid_load=False, 315 | configure_side_bar=True, 316 | ) 317 | 318 | # The code below is for the download button 319 | 320 | cs, c1 = st.columns([2, 2]) 321 | 322 | with cs: 323 | 324 | @st.cache 325 | def convert_df(df): 326 | # IMPORTANT: Cache the conversion to prevent computation on every rerun 327 | return df.to_csv().encode("utf-8") 328 | 329 | csv = convert_df(df) # 330 | 331 | st.download_button( 332 | label="Download results as CSV", 333 | data=csv, 334 | file_name="results.csv", 335 | mime="text/csv", 336 | ) 337 | 338 | elif selected == "Unlocked Mode": 339 | 340 | with st.form(key="my_form"): 341 | API_KEY2 = st.text_input( 342 | "Enter your 🤗 HuggingFace API key", 343 | help="Once you created you HuggiginFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens", 344 | ) 345 | 346 | API_URL = ( 347 | "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3" 348 | ) 349 | 350 | headers = {"Authorization": f"Bearer {API_KEY2}"} 351 | 352 | multiselectComponent = st_tags( 353 | label="", 354 | text="Add labels - 3 max", 355 | value=["Transactional", "Informational", "Navigational"], 356 | suggestions=[ 357 | "Informational", 358 | "Transactional", 359 | "Navigational", 360 | "Positive", 361 | "Negative", 362 | "Neutral", 363 | ], 364 | maxtags=3, 365 | ) 366 | 367 | new_line = "\n" 368 | nums = [ 369 | "I want to buy something in this store", 370 | "How to ask a question about a product", 371 | "Request a refund through the Google Play store", 372 | "I have a broken screen, what should I do?", 373 | "Can I have the link to the product?", 374 | ] 375 | 376 | sample = f"{new_line.join(map(str, nums))}" 377 | 378 | linesDeduped2 = [] 379 | 380 | MAX_LINES_FULL = 50 381 | text = st.text_area( 382 | "Enter keyphrases to classify", 383 | sample, 384 | height=200, 385 | key="2", 386 | help="At least two keyphrases for the classifier to work, one per line, " 387 | + str(MAX_LINES_FULL) 388 | + " keyphrases max in 'unlocked mode'. You can tweak 'MAX_LINES_FULL' in the code to change this", 389 | ) 390 | 391 | lines = text.split("\n") # A list of lines 392 | linesList = [] 393 | for x in lines: 394 | linesList.append(x) 395 | linesList = list(dict.fromkeys(linesList)) # Remove dupes from list 396 | linesList = list(filter(None, linesList)) # Remove empty lines from list 397 | 398 | if len(linesList) > MAX_LINES_FULL: 399 | st.info( 400 | f"❄️ Note that only the first " 401 | + str(MAX_LINES_FULL) 402 | + " keyprases will be reviewed to preserve performance. Fork the repo and tweak 'MAX_LINES_FULL' in the code to increase that limit." 403 | ) 404 | 405 | linesList = linesList[:MAX_LINES_FULL] 406 | 407 | submit_button = st.form_submit_button(label="Submit") 408 | 409 | if not submit_button and not st.session_state.valid_inputs_received: 410 | st.stop() 411 | 412 | elif submit_button and not text: 413 | st.warning("❄️ There is no keyphrases to classify") 414 | st.session_state.valid_inputs_received = False 415 | st.stop() 416 | 417 | elif submit_button and not multiselectComponent: 418 | st.warning("❄️ You have not added any labels, please add some! ") 419 | st.session_state.valid_inputs_received = False 420 | st.stop() 421 | 422 | elif submit_button and len(multiselectComponent) == 1: 423 | st.warning("❄️ Please make sure to add at least two labels for classification") 424 | st.session_state.valid_inputs_received = False 425 | st.stop() 426 | 427 | elif submit_button or st.session_state.valid_inputs_received: 428 | 429 | try: 430 | 431 | if submit_button: 432 | 433 | st.session_state.valid_inputs_received = True 434 | 435 | def query(payload): 436 | response = requests.post(API_URL, headers=headers, json=payload) 437 | # Unhash to check status codes from the API response 438 | # st.write(response.status_code) 439 | return response.json() 440 | 441 | listtest = ["I want a refund", "I have a question"] 442 | listToAppend = [] 443 | 444 | for row in linesList: 445 | output2 = query( 446 | { 447 | "inputs": row, 448 | "parameters": {"candidate_labels": multiselectComponent}, 449 | "options": {"wait_for_model": True}, 450 | } 451 | ) 452 | 453 | listToAppend.append(output2) 454 | 455 | df = pd.DataFrame.from_dict(output2) 456 | 457 | st.success("✅ Done!") 458 | 459 | df = pd.DataFrame.from_dict(listToAppend) 460 | 461 | st.caption("") 462 | st.markdown("### Check classifier results") 463 | st.caption("") 464 | 465 | st.checkbox( 466 | "Widen layout", 467 | key="widen", 468 | help="Tick this box to toggle the layout to 'Wide' mode", 469 | ) 470 | 471 | # This is a list comprehension to convert the decimals to percentages 472 | f = [[f"{x:.2%}" for x in row] for row in df["scores"]] 473 | 474 | # This code is for re-integrating the labels back into the dataframe 475 | df["classification scores"] = f 476 | df.drop("scores", inplace=True, axis=1) 477 | 478 | # This code is to rename the columns 479 | df.rename(columns={"sequence": "keyphrase"}, inplace=True) 480 | 481 | # The code below is for Ag-grid 482 | gb = GridOptionsBuilder.from_dataframe(df) 483 | # enables pivoting on all columns 484 | gb.configure_default_column( 485 | enablePivot=True, enableValue=True, enableRowGroup=True 486 | ) 487 | gb.configure_selection(selection_mode="multiple", use_checkbox=True) 488 | gb.configure_side_bar() 489 | gridOptions = gb.build() 490 | 491 | response = AgGrid( 492 | df, 493 | gridOptions=gridOptions, 494 | enable_enterprise_modules=True, 495 | update_mode=GridUpdateMode.MODEL_CHANGED, 496 | data_return_mode=DataReturnMode.FILTERED_AND_SORTED, 497 | height=300, 498 | fit_columns_on_grid_load=False, 499 | configure_side_bar=True, 500 | ) 501 | 502 | # The code below is for the download button 503 | 504 | cs, c1 = st.columns([2, 2]) 505 | 506 | with cs: 507 | 508 | @st.cache 509 | def convert_df(df): 510 | # IMPORTANT: Cache the conversion to prevent computation on every rerun 511 | return df.to_csv().encode("utf-8") 512 | 513 | csv = convert_df(df) # 514 | 515 | st.caption("") 516 | 517 | st.download_button( 518 | label="Download results as CSV", 519 | data=csv, 520 | file_name="results.csv", 521 | mime="text/csv", 522 | ) 523 | 524 | except ValueError as ve: 525 | 526 | st.warning("❄️ Add a valid HuggingFace API key in the text box above ☝️") 527 | st.stop() 528 | 529 | 530 | if __name__ == "__main__": 531 | main() 532 | 533 | keyboard_to_url( 534 | key="g", 535 | url="https://github.com/CharlyWargnier/zero-shot-classifier/blob/main/streamlit_app.py", 536 | ) 537 | keyboard_to_url( 538 | key_code=190, 539 | url="https://github.dev/CharlyWargnier/zero-shot-classifier/blob/main/streamlit_app.py", 540 | ) 541 | --------------------------------------------------------------------------------