├── .gitignore ├── dialog_tag ├── __init__.py ├── label_mapping.py ├── config.py ├── get_weights.py └── DialogTag.py ├── setup.py ├── LICENSE ├── requirements.txt └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | DialogTag.egg-info/ 3 | dist/ 4 | build/ -------------------------------------------------------------------------------- /dialog_tag/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.3" 2 | from .DialogTag import DialogTag -------------------------------------------------------------------------------- /dialog_tag/label_mapping.py: -------------------------------------------------------------------------------- 1 | from .config import class_labels 2 | 3 | class LabelMapping(): 4 | def __init__(self, label_mapping_path): 5 | self.__class_names_file = label_mapping_path 6 | 7 | def __load_class_names(self): 8 | class_names = [] 9 | with open(self.__class_names_file, 'r') as fp: 10 | lines = fp.readlines() 11 | for line in lines: 12 | line = line.rstrip() 13 | class_names.append(line) 14 | return class_names 15 | 16 | def helper(self): 17 | class_names = self.__load_class_names() 18 | label_map = {} 19 | for row in class_names: 20 | act,label = row.strip().split('|') 21 | label_map[label] = act 22 | 23 | mapping = class_labels["MAPPING"] 24 | inv_map = {v: k for k, v in mapping.items()} 25 | 26 | sts = {} 27 | for k,v in inv_map.items(): 28 | sts[v] = label_map[v] 29 | 30 | return inv_map, sts -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", mode="r", encoding="utf-8") as readme_file: 4 | long_description = readme_file.read() 5 | 6 | setuptools.setup( 7 | name="DialogTag", 8 | version="1.1.3", 9 | author="Bhavitvya Malik", 10 | author_email="bhavitvya.malik@gmail.com", 11 | description="A python library to classify dialogue tag.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/bhavitvyamalik/DialogTag", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | 'transformers>=3.0.0', 18 | 'tqdm', 19 | 'tensorflow>=2.0.0' 20 | ], 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | ], 26 | python_requires='>=3.7', 27 | keywords="Tensorflow BERT NLP deep learning Transformer Networks " 28 | ) 29 | -------------------------------------------------------------------------------- /dialog_tag/config.py: -------------------------------------------------------------------------------- 1 | class_labels = { 2 | "MAPPING" : {'fo_o_fw_"_by_bc': '0', 'ft': '1', 'fc': '2', 'qw': '3', '^g': '4', 'bh': '5', 'qy': '6', 'qrr': '7', 3 | 'fp': '8', 'qo': '9', 'bk': '10', 'h': '11', 'sv': '12', 'ba': '13', 'nn': '14', '^h': '15', '^2': '16', 'aap_am': '17', 4 | 'qw^d': '18', 'qy^d': '19', 'ng': '20', 'fa': '21', 'b': '22', 'ny': '23', 't3': '24', 'sd': '25', 'br': '26', 5 | 'oo_co_cc': '27', 'arp_nd': '28', 't1': '29', '^q': '30', 'aa': '31', 'na': '32', 'b^m': '33', 'bd': '34', 'ad': '35', 6 | 'bf': '36', 'qh': '37'} 7 | } 8 | 9 | model_params = { 10 | "num_labels": 38 11 | } 12 | 13 | model_location = { 14 | "MODEL" : "/.dialog-tag/models", 15 | "label_mapping" : "/label_map.txt" 16 | } 17 | 18 | model_download_link = { 19 | "distilbert-base-uncased" : "https://www.dropbox.com/sh/crq7khtdd99u4mo/AABdQb8W1lJ37Cm-CfOSISuBa?dl=1", 20 | "bert-base-uncased" : "https://www.dropbox.com/sh/ajlwp36obho2cbe/AADHTY4_PhSOAQzveCJIna4Va?dl=1" 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bhavitvya Malik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | aiohttp==3.6.2 3 | astroid==2.4.2 4 | astunparse==1.6.3 5 | async-timeout==3.0.1 6 | attrs==20.2.0 7 | cachetools==4.1.1 8 | certifi==2020.6.20 9 | chardet==3.0.4 10 | click==7.1.2 11 | filelock==3.0.12 12 | gast==0.3.3 13 | google-auth==1.22.0 14 | google-auth-oauthlib==0.4.1 15 | google-pasta==0.2.0 16 | grpcio==1.32.0 17 | h5py==2.10.0 18 | idna==2.10 19 | importlib-metadata==2.0.0 20 | isort==5.5.4 21 | joblib==0.16.0 22 | Keras-Preprocessing==1.1.2 23 | lazy-object-proxy==1.4.3 24 | Markdown==3.2.2 25 | mccabe==0.6.1 26 | multidict==4.7.6 27 | numpy==1.19.2 28 | oauthlib==3.1.0 29 | opt-einsum==3.3.0 30 | packaging==20.4 31 | protobuf==3.13.0 32 | pyasn1==0.4.8 33 | pyasn1-modules==0.2.8 34 | pylint==2.6.0 35 | pyparsing==2.4.7 36 | regex==2020.9.27 37 | requests==2.24.0 38 | requests-oauthlib==1.3.0 39 | rsa==4.6 40 | sacremoses==0.0.43 41 | scipy==1.4.1 42 | sentencepiece==0.1.91 43 | six==1.15.0 44 | tensorboard==2.2.2 45 | tensorboard-plugin-wit==1.7.0 46 | tensorflow==2.2.0 47 | tensorflow-estimator==2.2.0 48 | termcolor==1.1.0 49 | tokenizers==0.8.0rc4 50 | toml==0.10.1 51 | tqdm==4.50.0 52 | transformers==3.0.0 53 | typed-ast==1.4.1 54 | typing-extensions==3.7.4.3 55 | urllib3==1.25.10 56 | Werkzeug==1.0.1 57 | wrapt==1.12.1 58 | yarl==1.6.0 59 | zipp==3.2.0 60 | -------------------------------------------------------------------------------- /dialog_tag/get_weights.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from zipfile import ZipFile 3 | import os 4 | from tqdm import tqdm as tqdm 5 | import logging 6 | from .config import model_download_link, model_location 7 | from pathlib import Path 8 | 9 | 10 | class Download: 11 | def __init__(self, model_name): 12 | self.__model_name = model_name 13 | self.__url = model_download_link[self.__model_name] 14 | 15 | 16 | try: 17 | self.__lib_path = f"{str(Path.home())}"+model_location["MODEL"] 18 | Path(self.__lib_path).mkdir(parents=True, exist_ok=False) 19 | except: 20 | print("Model directory already exists, chill!") 21 | 22 | 23 | self.__model_zip_path = os.path.join(self.__lib_path, self.__model_name+".zip") 24 | self.__final_path = os.path.join(self.__lib_path, self.__model_name) 25 | 26 | def download_file(self): 27 | if os.path.dirname(self.__model_zip_path) != '': 28 | os.makedirs(os.path.dirname(self.__model_zip_path), exist_ok=True) 29 | 30 | req = requests.get(self.__url, stream=True) 31 | if req.status_code != 200: 32 | print("Exception when trying to download {}. Response {}".format(self.__url, req.status_code), file=sys.stderr) 33 | req.raise_for_status() 34 | return 35 | 36 | download_filepath = self.__model_zip_path +"_part" 37 | with open(download_filepath, "wb") as file_binary: 38 | content_length = req.headers.get('Content-Length') 39 | total = int(content_length) if content_length is not None else None 40 | progress = tqdm(unit="B", total=total, unit_scale=True) 41 | for chunk in req.iter_content(chunk_size=1024): 42 | if chunk: # filter out keep-alive new chunks 43 | progress.update(len(chunk)) 44 | file_binary.write(chunk) 45 | 46 | os.rename(download_filepath, self.__model_zip_path) 47 | progress.close() 48 | 49 | zf = ZipFile(self.__model_zip_path, 'r') 50 | zf.extractall(path = self.__final_path) 51 | zf.close() 52 | 53 | os.remove(self.__model_zip_path) -------------------------------------------------------------------------------- /dialog_tag/DialogTag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import operator 4 | from pathlib import Path 5 | 6 | from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, AutoConfig 7 | 8 | from .config import model_params, model_location 9 | from .get_weights import Download 10 | from .label_mapping import LabelMapping 11 | from . import __version__ 12 | 13 | 14 | class DialogTag: 15 | def __init__(self, model_name): 16 | 17 | self.__model_name = model_name 18 | 19 | self.__lib_path = f"{str(Path.home())}"+ model_location["MODEL"] 20 | self.__model_path = os.path.join(self.__lib_path, self.__model_name) 21 | self.__label_mapping_path = os.path.join(self.__lib_path, self.__model_name) + model_location["label_mapping"] 22 | 23 | # print(self.__lib_path, self.__model_path, self.__label_mapping_path) 24 | path_exists = os.path.exists(self.__model_path) 25 | self.__num = 0 26 | if(path_exists==True): 27 | self.__num = len(os.listdir(self.__model_path)) 28 | 29 | if(self.__num<3 or path_exists==False): 30 | print("Model not found in cache. Downloading...") 31 | self.__model_file = Download(self.__model_name) 32 | self.__model_file.download_file() 33 | else: 34 | print(f"{self.__model_name} found in cache. Loading model...") 35 | 36 | self.__tokenizer = AutoTokenizer.from_pretrained(self.__model_name, do_lower_case=True) 37 | self.__config = AutoConfig.from_pretrained(self.__model_path, num_labels=model_params["num_labels"]) 38 | self.__model = TFAutoModelForSequenceClassification.from_pretrained(self.__model_path, config=self.__config) 39 | 40 | def __classhelper(self): 41 | mapping_object = LabelMapping(self.__label_mapping_path) 42 | 43 | logits_class, class_expanded = mapping_object.helper() 44 | 45 | return logits_class, class_expanded 46 | 47 | def predict_tag(self, sentence): 48 | predict_input = self.__tokenizer.encode(sentence, 49 | truncation=True, 50 | padding=True, 51 | return_tensors="tf") 52 | 53 | tf_output = self.__model.predict(predict_input)[0] 54 | tf_prediction = tf.nn.softmax(tf_output, axis=1).numpy()[0] 55 | index, value = max(enumerate(tf_prediction), key=operator.itemgetter(1)) 56 | 57 | # print(value) 58 | 59 | logits_class, class_expanded = self.__classhelper() 60 | 61 | return class_expanded[logits_class[str(index)]] 62 | 63 | # test_sentence = "With their homes in ashes, residents share harrowing tales of survival after massive wildfires kill 15" 64 | 65 | 66 | if __name__=='__main__': 67 | A = DialogTag('distilbert-base-uncased') 68 | # z = A.predict_tag("With their homes in ashes, residents share harrowing tales of survival after massive wildfires kill 15") 69 | z = A.predict_tag("Stop talking silly!") 70 | print(z) 71 | 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DialogTag 2 | 3 | Dialogue act classification is the task of classifying an utterance with respect to the function it serves in a dialogue, i.e. the act the speaker is performing. This python library essentially does dialogue act classification on the Switchboard corpus. 4 | 5 | The Switchboard-1 corpus is a telephone speech corpus, consisting of about 2,400 two-sided telephone conversation among 543 speakers with about 70 provided conversation topics. The dataset includes the audio files and the transcription files, as well as information about the speakers and the calls. 6 | A subset of the Switchboard-1 corpus consisting of 1155 conversations was used. The resulting tags include dialogue acts like statement-non-opinion, acknowledge, statement-opinion, agree/accept, etc. 7 | 8 | Annotated example: 9 | >Speaker: A, Dialogue Act: Yes-No-Question, Utterance: So do you go to college right now? 10 | 11 | The original dataset contained around 42 tags but here we brought them down to 38 by removing a few redundant and ad-hoc tags. The available tags: 12 | 13 | | TAG | EXAMPLE | 14 | |------------------------------|---------------------------------------------------| 15 | | Statement-non-opinion | *Me, I'm in the legal department.* | 16 | | Acknowledge (Backchannel) | *Uh-huh.* | 17 | | Statement-opinion | *I think it's great* | 18 | | Agree/Accept | *That's exactly it.* | 19 | | Appreciation | *I can imagine.* | 20 | | Yes-No-Question | *Do you have to have any special training?* | 21 | | Yes answers | *Yes.* | 22 | | Conventional-closing | *Well, it's been nice talking to you.* | 23 | | Uninterpretable | *But, uh, yeah* | 24 | | Wh-Question | *Well, how old are you?* | 25 | | No answers | *No.* | 26 | | Response Acknowledgement | *Oh, okay.* | 27 | | Hedge | *I don't know if I'm making any sense or not.* | 28 | | Declarative Yes-No-Question | *So you can afford to get a house?* | 29 | | Other | *Well give me a break, you know.* | 30 | | Backchannel in question form | *Is that right?* | 31 | | Quotation | *You can't be pregnant and have cats* | 32 | | Summarize/reformulate | *Oh, you mean you switched schools for the kids.* | 33 | | Affirmative non-yes answers | *It is.* | 34 | | Action-directive | *Why don't you go first* | 35 | | Collaborative Completion | *Who aren't contributing.* | 36 | | Repeat-phrase | *Oh, fajitas* | 37 | | Open-Question | *How about you?* | 38 | | Rhetorical-Questions | *Who would steal a newspaper?* | 39 | | Hold before answer/agreement | *I'm drawing a blank.* | 40 | | Negative non-no answers | *Uh, not a whole lot.* | 41 | | Signal-non-understanding | *Excuse me?* | 42 | | Conventional-opening | *How are you?* | 43 | | Or-Clause | *or is it more of a company?* | 44 | | Dispreferred answers | *Well, not so much that.* | 45 | | 3rd-party-talk | *My goodness, Diane, get down from there.* | 46 | | Offers, Options Commits | *I'll have to check that out* | 47 | | Self-talk | *What's the word I'm looking for* | 48 | | Downplayer | *That's all right.* | 49 | | Maybe/Accept-part | *Something like that* | 50 | | Tag-Question | *Right?* | 51 | | Declarative Wh-Question | *You are what kind of buff?* | 52 | | Apology | *I'm sorry.* | 53 | | Thanking | *Hey thanks a lot* | 54 | 55 | ## Installation 56 | 57 | We recommend Python 3.7 or higher, Tensorflow 2.0.0 or higher and Transformers v3.0.0 or higher. 58 | 59 | ### Install with pip 60 | 61 | Install the DialogTag with pip: 62 | 63 | ``` 64 | pip install -U DialogTag 65 | ``` 66 | 67 | ### Install from sources 68 | 69 | Alternatively, you can also clone the latest version from the repository and install it directly from the source code: 70 | 71 | ``` 72 | pip install -e . 73 | ``` 74 | 75 | ## DialogTag in action 76 | 77 | This quick example will show you how to use DialogTag in your code. 78 | ``` 79 | from dialog_tag import DialogTag 80 | 81 | model = DialogTag('distilbert-base-uncased') 82 | 83 | sentence = "I'll probably go to shopping today." 84 | output = model.predict_tag(sentence) 85 | print(output) 86 | # output: 'Statement-non-opinion' 87 | 88 | sentence = "Why are you asking me this question again and again?" 89 | output = model.predict_tag(sentence) 90 | print(output) 91 | # output: 'Wh-Question' 92 | ``` 93 | 94 | ## Available pre-trained models 95 | 96 | Update: We now have bert-base-uncased model available! 97 | 98 | Currently we have only distilbert-base-uncased available. We're planning to extend it to bert-base-uncased and roberta-base. Stay tuned for further updates! 99 | 100 | ## License 101 | 102 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details --------------------------------------------------------------------------------