├── refinery ├── adapter │ ├── __init__.py │ ├── transformers.py │ ├── util.py │ ├── sklearn.py │ ├── torch.py │ └── rasa.py ├── callbacks │ ├── __init__.py │ ├── sklearn.py │ ├── transformers.py │ ├── torch.py │ └── inference.py ├── authentication.py ├── util.py ├── cli.py ├── settings.py ├── api_calls.py ├── exceptions.py └── __init__.py ├── publish.sh ├── requirements.txt ├── example_export.py ├── .github └── dependabot.yml ├── setup.py ├── .gitignore ├── LICENSE └── README.md /refinery/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /refinery/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* 3 | python3 setup.py bdist_wheel --universal 4 | twine upload dist/* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | requests 4 | boto3 5 | botocore 6 | spacy 7 | wasabi 8 | embedders 9 | datasets -------------------------------------------------------------------------------- /example_export.py: -------------------------------------------------------------------------------- 1 | from refinery import Client 2 | 3 | client = Client.from_secrets_file("secrets.json") 4 | 5 | print("Let's look into project details...") 6 | print(client.get_project_details()) 7 | 8 | print("-" * 10) 9 | print("And these are the first 10 records...") 10 | print(client.get_record_export().head(10)) 11 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem 4 | - package-ecosystem: "pip" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | # default is / which breaks drone 9 | pull-request-branch-name: 10 | separator: "-" 11 | # not created automatically for version updates so only security ones are created 12 | # https://docs.github.com/en/code-security/dependabot/dependabot-security-updates/configuring-dependabot-security-updates#overriding-the-default-behavior-with-a-configuration-file 13 | open-pull-requests-limit: 0 14 | 15 | -------------------------------------------------------------------------------- /refinery/authentication.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from refinery import settings 3 | import requests 4 | 5 | 6 | def create_session_token(user_name: str, password: str) -> str: 7 | headers = {"Accept": "application/json"} 8 | action_url = ( 9 | requests.get(settings.get_authentication_url(), headers=headers) 10 | .json() 11 | .get("ui") 12 | .get("action") 13 | ) 14 | session_token = ( 15 | requests.post( 16 | action_url, 17 | headers=headers, 18 | json={ 19 | "method": "password", 20 | "password": password, 21 | "password_identifier": user_name, 22 | }, 23 | ) 24 | .json() 25 | .get("session_token") 26 | ) 27 | return session_token 28 | -------------------------------------------------------------------------------- /refinery/util.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.client import Config 3 | from typing import List, Dict, Any 4 | 5 | 6 | def s3_upload( 7 | access_key: str, 8 | secret_key: str, 9 | aws_session_token: str, 10 | target_bucket: str, 11 | url: str, 12 | upload_task_id: str, 13 | file_path: str, 14 | file_name: str, 15 | ) -> bool: 16 | """ 17 | Connects to the object storage with temporary credentials generated for the 18 | given user_id, project_id and bucket 19 | """ 20 | s3 = boto3.resource( 21 | "s3", 22 | endpoint_url=url, 23 | aws_access_key_id=access_key, 24 | aws_secret_access_key=secret_key, 25 | aws_session_token=aws_session_token, 26 | config=Config(signature_version="s3v4"), 27 | region_name="us-east-1", 28 | ) 29 | s3_object = s3.Object(target_bucket, f"{upload_task_id}/{file_name}") 30 | with open(file_path, "rb") as file: 31 | s3_object.put(Body=file) 32 | return True 33 | 34 | 35 | def batch(records: List[Dict[str, Any]], batch_size: int): 36 | """Batches records into batches of size `batch_size`. 37 | 38 | Args: 39 | records (List[Dict[str, Any]]): List of records to batch. 40 | batch_size (int): Size of the batches. 41 | 42 | Yields: 43 | List[Dict[str, Any]]: Batches of records. 44 | """ 45 | for i in range(0, len(records), batch_size): 46 | yield records[i : i + batch_size] 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | 5 | from setuptools import setup, find_packages 6 | 7 | this_directory = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(this_directory, "README.md")) as file: 9 | long_description = file.read() 10 | 11 | setup( 12 | name="refinery-python-sdk", 13 | version="1.4.0", 14 | author="jhoetter", 15 | author_email="johannes.hoetter@kern.ai", 16 | description="Official Python SDK for Kern AI refinery.", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/code-kern-ai/refinery-python", 20 | keywords=[ 21 | "Kern AI", 22 | "refinery", 23 | "machine-learning", 24 | "supervised-learning", 25 | "data-centric-ai", 26 | "data-annotation", 27 | "python", 28 | ], 29 | classifiers=[ 30 | "Development Status :: 4 - Beta", 31 | "Programming Language :: Python :: 3", 32 | "License :: OSI Approved :: Apache Software License", 33 | ], 34 | package_dir={"": "."}, 35 | packages=find_packages("."), 36 | install_requires=[ 37 | "numpy", 38 | "pandas", 39 | "requests", 40 | "boto3", 41 | "botocore", 42 | "spacy", 43 | "wasabi", 44 | "embedders", 45 | "datasets", 46 | ], 47 | entry_points={ 48 | "console_scripts": [ 49 | "rsdk=refinery.cli:main", 50 | ], 51 | }, 52 | ) 53 | -------------------------------------------------------------------------------- /refinery/cli.py: -------------------------------------------------------------------------------- 1 | from refinery import Client 2 | import sys 3 | from wasabi import msg 4 | 5 | 6 | def pull(): 7 | client = Client.from_secrets_file("secrets.json") 8 | project_name = client.get_project_details()["name"] 9 | download_to = f"{project_name}.json" 10 | client.get_record_export(download_to=download_to) 11 | 12 | 13 | def push(file_path): 14 | client = Client.from_secrets_file("secrets.json") 15 | client.post_file_import(file_path) 16 | 17 | 18 | def help(): 19 | msg.info( 20 | "With the refinery SDK, you can type commands as `rsdk `. Currently, we provide the following:" 21 | ) 22 | msg.info( 23 | "- rsdk pull: Download the record export of the project defined in `settings.json` to your local storage." 24 | ) 25 | msg.info( 26 | "- rsdk push : Upload a record file to the project defined in `settings.json` from your local storage." 27 | ) 28 | 29 | 30 | def main(): 31 | cli_args = sys.argv[1:] 32 | if len(cli_args) == 0: 33 | msg.fail( 34 | "Please provide some arguments when running the `rsdk` command. Type `rsdk help` for some instructions." 35 | ) 36 | else: 37 | command = cli_args[0] 38 | if command == "pull": 39 | pull() 40 | elif command == "push": 41 | if len(cli_args) != 2: 42 | msg.fail("Please provide a path to a file when running rsdk push.") 43 | else: 44 | file_path = cli_args[1] 45 | push(file_path) 46 | elif command == "help": 47 | help() 48 | else: 49 | msg.fail( 50 | f"Could not understand command `{command}`. Type `rsdk help` for some instructions." 51 | ) 52 | -------------------------------------------------------------------------------- /refinery/callbacks/sklearn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict 2 | from refinery import Client 3 | from refinery.callbacks.inference import ModelCallback 4 | from sklearn.base import BaseEstimator 5 | 6 | 7 | def initialize_fn(inputs, labels, **kwargs): 8 | return {"clf": kwargs["clf"]} 9 | 10 | 11 | def postprocessing_fn(outputs, **kwargs): 12 | named_outputs = [] 13 | for prediction in outputs: 14 | pred_index = prediction.argmax() 15 | label = kwargs["clf"].classes_[pred_index] 16 | confidence = prediction[pred_index] 17 | named_outputs.append([label, confidence]) 18 | return named_outputs 19 | 20 | 21 | class SklearnCallback(ModelCallback): 22 | def __init__( 23 | self, 24 | client: Client, 25 | sklearn_model: BaseEstimator, 26 | labeling_task_name: str, 27 | ) -> None: 28 | """Callback for sklearn models. 29 | 30 | Args: 31 | client (Client): Refinery client 32 | sklearn_model (BaseEstimator): Sklearn model 33 | labeling_task_name (str): Name of the labeling task 34 | """ 35 | 36 | super().__init__( 37 | client, 38 | sklearn_model.__class__.__name__, 39 | labeling_task_name, 40 | inference_fn=sklearn_model.predict_proba, 41 | initialize_fn=initialize_fn, 42 | postprocessing_fn=postprocessing_fn, 43 | ) 44 | self.sklearn_model = sklearn_model 45 | self.initialized = False 46 | self.kwargs = {"clf": self.sklearn_model} 47 | 48 | def run(self, inputs: List[Any], indices: List[Dict[str, Any]]) -> None: 49 | if not self.initialized: 50 | self.initialize(None, None) 51 | self.initialized = True 52 | super().run(inputs, indices) 53 | -------------------------------------------------------------------------------- /refinery/callbacks/transformers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict 2 | from refinery import Client 3 | from refinery.callbacks.inference import ModelCallback 4 | from transformers import pipeline 5 | 6 | 7 | def initialize_fn(inputs, labels, **kwargs): 8 | return {"mapping": kwargs["mapping"]} 9 | 10 | 11 | def postprocessing_fn(outputs, **kwargs): 12 | named_outputs = [] 13 | for prediction in outputs: 14 | label = kwargs["mapping"][prediction["label"]] 15 | confidence = prediction["score"] 16 | named_output = [label, confidence] 17 | named_outputs.append(named_output) 18 | return named_outputs 19 | 20 | 21 | class TransformerCallback(ModelCallback): 22 | def __init__( 23 | self, 24 | client: Client, 25 | transformer_model: pipeline, 26 | labeling_task_name: str, 27 | mapping: Dict[str, str], 28 | ) -> None: 29 | """Callback for sklearn models. 30 | 31 | Args: 32 | client (Client): Refinery client 33 | sklearn_model (BaseEstimator): Sklearn model 34 | labeling_task_name (str): Name of the labeling task 35 | """ 36 | 37 | super().__init__( 38 | client, 39 | transformer_model.__class__.__name__, 40 | labeling_task_name, 41 | inference_fn=transformer_model.__call__, 42 | initialize_fn=initialize_fn, 43 | postprocessing_fn=postprocessing_fn, 44 | ) 45 | self.sklearn_model = transformer_model 46 | self.initialized = False 47 | self.kwargs = {"mapping": mapping} 48 | 49 | def run(self, inputs: List[Any], indices: List[Dict[str, Any]]) -> None: 50 | if not self.initialized: 51 | self.initialize(None, None) 52 | self.initialized = True 53 | super().run(inputs, indices) 54 | -------------------------------------------------------------------------------- /refinery/settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | BASE_URI: str 3 | DEFAULT_URI: str = "https://app.kern.ai" 4 | 5 | BATCH_SIZE_DEFAULT: int = 1000 6 | 7 | 8 | def set_base_uri(uri: str): 9 | global BASE_URI 10 | BASE_URI = uri 11 | 12 | 13 | def add_query_params(url: str, **kwargs) -> str: 14 | set_question_mark = False 15 | for key, value in kwargs.items(): 16 | if value is not None: 17 | if not set_question_mark: 18 | url = f"{url}?{key}={value}" 19 | set_question_mark = True 20 | else: 21 | url = f"{url}&{key}={value}" 22 | return url 23 | 24 | 25 | def get_authentication_url() -> str: 26 | return f"{BASE_URI}/.ory/kratos/public/self-service/login/api" 27 | 28 | 29 | def get_project_url(project_id: str) -> str: 30 | return f"{BASE_URI}/api/project/{project_id}" 31 | 32 | 33 | def get_lookup_list_url(project_id: str, lookup_list_id: str) -> str: 34 | return f"{get_project_url(project_id)}/lookup_list/{lookup_list_id}" 35 | 36 | 37 | def get_records_url(project_id: str) -> str: 38 | return f"{get_project_url(project_id)}/records" 39 | 40 | 41 | def get_export_url(project_id: str) -> str: 42 | return f"{get_project_url(project_id)}/export" 43 | 44 | 45 | def get_import_file_url(project_id: str) -> str: 46 | return f"{get_project_url(project_id)}/import_file" 47 | 48 | 49 | def get_import_json_url(project_id: str) -> str: 50 | return f"{get_project_url(project_id)}/import_json" 51 | 52 | 53 | def get_associations_url(project_id: str) -> str: 54 | return f"{get_project_url(project_id)}/associations" 55 | 56 | 57 | def get_full_config(project_id: str) -> str: 58 | return f"{get_project_url(project_id)}/import/full_config" 59 | 60 | 61 | def get_task(project_id: str, task_id: str) -> str: 62 | return f"{get_project_url(project_id)}/import/task/{task_id}" 63 | -------------------------------------------------------------------------------- /refinery/callbacks/torch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict 2 | from refinery import Client 3 | from refinery.callbacks.inference import ModelCallback 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from sklearn import preprocessing 7 | 8 | 9 | def initialize_fn(inputs, labels, **kwargs): 10 | return {"encoder": kwargs["encoder"]} 11 | 12 | 13 | def postprocessing_fn(outputs, **kwargs): 14 | named_outputs = [] 15 | pred_argindices = outputs.argmax(axis=1) 16 | for predindex, pred_argindex in enumerate(pred_argindices): 17 | label = kwargs["encoder"].classes_[pred_argindex] 18 | confidence = outputs[predindex][pred_argindex].tolist() 19 | named_outputs.append([label, confidence]) 20 | return named_outputs 21 | 22 | 23 | class TorchCallback(ModelCallback): 24 | def __init__( 25 | self, 26 | client: Client, 27 | torch_model: nn.Module, 28 | labeling_task_name: str, 29 | encoder: preprocessing.LabelEncoder, 30 | ) -> None: 31 | """Callback for sklearn models. 32 | 33 | Args: 34 | client (Client): Refinery client 35 | sklearn_model (BaseEstimator): Sklearn model 36 | labeling_task_name (str): Name of the labeling task 37 | """ 38 | 39 | super().__init__( 40 | client, 41 | torch_model.__class__.__name__, 42 | labeling_task_name, 43 | inference_fn=torch_model.forward, 44 | initialize_fn=initialize_fn, 45 | postprocessing_fn=postprocessing_fn, 46 | ) 47 | self.torch_model = torch_model 48 | self.initialized = False 49 | self.kwargs = {"encoder": encoder} 50 | 51 | def run(self, loader: DataLoader, indices: List[Dict[str, Any]]) -> None: 52 | if not self.initialized: 53 | self.initialize(None, None) 54 | self.initialized = True 55 | super().run(loader.dataset.X, indices) 56 | -------------------------------------------------------------------------------- /refinery/api_calls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from json.decoder import JSONDecodeError 4 | import pkg_resources 5 | from refinery import exceptions 6 | import requests 7 | from typing import Any, Dict 8 | 9 | try: 10 | version = pkg_resources.get_distribution("refinery-python").version 11 | except pkg_resources.DistributionNotFound: 12 | version = "noversion" 13 | 14 | 15 | def post_request( 16 | url: str, body: Dict[str, Any], session_token: str, project_id: str 17 | ) -> str: 18 | headers = _build_headers(session_token) 19 | response = requests.post(url=url, json=body, headers=headers) 20 | return _handle_response(response, project_id) 21 | 22 | 23 | def get_request(url: str, session_token: str, project_id: str, **query_params) -> str: 24 | headers = _build_headers(session_token) 25 | response = requests.get(url=url, headers=headers, params=query_params) 26 | return _handle_response(response, project_id) 27 | 28 | 29 | def _build_headers(session_token: str) -> Dict[str, str]: 30 | return { 31 | "content-type": "application/json", 32 | "user-agent": f"python-sdk-{version}", 33 | "authorization": f"Bearer {session_token}", 34 | "identifier": session_token, 35 | } 36 | 37 | 38 | def _handle_response(response: requests.Response, project_id: str) -> str: 39 | status_code = response.status_code 40 | if status_code == 200: 41 | json_data = response.json() 42 | if type(json_data) == str: 43 | json_data = json.loads(json_data) 44 | return json_data 45 | else: 46 | try: 47 | json_data = response.json() 48 | error_code = json_data.get("error_code") 49 | error_message = json_data.get("error_message") 50 | except JSONDecodeError: 51 | error_code = 500 52 | error_message = "The server was unable to process the provided data." 53 | 54 | exception = exceptions.get_api_exception_class( 55 | status_code=status_code, 56 | error_code=error_code, 57 | error_message=error_message, 58 | project_id=project_id, 59 | ) 60 | raise exception 61 | -------------------------------------------------------------------------------- /refinery/adapter/transformers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from refinery import Client 4 | from refinery.adapter.util import split_train_test_on_weak_supervision 5 | from datasets import load_dataset 6 | 7 | 8 | def build_classification_dataset( 9 | client: Client, 10 | sentence_input: str, 11 | classification_label: str, 12 | num_train: Optional[int] = 100, 13 | ): 14 | """Build a classification dataset from a refinery client and a config string useable for HuggingFace finetuning. 15 | 16 | Args: 17 | client (Client): Refinery client 18 | sentence_input (str): Name of the column containing the sentence input. 19 | classification_label (str): Name of the label; if this is a task on the full record, enter the string with as "__