├── examples
├── .MD
├── shap_examples.png
├── text_example.txt
└── sentiment_analysis.ipynb
├── data.zip
├── hebEMO1.png
├── heBERT_logo.png
├── requirements.txt
├── LICENSE
├── src
├── spider_plot.py
├── HeBERT_ner_app.py
├── HebEMO_app.py
├── HebEMO.py
└── train_model.py
└── README.md
/examples/.MD:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/data.zip
--------------------------------------------------------------------------------
/hebEMO1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/hebEMO1.png
--------------------------------------------------------------------------------
/heBERT_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/heBERT_logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.14.1
2 | pyplutchik==0.0.7
3 | torch
4 | streamlit
5 | localtunnel
6 |
--------------------------------------------------------------------------------
/examples/shap_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avichaychriqui/HeBERT/HEAD/examples/shap_examples.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Avichay Chriqui and Inbal Yahav
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/spider_plot.py:
--------------------------------------------------------------------------------
1 | def spider_plot(df):
2 | '''
3 | retrieved from https://www.python-graph-gallery.com/391-radar-chart-with-several-individuals
4 | '''
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | from math import pi
8 |
9 | # number of variable
10 | categories=list(df)
11 | N = len(categories)
12 |
13 | # We are going to plot the first line of the data frame.
14 | # But we need to repeat the first value to close the circular graph:
15 | values=df.loc[0].values.flatten().tolist()
16 | values += values[:1]
17 | values
18 |
19 | # What will be the angle of each axis in the plot? (we divide the plot / number of variable)
20 | angles = [n / float(N) * 2 * pi for n in range(N)]
21 | angles += angles[:1]
22 |
23 | # Initialise the spider plot
24 | fig, ax = plt.subplot(111, polar=True)
25 |
26 | # Draw one axe per variable + add labels
27 | plt.xticks(angles[:-1], categories, color='grey', size=8)
28 |
29 | # Draw ylabels
30 | ax.set_rlabel_position(0)
31 | plt.yticks([10,20,30], ["10","20","30"], color="grey", size=7)
32 | plt.ylim(0,1)
33 |
34 | # Plot data
35 | ax.plot(angles, values, linewidth=1, linestyle='solid')
36 |
37 | # Fill area
38 | ax.fill(angles, values, 'b', alpha=0.1)
39 |
40 | # Show the graph
41 | return(fig, ax)
42 |
--------------------------------------------------------------------------------
/src/HeBERT_ner_app.py:
--------------------------------------------------------------------------------
1 | from transformers import pipeline
2 | import streamlit as st
3 | from io import StringIO
4 | import pandas as pd
5 |
6 | st.title("NER in Hebrew Texts")
7 | st.write("Named-entity recognition (NER) (also known as (named) entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc. [wikipedia]")
8 | st.write("Write Hebrew sentences in the text box below to analyze (each sentence in a different rew). ")
9 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)',
10 | 5, 512, 20)
11 |
12 | NER = pipeline(
13 | "token-classification",
14 | model="avichr/heBERT_NER",
15 | tokenizer="avichr/heBERT_NER",
16 | ignore_labels = [],
17 | aggregation_strategy = 'simple'
18 | )
19 |
20 | method = st.selectbox(
21 | 'What is your input?',
22 | ('File', 'Free Text'))
23 |
24 | if method == 'File':
25 | uploaded_file = st.file_uploader("Choose a file")
26 |
27 | if uploaded_file is not None:
28 | # To read file as bytes:
29 | bytes_data = uploaded_file.getvalue()
30 |
31 | # To convert to a string based IO:
32 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
33 |
34 | # To read file as string:
35 | string_data = stringio.read()
36 |
37 | else:
38 | string_data = st.text_area(label = 'Text to analyze', value="",
39 | placeholder = '''
40 | It was the best of times, it was the worst of times, it was
41 | the age of wisdom, it was the age of foolishness, it was
42 | the epoch of belief, it was the epoch of incredulity, it
43 | was the season of Light, it was the season of Darkness, it
44 | was the spring of hope, it was the winter of despair, (...)
45 | ''')
46 | if 'string_data' in locals():
47 | if (string_data != '' and string_data is not None):
48 | st.write('It takes a while, be patient :)')
49 | ner_df = pd.DataFrame(NER(string_data))
50 | @st.cache_data
51 | def convert_df(df):
52 | return df.to_csv(index=False).encode('utf-8')
53 |
54 |
55 | csv = convert_df(ner_df)
56 |
57 | st.download_button(
58 | "Press to Download",
59 | csv,
60 | "ner_df.csv",
61 | "text/csv",
62 | key='download-csv'
63 | )
64 |
65 | st.write (ner_df)
66 |
--------------------------------------------------------------------------------
/examples/text_example.txt:
--------------------------------------------------------------------------------
1 | " וביבי יצילנו מידם... הוא ולא שרף. 1. תחילה הסתה וגזענות - איבדתה את הרוב של הכנסת2.עתה נגד הדמוקרטיה - הצלחתה לאחד מולך את כל האופוזציה, כולל מבית3. צעד נוסף באותו כיוון - אשריך !"
2 | " וביבי שותק. הם לא נעמה. הם נמצאים בבידוד ולא במאסר, הם בסביבה שדואגים להם והם לא נמצאים בסכנה מיידית כלשהי, ובנוסף הם עלולים להוות סיכון לאחרים. עם כל אי הנוחות, להישאר על הספינה זה הפתרון הטוב ביותר כרגע"
3 | ובישראל חסר מסביר כי ביבי לא ייתן לאף אחד לגנוב לו את ההצגה. היועץ הלאומי של שישה נשיאים ברציפות. מיגור הסארס בארה'ב נזקף לזכותו. הוא הפילטר להכרזות ההזויות של הנשיא הנוכחי (רק מלהזכיר את השם שלו אני חוטף קורו......)
4 | " וגם הרוב עצות יעילות וטובות, מנסיון "
5 | וגם השחיקה של התלמידים. אני אחזור ולא אוכל לדאוג להורי או לסכן את עצמי. גם גיל 45 גיל מסוכן
6 | והבס שלך לא עבריין??? 3 תיקי אישומים.
7 | " והוא יעשה נזק עוד יותר חמור במשרד השיכון. מענין מה קרה לקופסת הכותב? אפילו בבדיחה לא הייתי חושב ככה על ליצמן, אפילו שאני רואה טלווזיה כל יום."
8 | " והם ממש לוגיים. אסור לגולש גלים לגלוש בים הפתוח.... לו אנשים בישראל היו ממושמעים, אז היה הגיוני לאפשר הקלות פה ושם, אבל רק היום ראיתי מלא (מעל 6 אנשים) ללא מסיכות על הפנים, כמה היו עם מסיכה בהישג יד, בכיס האחורי !!!נורא"
9 | והמסכות במקור מסין בתקווה שהיבוא לא יביא גם את הוירוס
10 | והסינים ממשיכים לחיות כרגיל . יש להם טריליוני דולרים בכספות.... רוב החוב הגולמי של ארהב שוחב בכספות הסיניות... אם זה היה נגיף יהודי....היו דורשים מאיתנו את מחזור הדם...
11 | והרכבת ? זה זמן עוד יותר טוב לרכבת שלא פעילה בכלל.. כל הכבוד זה הזמן אבל350 אלף כלי רכב בתקופת הקורונה????
12 | " וואו ! התמונה של תל אביב מדהימה ,מריחים את העושר של ישראל . בכל זאת להרגיש ביחד. חג פסח שמח לכולכם."
13 | " וואו ! התמונה של תל אביב מדהימה ,מריחים את העושר של ישראל. בכל זאת להרגיש ביחד. חג פסח שמח לכולכם."
14 | וואו שלווה את מלכה וחזקה.. מותר לשלוח למבודדים משלוח מנות?. נו תרימו ראשכם! נכון זה לא שפעת אבל היי. השד לא נורא
15 | וואו! בהצלחה גם בשם חולי פיברו וcfs!!! . הגעתי אליה לפני יותר מעשרים שנה. זכורה לי מאז. רופאה מקצועית אנושית ומדברת עינייני ובגובה העיניים.
16 | " וואו, מתאים לספר השיאים של גינס. מקסים ומרגש!!!! . לא יודעת מי המתורגמנית. אבל יודעת בוודאות כמה זה חשוב!!!! כל כך חשוב. יש מתורגמנית שכולם מכירים מהחדשות בטלויזיה, אפרת נגר המדהימה שהיא גם דולה... עבודת קודש אמיתית. מזל טוב יקרים. הרבה בריאות ושמחה."
17 | " וואי ביבי הגנב,הנוכל, לוקח השוחד, מפר האמונים ומחמש את האויב. יהרגו את בני האדםחאלס זה מקרה חירום אין מקום לזכויות אדם. רק מקום לבטחון של האזרחים של יחלו ולא ימותו"
18 | " וואלה צודק. אם לא היינו נלחצים אולי גם ליצמן היה נפטר . בגלל שלא נזהר, הוא ואשתו נדבקו בקורונהבגלל שהציבור החרדי לא נזהר, אצלם מספר החולים והמתים הגבוה ביותר - בכל תחשיב שהוא."
19 | " ווואו היא מקסימה ומצחיקה . ואני מכירה את העולם הזה. נמאס מהזיוף והשטויות, שינצלו את הכוח שלהם, עלק כוח, בשביל דברים חשובים יותר. קרן מוכשרת חכמה ויפהפיה והיא מאלה שישרדו"
20 | " וועדת חקירה עכשיו. לחקור לאן נעלמו חומרי ההרדמה ועוד ציוד. בראש הנתונים המתפרסם במדיה נמצא מס' הנדבקים מתחילת המגפה, צריך לשנות כך שהנתון הבולט יהיה מס' החולים המאומתים עכשיו ובהמשך אותם יתר הנתונים. כמו מס'"
21 | אני אוהב את העולם
22 | אני לא אוהב את העולם
23 | אני ממש אוהב את העולם
24 | אני לא ממש אוהב את העולם
25 | אני ממש לא אוהב את העולם
26 | קפה זה טעים
27 | קפה זה לא טעים
28 | קפה זה סבבה
29 | קפה זה לא ממש טעים
30 | קפה זה טעים אש
31 | מה אוכלים היום?
32 |
--------------------------------------------------------------------------------
/src/HebEMO_app.py:
--------------------------------------------------------------------------------
1 | from HebEMO import HebEMO
2 | from transformers import pipeline
3 | import streamlit as st
4 | from io import StringIO
5 | import pandas as pd
6 |
7 | st.title("Emotion Recognition in Hebrew Texts")
8 | st.write("HebEMO is a tool to detect polarity and extract emotions from Hebrew user-generated content (UGC), which was trained on a unique Covid-19 related dataset that we collected and annotated. HebEMO yielded a high performance of weighted average F1-score = 0.96 for polarity classification. Emotion detection reached an F1-score of 0.78-0.97, with the exception of *surprise*, which the model failed to capture (F1 = 0.41). More information can be found in our git: https://github.com/avichaychriqui/HeBERT")
9 | st.write("Write Hebrew sentences in the text box below to analyze (each sentence in a different rew). An additional demo can be found in the Colab notebook: https://colab.research.google.com/drive/1Jw3gOWjwVMcZslu-ttXoNeD17lms1-ff ")
10 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)',
11 | 5, 512, 20)
12 |
13 |
14 | option = st.selectbox(
15 | 'What do you want to analyze?',
16 | ('Sentiment', 'Emotion'))
17 |
18 | if option == 'Emotion':
19 | emotions = st.multiselect(
20 | 'Which emotion to analyze?:',
21 | ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger',
22 | 'sadness', 'disgust'],
23 | ['joy'])
24 | else:
25 | emotions = ['sentiment']
26 |
27 | method = st.selectbox(
28 | 'What is your input?',
29 | ('File', 'Free Text'))
30 |
31 | if method == 'File':
32 | uploaded_file = st.file_uploader("Choose a file")
33 |
34 | if uploaded_file is not None:
35 | # To read file as bytes:
36 | bytes_data = uploaded_file.getvalue()
37 |
38 | # To convert to a string based IO:
39 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
40 |
41 | # To read file as string:
42 | string_data = stringio.read()
43 |
44 | else:
45 | string_data = st.text_area(label = 'Text to analyze', value="",
46 | placeholder = '''
47 | It was the best of times, it was the worst of times, it was
48 | the age of wisdom, it was the age of foolishness, it was
49 | the epoch of belief, it was the epoch of incredulity, it
50 | was the season of Light, it was the season of Darkness, it
51 | was the spring of hope, it was the winter of despair, (...)
52 | ''')
53 | if 'string_data' in locals():
54 |
55 | if (string_data != '' and string_data is not None):
56 | st.write('It takes a while, be patient :)')
57 | HebEMO_model = HebEMO(device= 0, emotions = emotions)
58 |
59 | hebEMO_df = HebEMO_model.hebemo(string_data, read_lines=True, plot=False, batch_size=32, max_length = max_length)
60 |
61 | if option == 'Emotion':
62 | hebEMO = pd.DataFrame(hebEMO_df[0])
63 | for emo in hebEMO_df.columns[1::2]:
64 | hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))
65 | else:
66 | hebEMO = hebEMO_df
67 |
68 | @st.cache_data
69 | def convert_df(df):
70 | return df.to_csv(index=False).encode('utf-8')
71 |
72 |
73 | csv = convert_df(hebEMO)
74 |
75 | st.download_button(
76 | "Press to Download",
77 | csv,
78 | "hebEMO-ed.csv",
79 | "text/csv",
80 | key='download-csv'
81 | )
82 |
83 | st.write (hebEMO)
84 |
--------------------------------------------------------------------------------
/src/HebEMO.py:
--------------------------------------------------------------------------------
1 | class HebEMO:
2 | def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger',
3 | 'sadness', 'disgust']):
4 | from transformers import pipeline
5 | from tqdm import tqdm
6 | self.device = device
7 | if type(emotions) == str:
8 | self.emotions = [emotions]
9 | elif type(emotions) != list:
10 | raise ValueError('emotions should be emotion as a text or list of emotions.')
11 | else:
12 | self.emotions = emotions
13 | self.hebemo_models = {}
14 | for emo in tqdm(self.emotions):
15 | if emo == 'sentiment':
16 | model_name = 'avichr/heBERT_sentiment_analysis'
17 | else:
18 | model_name = "avichr/hebEMO_"+emo
19 | self.hebemo_models[emo] = pipeline(
20 | "sentiment-analysis",
21 | model=model_name,
22 | tokenizer="avichr/heBERT",
23 | device = self.device #-1 run on CPU, else - device ID
24 | )
25 |
26 | def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False, batch_size=32, max_length = 512, truncation=True):
27 | '''
28 | text (str): a text or list of text to analyze
29 | input_path(str): the path to the text file (txt file, each row for different instance)
30 | '''
31 | from pyplutchik import plutchik
32 | import matplotlib.pyplot as plt
33 | import pandas as pd
34 | import time
35 | import torch
36 | from tqdm import tqdm
37 | if text is None and type(input_path) is str:
38 | # read the file
39 | with open(input_path, encoding='utf8') as p:
40 | txt = p.readlines()
41 | elif text is not None and (input_path is None or input_path is False):
42 | if type(text) is str:
43 | if read_lines:
44 | txt = text.split('\n')
45 | else:
46 | txt = [text]
47 | elif type(text) is list:
48 | txt = text
49 | else:
50 | raise ValueError('text should be text or list of text.')
51 | else:
52 | raise ValueError('you should provide a text string, list of strings or text path.')
53 | # run hebEMO
54 | hebEMO_df = pd.DataFrame(txt)
55 | for emo in tqdm(self.emotions):
56 | x = self.hebemo_models[emo](txt, truncation=truncation, max_length=max_length, batch_size=batch_size)
57 | hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo}))
58 | del x
59 | torch.cuda.empty_cache()
60 | hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x)
61 | if save_results is not False:
62 | gen_name = str(int(time.time()*1e7))
63 | if type(save_results) is str:
64 | hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8')
65 | else:
66 | hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8')
67 | if plot:
68 | hebEMO = pd.DataFrame()
69 | for emo in hebEMO_df.columns[1::2]:
70 | hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))
71 | for i in range(0,1):
72 | ax = plutchik(hebEMO.to_dict(orient='records')[i])
73 | print(hebEMO_df[0][i])
74 | plt.show()
75 | return (hebEMO_df[0][i], ax)
76 | else:
77 | return (hebEMO_df)
78 |
79 |
80 |
--------------------------------------------------------------------------------
/src/train_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformers import BertForSequenceClassification, BertTokenizer, BertModel, BertConfig
6 | import random
7 |
8 |
9 | class BertTrainer:
10 | def __init__(self, num_classes, max_length=15, model_name = 'bert-base-uncased', device="cpu", seed_num=42):
11 | self.tokenizer = BertTokenizer.from_pretrained(model_name)
12 | self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
13 | self.model.to(device)
14 | self.max_length = max_length
15 | self.device = device
16 |
17 | random.seed(seed_num)
18 | torch.manual_seed(seed_num)
19 | np.random.seed(seed_num)
20 | torch.cuda.manual_seed_all(seed_num)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 |
24 |
25 | def tokenize(self, texts):
26 | encoded_dict = self.tokenizer.batch_encode_plus(
27 | texts,
28 | add_special_tokens=True,
29 | max_length=self.max_length,
30 | padding='max_length',
31 | truncation=True,
32 | return_attention_mask=True,
33 | return_tensors='pt',
34 | )
35 |
36 | input_ids = encoded_dict['input_ids'].to(self.device)
37 | attention_masks = encoded_dict['attention_mask'].to(self.device)
38 | return input_ids, attention_masks
39 |
40 | def train(self, texts, labels, batch_size=32, epochs=5, learning_rate=2e-5, return_loss = False):
41 | self.model.train()
42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
43 |
44 | Epochs_Losses = []
45 | for epoch in range(epochs):
46 | epoch_loss = 0
47 | epoch_steps = 0
48 |
49 | for i in range(0, len(texts), batch_size):
50 | optimizer.zero_grad()
51 |
52 | input_ids, attention_masks = self.tokenize(texts[i:i+batch_size])
53 | batch_labels = torch.tensor(labels[i:i+batch_size]).to(self.device)
54 |
55 | outputs = self.model(input_ids, attention_mask=attention_masks, labels=batch_labels)
56 | loss = outputs[0]
57 |
58 |
59 | loss.backward()
60 | optimizer.step()
61 |
62 | epoch_loss += loss.item()
63 | epoch_steps += 1
64 |
65 | print(f"Epoch {epoch+1} Loss: {epoch_loss / epoch_steps:.4f}")
66 | Epochs_Losses.append(epoch_loss / epoch_steps)
67 |
68 | if return_loss:
69 | return(Epochs_Losses)
70 |
71 | def predict(self, texts):
72 | self.model.eval()
73 | input_ids, attention_masks = self.tokenize(texts)
74 | with torch.no_grad():
75 | outputs = self.model(input_ids, attention_mask=attention_masks)
76 | logits = outputs[0]
77 | predictions = torch.argmax(logits, dim=1)
78 |
79 | return_dic = {}
80 | return_dic['predictions'] = predictions.cpu().numpy()
81 | return_dic['logits'] = logits.cpu().numpy()
82 |
83 | return return_dic
84 |
85 | def eval (self, texts, y_true):
86 | from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
87 |
88 | predictions_outputs = self.predict(texts)
89 | st.write(confusion_matrix(y_true, predictions_outputs['predictions']),
90 | classification_report(y_true, predictions_outputs['predictions'], output_dict=True),
91 | 'auc : ', roc_auc_score(y_true, predictions_outputs['logits'][:, 1]), sep = '\n')
92 |
93 |
94 |
95 | from transformers import pipeline
96 | import streamlit as st
97 | from io import StringIO
98 | import pandas as pd
99 | from sklearn.model_selection import train_test_split
100 |
101 |
102 | st.title("Train your own model for Hebrew texts classification")
103 |
104 | model_name = st.selectbox('What is the base language model to train a model?', ('avichr/heBERT', 'avichr/Legal-heBERT_ft', 'avichr/Legal-heBERT',
105 | 'bert-base-uncased'))
106 |
107 | max_length = st.slider('What is the maximum length of the text to analyze? \n (The longer the text, the longer the calculation time will be)',
108 | 5, 512, 20)
109 |
110 | num_classes = st.number_input('How many classes you have in the data?', 2,10,2)
111 | learning_rate = st.number_input('What is your learning rate?', min_value = 1e-10, max_value= 1e-1, value = 5e-5, step = 1e-10)
112 | batch_size = st.slider('What is your batch size?', 2,512, 32)
113 | epochs = st.slider('How many epochs to learn?', 1, 50, 5)
114 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
115 | trainer = BertTrainer(num_classes=num_classes, max_length=max_length, device=device, model_name= model_name)
116 |
117 | uploaded_file = st.file_uploader("Choose a file to train the model")
118 | to_predict_file = st.file_uploader("Choose a file to predict", key = 'to_predict_file')
119 |
120 | if uploaded_file is not None and to_predict_file is not None:
121 | df = pd.read_csv(uploaded_file)
122 | to_predict = pd.read_csv(to_predict_file)
123 |
124 | text_column = st.selectbox('What is the text column in your data?', (df.columns))
125 | label_column = st.selectbox('What is the label column in your data? \n labels should be integers only', (df.columns), key = '_')
126 | text_column_to_predict = st.selectbox('What is the text column in your data to predict?', (to_predict.columns))
127 |
128 | df = df[[text_column, label_column]].dropna()
129 |
130 |
131 | test_size = st.slider('% to split test from data? \n The selected file will be automatically divided into training and validation splits (recommended).', 0.0, 1.0, .33)
132 |
133 | if 'text_column' in locals() and 'label_column' in locals():
134 | X_train, X_test, y_train, y_test = train_test_split(df[text_column].to_list(), df[label_column].to_list(), test_size=test_size, stratify= df[label_column], random_state=42)
135 |
136 | if 'X_train' in locals() and 'trainer' in locals() and st.button('Train', key='train'):
137 | trainer.train(X_train, y_train, batch_size=batch_size, epochs=epochs, learning_rate=5e-5)
138 | trained = True
139 |
140 | st.write('model perfomances:')
141 | st.write(trainer.eval(X_test, y_test))
142 |
143 |
144 | st.write('Predictions:')
145 | preds = pd.DataFrame({'text': to_predict[text_column_to_predict],
146 | 'predictions': trainer.predict(to_predict[text_column_to_predict])['predictions']
147 | })
148 |
149 | @st.cache_data
150 | def convert_df(df):
151 | return df.to_csv(index=False).encode('utf-8')
152 |
153 |
154 | preds_df = convert_df(preds)
155 |
156 | st.download_button(
157 | "Press to Download",
158 | preds_df,
159 | "preds.csv",
160 | "text/csv",
161 | key='download-csv'
162 | )
163 |
164 | st.write (preds)
165 |
--------------------------------------------------------------------------------
/examples/sentiment_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "sentiment_analysis.ipynb",
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyOqltqt2sTyWBv32xl/QbV/",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | ""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "id": "ydpDDgHtObR-"
32 | },
33 | "source": [
34 | "# !pip install shap==0.36.0\r\n",
35 | "# !pip install transformers\r\n",
36 | "# !pip install datasets"
37 | ],
38 | "execution_count": 1,
39 | "outputs": []
40 | },
41 | {
42 | "cell_type": "code",
43 | "metadata": {
44 | "id": "MIDsbPBYNHOf"
45 | },
46 | "source": [
47 | "import shap\r\n",
48 | "import transformers\r\n",
49 | "import torch\r\n",
50 | "import numpy as np\r\n",
51 | "import scipy as sp\r\n",
52 | "import pandas as pd"
53 | ],
54 | "execution_count": 2,
55 | "outputs": []
56 | },
57 | {
58 | "cell_type": "code",
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "id": "n8tGWji9NhDv",
64 | "outputId": "ae1daf52-2090-4f69-f09c-4c5a76988e04"
65 | },
66 | "source": [
67 | "from datasets import load_dataset\r\n",
68 | "dataset = load_dataset('text', data_files='text_example.txt', )\r\n"
69 | ],
70 | "execution_count": 3,
71 | "outputs": [
72 | {
73 | "output_type": "stream",
74 | "text": [
75 | "Using custom data configuration default\n",
76 | "Reusing dataset text (/root/.cache/huggingface/datasets/text/default-dba86d70c11ab66c/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)\n"
77 | ],
78 | "name": "stderr"
79 | }
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "metadata": {
85 | "colab": {
86 | "base_uri": "https://localhost:8080/"
87 | },
88 | "id": "HdWEAfR3O4AA",
89 | "outputId": "c1d6b49a-c668-4b93-88b5-02bba1ab5b27"
90 | },
91 | "source": [
92 | "dataset['train']['text'][::-5]"
93 | ],
94 | "execution_count": 4,
95 | "outputs": [
96 | {
97 | "output_type": "execute_result",
98 | "data": {
99 | "text/plain": [
100 | "['מה אוכלים היום?',\n",
101 | " 'קפה זה טעים',\n",
102 | " 'אני אוהב את העולם',\n",
103 | " '\" וואו, מתאים לספר השיאים של גינס. מקסים ומרגש!!!! . לא יודעת מי המתורגמנית. אבל יודעת בוודאות כמה זה חשוב!!!! כל כך חשוב. יש מתורגמנית שכולם מכירים מהחדשות בטלויזיה, אפרת נגר המדהימה שהיא גם דולה... עבודת קודש אמיתית. מזל טוב יקרים. הרבה בריאות ושמחה.\"',\n",
104 | " ' והרכבת ? זה זמן עוד יותר טוב לרכבת שלא פעילה בכלל.. כל הכבוד זה הזמן אבל350 אלף כלי רכב בתקופת הקורונה????',\n",
105 | " ' והבס שלך לא עבריין??? 3 תיקי אישומים. ',\n",
106 | " '\" וביבי יצילנו מידם... הוא ולא שרף. 1. תחילה הסתה וגזענות - איבדתה את הרוב של הכנסת2.עתה נגד הדמוקרטיה - הצלחתה לאחד מולך את כל האופוזציה, כולל מבית3. צעד נוסף באותו כיוון - אשריך !\"']"
107 | ]
108 | },
109 | "metadata": {
110 | "tags": []
111 | },
112 | "execution_count": 4
113 | }
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "metadata": {
119 | "id": "5jlc1kleRxp9"
120 | },
121 | "source": [
122 | "\r\n",
123 | "from transformers import AutoTokenizer, AutoModel, pipeline\r\n",
124 | "\r\n",
125 | "# how to use?\r\n",
126 | "sentiment_analysis = pipeline(\r\n",
127 | " \"sentiment-analysis\",\r\n",
128 | " model=\"avichr/heBERT_sentiment_analysis\",\r\n",
129 | " tokenizer=\"avichr/heBERT_sentiment_analysis\", \r\n",
130 | " return_all_scores = True, \r\n",
131 | ")\r\n"
132 | ],
133 | "execution_count": 5,
134 | "outputs": []
135 | },
136 | {
137 | "cell_type": "code",
138 | "metadata": {
139 | "colab": {
140 | "base_uri": "https://localhost:8080/"
141 | },
142 | "id": "T88MJxFBR53F",
143 | "outputId": "2f273d1e-9af1-4698-8b12-c152642d8996"
144 | },
145 | "source": [
146 | "for i in dataset['train']['text'][0:5]:\r\n",
147 | " print(i, sentiment_analysis(i), sep='\\n')\r\n"
148 | ],
149 | "execution_count": 6,
150 | "outputs": [
151 | {
152 | "output_type": "stream",
153 | "text": [
154 | ],
155 | "name": "stdout"
156 | }
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "metadata": {
162 | "colab": {
163 | "base_uri": "https://localhost:8080/"
164 | },
165 | "id": "x2yC3wdRO_k3",
166 | "outputId": "fea1da5c-2136-4fd7-c627-6203d93b93bf"
167 | },
168 | "source": [
169 | "\r\n",
170 | "\r\n",
171 | "# load a BERT sentiment analysis model\r\n",
172 | "tokenizer = transformers.BertTokenizerFast.from_pretrained(\"avichr/heBERT_sentiment_analysis\")\r\n",
173 | "model = transformers.BertForSequenceClassification.from_pretrained(\"avichr/heBERT_sentiment_analysis\").cuda()\r\n",
174 | "\r\n",
175 | "# define a prediction function\r\n",
176 | "def f(x):\r\n",
177 | " tv = torch.tensor([tokenizer.encode(v, pad_to_max_length=True, max_length=500, truncation=True, padding='longest') for v in x]).cuda()\r\n",
178 | " outputs = model(tv)[0].detach().cpu().numpy()\r\n",
179 | " scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T\r\n",
180 | " val = sp.special.logit(scores[:,1]) # use one vs rest logit units\r\n",
181 | " return val\r\n",
182 | "\r\n",
183 | "# build an explainer using a token masker\r\n",
184 | "explainer = shap.Explainer(f, tokenizer)\r\n",
185 | "\r\n",
186 | "# explain the model's predictions on IMDB reviews\r\n",
187 | "train_dataset = dataset['train']['text']\r\n",
188 | "shap_values = explainer(train_dataset)\r\n"
189 | ],
190 | "execution_count": 7,
191 | "outputs": [
192 | {
193 | "output_type": "stream",
194 | "text": [
195 | "explainers.Partition is still in an alpha state, so use with caution...\n"
196 | ],
197 | "name": "stderr"
198 | }
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "metadata": {
204 | "id": "Ti31Ki0VPTwu"
205 | },
206 | "source": [
207 | "# shap.initjs()\r\n",
208 | "\r\n",
209 | "# shap.plots.text(shap_values)\r\n"
210 | ],
211 | "execution_count": 8,
212 | "outputs": []
213 | }
214 | ]
215 | }
216 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HeBERT: Pre-trained BERT for Polarity Analysis and Emotion Recognition
2 |
3 |
4 | HeBERT is a Hebrew pre-trained language model. It is based on [Google's BERT](https://arxiv.org/abs/1810.04805) architecture and it is a BERT-Base config.
5 |
6 | HeBERT was trained on three dataset:
7 | 1. A Hebrew version of [OSCAR](https://oscar-corpus.com/): ~9.8 GB of data, including 1 billion words and over 20.8 million sentences.
8 | 2. A Hebrew dump of [Wikipedia](https://dumps.wikimedia.org/): ~650 MB of data, including over 63 million words and 3.8 million sentences
9 | 3. Emotion User Generated Content (UGC) data that was collected for the purpose of this study (described below).
10 |
11 |
74 |
75 |
76 |
77 | ### For masked-LM model (can be [fine-tunned to any down-stream task](https://colab.research.google.com/drive/1_t-9ALMQ6WqVExJT2NEdR7x67CDaE1FP?usp=sharing))
78 |
79 | from transformers import AutoTokenizer, AutoModel
80 | tokenizer = AutoTokenizer.from_pretrained("avichr/heBERT")
81 | model = AutoModel.from_pretrained("avichr/heBERT")
82 |
83 | from transformers import pipeline
84 | fill_mask = pipeline(
85 | "fill-mask",
86 | model="avichr/heBERT",
87 | tokenizer="avichr/heBERT"
88 | )
89 | fill_mask("הקורונה לקחה את [MASK] ולנו לא נשאר דבר.")
90 |
91 | ### For sentiment classification model (polarity ONLY):
92 | from transformers import AutoTokenizer, AutoModel, pipeline
93 | tokenizer = AutoTokenizer.from_pretrained("avichr/heBERT_sentiment_analysis") #same as 'avichr/heBERT' tokenizer
94 | model = AutoModel.from_pretrained("avichr/heBERT_sentiment_analysis")
95 |
96 | # how to use?
97 | sentiment_analysis = pipeline(
98 | "sentiment-analysis",
99 | model="avichr/heBERT_sentiment_analysis",
100 | tokenizer="avichr/heBERT_sentiment_analysis",
101 | return_all_scores = True
102 | )
103 |
104 | sentiment_analysis('אני מתלבט מה לאכול לארוחת צהריים')
105 | >>> [[{'label': 'natural', 'score': 0.9978172183036804},
106 | >>> {'label': 'positive', 'score': 0.0014792329166084528},
107 | >>> {'label': 'negative', 'score': 0.0007035882445052266}]]
108 |
109 | sentiment_analysis('קפה זה טעים')
110 | >>> [[{'label': 'natural', 'score': 0.00047328314394690096},
111 | >>> {'label': 'possitive', 'score': 0.9994067549705505},
112 | >>> {'label': 'negetive', 'score': 0.00011996887042187154}]]
113 |
114 | sentiment_analysis('אני לא אוהב את העולם')
115 | >>> [[{'label': 'natural', 'score': 9.214012970915064e-05},
116 | >>> {'label': 'possitive', 'score': 8.876807987689972e-05},
117 | >>> {'label': 'negetive', 'score': 0.9998190999031067}]]
118 |
119 |
120 | Our model is also available on AWS! for more information visit [AWS' git](https://github.com/aws-samples/aws-lambda-docker-serverless-inference/tree/main/hebert-sentiment-analysis-inference-docker-lambda)
121 |
122 | ## Named-entity recognition (NER)
123 | The ability of the model to classify named entities in text, such as persons' names, organizations, and locations; tested on a labeled dataset from [Ben Mordecai and M Elhadad (2005)](https://www.cs.bgu.ac.il/~elhadad/nlpproj/naama/), and evaluated with F1-score.
124 | [Colab notebook](https://colab.research.google.com/drive/18Uhq3HMxudo1XWaYZ14GI-_Op5Ds9EGo?usp=sharing)
125 |
126 | ### How to use
127 | ```
128 | from transformers import pipeline
129 |
130 | # how to use?
131 | NER = pipeline(
132 | "token-classification",
133 | model="avichr/heBERT_NER",
134 | tokenizer="avichr/heBERT_NER",
135 | )
136 | NER('דויד לומד באוניברסיטה העברית שבירושלים')
137 | ```
138 |
139 |
140 | ## Contact us
141 | [Avichay Chriqui](mailto:avichayc@mail.tau.ac.il)