├── examples ├── .MD ├── shap_examples.png ├── text_example.txt └── sentiment_analysis.ipynb ├── data.zip ├── hebEMO1.png ├── heBERT_logo.png ├── requirements.txt ├── LICENSE ├── src ├── spider_plot.py ├── HeBERT_ner_app.py ├── HebEMO_app.py ├── HebEMO.py └── train_model.py └── README.md /examples/.MD: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/data.zip -------------------------------------------------------------------------------- /hebEMO1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/hebEMO1.png -------------------------------------------------------------------------------- /heBERT_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/heBERT_logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.14.1 2 | pyplutchik==0.0.7 3 | torch 4 | streamlit 5 | localtunnel 6 | -------------------------------------------------------------------------------- /examples/shap_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/examples/shap_examples.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Avichay Chriqui and Inbal Yahav 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/spider_plot.py: -------------------------------------------------------------------------------- 1 | def spider_plot(df): 2 | ''' 3 | retrieved from https://www.python-graph-gallery.com/391-radar-chart-with-several-individuals 4 | ''' 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | from math import pi 8 | 9 | # number of variable 10 | categories=list(df) 11 | N = len(categories) 12 | 13 | # We are going to plot the first line of the data frame. 14 | # But we need to repeat the first value to close the circular graph: 15 | values=df.loc[0].values.flatten().tolist() 16 | values += values[:1] 17 | values 18 | 19 | # What will be the angle of each axis in the plot? (we divide the plot / number of variable) 20 | angles = [n / float(N) * 2 * pi for n in range(N)] 21 | angles += angles[:1] 22 | 23 | # Initialise the spider plot 24 | fig, ax = plt.subplot(111, polar=True) 25 | 26 | # Draw one axe per variable + add labels 27 | plt.xticks(angles[:-1], categories, color='grey', size=8) 28 | 29 | # Draw ylabels 30 | ax.set_rlabel_position(0) 31 | plt.yticks([10,20,30], ["10","20","30"], color="grey", size=7) 32 | plt.ylim(0,1) 33 | 34 | # Plot data 35 | ax.plot(angles, values, linewidth=1, linestyle='solid') 36 | 37 | # Fill area 38 | ax.fill(angles, values, 'b', alpha=0.1) 39 | 40 | # Show the graph 41 | return(fig, ax) 42 | -------------------------------------------------------------------------------- /src/HeBERT_ner_app.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | import streamlit as st 3 | from io import StringIO 4 | import pandas as pd 5 | 6 | st.title("NER in Hebrew Texts") 7 | st.write("Named-entity recognition (NER) (also known as (named) entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc. [wikipedia]") 8 | st.write("Write Hebrew sentences in the text box below to analyze (each sentence in a different rew). ") 9 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)', 10 | 5, 512, 20) 11 | 12 | NER = pipeline( 13 | "token-classification", 14 | model="avichr/heBERT_NER", 15 | tokenizer="avichr/heBERT_NER", 16 | ignore_labels = [], 17 | aggregation_strategy = 'simple' 18 | ) 19 | 20 | method = st.selectbox( 21 | 'What is your input?', 22 | ('File', 'Free Text')) 23 | 24 | if method == 'File': 25 | uploaded_file = st.file_uploader("Choose a file") 26 | 27 | if uploaded_file is not None: 28 | # To read file as bytes: 29 | bytes_data = uploaded_file.getvalue() 30 | 31 | # To convert to a string based IO: 32 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) 33 | 34 | # To read file as string: 35 | string_data = stringio.read() 36 | 37 | else: 38 | string_data = st.text_area(label = 'Text to analyze', value="", 39 | placeholder = ''' 40 | It was the best of times, it was the worst of times, it was 41 | the age of wisdom, it was the age of foolishness, it was 42 | the epoch of belief, it was the epoch of incredulity, it 43 | was the season of Light, it was the season of Darkness, it 44 | was the spring of hope, it was the winter of despair, (...) 45 | ''') 46 | if 'string_data' in locals(): 47 | if (string_data != '' and string_data is not None): 48 | st.write('It takes a while, be patient :)') 49 | ner_df = pd.DataFrame(NER(string_data)) 50 | @st.cache_data 51 | def convert_df(df): 52 | return df.to_csv(index=False).encode('utf-8') 53 | 54 | 55 | csv = convert_df(ner_df) 56 | 57 | st.download_button( 58 | "Press to Download", 59 | csv, 60 | "ner_df.csv", 61 | "text/csv", 62 | key='download-csv' 63 | ) 64 | 65 | st.write (ner_df) 66 | -------------------------------------------------------------------------------- /examples/text_example.txt: -------------------------------------------------------------------------------- 1 | " וביבי יצילנו מידם... הוא ולא שרף. 1. תחילה הסתה וגזענות - איבדתה את הרוב של הכנסת2.עתה נגד הדמוקרטיה - הצלחתה לאחד מולך את כל האופוזציה, כולל מבית3. צעד נוסף באותו כיוון - אשריך !" 2 | " וביבי שותק. הם לא נעמה. הם נמצאים בבידוד ולא במאסר, הם בסביבה שדואגים להם והם לא נמצאים בסכנה מיידית כלשהי, ובנוסף הם עלולים להוות סיכון לאחרים. עם כל אי הנוחות, להישאר על הספינה זה הפתרון הטוב ביותר כרגע" 3 | ובישראל חסר מסביר כי ביבי לא ייתן לאף אחד לגנוב לו את ההצגה. היועץ הלאומי של שישה נשיאים ברציפות. מיגור הסארס בארה'ב נזקף לזכותו. הוא הפילטר להכרזות ההזויות של הנשיא הנוכחי (רק מלהזכיר את השם שלו אני חוטף קורו......) 4 | " וגם הרוב עצות יעילות וטובות, מנסיון " 5 | וגם השחיקה של התלמידים. אני אחזור ולא אוכל לדאוג להורי או לסכן את עצמי. גם גיל 45 גיל מסוכן 6 | והבס שלך לא עבריין??? 3 תיקי אישומים. 7 | " והוא יעשה נזק עוד יותר חמור במשרד השיכון. מענין מה קרה לקופסת הכותב? אפילו בבדיחה לא הייתי חושב ככה על ליצמן, אפילו שאני רואה טלווזיה כל יום." 8 | " והם ממש לוגיים. אסור לגולש גלים לגלוש בים הפתוח.... לו אנשים בישראל היו ממושמעים, אז היה הגיוני לאפשר הקלות פה ושם, אבל רק היום ראיתי מלא (מעל 6 אנשים) ללא מסיכות על הפנים, כמה היו עם מסיכה בהישג יד, בכיס האחורי !!!נורא" 9 | והמסכות במקור מסין בתקווה שהיבוא לא יביא גם את הוירוס 10 | והסינים ממשיכים לחיות כרגיל . יש להם טריליוני דולרים בכספות.... רוב החוב הגולמי של ארהב שוחב בכספות הסיניות... אם זה היה נגיף יהודי....היו דורשים מאיתנו את מחזור הדם... 11 | והרכבת ? זה זמן עוד יותר טוב לרכבת שלא פעילה בכלל.. כל הכבוד זה הזמן אבל350 אלף כלי רכב בתקופת הקורונה???? 12 | " וואו ! התמונה של תל אביב מדהימה ,מריחים את העושר של ישראל . בכל זאת להרגיש ביחד. חג פסח שמח לכולכם." 13 | " וואו ! התמונה של תל אביב מדהימה ,מריחים את העושר של ישראל. בכל זאת להרגיש ביחד. חג פסח שמח לכולכם." 14 | וואו שלווה את מלכה וחזקה.. מותר לשלוח למבודדים משלוח מנות?. נו תרימו ראשכם! נכון זה לא שפעת אבל היי. השד לא נורא 15 | וואו! בהצלחה גם בשם חולי פיברו וcfs!!! . הגעתי אליה לפני יותר מעשרים שנה. זכורה לי מאז. רופאה מקצועית אנושית ומדברת עינייני ובגובה העיניים. 16 | " וואו, מתאים לספר השיאים של גינס. מקסים ומרגש!!!! . לא יודעת מי המתורגמנית. אבל יודעת בוודאות כמה זה חשוב!!!! כל כך חשוב. יש מתורגמנית שכולם מכירים מהחדשות בטלויזיה, אפרת נגר המדהימה שהיא גם דולה... עבודת קודש אמיתית. מזל טוב יקרים. הרבה בריאות ושמחה." 17 | " וואי ביבי הגנב,הנוכל, לוקח השוחד, מפר האמונים ומחמש את האויב. יהרגו את בני האדםחאלס זה מקרה חירום אין מקום לזכויות אדם. רק מקום לבטחון של האזרחים של יחלו ולא ימותו" 18 | " וואלה צודק. אם לא היינו נלחצים אולי גם ליצמן היה נפטר . בגלל שלא נזהר, הוא ואשתו נדבקו בקורונהבגלל שהציבור החרדי לא נזהר, אצלם מספר החולים והמתים הגבוה ביותר - בכל תחשיב שהוא." 19 | " ווואו היא מקסימה ומצחיקה . ואני מכירה את העולם הזה. נמאס מהזיוף והשטויות, שינצלו את הכוח שלהם, עלק כוח, בשביל דברים חשובים יותר. קרן מוכשרת חכמה ויפהפיה והיא מאלה שישרדו" 20 | " וועדת חקירה עכשיו. לחקור לאן נעלמו חומרי ההרדמה ועוד ציוד. בראש הנתונים המתפרסם במדיה נמצא מס' הנדבקים מתחילת המגפה, צריך לשנות כך שהנתון הבולט יהיה מס' החולים המאומתים עכשיו ובהמשך אותם יתר הנתונים. כמו מס'" 21 | אני אוהב את העולם 22 | אני לא אוהב את העולם 23 | אני ממש אוהב את העולם 24 | אני לא ממש אוהב את העולם 25 | אני ממש לא אוהב את העולם 26 | קפה זה טעים 27 | קפה זה לא טעים 28 | קפה זה סבבה 29 | קפה זה לא ממש טעים 30 | קפה זה טעים אש 31 | מה אוכלים היום? 32 | -------------------------------------------------------------------------------- /src/HebEMO_app.py: -------------------------------------------------------------------------------- 1 | from HebEMO import HebEMO 2 | from transformers import pipeline 3 | import streamlit as st 4 | from io import StringIO 5 | import pandas as pd 6 | 7 | st.title("Emotion Recognition in Hebrew Texts") 8 | st.write("HebEMO is a tool to detect polarity and extract emotions from Hebrew user-generated content (UGC), which was trained on a unique Covid-19 related dataset that we collected and annotated. HebEMO yielded a high performance of weighted average F1-score = 0.96 for polarity classification. Emotion detection reached an F1-score of 0.78-0.97, with the exception of *surprise*, which the model failed to capture (F1 = 0.41). More information can be found in our git: https://github.com/avichaychriqui/HeBERT") 9 | st.write("Write Hebrew sentences in the text box below to analyze (each sentence in a different rew). An additional demo can be found in the Colab notebook: https://colab.research.google.com/drive/1Jw3gOWjwVMcZslu-ttXoNeD17lms1-ff ") 10 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)', 11 | 5, 512, 20) 12 | 13 | 14 | option = st.selectbox( 15 | 'What do you want to analyze?', 16 | ('Sentiment', 'Emotion')) 17 | 18 | if option == 'Emotion': 19 | emotions = st.multiselect( 20 | 'Which emotion to analyze?:', 21 | ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger', 22 | 'sadness', 'disgust'], 23 | ['joy']) 24 | else: 25 | emotions = ['sentiment'] 26 | 27 | method = st.selectbox( 28 | 'What is your input?', 29 | ('File', 'Free Text')) 30 | 31 | if method == 'File': 32 | uploaded_file = st.file_uploader("Choose a file") 33 | 34 | if uploaded_file is not None: 35 | # To read file as bytes: 36 | bytes_data = uploaded_file.getvalue() 37 | 38 | # To convert to a string based IO: 39 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) 40 | 41 | # To read file as string: 42 | string_data = stringio.read() 43 | 44 | else: 45 | string_data = st.text_area(label = 'Text to analyze', value="", 46 | placeholder = ''' 47 | It was the best of times, it was the worst of times, it was 48 | the age of wisdom, it was the age of foolishness, it was 49 | the epoch of belief, it was the epoch of incredulity, it 50 | was the season of Light, it was the season of Darkness, it 51 | was the spring of hope, it was the winter of despair, (...) 52 | ''') 53 | if 'string_data' in locals(): 54 | 55 | if (string_data != '' and string_data is not None): 56 | st.write('It takes a while, be patient :)') 57 | HebEMO_model = HebEMO(device= 0, emotions = emotions) 58 | 59 | hebEMO_df = HebEMO_model.hebemo(string_data, read_lines=True, plot=False, batch_size=32, max_length = max_length) 60 | 61 | if option == 'Emotion': 62 | hebEMO = pd.DataFrame(hebEMO_df[0]) 63 | for emo in hebEMO_df.columns[1::2]: 64 | hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo])) 65 | else: 66 | hebEMO = hebEMO_df 67 | 68 | @st.cache_data 69 | def convert_df(df): 70 | return df.to_csv(index=False).encode('utf-8') 71 | 72 | 73 | csv = convert_df(hebEMO) 74 | 75 | st.download_button( 76 | "Press to Download", 77 | csv, 78 | "hebEMO-ed.csv", 79 | "text/csv", 80 | key='download-csv' 81 | ) 82 | 83 | st.write (hebEMO) 84 | -------------------------------------------------------------------------------- /src/HebEMO.py: -------------------------------------------------------------------------------- 1 | class HebEMO: 2 | def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger', 3 | 'sadness', 'disgust']): 4 | from transformers import pipeline 5 | from tqdm import tqdm 6 | self.device = device 7 | if type(emotions) == str: 8 | self.emotions = [emotions] 9 | elif type(emotions) != list: 10 | raise ValueError('emotions should be emotion as a text or list of emotions.') 11 | else: 12 | self.emotions = emotions 13 | self.hebemo_models = {} 14 | for emo in tqdm(self.emotions): 15 | if emo == 'sentiment': 16 | model_name = 'avichr/heBERT_sentiment_analysis' 17 | else: 18 | model_name = "avichr/hebEMO_"+emo 19 | self.hebemo_models[emo] = pipeline( 20 | "sentiment-analysis", 21 | model=model_name, 22 | tokenizer="avichr/heBERT", 23 | device = self.device #-1 run on CPU, else - device ID 24 | ) 25 | 26 | def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False, batch_size=32, max_length = 512, truncation=True): 27 | ''' 28 | text (str): a text or list of text to analyze 29 | input_path(str): the path to the text file (txt file, each row for different instance) 30 | ''' 31 | from pyplutchik import plutchik 32 | import matplotlib.pyplot as plt 33 | import pandas as pd 34 | import time 35 | import torch 36 | from tqdm import tqdm 37 | if text is None and type(input_path) is str: 38 | # read the file 39 | with open(input_path, encoding='utf8') as p: 40 | txt = p.readlines() 41 | elif text is not None and (input_path is None or input_path is False): 42 | if type(text) is str: 43 | if read_lines: 44 | txt = text.split('\n') 45 | else: 46 | txt = [text] 47 | elif type(text) is list: 48 | txt = text 49 | else: 50 | raise ValueError('text should be text or list of text.') 51 | else: 52 | raise ValueError('you should provide a text string, list of strings or text path.') 53 | # run hebEMO 54 | hebEMO_df = pd.DataFrame(txt) 55 | for emo in tqdm(self.emotions): 56 | x = self.hebemo_models[emo](txt, truncation=truncation, max_length=max_length, batch_size=batch_size) 57 | hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo})) 58 | del x 59 | torch.cuda.empty_cache() 60 | hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x) 61 | if save_results is not False: 62 | gen_name = str(int(time.time()*1e7)) 63 | if type(save_results) is str: 64 | hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8') 65 | else: 66 | hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8') 67 | if plot: 68 | hebEMO = pd.DataFrame() 69 | for emo in hebEMO_df.columns[1::2]: 70 | hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo])) 71 | for i in range(0,1): 72 | ax = plutchik(hebEMO.to_dict(orient='records')[i]) 73 | print(hebEMO_df[0][i]) 74 | plt.show() 75 | return (hebEMO_df[0][i], ax) 76 | else: 77 | return (hebEMO_df) 78 | 79 | 80 | -------------------------------------------------------------------------------- /src/train_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import BertForSequenceClassification, BertTokenizer, BertModel, BertConfig 6 | import random 7 | 8 | 9 | class BertTrainer: 10 | def __init__(self, num_classes, max_length=15, model_name = 'bert-base-uncased', device="cpu", seed_num=42): 11 | self.tokenizer = BertTokenizer.from_pretrained(model_name) 12 | self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_classes) 13 | self.model.to(device) 14 | self.max_length = max_length 15 | self.device = device 16 | 17 | random.seed(seed_num) 18 | torch.manual_seed(seed_num) 19 | np.random.seed(seed_num) 20 | torch.cuda.manual_seed_all(seed_num) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | def tokenize(self, texts): 26 | encoded_dict = self.tokenizer.batch_encode_plus( 27 | texts, 28 | add_special_tokens=True, 29 | max_length=self.max_length, 30 | padding='max_length', 31 | truncation=True, 32 | return_attention_mask=True, 33 | return_tensors='pt', 34 | ) 35 | 36 | input_ids = encoded_dict['input_ids'].to(self.device) 37 | attention_masks = encoded_dict['attention_mask'].to(self.device) 38 | return input_ids, attention_masks 39 | 40 | def train(self, texts, labels, batch_size=32, epochs=5, learning_rate=2e-5, return_loss = False): 41 | self.model.train() 42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 43 | 44 | Epochs_Losses = [] 45 | for epoch in range(epochs): 46 | epoch_loss = 0 47 | epoch_steps = 0 48 | 49 | for i in range(0, len(texts), batch_size): 50 | optimizer.zero_grad() 51 | 52 | input_ids, attention_masks = self.tokenize(texts[i:i+batch_size]) 53 | batch_labels = torch.tensor(labels[i:i+batch_size]).to(self.device) 54 | 55 | outputs = self.model(input_ids, attention_mask=attention_masks, labels=batch_labels) 56 | loss = outputs[0] 57 | 58 | 59 | loss.backward() 60 | optimizer.step() 61 | 62 | epoch_loss += loss.item() 63 | epoch_steps += 1 64 | 65 | print(f"Epoch {epoch+1} Loss: {epoch_loss / epoch_steps:.4f}") 66 | Epochs_Losses.append(epoch_loss / epoch_steps) 67 | 68 | if return_loss: 69 | return(Epochs_Losses) 70 | 71 | def predict(self, texts): 72 | self.model.eval() 73 | input_ids, attention_masks = self.tokenize(texts) 74 | with torch.no_grad(): 75 | outputs = self.model(input_ids, attention_mask=attention_masks) 76 | logits = outputs[0] 77 | predictions = torch.argmax(logits, dim=1) 78 | 79 | return_dic = {} 80 | return_dic['predictions'] = predictions.cpu().numpy() 81 | return_dic['logits'] = logits.cpu().numpy() 82 | 83 | return return_dic 84 | 85 | def eval (self, texts, y_true): 86 | from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score 87 | 88 | predictions_outputs = self.predict(texts) 89 | st.write(confusion_matrix(y_true, predictions_outputs['predictions']), 90 | classification_report(y_true, predictions_outputs['predictions'], output_dict=True), 91 | 'auc : ', roc_auc_score(y_true, predictions_outputs['logits'][:, 1]), sep = '\n') 92 | 93 | 94 | 95 | from transformers import pipeline 96 | import streamlit as st 97 | from io import StringIO 98 | import pandas as pd 99 | from sklearn.model_selection import train_test_split 100 | 101 | 102 | st.title("Train your own model for Hebrew texts classification") 103 | 104 | model_name = st.selectbox('What is the base language model to train a model?', ('avichr/heBERT', 'avichr/Legal-heBERT_ft', 'avichr/Legal-heBERT', 105 | 'bert-base-uncased')) 106 | 107 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)', 108 | 5, 512, 20) 109 | 110 | num_classes = st.number_input('How many classes you have in the data?', 2,10,2) 111 | learning_rate = st.number_input('What is your learning rate?', min_value = 1e-10, max_value= 1e-1, value = 5e-5, step = 1e-10) 112 | batch_size = st.slider('What is your batch size?', 2,512, 32) 113 | epochs = st.slider('How many epochs to learn?', 1, 50, 5) 114 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 115 | trainer = BertTrainer(num_classes=num_classes, max_length=max_length, device=device, model_name= model_name) 116 | 117 | uploaded_file = st.file_uploader("Choose a file to train the model") 118 | to_predict_file = st.file_uploader("Choose a file to predict", key = 'to_predict_file') 119 | 120 | if uploaded_file is not None and to_predict_file is not None: 121 | df = pd.read_csv(uploaded_file) 122 | to_predict = pd.read_csv(to_predict_file) 123 | 124 | text_column = st.selectbox('What is the text column in your data?', (df.columns)) 125 | label_column = st.selectbox('What is the label column in your data? \n labels should be integers only', (df.columns), key = '_') 126 | text_column_to_predict = st.selectbox('What is the text column in your data to predict?', (to_predict.columns)) 127 | 128 | df = df[[text_column, label_column]].dropna() 129 | 130 | 131 | test_size = st.slider('% to split test from data? \n The selected file will be automatically divided into training and validation splits (recommended).', 0.0, 1.0, .33) 132 | 133 | if 'text_column' in locals() and 'label_column' in locals(): 134 | X_train, X_test, y_train, y_test = train_test_split(df[text_column].to_list(), df[label_column].to_list(), test_size=test_size, stratify= df[label_column], random_state=42) 135 | 136 | if 'X_train' in locals() and 'trainer' in locals() and st.button('Train', key='train'): 137 | trainer.train(X_train, y_train, batch_size=batch_size, epochs=epochs, learning_rate=5e-5) 138 | trained = True 139 | 140 | st.write('model perfomances:') 141 | st.write(trainer.eval(X_test, y_test)) 142 | 143 | 144 | st.write('Predictions:') 145 | preds = pd.DataFrame({'text': to_predict[text_column_to_predict], 146 | 'predictions': trainer.predict(to_predict[text_column_to_predict])['predictions'] 147 | }) 148 | 149 | @st.cache_data 150 | def convert_df(df): 151 | return df.to_csv(index=False).encode('utf-8') 152 | 153 | 154 | preds_df = convert_df(preds) 155 | 156 | st.download_button( 157 | "Press to Download", 158 | preds_df, 159 | "preds.csv", 160 | "text/csv", 161 | key='download-csv' 162 | ) 163 | 164 | st.write (preds) 165 | -------------------------------------------------------------------------------- /examples/sentiment_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "sentiment_analysis.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyOqltqt2sTyWBv32xl/QbV/", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "ydpDDgHtObR-" 32 | }, 33 | "source": [ 34 | "# !pip install shap==0.36.0\r\n", 35 | "# !pip install transformers\r\n", 36 | "# !pip install datasets" 37 | ], 38 | "execution_count": 1, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "MIDsbPBYNHOf" 45 | }, 46 | "source": [ 47 | "import shap\r\n", 48 | "import transformers\r\n", 49 | "import torch\r\n", 50 | "import numpy as np\r\n", 51 | "import scipy as sp\r\n", 52 | "import pandas as pd" 53 | ], 54 | "execution_count": 2, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "n8tGWji9NhDv", 64 | "outputId": "ae1daf52-2090-4f69-f09c-4c5a76988e04" 65 | }, 66 | "source": [ 67 | "from datasets import load_dataset\r\n", 68 | "dataset = load_dataset('text', data_files='text_example.txt', )\r\n" 69 | ], 70 | "execution_count": 3, 71 | "outputs": [ 72 | { 73 | "output_type": "stream", 74 | "text": [ 75 | "Using custom data configuration default\n", 76 | "Reusing dataset text (/root/.cache/huggingface/datasets/text/default-dba86d70c11ab66c/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)\n" 77 | ], 78 | "name": "stderr" 79 | } 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "colab": { 86 | "base_uri": "https://localhost:8080/" 87 | }, 88 | "id": "HdWEAfR3O4AA", 89 | "outputId": "c1d6b49a-c668-4b93-88b5-02bba1ab5b27" 90 | }, 91 | "source": [ 92 | "dataset['train']['text'][::-5]" 93 | ], 94 | "execution_count": 4, 95 | "outputs": [ 96 | { 97 | "output_type": "execute_result", 98 | "data": { 99 | "text/plain": [ 100 | "['מה אוכלים היום?',\n", 101 | " 'קפה זה טעים',\n", 102 | " 'אני אוהב את העולם',\n", 103 | " '\" וואו, מתאים לספר השיאים של גינס. מקסים ומרגש!!!! . לא יודעת מי המתורגמנית. אבל יודעת בוודאות כמה זה חשוב!!!! כל כך חשוב. יש מתורגמנית שכולם מכירים מהחדשות בטלויזיה, אפרת נגר המדהימה שהיא גם דולה... עבודת קודש אמיתית. מזל טוב יקרים. הרבה בריאות ושמחה.\"',\n", 104 | " ' והרכבת ? זה זמן עוד יותר טוב לרכבת שלא פעילה בכלל.. כל הכבוד זה הזמן אבל350 אלף כלי רכב בתקופת הקורונה????',\n", 105 | " ' והבס שלך לא עבריין??? 3 תיקי אישומים. ',\n", 106 | " '\" וביבי יצילנו מידם... הוא ולא שרף. 1. תחילה הסתה וגזענות - איבדתה את הרוב של הכנסת2.עתה נגד הדמוקרטיה - הצלחתה לאחד מולך את כל האופוזציה, כולל מבית3. צעד נוסף באותו כיוון - אשריך !\"']" 107 | ] 108 | }, 109 | "metadata": { 110 | "tags": [] 111 | }, 112 | "execution_count": 4 113 | } 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "5jlc1kleRxp9" 120 | }, 121 | "source": [ 122 | "\r\n", 123 | "from transformers import AutoTokenizer, AutoModel, pipeline\r\n", 124 | "\r\n", 125 | "# how to use?\r\n", 126 | "sentiment_analysis = pipeline(\r\n", 127 | " \"sentiment-analysis\",\r\n", 128 | " model=\"avichr/heBERT_sentiment_analysis\",\r\n", 129 | " tokenizer=\"avichr/heBERT_sentiment_analysis\", \r\n", 130 | " return_all_scores = True, \r\n", 131 | ")\r\n" 132 | ], 133 | "execution_count": 5, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/" 141 | }, 142 | "id": "T88MJxFBR53F", 143 | "outputId": "2f273d1e-9af1-4698-8b12-c152642d8996" 144 | }, 145 | "source": [ 146 | "for i in dataset['train']['text'][0:5]:\r\n", 147 | " print(i, sentiment_analysis(i), sep='\\n')\r\n" 148 | ], 149 | "execution_count": 6, 150 | "outputs": [ 151 | { 152 | "output_type": "stream", 153 | "text": [ 154 | ], 155 | "name": "stdout" 156 | } 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "metadata": { 162 | "colab": { 163 | "base_uri": "https://localhost:8080/" 164 | }, 165 | "id": "x2yC3wdRO_k3", 166 | "outputId": "fea1da5c-2136-4fd7-c627-6203d93b93bf" 167 | }, 168 | "source": [ 169 | "\r\n", 170 | "\r\n", 171 | "# load a BERT sentiment analysis model\r\n", 172 | "tokenizer = transformers.BertTokenizerFast.from_pretrained(\"avichr/heBERT_sentiment_analysis\")\r\n", 173 | "model = transformers.BertForSequenceClassification.from_pretrained(\"avichr/heBERT_sentiment_analysis\").cuda()\r\n", 174 | "\r\n", 175 | "# define a prediction function\r\n", 176 | "def f(x):\r\n", 177 | " tv = torch.tensor([tokenizer.encode(v, pad_to_max_length=True, max_length=500, truncation=True, padding='longest') for v in x]).cuda()\r\n", 178 | " outputs = model(tv)[0].detach().cpu().numpy()\r\n", 179 | " scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T\r\n", 180 | " val = sp.special.logit(scores[:,1]) # use one vs rest logit units\r\n", 181 | " return val\r\n", 182 | "\r\n", 183 | "# build an explainer using a token masker\r\n", 184 | "explainer = shap.Explainer(f, tokenizer)\r\n", 185 | "\r\n", 186 | "# explain the model's predictions on IMDB reviews\r\n", 187 | "train_dataset = dataset['train']['text']\r\n", 188 | "shap_values = explainer(train_dataset)\r\n" 189 | ], 190 | "execution_count": 7, 191 | "outputs": [ 192 | { 193 | "output_type": "stream", 194 | "text": [ 195 | "explainers.Partition is still in an alpha state, so use with caution...\n" 196 | ], 197 | "name": "stderr" 198 | } 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "metadata": { 204 | "id": "Ti31Ki0VPTwu" 205 | }, 206 | "source": [ 207 | "# shap.initjs()\r\n", 208 | "\r\n", 209 | "# shap.plots.text(shap_values)\r\n" 210 | ], 211 | "execution_count": 8, 212 | "outputs": [] 213 | } 214 | ] 215 | } 216 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HeBERT: Pre-trained BERT for Polarity Analysis and Emotion Recognition 2 | 3 | 4 | HeBERT is a Hebrew pre-trained language model. It is based on [Google's BERT](https://arxiv.org/abs/1810.04805) architecture and it is a BERT-Base config.
5 | 6 | HeBERT was trained on three dataset: 7 | 1. A Hebrew version of [OSCAR](https://oscar-corpus.com/): ~9.8 GB of data, including 1 billion words and over 20.8 million sentences. 8 | 2. A Hebrew dump of [Wikipedia](https://dumps.wikimedia.org/): ~650 MB of data, including over 63 million words and 3.8 million sentences 9 | 3. Emotion User Generated Content (UGC) data that was collected for the purpose of this study (described below). 10 |
11 |
12 | We evaluated the model on downstream tasks: emotions recognition and sentiment analysis. 13 | 14 | ## Emotion UGC Data Description 15 | Our UGC data include comments posted on news articles collected from 3 major Israeli news sites, between January 2020 to August 2020. The total size of the data is ~150 MB, including over 7 million words and 350K sentences. 16 | ~4000 sentences were annotated by crowd members (3-10 annotators per sentence) for overall sentiment (polarity) and [eight emotions](https://en.wikipedia.org/wiki/Robert_Plutchik#Plutchik's_wheel_of_emotions): anger, disgust, expectation , fear, joy, sadness, surprise and trust. 17 | 18 | For our robustness analyses, we also collected and annotated two additional datasets. The first contains a random set of comments taken from our in-domain dataset (that is, comments that were posted on Covid-related news articles). The second is a random set of comments taken from an out-of-domain dataset containing comments that were posted in response to non-Covid-related articles from the same news sites. An additional explanation can be found in section 5.1 of our article. 19 | The percentage of sentences in which each emotion appeared is found in the table below. 20 | 21 | | | anger | disgust | expectation | fear | happy | sadness | surprise | trust | sentiment | 22 | |------: |------:|--------:|------------:|-----:|------:|--------:|---------:|------:|-----------| 23 | | **Main Dataset** | 0.78 | 0.83 | 0.58 | 0.45 | 0.12 | 0.59 | 0.17 | 0.11 | 0.25 | 24 | | **Random Comments from the Corpus** | 0.79 | 0.87 | 0.46 | 0.17 | 0.03 | 0.30 | 0.00 | 0.03 | 0.02 | 25 | | **Out of Domain** | 0.76 | 0.89 | 0.62 | 0.10 | 0.08 | 0.36 | 0.02 | 0.13 | 0.12 | 26 | 27 | 28 | All the datasets can be found on "data.zip" in this git (where each row stands for a different annotator of a sentence). The agreed score which we used to train and test our models, can be found in the column 'agreed score' (if we found sufficient agreement). See our article for more details on the annotation process.
29 | If you use our datasets please cite us (can be found below). 30 | 31 | ## Performance 32 | ### Emotion Recognition 33 | | emotion | f1-score | precision | recall | 34 | |-------------|----------|-----------|----------| 35 | | anger | 0.96 | 0.99 | 0.93 | 36 | | disgust | 0.97 | 0.98 | 0.96 | 37 | | expectation | 0.82 | 0.80 | 0.87 | 38 | | fear | 0.79 | 0.88 | 0.72 | 39 | | happy | 0.90 | 0.97 | 0.84 | 40 | | sadness | 0.90 | 0.86 | 0.94 | 41 | | sentiment | 0.88 | 0.90 | 0.87 | 42 | | surprise | 0.40 | 0.44 | 0.37 | 43 | | trust | 0.83 | 0.86 | 0.80 | 44 | 45 | *The above metrics for positive class (meaning, the emotion is reflected in text) for the main dataset* 46 | 47 | ### Sentiment (Polarity) Analysis 48 | | | precision | recall | f1-score | 49 | |--------------|-----------|--------|----------| 50 | | natural | 0.83 | 0.56 | 0.67 | 51 | | positive | 0.96 | 0.92 | 0.94 | 52 | | negative | 0.97 | 0.99 | 0.98 | 53 | | accuracy | | | 0.97 | 54 | | macro avg | 0.92 | 0.82 | 0.86 | 55 | | weighted avg | 0.96 | 0.97 | 0.96 | 56 | 57 | ## How to use 58 | ### For Emotion Recognition Model 59 | An online model can be found at [huggingface spaces](https://huggingface.co/spaces/avichr/HebEMO_demo) or as [colab notebook](https://colab.research.google.com/drive/1Jw3gOWjwVMcZslu-ttXoNeD17lms1-ff?usp=sharing) 60 | ``` 61 | # !pip install pyplutchik==0.0.7 62 | # !pip install transformers==4.14.1 63 | 64 | !git clone https://github.com/avichaychriqui/HeBERT.git 65 | from HeBERT.src.HebEMO import * 66 | HebEMO_model = HebEMO() 67 | 68 | HebEMO_model.hebemo(input_path = 'examples/text_example.txt') 69 | # return analyzed pandas.DataFrame 70 | 71 | hebEMO_df = HebEMO_model.hebemo(text='החיים יפים ומאושרים', plot=True) 72 | ``` 73 | 74 | 75 | 76 | 77 | ### For masked-LM model (can be [fine-tunned to any down-stream task](https://colab.research.google.com/drive/1_t-9ALMQ6WqVExJT2NEdR7x67CDaE1FP?usp=sharing)) 78 | 79 | from transformers import AutoTokenizer, AutoModel 80 | tokenizer = AutoTokenizer.from_pretrained("avichr/heBERT") 81 | model = AutoModel.from_pretrained("avichr/heBERT") 82 | 83 | from transformers import pipeline 84 | fill_mask = pipeline( 85 | "fill-mask", 86 | model="avichr/heBERT", 87 | tokenizer="avichr/heBERT" 88 | ) 89 | fill_mask("הקורונה לקחה את [MASK] ולנו לא נשאר דבר.") 90 | 91 | ### For sentiment classification model (polarity ONLY): 92 | from transformers import AutoTokenizer, AutoModel, pipeline 93 | tokenizer = AutoTokenizer.from_pretrained("avichr/heBERT_sentiment_analysis") #same as 'avichr/heBERT' tokenizer 94 | model = AutoModel.from_pretrained("avichr/heBERT_sentiment_analysis") 95 | 96 | # how to use? 97 | sentiment_analysis = pipeline( 98 | "sentiment-analysis", 99 | model="avichr/heBERT_sentiment_analysis", 100 | tokenizer="avichr/heBERT_sentiment_analysis", 101 | return_all_scores = True 102 | ) 103 | 104 | sentiment_analysis('אני מתלבט מה לאכול לארוחת צהריים') 105 | >>> [[{'label': 'natural', 'score': 0.9978172183036804}, 106 | >>> {'label': 'positive', 'score': 0.0014792329166084528}, 107 | >>> {'label': 'negative', 'score': 0.0007035882445052266}]] 108 | 109 | sentiment_analysis('קפה זה טעים') 110 | >>> [[{'label': 'natural', 'score': 0.00047328314394690096}, 111 | >>> {'label': 'possitive', 'score': 0.9994067549705505}, 112 | >>> {'label': 'negetive', 'score': 0.00011996887042187154}]] 113 | 114 | sentiment_analysis('אני לא אוהב את העולם') 115 | >>> [[{'label': 'natural', 'score': 9.214012970915064e-05}, 116 | >>> {'label': 'possitive', 'score': 8.876807987689972e-05}, 117 | >>> {'label': 'negetive', 'score': 0.9998190999031067}]] 118 | 119 | 120 | Our model is also available on AWS! for more information visit [AWS' git](https://github.com/aws-samples/aws-lambda-docker-serverless-inference/tree/main/hebert-sentiment-analysis-inference-docker-lambda) 121 | 122 | ## Named-entity recognition (NER) 123 | The ability of the model to classify named entities in text, such as persons' names, organizations, and locations; tested on a labeled dataset from [Ben Mordecai and M Elhadad (2005)](https://www.cs.bgu.ac.il/~elhadad/nlpproj/naama/), and evaluated with F1-score. 124 | [Colab notebook](https://colab.research.google.com/drive/18Uhq3HMxudo1XWaYZ14GI-_Op5Ds9EGo?usp=sharing) 125 | 126 | ### How to use 127 | ``` 128 | from transformers import pipeline 129 | 130 | # how to use? 131 | NER = pipeline( 132 | "token-classification", 133 | model="avichr/heBERT_NER", 134 | tokenizer="avichr/heBERT_NER", 135 | ) 136 | NER('דויד לומד באוניברסיטה העברית שבירושלים') 137 | ``` 138 | 139 | 140 | ## Contact us 141 | [Avichay Chriqui](mailto:avichayc@mail.tau.ac.il)
142 | [Inbal yahav](mailto:inbalyahav@tauex.tau.ac.il)
143 | The Coller Semitic Languages AI Lab
144 | Thank you, תודה, شكرا
145 | 146 | ## If you used this model please cite us as : 147 | Chriqui, A., & Yahav, I. (2022). HeBERT & HebEMO: a Hebrew BERT Model and a Tool for Polarity Analysis and Emotion Recognition. INFORMS Journal on Data Science, forthcoming. 148 | ``` 149 | @article{chriqui2021hebert, 150 | title={HeBERT \& HebEMO: a Hebrew BERT Model and a Tool for Polarity Analysis and Emotion Recognition}, 151 | author={Chriqui, Avihay and Yahav, Inbal}, 152 | journal={INFORMS Journal on Data Science}, 153 | year={2022} 154 | } 155 | ``` 156 | --------------------------------------------------------------------------------