├── .gitignore ├── LICENSE ├── Readme.md ├── dataset ├── __init__.py ├── nlu_QnA_converted.yml └── nlu_converted.yml ├── demo ├── __init__.py └── server.py ├── main.py ├── requirements.txt └── src ├── __init__.py ├── config.yml ├── data_reader ├── __init__.py ├── data_reader.py └── dataset.py ├── models ├── __init__.py ├── classifier.py ├── trainer.py └── wrapper.py └── utils └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | venv/ 141 | logs/ 142 | results/ 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Wei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # DIETClassifier - Pytorch 2 | 3 | 4 | [![Build Status](https://travis-ci.org/joemccann/dillinger.svg?branch=master)](https://travis-ci.org/joemccann/dillinger) 5 | 6 | DIETClassifier stand for Dual Intent Entity from Transformers which can be used to do intent classification and entities recognition at the same time. 7 | 8 | - Using Huggingface Transformers's BERT architect 9 | - Wrapped by python, with various implemented functions (reads dataset from .yml, builds and trains model, gives dictionary ouput) 10 | 11 | # Requirements 12 | 13 | * [transformers] - Library for using transformers models in nlp task 14 | * [pytorch] - Framework for deep learning task in python 15 | * [fastapi] - Backend building framework 16 | 17 | You can also install all requirement packages by: 18 | ```sh 19 | git clone https://github.com/WeiNyn/DIETClassifier-pytorch.git 20 | cd DIETClassifier-pytorch/ 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Demo 25 | 26 | You can use demo server to create a server that receive text message and predict intent, entities: 27 | 28 | - Download pretrained model from [this link](https://drive.google.com/drive/folders/1cAucUHO0FP_I-_atSpbyRwKEiflPPN7v?usp=sharing) 29 | - extract "latest_model" to "DIETClassifier-pytorch/" 30 | - run 31 | ```sh 32 | uvicorn demo.server:app 33 | ``` 34 | 35 | ### Configuration 36 | 37 | All project configurations stored in [config.yml] file 38 | ```yaml 39 | model: 40 | model: latest_model 41 | tokenizer: latest_model 42 | dataset_folder: dataset 43 | exclude_file: null 44 | entities: 45 | - working_type 46 | - shift_type 47 | intents: 48 | - WorkTimesBreaches 49 | - WorkingTimeBreachDiscipline 50 | - HolidaysOff 51 | - AnnualLeaveApplicationProcess 52 | - SetWorkingType 53 | - TemporarySetWorkingType 54 | - WorkingHours 55 | - WorkingDay 56 | - BreakTime 57 | - Pregnant 58 | - AttendanceRecord 59 | - SelectShiftType 60 | - LaborContract 61 | - Recruitment 62 | - SickLeave 63 | - UnpaidLeave 64 | - PaidLeaveForFamilyEvent 65 | - UnusedAnnualLeave 66 | - RegulatedAnnualLeave 67 | - rating 68 | device: cuda 69 | training: 70 | train_range: 0.95 71 | num_train_epochs: 100 72 | per_device_train_batch_size: 4 73 | per_device_eval_batch_size: 4 74 | warmup_steps: 500 75 | weight_decay: 0.01 76 | logging_dir: logs/ 77 | early_stopping_patience: 10 78 | early_stopping_threshold: 0.0001 79 | output_dir: results/ 80 | util: 81 | intent_threshold: 0.7 82 | entities_threshold: 0.5 83 | ambiguous_threshold: 0.2 84 | ``` 85 | 86 | | Attribute | Explain | 87 | | --------- | ------- | 88 | | model | name of transformers pretrained model or path to local model | 89 | | tokenizer | name of transformers pretrained tokenizer or path to local tokenizer | 90 | | dataset_folder | folder that container dataset files, using rasa nlu format | 91 | | exclude_file | files in folder that will not be used to train | 92 | | entities | list of entities | 93 | | intents | list of intents | 94 | | synonym | synonym list for synonym entities | 95 | | device | device to use ("cpu", "cuda", "cuda:0", etc) | 96 | | train_range | range to split dataset into train and valid set | 97 | | num_train_epochs | number of training epochs | 98 | | per_device_train/eval_batch_size | batch size when train/eval | 99 | | logging_dir | directory to save log file (tensorboard supported) | 100 | | early_stopping_patience/threshold | hyper parameters for early stopping training | 101 | | output_dir | directory to save model while training | 102 | 103 | ### Usage 104 | 105 | You can use DIETClassifierWrapper for loading, training, predicting in python code: 106 | ```python 107 | from src.models.wrapper import DIETClassifierWrapper 108 | 109 | config_file = "src/config.yml" 110 | wrapper = DIETClassifierWrapper(config=config_file) 111 | 112 | #predict 113 | wrapper.predict(["How to check attendance?"]) 114 | 115 | #train 116 | #after training, wrapper will load best model automatically 117 | wrapper.train_model(save_folder="test_model") 118 | ``` 119 | 120 | You can also use DIETClassifier in src.models.classifier as huggingface transformers model 121 | ```python 122 | from src.models.classifier import DIETClassifier, DIETClassifierConfig 123 | 124 | config = DIETClassifierConfig(model="BERT-base-uncased", 125 | intents=[str(i) for i in range(10)], 126 | entities=[str(i) for i in range(5)]) 127 | 128 | model = DIETClassifier(config=config) 129 | 130 | ``` 131 | 132 | ### Notice 133 | 134 | * This DIETClassifier using BERT base as the base architect, if you want to change to RoBerta, ALBert, etc. You need to modify the DIETClassifier Class. 135 | * You can also use any BERT base pretrained from Huggingface transformers for creating and fine tune yourself 136 | * Please read the source code to understand how the dataset be created in case that you want to make dataset in another file format. 137 | * If you get the error: AttributeError: """'NoneType' object has no attribute 'detach'""", please check the issue #5 138 | 139 | [//]: # (These are reference links used in the body of this note and get stripped out when the markdown processor does its job. There is no need to format nicely because it shouldn't be seen. Thanks SO - http://stackoverflow.com/questions/4823468/store-comments-in-markdown-syntax) 140 | 141 | 142 | [dill]: 143 | [git-repo-url]: 144 | [john gruber]: 145 | [df1]: 146 | [markdown-it]: 147 | [Ace Editor]: 148 | [node.js]: 149 | [Twitter Bootstrap]: 150 | [jQuery]: 151 | [@tjholowaychuk]: 152 | [express]: 153 | [AngularJS]: 154 | [Gulp]: 155 | 156 | [PlDb]: 157 | [PlGh]: 158 | [PlGd]: 159 | [PlOd]: 160 | [PlMe]: 161 | [PlGa]: 162 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/nlu_QnA_converted.yml: -------------------------------------------------------------------------------- 1 | version: "2.0" 2 | nlu: 3 | - intent: WorkTimesBreaches 4 | examples: | 5 | - What is late in and early out? 6 | - what is late-in? 7 | - what is late-in and early-out? 8 | - what is early-out? 9 | - What is early out? 10 | - How about coming late to work? 11 | - Late in? 12 | - late-in? 13 | - early out? 14 | - early-out? 15 | - intent: WorkingTimeBreachDiscipline 16 | examples: | 17 | - What if I have late-in record? 18 | - What if I have early-out record? 19 | - What if I late-in 20 | - What if I late in 21 | - What if I come late 22 | - What if I early-out 23 | - What if I early out 24 | - What if I come home early 25 | - How many time can I early-out 26 | - How many time can I late-in 27 | - How many time can I early-out and late-in 28 | - How many time can I late in 29 | - intent: HolidaysOff 30 | examples: | 31 | - How many holidays a year 32 | - Can I get information about holidays 33 | - Do we have holidays? 34 | - What is the time for new year vacation 35 | - How about lunar new year vacation 36 | - lunar new year 37 | - new year vacation 38 | - Hung king day 39 | - holiday 40 | - vacation 41 | - I want to know about holiday of years 42 | - How many holiday in a year 43 | - intent: AnnualLeaveApplicationProcess 44 | examples: | 45 | - How about annual leave? 46 | - annual leave? 47 | - annual leave process 48 | - How about annual leave process 49 | - How about annual leave application process 50 | - intent: SetWorkingType 51 | examples: | 52 | - I work on [office hours](working_type) 53 | - I work on [office hour]{"entity": "working_type", "value": "office hours"} 54 | - I'm working on [office hours](working_type) 55 | - I'm working on [office hour]{"entity": "working_type", "value": "office hours"} 56 | - I work on [shift](working_type) 57 | - I'm working on [shift](working_type) 58 | - I am working on [office hours](working_type) 59 | - I am working on [office hour]{"entity": "working_type", "value": "office hours"} 60 | - on [shift](working_type) 61 | - on [office hours](working_type) 62 | - [shift](working_type) 63 | - [office hours](working_type) 64 | - oh, I work on [office hours](working_type) 65 | - how about working on [shift](working_type) 66 | - today 67 | - intent: TemporarySetWorkingType 68 | examples: | 69 | - How about [office hours](working_type) 70 | - How about [office hour]{"entity": "working_type", "value": "office hours"} 71 | - How about [shift](working_type) 72 | - What if I work on [office hours](working_type) 73 | - What if I work on [office hour]{"entity": "working_type", "value": "office hours"} 74 | - What if I work on [shift](working_type) 75 | - how about work on [shift](working_type) 76 | - how about [office hours](working_type) 77 | - how about [shift](working_type) 78 | - intent: WorkingHours 79 | examples: | 80 | - What are average working hours per day? 81 | - Average working hours per day? 82 | - What are average woring hours per week? 83 | - Average working hours per week? 84 | - What are average woring hours per month? 85 | - Average working hours per month? 86 | - How about working time? 87 | - Working time? 88 | - What is office-working time? 89 | - Can you tell me about working time? 90 | - I want to know about working hours 91 | - About working hours 92 | - Working time in day 93 | - working hours 94 | - how about [office](working_type) working hours 95 | - how about [office hours](working_type) working time ? 96 | - How about working hours 97 | - about working hours 98 | - how about working hours 99 | - so, what is the average working hours 100 | - intent: WorkingDay 101 | examples: | 102 | - How many work day per week? 103 | - About off day of week? 104 | - off day? 105 | - Working day of week? 106 | - intent: BreakTime 107 | examples: | 108 | - What is break time? 109 | - Do we have break time? 110 | - How about break time? 111 | - When do we have lunch 112 | - When can I break 113 | - Break time? 114 | - Can I sleep? 115 | - how about break time ? 116 | - When do I have break time 117 | - about break time 118 | - When do we have break time 119 | - how about break time 120 | - intent: Pregnant 121 | examples: | 122 | - Working time for pregnant employee? 123 | - I'm pregnant? 124 | - Will this change when I'm pregnant 125 | - but I am pregnant 126 | - How about pregnant employee 127 | - Working time for pregnant employees 128 | - pregnant 129 | - pregnant policy 130 | - Do this thing change when I'm pregnant 131 | - I'm pregnant 132 | - intent: AttendanceRecord 133 | examples: | 134 | - How about attendence record? 135 | - About checking attendence? 136 | - How to check attendence? 137 | - How can I check attendence? 138 | - I mean check attendace 139 | - I want to know about attendance record 140 | - About attendance recorded 141 | - How to check attendance 142 | - intent: SelectShiftType 143 | examples: | 144 | - [day shift](shift_type) 145 | - [afternoon shift](shift_type) 146 | - [night shift](shift_type) 147 | - [night shift](shift_type) please 148 | - [afternoon shift](shift_type) please 149 | - intent: LaborContract 150 | examples: | 151 | - how about labor contract 152 | - labor contract 153 | - time for definite labor contract 154 | - time for indefinite labor contract 155 | - Time for indefinite labor contract 156 | - intent: Recruitment 157 | examples: | 158 | - How about recruitment policy 159 | - recruitment process 160 | - intent: SickLeave 161 | examples: | 162 | - How about sick leave 163 | - Sick leave 164 | - How many day of sick leave 165 | - What if I got a cold 166 | - Will I get paid for sick leave 167 | - How to get paid for sick leave 168 | - intent: UnpaidLeave 169 | examples: | 170 | - How many unpaid leave can I get 171 | - Can I leave after using all of my annual leave day 172 | - Unpaid leave 173 | - How about unpaid leave 174 | - unpaid leave please 175 | - Can I leave after using all of my annual leave day 176 | - Can I leave after using all of my annual leave day 177 | - intent: PaidLeaveForFamilyEvent 178 | examples: | 179 | - Can I leave for family events 180 | - Did I get paid for my family events 181 | - Will I get paid for my family events 182 | - Will I get paid for Marriage 183 | - Will I get paid for the dead of family member 184 | - Can I leave for my Marriage 185 | - Can I get paid for family event off 186 | - intent: UnusedAnnualLeave 187 | examples: | 188 | - What can I do with my unused annual leave day 189 | - What can I do with my unused day 190 | - Can I exchange my annual leave day 191 | - How can I use my unused annual leave day 192 | - what can I do with my unused annual leave day 193 | - What can I do with my unused annual leave 194 | - intent: RegulatedAnnualLeave 195 | examples: | 196 | - How many annual leave for employee? 197 | - How about number of annual leave day? 198 | - How many off day I have each year? 199 | - Annual leave day in year 200 | - Annual leave days 201 | - how many annual leave day each year 202 | - how many annual leave 203 | - how about annual leave day 204 | - How many annual leave I have 205 | - synonym: office hours 206 | examples: | 207 | - office hour 208 | - office 209 | - intent: rating 210 | examples: | 211 | - [1](number) 212 | - [2](number) 213 | - [3](number) 214 | - [4](number) 215 | - [5](number) 216 | -------------------------------------------------------------------------------- /dataset/nlu_converted.yml: -------------------------------------------------------------------------------- 1 | version: "2.0" 2 | nlu: 3 | - intent: greet 4 | examples: | 5 | - hey 6 | - hello 7 | - hi 8 | - good morning 9 | - good evening 10 | - hey there 11 | - intent: goodbye 12 | examples: | 13 | - bye 14 | - goodbye 15 | - see you around 16 | - see you later 17 | - quit 18 | - go home 19 | - intent: affirm 20 | examples: | 21 | - yes 22 | - indeed 23 | - of course 24 | - that sounds good 25 | - correct 26 | - intent: deny 27 | examples: | 28 | - no 29 | - never 30 | - I don't think so 31 | - don't like that 32 | - no way 33 | - not really 34 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/demo/__init__.py -------------------------------------------------------------------------------- /demo/server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.middleware.cors import CORSMiddleware 3 | from fastapi.encoders import jsonable_encoder 4 | from fastapi.responses import JSONResponse 5 | 6 | import sys 7 | import os 8 | 9 | sys.path.append(os.getcwd()) 10 | 11 | from src.models.wrapper import DIETClassifierWrapper 12 | 13 | CONFIG_FILE = "src/config.yml" 14 | 15 | wrapper = DIETClassifierWrapper(CONFIG_FILE) 16 | 17 | app = FastAPI() 18 | 19 | origins = ["*"] 20 | 21 | app.add_middleware( 22 | CORSMiddleware, 23 | allow_origins=origins, 24 | allow_credentials=True, 25 | allow_methods=["*"], 26 | allow_headers=["*"], 27 | ) 28 | 29 | 30 | @app.post("/detect") 31 | async def detect(input: str): 32 | output = wrapper.predict([input])[0] 33 | 34 | del output["intent_ranking"] 35 | 36 | response = jsonable_encoder(output) 37 | return JSONResponse(response) 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press Shift+F10 to execute it or replace it with your code. 4 | # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. 5 | 6 | 7 | def print_hi(name): 8 | # Use a breakpoint in the code line below to debug your script. 9 | print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. 10 | 11 | 12 | # Press the green button in the gutter to run the script. 13 | if __name__ == '__main__': 14 | print_hi('PyCharm') 15 | 16 | # See PyCharm help at https://www.jetbrains.com/help/pycharm/ 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.63.0 2 | pandas==1.2.0 3 | PyYAML==5.3.1 4 | torch==1.7.0 5 | transformers==4.2.1 6 | uvicorn==0.13.3 7 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/src/__init__.py -------------------------------------------------------------------------------- /src/config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | model: latest_model 3 | tokenizer: latest_model 4 | dataset_folder: dataset 5 | exclude_file: null 6 | entities: 7 | - working_type 8 | - shift_type 9 | intents: 10 | - WorkTimesBreaches 11 | - WorkingTimeBreachDiscipline 12 | - HolidaysOff 13 | - AnnualLeaveApplicationProcess 14 | - SetWorkingType 15 | - TemporarySetWorkingType 16 | - WorkingHours 17 | - WorkingDay 18 | - BreakTime 19 | - Pregnant 20 | - AttendanceRecord 21 | - SelectShiftType 22 | - LaborContract 23 | - Recruitment 24 | - SickLeave 25 | - UnpaidLeave 26 | - PaidLeaveForFamilyEvent 27 | - UnusedAnnualLeave 28 | - RegulatedAnnualLeave 29 | - rating 30 | - greet 31 | - goodbye 32 | - affirm 33 | - deny 34 | device: cuda 35 | synonym: 36 | office hour: office hours 37 | office: office hours 38 | office hours: office hour 39 | training: 40 | train_range: 0.95 41 | num_train_epochs: 100 42 | per_device_train_batch_size: 4 43 | per_device_eval_batch_size: 4 44 | warmup_steps: 500 45 | weight_decay: 0.01 46 | logging_dir: logs/ 47 | early_stopping_patience: 10 48 | early_stopping_threshold: 0.0001 49 | output_dir: results/ 50 | util: 51 | intent_threshold: 0.7 52 | entities_threshold: 0.5 53 | ambiguous_threshold: 0.2 54 | -------------------------------------------------------------------------------- /src/data_reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/src/data_reader/__init__.py -------------------------------------------------------------------------------- /src/data_reader/data_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import List, Dict, Tuple 4 | 5 | import pandas as pd 6 | import yaml 7 | 8 | """regex filter for preprocessing text data""" 9 | NORMAL_REGEX = "\[[\w+\s*]+\]\([\w+\s*]+\)" 10 | ENTITY_REGEX = "\[[\w+\s*]+\]" 11 | ENTITY_NAME_REGEX = "\([\w+\s*]+\)" 12 | SYNONYM_REGEX = "\[[\w+\s*]+\]\{.+\}" 13 | DATA_REGEX = "\{.+\}" 14 | 15 | 16 | def make_dataframe(files: List[str]) -> Tuple[pd.DataFrame, List[str], List[str], Dict[str, str]]: 17 | """ 18 | Make data frame for DIETClassifier dataset from files list 19 | 20 | :param files: list of files location 21 | :return: tuple(dataframe, list of entities class name, list of intent class name, synonym dictionary) 22 | """ 23 | 24 | data = [] 25 | synonym_dict = {} 26 | for file in files: 27 | data += read_from_yaml(file=file) 28 | 29 | df = pd.DataFrame(columns=["example", "intent", "entities"]) 30 | 31 | for intent in data: 32 | if not intent.get("intent", None): 33 | if intent.get("synonym", None): 34 | target_entity = intent["synonym"] 35 | 36 | entities_list_text = intent["examples"] 37 | if entities_list_text[:2] == "- ": 38 | entities_list_text = entities_list_text[2:] 39 | if entities_list_text[-1:] == "\n": 40 | entities_list_text = entities_list_text[:-1] 41 | 42 | synonym_entities_list = entities_list_text.split("\n- ") 43 | for entity in synonym_entities_list: 44 | synonym_dict[entity] = target_entity 45 | 46 | continue 47 | 48 | intent_name = intent["intent"] 49 | examples_as_text = intent["examples"] 50 | 51 | if examples_as_text[:2] == "- ": 52 | examples_as_text = examples_as_text[2:] 53 | if examples_as_text[-1:] == "\n": 54 | examples_as_text = examples_as_text[:-1] 55 | 56 | examples = examples_as_text.split("\n- ") 57 | 58 | new_data = dict( 59 | intent=[intent_name for i in range(len(examples))], 60 | example=examples, 61 | entities=[None for i in range(len(examples))] 62 | ) 63 | 64 | df = df.append(pd.DataFrame(data=new_data), ignore_index=True) 65 | 66 | df = get_entity(df=df) 67 | df, updated_synonym_dict = get_entity_with_synonym(df=df) 68 | 69 | synonym_dict.update(updated_synonym_dict) 70 | 71 | entities_list = [] 72 | intents_list = [] 73 | for _, row in df.iterrows(): 74 | entity_data = row["entities"] 75 | if isinstance(entity_data, str): 76 | try: 77 | entity_data = json.loads(entity_data) 78 | except Exception as ex: 79 | raise RuntimeError(f"Cannot convert entity_data to json: {entity_data}") 80 | 81 | for entity in entity_data: 82 | entity_name = entity.get("entity_name") 83 | if entity_name not in entities_list: 84 | entities_list.append(entity_name) 85 | 86 | if row["intent"] not in intents_list: 87 | intents_list.append(row["intent"]) 88 | 89 | return df, entities_list, intents_list, synonym_dict 90 | 91 | 92 | def read_from_yaml(file: str) -> List[Dict[str, str]]: 93 | """ 94 | Read data from .yml file 95 | 96 | :param file: file location (this data file need to follow the rasa nlu annotation format) 97 | :return: list(dict(text, any)) 98 | """ 99 | try: 100 | f = open(file, "r") 101 | except Exception as ex: 102 | raise RuntimeError(f"Cannot read file {file} with error:\t{ex}") 103 | 104 | data = yaml.load(f)["nlu"] 105 | return data 106 | 107 | 108 | def get_entity(df: pd.DataFrame) -> pd.DataFrame: 109 | """ 110 | extract entities in example sentences 111 | 112 | :param df: dataframe to process 113 | :return: precessed dataframe 114 | """ 115 | for _, row in df.iterrows(): 116 | entity_data = row["entities"] 117 | if not entity_data: 118 | entity_data = [] 119 | 120 | while True: 121 | example = row["example"] 122 | x = re.search(NORMAL_REGEX, example) 123 | if x is None: 124 | break 125 | 126 | start, end = x.span() 127 | entity = x.group() 128 | 129 | entity_text = re.search(ENTITY_REGEX, entity).group()[1:-1] 130 | entity_name_text = re.search(ENTITY_NAME_REGEX, entity).group()[1:-1] 131 | 132 | row["example"] = example.replace(entity, entity_text) 133 | 134 | entity_data.append(dict( 135 | entity=entity_text, 136 | entity_name=entity_name_text, 137 | position=(start, end - (len(entity) - len(entity_text))) 138 | )) 139 | 140 | row["entities"] = entity_data 141 | 142 | return df 143 | 144 | 145 | def get_entity_with_synonym(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, str]]: 146 | """ 147 | Extract entities with synonym in dataframe. 148 | 149 | :param df: dataframe to process 150 | :return: tuple(processed dataframe, synonym dictionary) 151 | """ 152 | synonym_dict = {} 153 | for _, row in df.iterrows(): 154 | entity_data = row["entities"] 155 | if not entity_data: 156 | entity_data = [] 157 | 158 | if isinstance(entity_data, str): 159 | try: 160 | entity_data = json.loads(entity_data) 161 | except Exception as ex: 162 | raise RuntimeError(f"Cannot convert entity_data to json: {entity_data}") 163 | 164 | while True: 165 | example = row["example"] 166 | x = re.search(SYNONYM_REGEX, example) 167 | if x is None: 168 | break 169 | 170 | start, end = x.span() 171 | entity = x.group() 172 | 173 | entity_text = re.search(ENTITY_REGEX, entity).group()[1:-1] 174 | synonym_text = re.search(DATA_REGEX, entity).group() 175 | 176 | try: 177 | synonym_data = json.loads(synonym_text) 178 | except Exception as ex: 179 | raise ValueError(f"Synonym json is incorrect: {synonym_text}") 180 | 181 | entity_name_text = synonym_data.get("entity", None) 182 | synonym_value = synonym_data.get("value", None) 183 | 184 | if entity_name_text is None or synonym_value is None: 185 | raise ValueError(f"synonym data should have 'entity' and 'value' attributes") 186 | 187 | row["example"] = example.replace(entity, entity_text) 188 | 189 | entity_data.append(dict( 190 | entity=entity_text, 191 | entity_name=entity_name_text, 192 | position=(start, end - (len(entity) - len(entity_text))), 193 | synonym=synonym_value 194 | )) 195 | 196 | synonym_dict[synonym_value] = entity_text 197 | 198 | row["entities"] = entity_data 199 | 200 | return df, synonym_dict 201 | 202 | 203 | if __name__ == '__main__': 204 | import os 205 | import sys 206 | 207 | sys.path.append(os.getcwd()) 208 | 209 | files = ["dataset/nlu_QnA_converted.yml", "dataset/nlu_QnA_converted.yml"] 210 | 211 | df, entities_list, intents_list, synonym_dict = make_dataframe(files) 212 | print(df.head(20)) 213 | print(entities_list) 214 | print(intents_list) 215 | print(synonym_dict) 216 | -------------------------------------------------------------------------------- /src/data_reader/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Dict, Text, Any, Union 3 | 4 | import pandas as pd 5 | import torch 6 | 7 | 8 | class DIETClassifierDataset: 9 | def __init__(self, dataframe: pd.DataFrame, tokenizer, entities: List[str], intents: List[str]): 10 | """ 11 | dataset for DIETClassifier 12 | 13 | :param dataframe: dataframe contains ["example", "intent", "entities"] columns 14 | :param tokenizer: tokenizer from transformers 15 | :param entities: list of entities class names 16 | :param intents: list of intents class names 17 | """ 18 | dataframe = dataframe[dataframe["intent"].isin(intents)] 19 | 20 | self.entities = ["O"] + entities 21 | dataframe["entities"] = dataframe["entities"].apply(self._remove_entities) 22 | 23 | self.tokenizer = tokenizer 24 | self.num_entities = len(self.entities) 25 | self.intents = intents 26 | self.num_intents = len(intents) 27 | 28 | sentences = dict( 29 | sentence=[], 30 | entities=[], 31 | intent=[] 32 | ) 33 | 34 | for _, row in dataframe.iterrows(): 35 | sentences["sentence"].append(row["example"]) 36 | sentences["entities"].append(row["entities"]) 37 | sentences["intent"].append(row["intent"]) 38 | 39 | sentences.update(tokenizer(sentences["sentence"], return_tensors="pt", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=512)) 40 | 41 | sentences["entities_labels"] = [] 42 | 43 | for index in range(len(sentences["sentence"])): 44 | entities_labels = [] 45 | for offset in sentences["offset_mapping"][index][1:]: 46 | is_label = False 47 | if not (offset[0] == 0 and offset[1] == 0): 48 | for entity in sentences["entities"][index]: 49 | if entity["position"][0] <= offset[0] and entity["position"][1] >= offset[1]: 50 | entities_labels.append(self.entities.index(entity["entity_name"])) 51 | is_label = True 52 | if not is_label: 53 | entities_labels.append(self.entities.index("O")) 54 | 55 | sentences["entities_labels"].append(entities_labels) 56 | 57 | sentences["entities_labels"] = torch.tensor(sentences["entities_labels"]) 58 | sentences["intent_labels"] = torch.tensor([self.intents.index(intent) for intent in sentences["intent"]]) 59 | 60 | self.data = sentences 61 | 62 | def __len__(self) -> int: 63 | return len(self.data["sentence"]) 64 | 65 | def __getitem__(self, index) -> Dict[Text, Any]: 66 | item = dict( 67 | input_ids=self.data["input_ids"][index], 68 | token_type_ids=self.data["token_type_ids"][index], 69 | attention_mask=self.data["attention_mask"][index], 70 | entities_labels=self.data["entities_labels"][index], 71 | intent_labels=self.data["intent_labels"][index] 72 | ) 73 | 74 | return item 75 | 76 | def _remove_entities(self, entities_list: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: 77 | if isinstance(entities_list, str): 78 | try: 79 | entities_list = json.loads(entities_list) 80 | except Exception as ex: 81 | raise RuntimeError(f"Cannot convert entity {entities_list} by error: {ex}") 82 | 83 | entities_list = [entity for entity in entities_list if entity["entity_name"] in self.entities] 84 | 85 | return entities_list 86 | 87 | 88 | if __name__ == "__main__": 89 | import os 90 | import sys 91 | 92 | sys.path.append(os.getcwd()) 93 | 94 | from src.data_reader.data_reader import make_dataframe 95 | from transformers import AutoTokenizer 96 | 97 | files = ["dataset/nlu_QnA_converted.yml", "dataset/nlu_QnA_converted.yml"] 98 | tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") 99 | 100 | df, entities_list, intents_list, synonym_dict = make_dataframe(files) 101 | dataset = DIETClassifierDataset(dataframe=df, tokenizer=tokenizer, entities=entities_list, intents=intents_list) 102 | 103 | print(len(dataset)) 104 | print(dataset[120]) 105 | print(dataset.data["entities"][120]) 106 | print(dataset.data["offset_mapping"][120]) 107 | 108 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | 7 | from transformers import BertForTokenClassification, BertPreTrainedModel, BertModel 8 | from transformers.configuration_utils import PretrainedConfig 9 | 10 | from os import path 11 | import json 12 | 13 | 14 | class DIETClassifierConfig(PretrainedConfig): 15 | def __init__(self, model: str, entities: List[str] = None, intents: List[str] = None): 16 | super().__init__() 17 | self.model = model 18 | self.entities = entities 19 | self.intents = intents 20 | self.hidden_dropout_prob = None 21 | self.hidden_size = None 22 | 23 | 24 | class DIETClassifier(BertPreTrainedModel): 25 | def __init__(self, config: DIETClassifierConfig): 26 | """ 27 | Create DIETClassifier model 28 | 29 | :param config: config for model 30 | """ 31 | if path.exists(config.model): 32 | try: 33 | json_config = json.load(open(f"{config.model}/config.json", "r")) 34 | except Exception as ex: 35 | raise RuntimeError(f"Cannot load configuration fil from {config.model} by error: {ex}") 36 | 37 | try: 38 | checkpoint = torch.load(f"{config.model}/pytorch_model.bin") 39 | except Exception as ex: 40 | raise RuntimeError(f"Cannot load model from {config.model} by error: {ex}") 41 | 42 | pretrained_model = None 43 | config = PretrainedConfig.from_dict(json_config) 44 | else: 45 | pretrained_model = BertForTokenClassification.from_pretrained(config.model) 46 | checkpoint = None 47 | if config.intents is None or config.entities is None: 48 | raise ValueError(f"Using pretrained from transformers should specific entities and intents") 49 | pretrained_model.config.update({"model": config.model, "entities": config.entities, "intents": config.intents}) 50 | config = pretrained_model.config 51 | 52 | super().__init__(config) 53 | 54 | self.entities_list = ["O"] + config.entities 55 | self.num_entities = len(self.entities_list) 56 | self.intents_list = config.intents 57 | self.num_intents = len(self.intents_list) 58 | 59 | self.bert = BertModel(config, add_pooling_layer=False) if not pretrained_model else pretrained_model.bert 60 | 61 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 62 | 63 | self.entities_classifier = nn.Linear(config.hidden_size, self.num_entities) 64 | self.intents_classifier = nn.Linear(config.hidden_size, self.num_intents) 65 | 66 | self.init_weights() 67 | 68 | if not pretrained_model: 69 | try: 70 | self.load_state_dict(checkpoint) 71 | except Exception as ex: 72 | raise RuntimeError(f"Cannot load state dict from checkpoint by error: {ex}") 73 | 74 | def forward( 75 | self, 76 | input_ids=None, 77 | attention_mask=None, 78 | token_type_ids=None, 79 | position_ids=None, 80 | head_mask=None, 81 | inputs_embeds=None, 82 | intent_labels=None, 83 | entities_labels=None, 84 | output_attentions=None, 85 | output_hidden_states=None, 86 | return_dict=None, 87 | ): 88 | """ 89 | training model if entities_labels and intent_labels are passed, else inference 90 | 91 | :param input_ids: embedding ids of tokens 92 | :param attention_mask: attention_mask 93 | :param token_type_ids: token_type_ids 94 | :param position_ids: position_ids (optional) 95 | :param head_mask: head_mask (optional) 96 | :param inputs_embeds: inputs_embeds (optional) 97 | :param intent_labels: labels of intent 98 | :param entities_labels: labels of entities 99 | :param output_attentions: return attention weight or not 100 | :param output_hidden_states: return hidden_states or not 101 | :param return_dict: return dictionary or not 102 | :return: 103 | """ 104 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 105 | 106 | outputs = self.bert( 107 | input_ids, 108 | attention_mask=attention_mask, 109 | token_type_ids=token_type_ids, 110 | position_ids=position_ids, 111 | head_mask=head_mask, 112 | inputs_embeds=inputs_embeds, 113 | output_attentions=output_attentions, 114 | output_hidden_states=output_hidden_states, 115 | return_dict=return_dict, 116 | ) 117 | 118 | sequence_output = outputs[0][:, 1:] 119 | sequence_output = self.dropout(sequence_output) 120 | 121 | pooled_output = outputs[0][:, :1] 122 | pooled_output = self.dropout(pooled_output) 123 | 124 | entities_logits = self.entities_classifier(sequence_output) 125 | intent_logits = self.intents_classifier(pooled_output) 126 | 127 | entities_loss = None 128 | if entities_labels is not None: 129 | entities_loss_fct = CrossEntropyLoss() 130 | # Only keep active parts of the loss 131 | if attention_mask is not None: 132 | active_loss = attention_mask[:, 1:].reshape(-1) == 1 133 | active_logits = entities_logits.view(-1, self.num_entities) 134 | active_labels = torch.where( 135 | active_loss, entities_labels.view(-1), 136 | torch.tensor(entities_loss_fct.ignore_index).type_as(entities_labels) 137 | ) 138 | entities_loss = entities_loss_fct(active_logits, active_labels) 139 | else: 140 | entities_loss = entities_loss_fct(entities_logits.view(-1, self.num_entities), entities_labels.view(-1)) 141 | 142 | intent_loss = None 143 | if intent_labels is not None: 144 | if self.num_intents == 1: 145 | intent_loss_fct = MSELoss() 146 | intent_loss = intent_loss_fct(intent_logits.view(-1), intent_labels.view(-1)) 147 | else: 148 | intent_loss_fct = CrossEntropyLoss() 149 | intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intents), intent_labels.view(-1)) 150 | 151 | if (entities_labels is not None) and (intent_labels is not None): 152 | loss = entities_loss * 0.1 + intent_loss * 0.9 153 | else: 154 | loss = None 155 | 156 | if self.training: 157 | return_dict = True 158 | 159 | if not return_dict: 160 | output = (loss,) + outputs[2:] 161 | return ((loss,) + output) if ((entities_loss is not None) and (intent_loss is not None)) else output 162 | 163 | return dict( 164 | entities_loss=entities_loss, 165 | intent_loss=intent_loss, 166 | loss=loss, 167 | logits=(entities_logits, intent_logits), 168 | hidden_states=outputs.hidden_states, 169 | attentions=outputs.attentions, 170 | ) 171 | 172 | 173 | if __name__ == '__main__': 174 | import os 175 | import sys 176 | 177 | sys.path.append(os.getcwd()) 178 | 179 | from src.data_reader.data_reader import make_dataframe 180 | from src.data_reader.dataset import DIETClassifierDataset 181 | from transformers import AutoTokenizer 182 | 183 | files = ["dataset/nlu_QnA_converted.yml", "dataset/nlu_QnA_converted.yml"] 184 | tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") 185 | 186 | df, entities_list, intents_list, synonym_dict = make_dataframe(files) 187 | dataset = DIETClassifierDataset(dataframe=df, tokenizer=tokenizer, entities=entities_list, intents=intents_list) 188 | 189 | config = DIETClassifierConfig( 190 | model="dslim/bert-base-NER", 191 | entities=entities_list, 192 | intents=intents_list 193 | ) 194 | model = DIETClassifier(config=config) 195 | 196 | sentences = ["What if I'm late"] 197 | 198 | inputs = tokenizer(sentences, return_tensors="pt", padding="max_length", max_length=512) 199 | outputs = model(**{k: v for k, v in inputs.items()}) 200 | 201 | print(outputs) 202 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import Trainer, TrainingArguments, EarlyStoppingCallback 4 | from transformers.integrations import TensorBoardCallback 5 | from torch.utils.data import random_split 6 | import torch 7 | 8 | 9 | class DIETTrainer: 10 | def __init__(self, model, dataset, train_range: 0.95, output_dir: str = "results", num_train_epochs: int = 100, per_device_train_batch_size: int = 4, 11 | per_device_eval_batch_size: int = 4, warmup_steps: int = 500, weight_decay: float = 0.01, 12 | logging_dir: str = "logs", early_stopping_patience: int = 20, early_stopping_threshold: float = 1e-5): 13 | """ 14 | Create DIETTrainer class 15 | 16 | :param model: model to train 17 | :param dataset: dataset (including train and eval) 18 | :param train_range: percentage of training dataset 19 | :param output_dir: model output directory 20 | :param num_train_epochs: number of training epochs 21 | :param per_device_train_batch_size: batch_size of training stage 22 | :param per_device_eval_batch_size: batch_size of evaluating stage 23 | :param warmup_steps: warmup steps 24 | :param weight_decay: weight decay 25 | :param logging_dir: logging directory 26 | """ 27 | self.training_args = TrainingArguments(output_dir=output_dir, 28 | num_train_epochs=num_train_epochs, 29 | per_device_train_batch_size=per_device_train_batch_size, 30 | per_device_eval_batch_size=per_device_eval_batch_size, 31 | warmup_steps=warmup_steps, 32 | weight_decay=weight_decay, 33 | logging_dir=logging_dir, 34 | load_best_model_at_end=True, 35 | metric_for_best_model="loss", 36 | greater_is_better=False, 37 | evaluation_strategy="epoch", 38 | label_names=["entities_labels", "intent_labels"], 39 | save_total_limit=1) 40 | 41 | train_dataset, eval_dataset = random_split(dataset, [int(len(dataset)*train_range), len(dataset) - int(len(dataset)*train_range)], generator=torch.Generator().manual_seed(42)) 42 | 43 | self.trainer = Trainer( 44 | model=model, 45 | args=self.training_args, 46 | train_dataset=train_dataset, 47 | eval_dataset=eval_dataset, 48 | callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience, early_stopping_threshold=early_stopping_threshold), TensorBoardCallback()] 49 | ) 50 | 51 | def train(self): 52 | self.trainer.train() 53 | 54 | 55 | if __name__ == '__main__': 56 | import os 57 | import sys 58 | 59 | sys.path.append(os.getcwd()) 60 | 61 | from src.data_reader.data_reader import make_dataframe 62 | from src.data_reader.dataset import DIETClassifierDataset 63 | from src.models.classifier import DIETClassifier, DIETClassifierConfig 64 | from transformers import AutoTokenizer 65 | 66 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 67 | 68 | files = ["dataset/nlu_QnA_converted.yml", "dataset/nlu_QnA_converted.yml"] 69 | tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") 70 | 71 | df, entities_list, intents_list, synonym_dict = make_dataframe(files) 72 | entities_list = [entity for entity in entities_list if entity != "number"] 73 | print(f"ENTITIES_LIST: {entities_list}") 74 | dataset = DIETClassifierDataset(dataframe=df, tokenizer=tokenizer, entities=entities_list, intents=intents_list) 75 | 76 | config = DIETClassifierConfig(model="dslim/bert-base-NER", entities=entities_list, intents=intents_list) 77 | model = DIETClassifier(config=config) 78 | 79 | sentences = ["What if I'm late"] 80 | 81 | inputs = tokenizer(sentences, return_tensors="pt", padding="max_length", max_length=512) 82 | 83 | outputs = model(**{k: v for k, v in inputs.items()}) 84 | 85 | trainer = DIETTrainer(model=model, dataset=dataset, train_range=0.95) 86 | 87 | trainer.train() 88 | -------------------------------------------------------------------------------- /src/models/wrapper.py: -------------------------------------------------------------------------------- 1 | from os import path, listdir 2 | from typing import Union, Dict, List, Any, Tuple 3 | 4 | import torch 5 | import yaml 6 | from transformers import BertTokenizerFast 7 | 8 | import os 9 | import sys 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | from src.models.classifier import DIETClassifier, DIETClassifierConfig 14 | from src.models.trainer import DIETTrainer 15 | from src.data_reader.dataset import DIETClassifierDataset 16 | from src.data_reader.data_reader import make_dataframe 17 | 18 | 19 | class DIETClassifierWrapper: 20 | """Wrapper for DIETClassifier.""" 21 | def __init__(self, config: Union[Dict[str, Dict[str, Any]], str]): 22 | """ 23 | Create wrapper with configuration. 24 | 25 | :param config: config in dictionary format or path to config file (.yml) 26 | """ 27 | if isinstance(config, str): 28 | try: 29 | f = open(config, "r") 30 | except Exception as ex: 31 | raise RuntimeError(f"Cannot read config file from {config}: {ex}") 32 | self.config_file_path = config 33 | config = yaml.load(f) 34 | 35 | self.config = config 36 | self.util_config = config.get("util", None) 37 | 38 | model_config_dict = config.get("model", None) 39 | if not model_config_dict: 40 | raise ValueError(f"Config file should have 'model' attribute") 41 | 42 | self.dataset_config = model_config_dict 43 | 44 | if model_config_dict["device"] is not None: 45 | self.device = torch.device(model_config_dict["device"]) if torch.cuda.is_available() else torch.device( 46 | "cpu") 47 | 48 | model_config_attributes = ["model", "intents", "entities"] 49 | # model_config_dict = {k: v for k, v in model_config_dict.items() if k in model_config_attributes} 50 | 51 | self.intents = model_config_dict["intents"] 52 | self.entities = ["O"] + model_config_dict["entities"] 53 | 54 | self.model_config = DIETClassifierConfig(**{k: v for k, v in model_config_dict.items() if k in model_config_attributes}) 55 | 56 | training_config_dict = config.get("training", None) 57 | if not training_config_dict: 58 | raise ValueError(f"Config file should have 'training' attribute") 59 | 60 | self.training_config = training_config_dict 61 | self.tokenizer = BertTokenizerFast.from_pretrained(model_config_dict["tokenizer"]) 62 | self.model = DIETClassifier(config=self.model_config) 63 | 64 | self.model.to(self.device) 65 | 66 | self.softmax = torch.nn.Softmax(dim=-1) 67 | 68 | self.synonym_dict = {} if not model_config_dict.get("synonym") else model_config_dict["synonym"] 69 | 70 | def tokenize(self, sentences) -> Tuple[Dict[str, Any], List[List[Tuple[int, int]]]]: 71 | """ 72 | Tokenize sentences using tokenizer. 73 | :param sentences: list of sentences 74 | :return: tuple(tokenized sentences, offset_mapping for sentences) 75 | """ 76 | inputs = self.tokenizer(sentences, return_tensors="pt", return_attention_mask=True, return_token_type_ids=True, 77 | return_offsets_mapping=True, 78 | padding=True, truncation=True) 79 | 80 | offset_mapping = inputs["offset_mapping"] 81 | inputs = {k: v.to(self.device) for k, v in inputs.items() if k != "offset_mapping"} 82 | 83 | return inputs, offset_mapping 84 | 85 | def convert_intent_logits(self, intent_logits: torch.tensor) -> List[Dict[str, float]]: 86 | """ 87 | Convert logits from model to predicted intent, 88 | 89 | :param intent_logits: output from model 90 | :return: dictionary of predicted intent 91 | """ 92 | softmax_intents = self.softmax(intent_logits) 93 | 94 | predicted_intents = [] 95 | 96 | for sentence in softmax_intents: 97 | sentence = sentence[0] 98 | 99 | sorted_sentence = sentence.clone() 100 | sorted_sentence, _ = torch.sort(sorted_sentence) 101 | 102 | if sorted_sentence[-1] >= self.util_config["intent_threshold"] and ( 103 | sorted_sentence[-1] - sorted_sentence[-2]) >= self.util_config["ambiguous_threshold"]: 104 | max_probability = torch.argmax(sentence) 105 | else: 106 | max_probability = -1 107 | 108 | predicted_intents.append({ 109 | "intent": None if max_probability == -1 else self.intents[max_probability], 110 | "intent_ranking": { 111 | intent_name: probability.item() for intent_name, probability in zip(self.intents, sentence) 112 | } 113 | }) 114 | 115 | return predicted_intents 116 | 117 | def convert_entities_logits(self, entities_logits: torch.tensor, offset_mapping: torch.tensor) -> List[ 118 | List[Dict[str, Any]]]: 119 | """ 120 | Convert logits to predicted entities 121 | 122 | :param entities_logits: entities logits from model 123 | :param offset_mapping: offset mapping for sentences 124 | :return: list of predicted entities 125 | """ 126 | softmax_entities = self.softmax(entities_logits) 127 | 128 | predicted_entities = [] 129 | 130 | for sentence, offset in zip(softmax_entities, offset_mapping): 131 | predicted_entities.append([]) 132 | latest_entity = None 133 | for word, token_offset in zip(sentence, offset[1:]): 134 | max_probability = torch.argmax(word) 135 | if word[max_probability] >= self.util_config["entities_threshold"] and max_probability != 0: 136 | if self.entities[max_probability] != latest_entity: 137 | latest_entity = self.entities[max_probability] 138 | predicted_entities[-1].append({ 139 | "entity_name": self.entities[max_probability], 140 | "start": token_offset[0].item(), 141 | "end": token_offset[1].item() 142 | }) 143 | else: 144 | predicted_entities[-1][-1]["end"] = token_offset[1].item() 145 | else: 146 | latest_entity = None 147 | 148 | return predicted_entities 149 | 150 | def predict(self, sentences: List[str]) -> List[Dict[str, Any]]: 151 | """ 152 | Predict intent and entities from sentences. 153 | 154 | :param sentences: list of sentences 155 | :return: list of prediction 156 | """ 157 | inputs, offset_mapping = self.tokenize(sentences=sentences) 158 | outputs = self.model(**inputs) 159 | logits = outputs["logits"] 160 | predicted_intents = self.convert_intent_logits(intent_logits=logits[1]) 161 | predicted_entities = self.convert_entities_logits(entities_logits=logits[0], offset_mapping=offset_mapping) 162 | predicted_outputs = [] 163 | for sentence, intent_sentence, entities_sentence in zip(sentences, predicted_intents, predicted_entities): 164 | predicted_outputs.append({}) 165 | predicted_outputs[-1].update(intent_sentence) 166 | predicted_outputs[-1].update({"entities": entities_sentence}) 167 | for entity in predicted_outputs[-1]["entities"]: 168 | entity["text"] = sentence[entity["start"]: entity["end"]] 169 | 170 | if self.synonym_dict.get(entity["text"], None): 171 | entity["original_text"] = entity["text"] 172 | entity["text"] = self.synonym_dict[entity["text"]] 173 | 174 | predicted_outputs[-1]["text"] = sentence 175 | 176 | return predicted_outputs 177 | 178 | def save_pretrained(self, directory: str): 179 | """ 180 | Save model and tokenizer to directory 181 | 182 | :param directory: path to save folder 183 | :return: None 184 | """ 185 | self.model.save_pretrained(directory) 186 | self.tokenizer.save_pretrained(directory) 187 | 188 | config_file_path = "config.yml" if not self.config_file_path else self.config_file_path 189 | 190 | try: 191 | f = open(config_file_path, "w") 192 | yaml.dump(self.config, f, sort_keys=False) 193 | f.close() 194 | except Exception as ex: 195 | raise RuntimeError(f"Cannot save config to {config_file_path} by error: {ex}") 196 | 197 | def train_model(self, save_folder: str = "latest_model"): 198 | """ 199 | Create trainer, train and save best model to save_folder 200 | :param save_folder: path to save folder 201 | :return: None 202 | """ 203 | dataset_folder = self.dataset_config["dataset_folder"] 204 | if not path.exists(dataset_folder): 205 | raise ValueError(f"Folder {dataset_folder} is not exists") 206 | 207 | files_list = [path.join(dataset_folder, f) for f in listdir(dataset_folder) if path.isfile(path.join(dataset_folder, f)) and f.endswith(".yml")] 208 | 209 | df, _, _, synonym_dict = make_dataframe(files=files_list) 210 | 211 | self.synonym_dict.update(synonym_dict) 212 | self.config["model"]["synonym"] = self.synonym_dict 213 | 214 | dataset = DIETClassifierDataset(dataframe=df, tokenizer=self.tokenizer, entities=self.entities[1:], intents=self.intents) 215 | 216 | trainer = DIETTrainer(model=self.model, dataset=dataset, 217 | train_range=self.training_config["train_range"], 218 | num_train_epochs=self.training_config["num_train_epochs"], 219 | per_device_train_batch_size=self.training_config["per_device_train_batch_size"], 220 | per_device_eval_batch_size=self.training_config["per_device_eval_batch_size"], 221 | warmup_steps=self.training_config["warmup_steps"], 222 | weight_decay=self.training_config["weight_decay"], 223 | logging_dir=self.training_config["logging_dir"], 224 | early_stopping_patience=self.training_config["early_stopping_patience"], 225 | early_stopping_threshold=self.training_config["early_stopping_threshold"], 226 | output_dir=self.training_config["output_dir"]) 227 | 228 | trainer.train() 229 | 230 | self.save_pretrained(directory=save_folder) 231 | 232 | 233 | if __name__ == "__main__": 234 | config_file = "src/config.yml" 235 | 236 | wrapper = DIETClassifierWrapper(config=config_file) 237 | 238 | print(wrapper.predict(["I work on office hours"])) 239 | 240 | wrapper.train_model() 241 | 242 | print(wrapper.predict(["What is the average working hours"])) 243 | 244 | 245 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiNyn/DIETClassifier-pytorch/58d0630d74d055b644cc81f736c03c09f4ea4c60/src/utils/__init__.py --------------------------------------------------------------------------------