├── .github └── dependabot.yml ├── .gitignore ├── LICENSE ├── README.md ├── example_export.py ├── publish.sh ├── refinery ├── __init__.py ├── adapter │ ├── __init__.py │ ├── rasa.py │ ├── sklearn.py │ ├── torch.py │ ├── transformers.py │ └── util.py ├── api_calls.py ├── authentication.py ├── callbacks │ ├── __init__.py │ ├── inference.py │ ├── sklearn.py │ ├── torch.py │ └── transformers.py ├── cli.py ├── exceptions.py ├── settings.py └── util.py ├── requirements.txt └── setup.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem 4 | - package-ecosystem: "pip" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | # default is / which breaks drone 9 | pull-request-branch-name: 10 | separator: "-" 11 | # not created automatically for version updates so only security ones are created 12 | # https://docs.github.com/en/code-security/dependabot/dependabot-security-updates/configuring-dependabot-security-updates#overriding-the-default-behavior-with-a-configuration-file 13 | open-pull-requests-limit: 0 14 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .vscode/ 3 | secrets.json 4 | 5 | # Jupyter 6 | *.ipynb 7 | 8 | # MacOS 9 | .DS_Store 10 | 11 | # JetBrains 12 | .idea/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | .ipynb 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Kern AI GmbH 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![refinery repository](https://uploads-ssl.webflow.com/61e47fafb12bd56b40022a49/62cf1c3cb8272b1e9c01127e_refinery%20sdk%20banner.png)](https://github.com/code-kern-ai/refinery) 2 | [![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-390/) 3 | [![pypi 1.4.0](https://img.shields.io/badge/pypi-1.4.0-yellow.svg)](https://pypi.org/project/refinery-python-sdk/1.4.0/) 4 | 5 | This is the official Python SDK for [*refinery*](https://github.com/code-kern-ai/refinery), the **open-source** data-centric IDE for NLP. 6 | 7 | **Table of Contents** 8 | - [Installation](#installation) 9 | - [Usage](#usage) 10 | - [Creating a `Client` object](#creating-a-client-object) 11 | - [Fetching labeled data](#fetching-labeled-data) 12 | - [Fetching lookup lists](#fetching-lookup-lists) 13 | - [Upload files](#upload-files) 14 | - [Adapters](#adapters) 15 | - [Sklearn](#sklearn-adapter) 16 | - [PyTorch](#pytorch-adapter) 17 | - [HuggingFace](#hugging-face-adapter) 18 | - [Rasa](#rasa-adapter) 19 | - [Callbacks](#callbacks) 20 | - [Sklearn](#sklearn-callback) 21 | - [PyTorch](#pytorch-callback) 22 | - [HuggingFace](#hugging-face-callback) 23 | - [Contributing](#contributing) 24 | - [License](#license) 25 | - [Contact](#contact) 26 | 27 | If you like what we're working on, please leave a ⭐! 28 | 29 | ## Installation 30 | 31 | You can set up this SDK either via running `$ pip install refinery-python-sdk`, or by cloning this repository and running `$ pip install -r requirements.txt`. 32 | 33 | ## Usage 34 | 35 | ### Creating a `Client` object 36 | Once you installed the package, you can create a `Client` object from any Python terminal as follows: 37 | 38 | ```python 39 | from refinery import Client 40 | 41 | user_name = "your-username" # this is the email you log in with 42 | password = "your-password" 43 | project_id = "your-project-id" # can be found in the URL of the web application 44 | 45 | client = Client(user_name, password, project_id) 46 | # if you run the application locally, please use the following instead: 47 | # client = Client(user_name, password, project_id, uri="http://localhost:4455") 48 | ``` 49 | 50 | The `project_id` can be found in your browser, e.g. if you run the app on your localhost: `http://localhost:4455/app/projects/{project_id}/overview` 51 | 52 | Alternatively, you can provide a `secrets.json` file in your directory where you want to run the SDK, looking as follows: 53 | ```json 54 | { 55 | "user_name": "your-username", 56 | "password": "your-password", 57 | "project_id": "your-project-id" 58 | } 59 | ``` 60 | 61 | Again, if you run on your localhost, you should also provide `"uri": "http://localhost:4455"`. Afterwards, you can access the client like this: 62 | 63 | ```python 64 | client = Client.from_secrets_file("secrets.json") 65 | ``` 66 | 67 | With the `Client`, you easily integrate your data into any kind of system; may it be a custom implementation, an AutoML system or a plain data analytics framework 🚀 68 | 69 | ### Fetching labeled data 70 | 71 | Now, you can easily fetch the data from your project: 72 | ```python 73 | df = client.get_record_export(tokenize=False) 74 | # if you set tokenize=True (default), the project-specific 75 | # spaCy tokenizer will process your textual data 76 | ``` 77 | 78 | Alternatively, you can also just run `rsdk pull` in your CLI given that you have provided the `secrets.json` file in the same directory. 79 | 80 | The `df` contains both your originally uploaded data (e.g. `headline` and `running_id` if you uploaded records like `{"headline": "some text", "running_id": 1234}`), and a triplet for each labeling task you create. This triplet consists of the manual labels, the weakly supervised labels, and their confidence. For extraction tasks, this data is on token-level. 81 | 82 | An example export file looks like this: 83 | ```json 84 | [ 85 | { 86 | "running_id": "0", 87 | "Headline": "T. Rowe Price (TROW) Dips More Than Broader Markets", 88 | "Date": "Jun-30-22 06:00PM\u00a0\u00a0", 89 | "Headline__Sentiment Label__MANUAL": null, 90 | "Headline__Sentiment Label__WEAK_SUPERVISION": "Negative", 91 | "Headline__Sentiment Label__WEAK_SUPERVISION__confidence": "0.6220" 92 | } 93 | ] 94 | ``` 95 | 96 | In this example, there is no manual label, but a weakly supervised label `"Negative"` has been set with 62.2% confidence. 97 | 98 | ### Fetching lookup lists 99 | In your project, you can create lookup lists to implement distant supervision heuristics. To fetch your lookup list(s), you can either get all or fetch one by its list id. 100 | ```python 101 | list_id = "your-list-id" 102 | lookup_list = client.get_lookup_list(list_id) 103 | ``` 104 | 105 | The list id can be found in your browser URL when you're on the details page of a lookup list, e.g. when you run on localhost: `http://localhost:4455/app/projects/{project_id}/knowledge-base/{list_id}`. 106 | 107 | Alternatively, you can pull all lookup lists: 108 | ```python 109 | lookup_lists = client.get_lookup_lists() 110 | ``` 111 | 112 | ### Upload files 113 | You can import files directly from your machine to your application: 114 | 115 | ```python 116 | file_path = "my/file/path/data.json" 117 | upload_was_successful = client.post_file_import(file_path) 118 | ``` 119 | 120 | We use Pandas to process the data you upload, so you can also provide `import_file_options` for the file type you use. Currently, you need to provide them as a `\n`-separated string (e.g. `"quoting=1\nsep=';'"`). We'll adapt this in the future to work with dictionaries instead. 121 | 122 | Alternatively, you can `rsdk push ` via CLI, given that you have provided the `secrets.json` file in the same directory. 123 | 124 | **Make sure that you've selected the correct project beforehand, and fit the data schema of existing records in your project!** 125 | 126 | ### Adapters 127 | 128 | #### Sklearn Adapter 129 | You can use *refinery* to directly pull data into a format you can apply for building [sklearn](https://github.com/scikit-learn/scikit-learn) models. This can look as follows: 130 | 131 | ```python 132 | from refinery.adapter.sklearn import build_classification_dataset 133 | from sklearn.tree import DecisionTreeClassifier 134 | 135 | data = build_classification_dataset(client, "headline", "__clickbait", "distilbert-base-uncased") 136 | 137 | clf = DecisionTreeClassifier() 138 | clf.fit(data["train"]["inputs"], data["train"]["labels"]) 139 | 140 | pred_test = clf.predict(data["test"]["inputs"]) 141 | accuracy = (pred_test == data["test"]["labels"]).mean() 142 | ``` 143 | 144 | By the way, we can highly recommend to combine this with [Truss](https://github.com/basetenlabs/truss) for easy model serving! 145 | 146 | #### PyTorch Adapter 147 | If you want to build a [PyTorch](https://github.com/pytorch/pytorch) network, you can build the `train_loader` and `test_loader` as follows: 148 | 149 | ```python 150 | from refinery.adapter.torch import build_classification_dataset 151 | train_loader, test_loader, encoder, index = build_classification_dataset( 152 | client, "headline", "__clickbait", "distilbert-base-uncased" 153 | ) 154 | ``` 155 | 156 | #### Hugging Face Adapter 157 | Transformers are great, but often times, you want to finetune them for your downstream task. With *refinery*, you can do so easily by letting the SDK build the dataset for you that you can use as a plug-and-play base for your training: 158 | 159 | ```python 160 | from refinery.adapter import transformers 161 | dataset, mapping = transformers.build_dataset(client, "headline", "__clickbait") 162 | ``` 163 | 164 | From here, you can follow the [finetuning example](https://huggingface.co/docs/transformers/training) provided in the official Hugging Face documentation. A next step could look as follows: 165 | 166 | ```python 167 | small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) 168 | small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(1000)) 169 | 170 | from transformers import ( 171 | AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer 172 | ) 173 | import numpy as np 174 | from datasets import load_metric 175 | 176 | tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") 177 | 178 | def tokenize_function(examples): 179 | return tokenizer(examples["headline"], padding="max_length", truncation=True) 180 | 181 | tokenized_datasets = dataset.map(tokenize_function, batched=True) 182 | model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) 183 | training_args = TrainingArguments(output_dir="test_trainer") 184 | metric = load_metric("accuracy") 185 | 186 | def compute_metrics(eval_pred): 187 | logits, labels = eval_pred 188 | predictions = np.argmax(logits, axis=-1) 189 | return metric.compute(predictions=predictions, references=labels) 190 | 191 | training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch") 192 | 193 | small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) 194 | small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) 195 | 196 | trainer = Trainer( 197 | model=model, 198 | args=training_args, 199 | train_dataset=small_train_dataset, 200 | eval_dataset=small_eval_dataset, 201 | compute_metrics=compute_metrics, 202 | ) 203 | 204 | trainer.train() 205 | 206 | trainer.save_model("path/to/model") 207 | ``` 208 | 209 | #### Rasa Adapter 210 | *refinery* is perfect to be used for building chatbots with [Rasa](https://github.com/RasaHQ/rasa). We've built an adapter with which you can easily create the required Rasa training data directly from *refinery*. 211 | 212 | To do so, do the following: 213 | 214 | ```python 215 | from refinery.adapter import rasa 216 | 217 | rasa.build_intent_yaml( 218 | client, 219 | "text", 220 | "__intent__WEAK_SUPERVISION" 221 | ) 222 | ``` 223 | 224 | This will create a `.yml` file looking as follows: 225 | 226 | ```yml 227 | nlu: 228 | - intent: check_balance 229 | examples: | 230 | - how much do I have on my savings account 231 | - how much money is in my checking account 232 | - What's the balance on my credit card account 233 | ``` 234 | 235 | If you want to provide a metadata-level label (such as sentiment), you can provide the optional argument `metadata_label_task`: 236 | 237 | ```python 238 | from refinery.adapter import rasa 239 | 240 | rasa.build_intent_yaml( 241 | client, 242 | "text", 243 | "__intent__WEAK_SUPERVISION", 244 | metadata_label_task="__sentiment__WEAK_SUPERVISION" 245 | ) 246 | ``` 247 | 248 | This will create a file like this: 249 | ```yml 250 | nlu: 251 | - intent: check_balance 252 | metadata: 253 | sentiment: neutral 254 | examples: | 255 | - how much do I have on my savings account 256 | - how much money is in my checking account 257 | - What's the balance on my credit card account 258 | ``` 259 | 260 | And if you have entities in your texts which you'd like to recognize, simply add the `tokenized_label_task` argument: 261 | 262 | ```python 263 | from refinery.adapter import rasa 264 | 265 | rasa.build_intent_yaml( 266 | client, 267 | "text", 268 | "__intent__WEAK_SUPERVISION", 269 | metadata_label_task="__sentiment__WEAK_SUPERVISION", 270 | tokenized_label_task="text__entities__WEAK_SUPERVISION" 271 | ) 272 | ``` 273 | 274 | This will not only inject the label names on token-level, but also creates lookup lists for your chatbot: 275 | 276 | ```yml 277 | nlu: 278 | - intent: check_balance 279 | metadata: 280 | sentiment: neutral 281 | examples: | 282 | - how much do I have on my [savings](account) account 283 | - how much money is in my [checking](account) account 284 | - What's the balance on my [credit card account](account) 285 | - lookup: account 286 | examples: | 287 | - savings 288 | - checking 289 | - credit card account 290 | ``` 291 | 292 | Please make sure to also create the further necessary files (`domain.yml`, `data/stories.yml` and `data/rules.yml`) if you want to train your Rasa chatbot. For further reference, see their [documentation](https://rasa.com/docs/rasa). 293 | 294 | 295 | ### Callbacks 296 | If you want to feed your production model's predictions back into *refinery*, you can do so with any version greater than [1.2.1](https://github.com/code-kern-ai/refinery/releases/tag/v1.2.1). 297 | 298 | To do so, we have a generalistic interface and framework-specific classes. 299 | 300 | #### Sklearn Callback 301 | If you want to train a scikit-learn model an feed its outputs back into the refinery, you can do so easily as follows: 302 | 303 | ```python 304 | from sklearn.linear_model import LogisticRegression 305 | clf = LogisticRegression() # we use this as an example, but you can use any model implementing predict_proba 306 | 307 | from refinery.adapter.sklearn import build_classification_dataset 308 | data = build_classification_dataset(client, "headline", "__clickbait", "distilbert-base-uncased") 309 | clf.fit(data["train"]["inputs"], data["train"]["labels"]) 310 | 311 | from refinery.callbacks.sklearn import SklearnCallback 312 | callback = SklearnCallback( 313 | client, 314 | clf, 315 | "clickbait", 316 | ) 317 | 318 | # executing this will call the refinery API with batches of size 32, so your data is pushed to the app 319 | callback.run(data["train"]["inputs"], data["train"]["index"]) 320 | callback.run(data["test"]["inputs"], data["test"]["index"]) 321 | ``` 322 | 323 | #### PyTorch Callback 324 | For PyTorch, the procedure is really similar. You can do as follows: 325 | 326 | ```python 327 | from refinery.adapter.torch import build_classification_dataset 328 | train_loader, test_loader, encoder, index = build_classification_dataset( 329 | client, "headline", "__clickbait", "distilbert-base-uncased" 330 | ) 331 | 332 | # build your custom model and train it here - example: 333 | import torch.nn as nn 334 | import numpy as np 335 | import torch 336 | 337 | # number of features (len of X cols) 338 | input_dim = 768 339 | # number of hidden layers 340 | hidden_layers = 20 341 | # number of classes (unique of y) 342 | output_dim = 2 343 | class Network(nn.Module): 344 | def __init__(self): 345 | super(Network, self).__init__() 346 | self.linear1 = nn.Linear(input_dim, output_dim) 347 | 348 | def forward(self, x): 349 | x = torch.sigmoid(self.linear1(x)) 350 | return x 351 | 352 | clf = Network() 353 | criterion = nn.CrossEntropyLoss() 354 | optimizer = torch.optim.SGD(clf.parameters(), lr=0.1) 355 | 356 | epochs = 2 357 | for epoch in range(epochs): 358 | running_loss = 0.0 359 | for i, data in enumerate(train_loader, 0): 360 | inputs, labels = data 361 | # set optimizer to zero grad to remove previous epoch gradients 362 | optimizer.zero_grad() 363 | # forward propagation 364 | outputs = clf(inputs) 365 | loss = criterion(outputs, labels) 366 | # backward propagation 367 | loss.backward() 368 | # optimize 369 | optimizer.step() 370 | running_loss += loss.item() 371 | # display statistics 372 | print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.5f}') 373 | running_loss = 0.0 374 | 375 | # with this model trained, you can use the callback 376 | from refinery.callbacks.torch import TorchCallback 377 | callback = TorchCallback( 378 | client, 379 | clf, 380 | "clickbait", 381 | encoder 382 | ) 383 | 384 | # and just execute this 385 | callback.run(train_loader, index["train"]) 386 | callback.run(test_loader, index["test"]) 387 | ``` 388 | 389 | #### HuggingFace Callback 390 | Collect the dataset and train your custom transformer model as follows: 391 | 392 | ```python 393 | from refinery.adapter import transformers 394 | dataset, mapping, index = transformers.build_classification_dataset(client, "headline", "__clickbait") 395 | 396 | # train a model here, we're simplifying this by just using an existing model w/o retraining 397 | from transformers import pipeline 398 | pipe = pipeline("text-classification", model="distilbert-base-uncased") 399 | 400 | # if you're interested to see how a training looks like, look into the above HuggingFace adapter 401 | 402 | # you can now apply the callback 403 | from refinery.callbacks.transformers import TransformerCallback 404 | callback = TransformerCallback( 405 | client, 406 | pipe, 407 | "clickbait", 408 | mapping 409 | ) 410 | 411 | callback.run(dataset["train"]["headline"], index["train"]) 412 | callback.run(dataset["test"]["headline"], index["test"]) 413 | ``` 414 | 415 | #### Generic Callback 416 | This one is your fallback if you have a very custom solution; other than that, we recommend you look into the framework-specific classes. 417 | 418 | ```python 419 | from refinery.callbacks.inference import ModelCallback 420 | from refinery.adapter.sklearn import build_classification_dataset 421 | from sklearn.linear_model import LogisticRegression 422 | 423 | data = build_classification_dataset(client, "headline", "__clickbait", "distilbert-base-uncased"0) 424 | clf = LogisticRegression() 425 | clf.fit(data["train"]["inputs"], data["train"]["labels"]) 426 | 427 | # you can build initialization functions that set states of objects you use in the pipeline 428 | def initialize_fn(inputs, labels, **kwargs): 429 | return {"clf": kwargs["clf"]} 430 | 431 | # postprocessing shifts the model outputs into a format accepted by our API 432 | def postprocessing_fn(outputs, **kwargs): 433 | named_outputs = [] 434 | for prediction in outputs: 435 | pred_index = prediction.argmax() 436 | label = kwargs["clf"].classes_[pred_index] 437 | confidence = prediction[pred_index] 438 | named_outputs.append([label, confidence]) 439 | return named_outputs 440 | 441 | callback = ModelCallback( 442 | client: Client, 443 | "my-custom-regression", 444 | "clickbait", 445 | inference_fn=clf.predict_proba, 446 | initialize_fn=initialize_fn, 447 | postprocessing_fn=postprocessing_fn 448 | ) 449 | 450 | # executing this will call the refinery API with batches of size 32 451 | callback.initialize_and_run(data["train"]["inputs"], data["train"]["index"]) 452 | callback.run(data["test"]["inputs"], data["test"]["index"]) 453 | ``` 454 | 455 | 456 | ## Contributing 457 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. 458 | 459 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". 460 | 461 | 1. Fork the Project 462 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 463 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 464 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 465 | 5. Open a Pull Request 466 | 467 | And please don't forget to leave a ⭐ if you like the work! 468 | 469 | ## License 470 | Distributed under the MIT License. See LICENSE.txt for more information. 471 | 472 | ## Contact 473 | This library is developed and maintained by [Kern AI](https://github.com/code-kern-ai). If you want to provide us with feedback or have some questions, don't hesitate to contact us. We're super happy to help ✌️ 474 | -------------------------------------------------------------------------------- /example_export.py: -------------------------------------------------------------------------------- 1 | from refinery import Client 2 | 3 | client = Client.from_secrets_file("secrets.json") 4 | 5 | print("Let's look into project details...") 6 | print(client.get_project_details()) 7 | 8 | print("-" * 10) 9 | print("And these are the first 10 records...") 10 | print(client.get_record_export().head(10)) 11 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* 3 | python3 setup.py bdist_wheel --universal 4 | twine upload dist/* -------------------------------------------------------------------------------- /refinery/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from uuid import uuid4 4 | from wasabi import msg 5 | import pandas as pd 6 | from refinery import authentication, api_calls, settings, exceptions, util 7 | from typing import List, Optional, Dict, Any 8 | import json 9 | import os.path 10 | from tqdm import tqdm 11 | import spacy 12 | import time 13 | from refinery import settings 14 | 15 | 16 | class Client: 17 | """Client object which can be used to directly address the Kern AI refinery API. 18 | 19 | Args: 20 | user_name (str): Your username (email) for the application. 21 | password (str): The respective password. Do not share this! 22 | project_id (str): The link to your project. This can be found in the URL in an active project. 23 | uri (str, optional): Link to the host of the application. Defaults to "https://app.kern.ai". 24 | 25 | Raises: 26 | exceptions.get_api_exception_class: If your credentials are incorrect, an exception is raised. 27 | """ 28 | 29 | def __init__( 30 | self, user_name: str, password: str, project_id: str, uri=settings.DEFAULT_URI 31 | ): 32 | msg.info(f"Connecting to {uri}") 33 | settings.set_base_uri(uri) 34 | self.session_token = authentication.create_session_token( 35 | user_name=user_name, password=password 36 | ) 37 | if self.session_token is not None: 38 | msg.good("Logged in to system.") 39 | else: 40 | msg.fail(f"Could not log in at {uri}. Please check username and password.") 41 | raise exceptions.get_api_exception_class(401) 42 | self.project_id = project_id 43 | 44 | self.get_project_details() 45 | 46 | @classmethod 47 | def from_secrets_file(cls, path_to_file: str, project_id: Optional[str] = None): 48 | """Creates a Client object from a secrets file. 49 | 50 | Args: 51 | path_to_file (str): Path to the secrets file. 52 | project_id (Optional[str], optional): The link to your project. This can be found in the URL in an active project. Defaults to None. In that case, it will read the project id from the file 53 | 54 | Returns: 55 | refinery.Client: Client object. 56 | """ 57 | with open(path_to_file, "r") as file: 58 | content = json.load(file) 59 | 60 | uri = content.get("uri") 61 | if uri is None: 62 | uri = settings.DEFAULT_URI 63 | 64 | if project_id is None: 65 | project_id = content["project_id"] 66 | 67 | return cls( 68 | user_name=content["user_name"], 69 | password=content["password"], 70 | project_id=project_id, 71 | uri=uri, 72 | ) 73 | 74 | def get_project_details(self) -> Dict[str, str]: 75 | """Collect high-level information about your project: name, description, and tokenizer 76 | 77 | Returns: 78 | Dict[str, str]: dictionary containing the above information 79 | """ 80 | url = settings.get_project_url(self.project_id) 81 | api_response = api_calls.get_request( 82 | url, 83 | self.session_token, 84 | self.project_id, 85 | ) 86 | return api_response 87 | 88 | def get_primary_keys(self) -> List[str]: 89 | """Fetches the primary keys of your current project. 90 | 91 | Returns: 92 | List[str]: Containing the primary keys of your project. 93 | """ 94 | project_details = self.get_project_details() 95 | project_attributes = project_details["attributes"] 96 | 97 | primary_keys = [] 98 | for attribute in project_attributes: 99 | if attribute["is_primary_key"]: 100 | primary_keys.append(attribute["name"]) 101 | return primary_keys 102 | 103 | def get_lookup_list(self, list_id: str) -> Dict[str, str]: 104 | """Fetches a lookup list of your current project. 105 | 106 | Args: 107 | list_id (str): The ID of the lookup list. 108 | 109 | Returns: 110 | Dict[str, str]: Containing the specified lookup list of your project. 111 | """ 112 | url = settings.get_lookup_list_url(self.project_id, list_id) 113 | api_response = api_calls.get_request( 114 | url, 115 | self.session_token, 116 | self.project_id, 117 | ) 118 | return api_response 119 | 120 | def get_lookup_lists(self) -> List[Dict[str, str]]: 121 | """Fetches all lookup lists of your current project 122 | 123 | Returns: 124 | List[Dict[str, str]]: Containing the lookups lists of your project. 125 | """ 126 | lookup_lists = [] 127 | for lookup_list_id in self.get_project_details()["knowledge_base_ids"]: 128 | lookup_list = self.get_lookup_list(lookup_list_id) 129 | lookup_lists.append(lookup_list) 130 | return lookup_lists 131 | 132 | def get_record_export( 133 | self, 134 | num_samples: Optional[int] = None, 135 | download_to: Optional[str] = None, 136 | tokenize: Optional[bool] = True, 137 | keep_attributes: Optional[List[str]] = None, 138 | dropna: Optional[bool] = False, 139 | ) -> pd.DataFrame: 140 | """Collects the export data of your project (i.e. the same data if you would export in the web app). 141 | 142 | Args: 143 | num_samples (Optional[int], optional): If set, only the first `num_samples` records are collected. Defaults to None. 144 | 145 | Returns: 146 | pd.DataFrame: DataFrame containing your record data. 147 | """ 148 | url = settings.get_export_url(self.project_id) 149 | api_response = api_calls.get_request( 150 | url, self.session_token, self.project_id, **{"num_samples": num_samples} 151 | ) 152 | df = pd.DataFrame(api_response) 153 | 154 | if tokenize: 155 | tokenize_attributes = [] 156 | for attribute in self.get_project_details()["attributes"]: 157 | if attribute["data_type"] == "TEXT": 158 | tokenize_attributes.append(attribute["name"]) 159 | 160 | if len(tokenize_attributes) > 0: 161 | tokenizer_package = self.get_project_details()["tokenizer"] 162 | if not spacy.util.is_package(tokenizer_package): 163 | spacy.cli.download(tokenizer_package) 164 | 165 | nlp = spacy.load(tokenizer_package) 166 | 167 | msg.info(f"Tokenizing data with spaCy '{tokenizer_package}'.") 168 | msg.info( 169 | "This will be provided from the server in future versions of refinery." 170 | ) 171 | 172 | tqdm.pandas(desc="Applying tokenization locally") 173 | for attribute in tokenize_attributes: 174 | df[f"{attribute}__tokenized"] = df[attribute].progress_apply( 175 | lambda x: nlp(x) 176 | ) 177 | 178 | else: 179 | msg.warn( 180 | "There are no attributes that can be tokenized in this project." 181 | ) 182 | 183 | if keep_attributes is not None: 184 | df = df[keep_attributes] 185 | 186 | if dropna: 187 | df = df.dropna() 188 | 189 | if download_to is not None: 190 | df.to_json(download_to, orient="records") 191 | msg.good(f"Downloaded export to {download_to}") 192 | return df 193 | 194 | def post_associations( 195 | self, 196 | associations, 197 | indices, 198 | name, 199 | label_task_name, 200 | source_type: Optional[str] = "heuristic", 201 | ): 202 | """Posts associations to the server. 203 | 204 | Args: 205 | associations (List[Dict[str, str]]): List of associations to post. 206 | indices (List[str]): List of indices to post to. 207 | name (str): Name of the association set. 208 | label_task_name (str): Name of the label task. 209 | source_type (Optional[str], optional): Source type of the associations. Defaults to "heuristic". 210 | """ 211 | url = settings.get_associations_url(self.project_id) 212 | api_response = api_calls.post_request( 213 | url, 214 | { 215 | "associations": associations, 216 | "indices": indices, 217 | "name": name, 218 | "label_task_name": label_task_name, 219 | "source_type": source_type, 220 | }, 221 | self.session_token, 222 | self.project_id, 223 | ) 224 | return api_response 225 | 226 | def post_records(self, records: List[Dict[str, Any]]): 227 | """Posts records to the server. 228 | 229 | Args: 230 | records (List[Dict[str, str]]): List of records to post. 231 | """ 232 | request_uuid = str(uuid4()) 233 | url = settings.get_import_json_url(self.project_id) 234 | 235 | batch_responses = [] 236 | for records_batch in util.batch(records, settings.BATCH_SIZE_DEFAULT): 237 | api_response = api_calls.post_request( 238 | url, 239 | { 240 | "request_uuid": request_uuid, 241 | "records": records_batch, 242 | "is_last": False, 243 | }, 244 | self.session_token, 245 | self.project_id, 246 | ) 247 | batch_responses.append(api_response) 248 | time.sleep(0.5) # wait half a second to avoid server overload 249 | api_calls.post_request( 250 | url, 251 | {"request_uuid": request_uuid, "records": [], "is_last": True}, 252 | self.session_token, 253 | self.project_id, 254 | ) 255 | return batch_responses 256 | 257 | def post_df(self, df: pd.DataFrame): 258 | """Posts a DataFrame to the server. 259 | 260 | Args: 261 | df (pd.DataFrame): DataFrame to post. 262 | """ 263 | records = df.to_dict(orient="records") 264 | return self.post_records(records) 265 | 266 | def post_file_import( 267 | self, path: str, import_file_options: Optional[str] = "" 268 | ) -> bool: 269 | """Imports a file into your project. 270 | 271 | Args: 272 | path (str): Path to the file to import. 273 | import_file_options (Optional[str], optional): Options for the Pandas import. Defaults to None. 274 | 275 | Raises: 276 | FileImportError: If the file could not be imported, an exception is raised. 277 | 278 | Returns: 279 | bool: True if the file was imported successfully, False otherwise. 280 | """ 281 | if not os.path.exists(path): 282 | raise exceptions.FileImportError( 283 | f"Given filepath is not valid. Path: {path}" 284 | ) 285 | last_path_part = path.split("/")[-1] 286 | file_name = f"{last_path_part}_SCALE" 287 | 288 | FILE_TYPE = "records" 289 | # config 290 | config_url = settings.get_full_config(self.project_id) 291 | config_api_response = api_calls.get_request( 292 | config_url, 293 | self.session_token, 294 | self.project_id, 295 | ) 296 | endpoint = config_api_response.get("KERN_S3_ENDPOINT") 297 | 298 | # credentials 299 | credentials_url = settings.get_import_file_url(self.project_id) 300 | credentials_api_response = api_calls.post_request( 301 | credentials_url, 302 | { 303 | "file_name": file_name, 304 | "file_type": FILE_TYPE, 305 | "import_file_options": import_file_options, 306 | }, 307 | self.session_token, 308 | self.project_id, 309 | ) 310 | credentials = credentials_api_response["Credentials"] 311 | access_key = credentials["AccessKeyId"] 312 | secret_key = credentials["SecretAccessKey"] 313 | session_token = credentials["SessionToken"] 314 | upload_task_id = credentials_api_response["uploadTaskId"] 315 | bucket = credentials_api_response["bucket"] 316 | success = util.s3_upload( 317 | access_key, 318 | secret_key, 319 | session_token, 320 | bucket, 321 | endpoint, 322 | upload_task_id, 323 | path, 324 | file_name, 325 | ) 326 | if success: 327 | msg.good(f"Uploaded {path} to object storage.") 328 | upload_task_id = ( 329 | upload_task_id.split("/")[-1] 330 | if "/" in upload_task_id 331 | else upload_task_id 332 | ) 333 | self.__monitor_task(upload_task_id) 334 | 335 | else: 336 | msg_text = f"Could not upload {path} to your project." 337 | msg.fail(msg_text) 338 | raise exceptions.FileImportError(msg_text) 339 | 340 | def __monitor_task(self, upload_task_id: str) -> None: 341 | do_monitoring = True 342 | idx = 0 343 | last_progress = 0.0 344 | print_success_message = False 345 | with tqdm( 346 | total=100.00, 347 | colour="green", 348 | bar_format="{desc}: {percentage:.2f}%|{bar:10}| {n:.2f}/{total_fmt}", 349 | ) as pbar: 350 | pbar.set_description_str(desc="PENDING", refresh=True) 351 | while do_monitoring: 352 | idx += 1 353 | task = self.__get_task(upload_task_id) 354 | task_progress = task.get("progress") if task.get("progress") else 0.0 355 | task_state = task.get("state") if task.get("state") else "FAILED" 356 | progress = task_progress - last_progress 357 | last_progress = task_progress 358 | pbar.update(progress) 359 | pbar.set_description_str(desc=task_state, refresh=True) 360 | if task_state == "DONE" or task_state == "FAILED": 361 | print_success_message = task_state == "DONE" 362 | do_monitoring = False 363 | if idx >= 100: 364 | raise exceptions.FileImportError( 365 | "Timeout while upload, please check the upload progress in the UI." 366 | ) 367 | time.sleep(0.5) 368 | if print_success_message: 369 | msg.good("File upload successful.") 370 | else: 371 | msg.fail( 372 | "Upload failed. Please look into the UI notification center for more details." 373 | ) 374 | 375 | def __get_task(self, upload_task_id: str) -> Dict[str, Any]: 376 | api_response = api_calls.get_request( 377 | settings.get_task(self.project_id, upload_task_id), 378 | self.session_token, 379 | self.project_id, 380 | ) 381 | return api_response 382 | -------------------------------------------------------------------------------- /refinery/adapter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-kern-ai/refinery-python-sdk/088add69f2cb365142f9386ab1dd2919d921131b/refinery/adapter/__init__.py -------------------------------------------------------------------------------- /refinery/adapter/rasa.py: -------------------------------------------------------------------------------- 1 | import os 2 | from wasabi import msg 3 | from typing import Any, List, Optional 4 | import pandas as pd 5 | import yaml 6 | from refinery import Client, exceptions 7 | from collections import OrderedDict 8 | 9 | # https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data 10 | class literal(str): 11 | pass 12 | 13 | 14 | def literal_presenter(dumper, data): 15 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 16 | 17 | 18 | yaml.add_representer(literal, literal_presenter) 19 | 20 | 21 | def ordered_dict_presenter(dumper, data): 22 | return dumper.represent_dict(data.items()) 23 | 24 | 25 | yaml.add_representer(OrderedDict, ordered_dict_presenter) 26 | 27 | CONSTANT_OUTSIDE = "OUTSIDE" 28 | CONSTANT_LABEL_BEGIN = "B-" 29 | CONSTANT_LABEL_INTERMEDIATE = "I-" 30 | 31 | 32 | def build_literal_from_iterable(iterable: List[Any]) -> str: 33 | """Builds a Rasa-conform yaml string from an iterable. 34 | 35 | Args: 36 | iterable (List[Any]): List with values to be converted to a literal block. 37 | 38 | Returns: 39 | str: literal block 40 | """ 41 | return "\n".join([f"- {value}" for value in iterable]) + "\n" 42 | 43 | 44 | def inject_label_in_text( 45 | row: pd.Series, text_name: str, tokenized_label_task: str, constant_outside: str 46 | ) -> str: 47 | """Insert token labels into text. 48 | E.g. "Hello, my name is Johannes Hötter" -> "Hello, my name is [Johannes Hötter](person)" 49 | 50 | Args: 51 | row (pd.Series): row of the record export dataframe 52 | text_name (str): name of the text/chat field 53 | tokenized_label_task (str): name of the label task containing token-level labels 54 | constant_outside (str): constant to be used for outside labels 55 | 56 | Returns: 57 | str: injected text 58 | """ 59 | string = "" 60 | token_list = row[f"{text_name}__tokenized"] 61 | 62 | close_multitoken_label = False 63 | multitoken_label = False 64 | for idx, token in enumerate(token_list): 65 | 66 | if idx < len(token_list) - 1: 67 | token_next = token_list[idx + 1] 68 | label_next = row[tokenized_label_task][idx + 1] 69 | if label_next.startswith(CONSTANT_LABEL_INTERMEDIATE): 70 | multitoken_label = True 71 | else: 72 | if multitoken_label: 73 | close_multitoken_label = True 74 | multitoken_label = False 75 | num_whitespaces = token_next.idx - (token.idx + len(token)) 76 | else: 77 | num_whitespaces = 0 78 | whitespaces = " " * num_whitespaces 79 | 80 | label = row[tokenized_label_task][idx] 81 | if label != constant_outside: 82 | if multitoken_label: 83 | if label.startswith(CONSTANT_LABEL_BEGIN): 84 | string += f"[{token.text}{whitespaces}" 85 | else: 86 | string += f"{token.text}{whitespaces}" 87 | else: 88 | label_trimmed = label[2:] # remove B- and I- 89 | if close_multitoken_label: 90 | string += f"{token.text}]({label_trimmed}){whitespaces}" 91 | close_multitoken_label = False 92 | else: 93 | string += f"[{token.text}]({label_trimmed}){whitespaces}" 94 | else: 95 | string += f"{token.text}{whitespaces}" 96 | return string 97 | 98 | 99 | def build_intent_yaml( 100 | client: Client, 101 | text_name: str, 102 | intent_label_task: str, 103 | metadata_label_task: Optional[str] = None, 104 | tokenized_label_task: Optional[str] = None, 105 | dir_name: str = "data", 106 | file_name: str = "nlu.yml", 107 | constant_outside: str = CONSTANT_OUTSIDE, 108 | version: str = "3.1", 109 | ) -> None: 110 | """builds a Rasa NLU yaml file from your project data via the client object. 111 | 112 | Args: 113 | client (Client): connected Client object for your project 114 | text_name (str): name of the text/chat field 115 | intent_label_task (str): name of the classification label with the intents 116 | metadata_label_task (Optional[str], optional): if you have a metadata task (e.g. sentiment), you can list it here. Currently, only one is possible to provide. Defaults to None. 117 | tokenized_label_task (Optional[str], optional): if you have a token-level task (e.g. for entities), you can list it here. Currently, only one is possible to provide. Defaults to None. 118 | dir_name (str, optional): name of your rasa data directory. Defaults to "data". 119 | file_name (str, optional): name of the file you want to store the data to. Defaults to "nlu.yml". 120 | constant_outside (str, optional): constant to be used for outside labels in token-level tasks. Defaults to CONSTANT_OUTSIDE. 121 | version (str, optional): Rasa version. Defaults to "3.1". 122 | 123 | Raises: 124 | exceptions.UnknownItemError: if the item you are looking for is not found. 125 | """ 126 | msg.info("Building training data for Rasa") 127 | msg.warn("If you haven't done so yet, please install rasa and run `rasa init`") 128 | df = client.get_record_export(tokenize=(tokenized_label_task is not None)) 129 | 130 | for attribute in [text_name, intent_label_task, metadata_label_task, tokenized_label_task]: 131 | if attribute is not None and attribute not in df.columns: 132 | raise exceptions.UnknownItemError(f"Can't find argument '{attribute}' in the existing export schema: {df.columns.tolist()}") 133 | 134 | if tokenized_label_task is not None: 135 | text_name_injected = f"{text_name}__injected" 136 | df[text_name_injected] = df.apply( 137 | lambda x: inject_label_in_text( 138 | x, text_name, tokenized_label_task, constant_outside 139 | ), 140 | axis=1, 141 | ) 142 | text_name = text_name_injected 143 | 144 | nlu_list = [] 145 | for label, df_sub_label in df.groupby(intent_label_task): 146 | 147 | if metadata_label_task is not None: 148 | metadata_label_name = metadata_label_task.split("__")[1] 149 | for metadata_label, df_sub_label_sub_metadata_label in df_sub_label.groupby( 150 | metadata_label_task 151 | ): 152 | literal_string = build_literal_from_iterable( 153 | df_sub_label_sub_metadata_label[text_name].tolist() 154 | ) 155 | nlu_list.append( 156 | OrderedDict( 157 | intent=label, 158 | metadata=OrderedDict(**{metadata_label_name: metadata_label}), 159 | examples=literal(literal_string), 160 | ) 161 | ) 162 | else: 163 | literal_string = build_literal_from_iterable( 164 | df_sub_label[text_name].tolist() 165 | ) 166 | nlu_list.append(OrderedDict(intent=label, examples=literal(literal_string))) 167 | 168 | if tokenized_label_task is not None: 169 | 170 | def flatten(xss): 171 | return [x for xs in xss for x in xs] 172 | 173 | labels = set(flatten(df[tokenized_label_task].tolist())) 174 | lookup_list_names = [] 175 | for label in labels: 176 | if label.startswith(CONSTANT_LABEL_BEGIN): 177 | label_trimmed = label[2:] # remove B- 178 | lookup_list_names.append(label_trimmed) 179 | 180 | for lookup_list in client.get_lookup_lists(): 181 | if lookup_list["name"] in lookup_list_names: 182 | values = [entry["value"] for entry in lookup_list["terms"]] 183 | literal_string = build_literal_from_iterable(values) 184 | nlu_list.append( 185 | OrderedDict( 186 | lookup=lookup_list["name"], examples=literal(literal_string) 187 | ) 188 | ) 189 | 190 | nlu_dict = OrderedDict(version=version, nlu=nlu_list) 191 | 192 | if dir_name is not None and not os.path.isdir(dir_name): 193 | os.mkdir(dir_name) 194 | 195 | file_path = os.path.join(dir_name, file_name) 196 | 197 | with open(file_path, "w") as f: 198 | yaml.dump(nlu_dict, f, allow_unicode=True) 199 | msg.good(f"Saved training data to {file_path}! 🚀") 200 | msg.warn( 201 | f"Please make sure to add the project-specific files domain.yml, {os.path.join(dir_name, 'rules.yml')} and {os.path.join(dir_name, 'stories.yml')}." 202 | ) 203 | msg.info("More information about these files can be found here:") 204 | msg.info(" - Domain: https://rasa.com/docs/rasa/domain") 205 | msg.info(" - Rules: https://rasa.com/docs/rasa/rules") 206 | msg.info(" - Stories: https://rasa.com/docs/rasa/stories") 207 | msg.good( 208 | "You're all set, and can now start building your conversational AI via `rasa train`! 🎉" 209 | ) 210 | -------------------------------------------------------------------------------- /refinery/adapter/sklearn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from embedders.classification.contextual import TransformerSentenceEmbedder 3 | from refinery import Client 4 | from refinery.adapter.util import split_train_test_on_weak_supervision 5 | 6 | 7 | def build_classification_dataset( 8 | client: Client, 9 | sentence_input: str, 10 | classification_label: str, 11 | config_string: Optional[str] = None, 12 | num_train: Optional[int] = None, 13 | ) -> Dict[str, Dict[str, Any]]: 14 | """ 15 | Builds a classification dataset from a refinery client and a config string. 16 | 17 | Args: 18 | client (Client): Refinery client 19 | sentence_input (str): Name of the column containing the sentence input. 20 | classification_label (str): Name of the label; if this is a task on the full record, enter the string with as "__