├── tests └── __init__.py ├── LICENSE_HEADER ├── src └── dataset_viber │ ├── __init__.py │ ├── _gradio │ ├── __init__.py │ ├── _mixins │ │ ├── _argilla.py │ │ ├── _task_config.py │ │ └── _import_export.py │ ├── _flagging.py │ └── collector.py │ ├── examples │ ├── fn_next_input_synthesizer.py │ ├── fn_next_input_synthesizer_distilabel.py │ ├── log_to_csv.py │ ├── interactive_components.py │ ├── fn_model.py │ ├── log_to_hub.py │ ├── fn_next_input_chat_preference.py │ └── fn_next_input_image_preference.py │ ├── _utils.py │ ├── _constants.py │ ├── synthesizer.py │ └── bulk.py ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── .pre-commit-config.yaml ├── .gitignore ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE_HEADER: -------------------------------------------------------------------------------- 1 | Copyright 2024-present, David Berenstein, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /src/dataset_viber/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber._gradio import CollectorInterface # noqa 16 | from dataset_viber._gradio import AnnotatorInterFace # noqa 17 | -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber._gradio.collector import CollectorInterface # noqa 16 | from dataset_viber._gradio.annotator import AnnotatorInterFace # noqa 17 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/fn_next_input_synthesizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | from dataset_viber.synthesizer import Synthesizer 17 | 18 | synthesizer = Synthesizer.for_chat_generation( 19 | prompt_context="A phone company customer support expert" 20 | ) 21 | 22 | interface = AnnotatorInterFace.for_chat_generation(fn_next_input=synthesizer) 23 | interface.launch() 24 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/fn_next_input_synthesizer_distilabel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | from dataset_viber.synthesizer import Synthesizer 17 | 18 | synthesizer = Synthesizer.for_text_generation( 19 | prompt_context="An expert in the field of AI" 20 | ) 21 | 22 | interface = AnnotatorInterFace.for_text_generation(fn_next_input=synthesizer) 23 | interface.launch() 24 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/log_to_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | 17 | texts = [ 18 | "Anthony Bourdain was an amazing chef!", 19 | "Anthony Bourdain was a terrible tv persona!", 20 | ] 21 | labels = ["positive", "negative"] 22 | 23 | interface = AnnotatorInterFace.for_text_classification( 24 | texts=texts, 25 | labels=labels, 26 | csv_logger=True, # True if you want to log to a CSV file 27 | ) 28 | interface.launch() 29 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/interactive_components.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | 17 | texts = [ 18 | "Anthony Bourdain was an amazing chef!", 19 | "Anthony Bourdain was a terrible tv persona!", 20 | ] 21 | labels = ["positive", "negative"] 22 | 23 | interface = AnnotatorInterFace.for_text_classification( 24 | texts=texts, 25 | labels=labels, 26 | interactive=[False, True], # only change the output labels 27 | ) 28 | interface.launch() 29 | -------------------------------------------------------------------------------- /src/dataset_viber/_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | 17 | 18 | def _get_init_arg_names(cls) -> list[str]: 19 | init_signature = inspect.signature(cls.__init__) 20 | return [param.name for param in init_signature.parameters.values()] 21 | 22 | 23 | def _get_init_payload(cls) -> dict: 24 | payload = cls.__dict__ 25 | payload["inputs"] = payload["input_components"] 26 | payload["outputs"] = payload["output_components"] 27 | return { 28 | key: value for key, value in payload.items() if key in _get_init_arg_names(cls) 29 | } 30 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/fn_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | from transformers import pipeline 17 | 18 | texts = [ 19 | "Anthony Bourdain was an amazing chef!", 20 | "Anthony Bourdain was a terrible tv persona!", 21 | ] 22 | labels = ["positive", "negative"] 23 | 24 | interface = AnnotatorInterFace.for_text_classification( 25 | texts=texts, 26 | labels=labels, 27 | fn_model=pipeline( 28 | "sentiment-analysis" 29 | ), # a callable e.g. (function or transformers pipelines) that returns `str` 30 | ) 31 | interface.launch() 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dataset-viber" 3 | version = "0.3.1" 4 | description = "Dataset Viber is your chill repo for data collection, annotation and vibe checks." 5 | authors = [ 6 | {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"}, 7 | ] 8 | requires-python = ">=3.9,<3.13" 9 | readme = "README.md" 10 | license = {text = "Apache 2.0"} 11 | dependencies = [ 12 | "gradio[oauth]>=4.38,<5", 13 | "datasets>=2,<3", 14 | "argilla>=2,<3", 15 | "gradio-huggingfacehub-search>=0.0.7" 16 | ] 17 | 18 | [project.optional-dependencies] 19 | bulk = [ 20 | "fast-sentence-transformers[gpu]>=0.5", 21 | "tabulate>=0.9.0", 22 | "umap-learn>=0.5,<1", 23 | "plotly>=5,<6", 24 | "dash>=2.11,<3", 25 | "dash-bootstrap-components>=1.6.0" 26 | ] 27 | synthesizer = [ 28 | "distilabel[hf-inference-endpoints]>=1.3.2", 29 | ] 30 | 31 | [build-system] 32 | requires = ["pdm-backend"] 33 | build-backend = "pdm.backend" 34 | 35 | [tool.ruff] 36 | line-length = 88 37 | 38 | [tool.black] 39 | line-length = 88 40 | 41 | [tool.pdm] 42 | distribution = true 43 | 44 | [tool.pdm.dev-dependencies] 45 | dev = [ 46 | "pre-commit>=3.8.0", 47 | "ruff>=0.5,<1", 48 | "pytest>=8,<9", 49 | "black>=24,<25", 50 | "openpyxl>=3,<4", 51 | ] 52 | 53 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/log_to_hub.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataset_viber import AnnotatorInterFace 16 | 17 | texts = [ 18 | "Anthony Bourdain was an amazing chef!", 19 | "Anthony Bourdain was a terrible tv persona!", 20 | ] 21 | labels = ["positive", "negative"] 22 | 23 | interface = AnnotatorInterFace.for_text_classification( 24 | texts=texts, 25 | labels=labels, 26 | dataset_name="username/my_dataset", # "/" if you want to log to the hub 27 | hf_token="HF_TOKEN", # your huggingface token or it will get HF_TOKEN from the environment 28 | private=False, # True if you want to keep the dataset private 29 | ) 30 | interface.launch() 31 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package to PyPI when a Release is Created 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Publish release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/dataset-viber 14 | permissions: 15 | id-token: write 16 | steps: 17 | - name: Checkout Code 🛎 18 | uses: actions/checkout@v4 19 | - name: Setup PDM 20 | uses: pdm-project/setup-pdm@v4 21 | with: 22 | cache: true 23 | python-version-file: pyproject.toml 24 | cache-dependency-path: pdm.lock 25 | - name: Read package info 26 | run: | 27 | PACKAGE_VERSION=$(pdm show --version) 28 | PACKAGE_NAME=$(pdm show --name) 29 | echo "PACKAGE_VERSION=$PACKAGE_VERSION" >> $GITHUB_ENV 30 | echo "PACKAGE_NAME=$PACKAGE_NAME" >> $GITHUB_ENV 31 | echo "$PACKAGE_NAME==$PACKAGE_VERSION" 32 | - name: Publish Package to PyPI test environment 🥪 33 | run: pdm publish --no-build --repository testpypi 34 | continue-on-error: true 35 | - name: Test Installing 🍿 36 | continue-on-error: true 37 | run: | 38 | pip3 install --index-url https://test.pypi.org/simple --no-deps $PACKAGE_NAME==$PACKAGE_VERSION 39 | - name: Publish Package to PyPI 🥩 40 | if: github.ref == 'refs/heads/main' 41 | run: pdm publish --no-build -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/charliermarsh/ruff-pre-commit 3 | rev: v0.4.8 4 | hooks: 5 | - id: ruff-format 6 | 7 | - repo: https://github.com/charliermarsh/ruff-pre-commit 8 | rev: v0.4.8 9 | hooks: 10 | - id: ruff 11 | files: 'src/dataset_viber/.*\.py$' 12 | args: 13 | - --fix 14 | - repo: https://github.com/Lucas-C/pre-commit-hooks 15 | rev: v1.5.5 16 | hooks: 17 | - id: insert-license 18 | name: "Insert license header in Python source files" 19 | files: '^src/dataset_viber/.*\.py$' 20 | args: 21 | - --license-filepath 22 | - LICENSE_HEADER 23 | - --fuzzy-match-generates-todo 24 | - repo: https://github.com/kynan/nbstripout 25 | rev: 0.7.1 26 | hooks: 27 | - id: nbstripout 28 | files: '^src/dataset_viber/.*\.ipynb$' 29 | args: 30 | - --keep-count 31 | - --keep-output 32 | # - --keep-prompt-number 33 | # - --keep-cell-ids 34 | # - --keep-markdown 35 | # - --keep-output-timestamp 36 | # - --keep-execution-count 37 | # - --keep-metadata 38 | # - --keep-version 39 | 40 | ci: 41 | autofix_commit_msg: | 42 | [pre-commit.ci] auto fixes from pre-commit.com hooks 43 | for more information, see https://pre-commit.ci 44 | autofix_prs: true 45 | autoupdate_branch: "" 46 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" 47 | autoupdate_schedule: weekly 48 | skip: [] 49 | submodules: false 50 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/fn_next_input_chat_preference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | 18 | from datasets import load_dataset 19 | from huggingface_hub import InferenceClient 20 | 21 | from dataset_viber import AnnotatorInterFace 22 | 23 | # https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending 24 | MODEL_IDS = [ 25 | "meta-llama/Meta-Llama-3.1-8B-Instruct", 26 | "microsoft/Phi-3-mini-4k-instruct", 27 | "mistralai/Mistral-7B-Instruct-v0.2", 28 | ] 29 | CLIENTS = [ 30 | InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS 31 | ] 32 | 33 | dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train") 34 | 35 | 36 | def _get_response(messages): 37 | client = random.choice(CLIENTS) 38 | message = client.chat_completion(messages=messages, stream=False, max_tokens=2000) 39 | return message.choices[0].message.content 40 | 41 | 42 | def next_input(_prompt, _completion_a, _completion_b): 43 | new_dataset = dataset.shuffle() 44 | row = new_dataset[0] 45 | messages = row["messages"][:-1] 46 | completions = [row["response"]] 47 | completions.append(_get_response(messages)) 48 | completions.append(_get_response(messages)) 49 | random.shuffle(completions) 50 | return messages, completions.pop(), completions.pop() 51 | 52 | 53 | if __name__ == "__main__": 54 | interface = AnnotatorInterFace.for_chat_generation_preference( 55 | fn_next_input=next_input, 56 | interactive=[False, True, True], 57 | ) 58 | interface.launch() 59 | -------------------------------------------------------------------------------- /src/dataset_viber/examples/fn_next_input_image_preference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import concurrent.futures 16 | import io 17 | import os 18 | import random 19 | import time 20 | 21 | import requests 22 | from PIL import Image 23 | 24 | from dataset_viber import AnnotatorInterFace 25 | 26 | HF_TOKEN = os.environ["HF_TOKEN"] 27 | HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} 28 | DATASET_SERVER_URL = "https://datasets-server.huggingface.co" 29 | DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train" 30 | MODEL_URL = ( 31 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 32 | ) 33 | 34 | 35 | def retrieve_sample(idx): 36 | api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1" 37 | response = requests.get(api_url, headers=HEADERS) 38 | data = response.json() 39 | img_url = data["rows"][0]["row"]["image"]["src"] 40 | prompt = data["rows"][0]["row"]["prompt"] 41 | return img_url, prompt 42 | 43 | 44 | def get_rows(): 45 | api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}" 46 | response = requests.get(api_url, headers=HEADERS) 47 | num_rows = response.json()["size"]["config"]["num_rows"] 48 | return num_rows 49 | 50 | 51 | def generate_response(prompt): 52 | def _get_response(prompt): 53 | payload = { 54 | "inputs": prompt, 55 | } 56 | response = requests.post(MODEL_URL, headers=HEADERS, json=payload) 57 | if response.status_code != 200: 58 | time.sleep(5) 59 | return _get_response(prompt) 60 | return response 61 | 62 | response = _get_response(prompt) 63 | image = Image.open(io.BytesIO(response.content)) 64 | return image 65 | 66 | 67 | def next_input(_prompt, _completion_a, _completion_b): 68 | with concurrent.futures.ThreadPoolExecutor() as executor: 69 | random_idx = random.randint(0, get_rows()) - 1 70 | future = executor.submit(retrieve_sample, random_idx) 71 | img_url, prompt = future.result() 72 | generated_image = generate_response(prompt) 73 | return (prompt, img_url, generated_image) 74 | 75 | 76 | if __name__ == "__main__": 77 | interface = AnnotatorInterFace.for_image_generation_preference( 78 | interactive=False, fn_next_input=next_input 79 | ) 80 | interface.launch() 81 | -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/_mixins/_argilla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argilla as rg 16 | 17 | 18 | class ArgillaMixin: 19 | def get_argilla_dataset(self): 20 | class MockClient: 21 | a = "dataset" 22 | 23 | client = MockClient() 24 | client.api.datasets = "mock" 25 | return rg.Dataset( 26 | name="fake-dataset", 27 | settings=self._get_argilla_settings(), 28 | ) 29 | 30 | def _get_argilla_settings(self): 31 | if self.task == "text-classification": 32 | return rg.Settings( 33 | fields=[rg.TextField(name="text")], 34 | questions=[rg.LabelQuestion(name="label", labels=self.labels)], 35 | ) 36 | elif self.task == "text-classification-multi-label": 37 | return rg.Settings( 38 | fields=[rg.TextField(name="text")], 39 | questions=[rg.MultiLabelQuestion(name="label", labels=self.labels)], 40 | ) 41 | elif self.task == "token-classification": 42 | raise NotImplementedError 43 | elif self.task == "question-answering": 44 | raise NotImplementedError 45 | elif self.task == "text-generation": 46 | return rg.Settings( 47 | fields=[rg.TextField(name="prompt")], 48 | questions=[rg.TextQuestion(name="completion")], 49 | ) 50 | elif self.task == "text-generation-preference": 51 | return rg.Settings( 52 | fields=[ 53 | rg.TextField(name="prompt"), 54 | rg.TextField(name="chosen"), 55 | rg.TextField(name="rejected"), 56 | ], 57 | questions=[ 58 | rg.LabelQuestion(name="flag", labels=["A", "B", "tie"]), 59 | rg.TextField(name="reason", required=False), 60 | ], 61 | ) 62 | elif self.task == "chat-classification": 63 | raise NotImplementedError 64 | elif self.task == "chat-classification-multi-label": 65 | raise NotImplementedError 66 | elif self.task == "chat-generation": 67 | raise NotImplementedError 68 | elif self.task == "chat-generation-preference": 69 | raise NotImplementedError 70 | elif self.task == "image-classification": 71 | raise NotImplementedError 72 | elif self.task == "image-classification-multi-label": 73 | raise NotImplementedError 74 | elif self.task == "image-generation": 75 | raise NotImplementedError 76 | elif self.task == "image-description": 77 | raise NotImplementedError 78 | elif self.task == "image-generation-preference": 79 | raise NotImplementedError 80 | elif self.task == "image-question-answering": 81 | raise NotImplementedError 82 | else: 83 | raise NotImplementedError 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | flagged -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/_mixins/_task_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gradio 16 | 17 | 18 | class TaskConfigMixin: 19 | def _set_text_classification_config(self, inputs): 20 | if self.task in [ 21 | "text-classification", 22 | "text-classification-multi-label", 23 | "chat-classification", 24 | "chat-classification-multi-label", 25 | "image-classification", 26 | "image-classification-multi-label", 27 | ]: 28 | with gradio.Tab("Label selector"): 29 | with gradio.Column(): 30 | label_selector = gradio.Dropdown( 31 | choices=[], 32 | label="label", 33 | allow_custom_value=True, 34 | multiselect=True, 35 | ) 36 | 37 | def update_labels(_label_selector): 38 | self.labels = _label_selector 39 | _kwargs = { 40 | "choices": _label_selector, 41 | "label": "label", 42 | } 43 | return ( 44 | gradio.CheckboxGroup(**_kwargs) 45 | if "multi-label" in self.task 46 | else gradio.Radio(**_kwargs) 47 | ) 48 | 49 | def get_label_from_dataframe(_input_data_component, _label_selector): 50 | if "suggestion" in _input_data_component.columns: 51 | unique_labels = ( 52 | _input_data_component["suggestion"].unique().tolist() 53 | ) 54 | if _label_selector: 55 | unique_labels = _label_selector + unique_labels 56 | else: 57 | unique_labels = _label_selector 58 | if unique_labels is None: 59 | return gradio.Dropdown( 60 | choices=_label_selector, 61 | label="label", 62 | allow_custom_value=True, 63 | multiselect=True, 64 | ) 65 | else: 66 | labels = [str(label) for label in unique_labels] 67 | labels = set(labels) 68 | labels = sorted(list(labels)) 69 | return gradio.Dropdown( 70 | choices=labels, 71 | value=labels, 72 | label="label", 73 | allow_custom_value=True, 74 | multiselect=True, 75 | ) 76 | 77 | self.input_data_component.change( 78 | fn=get_label_from_dataframe, 79 | inputs=[self.input_data_component, label_selector], 80 | outputs=[label_selector], 81 | ) 82 | 83 | label_selector.change( 84 | fn=update_labels, 85 | inputs=[label_selector], 86 | outputs=[ 87 | input 88 | for input in inputs 89 | if isinstance(input, (gradio.Radio, gradio.CheckboxGroup)) 90 | ], 91 | ) 92 | -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/_flagging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import sys 18 | from collections import OrderedDict 19 | from pathlib import Path 20 | from typing import Any 21 | 22 | import gradio as gr 23 | import huggingface_hub 24 | from gradio import utils 25 | from gradio.flagging import HuggingFaceDatasetSaver 26 | from gradio_client import utils as client_utils 27 | 28 | if sys.version_info >= (3, 12): 29 | from typing import override 30 | else: 31 | from typing_extensions import override 32 | 33 | 34 | class FixedHubDatasetSaver(HuggingFaceDatasetSaver): 35 | @override 36 | def _deserialize_components( 37 | self, 38 | data_dir: Path, 39 | flag_data: list[Any], 40 | flag_option: str = "", 41 | username: str = "", 42 | ) -> tuple[dict[Any, Any], list[Any]]: 43 | """Deserialize components and return the corresponding row for the flagged sample. 44 | 45 | Images/audio are saved to disk as individual files. 46 | """ 47 | # Components that can have a preview on dataset repos 48 | file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} 49 | 50 | # Generate the row corresponding to the flagged sample 51 | features = OrderedDict() 52 | row = [] 53 | for component, sample in zip(self.components, flag_data): 54 | # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-) 55 | label = component.label or "" 56 | save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) 57 | save_dir.mkdir(exist_ok=True, parents=True) 58 | deserialized = utils.simplify_file_data_in_str( 59 | component.flag(sample, save_dir) 60 | ) 61 | 62 | # Add deserialized object to row 63 | features[label] = {"dtype": "string", "_type": "Value"} 64 | try: 65 | deserialized_path = Path(deserialized) 66 | if not deserialized_path.exists(): 67 | raise FileNotFoundError(f"File {deserialized} not found") 68 | row.append(str(deserialized_path.relative_to(self.dataset_dir))) 69 | except (FileNotFoundError, TypeError, ValueError, OSError): 70 | deserialized = "" if deserialized is None else str(deserialized) 71 | row.append(deserialized) 72 | 73 | # If component is eligible for a preview, add the URL of the file 74 | # Be mindful that images and audio can be None 75 | if isinstance(component, tuple(file_preview_types)): # type: ignore 76 | for _component, _type in file_preview_types.items(): 77 | if isinstance(component, _component): 78 | features[label + " file"] = {"_type": _type} 79 | break 80 | if deserialized: 81 | path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL 82 | Path(deserialized).relative_to(self.dataset_dir) 83 | ).replace("\\", "/") 84 | row.append( 85 | huggingface_hub.hf_hub_url( 86 | repo_id=self.dataset_id, 87 | filename=path_in_repo, 88 | repo_type="dataset", 89 | ) 90 | ) 91 | else: 92 | row.append("") 93 | features["flag"] = {"dtype": "string", "_type": "Value"} 94 | features["username"] = {"dtype": "string", "_type": "Value"} 95 | row.append(flag_option) 96 | row.append(username) 97 | return features, row 98 | -------------------------------------------------------------------------------- /src/dataset_viber/_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, List, Tuple, Union 16 | 17 | import argilla as rg 18 | import gradio 19 | import numpy as np 20 | import PIL.Image 21 | 22 | 23 | class MockClient: 24 | def __init__(self): 25 | self.api = MockApi() 26 | self.workspaces = lambda _: rg.Workspace(name="dataset-viber", client=self) 27 | 28 | 29 | class MockApi: 30 | def __init__(self): 31 | self.fields = "" 32 | self.datasets = "" 33 | self.questions = "" 34 | self.records = "" 35 | 36 | 37 | DEFAULT_EMBEDDING_MODEL = "Snowflake/snowflake-arctic-embed-xs" 38 | COLORS = [ 39 | "#a6cee3", 40 | "#1f78b4", 41 | "#b2df8a", 42 | "#33a02c", 43 | "#fb9a99", 44 | "#e31a1c", 45 | "#fdbf6f", 46 | "#ff7f00", 47 | "#cab2d6", 48 | "#6a3d9a", 49 | "#ffff99", 50 | "#b15928", 51 | ] 52 | FAKE_UUID = "00000000-0000-0000-0000-000000000000" 53 | DEFAULT_DATASET_CONFIG = { 54 | "id": FAKE_UUID, 55 | "inserted_at": "2024-07-30T18:54:05.550199", 56 | "updated_at": "2024-07-30T18:54:05.748298", 57 | "name": "dataset-viber", 58 | "status": "ready", 59 | "guidelines": None, 60 | "allow_extra_metadata": True, 61 | "distribution": {"strategy": "overlap", "min_submitted": 1}, 62 | "workspace_id": FAKE_UUID, 63 | "last_activity_at": FAKE_UUID, 64 | } 65 | 66 | TASK_MAPPING = { 67 | "text-classification": { 68 | "input_columns": ["text", "suggestion"], 69 | "output_columns": ["text", "label"], 70 | "fn_model_output": List[Dict[str, Union[str, float]]], 71 | "fn_next_input_output": Tuple[str, str], 72 | "components": [gradio.Textbox, gradio.Radio], 73 | "autotrain": { 74 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/text_classification/hub_dataset.yml", 75 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/text_classification/local_dataset.yml", 76 | }, 77 | }, 78 | "text-classification-multi-label": { 79 | "input_columns": ["text", "suggestion"], 80 | "output_columns": ["text", "label"], 81 | "fn_model_output": List[Dict[str, Union[str, float]]], 82 | "fn_next_input_output": Tuple[str, List[str]], 83 | "components": [gradio.Textbox, gradio.CheckboxGroup], 84 | "autotrain": { 85 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/text_classification/hub_dataset.yml", 86 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/text_classification/local_dataset.yml", 87 | }, 88 | }, 89 | "token-classification": { 90 | "input_columns": ["text"], 91 | "output_columns": ["text", "spans"], 92 | "fn_model_output": List[Tuple[str, str]], 93 | "fn_next_input_output": Tuple[str, List[Tuple[str, str]]], 94 | "components": [gradio.Textbox, gradio.HighlightedText], 95 | "autotrain": { 96 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/token_classification/hub_dataset.yml", 97 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/token_classification/local_dataset.yml", 98 | }, 99 | }, 100 | "question-answering": { 101 | "input_columns": ["question", "context"], 102 | "output_columns": ["question", "context"], 103 | "fn_model_output": List[Tuple[str, str]], 104 | "fn_next_input_output": Tuple[str, Union[str, List[Tuple[str, str]]]], 105 | "components": [gradio.Textbox, gradio.HighlightedText], 106 | "autotrain": { 107 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/extractive_question_answering/hub_dataset.yml", 108 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/extractive_question_answering/local_dataset.yml", 109 | }, 110 | }, 111 | "text-generation": { 112 | "input_columns": ["prompt", "completion"], 113 | "output_columns": ["prompt", "completion"], 114 | "fn_model_output": str, 115 | "fn_next_input_output": Tuple[str, str], 116 | "components": [gradio.Textbox, gradio.Textbox], 117 | "autotrain": { 118 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/chat_generation/local.yml" 119 | }, 120 | }, 121 | "text-generation-preference": { 122 | "input_columns": ["prompt", "completion_a", "completion_b"], 123 | "output_columns": ["prompt", "completion_a", "completion_b", "flag"], 124 | "fn_model_output": str, 125 | "fn_next_input_output": Tuple[str, str, str], 126 | "components": [gradio.Textbox, gradio.Textbox, gradio.Textbox], 127 | "autotrain": { 128 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/llm_finetuning/llama3-8b-orpo.yml" 129 | }, 130 | }, 131 | "chat-classification": { 132 | "input_columns": ["prompt", "suggestion"], 133 | "output_columns": ["prompt", "label"], 134 | "fn_model_output": List[Dict[str, Union[str, float]]], 135 | "fn_next_input_output": Tuple[List[gradio.ChatMessage], str], 136 | "components": [gradio.Chatbot, gradio.Radio], 137 | }, 138 | "chat-classification-multi-label": { 139 | "input_columns": ["prompt", "suggestion"], 140 | "output_columns": ["prompt", "label"], 141 | "fn_model_output": List[Dict[str, Union[str, float]]], 142 | "fn_next_input_output": Tuple[List[gradio.ChatMessage], List[str]], 143 | "components": [gradio.Chatbot, gradio.CheckboxGroup], 144 | }, 145 | "chat-generation": { 146 | "input_columns": ["prompt", "completion"], 147 | "output_columns": ["prompt", "completion"], 148 | "fn_model_output": str, 149 | "fn_next_input_output": Tuple[List[gradio.ChatMessage], str], 150 | "components": [gradio.Chatbot, gradio.Textbox], 151 | "autotrain": { 152 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/chat_generation/local.yml" 153 | }, 154 | }, 155 | "chat-generation-preference": { 156 | "input_columns": ["prompt", "completion_a", "completion_b"], 157 | "output_columns": ["prompt", "completion_a", "completion_b", "flag"], 158 | "fn_model_output": str, 159 | "fn_next_input_output": Tuple[List[gradio.ChatMessage], str, str], 160 | "components": [gradio.Chatbot, gradio.Textbox, gradio.Textbox], 161 | "autotrain": { 162 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/llm_finetuning/llama3-8b-orpo.yml" 163 | }, 164 | }, 165 | "image-classification": { 166 | "input_columns": ["image", "suggestion"], 167 | "output_columns": ["image", "label"], 168 | "fn_model_output": Union[str, List[Dict[str, Union[str, float]]]], 169 | "fn_next_input_output": Tuple[PIL.Image.Image, str], 170 | "components": [gradio.Image, gradio.Radio], 171 | "autotrain": { 172 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/image_classification/hub_dataset.yml", 173 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/image_classification/local.yml", 174 | }, 175 | }, 176 | "image-classification-multi-label": { 177 | "input_columns": ["image", "suggestion"], 178 | "output_columns": ["image", "label"], 179 | "fn_model_output": List[Dict[str, Union[str, float]]], 180 | "fn_next_input_output": Tuple[PIL.Image.Image, List[str]], 181 | "components": [gradio.Image, gradio.CheckboxGroup], 182 | "autotrain": { 183 | "hub": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/image_classification/hub_dataset.yml", 184 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/image_classification/local.yml", 185 | }, 186 | }, 187 | "image-generation": { 188 | "input_columns": ["prompt", "completion"], 189 | "output_columns": ["prompt", "completion"], 190 | "fn_model_output": Union[np.ndarray, PIL.Image.Image, str], 191 | "fn_next_input_output": Tuple[str, PIL.Image.Image], 192 | "components": [gradio.Textbox, gradio.Image], 193 | }, 194 | "image-description": { 195 | "input_columns": ["image", "description"], 196 | "output_columns": ["image", "description"], 197 | "fn_model_output": str, 198 | "fn_next_input_output": Tuple[PIL.Image.Image, str], 199 | "components": [gradio.Image, gradio.Textbox], 200 | }, 201 | "image-generation-preference": { 202 | "input_columns": ["prompt", "completion_a", "completion_b"], 203 | "output_columns": ["prompt", "completion_a", "completion_b", "flag"], 204 | "fn_model_output": PIL.Image.Image, 205 | "fn_next_input_output": Tuple[str, PIL.Image.Image, PIL.Image.Image], 206 | "components": [gradio.Textbox, gradio.Image, gradio.Image], 207 | }, 208 | "image-question-answering": { 209 | "input_columns": ["image", "question", "answer"], 210 | "output_columns": ["image", "question", "answer"], 211 | "fn_model_output": str, 212 | "fn_next_input_output": Tuple[PIL.Image.Image, str, str], 213 | "components": [gradio.Image, gradio.Textbox, gradio.Textbox], 214 | "autotrain": { 215 | "local": "https://github.com/huggingface/autotrain-advanced/blob/main/configs/vlm/paligemma_vqa.yml" 216 | }, 217 | }, 218 | } 219 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/collector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from typing import TYPE_CHECKING, Callable, List, Optional, Sequence 17 | 18 | import gradio 19 | import huggingface_hub 20 | from gradio.components import Component 21 | 22 | from dataset_viber._gradio._flagging import FixedHubDatasetSaver 23 | from dataset_viber._utils import _get_init_payload 24 | 25 | if TYPE_CHECKING: 26 | from transformers.pipelines import Pipeline 27 | 28 | 29 | class CollectorInterface(gradio.Interface): 30 | def __init__( 31 | self, 32 | fn: Callable, 33 | inputs: str | Component | Sequence[str | Component] | None, 34 | outputs: str | Component | Sequence[str | Component] | None, 35 | *, 36 | csv_logger: Optional[bool] = False, 37 | dataset_name: str = None, 38 | hf_token: Optional[str] = None, 39 | private: Optional[bool] = False, 40 | allow_flagging: Optional[str] = "auto", 41 | flagging_options: Optional[List[str]] = None, 42 | show_embedded_viewer: Optional[bool] = True, 43 | **kwargs, 44 | ) -> None: 45 | """ 46 | Load a CollectorInterface with data logging capabilities. 47 | 48 | Parameters: 49 | fn: the function to run 50 | inputs: the input component(s) 51 | outputs: the output component(s) 52 | csv_logger: whether or not to log the data to a CSV file. 53 | dataset_name: the "org/dataset" to which the data needs to be logged 54 | hf_token: optional token to pass, otherwise will default to env var HF_TOKEN 55 | private: whether or not to create a private repo 56 | allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". 57 | flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc. 58 | 59 | Return: 60 | an intialized CollectorInterface 61 | """ 62 | self.csv_logger = csv_logger 63 | self._validate_flagging_options( 64 | allow_flagging=allow_flagging, flagging_options=flagging_options 65 | ) 66 | flagging_callback = kwargs.pop("flagging_callback", None) 67 | 68 | if dataset_name and flagging_callback is None: 69 | flagging_callback = self._get_flagging_callback( 70 | dataset_name=dataset_name, hf_token=hf_token, private=private 71 | ) 72 | if flagging_callback: 73 | pass 74 | elif csv_logger: 75 | flagging_callback = None 76 | else: 77 | flagging_callback = gradio.CSVLogger() 78 | flagging_callback.setup = lambda *args, **kwargs: None 79 | flagging_callback.flag = lambda *args, **kwargs: 0 80 | 81 | kwargs.update( 82 | { 83 | "flagging_callback": flagging_callback, 84 | "allow_flagging": allow_flagging, 85 | "flagging_options": flagging_options, 86 | } 87 | ) 88 | super().__init__(fn=fn, inputs=inputs, outputs=outputs, **kwargs) 89 | self = self._add_html_component_with_viewer( 90 | self, flagging_callback, show_embedded_viewer 91 | ) 92 | 93 | @classmethod 94 | def from_pipeline( 95 | cls, 96 | pipeline: "Pipeline", 97 | *, 98 | csv_logger: Optional[bool] = False, 99 | dataset_name: Optional[str] = None, 100 | hf_token: Optional[str] = None, 101 | private: Optional[bool] = False, 102 | allow_flagging: Optional[str] = "auto", 103 | flagging_options: Optional[List[str]] = None, 104 | show_embedded_viewer: Optional[bool] = True, 105 | **kwargs, 106 | ) -> gradio.Interface: 107 | """ 108 | Load an existing transformers.pipeline into a CollectorInterface with data logging capabilities. 109 | 110 | Parameters:: 111 | pipeline: an initialized the transformers.pipeline 112 | csv_logger: whether or not to log the data to a CSV file. 113 | dataset_name: the "org/dataset" to which the data needs to be logged 114 | hf_token: optional token to pass, otherwise will default to env var HF_TOKEN 115 | private: whether or not to create a private repo 116 | allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". 117 | flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc. 118 | 119 | Return: 120 | an intialized CollectorInterface 121 | """ 122 | return cls.from_interface( 123 | interface=gradio.Interface.from_pipeline(pipeline=pipeline), 124 | dataset_name=dataset_name, 125 | hf_token=hf_token, 126 | private=private, 127 | allow_flagging=allow_flagging, 128 | flagging_options=flagging_options, 129 | show_embedded_viewer=show_embedded_viewer, 130 | **kwargs, 131 | ) 132 | 133 | @classmethod 134 | def from_interface( 135 | cls, 136 | interface: gradio.Interface, 137 | *, 138 | csv_logger: Optional[bool] = False, 139 | dataset_name: Optional[str] = None, 140 | hf_token: Optional[str] = None, 141 | private: Optional[bool] = False, 142 | allow_flagging: Optional[str] = "auto", 143 | flagging_options: Optional[List[str]] = None, 144 | show_embedded_viewer: Optional[bool] = True, 145 | **kwargs, 146 | ) -> gradio.Interface: 147 | """ 148 | Load an existing gradio.Interface into a CollectorInterface with data logging capabilities. 149 | 150 | Parameters:: 151 | interface: any initialized gradio.Interface 152 | csv_logger: whether or not to log the data to a CSV file. 153 | dataset_name: the "org/dataset" to which the data needs to be logged 154 | hf_token: optional token to pass, otherwise will default to env var HF_TOKEN 155 | private: whether or not to create a private repo 156 | allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". 157 | flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc. 158 | 159 | Return: 160 | an intialized CollectorInterface 161 | """ 162 | flagging_callback = None or kwargs.pop("flagging_callback", None) 163 | if dataset_name and not flagging_callback: 164 | flagging_callback = cls._get_flagging_callback( 165 | dataset_name=dataset_name, hf_token=hf_token, private=private 166 | ) 167 | payload = _get_init_payload(interface) 168 | payload.update(**kwargs) 169 | payload.update( 170 | { 171 | "flagging_callback": flagging_callback, 172 | "allow_flagging": allow_flagging, 173 | "flagging_options": flagging_options, 174 | "show_embedded_viewer": show_embedded_viewer, 175 | } 176 | ) 177 | return cls(**payload) 178 | 179 | @staticmethod 180 | def _validate_flagging_options(allow_flagging, flagging_options) -> None: 181 | if allow_flagging == "auto" and flagging_options: 182 | raise ValueError( 183 | "automatic flagging cannot be combined with 'flagging_options', set `allow_flagging='manual'` instead" 184 | ) 185 | if allow_flagging == "never": 186 | warnings.warn("You are using a datacollector but don't enable flagging") 187 | 188 | @staticmethod 189 | def _get_flagging_callback( 190 | dataset_name: str, 191 | hf_token: str, 192 | private: bool = False, 193 | ) -> gradio.HuggingFaceDatasetSaver: 194 | return FixedHubDatasetSaver( 195 | hf_token=hf_token, 196 | dataset_name=dataset_name, 197 | private=private, 198 | info_filename="dataset_info.json", 199 | separate_dirs=True, 200 | ) 201 | 202 | @staticmethod 203 | def _get_repo_url_from_repo_id(repo_id: str) -> str: 204 | return f"https://huggingface.co/datasets/{repo_id}" 205 | 206 | @staticmethod 207 | def _get_repo_url_fom_dataset_saver( 208 | flagging_callback: gradio.HuggingFaceDatasetSaver, 209 | ) -> huggingface_hub.RepoUrl: 210 | return f"""https://huggingface.co/datasets/{huggingface_hub.create_repo( 211 | repo_id=flagging_callback.dataset_id, 212 | token=flagging_callback.hf_token, 213 | private=flagging_callback.dataset_private, 214 | repo_type="dataset", 215 | exist_ok=True, 216 | ).repo_id}""" 217 | 218 | @staticmethod 219 | def _get_embedded_dataset_viewer(repo_url: str) -> str: 220 | return f""" 221 | 227 | """ 228 | 229 | @classmethod 230 | def _add_html_component_with_viewer( 231 | cls, 232 | instance: gradio.Interface, 233 | flagging_callback: Optional[gradio.HuggingFaceDatasetSaver] = None, 234 | show_embedded_viewer: bool = True, 235 | ) -> gradio.Interface: 236 | if isinstance(flagging_callback, gradio.HuggingFaceDatasetSaver): 237 | repo_url = cls._get_repo_url_fom_dataset_saver(flagging_callback) 238 | formatted_repo_url = ( 239 | f"Data is being written to [a dataset on the Hub]({repo_url})." 240 | ) 241 | with instance: 242 | with gradio.Accordion("Data is synced to Hugging Face Hub", open=False): 243 | with gradio.Row(equal_height=False): 244 | gradio.Markdown(formatted_repo_url) 245 | if show_embedded_viewer and not flagging_callback.dataset_private: 246 | with gradio.Row(): 247 | with gradio.Accordion( 248 | "dataset viewer - do an (empty) search to refresh", 249 | open=False, 250 | ): 251 | gradio.HTML(cls._get_embedded_dataset_viewer(repo_url)) 252 | return instance 253 | -------------------------------------------------------------------------------- /src/dataset_viber/_gradio/_mixins/_import_export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import io 16 | import json 17 | import os 18 | import threading 19 | import time 20 | import uuid 21 | from pathlib import Path 22 | 23 | import gradio 24 | import numpy as np 25 | import pandas as pd 26 | from datasets import Dataset, load_dataset 27 | from gradio.oauth import ( 28 | OAUTH_CLIENT_ID, 29 | OAUTH_CLIENT_SECRET, 30 | OAUTH_SCOPES, 31 | OPENID_PROVIDER_URL, 32 | get_space, 33 | ) 34 | from gradio_huggingfacehub_search import HuggingfaceHubSearch 35 | from PIL import Image 36 | 37 | from dataset_viber._gradio._mixins._argilla import ArgillaMixin 38 | 39 | if ( 40 | all( 41 | [ 42 | OAUTH_CLIENT_ID, 43 | OAUTH_CLIENT_SECRET, 44 | OAUTH_SCOPES, 45 | OPENID_PROVIDER_URL, 46 | ] 47 | ) 48 | or get_space() is None 49 | ): 50 | from gradio.oauth import OAuthToken 51 | else: 52 | OAuthToken = str 53 | 54 | CODE_KWARGS = { 55 | "language": "json", 56 | "interactive": True, 57 | "label": "Column Mapping", 58 | "lines": 1, 59 | } 60 | 61 | 62 | class ImportExportMixin(ArgillaMixin): 63 | def _override_block_init_method(self, **kwargs): 64 | # Initialize the parent class 65 | gradio.Blocks.__init__( 66 | self, 67 | analytics_enabled=kwargs.get("analytics_enabled", True), 68 | mode="interface", 69 | css=kwargs.get("css", None), 70 | title=kwargs.get("title", "Gradio"), 71 | theme=kwargs.get("theme", None), 72 | js=kwargs.get("js", None), 73 | head=kwargs.get("head", None), 74 | delete_cache=kwargs.get("delete_cache", False), 75 | fill_width=kwargs.get("fill_width", False), 76 | # **kwargs, 77 | ) 78 | # Override the __init__ method of the parent class to avoid the re-creation of the blocks 79 | gradio.Blocks.__init__ = lambda *args, **kwargs: None 80 | 81 | def _configure_import(self): 82 | with gradio.Tab("Import data"): 83 | with gradio.Tab("Import from Hugging Face Hub"): 84 | search_in = HuggingfaceHubSearch( 85 | label="Search Huggingface Hub", 86 | placeholder="Search for datasets on Huggingface", 87 | search_type="dataset", 88 | sumbit_on_select=True, 89 | ) 90 | dataset_viewer = gradio.HTML(label="Dataset Viewer") 91 | search_in.submit( 92 | fn=lambda x: self._get_embedded_dataset_viewer( 93 | self._get_repo_url_from_repo_id(x) 94 | ), 95 | inputs=[search_in], 96 | outputs=[dataset_viewer], 97 | ) 98 | column_mapping_hf_upload = gradio.Code( 99 | value=json.dumps(dict.fromkeys(self.input_columns, ""), indent=2), 100 | **CODE_KWARGS, 101 | ) 102 | start_btn_hf_upload = gradio.Button("Start Annotating") 103 | start_btn_hf_upload.click( 104 | fn=self._set_data_hf_upload, 105 | inputs=[search_in, column_mapping_hf_upload], 106 | outputs=self.input_data_component, 107 | ) 108 | with gradio.Tab(label="Import from file"): 109 | upload_button = gradio.UploadButton( 110 | "Upload", 111 | label="Select a file (CSV or Excel)", 112 | file_types=["csv", "xlsx", "xlsx"], 113 | ) 114 | df_upload = gradio.Dataframe(interactive=True) 115 | upload_button.upload( 116 | fn=self.upload_file, inputs=upload_button, outputs=df_upload 117 | ) 118 | column_mapping_file_upload = gradio.Code( 119 | value=json.dumps(dict.fromkeys(self.input_columns, ""), indent=2), 120 | **CODE_KWARGS, 121 | ) 122 | start_btn_file_upload = gradio.Button("Start Annotating") 123 | start_btn_file_upload.click( 124 | fn=self._set_data, 125 | inputs=[df_upload, column_mapping_file_upload], 126 | outputs=self.input_data_component, 127 | ) 128 | 129 | def _configure_export(self): 130 | with gradio.Tab("Export data"): 131 | with gradio.Tab("Export to Hugging Face Hub"): 132 | with gradio.Row(): 133 | dataset_name = gradio.Textbox( 134 | placeholder="Dataset Name", label="Dataset Name" 135 | ) 136 | with gradio.Row(): 137 | gradio.Info( 138 | "Ensure HF_TOKEN env var has been set or gradio allows for login through OAuth." 139 | ) 140 | export_button_hf = gradio.Button("Export") 141 | if ( 142 | all( 143 | [ 144 | OAUTH_CLIENT_ID, 145 | OAUTH_CLIENT_SECRET, 146 | OAUTH_SCOPES, 147 | OPENID_PROVIDER_URL, 148 | ] 149 | ) 150 | or get_space() is None 151 | ): 152 | export_button_hf.click( 153 | fn=self._export_data_hf, 154 | inputs=dataset_name, 155 | outputs=dataset_name, 156 | ) 157 | else: 158 | token = gradio.Textbox( 159 | value=os.getenv("HF_TOKEN"), 160 | type="password", 161 | label=f"OAuth Token token_present={'HF_TOKEN' in os.environ}", 162 | interactive=True, 163 | ) 164 | export_button_hf.click( 165 | fn=self._export_data_hf, 166 | inputs=[dataset_name, token], 167 | outputs=dataset_name, 168 | ) 169 | 170 | with gradio.Tab("Export to file"): 171 | with gradio.Column(): 172 | export_button = gradio.Button("Export") 173 | self.file = gradio.File(interactive=False, visible=False) 174 | export_button.click( 175 | fn=self._export_data, 176 | outputs=self.file, 177 | ) 178 | 179 | def _set_data_hf_upload(self, repo_id, column_mapping, split="train"): 180 | gradio.Info("Started loading the dataset. This might take a while.") 181 | try: 182 | column_mapping = self._json_to_dict(column_mapping) 183 | dataset = load_dataset(repo_id, split=split) 184 | for key, value in column_mapping.items(): 185 | if key != value: 186 | if value in dataset.column_names: 187 | dataset = dataset.rename_column(value, key) 188 | dataset = dataset.select_columns( 189 | [ 190 | col 191 | for col in list(column_mapping.keys()) 192 | if col in dataset.column_names 193 | ] 194 | ) 195 | # add images before converting to bytes 196 | for column in column_mapping.keys(): 197 | if column in dataset.column_names: 198 | self.input_data[column].extend( 199 | [self.process_image_input(entry) for entry in dataset[column]] 200 | ) 201 | self._set_equal_length_input_data() 202 | dataframe = pd.DataFrame.from_dict(self.input_data) 203 | self.start = len(dataframe) 204 | except Exception as e: 205 | raise gradio.Error(f"An error occurred: {e}") 206 | gradio.Info( 207 | "Data loaded successfully. Showing first 100 examples in 'remaing data' tab. Click on \"⏭️ Next\" to get the next record." 208 | ) 209 | return dataframe.head(100) 210 | 211 | def _set_data(self, dataframe, column_mapping): 212 | gradio.Info("Started loading the dataset. This might take a while.") 213 | try: 214 | column_mapping = self._json_to_dict(column_mapping) 215 | dataframe = dataframe[list(column_mapping.values())] 216 | dataframe.columns = list(column_mapping.keys()) 217 | for column in column_mapping.keys(): 218 | if column in dataframe.columns: 219 | self.input_data[column].extend(dataframe[column].tolist()) 220 | self._set_equal_length_input_data() 221 | dataframe = pd.DataFrame.from_dict(self.input_data) 222 | self.start = len(dataframe) 223 | except Exception as e: 224 | raise gradio.Error(f"An error occurred: {e}") 225 | gradio.Info( 226 | "Data loaded successfully. Showing first 100 examples in 'remaing data' tab. Click on 🗑️ discard to get the next record." 227 | ) 228 | return dataframe.head(100) 229 | 230 | def _export_data_hf(self, dataset_name, oauth_token: OAuthToken | None): 231 | gradio.Info("Started exporting the dataset. This may take a while.") 232 | if not isinstance(oauth_token, str): 233 | oauth_token = oauth_token.token 234 | Dataset.from_dict(self.output_data).push_to_hub(dataset_name, token=oauth_token) 235 | gradio.Info(f"Exported the dataset to Hugging Face Hub as {dataset_name}.") 236 | 237 | def delete_file_after_delay(self, file_path, delay=30): 238 | def delete_file(): 239 | time.sleep(delay) 240 | Path(Path(file_path).name).unlink() 241 | 242 | thread = threading.Thread(target=delete_file) 243 | thread.start() 244 | 245 | def _export_data(self, dataframe): 246 | id = uuid.uuid4() 247 | if "image" in self.task: 248 | filename = f"{id}.parquet" 249 | Dataset.from_dict(self.output_data).to_parquet(filename) 250 | else: 251 | filename = f"{id}.csv" 252 | Dataset.from_dict(self.output_data).to_csv(filename) 253 | self.delete_file_after_delay(filename, 20) 254 | gradio.Info( 255 | f"Exported the dataset to {filename}. It will be deleted in 20 seconds." 256 | ) 257 | return gradio.File(value=filename, visible=True) 258 | 259 | def _delete_file(self, _file): 260 | return gradio.File(interactive=False, visible=False) 261 | 262 | @staticmethod 263 | def _json_to_dict(json_str): 264 | return json.loads(json_str) 265 | 266 | @staticmethod 267 | def upload_file(file): 268 | # Determine the file type and load accordingly 269 | if file.name.endswith(".csv"): 270 | df = pd.read_csv(file.name) 271 | elif file.name.endswith(".xls") or file.name.endswith(".xlsx"): 272 | df = pd.read_excel(file.name) 273 | else: 274 | return "Unsupported file type. Please upload a CSV, Excel, or JSON file." 275 | return df 276 | 277 | def process_image_input(self, input_data): 278 | if input_data is None: 279 | return None 280 | elif isinstance(input_data, dict): 281 | if "bytes" in input_data and input_data["bytes"]: 282 | # Case: bytes in a dictionary 283 | return Image.open(io.BytesIO(input_data["bytes"])) 284 | elif "path" in input_data and input_data["path"]: 285 | # Case: path in a dictionary 286 | return input_data["path"] 287 | elif isinstance(input_data, Image.Image): 288 | # Case: PIL Image 289 | return input_data 290 | elif isinstance(input_data, str): 291 | # Case: URL or file path as string 292 | return input_data 293 | elif isinstance(input_data, (np.ndarray, list)): 294 | # Case: numpy array or list 295 | return Image.fromarray(np.array(input_data)) 296 | else: 297 | return input_data 298 | 299 | def _set_equal_length_input_data(self): 300 | # assert all columns for self.input_data are a similar length and fille with "" if not 301 | max_column_len = max( 302 | [len(self.input_data[column]) for column in self.input_data.keys()] 303 | ) 304 | for column in self.input_data.keys(): 305 | if len(self.input_data[column]) < max_column_len: 306 | self.input_data[column].extend( 307 | [""] * (max_column_len - len(self.input_data[column])) 308 | ) 309 | 310 | def _set_equal_length_output_data(self): 311 | # assert all columns for self.output_data are a similar length and fille with "" if not 312 | max_column_len = max( 313 | [len(self.output_data[column]) for column in self.output_data.keys()] 314 | ) 315 | for column in self.output_data.keys(): 316 | if len(self.output_data[column]) < max_column_len: 317 | self.output_data[column].extend( 318 | [""] * (max_column_len - len(self.output_data[column])) 319 | ) 320 | -------------------------------------------------------------------------------- /src/dataset_viber/synthesizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import io 15 | import os 16 | import time 17 | import uuid 18 | import warnings 19 | from typing import Any, Optional 20 | 21 | import requests 22 | from distilabel.llms import LLM, InferenceEndpointsLLM 23 | from distilabel.steps.tasks import GenerateTextClassificationData, Magpie 24 | from PIL import Image 25 | 26 | from dataset_viber._constants import TASK_MAPPING 27 | 28 | _DEFAULT_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" 29 | _GENERATION_KWARGS = {"max_new_tokens": 4000, "temperature": 1, "do_sample": True} 30 | _DEFAULT_LLM = InferenceEndpointsLLM( 31 | model_id=_DEFAULT_MODEL_ID, 32 | tokenizer_id=_DEFAULT_MODEL_ID, 33 | magpie_pre_query_template="llama3", 34 | generation_kwargs=_GENERATION_KWARGS, 35 | ) 36 | 37 | _DEFAULT_MODEL_URL = ( 38 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 39 | ) 40 | 41 | 42 | class _ImageGeneration: 43 | """ 44 | A class for generating images based on text prompts using a specified model. 45 | """ 46 | 47 | def __init__(self, llm: Optional[str] = None): 48 | """ 49 | Initialize the _ImageGeneration class. 50 | 51 | Args: 52 | llm (Optional[str]): The URL of the image generation model. If None, uses the default model. 53 | """ 54 | self.model_url = llm or _DEFAULT_MODEL_URL 55 | self.headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"} 56 | 57 | def process(self, prompt): 58 | """ 59 | Generate an image based on the given prompt. 60 | 61 | Args: 62 | prompt (str): The text prompt for image generation. 63 | 64 | Returns: 65 | PIL.Image.Image: The generated image. 66 | """ 67 | 68 | def _get_response(prompt): 69 | payload = {"inputs": prompt, "_id": uuid.uuid4().hex} 70 | response = requests.post(self.model_url, headers=self.headers, json=payload) 71 | if response.status_code != 200: 72 | warnings.warn( 73 | f"Failed to get response from model. Status code: {response.status_code}. Response: {response.text}" 74 | ) 75 | time.sleep(5) 76 | return _get_response(prompt) 77 | return response 78 | 79 | response = _get_response(prompt) 80 | image = Image.open(io.BytesIO(response.content)) 81 | return image 82 | 83 | def load(self): 84 | pass 85 | 86 | 87 | class Synthesizer: 88 | """ 89 | A class for synthesizing data for various AI tasks. 90 | """ 91 | 92 | def __init__(self, next_input: callable, prompt_context: str): 93 | """ 94 | Initialize the Synthesizer class. 95 | 96 | Args: 97 | next_input (callable): A function to generate the next input. 98 | prompt_context (str): The context for the prompt. 99 | """ 100 | self.next_input = next_input 101 | self.prompt_context = prompt_context 102 | 103 | def __call__(self, *args: Any, **kwds: Any) -> Any: 104 | """ 105 | Call the next_input function with the given arguments. 106 | 107 | Args: 108 | *args: Positional arguments to pass to next_input. 109 | **kwds: Keyword arguments to pass to next_input. 110 | 111 | Returns: 112 | Any: The result of calling next_input. 113 | """ 114 | return self.next_input(*args, **kwds) 115 | 116 | def batch_synthesize(self, n: int): 117 | """ 118 | Synthesize a batch of inputs. 119 | 120 | Args: 121 | n (int): The number of inputs to synthesize. 122 | 123 | Returns: 124 | list: A list of synthesized inputs. 125 | """ 126 | batch = [self.next_input(*self.input_columns) for _ in range(n)] 127 | return list(map(list, zip(*batch))) 128 | 129 | @classmethod 130 | def _create_synthesizer( 131 | cls, task_type: str, prompt_context: str, llm: Optional[LLM] = None, **kwargs 132 | ): 133 | """ 134 | Create a Synthesizer instance for a specific task type. 135 | 136 | Args: 137 | task_type (str): The type of task for which to create the synthesizer. 138 | prompt_context (str): The context for the prompt. 139 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 140 | **kwargs: Additional keyword arguments for task configuration. 141 | 142 | Returns: 143 | Synthesizer: An instance of the Synthesizer class configured for the specified task. 144 | """ 145 | if llm: 146 | warnings.warn( 147 | "custom LLM passed, make sure to set do_sample=True for generation_kwargs within the llm" 148 | ) 149 | 150 | task_config = TASK_MAPPING[task_type] 151 | cls.input_columns = task_config["input_columns"] + ["prompt_context"] 152 | cls.output_columns = task_config["output_columns"] 153 | 154 | task_generator = cls._get_task_generator( 155 | task_type, llm or _DEFAULT_LLM, **kwargs 156 | ) 157 | next_input = cls._get_next_input_function( 158 | task_type, prompt_context, task_generator 159 | ) 160 | 161 | return cls(next_input, prompt_context) 162 | 163 | @staticmethod 164 | def _get_task_generator(task_type: str, llm: LLM, **kwargs): 165 | """ 166 | Get the appropriate task generator based on the task type. 167 | 168 | Args: 169 | task_type (str): The type of task. 170 | llm (LLM): The language model to use. 171 | **kwargs: Additional keyword arguments for task configuration. 172 | 173 | Returns: 174 | Any: An instance of the appropriate task generator. 175 | 176 | Raises: 177 | ValueError: If an unknown task type is provided. 178 | """ 179 | if task_type == "text-classification": 180 | task_generator = GenerateTextClassificationData(llm=llm, **kwargs) 181 | elif "image" in task_type: 182 | task_generator = _ImageGeneration() 183 | else: 184 | task_generator = Magpie(llm=llm) 185 | task_generator.set_runtime_parameters(kwargs.get("runtime_parameters", {})) 186 | task_generator.load() 187 | return task_generator 188 | 189 | @staticmethod 190 | def _get_next_input_function(task_type: str, prompt_context: str, task_generator): 191 | """ 192 | Get the appropriate next_input function based on the task type. 193 | 194 | Args: 195 | task_type (str): The type of task. 196 | prompt_context (str): The context for the prompt. 197 | task_generator: The task generator instance. 198 | 199 | Returns: 200 | callable: A function that generates the next input for the specified task type. 201 | 202 | Raises: 203 | ValueError: If an unknown task type is provided. 204 | """ 205 | if task_type == "text-classification": 206 | 207 | def next_input(_text, _label, _prompt_context): 208 | _prompt_context = _prompt_context or prompt_context 209 | inputs = [{"task": _prompt_context}] 210 | data = next(task_generator.process(inputs))[0] 211 | return data["input_text"], None, _prompt_context 212 | elif task_type in ["text-generation", "chat-generation"]: 213 | 214 | def next_input(_instruction, _response, _prompt_context): 215 | _prompt_context = _prompt_context or prompt_context 216 | data = next( 217 | task_generator.process([{"system_prompt": _prompt_context}]) 218 | )[0] 219 | if task_type == "text-generation": 220 | return data["instruction"], data["response"], _prompt_context 221 | else: 222 | conversation = data["conversation"][:-1] 223 | response = data["conversation"][-1]["content"] 224 | return conversation, response, _prompt_context 225 | elif task_type in ["text-generation-preference", "chat-generation-preference"]: 226 | 227 | def next_input(_conversation, _response_1, _response_2, _prompt_context): 228 | _prompt_context = _prompt_context or prompt_context 229 | data = next( 230 | task_generator.process([{"system_prompt": _prompt_context}]) 231 | )[0] 232 | if task_type == "text-generation-preference": 233 | response_2 = task_generator.llm.generate( 234 | inputs=[[{"role": "user", "content": data["instruction"]}]], 235 | **_GENERATION_KWARGS, 236 | )[0][0] 237 | return ( 238 | data["instruction"], 239 | data["response"], 240 | response_2, 241 | _prompt_context, 242 | ) 243 | else: 244 | conversation = data["conversation"][:-1] 245 | response_1 = data["conversation"][-1]["content"] 246 | response_2 = task_generator.llm.generate( 247 | inputs=[conversation], **_GENERATION_KWARGS 248 | )[0][0] 249 | return conversation, response_1, response_2, _prompt_context 250 | elif task_type == "chat-classification": 251 | 252 | def next_input(_conversation, _label, _prompt_context): 253 | _prompt_context = _prompt_context or prompt_context 254 | data = next( 255 | task_generator.process([{"system_prompt": _prompt_context}]) 256 | )[0] 257 | return data["conversation"], None, _prompt_context 258 | elif task_type == "image-classification": 259 | 260 | def next_input(_image, _label, _prompt_context): 261 | _prompt_context = _prompt_context or prompt_context 262 | image = task_generator.process(_prompt_context) 263 | return image, None, _prompt_context 264 | elif task_type == "image-generation": 265 | 266 | def next_input( 267 | _prompt, _image, _prompt_context 268 | ): # -> tuple[Any, Any, Any | str]:# -> tuple[Any, Any, Any | str]: 269 | _prompt_context = _prompt or _prompt_context or prompt_context 270 | image = task_generator.process(_prompt_context) 271 | return _prompt_context, image, _prompt_context 272 | 273 | elif task_type == "image-description": 274 | 275 | def next_input(_image, _description, _prompt_context): 276 | _prompt_context = _prompt_context or prompt_context 277 | image = task_generator.process(_prompt_context) 278 | return image, None, _prompt_context 279 | elif task_type == "image-generation-preference": 280 | 281 | def next_input(_prompt, _image_1, _image_2, _prompt_context): 282 | _prompt_context = _prompt or _prompt_context or prompt_context 283 | image_1 = task_generator.process(_prompt_context) 284 | image_2 = task_generator.process(_prompt_context) 285 | return _prompt_context, image_1, image_2, _prompt_context 286 | elif task_type == "image-question-answering": 287 | 288 | def next_input(_image, _question, _answer, _prompt_context): 289 | _prompt_context = _prompt_context or prompt_context 290 | image = task_generator.process(_prompt_context) 291 | return image, None, None, _prompt_context 292 | else: 293 | raise ValueError(f"Unknown task type: {task_type}") 294 | 295 | return next_input 296 | 297 | @classmethod 298 | def for_text_classification( 299 | cls, prompt_context: str, llm: Optional[LLM] = None, **kwargs 300 | ) -> "Synthesizer": 301 | """ 302 | Create a Synthesizer for text classification tasks. 303 | 304 | Args: 305 | prompt_context (str): The context for the prompt. 306 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 307 | **kwargs: Additional keyword arguments for task configuration. 308 | 309 | Returns: 310 | Synthesizer: An instance of the Synthesizer class configured for text classification. 311 | """ 312 | return cls._create_synthesizer( 313 | "text-classification", prompt_context, llm, **kwargs 314 | ) 315 | 316 | @classmethod 317 | def for_text_generation( 318 | cls, prompt_context: str, llm: Optional[LLM] = None 319 | ) -> "Synthesizer": 320 | """ 321 | Create a Synthesizer for text generation tasks. 322 | 323 | Args: 324 | prompt_context (str): The context for the prompt. 325 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 326 | 327 | Returns: 328 | Synthesizer: An instance of the Synthesizer class configured for text generation. 329 | """ 330 | return cls._create_synthesizer( 331 | "text-generation", 332 | prompt_context, 333 | llm, 334 | runtime_parameters={"n_turns": 1, "end_with_user": False}, 335 | ) 336 | 337 | @classmethod 338 | def for_question_answering( 339 | cls, prompt_context: str, llm: Optional[LLM] = None 340 | ) -> "Synthesizer": 341 | raise NotImplementedError 342 | 343 | def for_token_classification( 344 | cls, prompt_context: str, llm: Optional[LLM] = None 345 | ) -> "Synthesizer": 346 | raise NotImplementedError 347 | 348 | @classmethod 349 | def for_text_generation_preference( 350 | cls, prompt_context: str, llm: Optional[LLM] = None 351 | ) -> "Synthesizer": 352 | """ 353 | Create a Synthesizer for text generation preference tasks. 354 | 355 | Args: 356 | prompt_context (str): The context for the prompt. 357 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 358 | 359 | Returns: 360 | Synthesizer: An instance of the Synthesizer class configured for text generation preference. 361 | """ 362 | return cls._create_synthesizer( 363 | "text-generation-preference", 364 | prompt_context, 365 | llm, 366 | runtime_parameters={"n_turns": 1, "end_with_user": False}, 367 | ) 368 | 369 | @classmethod 370 | def for_chat_generation( 371 | cls, prompt_context: str, llm: Optional[LLM] = None, n_turns: int = 2 372 | ) -> "Synthesizer": 373 | """ 374 | Create a Synthesizer for chat generation tasks. 375 | 376 | Args: 377 | prompt_context (str): The context for the prompt. 378 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 379 | n_turns (int): The number of turns in the conversation. 380 | 381 | Returns: 382 | Synthesizer: An instance of the Synthesizer class configured for chat generation. 383 | """ 384 | assert n_turns > 1, "n_turns must be greater than 1" 385 | return cls._create_synthesizer( 386 | "chat-generation", 387 | prompt_context, 388 | llm, 389 | runtime_parameters={"n_turns": n_turns, "end_with_user": False}, 390 | ) 391 | 392 | @classmethod 393 | def for_chat_classification( 394 | cls, prompt_context: str, llm: Optional[LLM] = None, n_turns: int = 2 395 | ) -> "Synthesizer": 396 | """ 397 | Create a Synthesizer for chat classification tasks. 398 | 399 | Args: 400 | prompt_context (str): The context for the prompt. 401 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 402 | n_turns (int): The number of turns in the conversation. 403 | 404 | Returns: 405 | Synthesizer: An instance of the Synthesizer class configured for chat classification. 406 | """ 407 | assert n_turns > 1, "n_turns must be greater than 1" 408 | return cls._create_synthesizer( 409 | "chat-classification", 410 | prompt_context, 411 | llm, 412 | runtime_parameters={"n_turns": n_turns, "end_with_user": False}, 413 | ) 414 | 415 | @classmethod 416 | def for_chat_generation_preference( 417 | cls, prompt_context: str, llm: Optional[LLM] = None, n_turns: int = 2 418 | ) -> "Synthesizer": 419 | """ 420 | Create a Synthesizer for chat generation preference tasks. 421 | 422 | Args: 423 | prompt_context (str): The context for the prompt. 424 | llm (Optional[LLM]): The language model to use. If None, uses the default model. 425 | n_turns (int): The number of turns in the conversation. 426 | 427 | Returns: 428 | Synthesizer: An instance of the Synthesizer class configured for chat generation preference. 429 | """ 430 | assert n_turns > 1, "n_turns must be greater than 1" 431 | return cls._create_synthesizer( 432 | "chat-generation-preference", 433 | prompt_context, 434 | llm, 435 | runtime_parameters={"n_turns": n_turns, "end_with_user": False}, 436 | ) 437 | 438 | @classmethod 439 | def for_image_classification( 440 | cls, prompt_context: str, llm: Optional[str] = None, **kwargs 441 | ) -> "Synthesizer": 442 | """ 443 | Create a Synthesizer for image classification tasks. 444 | 445 | Args: 446 | prompt_context (str): The context for the prompt. 447 | llm (Optional[str]): The Hugging Face URL of the image generation model. If None, uses the default model. 448 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 449 | **kwargs: Additional keyword arguments for task configuration. 450 | 451 | Returns: 452 | Synthesizer: An instance of the Synthesizer class configured for image classification. 453 | """ 454 | return cls._create_synthesizer( 455 | "image-classification", prompt_context, llm, **kwargs 456 | ) 457 | 458 | @classmethod 459 | def for_image_generation( 460 | cls, prompt_context: str, llm: Optional[str] = None, **kwargs 461 | ) -> "Synthesizer": 462 | """ 463 | Create a Synthesizer for image generation tasks. 464 | 465 | Args: 466 | prompt_context (str): The context for the prompt. 467 | llm (Optional[str]): The Hugging Face URL of the image generation model. If None, uses the default model. 468 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 469 | **kwargs: Additional keyword arguments for task configuration. 470 | 471 | Returns: 472 | Synthesizer: An instance of the Synthesizer class configured for image generation. 473 | """ 474 | return cls._create_synthesizer( 475 | "image-generation", prompt_context, llm, **kwargs 476 | ) 477 | 478 | @classmethod 479 | def for_image_description( 480 | cls, prompt_context: str, llm: Optional[str] = None, **kwargs 481 | ) -> "Synthesizer": 482 | """ 483 | Create a Synthesizer for image description tasks. 484 | 485 | Args: 486 | prompt_context (str): The context for the prompt. 487 | llm (Optional[str]): The Hugging Face URL of the image generation model. If None, uses the default model. 488 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 489 | **kwargs: Additional keyword arguments for task configuration. 490 | 491 | Returns: 492 | Synthesizer: An instance of the Synthesizer class configured for image description. 493 | """ 494 | return cls._create_synthesizer( 495 | "image-description", prompt_context, llm, **kwargs 496 | ) 497 | 498 | @classmethod 499 | def for_image_generation_preference( 500 | cls, prompt_context: str, llm: Optional[str] = None, **kwargs 501 | ) -> "Synthesizer": 502 | """ 503 | Create a Synthesizer for image generation preference tasks. 504 | 505 | Args: 506 | prompt_context (str): The context for the prompt. 507 | llm (Optional[str]): The Hugging Face URL of the image generation model. If None, uses the default model. 508 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 509 | **kwargs: Additional keyword arguments for task configuration. 510 | 511 | Returns: 512 | Synthesizer: An instance of the Synthesizer class configured for image generation preference. 513 | """ 514 | return cls._create_synthesizer( 515 | "image-generation-preference", prompt_context, llm, **kwargs 516 | ) 517 | 518 | @classmethod 519 | def for_image_question_answering( 520 | cls, prompt_context: str, llm: Optional[str] = None, **kwargs 521 | ) -> "Synthesizer": 522 | """ 523 | Create a Synthesizer for image question answering tasks. 524 | 525 | Args: 526 | prompt_context (str): The context for the prompt. 527 | llm (Optional[str]): The Hugging Face URL of the image generation model. If None, uses the default model. 528 | "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" 529 | **kwargs: Additional keyword arguments for task configuration. 530 | 531 | Returns: 532 | Synthesizer: An instance of the Synthesizer class configured for image question answering. 533 | """ 534 | return cls._create_synthesizer( 535 | "image-question-answering", prompt_context, llm, **kwargs 536 | ) 537 | -------------------------------------------------------------------------------- /src/dataset_viber/bulk.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-present, David Berenstein, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import warnings 17 | from collections import defaultdict 18 | from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union 19 | 20 | import dash 21 | import dash_bootstrap_components as dbc 22 | import pandas as pd 23 | import plotly.express as px 24 | import umap 25 | from dash import dash_table, dcc, html 26 | from dash.dependencies import Input, Output, State 27 | from plotly.graph_objs._figure import Figure 28 | 29 | from dataset_viber._constants import DEFAULT_EMBEDDING_MODEL 30 | 31 | if TYPE_CHECKING: 32 | from sentence_transformers import SentenceTransformer 33 | 34 | 35 | class BulkInterface: 36 | def __init__( 37 | self, 38 | dataframe: pd.DataFrame, 39 | content_column: str, 40 | *, 41 | label_column: str = None, 42 | embedding_model: Optional[ 43 | Union["SentenceTransformer", str] 44 | ] = DEFAULT_EMBEDDING_MODEL, 45 | umap_kwargs: dict = {}, 46 | labels: list[str] = None, 47 | dataset_name: Optional[str] = None, 48 | hf_token: Optional[str] = None, 49 | private: Optional[bool] = False, 50 | content_format: Optional[str] = "text", 51 | ): 52 | self.content_format = content_format 53 | self.content_column = content_column 54 | self.label_column = label_column 55 | self.labels = labels 56 | if label_column and labels: 57 | if not all([label in dataframe[label_column].unique() for label in labels]): 58 | # apply label to first x empty rows in label_column 59 | empty_rows = dataframe[dataframe[label_column] == ""].index 60 | for label, row in zip(labels, empty_rows): 61 | dataframe.loc[row, label_column] = label 62 | warnings.warn( 63 | "Labels were not found in the label_column. Applied labels to the first x empty rows." 64 | ) 65 | 66 | contents = dataframe[content_column].tolist() 67 | 68 | # Apply embedding reduction 69 | component_columns: list[str] = ["x", "y"] 70 | if all([col in dataframe.columns for col in component_columns]): 71 | umap_df = dataframe[component_columns] 72 | else: 73 | if "embeddings" in dataframe.columns: 74 | embeddings = dataframe["embeddings"].tolist() 75 | else: 76 | self._set_embedding_model(embedding_model) 77 | embeddings = self.embed_content(contents) 78 | reducer = umap.UMAP(n_components=2, **umap_kwargs) 79 | umap_embeddings = reducer.fit_transform(embeddings) 80 | # Create a DataFrame for plotting 81 | umap_df = pd.DataFrame( 82 | umap_embeddings, 83 | columns=component_columns, 84 | ) 85 | 86 | umap_df["index"] = dataframe.index 87 | for col in dataframe.columns: 88 | umap_df[col] = dataframe[col] 89 | 90 | self.umap_df = umap_df 91 | app = dash.Dash(__name__, external_stylesheets=[dbc.themes.SANDSTONE]) 92 | figure = self._get_initial_figure(umap_df) 93 | app.layout = self._get_app_layout(figure, umap_df, labels, hf_token) 94 | 95 | if labels is not None: 96 | 97 | @app.callback( 98 | [ 99 | Output("scatter-plot", "figure", allow_duplicate=True), 100 | Output("data-table", "data", allow_duplicate=True), 101 | Output("data-table", "tooltip_data", allow_duplicate=True), 102 | ], 103 | [Input("update-button", "n_clicks")], 104 | [ 105 | State("scatter-plot", "selectedData"), 106 | State("label-dropdown", "value"), 107 | State("scatter-plot", "figure"), 108 | ], 109 | prevent_initial_call=True, 110 | ) 111 | def update_labels(n_clicks, selectedData, new_label, current_figure): 112 | ctx = dash.callback_context 113 | if not ctx.triggered or new_label is None: 114 | return current_figure, self.umap_df.to_dict("records") 115 | 116 | hidden_traces = [] 117 | for trace in figure["data"]: 118 | if trace["visible"] == "legendonly": 119 | hidden_traces.append(trace) 120 | 121 | if selectedData and selectedData["points"]: 122 | selected_indices = [ 123 | point["customdata"][0] for point in selectedData["points"] 124 | ] 125 | self.umap_df.loc[selected_indices, self.label_column] = new_label 126 | updated_traces = [] 127 | points_to_move = defaultdict(list) 128 | for trace in current_figure["data"]: 129 | if trace not in hidden_traces: 130 | if new_label != trace["name"]: 131 | points_to_keep = defaultdict(list) 132 | for idx, point in enumerate(trace["customdata"]): 133 | if point[0] not in selected_indices: 134 | points_to_keep["customdata"].append(point) 135 | points_to_keep["x"].append(trace["x"][idx]) 136 | points_to_keep["y"].append(trace["y"][idx]) 137 | else: 138 | points_to_move["customdata"].append(point) 139 | points_to_move["x"].append(trace["x"][idx]) 140 | points_to_move["y"].append(trace["y"][idx]) 141 | trace["customdata"] = points_to_keep["customdata"] 142 | trace["x"] = points_to_keep["x"] 143 | trace["y"] = points_to_keep["y"] 144 | trace["selectedpoints"] = [] 145 | updated_traces.append(trace) 146 | for trace in current_figure["data"]: 147 | if trace["name"] == new_label: 148 | trace["customdata"] += points_to_move["customdata"] 149 | trace["x"] += points_to_move["x"] 150 | trace["y"] += points_to_move["y"] 151 | trace["selectedpoints"] = [] 152 | updated_traces.append(trace) 153 | current_figure["data"] = updated_traces 154 | 155 | local_dataframe = self.umap_df.copy() 156 | tooltip_data = self.get_tooltip(local_dataframe) 157 | if self.content_format == "chat": 158 | local_dataframe[self.content_column] = local_dataframe[ 159 | self.content_column 160 | ].apply(lambda x: x[0]["content"]) 161 | return current_figure, local_dataframe.to_dict("records"), tooltip_data 162 | 163 | # Callback to print the dataframe 164 | @app.callback( 165 | Output("data-table", "data", allow_duplicate=True), 166 | [Input("upload-button", "n_clicks")], 167 | prevent_initial_call=True, 168 | ) 169 | def print_dataframe(n_clicks): 170 | if n_clicks > 0: 171 | print(self.umap_df) # This will print the dataframe to the console 172 | return self.umap_df.to_dict( 173 | "records" 174 | ) # Return the data to avoid updating the table 175 | 176 | @app.callback( 177 | Output("download-text", "data"), 178 | Input("btn-download-txt", "n_clicks"), 179 | prevent_initial_call=True, 180 | ) 181 | def func(n_clicks): 182 | return dcc.send_data_frame(self.umap_df.to_csv, "data.csv") 183 | 184 | # Update the existing update_selection callback 185 | @app.callback( 186 | [ 187 | Output("scatter-plot", "figure", allow_duplicate=True), 188 | Output("data-table", "data", allow_duplicate=True), 189 | Output("data-table", "tooltip_data", allow_duplicate=True), 190 | ], 191 | [Input("scatter-plot", "selectedData")], 192 | [State("scatter-plot", "figure")], 193 | prevent_initial_call=True, 194 | ) 195 | def update_selection(selectedData, figure): 196 | ctx = dash.callback_context 197 | if not ctx.triggered: 198 | return figure, self.umap_df.to_dict("records") 199 | 200 | hidden_traces = [] 201 | for trace in figure["data"]: 202 | if trace.get("visible") == "legendonly": 203 | hidden_traces.append(trace) 204 | 205 | if selectedData and selectedData["points"]: 206 | selected_indices = [ 207 | point["customdata"][0] for point in selectedData["points"] 208 | ] 209 | filtered_df = self.umap_df.iloc[selected_indices] 210 | else: 211 | filtered_df = self.umap_df 212 | selected_indices = None 213 | 214 | if hidden_traces: 215 | filtered_df = filtered_df[ 216 | ~filtered_df[self.label_column].isin(hidden_traces) 217 | ] 218 | 219 | local_dataframe = filtered_df.copy() 220 | tooltip_data = self.get_tooltip(local_dataframe) 221 | if self.content_format == "chat": 222 | local_dataframe[self.content_column] = local_dataframe[ 223 | self.content_column 224 | ].apply(lambda x: x[0]["content"]) 225 | return figure, local_dataframe.to_dict("records"), tooltip_data 226 | 227 | self.app = app 228 | 229 | def _set_embedding_model(self, embedding_model: str): 230 | import torch 231 | from sentence_transformers import SentenceTransformer 232 | 233 | if isinstance(embedding_model, SentenceTransformer): 234 | self.embedding_model = embedding_model 235 | elif isinstance(embedding_model, str): 236 | device = "cpu" 237 | if torch.backends.mps.is_available(): 238 | device = "mps" 239 | elif torch.cuda.is_available(): 240 | device = "cuda" 241 | if importlib.util.find_spec("fast_sentence_transformers") is not None: 242 | from fast_sentence_transformers import FastSentenceTransformer 243 | 244 | self.embedding_model = FastSentenceTransformer( 245 | model_id=embedding_model, device=device 246 | ) 247 | else: 248 | self.embedding_model = SentenceTransformer( 249 | model_name_or_path=embedding_model, device=device 250 | ) 251 | else: 252 | raise ValueError( 253 | "Embedding model should be of type `str` or `SentenceTransformer`" 254 | ) 255 | 256 | @classmethod 257 | def for_text_visualization( 258 | cls, 259 | dataframe: pd.DataFrame, 260 | content_column: str, 261 | *, 262 | label_column: str = None, 263 | embedding_model: Optional[ 264 | Union["SentenceTransformer", str] 265 | ] = DEFAULT_EMBEDDING_MODEL, 266 | umap_kwargs: dict = {}, 267 | ): 268 | return cls( 269 | dataframe=dataframe, 270 | content_column=content_column, 271 | label_column=label_column, 272 | embedding_model=embedding_model, 273 | umap_kwargs=umap_kwargs, 274 | content_format="text", 275 | ) 276 | 277 | @classmethod 278 | def for_text_classification( 279 | cls, 280 | dataframe: pd.DataFrame, 281 | content_column: str, 282 | labels: list[str], 283 | *, 284 | label_column: str = None, 285 | embedding_model: Optional[ 286 | Union["SentenceTransformer", str] 287 | ] = DEFAULT_EMBEDDING_MODEL, 288 | umap_kwargs: dict = {}, 289 | dataset_name: Optional[str] = None, 290 | hf_token: Optional[str] = None, 291 | private: Optional[bool] = False, 292 | ): 293 | if not label_column: 294 | dataframe["label"] = "" 295 | label_column = "label" 296 | return cls( 297 | dataframe=dataframe, 298 | content_column=content_column, 299 | label_column=label_column, 300 | embedding_model=embedding_model, 301 | umap_kwargs=umap_kwargs, 302 | labels=labels, 303 | dataset_name=dataset_name, 304 | hf_token=hf_token, 305 | private=private, 306 | content_format="text", 307 | ) 308 | 309 | @classmethod 310 | def for_chat_visualization( 311 | cls, 312 | dataframe: pd.DataFrame, 313 | chat_column: List[Dict[str, str]], 314 | *, 315 | label_column: str = None, 316 | embedding_model: Optional[ 317 | Union["SentenceTransformer", str] 318 | ] = DEFAULT_EMBEDDING_MODEL, 319 | umap_kwargs: dict = {}, 320 | ): 321 | return cls( 322 | dataframe=dataframe, 323 | content_column=chat_column, 324 | label_column=label_column, 325 | embedding_model=embedding_model, 326 | umap_kwargs=umap_kwargs, 327 | content_format="chat", 328 | ) 329 | 330 | @classmethod 331 | def for_chat_classification( 332 | cls, 333 | dataframe: pd.DataFrame, 334 | chat_column: List[Dict[str, str]], 335 | labels: list[str], 336 | *, 337 | label_column: str = None, 338 | embedding_model: Optional[ 339 | Union["SentenceTransformer", str] 340 | ] = DEFAULT_EMBEDDING_MODEL, 341 | umap_kwargs: dict = {}, 342 | dataset_name: Optional[str] = None, 343 | hf_token: Optional[str] = None, 344 | private: Optional[bool] = False, 345 | ): 346 | if not label_column: 347 | dataframe["label"] = "" 348 | label_column = "label" 349 | return cls( 350 | dataframe=dataframe, 351 | content_column=chat_column, 352 | embedding_model=embedding_model, 353 | umap_kwargs=umap_kwargs, 354 | labels=labels, 355 | dataset_name=dataset_name, 356 | label_column=label_column, 357 | hf_token=hf_token, 358 | private=private, 359 | content_format="chat", 360 | ) 361 | 362 | def launch(self, **kwargs): 363 | self.app.run_server(**kwargs) 364 | 365 | def _get_initial_figure(self, dataframe) -> Figure: 366 | # color_map = {label: color for label, color in zip(self.labels, _COLORS)} 367 | dataframe[f"wrapped_hover_{self.content_column}"] = dataframe[ 368 | self.content_column 369 | ].apply(lambda x: self.format_content(x, content_format=self.content_format)) 370 | custom_data = ["index"] + [ 371 | col 372 | for col in dataframe.columns 373 | if col not in ["x", "y", "index", self.content_column] 374 | ] 375 | hovertemplate: Literal[""] = "" 376 | df_custom = dataframe[custom_data] 377 | for col in df_custom.columns: 378 | if col in ["index", self.label_column]: 379 | continue 380 | idx = df_custom.columns.get_loc(col) 381 | hovertemplate += f"{col.replace('wrapped_hover_', '')}:
%{{customdata[{idx}]}}
" 382 | fig = px.scatter( 383 | dataframe, 384 | x="x", 385 | y="y", 386 | color=self.label_column if self.label_column in dataframe.columns else None, 387 | height=800, 388 | custom_data=custom_data, 389 | ) 390 | fig.update_traces(hovertemplate=str(hovertemplate)) 391 | fig.update_layout( 392 | xaxis_title=None, 393 | yaxis_title=None, 394 | xaxis=dict(showticklabels=False), 395 | yaxis=dict(showticklabels=False), 396 | margin=dict(l=0, r=0, t=0, b=0), 397 | dragmode="lasso", 398 | legend=dict( 399 | orientation="h", # Horizontal legend 400 | yanchor="top", # Anchor the legend to the top of the container 401 | y=-0.01, # Position the legend below the plot 402 | xanchor="center", # Center the legend horizontally 403 | x=0.5, # Center the legend at the bottom 404 | ), 405 | hoverlabel=dict( 406 | font_size=10, 407 | font_family="monospace", 408 | ), 409 | ) 410 | return fig 411 | 412 | def _get_app_layout(self, figure, dataframe, labels, hf_token): 413 | local_dataframe = dataframe.copy() 414 | buttons = [] 415 | if labels is not None: 416 | buttons.extend( 417 | [ 418 | buttons.append( 419 | dcc.Dropdown( 420 | id="label-dropdown", 421 | options=[ 422 | {"label": label, "value": label} for label in labels 423 | ], 424 | value=labels[0], 425 | clearable=True, 426 | style={ 427 | "width": "200px", 428 | "marginBottom": "-13px", 429 | "display": "inline-block", 430 | }, 431 | ) 432 | ), 433 | buttons.append( 434 | dbc.Button( 435 | "Update Labels", 436 | id="update-button", 437 | n_clicks=0, 438 | ) 439 | ), 440 | dbc.Button( 441 | "Upload to Hub", 442 | id="upload-button", 443 | n_clicks=0, 444 | ), 445 | dbc.Button("Download Text", id="btn-download-txt"), 446 | dcc.Download(id="download-text"), 447 | ] 448 | ) 449 | if self.content_format == "chat": 450 | tooltip_data = self.get_tooltip(local_dataframe) 451 | local_dataframe[self.content_column] = local_dataframe[ 452 | self.content_column 453 | ].apply(lambda x: x[0]["content"]) 454 | columns = local_dataframe.columns 455 | elif self.content_format == "text": 456 | tooltip_data = None 457 | columns = local_dataframe.columns 458 | else: 459 | raise ValueError( 460 | "content_format should be either 'text' or 'chat' but got {self.content_format}" 461 | ) 462 | 463 | layout = html.Div( 464 | [ 465 | html.H1("BulkInterface"), 466 | # Scatter plot 467 | html.Div( 468 | [ 469 | dcc.Graph(id="scatter-plot", figure=figure), 470 | html.Div([*buttons]), 471 | ], 472 | style={ 473 | "width": "49%", 474 | "display": "inline-block", 475 | "vertical-align": "top", 476 | "marginRight": "1%", 477 | }, 478 | ), 479 | html.Div( 480 | [ 481 | dash_table.DataTable( 482 | id="data-table", 483 | columns=[{"name": i, "id": i} for i in columns], 484 | data=local_dataframe[columns].to_dict("records"), 485 | hidden_columns=[ 486 | "x", 487 | "y", 488 | "index", 489 | f"wrapped_hover_{self.content_column}", 490 | ], 491 | tooltip_data=tooltip_data, 492 | column_selectable=False, 493 | page_size=20, 494 | fill_width=True, 495 | css=[ 496 | {"selector": ".show-hide", "rule": "display: none"}, 497 | { 498 | "selector": ".dash-table-tooltip", 499 | "rule": """ 500 | background-color: grey; 501 | font-family: monospace; 502 | color: white; 503 | max-width: 100vw !important; 504 | max-height: 80vh !important; 505 | overflow: auto; 506 | font-size: 10px; 507 | position: fixed; 508 | top: 50%; 509 | left: 50%; 510 | transform: translate(-50%, -50%); 511 | z-index: 1000; 512 | """, 513 | }, 514 | ], 515 | style_cell={ 516 | "whiteSpace": "normal", 517 | "height": "auto", 518 | "textAlign": "left", 519 | "font-size": "10px", 520 | "overflow": "auto", # Enable scrolling 521 | }, 522 | style_data={ 523 | "whiteSpace": "normal", 524 | "height": "auto", 525 | }, 526 | style_table={"overflowX": "auto"}, 527 | tooltip_duration=None, 528 | ) 529 | ], 530 | style={ 531 | "width": "50%", 532 | "display": "inline-block", 533 | "vertical-align": "top", 534 | }, 535 | ), 536 | ] 537 | ) 538 | return layout 539 | 540 | def embed_content(self, content: List[str]): 541 | if self.content_format == "text": 542 | return self.embedding_model.encode(content, convert_to_numpy=True) 543 | elif self.content_format == "chat": 544 | content = [ 545 | " ".join([turn["content"] for turn in conversation]) 546 | for conversation in content 547 | ] 548 | return self.embedding_model.encode(content, convert_to_numpy=True) 549 | 550 | def format_content(self, content, max_length=120, content_format="text"): 551 | wrapped_text = "" 552 | if content_format == "text": 553 | words = content.split(" ") 554 | line = "" 555 | 556 | for word in words: 557 | if len(line) + len(word) + 1 > max_length: 558 | if line: 559 | wrapped_text += line + "
" 560 | line = word 561 | else: 562 | if line: 563 | line += " " + word 564 | else: 565 | line = word 566 | 567 | wrapped_text += line 568 | return wrapped_text 569 | elif content_format == "chat": 570 | wrapped_text = "First 2 turns:

" 571 | for turn in content[:3]: 572 | wrapped_text += f"{turn['role']}:
{self.format_content(turn['content'])}

" 573 | return wrapped_text 574 | 575 | def get_tooltip(self, dataframe): 576 | if self.content_format == "text": 577 | return None 578 | return [ 579 | { 580 | self.content_column: { 581 | "value": pd.DataFrame.from_records(value)[ 582 | ["role", "content"] 583 | ].to_markdown(index=False, tablefmt="pipe"), 584 | "type": "markdown", 585 | } 586 | } 587 | for value in dataframe[self.content_column].tolist() 588 | ] 589 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | dataset-viber 3 |
4 | Dataset Viber 5 |
6 |

7 | 8 |

Avoid the hype, check the vibe!

9 | 10 | I've cooked up Dataset Viber, a cool set of tools to make your life easier when dealing with data for AI models. Dataset Viber is all about making your data prep journey smooth and fun. It's **not for team collaboration or production**, nor trying to be all fancy and formal - just a bunch of **cool tools to help you collect feedback and do vibe-checks** as an AI engineer or lover. Want to see it in action? Just plug it in and start vibing with your data. It's that easy! 11 | 12 | - **CollectorInterface**: Lazily collect data of model interactions without human annotation. 13 | - **AnnotatorInterface**: Walk through your data and annotate it with models in the loop. 14 | - **Synthesizer**: Synthesize data with `distilabel` in the loop. 15 | - **BulkInterface**: Explore your data distribution and annotate in bulk. 16 | 17 | Need any tweaks or want to hear more about a specific tool? Just [open an issue](https://github.com/davidberenstein1957/dataset-viber/issues/new) or give me a shout! 18 | 19 | > [!NOTE] 20 | > 21 | > - Data is logged to a local CSV or directly to the Hugging Face Hub. 22 | > - All tools also run in `.ipynb` notebooks. 23 | > - Models in the loop through `fn_model`. 24 | > - Input with custom data streamers or pre-built `Synthesizer` classes with the `fn_next_input` argument. 25 | > - It supports various tasks for `text`, `chat` and `image` modalities. 26 | > - Import and export from the Hugging Face Hub or CSV files. 27 | 28 | > [!TIP] 29 | > 30 | > - Code examples: [src/dataset_viber/examples](https://github.com/davidberenstein1957/dataset-viber/tree/main/src/dataset_viber/examples). 31 | > - Hub examples: [https://huggingface.co/dataset-viber](https://huggingface.co/dataset-viber). 32 | 33 | ## Installation 34 | 35 | You can install the package via pip: 36 | 37 | ```bash 38 | pip install dataset-viber 39 | ``` 40 | 41 | Or install `Synthesizer` dependencies. Note, that the `Synthesizer` relies on `distilabel[hf-inference-endpoints]`, but you can use other [LLMs available to distilabel](https://distilabel.argilla.io) too, like for example `distilabel[ollama]`. 42 | 43 | ```bash 44 | pip install dataset-viber[synthesizer] 45 | ``` 46 | 47 | Or install `BulkInterface` dependencies: 48 | 49 | ```bash 50 | pip install dataset-viber[bulk] 51 | ``` 52 | 53 | ## How are we vibing? 54 | 55 | ### CollectorInterface 56 | 57 | > Built on top of the `gr.Interface` and `gr.ChatInterface` to lazily collect data for interactions automatically. 58 | 59 | 60 | 61 | [Hub dataset](https://huggingface.co/datasets/davidberenstein1957/dataset-viber-token-classification) 62 | 63 |
64 | CollectorInterface 65 | 66 | ```python 67 | import gradio as gr 68 | from dataset_viber import CollectorInterface 69 | 70 | def calculator(num1, operation, num2): 71 | if operation == "add": 72 | return num1 + num2 73 | elif operation == "subtract": 74 | return num1 - num2 75 | elif operation == "multiply": 76 | return num1 * num2 77 | elif operation == "divide": 78 | return num1 / num2 79 | 80 | inputs = ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"] 81 | outputs = "number" 82 | 83 | interface = CollectorInterface( 84 | fn=calculator, 85 | inputs=inputs, 86 | outputs=outputs, 87 | csv_logger=False, # True if you want to log to a CSV 88 | dataset_name="/" 89 | ) 90 | interface.launch() 91 | ``` 92 | 93 |
94 | 95 |
96 | CollectorInterface.from_interface 97 | 98 | ```python 99 | interface = gr.Interface( 100 | fn=calculator, 101 | inputs=inputs, 102 | outputs=outputs 103 | ) 104 | interface = CollectorInterface.from_interface( 105 | interface=interface, 106 | csv_logger=False, # True if you want to log to a CSV 107 | dataset_name="/" 108 | ) 109 | interface.launch() 110 | ``` 111 | 112 |
113 | 114 |
115 | CollectorInterface.from_pipeline 116 | 117 | ```python 118 | from transformers import pipeline 119 | from dataset_viber import CollectorInterface 120 | 121 | pipeline = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection") 122 | interface = CollectorInterface.from_pipeline( 123 | pipeline=pipeline, 124 | csv_logger=False, # True if you want to log to a CSV 125 | dataset_name="/" 126 | ) 127 | interface.launch() 128 | ``` 129 | 130 |
131 | 132 | ### AnnotatorInterface 133 | 134 | > Built on top of the `CollectorInterface` to collect and annotate data and log it to the Hub. 135 | 136 | 137 | #### Text 138 | 139 | https://github.com/user-attachments/assets/d1abda66-9972-4c60-89d2-7626f5654f15 140 | 141 | [Hub dataset](https://huggingface.co/datasets/davidberenstein1957/dataset-viber-text-classification) 142 | 143 |
144 | text-classification/multi-label-text-classification 145 | 146 | ```python 147 | from dataset_viber import AnnotatorInterFace 148 | 149 | texts = [ 150 | "Anthony Bourdain was an amazing chef!", 151 | "Anthony Bourdain was a terrible tv persona!" 152 | ] 153 | labels = ["positive", "negative"] 154 | 155 | interface = AnnotatorInterFace.for_text_classification( 156 | texts=texts, 157 | labels=labels, 158 | multi_label=False, # True if you have multi-label data 159 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 160 | fn_next_input=None, # a function that feeds gradio components actively with the next input 161 | csv_logger=False, # True if you want to log to a CSV 162 | dataset_name=None # "/" if you want to log to the hub 163 | ) 164 | interface.launch() 165 | ``` 166 | 167 |
168 | 169 |
170 | token-classification 171 | 172 | ```python 173 | from dataset_viber import AnnotatorInterFace 174 | 175 | texts = ["Anthony Bourdain was an amazing chef in New York."] 176 | labels = ["NAME", "LOC"] 177 | 178 | interface = AnnotatorInterFace.for_token_classification( 179 | texts=texts, 180 | labels=labels, 181 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 182 | fn_next_input=None, # a function that feeds gradio components actively with the next input 183 | csv_logger=False, # True if you want to log to a CSV 184 | dataset_name=None # "/" if you want to log to the hub 185 | ) 186 | interface.launch() 187 | ``` 188 | 189 |
190 | 191 |
192 | extractive-question-answering 193 | 194 | ```python 195 | from dataset_viber import AnnotatorInterFace 196 | 197 | questions = ["Where was Anthony Bourdain located?"] 198 | contexts = ["Anthony Bourdain was an amazing chef in New York."] 199 | 200 | interface = AnnotatorInterFace.for_question_answering( 201 | questions=questions, 202 | contexts=contexts, 203 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 204 | fn_next_input=None, # a function that feeds gradio components actively with the next input 205 | csv_logger=False, # True if you want to log to a CSV 206 | dataset_name=None # "/" if you want to log to the hub 207 | ) 208 | interface.launch() 209 | ``` 210 | 211 |
212 | 213 |
214 | text-generation/translation/completion 215 | 216 | ```python 217 | from dataset_viber import AnnotatorInterFace 218 | 219 | prompts = ["Tell me something about Anthony Bourdain."] 220 | completions = ["Anthony Michael Bourdain was an American celebrity chef, author, and travel documentarian."] 221 | 222 | interface = AnnotatorInterFace.for_text_generation( 223 | prompts=prompts, # source 224 | completions=completions, # optional to show initial completion / target 225 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 226 | fn_next_input=None, # a function that feeds gradio components actively with the next input 227 | csv_logger=False, # True if you want to log to a CSV 228 | dataset_name=None # "/" if you want to log to the hub 229 | ) 230 | interface.launch() 231 | ``` 232 | 233 |
234 | 235 |
236 | text-generation-preference 237 | 238 | ```python 239 | from dataset_viber import AnnotatorInterFace 240 | 241 | prompts = ["Tell me something about Anthony Bourdain."] 242 | completions_a = ["Anthony Michael Bourdain was an American celebrity chef, author, and travel documentarian."] 243 | completions_b = ["Anthony Michael Bourdain was an cool guy that knew how to cook."] 244 | 245 | interface = AnnotatorInterFace.for_text_generation_preference( 246 | prompts=prompts, 247 | completions_a=completions_a, 248 | completions_b=completions_b, 249 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 250 | fn_next_input=None, # a function that feeds gradio components actively with the next input 251 | csv_logger=False, # True if you want to log to a CSV 252 | dataset_name=None # "/" if you want to log to the hub 253 | ) 254 | interface.launch() 255 | ``` 256 | 257 |
258 | 259 | #### Chat and multi-modal chat 260 | 261 | https://github.com/user-attachments/assets/fe7f0139-95a3-40e8-bc03-e37667d4f7a9 262 | 263 | [Hub dataset](https://huggingface.co/datasets/davidberenstein1957/dataset-viber-chat-generation-preference) 264 | 265 | > [!TIP] 266 | > I recommend uploading the files files to a cloud storage and using the remote URL to avoid any issues. This can be done [using Hugging Face Datasets](https://huggingface.co/docs/datasets/en/image_load#local-files). As shown in [utils](#utils). Additionally [GradioChatbot](https://www.gradio.app/docs/gradio/chatbot#behavior) shows how to use the chatbot interface for multi-modal. 267 | 268 |
269 | chat-classification 270 | 271 | ```python 272 | from dataset_viber import AnnotatorInterFace 273 | 274 | prompts = [ 275 | [ 276 | { 277 | "role": "user", 278 | "content": "Tell me something about Anthony Bourdain." 279 | }, 280 | { 281 | "role": "assistant", 282 | "content": "Anthony Michael Bourdain was an American celebrity chef, author, and travel documentarian." 283 | } 284 | ] 285 | ] 286 | 287 | interface = AnnotatorInterFace.for_chat_classification( 288 | prompts=prompts, 289 | labels=["toxic", "non-toxic"], 290 | multi_label=False, # True if you have multi-label data 291 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 292 | fn_next_input=None, # a function that feeds gradio components actively with the next input 293 | csv_logger=False, # True if you want to log to a CSV 294 | dataset_name=None # "/" if you want to log to the hub 295 | ) 296 | interface.launch() 297 | ``` 298 | 299 |
300 | 301 |
302 | chat-generation 303 | 304 | ```python 305 | from dataset_viber import AnnotatorInterFace 306 | 307 | prompts = [ 308 | [ 309 | { 310 | "role": "user", 311 | "content": "Tell me something about Anthony Bourdain." 312 | } 313 | ] 314 | ] 315 | 316 | completions = [ 317 | "Anthony Michael Bourdain was an American celebrity chef, author, and travel documentarian.", 318 | ] 319 | 320 | interface = AnnotatorInterFace.for_chat_generation( 321 | prompts=prompts, 322 | completions=completions, 323 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 324 | fn_next_input=None, # a function that feeds gradio components actively with the next input 325 | csv_logger=False, # True if you want to log to a CSV 326 | dataset_name=None # "/" if you want to log to the hub 327 | ) 328 | interface.launch() 329 | ``` 330 | 331 |
332 | 333 |
334 | chat-generation-preference 335 | 336 | ```python 337 | from dataset_viber import AnnotatorInterFace 338 | 339 | prompts = [ 340 | [ 341 | { 342 | "role": "user", 343 | "content": "Tell me something about Anthony Bourdain." 344 | } 345 | ] 346 | ] 347 | completions_a = [ 348 | "Anthony Michael Bourdain was an American celebrity chef, author, and travel documentarian.", 349 | ] 350 | completions_b = [ 351 | "Anthony Michael Bourdain was an cool guy that knew how to cook." 352 | ] 353 | 354 | interface = AnnotatorInterFace.for_chat_generation_preference( 355 | prompts=prompts, 356 | completions_a=completions_a, 357 | completions_b=completions_b, 358 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 359 | fn_next_input=None, # a function that feeds gradio components actively with the next input 360 | csv_logger=False, # True if you want to log to a CSV 361 | dataset_name=None # "/" if you want to log to the hub 362 | ) 363 | interface.launch() 364 | ``` 365 | 366 |
367 | 368 | #### Image and multi-modal 369 | 370 | 371 | 372 | [Hub dataset](https://huggingface.co/datasets/davidberenstein1957/dataset-viber-image-question-answering) 373 | 374 | > [!TIP] 375 | > I recommend uploading the files files to a cloud storage and using the remote URL to avoid any issues. This can be done [using Hugging Face Datasets](https://huggingface.co/docs/datasets/en/image_load#local-files). As shown in [utils](#utils). 376 | 377 |
378 | image-classification/multi-label-image-classification 379 | 380 | ```python 381 | from dataset_viber import AnnotatorInterFace 382 | 383 | images = [ 384 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 385 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg" 386 | ] 387 | labels = ["anthony-bourdain", "not-anthony-bourdain"] 388 | 389 | interface = AnnotatorInterFace.for_image_classification( 390 | images=images, 391 | labels=labels, 392 | multi_label=False, # True if you have multi-label data 393 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 394 | fn_next_input=None, # a function that feeds gradio components actively with the next input 395 | csv_logger=False, # True if you want to log to a CSV 396 | dataset_name=None # "/" if you want to log to the hub 397 | ) 398 | interface.launch() 399 | ``` 400 | 401 |
402 | 403 |
404 | image-generation 405 | 406 | ```python 407 | from dataset_viber import AnnotatorInterFace 408 | 409 | prompts = [ 410 | "Anthony Bourdain laughing", 411 | "David Chang wearing a suit" 412 | ] 413 | images = [ 414 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg", 415 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 416 | ] 417 | 418 | interface = AnnotatorInterFace.for_image_generation( 419 | prompts=prompts, 420 | completions=images, 421 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 422 | fn_next_input=None, # a function that feeds gradio components actively with the next input 423 | csv_logger=False, # True if you want to log to a CSV 424 | dataset_name=None # "/" if you want to log to the hub 425 | ) 426 | 427 | interface.launch() 428 | ``` 429 | 430 |
431 | 432 |
433 | image-description 434 | 435 | ```python 436 | from dataset_viber import AnnotatorInterFace 437 | 438 | images = [ 439 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 440 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg" 441 | ] 442 | descriptions = ["Anthony Bourdain laughing", "David Chang wearing a suit"] 443 | 444 | interface = AnnotatorInterFace.for_image_description( 445 | images=images, 446 | descriptions=descriptions, # optional to show initial descriptions 447 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 448 | fn_next_input=None, # a function that feeds gradio components actively with the next input 449 | csv_logger=False, # True if you want to log to a CSV 450 | dataset_name=None # "/" if you want to log to the hub 451 | ) 452 | interface.launch() 453 | ``` 454 | 455 |
456 | 457 |
458 | image-question-answering/visual-question-answering 459 | 460 | ```python 461 | from dataset_viber import AnnotatorInterFace 462 | 463 | images = [ 464 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 465 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg" 466 | ] 467 | questions = ["Who is this?", "What is he wearing?"] 468 | answers = ["Anthony Bourdain", "a suit"] 469 | 470 | interface = AnnotatorInterFace.for_image_question_answering( 471 | images=images, 472 | questions=questions, # optional to show initial questions 473 | answers=answers, # optional to show initial answers 474 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 475 | fn_next_input=None, # a function that feeds gradio components actively with the next input 476 | csv_logger=False, # True if you want to log to a CSV 477 | dataset_name=None # "/" if you want to log to the hub 478 | ) 479 | interface.launch() 480 | ``` 481 | 482 |
483 | 484 |
485 | image-generation-preference 486 | 487 | ```python 488 | from dataset_viber import AnnotatorInterFace 489 | 490 | prompts = [ 491 | "Anthony Bourdain laughing", 492 | "David Chang wearing a suit" 493 | ] 494 | 495 | images_a = [ 496 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg", 497 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 498 | ] 499 | 500 | images_b = [ 501 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Anthony_Bourdain_Peabody_2014b.jpg/440px-Anthony_Bourdain_Peabody_2014b.jpg", 502 | "https://upload.wikimedia.org/wikipedia/commons/8/85/David_Chang_David_Shankbone_2010.jpg" 503 | ] 504 | 505 | interface = AnnotatorInterFace.for_image_generation_preference( 506 | prompts=prompts, 507 | completions_a=images_a, 508 | completions_b=images_b, 509 | fn_model=None, # a callable e.g. (function or transformers pipelines) that returns `str` 510 | fn_next_input=None, # a function that feeds gradio components actively with the next input 511 | csv_logger=False, # True if you want to log to a CSV 512 | dataset_name=None # "/" if you want to log to the hub 513 | ) 514 | interface.launch() 515 | ``` 516 | 517 |
518 | 519 | ### Synthesizer 520 | 521 | > Built on top of the `distilabel` to synthesize data with models in the loop. 522 | 523 | > [!TIP] 524 | > You can use also call the synthesizer directly to generate data. `synthesizer() -> Tuple` or `Synthesizer.batch_synthesize(n) -> List[Tuple]` to get inputs for the various tasks. 525 | 526 |
527 | text-classification 528 | 529 | ```python 530 | from dataset_viber import AnnotatorInterFace 531 | from dataset_viber.synthesizer import Synthesizer 532 | 533 | synthesizer = Synthesizer.for_text_classification( 534 | prompt_context="IMDB movie reviews" 535 | ) 536 | 537 | interface = AnnotatorInterFace.for_text_classification( 538 | fn_next_input=synthesizer, 539 | labels=["positive", "negative"] 540 | ) 541 | interface.launch() 542 | ``` 543 | 544 |
545 | 546 |
547 | text-generation 548 | 549 | ```python 550 | from dataset_viber import AnnotatorInterFace 551 | from dataset_viber.synthesizer import Synthesizer 552 | 553 | synthesizer = Synthesizer.for_text_generation( 554 | prompt_context="Phone company customer support." 555 | ) 556 | 557 | interface = AnnotatorInterFace.for_text_generation( 558 | fn_next_input=synthesizer 559 | ) 560 | interface.launch() 561 | ``` 562 | 563 |
564 | 565 |
566 | chat-classification 567 | 568 | ```python 569 | from dataset_viber import AnnotatorInterFace 570 | from dataset_viber.synthesizer import Synthesizer 571 | 572 | synthesizer = Synthesizer.for_chat_classification( 573 | prompt_context="Phone company customer support." 574 | ) 575 | 576 | interface = AnnotatorInterFace.for_chat_classification( 577 | fn_next_input=synthesizer, 578 | labels=["positive", "negative"] 579 | ) 580 | interface.launch() 581 | ``` 582 | 583 |
584 | 585 |
586 | chat-generation 587 | 588 | ```python 589 | from dataset_viber import AnnotatorInterFace 590 | from dataset_viber.synthesizer import Synthesizer 591 | 592 | synthesizer = Synthesizer.for_chat_generation( 593 | prompt_context="Phone company customer support." 594 | ) 595 | 596 | interface = AnnotatorInterFace.for_chat_generation( 597 | fn_next_input=synthesizer 598 | ) 599 | interface.launch() 600 | ``` 601 | 602 |
603 | 604 |
605 | chat-generation-preference 606 | 607 | ```python 608 | from dataset_viber import AnnotatorInterFace 609 | from dataset_viber.synthesizer import Synthesizer 610 | 611 | synthesizer = Synthesizer.for_chat_generation_preference( 612 | prompt_context="Phone company customer support." 613 | ) 614 | 615 | interface = AnnotatorInterFace.for_chat_generation_preference( 616 | fn_next_input=synthesizer 617 | ) 618 | interface.launch() 619 | ``` 620 | 621 |
622 | 623 |
624 | image-classification 625 | 626 | ```python 627 | from dataset_viber import AnnotatorInterFace 628 | from dataset_viber.synthesizer import Synthesizer 629 | 630 | synthesizer = Synthesizer.for_image_classification( 631 | prompt_context="Phone company customer support." 632 | ) 633 | 634 | interface = AnnotatorInterFace.for_image_classification( 635 | fn_next_input=synthesizer, 636 | labels=["positive", "negative"] 637 | ) 638 | interface.launch() 639 | ``` 640 | 641 |
642 | 643 |
644 | 645 | image-generation 646 | 647 | ```python 648 | from dataset_viber import AnnotatorInterFace 649 | from dataset_viber.synthesizer import Synthesizer 650 | 651 | synthesizer = Synthesizer.for_image_generation( 652 | prompt_context="Phone company customer support." 653 | ) 654 | 655 | interface = AnnotatorInterFace.for_image_generation( 656 | fn_next_input=synthesizer 657 | ) 658 | interface.launch() 659 | ``` 660 | 661 |
662 | 663 |
664 | 665 | image-description 666 | 667 | ```python 668 | from dataset_viber import AnnotatorInterFace 669 | from dataset_viber.synthesizer import Synthesizer 670 | 671 | synthesizer = Synthesizer.for_image_description( 672 | prompt_context="Phone company customer support." 673 | ) 674 | 675 | interface = AnnotatorInterFace.for_image_description( 676 | fn_next_input=synthesizer 677 | ) 678 | interface.launch() 679 | ``` 680 | 681 |
682 | 683 |
684 | 685 | image-question-answering 686 | 687 | ```python 688 | from dataset_viber import AnnotatorInterFace 689 | from dataset_viber.synthesizer import Synthesizer 690 | 691 | synthesizer = Synthesizer.for_image_question_answering( 692 | prompt_context="Phone company customer support." 693 | ) 694 | 695 | interface = AnnotatorInterFace.for_image_question_answering( 696 | fn_next_input=synthesizer 697 | ) 698 | interface.launch() 699 | ``` 700 | 701 |
702 | 703 |
704 | 705 | image-generation-preference 706 | 707 | ```python 708 | from dataset_viber import AnnotatorInterFace 709 | from dataset_viber.synthesizer import Synthesizer 710 | 711 | synthesizer = Synthesizer.for_image_generation_preference( 712 | prompt_context="Phone company customer support." 713 | ) 714 | 715 | interface = AnnotatorInterFace.for_image_generation_preference( 716 | fn_next_input=synthesizer 717 | ) 718 | interface.launch() 719 | ``` 720 | 721 |
722 | 723 | ### BulkInterface 724 | 725 | > Built on top of the `Dash`, `plotly-express`, `umap-learn`, and `fast-sentence-transformers` to embed and understand your distribution and annotate your data. 726 | 727 | https://github.com/user-attachments/assets/5e96c06d-e37f-45a0-9633-1a8e714d71ed 728 | 729 | [Hub dataset](https://huggingface.co/datasets/SetFit/ag_news) 730 | 731 |
732 | text-visualization 733 | 734 | ```python 735 | from dataset_viber import BulkInterface 736 | from datasets import load_dataset 737 | 738 | ds = load_dataset("SetFit/ag_news", split="train[:2000]") 739 | 740 | interface: BulkInterface = BulkInterface.for_text_visualization( 741 | ds.to_pandas()[["text", "label_text"]], 742 | content_column='text', 743 | label_column='label_text', 744 | ) 745 | interface.launch() 746 | ``` 747 | 748 |
749 | 750 |
751 | text-classification 752 | 753 | ```python 754 | from dataset_viber import BulkInterface 755 | from datasets import load_dataset 756 | 757 | ds = load_dataset("SetFit/ag_news", split="train[:2000]") 758 | df = ds.to_pandas()[["text", "label_text"]] 759 | 760 | interface = BulkInterface.for_text_classification( 761 | dataframe=df, 762 | content_column='text', 763 | label_column='label_text', 764 | labels=df['label_text'].unique().tolist() 765 | ) 766 | interface.launch() 767 | ``` 768 | 769 |
770 | 771 |
772 | chat-visualization 773 | 774 | ```python 775 | from dataset_viber.bulk import BulkInterface 776 | from datasets import load_dataset 777 | 778 | ds = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train[:1000]") 779 | df = ds.to_pandas()[["chosen"]] 780 | 781 | interface = BulkInterface.for_chat_visualization( 782 | dataframe=df, 783 | chat_column='chosen', 784 | ) 785 | interface.launch() 786 | ``` 787 | 788 |
789 | 790 |
791 | chat-classification 792 | 793 | ```python 794 | from dataset_viber.bulk import BulkInterface 795 | from datasets import load_dataset 796 | 797 | ds = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train[:1000]") 798 | df = ds.to_pandas()[["chosen"]] 799 | 800 | interface = BulkInterface.for_chat_classification( 801 | dataframe=df, 802 | chat_column='chosen', 803 | labels=["math", "science", "history", "question seeking"], 804 | ) 805 | interface.launch() 806 | ``` 807 | 808 |
809 | 810 | ### Utils 811 | 812 |
813 | Shuffle inputs in the same order 814 | 815 | When working with multiple inputs, you might want to shuffle them in the same order. 816 | 817 | ```python 818 | def shuffle_lists(*lists): 819 | if not lists: 820 | return [] 821 | 822 | # Get the length of the first list 823 | length = len(lists[0]) 824 | 825 | # Check if all lists have the same length 826 | if not all(len(lst) == length for lst in lists): 827 | raise ValueError("All input lists must have the same length") 828 | 829 | # Create a list of indices and shuffle it 830 | indices = list(range(length)) 831 | random.shuffle(indices) 832 | 833 | # Reorder each list based on the shuffled indices 834 | return [ 835 | [lst[i] for i in indices] 836 | for lst in lists 837 | ] 838 | ``` 839 | 840 |
841 | 842 |
843 | Random swap to randomize completions 844 | 845 | When working with multiple completions, you might want to swap out the completions at the same index, where each completion index x is swapped with a random completion at the same index. This is useful for preference learning. 846 | 847 | ```python 848 | def swap_completions(*lists): 849 | # Assuming all lists are of the same length 850 | length = len(lists[0]) 851 | 852 | # Check if all lists have the same length 853 | if not all(len(lst) == length for lst in lists): 854 | raise ValueError("All input lists must have the same length") 855 | 856 | # Convert the input lists (which are tuples) to a list of lists 857 | lists = [list(lst) for lst in lists] 858 | 859 | # Iterate over each index 860 | for i in range(length): 861 | # Get the elements at index i from all lists 862 | elements = [lst[i] for lst in lists] 863 | 864 | # Randomly shuffle the elements 865 | random.shuffle(elements) 866 | 867 | # Assign the shuffled elements back to the lists 868 | for j, lst in enumerate(lists): 869 | lst[i] = elements[j] 870 | 871 | return lists 872 | ``` 873 | 874 |
875 | 876 |
877 | Load remote image URLs from Hugging Face Hub 878 | 879 | When working with images, you might want to load remote URLs from the Hugging Face Hub. 880 | 881 | ```python 882 | from datasets import Dataset, Image, load_dataset 883 | 884 | dataset = load_dataset( 885 | "my_hf_org/my_image_dataset" 886 | ).cast_column("my_image_column", Image(decode=False)) 887 | dataset[0]["my_image_column"] 888 | # {'bytes': None, 'path': 'path_to_image.jpg'} 889 | ``` 890 | 891 |
892 | 893 | ## Contribute and development setup 894 | 895 | First, [install PDM](https://pdm-project.org/latest/#installation). 896 | 897 | Then, install the environment, this will automatically create a `.venv` virtual env and install the dev environment. 898 | 899 | ```bash 900 | pdm install 901 | ``` 902 | 903 | Lastly, run pre-commit for formatting on commit. 904 | 905 | ```bash 906 | pre-commit install 907 | ``` 908 | 909 | Follow this [guide on making first contributions](https://github.com/firstcontributions/first-contributions?tab=readme-ov-file#first-contributions). 910 | 911 | ## References 912 | 913 | ### Logo 914 | 915 | Keyboard icons created by srip - Flaticon 916 | 917 | ### Inspirations 918 | 919 | - 920 | - 921 | - 922 | - 923 | --------------------------------------------------------------------------------