str:
37 | for size, category in AVAILABLE_SIZE_CATEGORIES.items():
38 | if input_size < size:
39 | return category
40 | return "n>1T"
41 |
42 |
43 | class DistilabelDatasetCard(DatasetCard):
44 | """A `DatasetCard` subclass that uses the Distilabel template by default."""
45 |
46 | default_template_path = TEMPLATE_DISTILABEL_DATASET_CARD_PATH
47 |
--------------------------------------------------------------------------------
/src/distilabel/utils/chat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any
16 |
17 |
18 | def is_openai_format(input: Any) -> bool:
19 | """Checks if the input is in OpenAI chat-like format:
20 |
21 | ```python
22 | [
23 | {"role": "user", "content": "Hello!"},
24 | {"role": "assistant", "content": "Hi! How can I help you?"},
25 | ]
26 | ```
27 |
28 | Args:
29 | input: The input to check.
30 |
31 | Returns:
32 | A boolean indicating if the input is in OpenAI chat-like format.
33 | """
34 | if not isinstance(input, list):
35 | return False
36 | return all(
37 | isinstance(x, dict) and "role" in x.keys() and "content" in x.keys()
38 | for x in input
39 | )
40 |
--------------------------------------------------------------------------------
/src/distilabel/utils/files.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from typing import Callable, List, Optional
17 |
18 |
19 | def list_files_in_dir(
20 | dir_path: Path, key: Optional[Callable] = lambda x: int(x.stem)
21 | ) -> List[Path]:
22 | """List all files in a directory.
23 |
24 | Args:
25 | dir_path: Path to the directory.
26 | key: A function to sort the files. Defaults to sorting by the integer value of the file name.
27 | This is useful when loading files from the cache, as the name will be numbered.
28 |
29 | Returns:
30 | A list of file names in the directory.
31 | """
32 | return [f for f in sorted(dir_path.iterdir(), key=key) if f.is_file()]
33 |
--------------------------------------------------------------------------------
/src/distilabel/utils/huggingface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from pathlib import Path
17 | from typing import Final
18 |
19 | from huggingface_hub import constants
20 |
21 | HF_TOKEN_ENV_VAR: Final[str] = "HF_TOKEN"
22 |
23 |
24 | def get_hf_token(cls_name: str, token_arg: str) -> str:
25 | """Get the token for the hugging face API.
26 |
27 | Tries to extract it from the environment variable, if it is not found
28 | it tries to read it from the file using 'huggingface_hub',
29 | and if not possible raises a ValueError.
30 |
31 | Args:
32 | cls_name: Name of the class/function that requires the token.
33 | token_arg: Argument name to use in the error message, normally
34 | is "token" or "api_key".
35 |
36 | Raises:
37 | ValueError: If the token is not found in the file.
38 |
39 | Returns:
40 | The token for the hugging face API.
41 | """
42 | token = os.getenv(HF_TOKEN_ENV_VAR)
43 | if token is None:
44 | if not Path(constants.HF_TOKEN_PATH).exists():
45 | raise ValueError(
46 | f"To use `{cls_name}` an API key must be provided via `{token_arg}`,"
47 | f" set the environment variable `{HF_TOKEN_ENV_VAR}` or use the"
48 | " `huggingface-hub` CLI to login with `huggingface-cli login`."
49 | )
50 | with open(constants.HF_TOKEN_PATH) as f:
51 | token = f.read().strip()
52 | return token
53 |
--------------------------------------------------------------------------------
/src/distilabel/utils/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import base64
16 | import io
17 | from typing import TYPE_CHECKING
18 |
19 | if TYPE_CHECKING:
20 | from PIL import Image
21 |
22 |
23 | def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
24 | """Converts a PIL Image to a base64 encoded string."""
25 | buffered = io.BytesIO()
26 | image.save(buffered, format=image_format)
27 | return base64.b64encode(buffered.getvalue()).decode("utf-8")
28 |
--------------------------------------------------------------------------------
/src/distilabel/utils/lists.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List
16 |
17 |
18 | def flatten_responses(responses: List[List[str]]) -> List[str]:
19 | """Flattens the list of lists of strings into a single list of strings.
20 |
21 | Args:
22 | responses: The list of lists of strings to flatten.
23 |
24 | Returns:
25 | A single list of strings containing the last item of each list.
26 | """
27 | return [response[-1] for response in responses]
28 |
--------------------------------------------------------------------------------
/src/distilabel/utils/mkdocs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - toc
4 | - navigation
5 | ---
6 | # {{ title }}
7 |
8 | {{ description }}
9 |
10 |
11 |
12 | {% for component in components %}
13 | - {% if component.docstring.icon %}{{ component.docstring.icon }}{% else %}{{ default_icon }}{% endif %}{ .lg .middle } __{{ component.name }}__
14 |
15 | ---
16 |
17 | {{ component.docstring.short_description }}
18 |
19 | [:octicons-arrow-right-24: {{ component.name }}]({{ component.name | lower }}.md){ .bottom }
20 | {% endfor %}
21 |
22 |
23 |
--------------------------------------------------------------------------------
/src/distilabel/utils/mkdocs/templates/components-gallery/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | - toc
5 | ---
6 | # Components Gallery
7 |
8 | ??? info "Category Overview"
9 | | Icon | Category | Description |
10 | |----------------------------|------------|-------------------------------------------------------------------|
11 | | :material-step-forward: | Steps | Steps are used for data manipulation. |
12 | | :material-check-outline: | Tasks | Tasks allow performing data generation, annotation, and more. |
13 | | :material-brain: | LLMs | Explore all available Large Language Models integrated with distilabel. |
14 | | :material-vector-line: | Embeddings | Explore all available Embeddings Models integrated with distilabel. |
15 |
16 |
17 |
18 | - :material-step-forward:{ .lg .middle } __Steps__
19 |
20 | ---
21 |
22 | Explore all the available `Step`s that can be used for data manipulation.
23 |
24 | [:octicons-arrow-right-24: Steps](steps/index.md){ .bottom }
25 |
26 | - :material-check-outline:{ .lg .middle } __Tasks__
27 |
28 | ---
29 |
30 | Explore all the available `Task`s that can be used with an `LLM` to perform data generation, annotation, and more.
31 |
32 | [:octicons-arrow-right-24: Tasks](tasks/index.md)
33 |
34 | - :material-brain:{ .lg .middle } __LLMs__
35 |
36 | ---
37 |
38 | Explore all the available `LLM`s integrated with `distilabel`.
39 |
40 | [:octicons-arrow-right-24: LLMs](llms/index.md){ .bottom }
41 |
42 | - :material-image:{ .lg .middle } __ImageGenerationModels__
43 |
44 | ---
45 |
46 | Explore all the available `ImageGenerationModels`s integrated with `distilabel`.
47 |
48 | [:octicons-arrow-right-24: ImageGenerationModels](image_generation/index.md){ .bottom }
49 |
50 | - :material-vector-line:{ .lg .middle } __Embeddings__
51 |
52 | ---
53 |
54 | Explore all the available `Embeddings` models integrated with `distilabel`.
55 |
56 | [:octicons-arrow-right-24: Embeddings](embeddings/index.md){ .bottom }
57 |
58 |
59 |
--------------------------------------------------------------------------------
/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | ---
5 | # {{ llm.name }}
6 |
7 | {% if llm.docstring.short_description %}
8 | {{ llm.docstring.short_description }}
9 | {% endif %}
10 |
11 | {% if llm.docstring.description %}
12 | {{ llm.docstring.description }}
13 | {% endif %}
14 |
15 | {% if llm.docstring.note %}
16 | ### Note
17 | {{ llm.docstring.note }}
18 | {% endif %}
19 |
20 | {% if llm.docstring.attributes %}
21 | ### Attributes
22 | {% for attribute_name, description in llm.docstring.attributes.items() %}
23 | - **{{ attribute_name }}**: {{ description }}
24 | {% endfor %}
25 | {% endif %}
26 |
27 |
28 | {% if llm.docstring.runtime_parameters %}
29 | ### Runtime Parameters
30 | {% for parameter_name, description in llm.docstring.runtime_parameters.items() %}
31 | - **{{ parameter_name }}**: {{ description }}
32 | {% endfor %}
33 | {% endif %}
34 |
35 | {% if llm.docstring.examples %}
36 | ### Examples
37 |
38 | {% for example_title, code in llm.docstring.examples.items() %}
39 | #### {{ example_title }}
40 | ```python
41 | {{ code | replace("\n", "\n") }}
42 | ```
43 | {% endfor %}
44 | {% endif %}
45 |
46 | {% if llm.docstring.references %}
47 | ### References
48 | {% for reference, url in llm.docstring.references.items() %}
49 | - [{{ reference }}]({{ url }})
50 | {% endfor %}
51 | {% endif %}
52 |
--------------------------------------------------------------------------------
/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | ---
5 | # {{ step.name }}
6 | {% if step.docstring.short_description %}
7 | {{ step.docstring.short_description }}
8 | {% endif %}
9 |
10 | {% if step.docstring.description %}
11 | {{ step.docstring.description }}
12 | {% endif %}
13 |
14 | {% if step.docstring.note %}
15 | ### Note
16 | {{ step.docstring.note }}
17 | {% endif %}
18 |
19 | {% if step.docstring.attributes %}
20 | ### Attributes
21 | {% for attribute_name, description in step.docstring.attributes.items() %}
22 | - **{{ attribute_name }}**: {{ description }}
23 | {% endfor %}
24 | {% endif %}
25 |
26 | {% if step.docstring.runtime_parameters %}
27 | ### Runtime Parameters
28 | {% for parameter_name, description in step.docstring.runtime_parameters.items() %}
29 | - **{{ parameter_name }}**: {{ description }}
30 | {% endfor %}
31 | {% endif %}
32 |
33 | ### Input & Output Columns
34 |
35 | ``` mermaid
36 | {{ mermaid_diagram }}
37 | ```
38 |
39 | {% if step.docstring.input_columns %}
40 | #### Inputs
41 |
42 | {% for column_name, value in step.docstring.input_columns.items() %}
43 | - **{{ column_name }}** ({{ value[0] }}): {{ value[1] }}
44 | {% endfor %}
45 | {% endif %}
46 |
47 | {% if step.docstring.output_columns %}
48 | #### Outputs
49 |
50 | {% for column_name, value in step.docstring.output_columns.items() %}
51 | - **{{ column_name }}** ({{ value[0] }}): {{ value[1] }}
52 | {% endfor %}
53 | {% endif %}
54 |
55 |
56 | {% if step.docstring.examples %}
57 | ### Examples
58 |
59 | {% for example_title, code in step.docstring.examples.items() %}
60 | #### {{ example_title }}
61 | ```python
62 | {{ code | replace("\n", "\n") }}
63 | ```
64 | {% endfor %}
65 | {% endif %}
66 |
67 | {% if step.docstring.references %}
68 | ### References
69 | {% for reference, url in step.docstring.references.items() %}
70 | - [{{ reference }}]({{ url }})
71 | {% endfor %}
72 | {% endif %}
73 |
74 |
--------------------------------------------------------------------------------
/src/distilabel/utils/notebook.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def in_notebook() -> bool:
17 | """Checks if the current code is being executed from a Jupyter Notebook.
18 | This is useful for better handling the `asyncio` events under `nest_asyncio`,
19 | as Jupyter Notebook runs a separate event loop.
20 |
21 | Returns:
22 | Whether the current code is being executed from a Jupyter Notebook.
23 |
24 | References:
25 | - https://stackoverflow.com/a/22424821
26 | """
27 | try:
28 | from IPython import get_ipython
29 |
30 | if "IPKernelApp" not in get_ipython().config: # pragma: no cover
31 | return False
32 | except ImportError:
33 | return False
34 | except AttributeError:
35 | return False
36 | return True
37 |
--------------------------------------------------------------------------------
/src/distilabel/utils/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 |
18 | def script_executed_in_ray_cluster() -> bool:
19 | """Checks if running in a Ray cluster. The checking is based on the presence of
20 | typical Ray environment variables that are set in each node of the cluster.
21 |
22 | Returns:
23 | `True` if running on a Ray cluster, `False` otherwise.
24 | """
25 | return all(
26 | env in os.environ
27 | for env in ["RAY_NODE_TYPE_NAME", "RAY_CLUSTER_NAME", "RAY_ADDRESS"]
28 | )
29 |
--------------------------------------------------------------------------------
/src/distilabel/utils/requirements.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING, Callable, List, TypeVar, Union
16 |
17 | if TYPE_CHECKING:
18 | from distilabel.steps.base import _Step
19 |
20 | S = TypeVar("S", bound="_Step")
21 |
22 |
23 | def requirements(requirements: Union[List[str]]) -> Callable[[S], S]:
24 | """Decorator to add requirements to a Step.
25 |
26 | When creating a custom step for a Pipeline that requires additional packages to be installed,
27 | (in case you want to distribute the pipeline) you can use this decorator to add the requirements.
28 |
29 | Args:
30 | requirements: List of requirements to be added to the step.
31 |
32 | Returns:
33 | The step with the requirements added.
34 |
35 | Example:
36 |
37 | ```python
38 | @requirements(["my_library>=1.0.1"])
39 | class CustomStep(Step):
40 | @property
41 | def inputs(self) -> List[str]:
42 | return ["instruction"]
43 |
44 | @property
45 | def outputs(self) -> List[str]:
46 | return ["response"]
47 |
48 | def process(self, inputs: StepInput) -> StepOutput: # type: ignore
49 | for input in inputs:
50 | input["response"] = "unit test"
51 | yield inputs
52 | ```
53 | """
54 |
55 | def decorator(step: S) -> S:
56 | step.requirements = requirements
57 | return step
58 |
59 | return decorator
60 |
--------------------------------------------------------------------------------
/src/distilabel/utils/template.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 | from distilabel.errors import DistilabelUserError
18 |
19 |
20 | def check_column_in_template(
21 | column: str, template: str, page: str = "components-gallery/tasks/textgeneration/"
22 | ) -> None:
23 | """Checks if a column is present in the template, and raises an error if it isn't.
24 |
25 | Args:
26 | column: The column name to check in the template.
27 | template: The template of the Task to be checked, the input from the user.
28 | page: The page to redirect the user for help . Defaults to "components-gallery/tasks/textgeneration/".
29 |
30 | Raises:
31 | DistilabelUserError: Custom error if the column is not present in the template.
32 | """
33 | pattern = (
34 | r"(?:{%.*?\b"
35 | + re.escape(column)
36 | + r"\b.*?%}|{{\s*"
37 | + re.escape(column)
38 | + r"\s*}})"
39 | )
40 | if not re.search(pattern, template):
41 | raise DistilabelUserError(
42 | (
43 | f"You required column name '{column}', but is not present in the template, "
44 | "ensure the 'columns' match with the 'template' to avoid errors."
45 | ),
46 | page=page,
47 | )
48 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | from typing import TYPE_CHECKING, List
17 |
18 | import pytest
19 |
20 | if TYPE_CHECKING:
21 | from _pytest.config import Config
22 | from _pytest.nodes import Item
23 |
24 |
25 | def pytest_configure(config: "Config") -> None:
26 | config.addinivalue_line(
27 | "markers",
28 | "skip_python_versions(versions): mark test to be skipped on specified Python versions",
29 | )
30 |
31 |
32 | def pytest_collection_modifyitems(config: "Config", items: List["Item"]) -> None:
33 | current_version = f"{sys.version_info.major}.{sys.version_info.minor}"
34 | for item in items:
35 | skip_versions_marker = item.get_closest_marker("skip_python_versions")
36 | if skip_versions_marker:
37 | versions_to_skip = skip_versions_marker.args[0]
38 | if current_version in versions_to_skip:
39 | skip_reason = f"Test not supported on Python {current_version}"
40 | item.add_marker(pytest.mark.skip(reason=skip_reason))
41 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/integration/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import tempfile
17 | from typing import Generator
18 |
19 | import pytest
20 |
21 |
22 | @pytest.fixture(autouse=True)
23 | def temp_cache_dir() -> Generator[None, None, None]:
24 | """Set the cache directory to a temporary directory for all tests."""
25 | with tempfile.TemporaryDirectory() as tmpdirname:
26 | os.environ["DISTILABEL_CACHE_DIR"] = tmpdirname
27 | yield
28 |
--------------------------------------------------------------------------------
/tests/integration/test_branching_missaligmnent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from distilabel.pipeline import Pipeline
18 | from distilabel.steps import GroupColumns, LoadDataFromDicts, StepInput, step
19 |
20 | if TYPE_CHECKING:
21 | from distilabel.steps import StepOutput
22 |
23 |
24 | @step(inputs=["instruction"], outputs=["response"])
25 | def FailAlways(_: StepInput) -> "StepOutput":
26 | raise Exception("This step always fails")
27 |
28 |
29 | @step(inputs=["instruction"], outputs=["response"])
30 | def SucceedAlways(inputs: StepInput) -> "StepOutput":
31 | for input in inputs:
32 | input["response"] = "This step always succeeds"
33 | yield inputs
34 |
35 |
36 | def test_branching_missalignment_because_step_fails_processing_batch() -> None:
37 | with Pipeline(name="") as pipeline:
38 | load_data = LoadDataFromDicts(data=[{"instruction": i} for i in range(20)])
39 |
40 | fail = FailAlways()
41 | succeed = SucceedAlways()
42 | combine = GroupColumns(columns=["response"])
43 |
44 | load_data >> [fail, succeed] >> combine
45 |
46 | distiset = pipeline.run(use_cache=False)
47 |
48 | assert (
49 | distiset["default"]["train"]["grouped_response"]
50 | == [[None, "This step always succeeds"]] * 20
51 | )
52 |
--------------------------------------------------------------------------------
/tests/integration/test_cache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING, List
16 |
17 | import numpy as np
18 | import pytest
19 |
20 | from distilabel.pipeline import Pipeline
21 | from distilabel.steps import GeneratorStep, StepInput, step
22 |
23 | if TYPE_CHECKING:
24 | from distilabel.steps import GeneratorStepOutput, StepOutput
25 |
26 |
27 | class NumpyBigArrayGenerator(GeneratorStep):
28 | num_batches: int
29 |
30 | @property
31 | def outputs(self) -> List[str]:
32 | return ["array"]
33 |
34 | def process(self, offset: int = 0) -> "GeneratorStepOutput":
35 | for i in range(self.num_batches):
36 | yield (
37 | [{"array": np.random.randn(256)} for _ in range(self.batch_size)], # type: ignore
38 | i == self.num_batches - 1,
39 | ) # type: ignore
40 |
41 |
42 | @step(step_type="global")
43 | def ReceiveArrays(inputs: StepInput) -> "StepOutput":
44 | yield inputs
45 |
46 |
47 | @pytest.mark.benchmark
48 | def test_cache_time() -> None:
49 | with Pipeline(name="dummy") as pipeline:
50 | numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=100)
51 |
52 | receive_arrays = ReceiveArrays()
53 |
54 | numpy_generator >> receive_arrays
55 |
56 | pipeline.run(use_cache=False)
57 |
--------------------------------------------------------------------------------
/tests/integration/test_deduplication.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.pipeline import Pipeline
16 | from distilabel.steps import LoadDataFromDicts, MinHashDedup
17 |
18 |
19 | def test_minhash_deduplication() -> None:
20 | with Pipeline() as pipeline:
21 | ds_size = 1000
22 | batch_size = 500
23 | data = LoadDataFromDicts(
24 | data=[
25 | {"text": "This is a test document."},
26 | {"text": "This document is a test."},
27 | {"text": "Test document for duplication."},
28 | {"text": "Document for duplication test."},
29 | {"text": "This is another unique document."},
30 | ]
31 | * (ds_size // 5),
32 | batch_size=batch_size,
33 | )
34 | minhash = MinHashDedup(
35 | tokenizer="ngrams",
36 | n=2,
37 | threshold=0.9,
38 | storage="disk",
39 | input_batch_size=batch_size,
40 | )
41 | data >> minhash
42 |
43 | distiset = pipeline.run(use_cache=False)
44 | ds = distiset["default"]["train"]
45 | ds_dedup = ds.filter(lambda x: x["keep_row_after_minhash_filtering"])
46 | assert len(ds_dedup) == 4
47 |
48 |
49 | if __name__ == "__main__":
50 | test_minhash_deduplication()
51 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/cli/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/cli/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/cli/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | current_dir = os.path.dirname(os.path.abspath(__file__))
18 |
19 | TEST_PIPELINE_PATH = os.path.join(current_dir, "test_pipeline.yaml")
20 |
--------------------------------------------------------------------------------
/tests/unit/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | DISTILABEL_RUN_SLOW_TESTS = os.getenv("DISTILABEL_RUN_SLOW_TESTS", False)
18 |
--------------------------------------------------------------------------------
/tests/unit/mixins/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/models/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/models/embeddings/test_sentence_transformers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.models.embeddings.sentence_transformers import (
16 | SentenceTransformerEmbeddings,
17 | )
18 |
19 |
20 | class TestSentenceTransformersEmbeddings:
21 | def test_model_name(self) -> None:
22 | embeddings = SentenceTransformerEmbeddings(
23 | model="sentence-transformers/all-MiniLM-L6-v2"
24 | )
25 |
26 | assert embeddings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
27 |
28 | def test_encode(self) -> None:
29 | embeddings = SentenceTransformerEmbeddings(
30 | model="sentence-transformers/all-MiniLM-L6-v2"
31 | )
32 |
33 | embeddings.load()
34 |
35 | results = embeddings.encode(
36 | inputs=[
37 | "Hello, how are you?",
38 | "What a nice day!",
39 | "I hear that llamas are very popular now.",
40 | ]
41 | )
42 |
43 | for result in results:
44 | assert len(result) == 384
45 |
--------------------------------------------------------------------------------
/tests/unit/models/embeddings/test_vllm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from unittest.mock import MagicMock, Mock
16 |
17 | from distilabel.models.embeddings.vllm import vLLMEmbeddings
18 |
19 |
20 | # @patch("vllm.entrypoints.LLM")
21 | class TestSentenceTransformersEmbeddings:
22 | model_name = "group/model-name"
23 |
24 | def test_model_name(self) -> None:
25 | embeddings = vLLMEmbeddings(model=self.model_name)
26 |
27 | assert embeddings.model_name == self.model_name
28 |
29 | def test_encode(self) -> None:
30 | embeddings = vLLMEmbeddings(model=self.model_name)
31 |
32 | # the loading should be done here, it's just mocked
33 | # embeddings.load()
34 | embeddings._model = MagicMock()
35 |
36 | mocked_response = Mock(outputs=Mock(embedding=[0.1] * 10))
37 | embeddings._model.encode = Mock(
38 | side_effect=lambda x: [mocked_response for _ in range(len(x))]
39 | )
40 |
41 | results = embeddings.encode(
42 | inputs=[
43 | "Hello, how are you?",
44 | "What a nice day!",
45 | "I hear that llamas are very popular now.",
46 | ]
47 | )
48 |
49 | for result in results:
50 | assert len(result) == 10
51 |
--------------------------------------------------------------------------------
/tests/unit/models/image_generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/models/image_generation/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/models/llms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/models/llms/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/models/llms/test_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytest
16 |
17 | from distilabel.errors import DistilabelNotImplementedError
18 | from tests.unit.conftest import DummyLLM
19 |
20 |
21 | class TestLLM:
22 | def test_offline_batch_generate_raise_distilabel_not_implemented_error(
23 | self,
24 | ) -> None:
25 | llm = DummyLLM()
26 |
27 | with pytest.raises(DistilabelNotImplementedError):
28 | llm.offline_batch_generate()
29 |
--------------------------------------------------------------------------------
/tests/unit/models/llms/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any
16 |
17 | from pydantic import BaseModel, PrivateAttr
18 |
19 |
20 | class DummyUserDetail(BaseModel):
21 | name: str
22 | age: int
23 | _raw_response: Any = PrivateAttr()
24 |
25 | def __init__(self, **data):
26 | super().__init__(**data)
27 | self._raw_response = data.get("_raw_response")
28 |
--------------------------------------------------------------------------------
/tests/unit/models/mixins/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/argilla/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/clustering/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/clustering/test_dbscan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from distilabel.steps.clustering.dbscan import DBSCAN
17 |
18 |
19 | class TestDBSCAN:
20 | def test_process(self) -> None:
21 | step = DBSCAN(n_jobs=1, eps=0.5, min_samples=5)
22 | step.load()
23 |
24 | results = next(
25 | step.process(
26 | inputs=[
27 | {"projection": [0.1, -0.4]},
28 | {"projection": [-0.3, 0.9]},
29 | {"projection": [0.6, 0.2]},
30 | {"projection": [-0.2, -0.6]},
31 | {"projection": [0.9, 0.1]},
32 | {"projection": [0.4, -0.7]},
33 | {"projection": [-0.5, 0.3]},
34 | {"projection": [0.7, 0.5]},
35 | {"projection": [-0.1, -0.9]},
36 | ]
37 | )
38 | )
39 | assert all(result["cluster_label"] == -1 for result in results)
40 |
--------------------------------------------------------------------------------
/tests/unit/steps/clustering/test_umap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | from distilabel.steps.clustering.umap import UMAP
18 |
19 |
20 | class TestUMAP:
21 | def test_process(self) -> None:
22 | n_components = 2
23 | step = UMAP(n_jobs=1, n_components=n_components)
24 | step.load()
25 |
26 | results = next(
27 | step.process(
28 | inputs=[
29 | {"embedding": [0.1, -0.4, 0.7, 0.2]},
30 | {"embedding": [-0.3, 0.9, 0.1, -0.5]},
31 | {"embedding": [0.6, 0.2, -0.1, 0.8]},
32 | {"embedding": [-0.2, -0.6, 0.4, 0.3]},
33 | {"embedding": [0.9, 0.1, -0.3, -0.2]},
34 | {"embedding": [0.4, -0.7, 0.6, 0.1]},
35 | {"embedding": [-0.5, 0.3, -0.2, 0.9]},
36 | {"embedding": [0.7, 0.5, -0.4, -0.1]},
37 | {"embedding": [-0.1, -0.9, 0.8, 0.6]},
38 | ]
39 | )
40 | )
41 | assert all(isinstance(result["projection"], np.ndarray) for result in results)
42 | assert all(len(result["projection"]) == n_components for result in results)
43 |
--------------------------------------------------------------------------------
/tests/unit/steps/columns/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/columns/test_combine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.constants import DISTILABEL_METADATA_KEY
16 | from distilabel.steps.columns.combine import CombineOutputs
17 |
18 |
19 | class TestCombineOutputs:
20 | def test_process(self) -> None:
21 | combine = CombineOutputs()
22 |
23 | output = next(
24 | combine.process(
25 | [
26 | {
27 | "a": 1,
28 | "b": 2,
29 | DISTILABEL_METADATA_KEY: {"model": "model-1", "a": 1},
30 | }
31 | ],
32 | [
33 | {
34 | "c": 3,
35 | "d": 4,
36 | DISTILABEL_METADATA_KEY: {"model": "model-2", "b": 1},
37 | }
38 | ],
39 | )
40 | )
41 |
42 | assert output == [
43 | {
44 | "a": 1,
45 | "b": 2,
46 | "c": 3,
47 | "d": 4,
48 | DISTILABEL_METADATA_KEY: {
49 | "model": ["model-1", "model-2"],
50 | "a": 1,
51 | "b": 1,
52 | },
53 | }
54 | ]
55 |
--------------------------------------------------------------------------------
/tests/unit/steps/columns/test_keep.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.pipeline.local import Pipeline
16 | from distilabel.steps.columns.keep import KeepColumns
17 |
18 |
19 | class TestKeepColumns:
20 | def test_init(self) -> None:
21 | task = KeepColumns(
22 | name="keep-columns",
23 | columns=["a", "b"],
24 | pipeline=Pipeline(name="unit-test-pipeline"),
25 | )
26 | assert task.inputs == ["a", "b"]
27 | assert task.outputs == ["a", "b"]
28 |
29 | def test_process(self) -> None:
30 | combine = KeepColumns(
31 | name="keep-columns",
32 | columns=["a", "b"],
33 | pipeline=Pipeline(name="unit-test-pipeline"),
34 | )
35 | output = next(combine.process([{"a": 1, "b": 2, "c": 3, "d": 4}]))
36 | assert output == [{"a": 1, "b": 2}]
37 |
38 | def test_process_preserve_order(self) -> None:
39 | combine = KeepColumns(
40 | name="keep-columns",
41 | columns=["b", "a"],
42 | pipeline=Pipeline(name="unit-test-pipeline"),
43 | )
44 | output = next(combine.process([{"a": 1, "b": 2, "c": 3, "d": 4}]))
45 | assert output == [{"b": 2, "a": 1}]
46 |
--------------------------------------------------------------------------------
/tests/unit/steps/columns/test_merge.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional
16 |
17 | import pytest
18 |
19 | from distilabel.steps.columns.merge import MergeColumns
20 |
21 |
22 | class TestMergeColumns:
23 | @pytest.mark.parametrize(
24 | "output_column, expected",
25 | [
26 | (None, "merged_column"),
27 | ("queries", "queries"),
28 | ],
29 | )
30 | def test_init(self, output_column: Optional[str], expected: str) -> None:
31 | task = MergeColumns(columns=["query", "queries"], output_column=output_column)
32 |
33 | assert task.inputs == ["query", "queries"]
34 | assert task.outputs == [expected]
35 |
36 | @pytest.mark.parametrize(
37 | "columns",
38 | [
39 | [{"query": 1, "queries": 2}],
40 | [{"query": 1, "queries": [2]}],
41 | [{"query": [1], "queries": [2]}],
42 | ],
43 | )
44 | def test_process(self, columns: List[Dict[str, Any]]) -> None:
45 | combiner = MergeColumns(
46 | columns=["query", "queries"],
47 | )
48 | output: List[Dict[str, Any]] = next(combiner.process(columns))
49 | assert output == [{"merged_column": [1, 2]}]
50 |
--------------------------------------------------------------------------------
/tests/unit/steps/columns/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.constants import DISTILABEL_METADATA_KEY
16 | from distilabel.steps.columns.utils import merge_distilabel_metadata
17 |
18 |
19 | def test_merge_distilabel_metadata() -> None:
20 | rows = [
21 | {DISTILABEL_METADATA_KEY: {"a": 1, "b": 1}},
22 | {DISTILABEL_METADATA_KEY: {"a": 2, "b": 2}},
23 | ]
24 | result = merge_distilabel_metadata(*rows)
25 | assert result == {"a": [1, 2], "b": [1, 2]}
26 |
27 |
28 | def test_merge_distilabel_metadata_list() -> None:
29 | rows = [
30 | {
31 | DISTILABEL_METADATA_KEY: [
32 | {"a": 1.0, "b": 1.0},
33 | {"a": 1.1, "b": 1.1},
34 | {"a": 1.2, "b": 1.2},
35 | ]
36 | },
37 | {
38 | DISTILABEL_METADATA_KEY: [
39 | {"a": 2.0, "b": 2.0},
40 | {"a": 2.1, "b": 2.1},
41 | {"a": 2.2, "b": 2.2},
42 | ]
43 | },
44 | ]
45 | result = merge_distilabel_metadata(*rows)
46 | assert result == [
47 | {"a": 1.0, "b": 1.0},
48 | {"a": 1.1, "b": 1.1},
49 | {"a": 1.2, "b": 1.2},
50 | {"a": 2.0, "b": 2.0},
51 | {"a": 2.1, "b": 2.1},
52 | {"a": 2.2, "b": 2.2},
53 | ]
54 |
--------------------------------------------------------------------------------
/tests/unit/steps/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/embeddings/test_embedding_generation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.models.embeddings.sentence_transformers import (
16 | SentenceTransformerEmbeddings,
17 | )
18 | from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
19 |
20 |
21 | class TestEmbeddingGeneration:
22 | def test_process(self) -> None:
23 | step = EmbeddingGeneration(
24 | embeddings=SentenceTransformerEmbeddings(
25 | model="sentence-transformers/all-MiniLM-L6-v2"
26 | )
27 | )
28 |
29 | step.load()
30 |
31 | results = next(
32 | step.process(
33 | inputs=[
34 | {"text": "Hello, how are you?"},
35 | {"text": "What a nice day!"},
36 | {"text": "I hear that llamas are very popular now."},
37 | ]
38 | )
39 | )
40 |
41 | step.unload()
42 |
43 | for result, text in zip(
44 | results,
45 | [
46 | "Hello, how are you?",
47 | "What a nice day!",
48 | "I hear that llamas are very popular now.",
49 | ],
50 | ):
51 | assert len(result["embedding"]) == 384
52 | assert result["text"] == text
53 | assert result["model_name"] == "sentence-transformers/all-MiniLM-L6-v2"
54 |
--------------------------------------------------------------------------------
/tests/unit/steps/filtering/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/formatting/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/formatting/test_conversation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.pipeline.local import Pipeline
16 | from distilabel.steps.formatting.conversation import ConversationTemplate
17 |
18 |
19 | class TestConversationTemplate:
20 | def test_process(self) -> None:
21 | conversation_template = ConversationTemplate(
22 | name="conversation_template",
23 | pipeline=Pipeline(name="unit-test"),
24 | )
25 |
26 | result = next(
27 | conversation_template.process([{"instruction": "Hello", "response": "Hi"}])
28 | )
29 |
30 | assert result == [
31 | {
32 | "instruction": "Hello",
33 | "response": "Hi",
34 | "conversation": [
35 | {"role": "user", "content": "Hello"},
36 | {"role": "assistant", "content": "Hi"},
37 | ],
38 | }
39 | ]
40 |
--------------------------------------------------------------------------------
/tests/unit/steps/generators/test_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 |
16 | from distilabel.pipeline.local import Pipeline
17 | from distilabel.steps.generators.data import LoadDataFromDicts
18 |
19 |
20 | class TestLoadDataFromDicts:
21 | data = [{"instruction": "test"}] * 10
22 |
23 | def test_init(self) -> None:
24 | pipeline = Pipeline(name="unit-test-pipeline")
25 | data: list[dict[str, str]] = self.data
26 | task = LoadDataFromDicts(
27 | name="task", pipeline=pipeline, data=data, batch_size=10
28 | )
29 | assert task.data == data
30 | assert task.batch_size == 10
31 |
32 | def test_process(self) -> None:
33 | pipeline = Pipeline(name="unit-test-pipeline")
34 | data: list[dict[str, str]] = self.data
35 | batch_size = 1
36 | task = LoadDataFromDicts(
37 | name="task", pipeline=pipeline, data=data, batch_size=batch_size
38 | )
39 |
40 | result = task.process()
41 | for i in range(len(self.data) - batch_size):
42 | assert next(result) == ([self.data[i]], False)
43 | assert next(result) == ([self.data[-batch_size]], True)
44 | with pytest.raises(StopIteration):
45 | next(result)
46 |
--------------------------------------------------------------------------------
/tests/unit/steps/generators/test_data_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List
16 |
17 | import pytest
18 |
19 | from distilabel.steps.generators.data_sampler import DataSampler
20 |
21 |
22 | @pytest.mark.parametrize(
23 | "samples, size, batch_size, expected",
24 | [
25 | (10, 2, 4, [4, 4, 2]),
26 | (7, 5, 6, [6, 1]),
27 | (20, 5, 20, [20]),
28 | (20, 50, 8, [8, 8, 4]),
29 | ],
30 | )
31 | def test_generator_and_sampler(
32 | samples: int, size: int, batch_size: int, expected: List[int]
33 | ):
34 | sampler = DataSampler(
35 | data=[{"sample": f"sample {i}"} for i in range(30)],
36 | size=size,
37 | samples=samples,
38 | batch_size=batch_size,
39 | )
40 | sampler.load()
41 | results = [item[0] for item in sampler.process()]
42 | assert len(results) == len(expected)
43 | assert len(results[0]) == batch_size
44 | for i, result in enumerate(results):
45 | assert len(result) == expected[i]
46 |
--------------------------------------------------------------------------------
/tests/unit/steps/generators/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict, List, Union
16 |
17 | import pandas as pd
18 | import pytest
19 | from datasets import Dataset
20 |
21 | from distilabel.pipeline.local import Pipeline
22 | from distilabel.steps.generators.utils import make_generator_step
23 |
24 | data = [{"instruction": "Tell me a joke."}] * 10
25 |
26 |
27 | @pytest.mark.parametrize("dataset", (data, Dataset.from_list(data), pd.DataFrame(data)))
28 | def test_make_generator_step(
29 | dataset: Union[Dataset, pd.DataFrame, List[Dict[str, str]]],
30 | ) -> None:
31 | batch_size = 5
32 | load_dataset = make_generator_step(
33 | dataset, batch_size=batch_size, output_mappings={"instruction": "other"}
34 | )
35 | load_dataset.load()
36 | result = next(load_dataset.process())
37 | assert len(result[0]) == batch_size
38 | if isinstance(dataset, (pd.DataFrame, Dataset)):
39 | assert isinstance(load_dataset._dataset, Dataset)
40 | else:
41 | assert isinstance(load_dataset.data, list)
42 |
43 | assert load_dataset.output_mappings == {"instruction": "other"}
44 |
45 |
46 | def test_make_generator_step_with_pipeline() -> None:
47 | pipeline = Pipeline()
48 | load_dataset = make_generator_step(data, pipeline=pipeline)
49 | assert load_dataset.pipeline == pipeline
50 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/apigen/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
17 | """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
18 |
19 | Args:
20 | initial_velocity: The initial velocity of the object.
21 | acceleration: The acceleration of the object.
22 | time: The time elapsed.
23 |
24 | Returns:
25 | The final velocity
26 | """
27 | return initial_velocity + acceleration * time
28 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Optional, Tuple
16 |
17 |
18 | def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
19 | """Gets the value at the specified index in the matrix.
20 |
21 | Args:
22 | matrix: A list of lists representing the matrix.
23 | indices: A tuple containing the row and column indices.
24 | """
25 | row_index, col_index = indices
26 | if (
27 | row_index < 0
28 | or row_index >= len(matrix)
29 | or col_index < 0
30 | or col_index >= len(matrix[row_index])
31 | ):
32 | return None
33 | return matrix[row_index][col_index]
34 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/apigen/_sample_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Optional, Tuple
16 |
17 |
18 | def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
19 | """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
20 |
21 | Args:
22 | initial_velocity: The initial velocity of the object.
23 | acceleration: The acceleration of the object.
24 | time: The time elapsed.
25 |
26 | Returns:
27 | The final velocity
28 | """
29 | return initial_velocity + acceleration * time
30 |
31 |
32 | def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
33 | """Gets the value at the specified index in the matrix.
34 |
35 | Args:
36 | matrix: A list of lists representing the matrix.
37 | indices: A tuple containing the row and column indices.
38 | """
39 | row_index, col_index = indices
40 | if (
41 | row_index < 0
42 | or row_index >= len(matrix)
43 | or col_index < 0
44 | or col_index >= len(matrix[row_index])
45 | ):
46 | return None
47 | return matrix[row_index][col_index]
48 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/evol_instruct/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.models.llms.base import LLM
16 | from distilabel.pipeline.local import Pipeline
17 | from distilabel.steps.tasks.evol_instruct.evol_complexity.base import (
18 | EvolComplexity,
19 | )
20 | from distilabel.steps.tasks.evol_instruct.evol_complexity.utils import (
21 | MUTATION_TEMPLATES,
22 | )
23 |
24 |
25 | class TestEvolComplexity:
26 | def test_mutation_templates(self, dummy_llm: LLM) -> None:
27 | pipeline = Pipeline(name="unit-test-pipeline")
28 | task = EvolComplexity(
29 | name="task", llm=dummy_llm, num_evolutions=2, pipeline=pipeline
30 | )
31 | assert task.name == "task"
32 | assert task.llm is dummy_llm
33 | assert task.num_evolutions == 2
34 | assert task.mutation_templates == MUTATION_TEMPLATES
35 | assert "BREADTH" not in task.mutation_templates
36 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.models.llms.base import LLM
16 | from distilabel.pipeline.local import Pipeline
17 | from distilabel.steps.tasks.evol_instruct.evol_complexity.generator import (
18 | EvolComplexityGenerator,
19 | )
20 | from distilabel.steps.tasks.evol_instruct.evol_complexity.utils import (
21 | GENERATION_MUTATION_TEMPLATES,
22 | )
23 |
24 |
25 | class TestEvolComplexityGenerator:
26 | def test_mutation_templates(self, dummy_llm: LLM) -> None:
27 | pipeline = Pipeline(name="unit-test-pipeline")
28 | task = EvolComplexityGenerator(
29 | name="task", llm=dummy_llm, num_instructions=2, pipeline=pipeline
30 | )
31 | assert task.name == "task"
32 | assert task.llm is dummy_llm
33 | assert task.num_instructions == 2
34 | assert task.mutation_templates == GENERATION_MUTATION_TEMPLATES
35 | assert "BREADTH" not in task.mutation_templates
36 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/evol_quality/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/magpie/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/math_shepherd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/structured_outputs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/steps/tasks/test_generate_embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Generator
16 |
17 | import pytest
18 |
19 | from distilabel.models.llms.huggingface.transformers import TransformersLLM
20 | from distilabel.pipeline.local import Pipeline
21 | from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings
22 |
23 |
24 | @pytest.fixture(scope="module")
25 | def transformers_llm() -> Generator[TransformersLLM, None, None]:
26 | llm = TransformersLLM(
27 | model="distilabel-internal-testing/tiny-random-mistral",
28 | cuda_devices=[],
29 | )
30 | llm.load()
31 |
32 | yield llm
33 |
34 |
35 | class TestGenerateEmbeddings:
36 | def test_process(self, transformers_llm: TransformersLLM) -> None:
37 | task = GenerateEmbeddings(
38 | name="task",
39 | llm=transformers_llm,
40 | pipeline=Pipeline(name="unit-test-pipeline"),
41 | )
42 | result = next(task.process([{"text": "Hello, how are you?"}]))
43 |
44 | assert "embedding" in result[0]
45 | assert len(result[0]["embedding"]) == 128
46 |
--------------------------------------------------------------------------------
/tests/unit/steps/test_truncate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | import pytest
18 |
19 | from distilabel.steps.truncate import TruncateTextColumn
20 |
21 |
22 | @pytest.mark.parametrize(
23 | "max_length, text, tokenizer, expected",
24 | [
25 | (
26 | 10,
27 | "This is a sample text that is longer than 10 characters",
28 | None,
29 | "This is a ",
30 | ),
31 | (
32 | 4,
33 | "This is a sample text that is longer than 10 characters",
34 | "teknium/OpenHermes-2.5-Mistral-7B",
35 | "This is a sample",
36 | ),
37 | ],
38 | )
39 | def test_truncate_row(
40 | max_length: int, text: str, tokenizer: Optional[str], expected: str
41 | ) -> None:
42 | trunc = TruncateTextColumn(
43 | column="text", max_length=max_length, tokenizer=tokenizer
44 | )
45 | trunc.load()
46 |
47 | assert next(trunc.process([{"text": text}])) == [{"text": expected}]
48 |
--------------------------------------------------------------------------------
/tests/unit/test_errors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from distilabel.errors import DistilabelUserError
16 |
17 |
18 | def test_distilabel_user_error() -> None:
19 | msg = DistilabelUserError("This is an error message.")
20 | assert str(msg) == "This is an error message."
21 | msg = DistilabelUserError(
22 | "This is an error message.", page="sections/getting_started/faq/"
23 | )
24 | assert (
25 | str(msg)
26 | == "This is an error message.\n\nFor further information visit 'https://distilabel.argilla.io/latest/sections/getting_started/faq/'"
27 | )
28 |
--------------------------------------------------------------------------------
/tests/unit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_files.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tempfile
16 | from pathlib import Path
17 |
18 | from distilabel.utils.files import list_files_in_dir
19 |
20 |
21 | def test_list_files_in_dir() -> None:
22 | with tempfile.TemporaryDirectory() as temp_dir:
23 | temp_dir = Path(temp_dir)
24 |
25 | created_files = []
26 | for i in range(20):
27 | file_path = temp_dir / f"{i}.txt"
28 | created_files.append(file_path)
29 | with open(file_path, "w") as f:
30 | f.write("hello")
31 |
32 | assert list_files_in_dir(Path(temp_dir)) == created_files
33 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_lists.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List
16 |
17 | import pytest
18 |
19 | from distilabel.utils.lists import flatten_responses
20 |
21 |
22 | @pytest.mark.parametrize(
23 | "input, expected",
24 | [
25 | ([["A"], ["B"]], ["A", "B"]),
26 | ([["A", "B"], ["C", "D"]], ["B", "D"]),
27 | ],
28 | )
29 | def test_flatten_responses(input: List[List[str]], expected: List[str]) -> None:
30 | assert flatten_responses(input) == expected
31 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from unittest import mock
17 |
18 | from distilabel.utils.ray import script_executed_in_ray_cluster
19 |
20 |
21 | def test_script_executed_on_ray_cluster() -> None:
22 | assert not script_executed_in_ray_cluster()
23 |
24 | with mock.patch.dict(
25 | os.environ,
26 | {
27 | "RAY_NODE_TYPE_NAME": "headgroup",
28 | "RAY_CLUSTER_NAME": "disticluster",
29 | "RAY_ADDRESS": "127.0.0.1:6379",
30 | },
31 | ):
32 | assert script_executed_in_ray_cluster()
33 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_serialization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pydantic import BaseModel
16 |
17 | from distilabel.utils.serialization import _extra_serializable_fields, _Serializable
18 |
19 |
20 | def test_extra_serializable_fields() -> None:
21 | class DummyAttribute(BaseModel, _Serializable):
22 | pass
23 |
24 | class Dummy(BaseModel, _Serializable):
25 | attr: DummyAttribute
26 |
27 | dummy = Dummy(attr=DummyAttribute())
28 |
29 | assert _extra_serializable_fields(dummy) == [
30 | {
31 | "attr": {
32 | "type_info": {
33 | "module": "tests.unit.utils.test_serialization",
34 | "name": "DummyAttribute",
35 | }
36 | }
37 | }
38 | ]
39 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_typing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present, Argilla, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 |
17 | from typing_extensions import Annotated
18 |
19 | from distilabel.utils.typing_ import is_parameter_annotated_with
20 |
21 |
22 | def test_is_parameter_annotated_with() -> None:
23 | def dummy_function(arg: Annotated[int, "unit-test"], arg2: int) -> None:
24 | pass
25 |
26 | signature = inspect.signature(dummy_function)
27 | arg_parameter = signature.parameters["arg"]
28 | arg2_parameter = signature.parameters["arg2"]
29 |
30 | assert is_parameter_annotated_with(arg_parameter, "hello") is False
31 | assert is_parameter_annotated_with(arg_parameter, "unit-test") is True
32 | assert is_parameter_annotated_with(arg2_parameter, "unit-test") is False
33 |
--------------------------------------------------------------------------------