├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.ja.md ├── README.md ├── examples.ja.md ├── examples.md ├── examples ├── dice.py ├── echo.py ├── janomeecho.py ├── minette.ini ├── todo.py └── translation.py ├── minette ├── __init__.py ├── adapter │ ├── __init__.py │ ├── base.py │ ├── clovaadapter.py │ └── lineadapter.py ├── config.py ├── core.py ├── datastore │ ├── __init__.py │ ├── azurestoragestores.py │ ├── connectionprovider.py │ ├── contextstore.py │ ├── messagelogstore.py │ ├── mysqlstores.py │ ├── sqlalchemystores.py │ ├── sqldbstores.py │ ├── sqlitestores.py │ ├── storeset.py │ └── userstore.py ├── dialog │ ├── __init__.py │ ├── dependency.py │ ├── router.py │ └── service.py ├── models │ ├── __init__.py │ ├── context.py │ ├── group.py │ ├── message.py │ ├── payload.py │ ├── performance.py │ ├── priority.py │ ├── response.py │ ├── topic.py │ ├── user.py │ └── wordnode.py ├── scheduler │ ├── __init__.py │ └── base.py ├── serializer.py ├── tagger │ ├── __init__.py │ ├── base.py │ ├── janometagger.py │ ├── mecabservice.py │ └── mecabtagger.py ├── testing │ ├── __init__.py │ └── helper.py ├── utils.py └── version.py ├── requirements-dev.txt ├── setup.py └── tests ├── adapter ├── test_adapter_base.py ├── test_clovaadapter.py └── test_lineadapter.py ├── config ├── test_config.ini └── test_config_empty.ini ├── datastore ├── test_connectionprovider.py ├── test_contextstore.py ├── test_messagelogstore.py └── test_userstore.py ├── dialog ├── test_dependency.py ├── test_router.py └── test_service.py ├── models ├── test_context.py ├── test_group.py ├── test_message.py ├── test_models_base.py ├── test_payload.py ├── test_performance.py ├── test_priority.py ├── test_response.py ├── test_topic.py ├── test_user.py └── test_wordnode.py ├── payload.png ├── scheduler └── test_scheduler.py ├── tagger ├── test_janometagger.py ├── test_mecabservice.py ├── test_mecabtagger.py └── test_tagger_base.py ├── test_config.py ├── test_core.py ├── test_testing.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # VS Code 2 | .vscode/ 3 | 4 | # test apps 5 | imoutobot/ 6 | develop/ 7 | examples/ 8 | 9 | # internal test app 10 | develop/ 11 | 12 | # DB/Log/ini 13 | *.db 14 | *.log 15 | tests/config/test_config_datastores.ini 16 | tests/config/test_config_adapter.ini 17 | tests/adapter/request_samples.py 18 | tests/config/private/ 19 | 20 | # .DS_Store 21 | .DS_Store 22 | 23 | # IDE 24 | .vscode/ 25 | 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | env/ 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib/ 44 | lib64/ 45 | parts/ 46 | sdist/ 47 | var/ 48 | wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *,cover 72 | .hypothesis/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # celery beat schedule file 102 | celerybeat-schedule 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # dotenv 108 | .env 109 | 110 | # virtualenv 111 | .venv 112 | venv/ 113 | ENV/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guideline 2 | 3 | ``` 4 | 🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️ 5 | 🙏   Feel free to contribute!!!!!!    🙇‍♀️ 6 | 🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏🙇‍♀️🙏 7 | ``` 8 | 9 | ## 1. Fork this repository 10 | 11 | Fork this repository on Github website and clone it to you local machine. 12 | 13 | ## 2. Install libraries required for development and test 14 | 15 | ```bash 16 | $ pip install -r requirements-dev.txt 17 | ``` 18 | 19 | ## 3. Make changes 20 | 21 | Also, changes to README / README.ja along with your changes to the codes is welcomed. 22 | 23 | ## 4. Run tests 24 | 25 | Run test in `tests` directory using pytest. 26 | 27 | ```bash 28 | $ pytest 29 | ``` 30 | 31 | ## 5. Create pull request 32 | 33 | If your code passes (partly skipped) all tests, create a pull request. 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017-2020 uezo 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /examples.ja.md: -------------------------------------------------------------------------------- 1 | 2 | # サンプルコード 3 | 4 | 同じものが`examples`にも格納されていますので、すぐに動かしたい方はそちらを利用してください。 5 | 6 | # 🎲 さいころBOT 7 | 8 | このサンプルコードは、チャットボットの処理ロジックの実装と、処理結果を利用した応答メッセージの組み立ての方法の例です。 9 | 10 | ```python 11 | import random 12 | from minette import Minette, DialogService 13 | 14 | 15 | # カスタムの対話部品 16 | class DiceDialogService(DialogService): 17 | # ロジックを処理して結果をコンテキストに格納 18 | def process_request(self, request, context, connection): 19 | context.data = { 20 | "dice1": random.randint(1, 6), 21 | "dice2": random.randint(1, 6) 22 | } 23 | 24 | # コンテキスト情報を使って応答データを組み立て 25 | def compose_response(self, request, context, connection): 26 | return "Dice1:{} / Dice2:{}".format( 27 | str(context.data["dice1"]), str(context.data["dice2"])) 28 | 29 | 30 | if __name__ == "__main__": 31 | # BOTの起動 32 | bot = Minette(default_dialog_service=DiceDialogService) 33 | # 対話の開始 34 | while True: 35 | req = input("user> ") 36 | res = bot.chat(req) 37 | for message in res.messages: 38 | print("minette> " + message.text) 39 | ``` 40 | 41 | 実行結果は以下の通り。 42 | 43 | ``` 44 | $ python dice.py 45 | 46 | user> dice 47 | minette> Dice1:1 / Dice2:2 48 | user> more 49 | minette> Dice1:4 / Dice2:5 50 | user> 51 | minette> Dice1:6 / Dice2:6 52 | ``` 53 | 54 | 55 | # ✅ Todo bot 56 | 57 | SQLAlchemy(0.4.1で実験的サポート)を使ってTodoリスト管理BOTを作るサンプルです。`Session`のインスタンスがリクエスト毎に生成され、DialogServiceの中で利用することができます。 58 | 59 | ```python 60 | from minette import Minette, DialogService 61 | from minette.datastore.sqlalchemystores import SQLAlchemyStores, Base 62 | from datetime import datetime 63 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 64 | 65 | # Define datamodel 66 | class TodoModel(Base): 67 | __tablename__ = "todolist" 68 | id = Column("id", Integer, primary_key=True, autoincrement=True) 69 | created_at = Column("created_at", DateTime, default=datetime.utcnow()) 70 | text = Column("title", String(255)) 71 | is_closed = Column("is_closed", Boolean, default=False) 72 | 73 | # TodoDialog 74 | class TodoDialogService(DialogService): 75 | def process_request(self, request, context, connection): 76 | 77 | # Note: Session of SQLAlchemy is provided as argument `connection` 78 | 79 | # Register new item 80 | if request.text.lower().startswith("todo:"): 81 | item = TodoModel() 82 | item.text = request.text[5:].strip() 83 | connection.add(item) 84 | connection.commit() 85 | context.data["item"] = item 86 | context.topic.status = "item_added" 87 | 88 | # Close item 89 | elif request.text.lower().startswith("close:"): 90 | item_id = int(request.text[6:]) 91 | item = connection.query(TodoModel).filter(TodoModel.id==item_id).first() 92 | if item: 93 | item.is_closed = True 94 | connection.commit() 95 | context.data["item"] = item 96 | context.topic.status = "item_closed" 97 | else: 98 | context.data["item_id"] = item_id 99 | context.topic.status = "item_not_found" 100 | 101 | # Get item list 102 | elif request.text.lower().startswith("list") or request.text.lower().startswith("show"): 103 | if "all" in request.text.lower(): 104 | items = connection.query(TodoModel).all() 105 | else: 106 | items = connection.query(TodoModel).filter(TodoModel.is_closed==0).all() 107 | if items: 108 | context.data["items"] = items 109 | context.topic.status = "item_listed" 110 | else: 111 | context.topic.status = "no_items" 112 | 113 | # Return reply message to user 114 | def compose_response(self, request, context, connection): 115 | if context.topic.status == "item_added": 116 | return "New item created: □ #{} {}".format(context.data["item"].id, context.data["item"].text) 117 | elif context.topic.status == "item_closed": 118 | return "Item closed: ✅#{} {}".format(context.data["item"].id, context.data["item"].text) 119 | elif context.topic.status == "item_not_found": 120 | return "Item not found: #{}".format(context.data["item_id"]) 121 | elif context.topic.status == "item_listed": 122 | text = "Todo:" 123 | for item in context.data["items"]: 124 | text += "\n{}#{} {}".format("□ " if item.is_closed == 0 else "✅", item.id, item.text) 125 | return text 126 | elif context.topic.status == "no_items": 127 | return "No todo item registered" 128 | else: 129 | return "Something wrong :(" 130 | 131 | # Create an instance of Minette with TodoDialogService and SQLAlchemyStores 132 | bot = Minette( 133 | default_dialog_service=TodoDialogService, 134 | data_stores=SQLAlchemyStores, 135 | connection_str="sqlite:///todo.db", 136 | db_echo=False) 137 | 138 | # Create table(s) using engine 139 | Base.metadata.create_all(bind=bot.connection_provider.engine) 140 | 141 | # Send and receive messages 142 | while True: 143 | req = input("user> ") 144 | res = bot.chat(req) 145 | for message in res.messages: 146 | print("minette> " + message.text) 147 | ``` 148 | 149 | Run it. 150 | 151 | ```bash 152 | $ python todo.py 153 | 154 | user> todo: Buy beer 155 | minette> New item created: □ #1 Buy beer 156 | user> todo: Take a bath 157 | minette> New item created: □ #2 Take a bath 158 | user> todo: Watch anime 159 | minette> New item created: □ #3 Watch anime 160 | user> close: 2 161 | minette> Item closed: ✅#2 Take a bath 162 | user> list 163 | minette> Todo: 164 | □ #1 Buy beer 165 | □ #3 Watch anime 166 | user> list all 167 | minette> Todo: 168 | □ #1 Buy beer 169 | ✅#2 Take a bath 170 | □ #3 Watch anime 171 | ``` 172 | 173 | 174 | # 🇬🇧 翻訳BOT 175 | 176 | このサンプルコードは以下の方法について解説するものです。 177 | 178 | - コンテキスト情報を利用した継続的な対話(文脈のある対話) 179 | - インテントの識別とそれに応じた適切な対話部品へのルーティング 180 | - 設定ファイル(minette.ini)へのAPIキーの設定 181 | 182 | ```python 183 | """ 184 | Translation Bot 185 | 186 | Notes 187 | Signup Microsoft Cognitive Services and get API Key for Translator Text API 188 | https://azure.microsoft.com/ja-jp/services/cognitive-services/ 189 | 190 | """ 191 | from datetime import datetime 192 | import requests 193 | from minette import ( 194 | Minette, 195 | DialogRouter, 196 | DialogService, 197 | EchoDialogService # 組み込みのおうむ返し部品 198 | ) 199 | 200 | class TranslationDialogService(DialogService): 201 | # ロジックを処理して結果をコンテキストに格納 202 | def process_request(self, request, context, connection): 203 | # 翻訳処理の開始・終了時には`topic.status`を更新のみ行う 204 | if context.topic.is_new: 205 | context.topic.status = "start_translation" 206 | 207 | elif request.text == "stop": 208 | context.topic.status = "end_translation" 209 | 210 | # 日本語への翻訳処理 211 | else: 212 | # Azure Cognitive Servicesを用いて翻訳 213 | api_url = "https://api.cognitive.microsofttranslator.com/translate?api-version=3.0&to=ja" 214 | headers = { 215 | # 事前に `translation_api_key` を `minette.ini` の `minette` セクションに追加しておきます 216 | # 217 | # [minette] 218 | # translation_api_key=YOUR_TRANSLATION_API_KEY 219 | "Ocp-Apim-Subscription-Key": self.config.get("translation_api_key"), 220 | "Content-type": "application/json" 221 | } 222 | data = [{"text": request.text}] 223 | api_result = requests.post(api_url, headers=headers, json=data).json() 224 | # 翻訳語の文章をコンテキストデータに保存 225 | context.data["translated_text"] = api_result[0]["translations"][0]["text"] 226 | context.topic.status = "process_translation" 227 | 228 | # コンテキスト情報を使って応答データを組み立て 229 | def compose_response(self, request, context, connection): 230 | if context.topic.status == "start_translation": 231 | context.topic.keep_on = True 232 | return "Input words to translate into Japanese" 233 | elif context.topic.status == "end_translation": 234 | return "Translation finished" 235 | elif context.topic.status == "process_translation": 236 | context.topic.keep_on = True 237 | return request.text + " in Japanese: " + context.data["translated_text"] 238 | 239 | 240 | class MyDialogRouter(DialogRouter): 241 | # intent->dialog のルーティングテーブルを定義 242 | def register_intents(self): 243 | self.intent_resolver = { 244 | # インテントが"TranslationIntent"のとき、`TranslationDialogService`を利用する 245 | "TranslationIntent": TranslationDialogService, 246 | "EchoIntent": EchoDialogService 247 | } 248 | 249 | # インテントの抽出ロジックを定義 250 | def extract_intent(self, request, context, connection): 251 | # リクエスト本文に「translat」が含まれる時 `TranslationIntent` と解釈する 252 | if "translat" in request.text.lower(): 253 | return "TranslationIntent" 254 | 255 | # リクエスト本文が「ignore」でないとき `EchoIntent` と解釈する 256 | # この場合「ignore」のときはインテントが抽出されないため、BOTは何も応答メッセージを返さない 257 | elif request.text.lower() != "ignore": 258 | return "EchoIntent" 259 | 260 | 261 | if __name__ == "__main__": 262 | # BOTの起動 263 | bot = Minette(dialog_router=MyDialogRouter) 264 | 265 | # 対話の開始 266 | while True: 267 | req = input("user> ") 268 | res = bot.chat(req) 269 | for message in res.messages: 270 | print("minette> " + message.text) 271 | ``` 272 | 273 | チャットボットとお話してみましょう。 274 | 275 | ``` 276 | $ python translation.py 277 | 278 | user> hello 279 | minette> You said: hello 280 | user> ignore 281 | user> okay 282 | minette> You said: okay 283 | user> translate 284 | minette> Input words to translate into Japanese 285 | user> I'm feeling happy 286 | minette> I'm feeling happy in Japanese: 幸せな気分だ 287 | user> My favorite food is soba 288 | minette> My favorite food is soba in Japanese: 私の好きな食べ物はそばです。 289 | user> stop 290 | minette> Translation finished 291 | user> thank you 292 | minette> You said: thank you 293 | ``` 294 | -------------------------------------------------------------------------------- /examples/dice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dicebot 3 | 4 | This example shows how to implement your logic 5 | and build the reply message using the result of logic 6 | 7 | 8 | Sample conversation 9 | $ python dice.py 10 | 11 | user> dice 12 | minette> Dice1:1 / Dice2:2 13 | user> more 14 | minette> Dice1:4 / Dice2:5 15 | user> 16 | minette> Dice1:6 / Dice2:6 17 | 18 | """ 19 | import random 20 | from minette import Minette, DialogService 21 | 22 | 23 | # Custom dialog service 24 | class DiceDialogService(DialogService): 25 | # Process logic and build context data 26 | def process_request(self, request, context, connection): 27 | context.data = { 28 | "dice1": random.randint(1, 6), 29 | "dice2": random.randint(1, 6) 30 | } 31 | 32 | # Compose response message using context data 33 | def compose_response(self, request, context, connection): 34 | return "Dice1:{} / Dice2:{}".format( 35 | str(context.data["dice1"]), str(context.data["dice2"])) 36 | 37 | 38 | if __name__ == "__main__": 39 | # Create bot 40 | bot = Minette(default_dialog_service=DiceDialogService) 41 | # Start conversation 42 | while True: 43 | req = input("user> ") 44 | res = bot.chat(req) 45 | for message in res.messages: 46 | print("minette> " + message.text) 47 | -------------------------------------------------------------------------------- /examples/echo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Echobot 3 | 4 | This example shows the simple echo-bot 5 | 6 | 7 | Sample conversation 8 | $ python echo.py 9 | 10 | user> hello 11 | minette> You said: hello 12 | user> I love soba 13 | minette> You said: I love soba 14 | 15 | """ 16 | from minette import Minette, DialogService 17 | 18 | 19 | # EchoDialog 20 | class EchoDialogService(DialogService): 21 | # Return reply message to user 22 | def compose_response(self, request, context, connection): 23 | return "You said: {}".format(request.text) 24 | 25 | 26 | # Create an instance of Minette with EchoDialogService 27 | bot = Minette(default_dialog_service=EchoDialogService) 28 | 29 | # Send and receive messages 30 | while True: 31 | req = input("user> ") 32 | res = bot.chat(req) 33 | for message in res.messages: 34 | print("minette> " + message.text) 35 | -------------------------------------------------------------------------------- /examples/janomeecho.py: -------------------------------------------------------------------------------- 1 | """ 2 | Janome Japanese morphological analysis echo-bot 3 | 4 | This example shows the echo-bot that returns the response analyzed by Janome, 5 | pure Python Japanese morphological analysis engine. 6 | 7 | 8 | Sample conversation 9 | $ python janomeecho.py 10 | 11 | user> 今日も暑くなりそうですね 12 | minette> 今日(名詞), も(助詞), 暑く(形容詞), なり(動詞), そう(名詞), です(助動詞), ね(助詞) 13 | user> もしハワイに行ったらパンケーキをたくさん食べます 14 | minette> 固有名詞あり: ハワイ 15 | 16 | Using user dictionary 17 | To use user dictionary, pass the path to user dictionary as `user_dic` argument. 18 | 19 | user> 新しい魔法少女リリカルなのはの映画を観ましたか? 20 | 21 | minette without udic> 新しい(形容詞), 魔法(名詞), 少女(名詞), リリカル(名詞), な(助動詞), の(名詞), は(助詞), の(助詞), 映画(名詞), を(助詞), 観(動詞), まし(助動詞), た(助動詞), か(助詞), ?(記号) 22 | minette with udic> 固有名詞あり: 魔法少女リリカルなのは 23 | """ 24 | from minette import Minette, DialogService 25 | from minette.tagger.janometagger import JanomeTagger 26 | 27 | 28 | # Custom dialog service 29 | class DiceDialogService(DialogService): 30 | def process_request(self, request, context, connection): 31 | # Text processing using the result of Janome 32 | context.data["proper_nouns"] = \ 33 | [w.surface for w in request.words if w.part_detail1 == "固有名詞"] 34 | 35 | def compose_response(self, request, context, connection): 36 | if context.data.get("proper_nouns"): 37 | # Echo extracted proper nouns when the request contains 38 | return "固有名詞あり: " + ", ".join(context.data.get("proper_nouns")) 39 | else: 40 | # Echo with analysis result 41 | return ", ".join(["{}({})".format(w.surface, w.part) for w in request.words]) 42 | 43 | 44 | if __name__ == "__main__": 45 | # Create bot with Janome Tagger 46 | bot = Minette( 47 | default_dialog_service=DiceDialogService, 48 | tagger=JanomeTagger, 49 | # user_dic="/path/to/userdict" # <= Uncomment when you use user dict 50 | ) 51 | # Start conversation 52 | while True: 53 | req = input("user> ") 54 | res = bot.chat(req) 55 | for message in res.messages: 56 | print("minette> " + message.text) 57 | -------------------------------------------------------------------------------- /examples/minette.ini: -------------------------------------------------------------------------------- 1 | [minette] 2 | translation_api_key=YOUR_TRANSLATION_API_KEY 3 | -------------------------------------------------------------------------------- /examples/todo.py: -------------------------------------------------------------------------------- 1 | """ 2 | TodoBot 3 | 4 | This example for using SQLAlchemy 5 | 6 | 7 | Sample conversation 8 | $ python todo.py 9 | 10 | user> todo: Buy beer 11 | minette> New item created: □ #1 Buy beer 12 | user> todo: Take a bath 13 | minette> New item created: □ #2 Take a bath 14 | user> todo: Watch anime 15 | minette> New item created: □ #3 Watch anime 16 | user> close: 2 17 | minette> Item closed: ✅#2 Take a bath 18 | user> list 19 | minette> Todo: 20 | □ #1 Buy beer 21 | □ #3 Watch anime 22 | user> list all 23 | minette> Todo: 24 | □ #1 Buy beer 25 | ✅#2 Take a bath 26 | □ #3 Watch anime 27 | """ 28 | 29 | from minette import Minette, DialogService 30 | from minette.datastore.sqlalchemystores import SQLAlchemyStores, Base 31 | 32 | from datetime import datetime 33 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 34 | 35 | 36 | # Define datamodel 37 | class TodoModel(Base): 38 | __tablename__ = "todolist" 39 | id = Column("id", Integer, primary_key=True, autoincrement=True) 40 | created_at = Column("created_at", DateTime, default=datetime.utcnow()) 41 | text = Column("title", String(255)) 42 | is_closed = Column("is_closed", Boolean, default=False) 43 | 44 | 45 | # TodoDialog 46 | class TodoDialogService(DialogService): 47 | def process_request(self, request, context, connection): 48 | 49 | # Note: Session of SQLAlchemy is provided as argument `connection` 50 | 51 | # Register new item 52 | if request.text.lower().startswith("todo:"): 53 | item = TodoModel() 54 | item.text = request.text[5:].strip() 55 | connection.add(item) 56 | connection.commit() 57 | context.data["item"] = item 58 | context.topic.status = "item_added" 59 | 60 | # Close item 61 | elif request.text.lower().startswith("close:"): 62 | item_id = int(request.text[6:]) 63 | item = connection.query(TodoModel).filter(TodoModel.id==item_id).first() 64 | if item: 65 | item.is_closed = True 66 | connection.commit() 67 | context.data["item"] = item 68 | context.topic.status = "item_closed" 69 | else: 70 | context.data["item_id"] = item_id 71 | context.topic.status = "item_not_found" 72 | 73 | # Get item list 74 | elif request.text.lower().startswith("list") or request.text.lower().startswith("show"): 75 | if "all" in request.text.lower(): 76 | items = connection.query(TodoModel).all() 77 | else: 78 | items = connection.query(TodoModel).filter(TodoModel.is_closed==0).all() 79 | if items: 80 | context.data["items"] = items 81 | context.topic.status = "item_listed" 82 | else: 83 | context.topic.status = "no_items" 84 | 85 | # Return reply message to user 86 | def compose_response(self, request, context, connection): 87 | if context.topic.status == "item_added": 88 | return "New item created: □ #{} {}".format(context.data["item"].id, context.data["item"].text) 89 | elif context.topic.status == "item_closed": 90 | return "Item closed: ✅#{} {}".format(context.data["item"].id, context.data["item"].text) 91 | elif context.topic.status == "item_not_found": 92 | return "Item not found: #{}".format(context.data["item_id"]) 93 | elif context.topic.status == "item_listed": 94 | text = "Todo:" 95 | for item in context.data["items"]: 96 | text += "\n{}#{} {}".format("□ " if item.is_closed == 0 else "✅", item.id, item.text) 97 | return text 98 | elif context.topic.status == "no_items": 99 | return "No todo item registered" 100 | else: 101 | return "Something wrong :(" 102 | 103 | 104 | # Create an instance of Minette with TodoDialogService and SQLAlchemyStores 105 | bot = Minette( 106 | default_dialog_service=TodoDialogService, 107 | data_stores=SQLAlchemyStores, 108 | connection_str="sqlite:///todo.db", 109 | db_echo=False) 110 | 111 | # Create table(s) using engine 112 | Base.metadata.create_all(bind=bot.connection_provider.engine) 113 | 114 | # Send and receive messages 115 | while True: 116 | req = input("user> ") 117 | res = bot.chat(req) 118 | for message in res.messages: 119 | print("minette> " + message.text) 120 | -------------------------------------------------------------------------------- /examples/translation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Translation Bot 3 | 4 | This example shows; 5 | - how to make the successive conversation using context 6 | - how to extract intent from what user is saying and route the proper DialogService 7 | - how to configure API Key using configuration file (minette.ini) 8 | 9 | Notes 10 | Signup Microsoft Cognitive Services and get API Key for Translator Text API 11 | https://azure.microsoft.com/ja-jp/services/cognitive-services/ 12 | 13 | 14 | Sample conversation 15 | $ python translation.py 16 | 17 | user> hello 18 | minette> You said: hello 19 | user> ignore 20 | user> okay 21 | minette> You said: okay 22 | user> translate 23 | minette> Input words to translate into Japanese 24 | user> I'm feeling happy 25 | minette> I'm feeling happy in Japanese: 幸せな気分だ 26 | user> My favorite food is soba 27 | minette> My favorite food is soba in Japanese: 私の好きな食べ物はそばです。 28 | user> stop 29 | minette> Translation finished 30 | user> thank you 31 | minette> You said: thank you 32 | 33 | """ 34 | from datetime import datetime 35 | import requests 36 | from minette import ( 37 | Minette, 38 | DialogRouter, 39 | DialogService, 40 | EchoDialogService # built-in EchoDialog 41 | ) 42 | 43 | 44 | class TranslationDialogService(DialogService): 45 | # Process logic and build context data 46 | def process_request(self, request, context, connection): 47 | # Just set the topic.status at the start and the end of translation dialog 48 | if context.topic.is_new: 49 | context.topic.status = "start_translation" 50 | 51 | elif request.text == "stop": 52 | context.topic.status = "end_translation" 53 | 54 | # Translate to Japanese 55 | else: 56 | # translate using Azure Cognitive Services 57 | api_url = "https://api.cognitive.microsofttranslator.com/translate?api-version=3.0&to=ja" 58 | headers = { 59 | # set `translation_api_key` at the `minette` section in `minette.ini` 60 | # 61 | # [minette] 62 | # translation_api_key=YOUR_TRANSLATION_API_KEY 63 | "Ocp-Apim-Subscription-Key": self.config.get("translation_api_key"), 64 | "Content-type": "application/json" 65 | } 66 | data = [{"text": request.text}] 67 | api_result = requests.post(api_url, headers=headers, json=data).json() 68 | # set translated text to context 69 | context.data["translated_text"] = api_result[0]["translations"][0]["text"] 70 | context.topic.status = "process_translation" 71 | 72 | # Compose response message 73 | def compose_response(self, request, context, connection): 74 | if context.topic.status == "start_translation": 75 | context.topic.keep_on = True 76 | return "Input words to translate into Japanese" 77 | elif context.topic.status == "end_translation": 78 | return "Translation finished" 79 | elif context.topic.status == "process_translation": 80 | context.topic.keep_on = True 81 | return request.text + " in Japanese: " + context.data["translated_text"] 82 | 83 | 84 | class MyDialogRouter(DialogRouter): 85 | # Configure intent->dialog routing table 86 | def register_intents(self): 87 | self.intent_resolver = { 88 | # If the intent is "TranslationIntent" then use TranslationDialogService 89 | "TranslationIntent": TranslationDialogService, 90 | "EchoIntent": EchoDialogService 91 | } 92 | 93 | # Implement the intent extraction logic 94 | def extract_intent(self, request, context, connection): 95 | # Return TranslationIntent if request contains "translat" 96 | if "translat" in request.text.lower(): 97 | return "TranslationIntent" 98 | 99 | # Return EchoIntent if request is not "ignore" 100 | # If "ignore", chatbot doesn't return reply message. 101 | elif request.text.lower() != "ignore": 102 | return "EchoIntent" 103 | 104 | 105 | if __name__ == "__main__": 106 | # Create bot 107 | bot = Minette(dialog_router=MyDialogRouter) 108 | 109 | # Start conversation 110 | while True: 111 | req = input("user> ") 112 | res = bot.chat(req) 113 | for message in res.messages: 114 | print("minette> " + message.text) 115 | -------------------------------------------------------------------------------- /minette/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .config import Config 4 | from .core import Minette 5 | from .datastore import ( 6 | ConnectionProvider, 7 | ContextStore, 8 | UserStore, 9 | MessageLogStore, 10 | StoreSet, 11 | SQLiteConnectionProvider, 12 | SQLiteContextStore, 13 | SQLiteUserStore, 14 | SQLiteMessageLogStore, 15 | SQLiteStores 16 | ) 17 | from .dialog import ( 18 | DialogService, 19 | EchoDialogService, 20 | ErrorDialogService, 21 | DialogRouter, 22 | DependencyContainer 23 | ) 24 | from .models import * 25 | from .tagger import Tagger 26 | from .adapter import Adapter 27 | from .scheduler import Task, Scheduler 28 | -------------------------------------------------------------------------------- /minette/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Adapter 2 | -------------------------------------------------------------------------------- /minette/adapter/base.py: -------------------------------------------------------------------------------- 1 | """ Base class for channel adapters """ 2 | from abc import ABC, abstractmethod 3 | import traceback 4 | from logging import Logger 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from ..core import Minette 8 | 9 | 10 | class Adapter(ABC): 11 | """ 12 | Base class for channel adapters 13 | 14 | Attributes 15 | ---------- 16 | bot : minette.Minette 17 | Instance of Minette 18 | config : minette.Config 19 | Configuration 20 | timezone : pytz.timezone 21 | Timezone 22 | logger : logging.Logger 23 | Logger 24 | threads : int 25 | Number of worker threads to process requests 26 | executor : ThreadPoolExecutor 27 | Thread pool of workers 28 | debug : bool 29 | Debug mode 30 | """ 31 | 32 | def __init__(self, bot=None, *, threads=None, debug=False, **kwargs): 33 | """ 34 | Parameters 35 | ---------- 36 | bot : minette.Minette, default None 37 | Instance of Minette. 38 | If None, create new instance of Minette by using `**kwargs` 39 | threads : int, default None 40 | Number of worker threads to process requests 41 | debug : bool, default None 42 | Debug mode 43 | """ 44 | self.bot = bot or Minette(**kwargs) 45 | self.config = self.bot.config 46 | self.timezone = self.bot.timezone 47 | self.logger = self.bot.logger 48 | self.threads = threads 49 | if self.threads != 0: 50 | self.logger.info("Use worker threads to handle events") 51 | self.executor = ThreadPoolExecutor( 52 | max_workers=self.threads, thread_name_prefix="AdapterThread") 53 | else: 54 | self.logger.info("Use main thread to handle events") 55 | self.executor = None 56 | self.debug = debug 57 | 58 | def handle_event(self, event): 59 | """ 60 | Handle event from channel 61 | 62 | Parameters 63 | ---------- 64 | event : object 65 | Event data from channel 66 | 67 | Returns 68 | ------- 69 | channel_messages : list 70 | List of messages in channel specific format 71 | """ 72 | if self.debug: 73 | self.logger.info(event) 74 | token = self._extract_token(event) 75 | message = self._to_minette_message(event) 76 | response = self.bot.chat(message) 77 | channel_messages = [ 78 | self._to_channel_message(m) for m in response.messages] 79 | return channel_messages, token 80 | 81 | def _extract_token(self, event): 82 | """ 83 | Extract token from event 84 | 85 | Parameters 86 | ---------- 87 | event : object 88 | Event data from channel 89 | 90 | Returns 91 | ------- 92 | token : str or object 93 | Token data for channel 94 | """ 95 | return "" 96 | 97 | @staticmethod 98 | @abstractmethod 99 | def _to_minette_message(event): 100 | """ 101 | Convert channel event into internal Message object 102 | 103 | Parameters 104 | ---------- 105 | event : object 106 | Event data from channel 107 | 108 | Returns 109 | ------- 110 | message : minette.Message 111 | Request message 112 | """ 113 | pass 114 | 115 | @staticmethod 116 | @abstractmethod 117 | def _to_channel_message(message): 118 | """ 119 | Convert internal Message object into channel formatted message 120 | 121 | Parameters 122 | ---------- 123 | message : minette.Message 124 | Response message 125 | 126 | Returns 127 | ------- 128 | channel_message : object 129 | Channel formatted message 130 | """ 131 | pass 132 | -------------------------------------------------------------------------------- /minette/adapter/clovaadapter.py: -------------------------------------------------------------------------------- 1 | """ Adapter for Clova Extensions Kit """ 2 | import traceback 3 | 4 | from cek import ( 5 | Clova, 6 | URL, 7 | IntentRequest 8 | ) 9 | 10 | from ..serializer import dumps 11 | from .base import Adapter 12 | from ..models import Message 13 | 14 | 15 | class ClovaAdapter(Adapter): 16 | """ 17 | Adapter for Clova Extensions Kit 18 | 19 | Attributes 20 | ---------- 21 | bot : minette.Minette 22 | Instance of Minette 23 | application_id : str 24 | Application ID of Clova Skill 25 | default_language : str 26 | Default language of Clova Skill 27 | clova : Clova 28 | Clova Extensions Kit API 29 | config : minette.Config 30 | Configuration 31 | timezone : pytz.timezone 32 | Timezone 33 | logger : logging.Logger 34 | Logger 35 | debug : bool 36 | Debug mode 37 | """ 38 | 39 | def __init__(self, bot=None, *, debug=False, 40 | application_id=None, default_language=None, **kwargs): 41 | """ 42 | Parameters 43 | ---------- 44 | bot : minette.Minette, default None 45 | Instance of Minette. 46 | If None, create new instance of Minette by using `**kwargs` 47 | application_id : str or None, default None 48 | Application ID for your Clova Skill 49 | default_language : str or None, default None 50 | Default language. ("en" / "ja" / "ko") 51 | If None, "ja" is set to Clova Extensions Kit API object 52 | debug : bool, default False 53 | Debug mode 54 | """ 55 | super().__init__(bot=bot, threads=0, debug=debug, **kwargs) 56 | self.application_id = application_id or \ 57 | self.config.get(section="clova_cek", key="application_id") 58 | self.default_language = default_language or \ 59 | self.config.get(section="clova_cek", key="default_language") or "ja" 60 | self.clova = Clova(application_id=self.application_id, 61 | default_language=self.default_language, 62 | debug_mode=debug) 63 | 64 | # handler for all types of request 65 | @self.clova.handle.default 66 | def default(clova_request): 67 | return clova_request 68 | 69 | def handle_http_request(self, request_data, request_headers): 70 | """ 71 | Interface to chat with Clova Skill 72 | 73 | Parameters 74 | ---------- 75 | request_data : bytes 76 | Request data from Clova as bytes 77 | request_headers : dict 78 | Request headers from Clova as dict 79 | 80 | Returns 81 | ------- 82 | response : Response 83 | Response from chatbot. Send back `json` attribute to Clova API 84 | """ 85 | clova_request = self.clova.route(request_data, request_headers) 86 | return self.handle_event(clova_request) 87 | 88 | def handle_event(self, clova_request): 89 | # execute bot 90 | channel_messages, _ = super().handle_event(clova_request) 91 | 92 | # print response for debug 93 | for msg in channel_messages: 94 | if self.debug: 95 | self.logger.info(msg) 96 | else: 97 | self.logger.info("Minette> {}".format(msg["speech_value"])) 98 | 99 | # build response message 100 | speech_values = [msg["speech_value"] for msg in channel_messages] 101 | end_session = channel_messages[-1]["end_session"] 102 | reprompt = channel_messages[-1]["reprompt"] 103 | if len(speech_values) == 1: 104 | return dumps(self.clova.response( 105 | speech_values[0], end_session=end_session, reprompt=reprompt)) 106 | else: 107 | return dumps(self.clova.response( 108 | speech_values, end_session=end_session, reprompt=reprompt)) 109 | 110 | @staticmethod 111 | def _to_minette_message(clova_request): 112 | """ 113 | Convert ClovaRequest object to Minette Message object 114 | 115 | Parameters 116 | ---------- 117 | clova_request : cek.Request 118 | Request from clova 119 | 120 | Returns 121 | ------- 122 | message : minette.Message 123 | Request converted into Message object 124 | """ 125 | msg = Message( 126 | type=clova_request.type, 127 | channel="LINE", 128 | channel_detail="Clova", 129 | channel_user_id=clova_request.session.user.id if clova_request.session._session else "", 130 | channel_message=clova_request 131 | ) 132 | 133 | # Set intent and entities when IntentRequest 134 | if isinstance(clova_request, IntentRequest): 135 | msg.intent = clova_request.name 136 | # if clova_request.slots: <- Error occures when no slot values 137 | if clova_request._request["intent"]["slots"]: 138 | msg.entities = clova_request.slots 139 | return msg 140 | 141 | @staticmethod 142 | def _to_channel_message(message): 143 | """ 144 | Convert Minette Message object to LINE SendMessage object 145 | 146 | Parameters 147 | ---------- 148 | response : Message 149 | Response message object 150 | 151 | Returns 152 | ------- 153 | response : SendMessage 154 | SendMessage object for LINE Messaging API 155 | """ 156 | return { 157 | "speech_value": URL(message.text) if message.type == "url" else message.text, 158 | "end_session": message.entities.get("end_session", True), 159 | "reprompt": message.entities.get("reprompt", None) 160 | } 161 | -------------------------------------------------------------------------------- /minette/config.py: -------------------------------------------------------------------------------- 1 | """ Configuration management """ 2 | import os 3 | from configparser import ConfigParser 4 | 5 | 6 | class Config: 7 | """ 8 | Configuration management 9 | 10 | Attributes 11 | ---------- 12 | confg_parser : ConfigParser 13 | ConfigParser used internally 14 | """ 15 | 16 | def __init__(self, config_file): 17 | """ 18 | Parameters 19 | ---------- 20 | config_file : str 21 | Path to configuration file 22 | """ 23 | self.confg_parser = ConfigParser() 24 | if config_file and len(self.confg_parser.read(config_file)) == 0: 25 | print("Can't read/find configuration file: {}".format(config_file)) 26 | print("Initialize with default configuration instead") 27 | 28 | if not self.confg_parser.has_section("minette"): 29 | self.confg_parser.add_section("minette") 30 | self.confg_parser.set("minette", "timezone", "UTC") 31 | self.confg_parser.set( 32 | "minette", "log_file", "ENV::MINETTE_LOGFILE") 33 | self.confg_parser.set( 34 | "minette", "connection_str", "ENV::MINETTE_CONNECTION_STR") 35 | 36 | def get(self, key, default=None, section="minette"): 37 | if section in self.confg_parser.sections(): 38 | ret = self.confg_parser[section].get(key, default) 39 | else: 40 | ret = default 41 | if str(ret).startswith("ENV::"): 42 | ret = os.environ.get(ret[5:], default) 43 | return ret 44 | -------------------------------------------------------------------------------- /minette/datastore/__init__.py: -------------------------------------------------------------------------------- 1 | from .connectionprovider import ConnectionProvider 2 | from .contextstore import ContextStore 3 | from .userstore import UserStore 4 | from .messagelogstore import MessageLogStore 5 | from .storeset import StoreSet 6 | 7 | from .sqlitestores import ( 8 | SQLiteConnectionProvider, 9 | SQLiteContextStore, 10 | SQLiteUserStore, 11 | SQLiteMessageLogStore, 12 | SQLiteStores 13 | ) 14 | -------------------------------------------------------------------------------- /minette/datastore/connectionprovider.py: -------------------------------------------------------------------------------- 1 | """ Base class for ConnectionProvider """ 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class ConnectionProvider(ABC): 6 | """ 7 | Base class for ConnectionProvider 8 | 9 | Attributes 10 | ---------- 11 | connection_str : str 12 | Connection string 13 | """ 14 | 15 | def __init__(self, connection_str, **kwargs): 16 | """ 17 | Parameters 18 | ---------- 19 | connection_str : str 20 | Connection string 21 | """ 22 | self.connection_str = connection_str 23 | 24 | @abstractmethod 25 | def get_connection(self): 26 | """ 27 | Get connection 28 | 29 | Returns 30 | ------- 31 | connection : Connection 32 | Database connection 33 | """ 34 | pass 35 | 36 | def get_prepare_params(self): 37 | """ 38 | Get parameters for preparing tables 39 | 40 | Returns 41 | ------- 42 | prepare_params : tuple or None 43 | Parameters for preparing tables 44 | """ 45 | return None 46 | -------------------------------------------------------------------------------- /minette/datastore/contextstore.py: -------------------------------------------------------------------------------- 1 | """ Base class for ContextStore """ 2 | from abc import ABC, abstractmethod 3 | import traceback 4 | from logging import getLogger 5 | from datetime import datetime 6 | from pytz import timezone as tz 7 | 8 | from ..serializer import dumps, loads 9 | from ..models import Context, Topic 10 | 11 | 12 | class ContextStore(ABC): 13 | """ 14 | Base class for ContextStore to enable successive conversation 15 | 16 | Attributes 17 | ---------- 18 | config : minette.Config 19 | Configuration 20 | timezone : pytz.timezone 21 | Timezone 22 | logger : logging.Logger 23 | Logger 24 | table_name : str 25 | Database table name for read/write context data 26 | sqls : dict 27 | SQLs used in ContextStore 28 | timeout : int 29 | Context timeout (Seconds) 30 | """ 31 | 32 | def __init__(self, config=None, timezone=None, logger=None, 33 | table_name="context", *, timeout=300, **kwargs): 34 | """ 35 | Parameters 36 | ---------- 37 | config : minette.Config, default None 38 | Configuration 39 | timezone : pytz.timezone, default None 40 | Timezone 41 | logger : logging.Logger, default None 42 | Logger 43 | table_name : str, default "context" 44 | Database table name for read/write context data 45 | timeout : int, default 300 46 | Context timeout (Seconds) 47 | """ 48 | self.config = config 49 | self.timezone = timezone or ( 50 | tz(config.get("timezone", default="UTC")) if config else tz("UTC")) 51 | self.logger = logger if logger else getLogger(__name__) 52 | self.table_name = table_name 53 | self.timeout = timeout 54 | self.sqls = self.get_sqls() 55 | 56 | @abstractmethod 57 | def get_sqls(self): 58 | pass 59 | 60 | def prepare_table(self, connection, prepare_params=None): 61 | """ 62 | Check and create table if not exist 63 | 64 | Parameters 65 | ---------- 66 | connection : Connection 67 | Connection for prepare 68 | 69 | query_params : tuple, default tuple() 70 | Query parameters for checking table 71 | 72 | Returns 73 | ------- 74 | created : bool 75 | Return True when created new table 76 | """ 77 | cursor = connection.cursor() 78 | cursor.execute(self.sqls["prepare_check"], prepare_params or tuple()) 79 | if not cursor.fetchone(): 80 | cursor.execute(self.sqls["prepare_create"]) 81 | connection.commit() 82 | return True 83 | else: 84 | return False 85 | 86 | def get(self, channel, channel_user_id, connection): 87 | """ 88 | Get context by channel and channel_user_id 89 | 90 | Parameters 91 | ---------- 92 | channel : str 93 | Channel 94 | channel_user_id : str 95 | Channel user ID 96 | connection : Connection 97 | Connection 98 | 99 | Returns 100 | ------- 101 | context : minette.Context 102 | Context for channel and channel_user_id 103 | """ 104 | context = Context(channel, channel_user_id) 105 | context.timestamp = datetime.now(self.timezone) 106 | if not channel_user_id: 107 | return context 108 | try: 109 | cursor = connection.cursor() 110 | cursor.execute( 111 | self.sqls["get_context"], (channel, channel_user_id)) 112 | row = cursor.fetchone() 113 | if row is not None: 114 | # convert to dict 115 | if isinstance(row, dict): 116 | record = row 117 | else: 118 | record = dict( 119 | zip([column[0] for column in cursor.description], row)) 120 | # convert type 121 | record["topic_previous"] = \ 122 | loads(record["topic_previous"]) 123 | record["data"] = loads(record["data"]) 124 | # check context timeout 125 | if record["timestamp"].tzinfo: 126 | last_access = record["timestamp"].astimezone(self.timezone) 127 | else: 128 | last_access = self.timezone.localize(record["timestamp"]) 129 | gap = datetime.now(self.timezone) - last_access 130 | if gap.total_seconds() <= self.timeout: 131 | # restore context if not timeout 132 | context.topic.name = record["topic_name"] 133 | context.topic.status = record["topic_status"] 134 | context.topic.priority = record["topic_priority"] 135 | context.topic.previous = Topic.from_dict( 136 | record["topic_previous"]) \ 137 | if record["topic_previous"] else None 138 | context.data = record["data"] if record["data"] else {} 139 | context.is_new = False 140 | except Exception as ex: 141 | self.logger.error( 142 | "Error occured in restoring context from database: " 143 | + str(ex) + "\n" + traceback.format_exc()) 144 | return context 145 | 146 | def save(self, context, connection): 147 | """ 148 | Save context 149 | 150 | Parameters 151 | ---------- 152 | context : minette.Context 153 | Context to save 154 | connection : Connection 155 | Connection 156 | """ 157 | if not context.channel_user_id: 158 | return 159 | # serialize some elements 160 | context_dict = context.to_dict() 161 | serialized_previous_topic = \ 162 | dumps(context_dict["topic"]["previous"]) 163 | serialized_data = dumps(context_dict["data"]) 164 | # save 165 | cursor = connection.cursor() 166 | cursor.execute(self.sqls["save_context"], ( 167 | context.channel, context.channel_user_id, context.timestamp, 168 | context.topic.name, context.topic.status, 169 | serialized_previous_topic, context.topic.priority, 170 | serialized_data)) 171 | connection.commit() 172 | -------------------------------------------------------------------------------- /minette/datastore/messagelogstore.py: -------------------------------------------------------------------------------- 1 | """ Base class for MessageLogStore """ 2 | from abc import ABC, abstractmethod 3 | from logging import getLogger 4 | from pytz import timezone as tz 5 | 6 | from ..serializer import dumps 7 | 8 | 9 | class MessageLogStore(ABC): 10 | """ 11 | Base class for MessageLogStore to analyze the communication 12 | 13 | Attributes 14 | ---------- 15 | config : minette.Config 16 | Configuration 17 | timezone : pytz.timezone 18 | Timezone 19 | logger : logging.Logger 20 | Logger 21 | table_name : str 22 | Database table name for read/write message log data 23 | sqls : dict 24 | SQLs used in ContextStore 25 | """ 26 | def __init__(self, config=None, timezone=None, logger=None, 27 | table_name="messagelog", **kwargs): 28 | """ 29 | Parameters 30 | ---------- 31 | config : minette.Config, default None 32 | Configuration 33 | timezone : pytz.timezone, default None 34 | Timezone 35 | logger : logging.Logger, default None 36 | Logger 37 | table_name : str, default "messagelog" 38 | Database table name for read/write message log data 39 | """ 40 | self.config = config 41 | self.timezone = timezone or ( 42 | tz(config.get("timezone", default="UTC")) if config else tz("UTC")) 43 | self.logger = logger if logger else getLogger(__name__) 44 | self.table_name = table_name 45 | self.sqls = self.get_sqls() 46 | 47 | @abstractmethod 48 | def get_sqls(self): 49 | pass 50 | 51 | def prepare_table(self, connection, prepare_params=None): 52 | """ 53 | Check and create table if not exist 54 | 55 | Parameters 56 | ---------- 57 | connection : Connection 58 | Connection for prepare 59 | 60 | query_params : tuple, default tuple() 61 | Query parameters for checking table 62 | """ 63 | cursor = connection.cursor() 64 | cursor.execute(self.sqls["prepare_check"], prepare_params or tuple()) 65 | if not cursor.fetchone(): 66 | cursor.execute(self.sqls["prepare_create"]) 67 | connection.commit() 68 | return True 69 | else: 70 | return False 71 | 72 | def _flatten(self, request, response, context): 73 | return { 74 | # request 75 | "channel": request.channel, 76 | "channel_detail": request.channel_detail, 77 | "channel_user_id": request.channel_user_id, 78 | "request_timestamp": request.timestamp, 79 | "request_id": request.id, 80 | "request_type": request.type, 81 | "request_text": request.text, 82 | "request_payloads": dumps( 83 | [p.to_dict() for p in request.payloads]), 84 | "request_intent": request.intent, 85 | "request_is_adhoc": request.is_adhoc, 86 | # response 87 | "response_type": response.messages[0].type if response.messages else "", 88 | "response_text": response.messages[0].text if response.messages else "", 89 | "response_payloads": dumps([p.to_dict() for p in response.messages[0].payloads]) if response.messages else "", 90 | "response_milliseconds": response.performance.milliseconds, 91 | # context 92 | "context_is_new": context.is_new, 93 | "context_topic_name": context.topic.name, 94 | "context_topic_status": context.topic.status, 95 | "context_topic_is_new": context.topic.is_new, 96 | "context_topic_keep_on": context.topic.keep_on, 97 | "context_topic_priority": context.topic.priority, 98 | "context_error": dumps(context.error), 99 | } 100 | 101 | def save(self, request, response, context, connection): 102 | """ 103 | Write message log 104 | 105 | Parameters 106 | ---------- 107 | request : minette.Message 108 | Request to chatbot 109 | response : minette.Response 110 | Response from chatbot 111 | context : minette.Context 112 | Context 113 | connection : Connection 114 | Connection 115 | """ 116 | f = self._flatten(request, response, context) 117 | cursor = connection.cursor() 118 | cursor.execute(self.sqls["write"], ( 119 | f["channel"], 120 | f["channel_detail"], 121 | f["channel_user_id"], 122 | f["request_timestamp"], 123 | f["request_id"], 124 | f["request_type"], 125 | f["request_text"], 126 | f["request_payloads"], 127 | f["request_intent"], 128 | f["request_is_adhoc"], 129 | f["response_type"], 130 | f["response_text"], 131 | f["response_payloads"], 132 | f["response_milliseconds"], 133 | f["context_is_new"], 134 | f["context_topic_name"], 135 | f["context_topic_status"], 136 | f["context_topic_is_new"], 137 | f["context_topic_keep_on"], 138 | f["context_topic_priority"], 139 | f["context_error"], 140 | request.to_json(), 141 | response.to_json(), 142 | context.to_json()) 143 | ) 144 | connection.commit() 145 | -------------------------------------------------------------------------------- /minette/datastore/mysqlstores.py: -------------------------------------------------------------------------------- 1 | import MySQLdb 2 | from MySQLdb.cursors import DictCursor 3 | from MySQLdb.connections import Connection 4 | 5 | from .connectionprovider import ConnectionProvider 6 | from .contextstore import ContextStore 7 | from .userstore import UserStore 8 | from .messagelogstore import MessageLogStore 9 | from .storeset import StoreSet 10 | 11 | 12 | class MySQLConnection(Connection): 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def __enter__(self): 17 | return self 18 | 19 | def __exit__(self, exc_type, exc_value, traceback): 20 | self.close() 21 | 22 | 23 | class MySQLConnectionProvider(ConnectionProvider): 24 | """ 25 | Connection provider for MySQL 26 | 27 | Attributes 28 | ---------- 29 | connection_str : str 30 | Connection string 31 | connection_params : dict 32 | Parameters for connection 33 | """ 34 | def __init__(self, connection_str, **kwargs): 35 | """ 36 | Parameters 37 | ---------- 38 | connection_str : str 39 | Connection string 40 | """ 41 | self.connection_str = connection_str 42 | self.connection_params = {"cursorclass": DictCursor, "charset": "utf8"} 43 | param_values = self.connection_str.split(";") 44 | for pv in param_values: 45 | if "=" in pv: 46 | p, v = list(map(str.strip, pv.split("="))) 47 | self.connection_params[p] = v 48 | 49 | def get_connection(self): 50 | """ 51 | Get connection 52 | 53 | Returns 54 | ------- 55 | connection : Connection 56 | Database connection 57 | """ 58 | return MySQLConnection(**self.connection_params) 59 | 60 | def get_prepare_params(self): 61 | """ 62 | Get parameters for preparing tables 63 | 64 | Returns 65 | ------- 66 | prepare_params : tuple or None 67 | Parameters for preparing tables 68 | """ 69 | return (self.connection_params["db"], ) 70 | 71 | 72 | class MySQLContextStore(ContextStore): 73 | def get_sqls(self): 74 | """ 75 | Get SQLs used in ContextStore 76 | 77 | Returns 78 | ------- 79 | sqls : dict 80 | SQLs used in SessionStore 81 | """ 82 | return { 83 | "prepare_check": "select * from information_schema.TABLES where TABLE_NAME='{0}' and TABLE_SCHEMA=%s".format(self.table_name), 84 | "prepare_create": "create table {0} (channel VARCHAR(20), channel_user_id VARCHAR(100), timestamp DATETIME, topic_name VARCHAR(100), topic_status VARCHAR(100), topic_previous VARCHAR(500), topic_priority INT, data JSON, primary key(channel, channel_user_id))".format(self.table_name), 85 | "get_context": "select channel, channel_user_id, timestamp, topic_name, topic_status, topic_previous, topic_priority, data from {0} where channel=%s and channel_user_id=%s limit 1".format(self.table_name), 86 | "save_context": "replace into {0} (channel, channel_user_id, timestamp, topic_name, topic_status, topic_previous, topic_priority, data) values (%s,%s,%s,%s,%s,%s,%s,%s)".format(self.table_name), 87 | } 88 | 89 | 90 | class MySQLUserStore(UserStore): 91 | def get_sqls(self): 92 | """ 93 | Get SQLs used in UserStore 94 | 95 | Returns 96 | ------- 97 | sqls : dict 98 | SQLs used in UserRepository 99 | """ 100 | return { 101 | "prepare_check": "select * from information_schema.TABLES where TABLE_NAME='{0}' and TABLE_SCHEMA=%s".format(self.table_name), 102 | "prepare_create": "create table {0} (channel VARCHAR(20), channel_user_id VARCHAR(100), user_id VARCHAR(100), timestamp DATETIME, name VARCHAR(100), nickname VARCHAR(100), profile_image_url VARCHAR(500), data JSON, primary key(channel, channel_user_id))".format(self.table_name), 103 | "get_user": "select channel, channel_user_id, user_id, timestamp, name, nickname, profile_image_url, data from {0} where channel=%s and channel_user_id=%s limit 1".format(self.table_name), 104 | "add_user": "insert into {0} (channel, channel_user_id, user_id, timestamp, name, nickname, profile_image_url, data) values (%s,%s,%s,%s,%s,%s,%s,%s)".format(self.table_name), 105 | "save_user": "update {0} set timestamp=%s, name=%s, nickname=%s, profile_image_url=%s, data=%s where channel=%s and channel_user_id=%s".format(self.table_name), 106 | } 107 | 108 | 109 | class MySQLMessageLogStore(MessageLogStore): 110 | def get_sqls(self): 111 | """ 112 | Get SQLs used in MessageLogStore 113 | 114 | Returns 115 | ------- 116 | sqls : dict 117 | SQLs used in MessageLogger 118 | """ 119 | return { 120 | "prepare_check": "select * from information_schema.TABLES where TABLE_NAME='{0}' and TABLE_SCHEMA=%s".format(self.table_name), 121 | "prepare_create": """ 122 | create table {0} ( 123 | id INT PRIMARY KEY AUTO_INCREMENT, 124 | channel VARCHAR(20), 125 | channel_detail VARCHAR(100), 126 | channel_user_id VARCHAR(100), 127 | request_timestamp DATETIME, 128 | request_id VARCHAR(100), 129 | request_type VARCHAR(100), 130 | request_text VARCHAR(4000), 131 | request_payloads JSON, 132 | request_intent VARCHAR(100), 133 | request_is_adhoc BOOLEAN, 134 | response_type VARCHAR(100), 135 | response_text VARCHAR(4000), 136 | response_payloads JSON, 137 | response_milliseconds INT, 138 | context_is_new BOOLEAN, 139 | context_topic_name TEXT, 140 | context_topic_status TEXT, 141 | context_topic_is_new BOOLEAN, 142 | context_topic_keep_on BOOLEAN, 143 | context_topic_priority INT, 144 | context_error JSON, 145 | request_json JSON, 146 | response_json JSON, 147 | context_json JSON) 148 | """.format(self.table_name), 149 | "write": """ 150 | insert into {0} ( 151 | channel, 152 | channel_detail, 153 | channel_user_id, 154 | request_timestamp, 155 | request_id, 156 | request_type, 157 | request_text, 158 | request_payloads, 159 | request_intent, 160 | request_is_adhoc, 161 | response_type, 162 | response_text, 163 | response_payloads, 164 | response_milliseconds, 165 | context_is_new, 166 | context_topic_name, 167 | context_topic_status, 168 | context_topic_is_new, 169 | context_topic_keep_on, 170 | context_topic_priority, 171 | context_error, 172 | request_json, response_json, context_json) 173 | values ( 174 | %s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) 175 | """.format(self.table_name), 176 | } 177 | 178 | 179 | class MySQLStores(StoreSet): 180 | connection_provider = MySQLConnectionProvider 181 | context_store = MySQLContextStore 182 | user_store = MySQLUserStore 183 | messagelog_store = MySQLMessageLogStore 184 | -------------------------------------------------------------------------------- /minette/datastore/sqldbstores.py: -------------------------------------------------------------------------------- 1 | import pyodbc 2 | 3 | from .connectionprovider import ConnectionProvider 4 | from .contextstore import ContextStore 5 | from .userstore import UserStore 6 | from .messagelogstore import MessageLogStore 7 | from .storeset import StoreSet 8 | 9 | 10 | class SQLDBConnectionProvider(ConnectionProvider): 11 | """ 12 | Connection provider for Azure SQL Database 13 | 14 | Attributes 15 | ---------- 16 | connection_str : str 17 | Connection string 18 | """ 19 | def get_connection(self): 20 | """ 21 | Get connection 22 | 23 | Returns 24 | ------- 25 | connection : Connection 26 | Database connection 27 | """ 28 | return pyodbc.connect(self.connection_str) 29 | 30 | 31 | class SQLDBContextStore(ContextStore): 32 | """ 33 | Session store for Azure SQL Database to enable successive conversation 34 | 35 | """ 36 | def get_sqls(self): 37 | """ 38 | Get SQLs used in ContextStore 39 | 40 | Returns 41 | ------- 42 | sqls : dict 43 | SQLs used in ContextStore 44 | """ 45 | return { 46 | "prepare_check": "select id from dbo.sysobjects where id = object_id('{0}')".format(self.table_name), 47 | "prepare_create": "create table {0} (channel NVARCHAR(20), channel_user_id NVARCHAR(100), timestamp DATETIME2, topic_name NVARCHAR(100), topic_status NVARCHAR(100), topic_previous NVARCHAR(4000), topic_priority INT, data NVARCHAR(MAX), primary key(channel, channel_user_id))".format(self.table_name), 48 | "get_context": "select top 1 * from {0} where channel=? and channel_user_id=?".format(self.table_name), 49 | "save_context": """ 50 | merge into {0} as A 51 | using (select ? as channel, ? as channel_user_id, ? as timestamp, ? as topic_name, ? as topic_status, ? as topic_previous, ? as topic_priority, ? as data) as B 52 | on (A.channel = B.channel and A.channel_user_id = B.channel_user_id) 53 | when matched then 54 | update set timestamp=B.timestamp, topic_name=B.topic_name, topic_status=B.topic_status, topic_previous=B.topic_previous, topic_priority=B.topic_priority, data=B.data 55 | when not matched then 56 | insert (channel, channel_user_id, timestamp, topic_name, topic_status, topic_previous, topic_priority, data) values (B.channel, B.channel_user_id, B.timestamp, B.topic_name, B.topic_status, B.topic_previous, B.topic_priority, B.data); 57 | """.format(self.table_name) 58 | } 59 | 60 | 61 | class SQLDBUserStore(UserStore): 62 | """ 63 | User store for Azure SQL Database 64 | 65 | """ 66 | def get_sqls(self): 67 | """ 68 | Get SQLs used in UserStore 69 | 70 | Returns 71 | ------- 72 | sqls : dict 73 | SQLs used in UserStore 74 | """ 75 | return { 76 | "prepare_check": "select id from dbo.sysobjects where id = object_id('{0}')".format(self.table_name), 77 | "prepare_create": "create table {0} (channel NVARCHAR(20), channel_user_id NVARCHAR(100), user_id NVARCHAR(100), timestamp DATETIME2, name NVARCHAR(100), nickname NVARCHAR(100), profile_image_url NVARCHAR(500), data NVARCHAR(MAX), primary key(channel, channel_user_id))".format(self.table_name), 78 | "get_user": "select top 1 channel, channel_user_id, user_id, timestamp, name, nickname, profile_image_url, data from {0} where channel=? and channel_user_id=?".format(self.table_name), 79 | "add_user": "insert into {0} (channel, channel_user_id, user_id, timestamp, name, nickname, profile_image_url, data) values (?,?,?,?,?,?,?,?)".format(self.table_name), 80 | "save_user": "update {0} set timestamp=?, name=?, nickname=?, profile_image_url=?, data=? where channel=? and channel_user_id=?".format(self.table_name), 81 | } 82 | 83 | 84 | class SQLDBMessageLogStore(MessageLogStore): 85 | """ 86 | Message log store for Azure SQL Database 87 | 88 | """ 89 | def get_sqls(self): 90 | """ 91 | Get SQLs used in MessageLogStore 92 | 93 | Returns 94 | ------- 95 | sqls : dict 96 | SQLs used in MessageLogStore 97 | """ 98 | return { 99 | "prepare_check": "select id from dbo.sysobjects where id = object_id('{0}')".format(self.table_name), 100 | "prepare_create": """ 101 | create table {0} ( 102 | id INT primary key identity, 103 | channel NVARCHAR(20), 104 | channel_detail NVARCHAR(100), 105 | channel_user_id NVARCHAR(100), 106 | request_timestamp DATETIME2, 107 | request_id NVARCHAR(100), 108 | request_type NVARCHAR(100), 109 | request_text NVARCHAR(MAX), 110 | request_payloads NVARCHAR(MAX), 111 | request_intent NVARCHAR(100), 112 | request_is_adhoc BIT, 113 | response_type NVARCHAR(100), 114 | response_text NVARCHAR(MAX), 115 | response_payloads NVARCHAR(MAX), 116 | response_milliseconds INT, 117 | context_is_new BIT, 118 | context_topic_name TEXT, 119 | context_topic_status TEXT, 120 | context_topic_is_new BIT, 121 | context_topic_keep_on BIT, 122 | context_topic_priority INT, 123 | context_error NVARCHAR(MAX), 124 | request_json NVARCHAR(MAX), 125 | response_json NVARCHAR(MAX), 126 | context_json NVARCHAR(MAX)) 127 | """.format(self.table_name), 128 | 129 | "write": """ 130 | insert into {0} ( 131 | channel, 132 | channel_detail, 133 | channel_user_id, 134 | request_timestamp, 135 | request_id, 136 | request_type, 137 | request_text, 138 | request_payloads, 139 | request_intent, 140 | request_is_adhoc, 141 | response_type, 142 | response_text, 143 | response_payloads, 144 | response_milliseconds, 145 | context_is_new, 146 | context_topic_name, 147 | context_topic_status, 148 | context_topic_is_new, 149 | context_topic_keep_on, 150 | context_topic_priority, 151 | context_error, 152 | request_json, response_json, context_json) 153 | values ( 154 | ?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) 155 | """.format(self.table_name), 156 | } 157 | 158 | 159 | class SQLDBStores(StoreSet): 160 | connection_provider = SQLDBConnectionProvider 161 | context_store = SQLDBContextStore 162 | user_store = SQLDBUserStore 163 | messagelog_store = SQLDBMessageLogStore 164 | -------------------------------------------------------------------------------- /minette/datastore/sqlitestores.py: -------------------------------------------------------------------------------- 1 | """ Set of data stores and connection provider using SQLite """ 2 | import sqlite3 3 | 4 | from .connectionprovider import ConnectionProvider 5 | from .contextstore import ContextStore 6 | from .userstore import UserStore 7 | from .messagelogstore import MessageLogStore 8 | from .storeset import StoreSet 9 | 10 | 11 | class SQLiteConnectionProvider(ConnectionProvider): 12 | """ 13 | Connection provider for SQLite 14 | 15 | Attributes 16 | ---------- 17 | connection_str : str 18 | Connection string 19 | """ 20 | def get_connection(self): 21 | """ 22 | Get connection 23 | 24 | Returns 25 | ------- 26 | connection : Connection 27 | Database connection 28 | """ 29 | connection = sqlite3.connect( 30 | self.connection_str, detect_types=sqlite3.PARSE_DECLTYPES) 31 | connection.row_factory = sqlite3.Row 32 | return connection 33 | 34 | 35 | class SQLiteContextStore(ContextStore): 36 | """ 37 | ContextStore using SQLite 38 | 39 | """ 40 | def get_sqls(self): 41 | """ 42 | Get SQLs used in ContextStore 43 | 44 | Returns 45 | ------- 46 | sqls : dict 47 | SQLs used in ContextStore 48 | """ 49 | return { 50 | "prepare_check": """ 51 | select * from sqlite_master where type='table' and name='{0}' 52 | """.format(self.table_name), 53 | "prepare_create": """ 54 | create table {0} ( 55 | channel TEXT, channel_user_id TEXT, timestamp TIMESTAMP, 56 | topic_name TEXT, topic_status TEXT, topic_previous TEXT, 57 | topic_priority INTEGER, data TEXT, primary key(channel, 58 | channel_user_id))""".format(self.table_name), 59 | "get_context": """ 60 | select 61 | channel, channel_user_id, timestamp, topic_name, 62 | topic_status, topic_previous, topic_priority, data 63 | from {0} 64 | where 65 | channel=? and channel_user_id=? limit 1 66 | """.format(self.table_name), 67 | "save_context": """ 68 | replace into {0} ( 69 | channel, channel_user_id, timestamp, topic_name, 70 | topic_status, topic_previous, topic_priority, data) 71 | values ( 72 | ?,?,?,?,?,?,?,?) 73 | """.format(self.table_name), 74 | } 75 | 76 | 77 | class SQLiteUserStore(UserStore): 78 | """ 79 | UserStore using SQLite 80 | 81 | """ 82 | def get_sqls(self): 83 | """ 84 | Get SQLs used in UserStore 85 | 86 | Returns 87 | ------- 88 | sqls : dict 89 | SQLs used in UserStore 90 | """ 91 | return { 92 | "prepare_check": """ 93 | select * from sqlite_master where type='table' and name='{0}' 94 | """.format(self.table_name), 95 | "prepare_create": """ 96 | create table {0} ( 97 | channel TEXT, channel_user_id TEXT, user_id TEXT, 98 | timestamp TIMESTAMP, name TEXT, nickname TEXT, 99 | profile_image_url TEXT, data TEXT, 100 | primary key(channel, channel_user_id)) 101 | """.format(self.table_name), 102 | "get_user": """ 103 | select 104 | channel, channel_user_id, user_id,timestamp, name, 105 | nickname, profile_image_url, data 106 | from {0} 107 | where 108 | channel=? and channel_user_id=? limit 1 109 | """.format(self.table_name), 110 | "add_user": """ 111 | insert into {0} ( 112 | channel, channel_user_id, user_id, timestamp, name, 113 | nickname, profile_image_url, data) 114 | values ( 115 | ?,?,?,?,?,?,?,?) 116 | """.format(self.table_name), 117 | "save_user": """ 118 | update {0} 119 | set 120 | timestamp=?, name=?, nickname=?, profile_image_url=?, 121 | data=? 122 | where 123 | channel=? and channel_user_id=? 124 | """.format(self.table_name), 125 | } 126 | 127 | 128 | class SQLiteMessageLogStore(MessageLogStore): 129 | """ 130 | MessageLogStore using SQLite 131 | 132 | """ 133 | def get_sqls(self): 134 | """ 135 | Get SQLs used in MessageLogStore 136 | 137 | Returns 138 | ------- 139 | sqls : dict 140 | SQLs used in MessageLogStore 141 | """ 142 | return { 143 | "prepare_check": """ 144 | select * from sqlite_master where type='table' and name='{0}' 145 | """.format(self.table_name), 146 | "prepare_create": """ 147 | create table {0} ( 148 | id INTEGER PRIMARY KEY, 149 | channel TEXT, 150 | channel_detail TEXT, 151 | channel_user_id TEXT, 152 | request_timestamp TIMESTAMP, 153 | request_id TEXT, 154 | request_type TEXT, 155 | request_text TEXT, 156 | request_payloads TEXT, 157 | request_intent TEXT, 158 | request_is_adhoc BOOLEAN, 159 | response_type TEXT, 160 | response_text TEXT, 161 | response_payloads TEXT, 162 | response_milliseconds INT, 163 | context_is_new BOOLEAN, 164 | context_topic_name TEXT, 165 | context_topic_status TEXT, 166 | context_topic_is_new BOOLEAN, 167 | context_topic_keep_on BOOLEAN, 168 | context_topic_priority INTEGER, 169 | context_error TEXT, 170 | request_json TEXT, 171 | response_json TEXT, 172 | context_json TEXT) 173 | """.format(self.table_name), 174 | "write": """ 175 | insert into {0} ( 176 | channel, 177 | channel_detail, 178 | channel_user_id, 179 | request_timestamp, 180 | request_id, 181 | request_type, 182 | request_text, 183 | request_payloads, 184 | request_intent, 185 | request_is_adhoc, 186 | response_type, 187 | response_text, 188 | response_payloads, 189 | response_milliseconds, 190 | context_is_new, 191 | context_topic_name, 192 | context_topic_status, 193 | context_topic_is_new, 194 | context_topic_keep_on, 195 | context_topic_priority, 196 | context_error, 197 | request_json, response_json, context_json) 198 | values ( 199 | ?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) 200 | """.format(self.table_name), 201 | } 202 | 203 | 204 | class SQLiteStores(StoreSet): 205 | """ 206 | Set of data stores and connection provider using SQLite 207 | 208 | """ 209 | connection_provider = SQLiteConnectionProvider 210 | context_store = SQLiteContextStore 211 | user_store = SQLiteUserStore 212 | messagelog_store = SQLiteMessageLogStore 213 | -------------------------------------------------------------------------------- /minette/datastore/storeset.py: -------------------------------------------------------------------------------- 1 | """ Base class for set of data stores and connection provider for them """ 2 | 3 | 4 | class StoreSet: 5 | connection_provider = None 6 | context_store = None 7 | user_store = None 8 | messagelog_store = None 9 | -------------------------------------------------------------------------------- /minette/datastore/userstore.py: -------------------------------------------------------------------------------- 1 | """ Base class for UserStore """ 2 | from abc import ABC, abstractmethod 3 | import traceback 4 | from logging import getLogger 5 | from datetime import datetime 6 | from pytz import timezone as tz 7 | 8 | from ..serializer import dumps, loads 9 | from ..models import User 10 | 11 | 12 | class UserStore(ABC): 13 | """ 14 | Base class for UserStore to read/write user information 15 | 16 | Attributes 17 | ---------- 18 | config : minette.Config 19 | Configuration 20 | timezone : pytz.timezone 21 | Timezone 22 | logger : logging.Logger 23 | Logger 24 | table_name : str 25 | Database table name for read/write user data 26 | sqls : dict 27 | SQLs used in ContextStore 28 | """ 29 | 30 | def __init__(self, config=None, timezone=None, logger=None, 31 | table_name="user", **kwargs): 32 | """ 33 | Parameters 34 | ---------- 35 | config : minette.Config, default None 36 | Configuration 37 | timezone : pytz.timezone, default None 38 | Timezone 39 | logger : logging.Logger, default None 40 | Logger 41 | table_name : str, default "user" 42 | Database table name for read/write user data 43 | """ 44 | self.config = config 45 | self.timezone = timezone or ( 46 | tz(config.get("timezone", default="UTC")) if config else tz("UTC")) 47 | self.logger = logger if logger else getLogger(__name__) 48 | self.table_name = table_name 49 | self.sqls = self.get_sqls() 50 | 51 | @abstractmethod 52 | def get_sqls(self): 53 | pass 54 | 55 | def prepare_table(self, connection, prepare_params=None): 56 | """ 57 | Check and create table if not exist 58 | 59 | Parameters 60 | ---------- 61 | connection : Connection 62 | Connection for prepare 63 | 64 | query_params : tuple, default tuple() 65 | Query parameters for checking table 66 | """ 67 | cursor = connection.cursor() 68 | cursor.execute(self.sqls["prepare_check"], prepare_params or tuple()) 69 | if not cursor.fetchone(): 70 | cursor.execute(self.sqls["prepare_create"]) 71 | connection.commit() 72 | return True 73 | else: 74 | return False 75 | 76 | def get(self, channel, channel_user_id, connection): 77 | """ 78 | Get user by channel and channel_user_id 79 | 80 | Parameters 81 | ---------- 82 | channel : str 83 | Channel 84 | channel_user_id : str 85 | Channel user ID 86 | connection : Connection 87 | Connection 88 | 89 | Returns 90 | ------- 91 | user : minette.User 92 | User 93 | """ 94 | user = User(channel=channel, channel_user_id=channel_user_id) 95 | if not channel_user_id: 96 | return user 97 | try: 98 | cursor = connection.cursor() 99 | cursor.execute(self.sqls["get_user"], (channel, channel_user_id)) 100 | row = cursor.fetchone() 101 | if row is not None: 102 | # convert to dict 103 | if isinstance(row, dict): 104 | record = row 105 | else: 106 | record = dict( 107 | zip([column[0] for column in cursor.description], row)) 108 | # convert type 109 | record["data"] = loads(record["data"]) 110 | # restore user 111 | user.id = record["user_id"] 112 | user.name = record["name"] 113 | user.nickname = record["nickname"] 114 | user.profile_image_url = record["profile_image_url"] 115 | user.data = record["data"] if record["data"] else {} 116 | else: 117 | cursor.execute(self.sqls["add_user"], ( 118 | channel, channel_user_id, user.id, 119 | datetime.now(self.timezone), user.name, user.nickname, 120 | user.profile_image_url, None) 121 | ) 122 | connection.commit() 123 | except Exception as ex: 124 | self.logger.error( 125 | "Error occured in restoring user from database: " 126 | + str(ex) + "\n" + traceback.format_exc()) 127 | return user 128 | 129 | def save(self, user, connection): 130 | """ 131 | Save user 132 | 133 | Parameters 134 | ---------- 135 | user : minette.User 136 | User to save 137 | connection : Connection 138 | Connection 139 | """ 140 | user_dict = user.to_dict() 141 | serialized_data = dumps(user_dict["data"]) 142 | cursor = connection.cursor() 143 | cursor.execute(self.sqls["save_user"], ( 144 | datetime.now(self.timezone), user.name, user.nickname, 145 | user.profile_image_url, serialized_data, user.channel, 146 | user.channel_user_id) 147 | ) 148 | connection.commit() 149 | -------------------------------------------------------------------------------- /minette/dialog/__init__.py: -------------------------------------------------------------------------------- 1 | from .service import ( 2 | DialogService, 3 | EchoDialogService, 4 | ErrorDialogService, 5 | ) 6 | from .router import DialogRouter 7 | from .dependency import DependencyContainer 8 | -------------------------------------------------------------------------------- /minette/dialog/dependency.py: -------------------------------------------------------------------------------- 1 | """ Container class for components that DialogRouter/DialogServices depend """ 2 | 3 | 4 | class DependencyContainer: 5 | def __init__(self, dialog, dependency_rules=None, **defaults): 6 | # set default dependencies 7 | for k, v in defaults.items(): 8 | setattr(self, k, v) 9 | # set dialog specific dependencies 10 | if dependency_rules: 11 | dialog_dependencies = dependency_rules.get(type(dialog)) 12 | if dialog_dependencies: 13 | for k, v in dialog_dependencies.items(): 14 | setattr(self, k, v) 15 | -------------------------------------------------------------------------------- /minette/dialog/router.py: -------------------------------------------------------------------------------- 1 | """ Base class for DialogRouter that routes proper dialog for the intent """ 2 | from abc import ABC, abstractmethod 3 | import traceback 4 | from logging import Logger, getLogger 5 | 6 | from ..models import Message, Priority 7 | from .service import DialogService, ErrorDialogService 8 | from .dependency import DependencyContainer 9 | 10 | 11 | class DialogRouter: 12 | """ 13 | Base class for DialogRouter 14 | 15 | Attributes 16 | ---------- 17 | config : minette.Config 18 | Configuration 19 | timezone : timezone 20 | Timezone 21 | logger : logging.Logger 22 | Logger 23 | default_dialog_service : DialogService 24 | Dialog service used when intent is not clear 25 | dependency_rules : dict 26 | Rules that defines on which components each DialogService/Router depends 27 | default_dependencies : dict 28 | Dependency components for all DialogServices/Router 29 | dependencies : DependencyContainer 30 | Container to attach objects DialogRouter depends 31 | intent_resolver : dict 32 | Resolver for intent to dialog 33 | topic_resolver : dict 34 | Resolver for topic to dialog for successive chatting 35 | """ 36 | 37 | def __init__(self, config=None, timezone=None, logger=None, 38 | default_dialog_service=None, intent_resolver=None, **kwargs): 39 | """ 40 | Parameters 41 | ---------- 42 | config : minette.Config, default None 43 | Configuration 44 | timezone : pytz.timezone, default None 45 | Timezone 46 | logger : logging.Logger, default None 47 | Logger 48 | default_dialog_service : minette.DialogService or type, default None 49 | Dialog service used when intent is not clear. 50 | """ 51 | self.config = config 52 | self.timezone = timezone 53 | self.logger = logger or getLogger(__name__) 54 | self.default_dialog_service = default_dialog_service or DialogService 55 | self.dependency_rules = None 56 | self.default_dependencies = None or {} # empty dict is required to unpack 57 | self.dependencies = None 58 | # set up intent_resolver 59 | self.intent_resolver = intent_resolver or {} 60 | self.register_intents() 61 | # set up topic_resolver 62 | self.topic_resolver = { 63 | v.topic_name(): v for v in self.intent_resolver.values() if v} 64 | self.topic_resolver[self.default_dialog_service.topic_name()] = \ 65 | self.default_dialog_service 66 | 67 | def register_intents(self): 68 | """ 69 | Register intents and the dialog services to process the intents 70 | 71 | >>> self.intent_resolver = { 72 | "PizzaOrderIntent": PizzaDialogService, 73 | "ChangeAddressIntent": ChangeAddressDialogService, 74 | } 75 | """ 76 | pass 77 | 78 | def execute(self, request, context, connection, performance): 79 | """ 80 | Main logic of DialogRouter 81 | 82 | Parameters 83 | ---------- 84 | request : minette.Message 85 | Request message 86 | context : minette.Context 87 | Context 88 | connection : Connection 89 | Connection 90 | performance : minette.PerformanceInfo 91 | Performance information 92 | 93 | Returns 94 | ------- 95 | dialog_service : minette.DialogService 96 | DialogService to process request message 97 | """ 98 | try: 99 | # extract intent and entities 100 | extracted = self.extract_intent( 101 | request=request, context=context, connection=connection) 102 | if isinstance(extracted, tuple): 103 | request.intent = extracted[0] 104 | request.entities = extracted[1] 105 | if len(extracted) > 2: 106 | request.intent_priority = extracted[2] 107 | elif isinstance(extracted, str): 108 | request.intent = extracted 109 | performance.append("dialog_router.extract_intent") 110 | # preprocess before route 111 | self.before_route(request, context, connection) 112 | performance.append("dialog_router.before_route") 113 | # route dialog 114 | dialog_service = self.route(request, context, connection) 115 | if issubclass(dialog_service, DialogService): 116 | dialog_service = dialog_service( 117 | config=self.config, timezone=self.timezone, 118 | logger=self.logger 119 | ) 120 | dialog_service.dependencies = DependencyContainer( 121 | dialog_service, 122 | self.dependency_rules, 123 | **self.default_dependencies) 124 | performance.append("dialog_router.route") 125 | except Exception as ex: 126 | self.logger.error( 127 | "Error occured in dialog_router: " 128 | + str(ex) + "\n" + traceback.format_exc()) 129 | dialog_service = \ 130 | self.handle_exception(request, context, ex, connection) 131 | 132 | return dialog_service 133 | 134 | def extract_intent(self, request, context, connection): 135 | """ 136 | Extract intent and entities from request message 137 | 138 | Parameters 139 | ---------- 140 | request : minette.Message 141 | Request message 142 | context : minette.Context 143 | Context 144 | connection : Connection 145 | Connection 146 | 147 | Returns 148 | ------- 149 | response : tuple of (str, dict) 150 | Intent and entities 151 | """ 152 | return request.intent, request.entities 153 | 154 | def before_route(self, request, context, connection): 155 | """ 156 | Preprocessing for all requests before routing 157 | 158 | Parameters 159 | ---------- 160 | request : minette.Message 161 | Request message 162 | context : minette.Context 163 | Context 164 | connection : Connection 165 | Connection 166 | """ 167 | pass 168 | 169 | def route(self, request, context, connection): 170 | """ 171 | Return proper DialogService for intent or topic 172 | 173 | Parameters 174 | ---------- 175 | request : minette.Message 176 | Request message 177 | context : minette.Context 178 | Context 179 | connection : Connection 180 | Connection 181 | 182 | 183 | Returns 184 | ------- 185 | dialog_service : minette.DialogService 186 | Dialog service proper for intent or topic 187 | """ 188 | # update 189 | if request.intent in self.intent_resolver and ( 190 | request.intent_priority > context.topic.priority or 191 | not context.topic.name): 192 | dialog_service = self.intent_resolver[request.intent] 193 | # update topic if request is not adhoc 194 | if dialog_service and not request.is_adhoc: 195 | context.topic.name = dialog_service.topic_name() 196 | context.topic.status = "" 197 | if request.intent_priority >= Priority.Highest: 198 | # set slightly lower priority to enable to update Highest priority intent 199 | context.topic.priority = Priority.Highest - 1 200 | else: 201 | context.topic.priority = request.intent_priority 202 | context.topic.is_new = True 203 | # do not update topic when request is adhoc or ds is None 204 | else: 205 | dialog_service = dialog_service or DialogService 206 | if context.topic.name: 207 | context.topic.keep_on = True 208 | 209 | # continue 210 | elif context.topic.name: 211 | dialog_service = self.topic_resolver[context.topic.name] 212 | 213 | # default (intent not extracted or unknown) 214 | else: 215 | dialog_service = self.default_dialog_service 216 | context.topic.name = dialog_service.topic_name() 217 | context.topic.status = "" 218 | context.topic.is_new = True 219 | return dialog_service 220 | 221 | def handle_exception(self, request, context, exception, connection): 222 | """ 223 | Handle exception and return ErrorDialogService 224 | 225 | Parameters 226 | ---------- 227 | request : minette.Message 228 | Request message 229 | context : minette.Context 230 | Context 231 | exception : Exception 232 | Exception 233 | connection : Connection 234 | Connection 235 | 236 | Returns 237 | ------- 238 | response : minette.ErrorDialogService 239 | Dialog service for error occured in chatting 240 | """ 241 | context.set_error(exception) 242 | return ErrorDialogService( 243 | config=self.config, timezone=self.timezone, logger=self.logger) 244 | -------------------------------------------------------------------------------- /minette/dialog/service.py: -------------------------------------------------------------------------------- 1 | """ Base class for DialogService for processing each dialogs """ 2 | import traceback 3 | from logging import Logger, getLogger 4 | 5 | from ..models import ( 6 | Message, 7 | Response, 8 | Context, 9 | PerformanceInfo 10 | ) 11 | 12 | 13 | class DialogService: 14 | """ 15 | Base class for DialogService 16 | 17 | Attributes 18 | ---------- 19 | config : minette.Config 20 | Configuration 21 | timezone : timezone 22 | Timezone 23 | logger : logging.Logger 24 | Logger 25 | dependencies : DependencyContainer 26 | Container to attach objects DialogRouter depends 27 | """ 28 | 29 | @classmethod 30 | def topic_name(cls): 31 | """ 32 | Topic name of this dialog service 33 | 34 | Returns 35 | ------- 36 | topic_name : str 37 | Topic name of this dialog service 38 | """ 39 | cls_name = cls.__name__.lower() 40 | if cls_name.endswith("dialogservice"): 41 | cls_name = cls_name[:-13] 42 | elif cls_name.endswith("dialog"): 43 | cls_name = cls_name[:-6] 44 | return cls_name 45 | 46 | def __init__(self, config=None, timezone=None, logger=None): 47 | """ 48 | Parameters 49 | ---------- 50 | config : minette.Config, default None 51 | Configuration 52 | timezone : pytz.timezone, default None 53 | Timezone 54 | logger : logging.Logger, default None 55 | Logger 56 | """ 57 | self.config = config 58 | self.timezone = timezone 59 | self.logger = logger or getLogger(__name__) 60 | self.dependencies = None 61 | 62 | def execute(self, request, context, connection, performance): 63 | """ 64 | Main logic of DialogService 65 | 66 | Parameters 67 | ---------- 68 | request : minette.Message 69 | Request message 70 | context : minette.Context 71 | Context 72 | connection : Connection 73 | Connection 74 | performance : minette.PerformanceInfo 75 | Performance information 76 | 77 | Returns 78 | ------- 79 | response : minette.Response 80 | Response from chatbot 81 | """ 82 | try: 83 | # extract entities 84 | for k, v in self.extract_entities( 85 | request, context, connection).items(): 86 | if not request.entities.get(k, ""): 87 | request.entities[k] = v 88 | performance.append("dialog_service.extract_entities") 89 | 90 | # initialize context data 91 | if context.topic.is_new: 92 | context.data = self.get_slots(request, context, connection) 93 | performance.append("dialog_service.get_slots") 94 | 95 | # process request 96 | self.process_request(request, context, connection) 97 | performance.append("dialog_service.process_request") 98 | 99 | # compose response 100 | response_messages = \ 101 | self.compose_response(request, context, connection) 102 | if not response_messages: 103 | self.logger.info("No response") 104 | response_messages = [] 105 | elif not isinstance(response_messages, list): 106 | response_messages = [response_messages] 107 | response = Response() 108 | for rm in response_messages: 109 | if isinstance(rm, Message): 110 | response.messages.append(rm) 111 | elif isinstance(rm, str): 112 | response.messages.append(request.to_reply(text=rm)) 113 | performance.append("dialog_service.compose_response") 114 | 115 | except Exception as ex: 116 | self.logger.error( 117 | "Error occured in dialog_service: " 118 | + str(ex) + "\n" + traceback.format_exc()) 119 | response = Response(messages=[ 120 | self.handle_exception(request, context, ex, connection)]) 121 | 122 | return response 123 | 124 | def extract_entities(self, request, context, connection): 125 | """ 126 | Extract entities from request message 127 | 128 | Parameters 129 | ---------- 130 | request : minette.Message 131 | Request message 132 | context : minette.Context 133 | Context 134 | connection : Connection 135 | Connection 136 | 137 | Returns 138 | ------- 139 | entities : dict 140 | Entities extracted from request message 141 | """ 142 | return {} 143 | 144 | def get_slots(self, request, context, connection): 145 | """ 146 | Get initial context.data at the start of this dialog 147 | 148 | Parameters 149 | ---------- 150 | request : minette.Message 151 | Request message 152 | context : minette.Context 153 | Context 154 | connection : Connection 155 | Connection 156 | 157 | Returns 158 | ------- 159 | slots : dict 160 | Initial context.data 161 | """ 162 | return {} 163 | 164 | def process_request(self, request, context, connection): 165 | """ 166 | Process your chatbot's functions/skills and setup context data 167 | 168 | Parameters 169 | ---------- 170 | request : minette.Message 171 | Request message 172 | context : minette.Context 173 | Context 174 | connection : Connection 175 | Connection 176 | """ 177 | pass 178 | 179 | def compose_response(self, request, context, connection): 180 | """ 181 | Compose response messages using context data 182 | 183 | Parameters 184 | ---------- 185 | request : minette.Message 186 | Request message 187 | context : minette.Context 188 | Context 189 | connection : Connection 190 | Connection 191 | 192 | Returns 193 | ------- 194 | response : minette.Response 195 | Response from chatbot 196 | """ 197 | return "" 198 | 199 | def handle_exception(self, request, context, exception, connection): 200 | """ 201 | Handle exception and return error response message 202 | 203 | Parameters 204 | ---------- 205 | request : minette.Message 206 | Request message 207 | context : minette.Context 208 | Context 209 | exception : Exception 210 | Exception 211 | connection : Connection 212 | Connection 213 | 214 | Returns 215 | ------- 216 | response : minette.Response 217 | Error response from chatbot 218 | """ 219 | context.set_error(exception) 220 | context.topic.keep_on = False 221 | return request.to_reply(text="?") 222 | 223 | 224 | class EchoDialogService(DialogService): 225 | """ 226 | Simple echo dialog service for tutorial 227 | 228 | """ 229 | def compose_response(self, request, context, connection=None): 230 | return request.to_reply(text="You said: {}".format(request.text)) 231 | 232 | 233 | class ErrorDialogService(DialogService): 234 | """ 235 | Dialog service for error in chatting 236 | 237 | """ 238 | def compose_response(self, request, context, connection=None): 239 | return request.to_reply(text="?") 240 | -------------------------------------------------------------------------------- /minette/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .user import User 2 | from .priority import Priority 3 | from .topic import Topic 4 | from .context import Context 5 | from .group import Group 6 | from .payload import Payload 7 | from .message import Message 8 | from .performance import PerformanceInfo 9 | from .response import Response 10 | from .wordnode import WordNode 11 | -------------------------------------------------------------------------------- /minette/models/context.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from copy import deepcopy 3 | 4 | from ..serializer import Serializable 5 | from .priority import Priority 6 | from .topic import Topic 7 | 8 | 9 | class Context(Serializable): 10 | """ 11 | Context 12 | 13 | Attributes 14 | ---------- 15 | channel : str 16 | Channel 17 | channel_user_id : str 18 | Channel user ID 19 | timestamp : datetime 20 | Timestamp 21 | is_new : bool 22 | True if context is created at this turn 23 | topic : Topic 24 | Current topic 25 | data : dict 26 | Data slots 27 | error : dict 28 | Error info 29 | """ 30 | def __init__(self, channel=None, channel_user_id=None): 31 | """ 32 | Parameters 33 | ---------- 34 | channel : str, default None 35 | Channel 36 | channel_user_id : str, default None 37 | Channel user ID 38 | """ 39 | self.channel = channel 40 | self.channel_user_id = \ 41 | channel_user_id if isinstance(channel_user_id, str) else "" 42 | self.timestamp = None 43 | self.is_new = True 44 | self.topic = Topic() 45 | self.data = {} 46 | self.error = {} 47 | 48 | def reset(self, keep_data=False): 49 | """ 50 | Backup to previous topic and remove data 51 | 52 | Parameters 53 | ---------- 54 | keep_data : bool, default False 55 | Keep context data to next turn 56 | """ 57 | # backup previous topic 58 | self.topic.previous = None 59 | self.topic.previous = deepcopy(self.topic) 60 | # remove data if topic not keep_on 61 | if not self.topic.keep_on: 62 | self.topic.name = "" 63 | self.topic.status = "" 64 | self.topic.priority = Priority.Normal 65 | self.data = self.data if keep_data else {} 66 | self.error = {} 67 | 68 | def set_error(self, ex, info=None): 69 | """ 70 | Set error info 71 | 72 | Parameters 73 | ---------- 74 | ex : Exception 75 | Exception 76 | info : dict, default None 77 | More information for debugging 78 | """ 79 | self.error = { 80 | "exception": str(ex), 81 | "traceback": traceback.format_exc(), 82 | "info": info if info else {}} 83 | 84 | @classmethod 85 | def _types(cls): 86 | return { 87 | "topic": Topic 88 | } 89 | -------------------------------------------------------------------------------- /minette/models/group.py: -------------------------------------------------------------------------------- 1 | from ..serializer import Serializable 2 | 3 | 4 | class Group(Serializable): 5 | """ 6 | Group 7 | 8 | Attributes 9 | ---------- 10 | id : str 11 | ID of group 12 | type : str 13 | Type of group 14 | """ 15 | def __init__(self, id=None, type=None): 16 | """ 17 | Parameters 18 | ---------- 19 | id : str, default None 20 | ID of group 21 | type : str, default None 22 | Type of group 23 | """ 24 | self.id = id 25 | self.type = type 26 | -------------------------------------------------------------------------------- /minette/models/message.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pytz import timezone as tz 3 | from copy import copy 4 | 5 | from ..serializer import Serializable 6 | from .payload import Payload 7 | from .priority import Priority 8 | from .user import User 9 | 10 | 11 | class Message(Serializable): 12 | """ 13 | Message 14 | 15 | Attributes 16 | ---------- 17 | id : str 18 | Message ID 19 | type : str 20 | Message type 21 | timestamp : datetime 22 | Timestamp 23 | channel : str 24 | Channel 25 | channel_detail : str 26 | Detail information of channel 27 | channel_user_id : str 28 | User ID for channel 29 | channel_message : Any 30 | Original message from channel 31 | token : Any 32 | Token to do actions to the channel 33 | user : User 34 | User 35 | group : Group 36 | Group 37 | text : str 38 | Body of message 39 | words : list of minette.WordNode 40 | Word nodes parsed by tagger 41 | payloads : list of minette.Payload 42 | Payloads 43 | intent : str 44 | Intent extracted from message 45 | intent_priority : int 46 | Priority for processing intent 47 | entities : dict 48 | Entities extracted from message 49 | is_adhoc : bool 50 | Process adhoc or not 51 | """ 52 | def __init__(self, id=None, type="text", channel="console", 53 | channel_detail=None, channel_user_id="anonymous", 54 | timestamp=None, channel_message=None, token=None, 55 | user=None, group=None, text=None, payloads=None, intent=None, 56 | intent_priority=None, entities=None, is_adhoc=False): 57 | """ 58 | Parameters 59 | ---------- 60 | id : str, default None 61 | Message ID 62 | type : str default "text" 63 | Message type 64 | channel : str, default "console" 65 | Channel 66 | channel_detail : str, default None 67 | Detail information of channel 68 | channel_user_id : str, default "anonymous" 69 | User ID for channel 70 | channel_message : Any, default None 71 | Original message from channel 72 | token : Any, default None 73 | Token to do actions to the channel 74 | text : str, default None 75 | Body of message 76 | payloads : list of minette.Payload, default None 77 | Payloads 78 | """ 79 | self.id = id or "" 80 | self.type = type 81 | self.timestamp = timestamp or datetime.now(tz("UTC")) 82 | self.channel = channel 83 | self.channel_detail = channel_detail or "" 84 | self.channel_user_id = channel_user_id 85 | self.channel_message = channel_message 86 | self.token = token or "" 87 | self.user = user 88 | self.group = group 89 | self.text = text or "" 90 | self.words = [] 91 | self.payloads = payloads or [] 92 | self.intent = intent or "" 93 | self.intent_priority = intent_priority or Priority.Normal 94 | self.entities = entities or {} 95 | self.is_adhoc = is_adhoc 96 | 97 | def to_reply(self, text=None, payloads=None, type="text"): 98 | """ 99 | Get reply message for this message 100 | 101 | Parameters 102 | ---------- 103 | text : str, default None 104 | Body of reply message 105 | payloads : list of minette.Payload, default None 106 | Payloads 107 | type : str default "text" 108 | Message type 109 | 110 | Returns 111 | ------- 112 | reply_message : minette.Message 113 | Reply message for this message 114 | """ 115 | message = copy(self) 116 | message.timestamp = datetime.now(message.timestamp.tzinfo) 117 | message.channel_message = None 118 | message.type = type 119 | message.text = text 120 | message.words = [] 121 | message.payloads = payloads if payloads else [] 122 | message.intent = "" 123 | message.entities = {} 124 | message.is_adhoc = False 125 | return message 126 | 127 | def reply(self, text=None, payloads=None, type="text"): 128 | print("WARNING: `reply` is deprecated and will be deleted at version 0.5. Use `to_reply` instead.") 129 | return self.to_reply(text=text, payloads=payloads, type=type) 130 | 131 | def to_dict(self): 132 | """ 133 | Convert Message object to dict 134 | 135 | Returns 136 | ------- 137 | message_dict : dict 138 | Message dictionary 139 | """ 140 | message_dict = super().to_dict() 141 | # channel_message is not JSON serializable 142 | message_dict["channel_message"] = str(message_dict["channel_message"]) 143 | return message_dict 144 | 145 | @classmethod 146 | def _types(cls): 147 | return { 148 | "user": User, 149 | "payloads": Payload 150 | } 151 | -------------------------------------------------------------------------------- /minette/models/payload.py: -------------------------------------------------------------------------------- 1 | from ..serializer import Serializable 2 | 3 | 4 | class Payload(Serializable): 5 | """ 6 | Payload 7 | 8 | Attributes 9 | ---------- 10 | content_type : str 11 | Content type of payload 12 | content : Any 13 | Content data 14 | headers : dict 15 | Headers of content or headers to get content 16 | url : str 17 | Url to get content 18 | thumb : str 19 | URL to get thumbnail image 20 | """ 21 | def __init__(self, *, content_type="image", content=None, headers=None, 22 | url=None, thumb=None,): 23 | """ 24 | Parameters 25 | ---------- 26 | content_type : str, default "image" 27 | Content type of payload 28 | content : Any, default None 29 | Content data 30 | headers : dict, default None 31 | Headers of content or headers to get content 32 | url : str, default None 33 | Url to get content 34 | thumb : str, default None 35 | URL to get thumbnail image 36 | """ 37 | self.content_type = content_type 38 | self.content = content 39 | self.headers = headers or {} 40 | self.url = url 41 | self.thumb = thumb if thumb is not None else url 42 | -------------------------------------------------------------------------------- /minette/models/performance.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from ..serializer import Serializable 3 | 4 | 5 | class PerformanceInfo(Serializable): 6 | """ 7 | Performance information of each steps 8 | 9 | Attributes 10 | ---------- 11 | start_time : float 12 | Unix epoch seconds at start 13 | ticks : list 14 | Seconds since start_time 15 | milliseconds : int 16 | Total processing time in milliseconds 17 | """ 18 | def __init__(self): 19 | self.start_time = time() 20 | self.ticks = [] 21 | self.milliseconds = 0 22 | 23 | def append(self, comment): 24 | """ 25 | Append current performance timestamp 26 | 27 | Parameters 28 | ---------- 29 | comment : str 30 | Comment to identify each steps 31 | """ 32 | self.ticks.append((comment, time() - self.start_time)) 33 | self.milliseconds = int(self.ticks[-1][1] * 1000) 34 | -------------------------------------------------------------------------------- /minette/models/priority.py: -------------------------------------------------------------------------------- 1 | class Priority: 2 | """ 3 | Priority of topic 4 | 5 | Attributes 6 | ---------- 7 | Highest : int 8 | Highest (100) 9 | High : int 10 | High (75) 11 | Normal : int 12 | Normal (50) 13 | Low : int 14 | Low (25) 15 | Ignore : int 16 | Ignore (0) 17 | """ 18 | Highest = 100 19 | High = 75 20 | Normal = 50 21 | Low = 25 22 | Ignore = 0 23 | -------------------------------------------------------------------------------- /minette/models/response.py: -------------------------------------------------------------------------------- 1 | from ..serializer import Serializable 2 | from .performance import PerformanceInfo 3 | from .message import Message 4 | 5 | 6 | class Response(Serializable): 7 | """ 8 | Response from chatbot 9 | 10 | Attributes 11 | ---------- 12 | messages : list of minette.Message 13 | Response messages 14 | headers : dict 15 | Response header 16 | performance : minette.PerformanceInfo 17 | Performance information of each steps in chat() 18 | """ 19 | def __init__(self, messages=None, headers=None, performance=None): 20 | """ 21 | Parameters 22 | ---------- 23 | messages : list of minette.Message, default None 24 | Response messages. If None, `[]` is set to `messages`. 25 | headers : dict, default None 26 | Response headers. If None, `{}` is set to `headers` 27 | performance : minette.PerformanceInfo, default None 28 | Performance information of each steps in chat(). 29 | If None, create new PerformanceInfo object. 30 | """ 31 | self.messages = messages or [] 32 | self.headers = headers or {} 33 | self.performance = performance or PerformanceInfo() 34 | 35 | @classmethod 36 | def _types(cls): 37 | return { 38 | "messages": Message, 39 | "performance": PerformanceInfo 40 | } 41 | -------------------------------------------------------------------------------- /minette/models/topic.py: -------------------------------------------------------------------------------- 1 | from ..serializer import Serializable 2 | from .priority import Priority 3 | 4 | 5 | class Topic(Serializable): 6 | """ 7 | Topic 8 | 9 | Attributes 10 | ---------- 11 | name : str 12 | Name of topic 13 | status : str 14 | Status of topic 15 | is_new : bool 16 | True if topic starts at this turn 17 | keep_on : bool 18 | True to keep this topic at next turn 19 | previous : minette.Topic 20 | Previous topic object 21 | priority : int 22 | Priority of topic 23 | is_changed : bool 24 | True if topic is changed at this turn 25 | """ 26 | def __init__(self): 27 | self.name = "" 28 | self.status = "" 29 | self.is_new = False 30 | self.keep_on = False 31 | self.previous = None 32 | self.priority = Priority.Normal 33 | 34 | @property 35 | def is_changed(self): 36 | if self.previous and self.previous.name == self.name: 37 | return False 38 | else: 39 | return True 40 | -------------------------------------------------------------------------------- /minette/models/user.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | from ..serializer import Serializable 3 | 4 | 5 | class User(Serializable): 6 | """ 7 | User 8 | 9 | Attributes 10 | ---------- 11 | id : str 12 | User ID 13 | name : str 14 | User name 15 | nickname : str 16 | Nickname 17 | channel : str 18 | Channel 19 | channel_user_id : str 20 | Channel user ID 21 | data : dict 22 | User data 23 | """ 24 | def __init__(self, channel=None, channel_user_id=None): 25 | """ 26 | Parameters 27 | ---------- 28 | channel : str, default None 29 | Channel 30 | channel_user_id : str, default None 31 | Channel user ID 32 | """ 33 | self.id = str(uuid4()) 34 | self.name = "" 35 | self.nickname = "" 36 | self.channel = channel 37 | self.channel_user_id = \ 38 | channel_user_id if isinstance(channel_user_id, str) else "" 39 | self.profile_image_url = "" 40 | self.data = {} 41 | -------------------------------------------------------------------------------- /minette/models/wordnode.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from ..serializer import Serializable 3 | 4 | 5 | class WordNode(ABC, Serializable): 6 | """ 7 | Base class of parsed word 8 | 9 | Attributes 10 | ---------- 11 | surface : str 12 | Surface of word 13 | part : str 14 | Part of the word 15 | part_detail1 : str 16 | Detail1 of part 17 | part_detail2 : str 18 | Detail2 of part 19 | part_detail3 : str 20 | Detail3 of part 21 | stem_type : str 22 | Stem type 23 | stem_form : str 24 | Stem form 25 | word : str 26 | Word itself 27 | kana : str 28 | Japanese kana of the word 29 | pronunciation : str 30 | Pronunciation of the word 31 | """ 32 | def __init__(self, surface, part, part_detail1, part_detail2, part_detail3, 33 | stem_type, stem_form, word, kana, pronunciation): 34 | self.surface = surface 35 | self.part = part 36 | self.part_detail1 = part_detail1 37 | self.part_detail2 = part_detail2 38 | self.part_detail3 = part_detail3 39 | self.stem_type = stem_type 40 | self.stem_form = stem_form 41 | self.word = word 42 | self.kana = kana 43 | self.pronunciation = pronunciation 44 | 45 | @classmethod 46 | @abstractmethod 47 | def create(cls, surface, features): 48 | pass 49 | -------------------------------------------------------------------------------- /minette/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Task, Scheduler 2 | -------------------------------------------------------------------------------- /minette/scheduler/base.py: -------------------------------------------------------------------------------- 1 | """ Scheduler for periodic tasks """ 2 | import traceback 3 | from logging import getLogger 4 | import schedule 5 | import time 6 | from concurrent.futures import ThreadPoolExecutor 7 | 8 | 9 | class Task: 10 | """ 11 | Base class of tasks 12 | 13 | Attributes 14 | ---------- 15 | config : minette.Config 16 | Configuration 17 | timezone : pytz.timezone 18 | Timezone 19 | logger : logging.Logger 20 | Logger 21 | connection_provider : minette.ConnectionProvider 22 | Connection provider to use database in each tasks 23 | """ 24 | def __init__(self, config=None, timezone=None, logger=None, 25 | connection_provider=None, **kwargs): 26 | """ 27 | Parameters 28 | ---------- 29 | config : minette.Config, default None 30 | Configuration 31 | timezone : pytz.timezone, default None 32 | Timezone 33 | logger : logging.Logger, default None 34 | Logger 35 | connection_provider : minette.ConnectionProvider 36 | Connection provider to use database in each tasks 37 | """ 38 | self.config = config 39 | self.timezone = timezone 40 | self.logger = logger or getLogger(__name__) 41 | self.connection_provider = connection_provider 42 | 43 | def do(self, **kwargs): 44 | """ 45 | Implement your periodic task 46 | 47 | """ 48 | self.logger.error("Task is not implemented") 49 | 50 | 51 | class Scheduler: 52 | """ 53 | Task scheduler for periodic tasks 54 | 55 | Examples 56 | -------- 57 | To start doing scheduled tasks, just create `Scheduler` instance 58 | and register task(s), then call `start()` 59 | 60 | >>> my_scheduler = MyScheduler() 61 | >>> my_scheduler.every_minutes(MyTask) 62 | >>> my_scheduler.start() 63 | 64 | To register tasks, this class provides shortcut methods. 65 | Each tasks run at worker threads. 66 | 67 | >>> my_scheduler.every_minutes(MyTask) 68 | >>> my_scheduler.every_seconds(MyTask, seconds=5) 69 | >>> my_scheduler.every_seconds(MyTask, seconds=5, arg1="val1", arg2="val2") 70 | 71 | You can also use internal `schedule` to register tasks 72 | then the tasks run at main thread. 73 | 74 | >>> my_scheduler.schedule.every().minute.do(self.create_task(MyTask)) 75 | >>> my_scheduler.schedule.every().hour.do(self.create_task(YourTask)) 76 | 77 | Notes 78 | ----- 79 | How to execute jobs in parallel? 80 | https://schedule.readthedocs.io/en/stable/faq.html#how-to-execute-jobs-in-parallel 81 | 82 | Attributes 83 | ---------- 84 | config : minette.Config 85 | Configuration 86 | timezone : pytz.timezone 87 | Timezone 88 | logger : logging.Logger 89 | Logger 90 | threads : int 91 | Number of worker threads to process tasks 92 | connection_provider : minette.ConnectionProvider 93 | Connection provider to use database in each tasks 94 | schedule : schedule 95 | schedule module 96 | executor : concurrent.futures.ThreadPoolExecutor 97 | Executor to run tasks at worker threads 98 | """ 99 | 100 | def __init__(self, config=None, timezone=None, logger=None, threads=None, 101 | connection_provider=None, **kwargs): 102 | """ 103 | Parameters 104 | ---------- 105 | config : minette.Config, default None 106 | Configuration 107 | timezone : pytz.timezone, default None 108 | Timezone 109 | logger : logging.Logger, default None 110 | Logger 111 | threads : int, default None 112 | Number of worker threads to process tasks 113 | connection_provider : ConnectionProvider, default None 114 | Connection provider to use database in each tasks 115 | """ 116 | self.config = config 117 | self.timezone = timezone 118 | self.logger = logger or getLogger(__name__) 119 | self.threads = threads 120 | self.connection_provider = connection_provider 121 | self.schedule = schedule 122 | self.executor = ThreadPoolExecutor( 123 | max_workers=self.threads, thread_name_prefix="SchedulerThread") 124 | self._is_running = False 125 | 126 | @property 127 | def is_running(self): 128 | return self._is_running 129 | 130 | def create_task(self, task_class, **kwargs): 131 | """ 132 | Create and return callable function of task 133 | 134 | Parameters 135 | ---------- 136 | task_class : type 137 | Class of task 138 | 139 | Returns 140 | ------- 141 | task_method : callable 142 | Callable interface of task 143 | """ 144 | if isinstance(task_class, type): 145 | if issubclass(task_class, Task): 146 | return task_class( 147 | config=self.config, 148 | timezone=self.timezone, 149 | logger=self.logger, 150 | connection_provider=self.connection_provider, 151 | **kwargs).do 152 | else: 153 | raise TypeError( 154 | "task_class should be a subclass of minette.Task " + 155 | "or callable, not {}".format(task_class.__name__)) 156 | 157 | elif callable(task_class): 158 | return task_class 159 | 160 | else: 161 | raise TypeError( 162 | "task_class should be a subclass of minette.Task " + 163 | "or callable, not the instance of {}".format( 164 | task_class.__class__.__name__)) 165 | 166 | def every_seconds(self, task, seconds=1, *args, **kwargs): 167 | self.schedule.every(seconds).seconds.do( 168 | self.executor.submit, self.create_task(task), *args, **kwargs) 169 | 170 | def every_minutes(self, task, minutes=1, *args, **kwargs): 171 | self.schedule.every(minutes).minutes.do( 172 | self.executor.submit, self.create_task(task), *args, **kwargs) 173 | 174 | def every_hours(self, task, hours=1, *args, **kwargs): 175 | self.schedule.every(hours).hours.do( 176 | self.executor.submit, self.create_task(task), *args, **kwargs) 177 | 178 | def every_days(self, task, days=1, *args, **kwargs): 179 | self.schedule.every(days).days.do( 180 | self.executor.submit, self.create_task(task), *args, **kwargs) 181 | 182 | def start(self): 183 | """ 184 | Start scheduler 185 | """ 186 | self.logger.info("Task scheduler started") 187 | self._is_running = True 188 | while self._is_running: 189 | self.schedule.run_pending() 190 | time.sleep(1) 191 | self.logger.info("Task scheduler stopped") 192 | 193 | def stop(self): 194 | """ 195 | Stop scheduler 196 | """ 197 | self._is_running = False 198 | -------------------------------------------------------------------------------- /minette/serializer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | import re 4 | from .utils import date_to_str, str_to_date 5 | 6 | 7 | def _is_datestring(s): 8 | return isinstance(s, str) and \ 9 | re.match(r"(\d{4})-(\d{2})-(\d{2})T(\d{2})\:(\d{2})\:(\d{2})", s) 10 | 11 | 12 | def _encode_datetime(obj): 13 | if isinstance(obj, datetime): 14 | return date_to_str(obj, obj.tzinfo is not None) 15 | 16 | 17 | def _decode_datetime(d): 18 | for k in d: 19 | if _is_datestring(d[k]): 20 | d[k] = str_to_date(d[k]) 21 | if isinstance(d[k], list): 22 | for i, v in enumerate(d[k]): 23 | if _is_datestring(v): 24 | d[k][i] = str_to_date(v) 25 | return d 26 | 27 | 28 | def dumpd(obj): 29 | """ 30 | Convert object to dict 31 | 32 | Parameters 33 | ---------- 34 | obj : object 35 | Object to convert 36 | 37 | Returns 38 | ------- 39 | d : dict 40 | Object as dict 41 | """ 42 | # return input directly if it is already dict 43 | if isinstance(obj, dict): 44 | return obj 45 | # return list of dict 46 | elif isinstance(obj, (list, tuple, set)): 47 | return [dumpd(o) for o in obj] 48 | # convert to dict 49 | data = {} 50 | for key in obj.__dict__.keys(): 51 | if not key.startswith("_"): 52 | # convert each items in list-like object 53 | if isinstance(getattr(obj, key, None), (list, tuple, set)): 54 | data[key] = [] 55 | for v in getattr(obj, key, None): 56 | if hasattr(v, "to_dict"): 57 | data[key].append(v.to_dict()) 58 | elif hasattr(v, "__dict__"): 59 | data[key].append(dumpd(v)) 60 | else: 61 | data[key].append(v) 62 | # convert each items in dict 63 | elif isinstance(getattr(obj, key, None), dict): 64 | data[key] = {} 65 | for k, v in getattr(obj, key, None).items(): 66 | if hasattr(v, "to_dict"): 67 | data[key][k] = v.to_dict() 68 | elif hasattr(v, "__dict__"): 69 | data[key][k] = dumpd(v) 70 | else: 71 | data[key][k] = v 72 | # convert object with `to_dict` 73 | elif hasattr(getattr(obj, key, None), "to_dict"): 74 | data[key] = getattr(obj, key).to_dict() 75 | # convert plain object 76 | elif hasattr(getattr(obj, key, None), "__dict__"): 77 | data[key] = dumpd(getattr(obj, key)) 78 | else: 79 | data[key] = getattr(obj, key, None) 80 | return data 81 | 82 | 83 | def loadd(d, obj_cls): 84 | """ 85 | Convert dict to object 86 | 87 | Parameters 88 | ---------- 89 | d : dict 90 | Dictionary to convert 91 | obj_cls : type 92 | Class of object to convert 93 | 94 | Returns 95 | ------- 96 | obj : object 97 | Instance of obj_cls 98 | """ 99 | # return None when input is None 100 | if d is None: 101 | return None 102 | # return the list of objects when input is list 103 | if isinstance(d, list): 104 | return [loadd(di, obj_cls) for di in d] 105 | # use `create_object` instead of its constructor 106 | if hasattr(obj_cls, "create_object"): 107 | obj = obj_cls.create_object(d) 108 | else: 109 | obj = obj_cls() 110 | # get member's type info 111 | types = obj_cls._types() if getattr(obj_cls, "_types", None) else {} 112 | # set values to object 113 | for k, v in d.items(): 114 | if k in types: 115 | if hasattr(types[k], "from_dict"): 116 | setattr(obj, k, types[k].from_dict(v)) 117 | else: 118 | setattr(obj, k, loadd(v, types[k])) 119 | else: 120 | setattr(obj, k, v) 121 | return obj 122 | 123 | 124 | def dumps(obj, **kwargs): 125 | """ 126 | Encode object/dict to JSON 127 | 128 | Parameters 129 | ---------- 130 | obj : object 131 | Object to encode 132 | 133 | Returns 134 | ------- 135 | s : str 136 | JSON string 137 | """ 138 | if obj is None: 139 | return "" 140 | d = dumpd(obj) 141 | return json.dumps(d, default=_encode_datetime, **kwargs) 142 | 143 | 144 | def loads(s, obj_cls=None, **kwargs): 145 | """ 146 | Decode JSON to dict/object 147 | 148 | Parameters 149 | ---------- 150 | s : str 151 | JSON string to decode 152 | obj_cls : type, default None 153 | Class of object to convert. If None, convert to dict 154 | 155 | Returns 156 | ------- 157 | obj : object 158 | Instance of obj_cls 159 | """ 160 | if s is None or s == "": 161 | return None 162 | d = json.loads(s, object_hook=_decode_datetime, **kwargs) 163 | if obj_cls is None: 164 | return d 165 | else: 166 | return loadd(d, obj_cls) 167 | 168 | 169 | class Serializable: 170 | """ 171 | Base class for serializable object 172 | 173 | """ 174 | 175 | @classmethod 176 | def _types(cls): 177 | """ 178 | Override this method to create instance of specific class for members. 179 | Configure like below then instance of `Foo` will be set to `self.foo` 180 | and `Bar` to `self.bar` 181 | ``` 182 | return { 183 | "foo": Foo, 184 | "bar": Bar 185 | } 186 | ``` 187 | """ 188 | return {} 189 | 190 | def __repr__(self): 191 | return "<{} at {}>\n{}".format( 192 | self.__class__.__name__, 193 | hex(id(self)), 194 | self.to_json(indent=2, ensure_ascii=False)) 195 | 196 | @classmethod 197 | def create_object(obj_cls, d): 198 | return obj_cls() 199 | 200 | def to_dict(self): 201 | """ 202 | Convert this object to dict 203 | 204 | Returns 205 | ------- 206 | d : dict 207 | Object as dict 208 | """ 209 | return dumpd(self) 210 | 211 | def to_json(self, **kwargs): 212 | """ 213 | Convert this object to JSON 214 | 215 | Returns 216 | ------- 217 | s : str 218 | Object as JSON string 219 | """ 220 | return dumps(self, **kwargs) 221 | 222 | @classmethod 223 | def from_dict(cls, d): 224 | """ 225 | Create object from dict 226 | 227 | Parameters 228 | ---------- 229 | d : dict 230 | Dictionary of this object 231 | 232 | Returns 233 | ------- 234 | obj : Serializable 235 | Instance of this class 236 | """ 237 | return loadd(d, cls) 238 | 239 | @classmethod 240 | def from_dict_dict(cls, dict_dict): 241 | """ 242 | Create dictionary of this objects from dictionaries of dictionaries 243 | 244 | Parameters 245 | ---------- 246 | dict_dict : dict 247 | Dictionary of dictionaries 248 | 249 | Returns 250 | ------- 251 | dict_of_this_obj : dict 252 | Dictionary of this objects 253 | """ 254 | return {k: cls.from_dict(v) for k, v in dict_dict.items()} 255 | 256 | @classmethod 257 | def from_json(cls, s, **kwargs): 258 | """ 259 | Create this object from JSON string 260 | 261 | Parameters 262 | ---------- 263 | s : str 264 | JSON string of this object 265 | 266 | Returns 267 | ------- 268 | obj : Serializable 269 | Instance of this class 270 | """ 271 | return loads(s, cls, **kwargs) 272 | -------------------------------------------------------------------------------- /minette/tagger/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Tagger 2 | -------------------------------------------------------------------------------- /minette/tagger/base.py: -------------------------------------------------------------------------------- 1 | """ Base for Taggers """ 2 | from logging import getLogger 3 | 4 | 5 | class Tagger: 6 | """ 7 | Base class for word taggers 8 | 9 | Attributes 10 | ---------- 11 | config : minette.Config 12 | Configuration 13 | timezone : pytz.timezone 14 | Timezone 15 | logger : logging.Logger 16 | Logger 17 | """ 18 | MAX_LENGTH = 1000 19 | 20 | def __init__(self, config=None, timezone=None, logger=None, *, 21 | max_length=MAX_LENGTH, **kwargs): 22 | """ 23 | Parameters 24 | ---------- 25 | config : Config, default None 26 | Configuration 27 | timezone : timezone, default None 28 | Timezone 29 | logger : Logger, default None 30 | Logger 31 | max_length : int, default 1000 32 | Max length of the text to parse 33 | """ 34 | self.config = config 35 | self.timezone = timezone 36 | self.logger = logger or getLogger(__name__) 37 | self.max_length = max_length 38 | 39 | def validate(self, text, max_length=None): 40 | if not text: 41 | return False 42 | elif max_length is not None: 43 | if len(text) > max_length: 44 | return False 45 | elif len(text) > self.max_length: 46 | return False 47 | 48 | return True 49 | 50 | def parse_as_generator(self, text, max_length=None): 51 | """ 52 | Analyze and parse text, returns Generator 53 | 54 | Parameters 55 | ---------- 56 | text : str 57 | Text to analyze 58 | max_length : int, default None 59 | Max length of the text to parse 60 | 61 | Returns 62 | ------- 63 | words : Generator of minette.WordNode (empty) 64 | Word nodes 65 | """ 66 | yield from () 67 | 68 | def parse(self, text, max_length=None): 69 | """ 70 | Analyze and parse text 71 | 72 | Parameters 73 | ---------- 74 | text : str 75 | Text to analyze 76 | max_length : int, default None 77 | Max length of the text to parse 78 | 79 | Returns 80 | ------- 81 | words : list of minette.WordNode (empty) 82 | Word nodes 83 | """ 84 | return [wn for wn in self.parse_as_generator(text, max_length)] 85 | -------------------------------------------------------------------------------- /minette/tagger/janometagger.py: -------------------------------------------------------------------------------- 1 | """ tagger using janome """ 2 | import traceback 3 | from janome.tokenizer import Tokenizer 4 | 5 | from ..models import WordNode 6 | from .base import Tagger 7 | 8 | 9 | class JanomeNode(WordNode): 10 | """ 11 | Parsed word node by Janome 12 | 13 | Attributes 14 | ---------- 15 | surface : str 16 | Surface of word 17 | part : str 18 | Part of the word 19 | part_detail1 : str 20 | Detail1 of part 21 | part_detail2 : str 22 | Detail2 of part 23 | part_detail3 : str 24 | Detail3 of part 25 | stem_type : str 26 | Stem type 27 | stem_form : str 28 | Stem form 29 | word : str 30 | Word itself 31 | kana : str 32 | Japanese kana of the word 33 | pronunciation : str 34 | Pronunciation of the word 35 | """ 36 | 37 | @classmethod 38 | def create(cls, surface, features): 39 | """ 40 | Create instance of JanomeNode 41 | 42 | Parameters 43 | ---------- 44 | surface : str 45 | Surface of the word 46 | features : janome.Token 47 | Features analyzed by Janome 48 | """ 49 | ps = features.part_of_speech.split(",") 50 | return cls( 51 | surface=surface, 52 | part=ps[0] if len(ps) > 1 and ps[0] != "*" else "", 53 | part_detail1=ps[1] if len(ps) > 2 and ps[1] != "*" else "", 54 | part_detail2=ps[2] if len(ps) > 3 and ps[2] != "*" else "", 55 | part_detail3=ps[3] if len(ps) > 4 and ps[3] != "*" else "", 56 | stem_type=features.infl_type if features.infl_type != "*" else "", 57 | stem_form=features.infl_form if features.infl_form != "*" else "", 58 | word=features.base_form if features.base_form != "*" else "", 59 | kana=features.reading if features.reading != "*" else "", 60 | pronunciation=features.phonetic if features.phonetic != "*" else "" 61 | ) 62 | 63 | 64 | class JanomeTagger(Tagger): 65 | """ 66 | Tagger using Janome 67 | 68 | Attributes 69 | ---------- 70 | config : minette.Config 71 | Configuration 72 | timezone : pytz.timezone 73 | Timezone 74 | logger : logging.Logger 75 | Logger 76 | """ 77 | 78 | def __init__(self, config=None, timezone=None, logger=None, *, 79 | max_length=Tagger.MAX_LENGTH, user_dic=None, **kwargs): 80 | """ 81 | Parameters 82 | ---------- 83 | config : Config, default None 84 | Configuration 85 | timezone : timezone, default None 86 | Timezone 87 | logger : Logger, default None 88 | Logger 89 | max_length : int, default 1000 90 | Max length of the text to parse 91 | user_dic : str, default None 92 | Path to user dictionary (MeCab IPADIC format) 93 | """ 94 | super().__init__(logger=logger, config=config, timezone=timezone, max_length=max_length) 95 | self.user_dic = user_dic or config.get("janome_userdic") if config else None 96 | if self.user_dic: 97 | self.tokenizer = Tokenizer(self.user_dic, udic_enc="utf8") 98 | else: 99 | self.tokenizer = Tokenizer() 100 | 101 | def parse_as_generator(self, text, max_length=None): 102 | """ 103 | Parse and annotate using Janome, returns Generator 104 | 105 | Parameters 106 | ---------- 107 | text : str 108 | Text to analyze 109 | max_length : int, default 1000 110 | Max length of the text to parse 111 | 112 | Returns 113 | ------- 114 | words : Generator of minette.minette.tagger.janometagger.JanomeNode 115 | Janome nodes 116 | """ 117 | if self.validate(text, max_length) is False: 118 | return 119 | 120 | try: 121 | for token in self.tokenizer.tokenize(text): 122 | yield JanomeNode.create(token.surface, token) 123 | except Exception as ex: 124 | self.logger.error( 125 | "Janome parsing error: " 126 | + str(ex) + "\n" + traceback.format_exc()) 127 | -------------------------------------------------------------------------------- /minette/tagger/mecabservice.py: -------------------------------------------------------------------------------- 1 | """ Tagger using mecab-service """ 2 | import traceback 3 | import requests 4 | 5 | from ..models import WordNode 6 | from .base import Tagger 7 | 8 | 9 | class MeCabServiceNode(WordNode): 10 | """ 11 | Parsed word node by MeCabServiceTagger 12 | 13 | Attributes 14 | ---------- 15 | surface : str 16 | Surface of word 17 | part : str 18 | Part of the word 19 | part_detail1 : str 20 | Detail1 of part 21 | part_detail2 : str 22 | Detail2 of part 23 | part_detail3 : str 24 | Detail3 of part 25 | stem_type : str 26 | Stem type 27 | stem_form : str 28 | Stem form 29 | word : str 30 | Word itself 31 | kana : str 32 | Japanese kana of the word 33 | pronunciation : str 34 | Pronunciation of the word 35 | """ 36 | 37 | @classmethod 38 | def create(cls, surface, features): 39 | """ 40 | Create instance of MeCabServiceNode 41 | 42 | Parameters 43 | ---------- 44 | surface : str 45 | Surface of the word 46 | features : dict 47 | Features analyzed by MeCabService 48 | """ 49 | return cls( 50 | surface=surface, 51 | part=features["part"], 52 | part_detail1=features["part_detail1"], 53 | part_detail2=features["part_detail2"], 54 | part_detail3=features["part_detail3"], 55 | stem_type=features["stem_type"], 56 | stem_form=features["stem_form"], 57 | word=features["word"], 58 | kana=features["kana"], 59 | pronunciation=features["pronunciation"] 60 | ) 61 | 62 | 63 | class MeCabServiceTagger(Tagger): 64 | """ 65 | Tagger using mecab-service 66 | 67 | Attributes 68 | ---------- 69 | config : minette.Config 70 | Configuration 71 | timezone : pytz.timezone 72 | Timezone 73 | logger : logging.Logger 74 | Logger 75 | api_url : str 76 | URL for MeCabService API 77 | """ 78 | 79 | def __init__(self, config=None, timezone=None, logger=None, *, 80 | api_url=None, **kwargs): 81 | """ 82 | Parameters 83 | ---------- 84 | config : Config, default None 85 | Configuration 86 | timezone : timezone, default None 87 | Timezone 88 | logger : Logger, default None 89 | Logger 90 | api_url : str, default None 91 | URL for MeCabService API. 92 | If None trial URL is used. 93 | """ 94 | super().__init__(config=config, timezone=timezone, logger=logger) 95 | if not api_url: 96 | self.api_url = "https://api.uezo.net/mecab/parse" 97 | self.logger.warning( 98 | "Do not use default API URL for the production environment. " 99 | "This is for trial use only. " 100 | "Install MeCab and use MeCabTagger instead.") 101 | else: 102 | self.api_url = api_url 103 | 104 | def parse(self, text): 105 | """ 106 | Parse and annotate using MeCab Service 107 | 108 | Parameters 109 | ---------- 110 | text : str 111 | Text to analyze 112 | 113 | Returns 114 | ------- 115 | words : list of minette.MeCabServiceNode 116 | MeCabService nodes 117 | """ 118 | ret = [] 119 | if not text: 120 | return ret 121 | try: 122 | parsed_json = requests.post( 123 | self.api_url, headers={"content-type": "application/json"}, 124 | json={"text": text}, timeout=10).json() 125 | ret = [MeCabServiceNode.create( 126 | n["surface"], n["features"]) for n in parsed_json["nodes"]] 127 | except Exception as ex: 128 | self.logger.error( 129 | "MeCab Service parsing error: " 130 | + str(ex) + "\n" + traceback.format_exc()) 131 | return ret 132 | -------------------------------------------------------------------------------- /minette/tagger/mecabtagger.py: -------------------------------------------------------------------------------- 1 | """ tagger using mecab """ 2 | import traceback 3 | import MeCab 4 | 5 | from ..models import WordNode 6 | from .base import Tagger 7 | 8 | 9 | class MeCabNode(WordNode): 10 | """ 11 | Parsed word node by MeCab 12 | 13 | Attributes 14 | ---------- 15 | surface : str 16 | Surface of word 17 | part : str 18 | Part of the word 19 | part_detail1 : str 20 | Detail1 of part 21 | part_detail2 : str 22 | Detail2 of part 23 | part_detail3 : str 24 | Detail3 of part 25 | stem_type : str 26 | Stem type 27 | stem_form : str 28 | Stem form 29 | word : str 30 | Word itself 31 | kana : str 32 | Japanese kana of the word 33 | pronunciation : str 34 | Pronunciation of the word 35 | """ 36 | 37 | @classmethod 38 | def create(cls, surface, features): 39 | """ 40 | Create instance of MeCabNode 41 | 42 | Parameters 43 | ---------- 44 | surface : str 45 | Surface of the word 46 | features : list 47 | Features analyzed by MeCab 48 | """ 49 | return cls( 50 | surface=surface, 51 | part=features[0] if features[0] != "*" else "", 52 | part_detail1=features[1] if features[1] != "*" else "", 53 | part_detail2=features[2] if features[2] != "*" else "", 54 | part_detail3=features[3] if features[3] != "*" else "", 55 | stem_type=features[4] if features[4] != "*" else "", 56 | stem_form=features[5] if features[5] != "*" else "", 57 | word=features[6] if features[6] != "*" else "", 58 | kana=features[7] if len(features) > 7 else "", 59 | pronunciation=features[8] if len(features) > 8 else "" 60 | ) 61 | 62 | 63 | class MeCabTagger(Tagger): 64 | """ 65 | Tagger using MeCab 66 | 67 | Attributes 68 | ---------- 69 | config : minette.Config 70 | Configuration 71 | timezone : pytz.timezone 72 | Timezone 73 | logger : logging.Logger 74 | Logger 75 | """ 76 | 77 | def parse_as_generator(self, text, max_length=None): 78 | """ 79 | Analyze and parse text using MeCab, returns Generator 80 | 81 | Parameters 82 | ---------- 83 | text : str 84 | Text to analyze 85 | max_length : int, default 1000 86 | Max length of the text to parse 87 | 88 | Returns 89 | ------- 90 | words : list of minette.tagger.mecabtagger.MeCabNode 91 | MeCab word nodes 92 | """ 93 | if self.validate(text, max_length) is False: 94 | return 95 | 96 | try: 97 | m = MeCab.Tagger("-Ochasen") 98 | # m.parse("") before m.parseToNode(text) against the bug that node.surface is not set 99 | m.parse("") 100 | node = m.parseToNode(text) 101 | while node: 102 | features = node.feature.split(",") 103 | if features[0] != "BOS/EOS": 104 | yield MeCabNode.create(node.surface, features) 105 | node = node.next 106 | except Exception as ex: 107 | self.logger.error( 108 | "MeCab parsing error: " 109 | + str(ex) + "\n" + traceback.format_exc()) 110 | -------------------------------------------------------------------------------- /minette/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .helper import MinetteForTest 2 | -------------------------------------------------------------------------------- /minette/testing/helper.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from ..core import Minette 4 | from ..models import Message 5 | 6 | 7 | class MinetteForTest(Minette): 8 | def __init__(self, **kwargs): 9 | super().__init__(**kwargs) 10 | self.default_channel = kwargs.get("default_channel", "") 11 | self.case_id = str(int(time() * 10000000)) 12 | 13 | def chat(self, request, **kwargs): 14 | self.logger.info("start testcase: " + self.case_id) 15 | # convert to Message 16 | if isinstance(request, str): 17 | request = Message(text=request, **kwargs) 18 | # set channel and channel_user_id for this test case 19 | if request.channel == "console": 20 | request.channel = self.default_channel or request.channel 21 | if request.channel_user_id == "anonymous": 22 | request.channel_user_id = "user" + self.case_id 23 | # chat and return response 24 | response = super().chat(request) 25 | if response.messages: 26 | response.text = response.messages[0].text 27 | else: 28 | response.text = "" 29 | self.logger.info("end testcase: " + self.case_id) 30 | return response 31 | -------------------------------------------------------------------------------- /minette/utils.py: -------------------------------------------------------------------------------- 1 | """ Utilities functions for minette """ 2 | from datetime import datetime 3 | import calendar 4 | 5 | 6 | def date_to_str(dt, with_timezone=False): 7 | """ 8 | Convert datetime to str 9 | 10 | Parameters 11 | ---------- 12 | dt : datetime 13 | datetime to convert 14 | with_timezone : bool, default False 15 | Include timezone or not 16 | 17 | Returns 18 | ------- 19 | datetime_str : str 20 | Datetime string 21 | """ 22 | if with_timezone and dt.tzinfo: 23 | dtstr = dt.strftime("%Y-%m-%dT%H:%M:%S.%f%z") 24 | return dtstr[:-2] + ":" + dtstr[-2:] 25 | else: 26 | return dt.strftime("%Y-%m-%dT%H:%M:%S.%f") 27 | 28 | 29 | def str_to_date(dtstr): 30 | """ 31 | Convert str to datetime 32 | 33 | Parameters 34 | ---------- 35 | dtstr : str 36 | str to convert 37 | 38 | Returns 39 | ------- 40 | datetime : datetime 41 | datetime 42 | """ 43 | if len(dtstr) > 19 and dtstr[-3:-2] == ":": 44 | dtstr = dtstr[:-3] + dtstr[-2:] 45 | fmt = "%Y-%m-%dT%H:%M:%S" 46 | if "." in dtstr: 47 | fmt += ".%f" 48 | if dtstr[-5] == "+" or dtstr[-5] == "-": 49 | fmt += "%z" 50 | return datetime.strptime(dtstr, fmt) 51 | 52 | 53 | def date_to_unixtime(dt): 54 | """ 55 | Convert datetime to unixtime 56 | 57 | Parameters 58 | ---------- 59 | dt : datetime 60 | datetime to convert 61 | 62 | Returns 63 | ------- 64 | unixtime : int 65 | Unixtime 66 | """ 67 | return calendar.timegm(dt.utctimetuple()) 68 | 69 | 70 | def unixtime_to_date(unixtime, tz=None): 71 | """ 72 | Convert unixtime to datetime 73 | 74 | Parameters 75 | ---------- 76 | unixtime : int 77 | unixtime to convert 78 | tz : timezone 79 | timezone to set 80 | 81 | Returns 82 | ------- 83 | datetime : datetime 84 | datetime 85 | """ 86 | return datetime.fromtimestamp(unixtime, tz=tz) 87 | -------------------------------------------------------------------------------- /minette/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.3" 2 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytz==2020.1 2 | schedule==0.6.0 3 | pytest==6.0.1 4 | Janome==0.4.0 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("./minette/version.py") as f: 4 | exec(f.read()) 5 | 6 | setup( 7 | name="minette", 8 | version=__version__, 9 | url="https://github.com/uezo/minette-python", 10 | author="uezo", 11 | author_email="uezo@uezo.net", 12 | maintainer="uezo", 13 | maintainer_email="uezo@uezo.net", 14 | description="Minette is a minimal and extensible chatbot framework. It is extremely easy to create chatbot and also enables you to make your chatbot more sophisticated and multi-skills, with preventing to be spaghetti code.", 15 | long_description=open("README.md").read(), 16 | long_description_content_type="text/markdown", 17 | packages=find_packages(exclude=["examples*", "develop*", "tests*"]), 18 | install_requires=["pytz", "schedule"], 19 | license="Apache v2", 20 | classifiers=[ 21 | "Programming Language :: Python :: 3" 22 | ] 23 | ) 24 | -------------------------------------------------------------------------------- /tests/adapter/test_adapter_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | from minette import Adapter, Message, DialogService 6 | 7 | 8 | class ChannelEvent: 9 | def __init__(self, text): 10 | self.text = text 11 | 12 | 13 | class CustomAdapter(Adapter): 14 | @staticmethod 15 | def _to_minette_message(event): 16 | return Message(text=event.text) 17 | 18 | @staticmethod 19 | def _to_channel_message(message): 20 | return message.text 21 | 22 | 23 | class MyDialog(DialogService): 24 | def compose_response(self, request, context, connection): 25 | return "res:" + request.text 26 | 27 | 28 | def test_init(): 29 | adapter = CustomAdapter( 30 | timezone=timezone("Asia/Tokyo"), prepare_table=True) 31 | assert adapter.timezone == timezone("Asia/Tokyo") 32 | assert adapter.bot.timezone == timezone("Asia/Tokyo") 33 | assert isinstance(adapter.executor, ThreadPoolExecutor) 34 | # run in main thread 35 | adapter = CustomAdapter( 36 | timezone=timezone("Asia/Tokyo"), prepare_table=True, threads=0) 37 | assert adapter.executor is None 38 | 39 | 40 | def test_extract_token(): 41 | adapter = CustomAdapter(prepare_table=True) 42 | token = adapter._extract_token(ChannelEvent("hello")) 43 | assert token == "" 44 | 45 | 46 | def test_handle_event(): 47 | adapter = CustomAdapter( 48 | default_dialog_service=MyDialog, debug=True, prepare_table=True) 49 | channel_messages, token = adapter.handle_event(ChannelEvent("hello")) 50 | assert channel_messages[0] == "res:hello" 51 | assert token == "" 52 | -------------------------------------------------------------------------------- /tests/adapter/test_clovaadapter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | try: 4 | from cek import ( 5 | Clova, 6 | URL 7 | ) 8 | from minette.adapter.clovaadapter import ClovaAdapter 9 | import request_samples as rs 10 | except Exception: 11 | # Skip if import dependencies not found 12 | pytestmark = pytest.mark.skip 13 | 14 | from minette import ( 15 | DialogService, 16 | Message, 17 | Config 18 | ) 19 | from minette.serializer import loads 20 | 21 | clovaconfig = Config("config/test_config_adapter.ini") 22 | application_id = clovaconfig.get("application_id", section="clova_cek") 23 | default_language = clovaconfig.get("default_language", section="clova_cek") 24 | 25 | # Skip if application_id is not provided 26 | if not application_id: 27 | pytestmark = pytest.mark.skip 28 | 29 | 30 | class MyDialog(DialogService): 31 | def compose_response(self, request, context, connection): 32 | if request.intent == "TurnOff": 33 | return ["Handled {}".format(request.type), "Finish turning off"] 34 | else: 35 | return "Handled {}".format(request.type) 36 | 37 | 38 | def test_init(): 39 | adapter = ClovaAdapter( 40 | application_id=application_id, 41 | default_language=default_language, prepare_table=True) 42 | assert adapter.application_id is not None 43 | assert adapter.application_id == clovaconfig.get("application_id", section="clova_cek") 44 | assert adapter.default_language is not None 45 | assert adapter.default_language == clovaconfig.get("default_language", section="clova_cek") 46 | assert isinstance(adapter.clova, Clova) 47 | 48 | 49 | def test_to_channel_message(): 50 | # text messages 51 | message = ClovaAdapter._to_channel_message( 52 | Message(text="hello", entities={"end_session": True, "reprompt": None})) 53 | assert message["speech_value"] == "hello" 54 | assert message["end_session"] is True 55 | assert message["reprompt"] is None 56 | 57 | # url messages 58 | message = ClovaAdapter._to_channel_message( 59 | Message(text="http://uezo.net", type="url", 60 | entities={"end_session": True, "reprompt": None})) 61 | assert message["speech_value"].value == URL("http://uezo.net").value 62 | assert message["end_session"] is True 63 | assert message["reprompt"] is None 64 | 65 | # end_session and reprompt 66 | message = ClovaAdapter._to_channel_message( 67 | Message(text="hello", entities={"end_session": False, "reprompt": "are you okay?"})) 68 | assert message["speech_value"] == "hello" 69 | assert message["end_session"] is False 70 | assert message["reprompt"] == "are you okay?" 71 | 72 | 73 | def test_handle_intent_request(): 74 | adapter = ClovaAdapter( 75 | application_id="com.line.myApplication", 76 | default_dialog_service=MyDialog, 77 | debug=True, 78 | prepare_table=True) 79 | request_headers = { 80 | "Signaturecek": rs.REQUEST_SIGNATURE, 81 | "Content-Type": "application/json;charset=UTF-8", 82 | "Content-Length": 578, 83 | "Host": "host.name.local", 84 | "Accept": "*/*", 85 | } 86 | 87 | # launch request 88 | response = loads(adapter.handle_http_request(rs.LAUNCH_REQUEST_BODY, request_headers)) 89 | assert response["response"]["outputSpeech"]["values"]["value"] == "Handled LaunchRequest" 90 | 91 | # intent request 92 | response = loads(adapter.handle_http_request(rs.INTENT_REQUEST_BODY, request_headers)) 93 | assert response["response"]["outputSpeech"]["values"]["value"] == "Handled IntentRequest" 94 | 95 | # intent request (multiple response message) 96 | response = loads(adapter.handle_http_request(rs.INTENT_REQUEST_TURN_OFF, request_headers)) 97 | assert response["response"]["outputSpeech"]["values"][0]["value"] == "Handled IntentRequest" 98 | assert response["response"]["outputSpeech"]["values"][1]["value"] == "Finish turning off" 99 | 100 | # end request 101 | response = loads(adapter.handle_http_request(rs.END_REQUEST_BODY, request_headers)) 102 | assert response["response"]["outputSpeech"]["values"]["value"] == "Handled SessionEndedRequest" 103 | 104 | # event request 105 | response = loads(adapter.handle_http_request(rs.EVENT_REQUEST_BODY, request_headers)) 106 | assert response["response"]["outputSpeech"]["values"]["value"] == "Handled EventRequest" 107 | 108 | # EventFromSkillStore request 109 | response = loads(adapter.handle_http_request(rs.EVENT_REQUEST_BODY_FROM_SKILL_STORE, request_headers)) 110 | assert response["response"]["outputSpeech"]["values"]["value"] == "Handled EventRequest" 111 | -------------------------------------------------------------------------------- /tests/config/test_config.ini: -------------------------------------------------------------------------------- 1 | [minette] 2 | key1 = value1 3 | key2 = ENV::VALUE2 4 | 5 | timezone = Asia/Tokyo 6 | log_file = configured_log_file.log 7 | logger_name = configured_logger_name 8 | connection_str = test_core.db 9 | context_timeout = 100 10 | context_table = test_core_context 11 | user_table = test_core_user 12 | messagelog_table = test_core_messagelog 13 | -------------------------------------------------------------------------------- /tests/config/test_config_empty.ini: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uezo/minette-python/dd8cd7d244b6e6e4133c8e73d637ded8a8c6846f/tests/config/test_config_empty.ini -------------------------------------------------------------------------------- /tests/datastore/test_connectionprovider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlite3 3 | 4 | # SQLite 5 | from minette import ( 6 | SQLiteConnectionProvider, 7 | Config 8 | ) 9 | 10 | # SQLDatabase 11 | SQLDBConnection = None 12 | SQLDBConnectionProvider = None 13 | try: 14 | import pyodbc 15 | SQLDBConnection = pyodbc.Connection 16 | from minette.datastore.sqldbstores import SQLDBConnectionProvider 17 | except Exception: 18 | pass 19 | 20 | # AzureTableStorage 21 | AzureTableConnection = None 22 | AzureTableConnectionProvider = None 23 | try: 24 | from minette.datastore.azurestoragestores import ( 25 | AzureTableConnection, 26 | AzureTableConnectionProvider 27 | ) 28 | except Exception: 29 | pass 30 | 31 | # MySQL 32 | MySQLConnection = None 33 | MySQLConnectionProvider = None 34 | try: 35 | from minette.datastore.mysqlstores import ( 36 | MySQLConnection, 37 | MySQLConnectionProvider 38 | ) 39 | except Exception: 40 | pass 41 | 42 | # SQLAlchemy 43 | SQLAlchemyConnection = None 44 | SQLAlchemyConnectionProvider = None 45 | try: 46 | from minette.datastore.sqlalchemystores import ( 47 | SQLAlchemyConnection, 48 | SQLAlchemyConnectionProvider 49 | ) 50 | except Exception: 51 | pass 52 | 53 | dbconfig = Config("config/test_config_datastores.ini") 54 | 55 | datastore_params = [ 56 | (sqlite3.Connection, SQLiteConnectionProvider, "test.db"), 57 | (SQLDBConnection, SQLDBConnectionProvider, dbconfig.get("sqldb_connection_str")), 58 | (AzureTableConnection, AzureTableConnectionProvider, dbconfig.get("table_connection_str")), 59 | (MySQLConnection, MySQLConnectionProvider, dbconfig.get("mysql_connection_str")), 60 | (SQLAlchemyConnection, SQLAlchemyConnectionProvider, dbconfig.get("sqlalchemy_sqlite_connection_str")), 61 | (SQLAlchemyConnection, SQLAlchemyConnectionProvider, dbconfig.get("sqlalchemy_sqldb_connection_str")), 62 | (SQLAlchemyConnection, SQLAlchemyConnectionProvider, dbconfig.get("sqlalchemy_mysql_connection_str")), 63 | ] 64 | 65 | 66 | @pytest.mark.parametrize("connection_class, connection_provider_class, connection_str", datastore_params) 67 | def test_get_connection(connection_class, connection_provider_class, connection_str): 68 | if not connection_class or not connection_provider_class: 69 | pytest.skip("Dependencies are not found") 70 | if not connection_str: 71 | pytest.skip( 72 | "Connection string for {} is not provided" 73 | .format(connection_provider_class.__name__)) 74 | 75 | cp = connection_provider_class(connection_str) 76 | with cp.get_connection() as connection: 77 | connection = cp.get_connection() 78 | assert isinstance(connection, connection_class) 79 | -------------------------------------------------------------------------------- /tests/datastore/test_contextstore.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import datetime 3 | from pytz import timezone 4 | from time import sleep 5 | 6 | from minette import ( 7 | SQLiteStores, 8 | Context, 9 | Config 10 | ) 11 | 12 | # SQLDatabase 13 | SQLDBStores = None 14 | try: 15 | from minette.datastore.sqldbstores import SQLDBStores 16 | except Exception: 17 | pass 18 | 19 | # AzureTableStorage 20 | AzureTableStores = None 21 | try: 22 | from minette.datastore.azurestoragestores import AzureTableStores 23 | except Exception: 24 | pass 25 | 26 | # MySQL 27 | MySQLStores = None 28 | try: 29 | from minette.datastore.mysqlstores import MySQLStores 30 | except Exception: 31 | pass 32 | 33 | # SQLAlchemy 34 | SQLAlchemyStores = None 35 | SQLAlchemyConnectionProvider = None 36 | try: 37 | from minette.datastore.sqlalchemystores import ( 38 | SQLAlchemyStores, 39 | SQLAlchemyConnectionProvider 40 | ) 41 | except Exception: 42 | pass 43 | 44 | from minette.utils import date_to_unixtime 45 | 46 | now = datetime.now(tz=timezone("Asia/Tokyo")) 47 | table_name = "context" + str(date_to_unixtime(now)) 48 | user_id = "user_id" + str(date_to_unixtime(now)) 49 | print("context_table: {}".format(table_name)) 50 | print("user_id: {}".format(user_id)) 51 | 52 | dbconfig = Config("config/test_config_datastores.ini") 53 | 54 | datastore_params = [ 55 | ( 56 | SQLiteStores, 57 | "test.db", 58 | ), 59 | ( 60 | SQLDBStores, 61 | dbconfig.get("sqldb_connection_str"), 62 | ), 63 | ( 64 | AzureTableStores, 65 | dbconfig.get("table_connection_str"), 66 | ), 67 | ( 68 | MySQLStores, 69 | dbconfig.get("mysql_connection_str"), 70 | ), 71 | ( 72 | SQLAlchemyStores, 73 | dbconfig.get("sqlalchemy_sqlite_connection_str"), 74 | ), 75 | ( 76 | SQLAlchemyStores, 77 | dbconfig.get("sqlalchemy_sqldb_connection_str"), 78 | ), 79 | ( 80 | SQLAlchemyStores, 81 | dbconfig.get("sqlalchemy_mysql_connection_str"), 82 | ), 83 | ] 84 | 85 | 86 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 87 | def test_prepare(datastore_class, connection_str): 88 | if not datastore_class: 89 | pytest.skip("Unable to import DataStoreSet") 90 | if not connection_str: 91 | pytest.skip( 92 | "Connection string for {} is not provided" 93 | .format(datastore_class.connection_provider.__name__)) 94 | 95 | cs = datastore_class.context_store( 96 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 97 | cp = datastore_class.connection_provider(connection_str) 98 | with cp.get_connection() as connection: 99 | prepare_params = cp.get_prepare_params() 100 | if SQLAlchemyConnectionProvider and isinstance(cp, SQLAlchemyConnectionProvider): 101 | assert cs.prepare_table(connection, prepare_params) is False 102 | else: 103 | assert cs.prepare_table(connection, prepare_params) is True 104 | assert cs.prepare_table(connection, prepare_params) is False 105 | 106 | 107 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 108 | def test_get(datastore_class, connection_str): 109 | if not datastore_class: 110 | pytest.skip("Unable to import DataStoreSet") 111 | if not connection_str: 112 | pytest.skip( 113 | "Connection string for {} is not provided" 114 | .format(datastore_class.connection_provider.__name__)) 115 | 116 | cs = datastore_class.context_store( 117 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 118 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 119 | ctx = cs.get("TEST", user_id, connection) 120 | assert ctx.channel == "TEST" 121 | assert ctx.channel_user_id == user_id 122 | assert ctx.is_new is True 123 | assert ctx.data == {} 124 | 125 | # get without user_id 126 | ctx_without_user = cs.get("TEST", None, connection) 127 | assert ctx_without_user.channel == "TEST" 128 | assert ctx_without_user.channel_user_id == "" 129 | 130 | 131 | @pytest.mark.parametrize("datastore_class,connection_str", datastore_params) 132 | def test_get_error(datastore_class, connection_str): 133 | if not datastore_class: 134 | pytest.skip("Unable to import DataStoreSet") 135 | if not connection_str: 136 | pytest.skip( 137 | "Connection string for {} is not provided" 138 | .format(datastore_class.connection_provider.__name__)) 139 | 140 | cs = datastore_class.context_store( 141 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 142 | cs.sqls["get_context"] = "" 143 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 144 | ctx = cs.get("TEST", user_id, connection) 145 | assert ctx.channel == "TEST" 146 | assert ctx.channel_user_id == user_id 147 | assert ctx.is_new is True 148 | assert ctx.data == {} 149 | 150 | 151 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 152 | def test_save(datastore_class, connection_str): 153 | if not datastore_class: 154 | pytest.skip("Unable to import DataStoreSet") 155 | if not connection_str: 156 | pytest.skip( 157 | "Connection string for {} is not provided" 158 | .format(datastore_class.connection_provider.__name__)) 159 | 160 | cs = datastore_class.context_store( 161 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 162 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 163 | ctx = cs.get("TEST", user_id, connection) 164 | ctx.data["strvalue"] = "value1" 165 | ctx.data["intvalue"] = 2 166 | ctx.data["dtvalue"] = now 167 | ctx.data["dictvalue"] = { 168 | "k1": "v1", 169 | "k2": 2, 170 | } 171 | cs.save(ctx, connection) 172 | ctx = cs.get("TEST", user_id, connection) 173 | assert ctx.channel == "TEST" 174 | assert ctx.channel_user_id == user_id 175 | assert ctx.is_new is False 176 | assert ctx.data == { 177 | "strvalue": "value1", 178 | "intvalue": 2, 179 | "dtvalue": now, 180 | "dictvalue": { 181 | "k1": "v1", 182 | "k2": 2, 183 | } 184 | } 185 | 186 | # save (not saved) 187 | cs.save(Context(), connection) 188 | 189 | # timeout 190 | cs_timeout = datastore_class.context_store( 191 | table_name=table_name, timezone=timezone("Asia/Tokyo"), timeout=3) 192 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 193 | ctx = cs_timeout.get("TEST", user_id + "_to", connection) 194 | assert ctx.is_new is True 195 | ctx.data["strvalue"] = "value1" 196 | cs.save(ctx, connection) 197 | sleep(1) # shorter than timeout 198 | ctx = cs_timeout.get("TEST", user_id + "_to", connection) 199 | assert ctx.is_new is False 200 | assert ctx.data["strvalue"] == "value1" 201 | sleep(5) # longer than timeout 202 | ctx = cs_timeout.get("TEST", user_id + "_to", connection) 203 | assert ctx.is_new is True 204 | assert ctx.data == {} 205 | -------------------------------------------------------------------------------- /tests/datastore/test_messagelogstore.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import datetime 3 | from pytz import timezone 4 | 5 | from minette import ( 6 | SQLiteStores, 7 | Message, 8 | Response, 9 | Context, 10 | Config 11 | ) 12 | from minette.serializer import dumpd 13 | 14 | SQLDBStores = None 15 | try: 16 | from minette.datastore.sqldbstores import SQLDBStores 17 | except Exception: 18 | pass 19 | 20 | AzureTableStores = None 21 | AzureTableConnection = None 22 | try: 23 | from minette.datastore.azurestoragestores import ( 24 | AzureTableStores, 25 | AzureTableConnection, 26 | ) 27 | except Exception: 28 | pass 29 | 30 | MySQLStores = None 31 | MySQLConnection = None 32 | try: 33 | from minette.datastore.mysqlstores import ( 34 | MySQLStores, 35 | MySQLConnection, 36 | ) 37 | except Exception: 38 | pass 39 | 40 | SQLAlchemyStores = None 41 | SQLAlchemyConnection = None 42 | SQLAlchemyMessageLog = None 43 | SQLAlchemyMessageLogStore = None 44 | try: 45 | from minette.datastore.sqlalchemystores import ( 46 | SQLAlchemyStores, 47 | SQLAlchemyConnection, 48 | SQLAlchemyMessageLog, 49 | SQLAlchemyMessageLogStore 50 | ) 51 | except Exception: 52 | pass 53 | 54 | from minette.utils import date_to_unixtime, date_to_str 55 | 56 | now = datetime.now() 57 | table_name = "messagelog" + str(date_to_unixtime(now)) 58 | user_id = "user_id" + str(date_to_unixtime(now)) 59 | print("messagelog_table: {}".format(table_name)) 60 | print("user_id: {}".format(user_id)) 61 | 62 | dbconfig = Config("config/test_config_datastores.ini") 63 | 64 | datastore_params = [ 65 | ( 66 | SQLiteStores, 67 | "test.db", 68 | ), 69 | ( 70 | SQLDBStores, 71 | dbconfig.get("sqldb_connection_str"), 72 | ), 73 | ( 74 | AzureTableStores, 75 | dbconfig.get("table_connection_str"), 76 | ), 77 | ( 78 | MySQLStores, 79 | dbconfig.get("mysql_connection_str"), 80 | ), 81 | ( 82 | SQLAlchemyStores, 83 | dbconfig.get("sqlalchemy_sqlite_connection_str"), 84 | ), 85 | ( 86 | SQLAlchemyStores, 87 | dbconfig.get("sqlalchemy_sqldb_connection_str"), 88 | ), 89 | ( 90 | SQLAlchemyStores, 91 | dbconfig.get("sqlalchemy_mysql_connection_str"), 92 | ), 93 | ] 94 | 95 | 96 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 97 | def test_prepare(datastore_class, connection_str): 98 | if not datastore_class: 99 | pytest.skip("Unable to import DataStoreSet") 100 | if not connection_str: 101 | pytest.skip( 102 | "Connection string for {} is not provided" 103 | .format(datastore_class.connection_provider.__name__)) 104 | 105 | ms = datastore_class.messagelog_store( 106 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 107 | cp = datastore_class.connection_provider(connection_str) 108 | with cp.get_connection() as connection: 109 | prepare_params = cp.get_prepare_params() 110 | if SQLAlchemyMessageLogStore and isinstance(ms, SQLAlchemyMessageLogStore): 111 | assert ms.prepare_table(connection, prepare_params) is False 112 | else: 113 | assert ms.prepare_table(connection, prepare_params) is True 114 | assert ms.prepare_table(connection, prepare_params) is False 115 | 116 | 117 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 118 | def test_save(datastore_class, connection_str): 119 | if not datastore_class: 120 | pytest.skip("Unable to import DataStoreSet") 121 | if not connection_str: 122 | pytest.skip( 123 | "Connection string for {} is not provided" 124 | .format(datastore_class.connection_provider.__name__)) 125 | 126 | ms = datastore_class.messagelog_store( 127 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 128 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 129 | # request 130 | request = Message( 131 | id=str(date_to_unixtime(now)), 132 | channel="TEST", channel_user_id=user_id, 133 | text="request message {}".format(str(date_to_unixtime(now)))) 134 | # response 135 | response = Response(messages=[Message(channel="TEST", channel_user_id=user_id, text="response message {}".format(str(date_to_unixtime(now))))]) 136 | # context 137 | context = Context("TEST", user_id) 138 | context.data = { 139 | "strvalue": "value1", 140 | "intvalue": 2, 141 | "dtvalue": date_to_str(now), 142 | "dictvalue": { 143 | "k1": "v1", 144 | "k2": 2, 145 | } 146 | } 147 | 148 | # save 149 | save_res = ms.save(request, response, context, connection) 150 | 151 | # check 152 | if AzureTableConnection and isinstance(connection, AzureTableConnection): 153 | record = connection.get_entity(table_name, user_id, save_res) 154 | elif SQLAlchemyConnection and isinstance(connection, SQLAlchemyConnection): 155 | record = connection.query(SQLAlchemyMessageLog).filter( 156 | SQLAlchemyMessageLog.request_id == str(date_to_unixtime(now)) 157 | ).first() 158 | record = dumpd(record) 159 | else: 160 | cursor = connection.cursor() 161 | if MySQLConnection and isinstance(connection, MySQLConnection): 162 | sql = "select * from {} where request_id = %s" 163 | else: 164 | sql = "select * from {} where request_id = ?" 165 | cursor.execute(sql.format(table_name), (str(date_to_unixtime(now)), )) 166 | row = cursor.fetchone() 167 | if isinstance(row, dict): 168 | record = row 169 | else: 170 | record = dict(zip([column[0] for column in cursor.description], row)) 171 | 172 | assert record["request_text"] == "request message {}".format(str(date_to_unixtime(now))) 173 | assert record["response_text"] == "response message {}".format(str(date_to_unixtime(now))) 174 | -------------------------------------------------------------------------------- /tests/datastore/test_userstore.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import datetime 3 | from pytz import timezone 4 | 5 | from minette import ( 6 | SQLiteStores, 7 | Config 8 | ) 9 | 10 | SQLDBStores = None 11 | try: 12 | from minette.datastore.sqldbstores import SQLDBStores 13 | except Exception: 14 | pass 15 | 16 | AzureTableStores = None 17 | try: 18 | from minette.datastore.azurestoragestores import AzureTableStores 19 | except Exception: 20 | pass 21 | 22 | MySQLStores = None 23 | try: 24 | from minette.datastore.mysqlstores import MySQLStores 25 | except Exception: 26 | pass 27 | 28 | SQLAlchemyStores = None 29 | SQLAlchemyUserStore = None 30 | try: 31 | from minette.datastore.sqlalchemystores import ( 32 | SQLAlchemyStores, 33 | SQLAlchemyUserStore 34 | ) 35 | except Exception: 36 | pass 37 | 38 | from minette.utils import date_to_unixtime 39 | 40 | now = datetime.now() 41 | table_name = "user" + str(date_to_unixtime(now)) 42 | user_id = "user_id" + str(date_to_unixtime(now)) 43 | print("user_table: {}".format(table_name)) 44 | print("user_id: {}".format(user_id)) 45 | 46 | dbconfig = Config("config/test_config_datastores.ini") 47 | 48 | datastore_params = [ 49 | ( 50 | SQLiteStores, 51 | "test.db", 52 | ), 53 | ( 54 | SQLDBStores, 55 | dbconfig.get("sqldb_connection_str"), 56 | ), 57 | ( 58 | AzureTableStores, 59 | dbconfig.get("table_connection_str"), 60 | ), 61 | ( 62 | MySQLStores, 63 | dbconfig.get("mysql_connection_str"), 64 | ), 65 | ( 66 | SQLAlchemyStores, 67 | dbconfig.get("sqlalchemy_sqlite_connection_str"), 68 | ), 69 | ( 70 | SQLAlchemyStores, 71 | dbconfig.get("sqlalchemy_sqldb_connection_str"), 72 | ), 73 | ( 74 | SQLAlchemyStores, 75 | dbconfig.get("sqlalchemy_mysql_connection_str"), 76 | ), 77 | ] 78 | 79 | 80 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 81 | def test_prepare(datastore_class, connection_str): 82 | if not datastore_class: 83 | pytest.skip("Unable to import DataStoreSet") 84 | if not connection_str: 85 | pytest.skip( 86 | "Connection string for {} is not provided" 87 | .format(datastore_class.connection_provider.__name__)) 88 | 89 | us = datastore_class.user_store( 90 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 91 | cp = datastore_class.connection_provider(connection_str) 92 | with cp.get_connection() as connection: 93 | prepare_params = cp.get_prepare_params() 94 | if SQLAlchemyUserStore and isinstance(us, SQLAlchemyUserStore): 95 | assert us.prepare_table(connection, prepare_params) is False 96 | else: 97 | assert us.prepare_table(connection, prepare_params) is True 98 | assert us.prepare_table(connection, prepare_params) is False 99 | 100 | 101 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 102 | def test_get(datastore_class, connection_str): 103 | if not datastore_class: 104 | pytest.skip("Unable to import DataStoreSet") 105 | if not connection_str: 106 | pytest.skip( 107 | "Connection string for {} is not provided" 108 | .format(datastore_class.connection_provider.__name__)) 109 | 110 | us = datastore_class.user_store( 111 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 112 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 113 | user = us.get("TEST", user_id, connection) 114 | assert user.channel == "TEST" 115 | assert user.channel_user_id == user_id 116 | assert user.data == {} 117 | # get(without user_id) 118 | user_without_user_id = us.get("TEST", None, connection) 119 | assert user_without_user_id.channel == "TEST" 120 | assert user_without_user_id.channel_user_id == "" 121 | 122 | 123 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 124 | def test_get_error(datastore_class, connection_str): 125 | if not datastore_class: 126 | pytest.skip("Unable to import DataStoreSet") 127 | if not connection_str: 128 | pytest.skip( 129 | "Connection string for {} is not provided" 130 | .format(datastore_class.connection_provider.__name__)) 131 | 132 | us = datastore_class.user_store( 133 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 134 | us.sqls["get_user"] = "" 135 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 136 | user = us.get("TEST", user_id, connection) 137 | assert user.channel == "TEST" 138 | assert user.channel_user_id == user_id 139 | assert user.data == {} 140 | # table doesn't exist 141 | us.table_name = "notexisttable" 142 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 143 | user = us.get("TEST", user_id, connection) 144 | assert user.channel == "TEST" 145 | assert user.channel_user_id == user_id 146 | assert user.data == {} 147 | # invalid table name 148 | us.table_name = "_#_#_" 149 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 150 | user = us.get("TEST", user_id, connection) 151 | assert user.channel == "TEST" 152 | assert user.channel_user_id == user_id 153 | assert user.data == {} 154 | 155 | 156 | @pytest.mark.parametrize("datastore_class, connection_str", datastore_params) 157 | def test_save(datastore_class, connection_str): 158 | if not datastore_class: 159 | pytest.skip("Unable to import DataStoreSet") 160 | if not connection_str: 161 | pytest.skip( 162 | "Connection string for {} is not provided" 163 | .format(datastore_class.connection_provider.__name__)) 164 | 165 | us = datastore_class.user_store( 166 | table_name=table_name, timezone=timezone("Asia/Tokyo")) 167 | with datastore_class.connection_provider(connection_str).get_connection() as connection: 168 | # save 169 | user = us.get("TEST", user_id, connection) 170 | user.data["strvalue"] = "value1" 171 | user.data["intvalue"] = 2 172 | user.data["dtvalue"] = now 173 | user.data["dictvalue"] = { 174 | "k1": "v1", 175 | "k2": 2, 176 | } 177 | us.save(user, connection) 178 | 179 | # check 180 | user = us.get("TEST", user_id, connection) 181 | assert user.channel == "TEST" 182 | assert user.channel_user_id == user_id 183 | assert user.data == { 184 | "strvalue": "value1", 185 | "intvalue": 2, 186 | "dtvalue": now, 187 | "dictvalue": { 188 | "k1": "v1", 189 | "k2": 2, 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /tests/dialog/test_dependency.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | 5 | from minette.dialog import ( 6 | DialogService, 7 | DialogRouter, 8 | DependencyContainer 9 | ) 10 | 11 | 12 | class SobaDialogService(DialogService): 13 | pass 14 | 15 | 16 | class UdonDialogService(DialogService): 17 | pass 18 | 19 | 20 | class RamenDialogService(DialogService): 21 | pass 22 | 23 | 24 | class MenDialogRouter(DialogRouter): 25 | pass 26 | 27 | 28 | def test_dependency(): 29 | # dependencies 30 | d1 = 1 31 | d2 = 2 32 | d3 = 3 33 | d4 = 4 34 | d5 = 5 35 | d6 = 6 36 | d7 = 7 37 | 38 | # define rules 39 | dependency_rules = { 40 | SobaDialogService: {"d1": d1, "d2": d2}, 41 | UdonDialogService: {"d2": d2, "d3": d3}, 42 | RamenDialogService: {"d3": d3, "d4": d4}, 43 | MenDialogRouter: {"d4": d4, "d5": d5} 44 | } 45 | 46 | # dialog service 47 | soba_dep = DependencyContainer(SobaDialogService(), dependency_rules, d6=d6, d7=d7) 48 | # dependencies for soba 49 | assert soba_dep.d1 == 1 50 | assert soba_dep.d2 == 2 51 | # dependencies for all 52 | assert soba_dep.d6 == 6 53 | assert soba_dep.d7 == 7 54 | # dependencies not for soba 55 | assert hasattr(soba_dep, "d3") is False 56 | 57 | # dialog router 58 | men_dep = DependencyContainer(MenDialogRouter(), dependency_rules, d6=d6, d7=d7) 59 | # dependencies for men 60 | assert men_dep.d4 == 4 61 | assert men_dep.d5 == 5 62 | # dependencies for all 63 | assert men_dep.d6 == 6 64 | assert men_dep.d7 == 7 65 | # dependencies not for men 66 | assert hasattr(men_dep, "d1") is False 67 | -------------------------------------------------------------------------------- /tests/dialog/test_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | 4 | from minette import DialogService, EchoDialogService, ErrorDialogService 5 | from minette import ( 6 | Message, 7 | Response, 8 | Context, 9 | PerformanceInfo 10 | ) 11 | 12 | 13 | class MyDialog(DialogService): 14 | pass 15 | 16 | 17 | class PizzaDialogService(DialogService): 18 | def extract_entities(self, request, context, connection): 19 | if "seafood" in request.text: 20 | return {"pizza_name": "Seafood Pizza"} 21 | else: 22 | return {"pizza_name": ""} 23 | 24 | def get_slots(self, request, context, connection): 25 | return { 26 | "pizza_name": "", 27 | "pizza_count": 0 28 | } 29 | 30 | def process_request(self, request, context, connection): 31 | if request.text == "error": 32 | # raise runtime error 33 | 1 / 0 34 | 35 | if request.text == "no message": 36 | context.topic.status = "no_message" 37 | return 38 | 39 | # confirmation 40 | if context.topic.status == "confirmation": 41 | if request.text == "yes": 42 | context.topic.status = "confirmed" 43 | return 44 | 45 | # get order 46 | context.data["pizza_name"] = context.data["pizza_name"] or request.entities.get("pizza_name", "") 47 | if context.data["pizza_name"]: 48 | context.topic.status = "confirmation" 49 | else: 50 | context.topic.status = "requireorder" 51 | 52 | def compose_response(self, request, context, connection): 53 | if context.topic.status == "no_message": 54 | return 55 | 56 | elif context.topic.status == "requireorder": 57 | context.topic.keep_on = True 58 | return "Which pizza?" 59 | 60 | elif context.topic.status == "confirmation": 61 | context.topic.keep_on = True 62 | return "Your order is {}?".format(context.data["pizza_name"]) 63 | 64 | elif context.topic.status == "confirmed": 65 | messages = [ 66 | request.to_reply(text="Thank you!"), 67 | "We will deliver {} in 30min.".format(context.data["pizza_name"]) 68 | ] 69 | return messages 70 | 71 | 72 | def test_init_base(): 73 | ds = DialogService(timezone=timezone("Asia/Tokyo")) 74 | assert ds.timezone == timezone("Asia/Tokyo") 75 | 76 | 77 | def test_topic_name(): 78 | my_ds = MyDialog() 79 | assert my_ds.topic_name() == "my" 80 | pizza_ds = PizzaDialogService() 81 | assert pizza_ds.topic_name() == "pizza" 82 | 83 | 84 | def test_handle_exception(): 85 | ds = MyDialog() 86 | context = Context("TEST", "test_user") 87 | request = Message(channel="TEST", channel_user_id="test_user", text="Hello") 88 | message = ds.handle_exception(request, context, ValueError("test error"), None) 89 | assert message.text == "?" 90 | assert context.error["exception"] == "test error" 91 | 92 | 93 | def test_execute(): 94 | ds = PizzaDialogService(timezone=timezone("Asia/Tokyo")) 95 | performance = PerformanceInfo() 96 | 97 | # first contact 98 | context = Context("TEST", "test_user") 99 | context.topic.is_new = True 100 | request = Message(channel="TEST", channel_user_id="test_user", text="Give me pizza") 101 | response = ds.execute(request, context, None, performance) 102 | assert response.messages[0].text == "Which pizza?" 103 | assert request.entities == {"pizza_name": ""} 104 | assert context.data == { 105 | "pizza_name": "", 106 | "pizza_count": 0 107 | } 108 | # say pizza name 109 | context.is_new = False 110 | context.topic.is_new = False 111 | request = Message(channel="TEST", channel_user_id="test_user", text="seafood pizza") 112 | response = ds.execute(request, context, None, performance) 113 | assert response.messages[0].text == "Your order is Seafood Pizza?" 114 | 115 | # confirmation 116 | request = Message(channel="TEST", channel_user_id="test_user", text="yes") 117 | response = ds.execute(request, context, None, performance) 118 | assert response.messages[0].text == "Thank you!" 119 | assert response.messages[1].text == "We will deliver Seafood Pizza in 30min." 120 | 121 | # raise error 122 | request = Message(channel="TEST", channel_user_id="test_user", text="error") 123 | response = ds.execute(request, context, None, performance) 124 | assert response.messages[0].text == "?" 125 | assert context.error["exception"] == "division by zero" 126 | 127 | # no response messages 128 | request = Message(channel="TEST", channel_user_id="test_user", text="no message") 129 | response = ds.execute(request, context, None, performance) 130 | assert response.messages == [] 131 | 132 | 133 | def test_execute_default(): 134 | ds = MyDialog() 135 | performance = PerformanceInfo() 136 | context = Context("TEST", "test_user") 137 | context.topic.is_new = True 138 | request = Message(channel="TEST", channel_user_id="test_user", text="hello") 139 | response = ds.execute(request, context, None, performance) 140 | assert response.messages == [] 141 | assert request.entities == {} 142 | assert context.data == {} 143 | 144 | 145 | def test_execute_echo(): 146 | ds = EchoDialogService(timezone=timezone("Asia/Tokyo")) 147 | performance = PerformanceInfo() 148 | context = Context("TEST", "test_user") 149 | context.topic.is_new = True 150 | request = Message(channel="TEST", channel_user_id="test_user", text="hello") 151 | response = ds.execute(request, context, None, performance) 152 | assert response.messages[0].text == "You said: hello" 153 | 154 | 155 | def test_execute_error(): 156 | ds = ErrorDialogService(timezone=timezone("Asia/Tokyo")) 157 | performance = PerformanceInfo() 158 | context = Context("TEST", "test_user") 159 | context.topic.is_new = True 160 | request = Message(channel="TEST", channel_user_id="test_user", text="hello") 161 | response = ds.execute(request, context, None, performance) 162 | assert response.messages[0].text == "?" 163 | -------------------------------------------------------------------------------- /tests/models/test_context.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from datetime import datetime 4 | 5 | from minette import Context, Topic 6 | from minette.utils import date_to_str, str_to_date 7 | 8 | 9 | def test_init(): 10 | context = Context(channel="TEST", channel_user_id="user_id") 11 | assert context.channel == "TEST" 12 | assert context.channel_user_id == "user_id" 13 | assert context.timestamp is None 14 | assert context.is_new is True 15 | assert isinstance(context.topic, Topic) 16 | assert context.data == {} 17 | assert context.error == {} 18 | 19 | 20 | def test_reset(): 21 | # setup context 22 | context = Context(channel="TEST", channel_user_id="user_id") 23 | context.data = {"fruit": "Apple"} 24 | context.topic.name = "favorite_fruit" 25 | context.topic.status = "continue" 26 | context.topic.keep_on = True 27 | # reset 28 | context.reset() 29 | assert context.topic.name == "favorite_fruit" 30 | assert context.topic.status == "continue" 31 | assert context.topic.previous.name == "favorite_fruit" 32 | assert context.topic.previous.status == "continue" 33 | assert context.data == {"fruit": "Apple"} 34 | # update 35 | context.topic.status = "finish" 36 | context.topic.keep_on = False 37 | # reset 38 | context.reset() 39 | assert context.topic.name == "" 40 | assert context.topic.status == "" 41 | assert context.topic.previous.name == "favorite_fruit" 42 | assert context.topic.previous.status == "finish" 43 | assert context.data == {} 44 | 45 | 46 | def test_set_error(): 47 | context = Context(channel="TEST", channel_user_id="user_id") 48 | try: 49 | 1 / 0 50 | except Exception as ex: 51 | context.set_error(ex, info="this is test error") 52 | assert context.error["exception"] == "division by zero" 53 | assert "Traceback" in context.error["traceback"] 54 | assert context.error["info"] == "this is test error" 55 | 56 | 57 | def test_from_dict(): 58 | context_dict = { 59 | "channel": "TEST", 60 | "channel_user_id": "user_id", 61 | "data": {"fruit": "Apple"}, 62 | "topic": { 63 | "name": "favorite_fruit", 64 | "status": "continue" 65 | } 66 | } 67 | context = Context.from_dict(context_dict) 68 | assert context.channel == "TEST" 69 | assert context.channel_user_id == "user_id" 70 | assert context.timestamp is None 71 | assert context.is_new is True 72 | assert context.topic.name == "favorite_fruit" 73 | assert context.topic.status == "continue" 74 | assert context.data == {"fruit": "Apple"} 75 | assert context.error == {} 76 | -------------------------------------------------------------------------------- /tests/models/test_group.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import Group 4 | 5 | 6 | def test_init(): 7 | group = Group(id="group_id", type="room") 8 | assert group.id == "group_id" 9 | assert group.type == "room" 10 | -------------------------------------------------------------------------------- /tests/models/test_message.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from datetime import datetime 4 | 5 | from minette import Message, Group, Payload, Priority 6 | 7 | 8 | def test_init(): 9 | now = datetime.now(timezone("Asia/Tokyo")) 10 | message = Message( 11 | id="mid123456789", 12 | type="hogehoge", 13 | timestamp=now, 14 | channel="TEST", 15 | channel_detail="messaging", 16 | channel_user_id="user_id", 17 | channel_message={ 18 | "text": "hello" 19 | }, 20 | text="hello", 21 | token="token123456789", 22 | payloads=[Payload(content_type="image", url="https://image")], 23 | intent="OrderPan", 24 | intent_priority=Priority.High, 25 | entities={"menu": "Yakisoba Pan", "count": 20}, 26 | is_adhoc=True 27 | ) 28 | assert message.id == "mid123456789" 29 | assert message.type == "hogehoge" 30 | assert message.timestamp == now 31 | assert message.channel == "TEST" 32 | assert message.channel_detail == "messaging" 33 | assert message.channel_user_id == "user_id" 34 | assert message.channel_message == {"text": "hello"} 35 | assert message.text == "hello" 36 | assert message.token == "token123456789" 37 | assert message.payloads[0].url == "https://image" 38 | assert message.intent == "OrderPan" 39 | assert message.intent_priority == Priority.High 40 | assert message.entities == {"menu": "Yakisoba Pan", "count": 20} 41 | assert message.is_adhoc is True 42 | 43 | 44 | def test_init_default(): 45 | message = Message() 46 | assert message.id == "" 47 | assert message.type == "text" 48 | assert isinstance(message.timestamp, datetime) 49 | assert message.channel == "console" 50 | assert message.channel_detail == "" 51 | assert message.channel_user_id == "anonymous" 52 | assert message.text == "" 53 | assert message.token == "" 54 | assert message.payloads == [] 55 | assert message.channel_message is None 56 | 57 | 58 | def test_to_reply(): 59 | now = datetime.now(timezone("Asia/Tokyo")) 60 | message = Message( 61 | id="mid123456789", 62 | type="hogehoge", 63 | timestamp=now, 64 | channel="TEST", 65 | channel_detail="messaging", 66 | channel_user_id="user_id", 67 | text="hello", 68 | token="token123456789", 69 | channel_message={ 70 | "text": "hello" 71 | } 72 | ).to_reply(text="nice talking to you", payloads=[Payload(content_type="image", url="https://image")], type="image") 73 | assert message.id == "mid123456789" 74 | assert message.type == "image" 75 | assert isinstance(message.timestamp, datetime) 76 | assert str(message.timestamp.tzinfo) == "Asia/Tokyo" 77 | assert message.channel == "TEST" 78 | assert message.channel_detail == "messaging" 79 | assert message.channel_user_id == "user_id" 80 | assert message.text == "nice talking to you" 81 | assert message.token == "token123456789" 82 | assert message.payloads[0].url == "https://image" 83 | assert message.channel_message is None 84 | 85 | 86 | def test_to_dict(): 87 | now = datetime.now(timezone("Asia/Tokyo")) 88 | message = Message( 89 | id="mid123456789", 90 | type="hogehoge", 91 | timestamp=now, 92 | channel="TEST", 93 | channel_detail="messaging", 94 | channel_user_id="user_id", 95 | text="hello", 96 | token="token123456789", 97 | payloads=[Payload(content_type="image", url="https://image")], 98 | channel_message={ 99 | "text": "hello" 100 | } 101 | ) 102 | msg_dict = message.to_dict() 103 | assert msg_dict["id"] == "mid123456789" 104 | assert msg_dict["timestamp"] == now 105 | 106 | 107 | def test_from_dict(): 108 | now = datetime.now(timezone("Asia/Tokyo")) 109 | message = Message( 110 | id="mid123456789", 111 | type="hogehoge", 112 | timestamp=now, 113 | channel="TEST", 114 | channel_detail="messaging", 115 | channel_user_id="user_id", 116 | text="hello", 117 | token="token123456789", 118 | payloads=[Payload(content_type="image", url="https://image")], 119 | channel_message={ 120 | "text": "hello" 121 | } 122 | ) 123 | msg_dict = message.to_dict() 124 | message = Message.from_dict(msg_dict) 125 | assert message.id == "mid123456789" 126 | assert message.type == "hogehoge" 127 | assert message.timestamp == now 128 | assert message.channel == "TEST" 129 | assert message.channel_detail == "messaging" 130 | assert message.channel_user_id == "user_id" 131 | assert message.text == "hello" 132 | assert message.token == "token123456789" 133 | assert message.payloads[0].url == "https://image" 134 | assert message.channel_message == str({"text": "hello"}) 135 | -------------------------------------------------------------------------------- /tests/models/test_models_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from datetime import datetime 4 | 5 | from minette.serializer import Serializable 6 | from minette.utils import date_to_str 7 | 8 | 9 | class CustomClass(Serializable): 10 | def __init__(self, strvalue=None, intvalue=None, dtvalue=None, listvalue=None, dictvalue=None, objvalue=None): 11 | super().__init__() 12 | self.strvalue = strvalue 13 | self.intvalue = intvalue 14 | self.dtvalue = dtvalue 15 | self.listvalue = listvalue 16 | self.dictvalue = dictvalue 17 | self.objvalue = objvalue 18 | 19 | 20 | class SubClass(CustomClass): 21 | pass 22 | 23 | 24 | def test_init(): 25 | now = datetime.now(timezone("Asia/Tokyo")) 26 | cc = CustomClass(strvalue="str", intvalue=1, dtvalue=now, listvalue=[1, 2, 3], dictvalue={"key1": "value1", "key2": 2}) 27 | assert cc.strvalue == "str" 28 | assert cc.intvalue == 1 29 | assert cc.dtvalue == now 30 | assert cc.listvalue == [1, 2, 3] 31 | assert cc.dictvalue == {"key1": "value1", "key2": 2} 32 | 33 | 34 | def test_to_dict(): 35 | now = datetime.now(timezone("Asia/Tokyo")) 36 | sc1 = SubClass(strvalue="sub_str_1") 37 | sc2 = SubClass(strvalue="sub_str_2") 38 | sc3 = SubClass(strvalue="sub_str_3") 39 | sc4 = SubClass(strvalue="sub_str_4") 40 | cc = CustomClass(strvalue="str", intvalue=1, dtvalue=now, listvalue=[sc1, sc2, "list_str_1"], dictvalue={"key1": "value1", "key2": 1, "key3": sc3}, objvalue=sc4) 41 | cc_dict = cc.to_dict() 42 | # CustomClass 43 | assert cc_dict["strvalue"] == "str" 44 | assert cc_dict["intvalue"] == 1 45 | assert cc_dict["dtvalue"] == now 46 | assert cc_dict["listvalue"] == [sc1.to_dict(), sc2.to_dict(), "list_str_1"] 47 | assert cc_dict["dictvalue"]["key1"] == "value1" 48 | assert cc_dict["dictvalue"]["key2"] == 1 49 | assert cc_dict["dictvalue"]["key3"] == sc3.to_dict() 50 | assert cc_dict["objvalue"] == sc4.to_dict() 51 | 52 | 53 | def test_to_json(): 54 | now = datetime.now(timezone("Asia/Tokyo")) 55 | cc = CustomClass(strvalue="str", intvalue=1, dtvalue=now, dictvalue={"key1": "value1", "key2": 2}) 56 | cc_json = cc.to_json() 57 | assert cc_json == '{"strvalue": "str", "intvalue": 1, "dtvalue": "' + date_to_str(now, with_timezone=True) + '", "listvalue": null, "dictvalue": {"key1": "value1", "key2": 2}, "objvalue": null}' 58 | 59 | 60 | def test_from_dict(): 61 | now = datetime.now(timezone("Asia/Tokyo")) 62 | cc_dict = { 63 | "strvalue": "str", 64 | "intvalue": 1, 65 | "dtvalue": now, 66 | "listvalue": [1, 2, 3], 67 | "dictvalue": {"key1": "value1", "key2": 2}, 68 | } 69 | cc = CustomClass.from_dict(cc_dict) 70 | assert cc.strvalue == "str" 71 | assert cc.intvalue == 1 72 | assert cc.dtvalue == now 73 | assert cc.listvalue == [1, 2, 3] 74 | assert cc.dictvalue == {"key1": "value1", "key2": 2} 75 | 76 | 77 | def test_from_dict_dict(): 78 | cc_dict_dict = { 79 | "cc1": { 80 | "strvalue": "str_1", 81 | }, 82 | "cc2": { 83 | "strvalue": "str_2", 84 | }, 85 | } 86 | cc = CustomClass.from_dict_dict(cc_dict_dict) 87 | assert cc["cc1"].strvalue == "str_1" 88 | assert cc["cc1"].to_dict() == CustomClass.from_dict({"strvalue": "str_1"}).to_dict() 89 | assert cc["cc2"].strvalue == "str_2" 90 | assert cc["cc2"].to_dict() == CustomClass.from_dict({"strvalue": "str_2"}).to_dict() 91 | 92 | 93 | def test_from_dict_list(): 94 | cc_dict_list = [ 95 | { 96 | "strvalue": "str_1", 97 | }, 98 | { 99 | "strvalue": "str_2", 100 | }, 101 | ] 102 | cc = CustomClass.from_dict(cc_dict_list) 103 | assert cc[0].strvalue == "str_1" 104 | assert cc[0].to_dict() == CustomClass.from_dict({"strvalue": "str_1"}).to_dict() 105 | assert cc[1].strvalue == "str_2" 106 | assert cc[1].to_dict() == CustomClass.from_dict({"strvalue": "str_2"}).to_dict() 107 | 108 | 109 | def test_from_json(): 110 | now = datetime.now(timezone("Asia/Tokyo")) 111 | cc_json = '{"strvalue": "str", "intvalue": 1, "dtvalue": "' + date_to_str(now, with_timezone=True) + '", "listvalue": [1, 2, 3], "dictvalue": {"key1": "value1", "key2": 2}, "objvalue": null}' 112 | cc = CustomClass.from_json(cc_json) 113 | assert cc.strvalue == "str" 114 | assert cc.intvalue == 1 115 | assert cc.dtvalue == now 116 | assert cc.listvalue == [1, 2, 3] 117 | assert cc.dictvalue == {"key1": "value1", "key2": 2} 118 | assert cc.objvalue is None 119 | -------------------------------------------------------------------------------- /tests/models/test_payload.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import Payload 4 | 5 | 6 | def test_init(): 7 | payload = Payload(content_type="image/jpeg", url="http://uezo.net/img/minette_architecture.png", thumb="https://thumb") 8 | assert payload.content_type == "image/jpeg" 9 | assert payload.url == "http://uezo.net/img/minette_architecture.png" 10 | assert payload.thumb == "https://thumb" 11 | assert payload.headers == {} 12 | assert payload.content is None 13 | -------------------------------------------------------------------------------- /tests/models/test_performance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from time import sleep 3 | 4 | from minette import PerformanceInfo 5 | 6 | 7 | def test_init(): 8 | performance = PerformanceInfo() 9 | assert isinstance(performance.start_time, float) 10 | assert performance.ticks == [] 11 | assert performance.milliseconds == 0 12 | 13 | 14 | def test_append(): 15 | performance = PerformanceInfo() 16 | sleep(2) 17 | performance.append("operation") 18 | assert performance.ticks[0][1] > 1 19 | assert performance.milliseconds > 0 20 | -------------------------------------------------------------------------------- /tests/models/test_priority.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import Priority 4 | 5 | 6 | def test_class(): 7 | assert Priority.Highest == 100 8 | assert Priority.High == 75 9 | assert Priority.Normal == 50 10 | assert Priority.Low == 25 11 | assert Priority.Ignore == 0 12 | -------------------------------------------------------------------------------- /tests/models/test_response.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import Response, Message, PerformanceInfo 4 | 5 | 6 | def test_init(): 7 | msg1 = Message(text="message 1") 8 | msg2 = Message(text="message 2") 9 | headers = {"key1": "value1", "key2": 1} 10 | performance = PerformanceInfo() 11 | response = Response(messages=[msg1, msg2], headers=headers, performance=performance) 12 | assert response.messages == [msg1, msg2] 13 | assert response.headers == headers 14 | assert isinstance(response.performance, PerformanceInfo) 15 | 16 | 17 | def test_from_dict(): 18 | # setup response dict 19 | msg1 = Message(text="message 1") 20 | msg2 = Message(text="message 2") 21 | headers = {"key1": "value1", "key2": 1} 22 | performance = PerformanceInfo() 23 | response_dict = Response(messages=[msg1, msg2], headers=headers, performance=performance).to_dict() 24 | # restore 25 | response = Response.from_dict(response_dict) 26 | assert response.messages[0].text == msg1.text 27 | assert response.messages[1].text == msg2.text 28 | assert response.headers == headers 29 | assert isinstance(response.performance, PerformanceInfo) 30 | -------------------------------------------------------------------------------- /tests/models/test_topic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from copy import deepcopy 4 | 5 | from minette import Topic, Priority 6 | 7 | 8 | def test_init(): 9 | topic = Topic() 10 | assert topic.name == "" 11 | assert topic.status == "" 12 | assert topic.is_new is False 13 | assert topic.keep_on is False 14 | assert topic.previous is None 15 | assert topic.priority == Priority.Normal 16 | 17 | 18 | def test_is_changed(): 19 | # create topic 20 | topic = Topic() 21 | topic.name = "favorite_fruit" 22 | topic.previous = deepcopy(topic) 23 | # keep topic 24 | topic.name = "favorite_fruit" 25 | assert topic.is_changed is False 26 | # change topic 27 | topic.name = "favorite_sushi" 28 | assert topic.is_changed is True 29 | -------------------------------------------------------------------------------- /tests/models/test_user.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import User 4 | 5 | 6 | def test_user(): 7 | user = User(channel="TEST", channel_user_id="user_id") 8 | assert len(user.id) > 10 9 | assert user.name == "" 10 | assert user.nickname == "" 11 | assert user.channel == "TEST" 12 | assert user.channel_user_id == "user_id" 13 | assert user.profile_image_url == "" 14 | assert user.data == {} 15 | -------------------------------------------------------------------------------- /tests/models/test_wordnode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from minette import WordNode 4 | 5 | 6 | class CustomNode(WordNode): 7 | @classmethod 8 | def create(cls, surface, features): 9 | pass 10 | 11 | 12 | def test_init(): 13 | node = CustomNode("surface", "part", "part_detail1", "part_detail2", "part_detail3", "stem_type", "stem_form", "word", "kana", "pronunciation") 14 | assert node.surface == "surface" 15 | assert node.part == "part" 16 | assert node.part_detail1 == "part_detail1" 17 | assert node.part_detail2 == "part_detail2" 18 | assert node.part_detail3 == "part_detail3" 19 | assert node.stem_type == "stem_type" 20 | assert node.stem_form == "stem_form" 21 | assert node.word == "word" 22 | assert node.kana == "kana" 23 | assert node.pronunciation == "pronunciation" 24 | -------------------------------------------------------------------------------- /tests/payload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uezo/minette-python/dd8cd7d244b6e6e4133c8e73d637ded8a8c6846f/tests/payload.png -------------------------------------------------------------------------------- /tests/scheduler/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | from minette import Task, Scheduler 6 | 7 | 8 | class MyTask(Task): 9 | def do(self, arg1, arg2): 10 | print(arg1, arg2) 11 | 12 | 13 | class StopTask(Task): 14 | def do(self, sc): 15 | sc.stop() 16 | 17 | 18 | class MyClass: 19 | def do(self, arg1): 20 | print(arg1) 21 | 22 | 23 | def print_something(): 24 | print("something") 25 | 26 | 27 | def test_init_task(): 28 | task = Task(timezone=timezone("Asia/Tokyo")) 29 | assert task.timezone == timezone("Asia/Tokyo") 30 | 31 | 32 | def test_do_task(): 33 | task = Task() 34 | task.do() 35 | 36 | 37 | def test_init_scheduler(): 38 | sc = Scheduler(timezone=timezone("Asia/Tokyo"), threads=2) 39 | assert sc.timezone == timezone("Asia/Tokyo") 40 | assert sc.threads == 2 41 | assert isinstance(sc.executor, ThreadPoolExecutor) 42 | assert sc.is_running is False 43 | 44 | 45 | def test_scheduler_create_task(): 46 | # subclass of task 47 | sc = Scheduler(timezone=timezone("Asia/Tokyo")) 48 | task = sc.create_task(MyTask) 49 | assert callable(task) 50 | assert not isinstance(task, MyTask) 51 | assert task is not MyTask 52 | 53 | # other class 54 | with pytest.raises(TypeError): 55 | task = sc.create_task(MyClass) 56 | 57 | # callable 58 | task = sc.create_task(print_something) 59 | assert task is print_something 60 | 61 | # callable (instance method) 62 | mc = MyClass() 63 | task = sc.create_task(mc.do) 64 | assert task == mc.do 65 | 66 | # other 67 | with pytest.raises(TypeError): 68 | task = sc.create_task(MyClass()) 69 | 70 | 71 | def test_scheduler_register_task(): 72 | sc = Scheduler() 73 | sc.every_seconds(print_something) 74 | sc.every_minutes(print_something, 2) 75 | sc.every_hours(MyTask, arg1="val1", arg2="val2") 76 | sc.every_days(MyTask, 4, "val1", "val2") 77 | 78 | 79 | def test_scheduler_start_stop(): 80 | sc = Scheduler() 81 | sc.every_seconds(MyTask, arg1="val1", arg2="val2") 82 | sc.every_seconds(StopTask, 3, sc=sc) 83 | sc.start() 84 | assert sc.is_running is False 85 | -------------------------------------------------------------------------------- /tests/tagger/test_janometagger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | import pytest 5 | from pytz import timezone 6 | from types import GeneratorType 7 | 8 | try: 9 | from minette.tagger.janometagger import JanomeTagger, JanomeNode 10 | except Exception: 11 | # Skip if import dependencies not found 12 | pytestmark = pytest.mark.skip 13 | 14 | 15 | def test_init(): 16 | tagger = JanomeTagger(timezone=timezone("Asia/Tokyo")) 17 | assert tagger.timezone == timezone("Asia/Tokyo") 18 | 19 | 20 | def test_parse(): 21 | tagger = JanomeTagger() 22 | # 空文字列 23 | assert tagger.parse("") == [] 24 | # センテンスあり 25 | words = tagger.parse("今日は良い天気です") 26 | assert isinstance(words[0], JanomeNode) 27 | # 今日 28 | assert words[0].surface == "今日" 29 | assert words[0].part == "名詞" 30 | assert words[0].part_detail1 == "副詞可能" 31 | assert words[0].word == "今日" 32 | assert words[0].kana == "キョウ" 33 | assert words[0].pronunciation == "キョー" 34 | # 良い 35 | assert words[2].surface == "良い" 36 | assert words[2].part == "形容詞" 37 | assert words[2].part_detail1 == "自立" 38 | assert words[2].stem_type == "形容詞・アウオ段" 39 | assert words[2].stem_form == "基本形" 40 | assert words[2].word == "良い" 41 | assert words[2].kana == "ヨイ" 42 | assert words[2].pronunciation == "ヨイ" 43 | 44 | 45 | def test_parse_as_generator(): 46 | tagger = JanomeTagger() 47 | # 空文字列 48 | empty_words_gen = tagger.parse_as_generator("") 49 | assert isinstance(empty_words_gen, GeneratorType) 50 | empty_words = [ew for ew in empty_words_gen] 51 | assert empty_words == [] 52 | # センテンスあり 53 | words = tagger.parse_as_generator("今日は良い天気です") 54 | assert isinstance(words, GeneratorType) 55 | i = 0 56 | for w in words: 57 | if i == 0: 58 | assert w.surface == "今日" 59 | assert w.part == "名詞" 60 | assert w.part_detail1 == "副詞可能" 61 | assert w.word == "今日" 62 | assert w.kana == "キョウ" 63 | assert w.pronunciation == "キョー" 64 | elif i == 2: 65 | assert w.surface == "良い" 66 | assert w.part == "形容詞" 67 | assert w.part_detail1 == "自立" 68 | assert w.stem_type == "形容詞・アウオ段" 69 | assert w.stem_form == "基本形" 70 | assert w.word == "良い" 71 | assert w.kana == "ヨイ" 72 | assert w.pronunciation == "ヨイ" 73 | i += 1 74 | 75 | 76 | def test_parse_with_max(): 77 | tagger = JanomeTagger(max_length=8) 78 | # over instance max_length 79 | words_9 = tagger.parse("今日は良い天気です") 80 | assert words_9 == [] 81 | # over instance max_length but under inline 82 | words_9_max10 = tagger.parse("今日は良い天気です", max_length=10) 83 | assert words_9_max10[0].surface == "今日" 84 | # under instance max_length but over inline 85 | words_9_max7 = tagger.parse("今日は良い天気", max_length=6) 86 | assert words_9_max7 == [] 87 | 88 | 89 | def test_parse_gen_with_max(): 90 | tagger = JanomeTagger(max_length=8) 91 | # over instance max_length 92 | words_9 = [w for w in tagger.parse_as_generator("今日は良い天気です")] 93 | assert words_9 == [] 94 | # over instance max_length but under inline 95 | words_9_max10 = [w for w in tagger.parse_as_generator("今日は良い天気です", max_length=10)] 96 | assert words_9_max10[0].surface == "今日" 97 | # under instance max_length but over inline 98 | words_9_max7 = [w for w in tagger.parse_as_generator("今日は良い天気", max_length=6)] 99 | assert words_9_max7 == [] 100 | 101 | 102 | def test_error(): 103 | tagger = JanomeTagger() 104 | with pytest.raises(TypeError): 105 | tagger.parse(object()) 106 | -------------------------------------------------------------------------------- /tests/tagger/test_mecabservice.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | 4 | try: 5 | from minette.tagger.mecabservice import ( 6 | MeCabServiceTagger, 7 | MeCabServiceNode 8 | ) 9 | except Exception: 10 | # Skip if import dependencies not found 11 | pytestmark = pytest.mark.skip 12 | 13 | 14 | def test_init(): 15 | tagger = MeCabServiceTagger(timezone=timezone("Asia/Tokyo")) 16 | assert tagger.timezone == timezone("Asia/Tokyo") 17 | 18 | 19 | def test_parse(): 20 | tagger = MeCabServiceTagger() 21 | # 空文字列 22 | assert tagger.parse("") == [] 23 | # センテンスあり 24 | words = tagger.parse("今日は良い天気です") 25 | assert isinstance(words[0], MeCabServiceNode) 26 | # 今日 27 | assert words[0].surface == "今日" 28 | assert words[0].part == "名詞" 29 | assert words[0].part_detail1 == "副詞可能" 30 | assert words[0].word == "今日" 31 | assert words[0].kana == "キョウ" 32 | assert words[0].pronunciation == "キョー" 33 | # 良い 34 | assert words[2].surface == "良い" 35 | assert words[2].part == "形容詞" 36 | assert words[2].part_detail1 == "自立" 37 | assert words[2].stem_type == "形容詞・アウオ段" 38 | assert words[2].stem_form == "基本形" 39 | assert words[2].word == "良い" 40 | assert words[2].kana == "ヨイ" 41 | assert words[2].pronunciation == "ヨイ" 42 | 43 | 44 | def test_error(): 45 | tagger = MeCabServiceTagger(api_url="https://") 46 | assert tagger.api_url == "https://" 47 | assert tagger.parse("今日は良い天気です") == [] 48 | -------------------------------------------------------------------------------- /tests/tagger/test_mecabtagger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | import pytest 5 | from pytz import timezone 6 | from types import GeneratorType 7 | 8 | try: 9 | from minette.tagger.mecabtagger import MeCabTagger, MeCabNode 10 | except Exception: 11 | # Skip if import dependencies not found 12 | pytestmark = pytest.mark.skip 13 | 14 | 15 | def test_init(): 16 | tagger = MeCabTagger(timezone=timezone("Asia/Tokyo")) 17 | assert tagger.timezone == timezone("Asia/Tokyo") 18 | 19 | 20 | def test_parse(): 21 | tagger = MeCabTagger() 22 | # 空文字列 23 | assert tagger.parse("") == [] 24 | # センテンスあり 25 | words = tagger.parse("今日は良い天気です") 26 | assert isinstance(words[0], MeCabNode) 27 | # 今日 28 | assert words[0].surface == "今日" 29 | assert words[0].part == "名詞" 30 | assert words[0].part_detail1 == "副詞可能" 31 | assert words[0].word == "今日" 32 | assert words[0].kana == "キョウ" 33 | assert words[0].pronunciation == "キョー" 34 | # 良い 35 | assert words[2].surface == "良い" 36 | assert words[2].part == "形容詞" 37 | assert words[2].part_detail1 == "自立" 38 | assert words[2].stem_type == "形容詞・アウオ段" 39 | assert words[2].stem_form == "基本形" 40 | assert words[2].word == "良い" 41 | assert words[2].kana == "ヨイ" 42 | assert words[2].pronunciation == "ヨイ" 43 | 44 | 45 | def test_parse_as_generator(): 46 | tagger = MeCabTagger() 47 | # 空文字列 48 | empty_words_gen = tagger.parse_as_generator("") 49 | assert isinstance(empty_words_gen, GeneratorType) 50 | empty_words = [ew for ew in empty_words_gen] 51 | assert empty_words == [] 52 | # センテンスあり 53 | words = tagger.parse_as_generator("今日は良い天気です") 54 | assert isinstance(words, GeneratorType) 55 | i = 0 56 | for w in words: 57 | if i == 0: 58 | assert w.surface == "今日" 59 | assert w.part == "名詞" 60 | assert w.part_detail1 == "副詞可能" 61 | assert w.word == "今日" 62 | assert w.kana == "キョウ" 63 | assert w.pronunciation == "キョー" 64 | elif i == 2: 65 | assert w.surface == "良い" 66 | assert w.part == "形容詞" 67 | assert w.part_detail1 == "自立" 68 | assert w.stem_type == "形容詞・アウオ段" 69 | assert w.stem_form == "基本形" 70 | assert w.word == "良い" 71 | assert w.kana == "ヨイ" 72 | assert w.pronunciation == "ヨイ" 73 | i += 1 74 | 75 | 76 | def test_parse_with_max(): 77 | tagger = MeCabTagger(max_length=8) 78 | # over instance max_length 79 | words_9 = tagger.parse("今日は良い天気です") 80 | assert words_9 == [] 81 | # over instance max_length but under inline 82 | words_9_max10 = tagger.parse("今日は良い天気です", max_length=10) 83 | assert words_9_max10[0].surface == "今日" 84 | # under instance max_length but over inline 85 | words_9_max7 = tagger.parse("今日は良い天気", max_length=6) 86 | assert words_9_max7 == [] 87 | 88 | 89 | def test_parse_gen_with_max(): 90 | tagger = MeCabTagger(max_length=8) 91 | # over instance max_length 92 | words_9 = [w for w in tagger.parse_as_generator("今日は良い天気です")] 93 | assert words_9 == [] 94 | # over instance max_length but under inline 95 | words_9_max10 = [w for w in tagger.parse_as_generator("今日は良い天気です", max_length=10)] 96 | assert words_9_max10[0].surface == "今日" 97 | # under instance max_length but over inline 98 | words_9_max7 = [w for w in tagger.parse_as_generator("今日は良い天気", max_length=6)] 99 | assert words_9_max7 == [] 100 | 101 | 102 | def test_error(): 103 | tagger = MeCabTagger() 104 | with pytest.raises(TypeError): 105 | tagger.parse(object()) 106 | -------------------------------------------------------------------------------- /tests/tagger/test_tagger_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | import pytest 5 | from pytz import timezone 6 | from types import GeneratorType 7 | 8 | from minette import Tagger 9 | 10 | 11 | def test_init(): 12 | tagger = Tagger(timezone=timezone("Asia/Tokyo")) 13 | assert tagger.timezone == timezone("Asia/Tokyo") 14 | 15 | 16 | def test_parse(): 17 | tagger = Tagger() 18 | assert tagger.parse("今日は良い天気です") == [] 19 | 20 | 21 | def test_parse_as_generator(): 22 | tagger = Tagger() 23 | assert isinstance(tagger.parse_as_generator("今日は良い天気です"), GeneratorType) 24 | 25 | 26 | def test_validate(): 27 | tagger = Tagger(max_length=5) 28 | assert tagger.validate("こんにちは") is True 29 | assert tagger.validate("ごきげんよう") is False 30 | assert tagger.validate("こんにちは", max_length=4) is False 31 | assert tagger.validate("ごきげんよう", max_length=6) is True 32 | with pytest.raises(TypeError): 33 | tagger.validate(object()) 34 | with pytest.raises(TypeError): 35 | tagger.validate(1) 36 | 37 | tagger_nomax = Tagger() 38 | assert tagger_nomax.max_length == 1000 39 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | import pytest 5 | 6 | from minette.config import Config 7 | 8 | 9 | def test_init_without_file(): 10 | config = Config(None) 11 | assert config.get("timezone", section="minette") == "UTC" 12 | 13 | 14 | def test_init_file_not_exist(): 15 | config = Config(None) 16 | assert config.get("timezone", section="minette") == "UTC" 17 | 18 | 19 | def test_get(): 20 | config = Config("config/test_config.ini") 21 | assert config.get("key1") == "value1" 22 | assert config.get("key2") is None 23 | assert config.get("key3", section="invalid_section", default="default_value") == "default_value" 24 | 25 | 26 | def test_get_without_section(): 27 | config = Config("config/test_config_empty.ini") 28 | assert config.get("timezone") == "UTC" 29 | -------------------------------------------------------------------------------- /tests/test_testing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytz import timezone 3 | from logging import Logger 4 | from datetime import datetime 5 | 6 | from minette import ( 7 | DialogService, SQLiteConnectionProvider, 8 | SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore, 9 | Tagger, Message 10 | ) 11 | from minette.testing.helper import MinetteForTest 12 | from minette.utils import date_to_unixtime 13 | 14 | now = datetime.now() 15 | user_id = "user_id" + str(date_to_unixtime(now)) 16 | print("user_id: {}".format(user_id)) 17 | 18 | 19 | class FooDialog(DialogService): 20 | def compose_response(self, request, context, connetion): 21 | return "foo" 22 | 23 | 24 | class BarDialog(DialogService): 25 | def compose_response(self, request, context, connetion): 26 | return "bar" 27 | 28 | 29 | class MessageDialog(DialogService): 30 | def compose_response(self, request, context, connetion): 31 | return request.channel + "_" + request.channel_user_id 32 | 33 | 34 | def test_init(): 35 | # without config 36 | bot = MinetteForTest( 37 | default_channel="TEST", 38 | intent_resolver={ 39 | "FooIntent": FooDialog, 40 | "BarIntent": BarDialog, 41 | }, 42 | ) 43 | assert bot.config.get("timezone") == "UTC" 44 | assert bot.timezone == timezone("UTC") 45 | assert isinstance(bot.logger, Logger) 46 | assert bot.logger.name == "minette" 47 | assert isinstance(bot.connection_provider, SQLiteConnectionProvider) 48 | assert isinstance(bot.context_store, SQLiteContextStore) 49 | assert isinstance(bot.user_store, SQLiteUserStore) 50 | assert isinstance(bot.messagelog_store, SQLiteMessageLogStore) 51 | assert bot.default_dialog_service is None 52 | assert isinstance(bot.tagger, Tagger) 53 | assert len(bot.case_id) > 10 54 | assert bot.default_channel == "TEST" 55 | assert bot.dialog_router.intent_resolver["FooIntent"] == FooDialog 56 | assert bot.dialog_router.intent_resolver["BarIntent"] == BarDialog 57 | 58 | 59 | def test_chat(): 60 | bot = MinetteForTest( 61 | default_channel="TEST", 62 | intent_resolver={ 63 | "FooIntent": FooDialog, 64 | "BarIntent": BarDialog, 65 | }, 66 | ) 67 | assert bot.chat("hello", intent="FooIntent").messages[0].text == "foo" 68 | assert bot.chat("hello", intent="BarIntent").messages[0].text == "bar" 69 | assert bot.chat("hello").messages == [] 70 | 71 | 72 | def test_chat_message(): 73 | bot = MinetteForTest( 74 | intent_resolver={ 75 | "MessageIntent": MessageDialog, 76 | }, 77 | ) 78 | assert bot.chat(Message( 79 | intent="MessageIntent", 80 | channel="test_channel", 81 | channel_user_id="test_user" 82 | )).messages[0].text == "test_channel_test_user" 83 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import datetime 3 | from pytz import timezone 4 | 5 | from minette.serializer import dumps, loads 6 | import minette.utils as u 7 | 8 | naive_dt = datetime(2019, 1, 2, 3, 4, 5) 9 | aware_dt = timezone("Asia/Tokyo").localize(naive_dt) 10 | 11 | 12 | def test_date_to_str(): 13 | assert u.date_to_str(naive_dt) == "2019-01-02T03:04:05.000000" 14 | assert u.date_to_str(aware_dt) == "2019-01-02T03:04:05.000000" 15 | assert u.date_to_str(aware_dt, with_timezone=True) == "2019-01-02T03:04:05.000000+09:00" 16 | 17 | 18 | def test_str_to_date(): 19 | assert u.str_to_date("2019-01-02T03:04:05") == naive_dt 20 | assert u.str_to_date("2019-01-02T03:04:05+09:00") == aware_dt 21 | 22 | 23 | def test_date_to_unixtime(): 24 | assert u.date_to_unixtime(aware_dt) == 1546365845 25 | 26 | 27 | def test_unixtime_to_date(): 28 | assert u.unixtime_to_date(1546365845) == naive_dt 29 | 30 | 31 | def test_encode_json(): 32 | obj = { 33 | "key1": "value1", 34 | "key2": 2, 35 | "key3": naive_dt, 36 | "key4": aware_dt 37 | } 38 | assert dumps(obj) == '{"key1": "value1", "key2": 2, "key3": "2019-01-02T03:04:05.000000", "key4": "2019-01-02T03:04:05.000000+09:00"}' 39 | assert dumps(None) == "" 40 | with pytest.raises(AttributeError): 41 | dumps(object()) 42 | 43 | 44 | def test_decode_json(): 45 | obj = loads('{"key1": "value1", "key2": 2, "key3": "2019-01-02T03:04:05+09:00"}') 46 | assert obj["key1"] == "value1" 47 | assert obj["key2"] == 2 48 | assert obj["key3"] == aware_dt 49 | assert loads("") is None 50 | --------------------------------------------------------------------------------