├── artifex
├── py.typed
├── models
│ ├── reranker
│ │ └── __init__.py
│ ├── classification
│ │ ├── binary_classification
│ │ │ ├── __init__.py
│ │ │ └── guardrail
│ │ │ │ └── __init__.py
│ │ ├── multi_class_classification
│ │ │ ├── __init__.py
│ │ │ ├── intent_classifier
│ │ │ │ └── __init__.py
│ │ │ ├── emotion_detection
│ │ │ │ └── __init__.py
│ │ │ └── sentiment_analysis
│ │ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── named_entity_recognition
│ │ ├── __init__.py
│ │ └── text_anonymization
│ │ │ └── __init__.py
│ └── __init__.py
├── core
│ ├── __init__.py
│ ├── exceptions.py
│ ├── models.py
│ ├── decorators.py
│ └── _hf_patches.py
├── utils.py
├── config.py
└── __init__.py
├── tests
├── __init__.py
├── unit
│ ├── __init__.py
│ ├── guardrail
│ │ ├── test_gr__init__.py
│ │ ├── test_gr_get_data_gen_instr.py
│ │ └── test_gr_train.py
│ ├── emotion_detection
│ │ └── test_ed__init__.py
│ ├── intent_classifier
│ │ └── test_ic__init__.py
│ ├── sentiment_analysis
│ │ └── test_sa__init__.py
│ ├── classification_model
│ │ ├── test_cm__init__.py
│ │ ├── test_cm_load_model.py
│ │ └── test_cm_get_data_gen_instr.py
│ ├── core
│ │ └── decorators
│ │ │ ├── test_should_skip_method.py
│ │ │ └── test_auto_validate_methods.py
│ ├── base_model
│ │ ├── test_bm_sanitize_output_path.py
│ │ └── test_base_model_load.py
│ ├── named_entity_recognition
│ │ ├── test_ner_parse_user_instructions.py
│ │ ├── test_ner_post_process_synthetic_dataset.py
│ │ └── test_ner_get_data_gen_instr.py
│ ├── reranker
│ │ ├── test_rr_get_data_gen_instr.py
│ │ ├── test_rr_parse_user_instructions.py
│ │ └── test_rr__call__.py
│ └── text_anonymization
│ │ ├── test_ta__call__.py
│ │ └── test_ta_train.py
├── integration
│ ├── __init__.py
│ ├── reranker
│ │ ├── test_rr_train_intgr.py
│ │ └── test_rr__call__intgr.py
│ ├── text_classification
│ │ ├── test_tc_train_intgr.py
│ │ └── test_tc__call__intgr.py
│ ├── guardrail
│ │ ├── test_gr_train_intgr.py
│ │ └── test_gr__call__intgr.py
│ ├── text_anonymization
│ │ ├── test_ta_train_intgr.py
│ │ └── test_ta__call__intgr.py
│ ├── intent_classifier
│ │ ├── test_ic_train_intgr.py
│ │ └── test_ic__call__intgr.py
│ ├── emotion_detection
│ │ ├── test_ed_train_intgr.py
│ │ └── test_ed__call__intgr.py
│ ├── named_entity_recognition
│ │ ├── test_ner_train_intgr.py
│ │ └── test_ner__call__intgr.py
│ └── sentiment_analysis
│ │ ├── test_sa__call__.py
│ │ └── test_sa_train_intgr.py
└── conftest.py
├── .github
├── CODEOWNERS
├── pull_request_template.md
├── ISSUE_TEMPLATE
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ └── python-publish.yml
├── setup.py
├── assets
├── hero.png
├── banner.png
└── experiment.png
├── pytest.ini
├── .gitignore
├── SECURITY.md
├── LICENSE
├── CLA.md
├── pyproject.toml
├── CODE_OF_CONDUCT.md
├── CHANGELOG.md
├── CONTRIBUTING.md
└── requirements.txt
/artifex/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @rlucatoor
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 |
4 | setup()
--------------------------------------------------------------------------------
/assets/hero.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/hero.png
--------------------------------------------------------------------------------
/assets/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/banner.png
--------------------------------------------------------------------------------
/assets/experiment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/experiment.png
--------------------------------------------------------------------------------
/artifex/models/reranker/__init__.py:
--------------------------------------------------------------------------------
1 | from .reranker import Reranker
2 |
3 | __all__ = ["Reranker"]
--------------------------------------------------------------------------------
/artifex/models/classification/binary_classification/__init__.py:
--------------------------------------------------------------------------------
1 | from .guardrail import Guardrail
2 |
3 | __all__ = [
4 | "Guardrail"
5 | ]
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts = -s
3 | markers =
4 | unit: marks tests as unit (use with -m "unit")
5 | integration: marks tests as integration (use with -m "integration")
--------------------------------------------------------------------------------
/artifex/models/named_entity_recognition/__init__.py:
--------------------------------------------------------------------------------
1 | from .text_anonymization import TextAnonymization
2 | from .named_entity_recognition import NamedEntityRecognition
3 |
4 | __all__ = [
5 | "NamedEntityRecognition",
6 | "TextAnonymization"
7 | ]
--------------------------------------------------------------------------------
/artifex/models/classification/multi_class_classification/__init__.py:
--------------------------------------------------------------------------------
1 | from .emotion_detection import EmotionDetection
2 | from .intent_classifier import IntentClassifier
3 | from .sentiment_analysis import SentimentAnalysis
4 |
5 | __all__ = [
6 | "EmotionDetection",
7 | "IntentClassifier",
8 | "SentimentAnalysis"
9 | ]
--------------------------------------------------------------------------------
/artifex/models/classification/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification_model import ClassificationModel
2 | from .binary_classification import Guardrail
3 | from .multi_class_classification import EmotionDetection, IntentClassifier, SentimentAnalysis
4 |
5 | __all__ = [
6 | "ClassificationModel",
7 | "Guardrail",
8 | "EmotionDetection",
9 | "IntentClassifier",
10 | "SentimentAnalysis",
11 | ]
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | # Pull Request
2 |
3 | ## What does this PR do?
4 |
5 |
6 |
7 | ## Checklist
8 |
9 | - [ ] I’ve tested the changes
10 | - [ ] I’ve added or updated docs (if applicable)
11 | - [ ] I’ve added or updated tests (if applicable)
12 |
13 | ## Related Issues
14 |
15 |
16 |
17 | ## Additional Notes
18 |
19 |
20 |
--------------------------------------------------------------------------------
/tests/integration/reranker/test_rr_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `Reranker` class.
13 | Args:
14 | artifex (Artifex): The Artifex instance to be used for testing.
15 | """
16 |
17 | artifex.reranker.train(
18 | domain="test domain",
19 | num_samples=40,
20 | num_epochs=1,
21 | output_path=output_folder
22 | )
--------------------------------------------------------------------------------
/artifex/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification import ClassificationModel, Guardrail, EmotionDetection, IntentClassifier, \
2 | SentimentAnalysis
3 |
4 | from .base_model import BaseModel
5 |
6 | from .named_entity_recognition import NamedEntityRecognition, TextAnonymization
7 |
8 | from .reranker import Reranker
9 |
10 | __all__ = [
11 | "ClassificationModel",
12 | "Guardrail",
13 | "EmotionDetection",
14 | "IntentClassifier",
15 | "SentimentAnalysis",
16 | "BaseModel",
17 | "NamedEntityRecognition",
18 | "TextAnonymization",
19 | "Reranker",
20 | ]
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Virtual environment
2 | .venv/
3 | venv/
4 | .env/
5 |
6 | # VS Code stuff
7 | .vscode/
8 |
9 | # Cache
10 | __pycache__/
11 | .*cache
12 |
13 | # Environment variables
14 | .env
15 |
16 | # Tests
17 | test_data/
18 | .coverage
19 | htmlcov/
20 | test.py
21 | test.ipynb
22 | .pytest_env_backup
23 |
24 | # Distribution / packaging
25 | build/
26 | dist/
27 | *.egg-info/
28 |
29 | # Version info
30 | __version__.py
31 |
32 | # Debug Mode
33 | .debug
34 |
35 | # Training output
36 | trainer_output/
37 | artifex_output/
38 |
39 | # Notebooks
40 | .ipynb_checkpoints/
41 | notebooks/
--------------------------------------------------------------------------------
/artifex/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .decorators import auto_validate_methods
2 | from .exceptions import ServerError, ValidationError, BadRequestError
3 | from .models import ClassificationResponse, ClassificationClassName, NERTagName, NEREntity, \
4 | ClassificationInstructions, NERInstructions
5 |
6 |
7 | __all__ = [
8 | "auto_validate_methods",
9 | "ServerError",
10 | "ValidationError",
11 | "BadRequestError",
12 | "ClassificationResponse",
13 | "ClassificationClassName",
14 | "NERTagName",
15 | "NEREntity",
16 | "ClassificationInstructions",
17 | "NERInstructions"
18 | ]
--------------------------------------------------------------------------------
/artifex/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from artifex.config import config
3 |
4 |
5 | def get_model_output_path(output_path: str) -> str:
6 | """
7 | Get the output path for the trained model based on the provided output path (its parent directory).
8 | """
9 |
10 | return str(os.path.join(output_path, config.SYNTHEX_OUTPUT_MODEL_FOLDER_NAME))
11 |
12 | def get_dataset_output_path(output_path: str) -> str:
13 | """
14 | Get the output path for the dataset based on the provided output path (its parent directory).
15 | """
16 |
17 | return str(os.path.join(output_path, config.DEFAULT_SYNTHEX_DATASET_NAME))
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Supported Versions
4 |
5 | We currently support the latest published version of this project.
6 | Older versions may not receive security updates unless explicitly stated.
7 |
8 | ## Reporting a Vulnerability
9 |
10 | If you discover a security issue in this project, please do not open a public issue or pull requests, as doing so would make immediately make the problem visible to potentially malicious actors.
11 |
12 | Instead, report it privately by emailing:
13 |
14 | 📬 **info@tanaos.com**
15 |
16 | Please include as much detail as possible so we can verify and address the issue quickly.
17 |
18 | We appreciate responsible disclosure and will do our best to respond promptly.
19 |
--------------------------------------------------------------------------------
/tests/integration/text_classification/test_tc_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `ClassificationModel` class. Verify that:
13 | - The training process completes without errors.
14 | Args:
15 | artifex (Artifex): The Artifex instance to be used for testing.
16 | """
17 |
18 | classes = {
19 | "class_a": "Description for class A.",
20 | "class_b": "Description for class B.",
21 | "class_c": "Description for class C."
22 | }
23 |
24 | tc = artifex.text_classification
25 |
26 | tc.train(
27 | domain="test domain",
28 | classes=classes,
29 | num_samples=40,
30 | num_epochs=1,
31 | output_path=output_folder
32 | )
--------------------------------------------------------------------------------
/artifex/core/exceptions.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 |
4 | class ArtifexError(Exception):
5 | """
6 | Base exception for all errors raised by the library.
7 | """
8 |
9 | def __init__(
10 | self, message: str, details: Optional[str] = None
11 | ):
12 | self.message = message
13 | self.details = details
14 | super().__init__(self.__str__())
15 |
16 | def __str__(self):
17 | parts = [f"{self.message}"]
18 | if self.details:
19 | parts.append(f"Details: {self.details}")
20 | return " ".join(parts)
21 |
22 |
23 | class BadRequestError(ArtifexError):
24 | """Raised when the API request is malformed or invalid."""
25 | pass
26 |
27 | class ServerError(ArtifexError):
28 | """Raised when the API server returns a 5xx error."""
29 | pass
30 |
31 | class ValidationError(ArtifexError):
32 | """Raised when the API returns a validation error."""
33 | pass
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import shutil
3 | from typing import Generator
4 |
5 | from artifex import Artifex
6 | from artifex.config import config
7 |
8 |
9 | @pytest.fixture(scope="function")
10 | def artifex() -> Artifex:
11 | """
12 | Creates and returns an instance of the Artifex class using the API key
13 | from the environment variables.
14 | Returns:
15 | Artifex: An instance of the Artifex class initialized with the API key.
16 | """
17 |
18 | api_key = config.API_KEY
19 | if not api_key:
20 | pytest.fail("API_KEY not found in environment variables")
21 | return Artifex(api_key)
22 |
23 |
24 | @pytest.fixture(scope="function")
25 | def output_folder() -> Generator[str, None, None]:
26 | """
27 | Provides a temporary output folder path and cleans it up after the test.
28 | """
29 | folder_path = "./output_folder/"
30 | yield folder_path
31 | # Cleanup after test
32 | shutil.rmtree(folder_path, ignore_errors=True)
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Riccardo Lucato
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/tests/integration/guardrail/test_gr_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `Guardrail` class. Ensure that:
13 | - The training process completes without errors.
14 | - The output model's id2label mapping is { 0: "safe", 1: "unsafe" }.
15 | - The output model's label2id mapping is { "safe": 0, "unsafe": 1 }.
16 | Args:
17 | artifex (Artifex): The Artifex instance to be used for testing.
18 | """
19 |
20 | gr = artifex.guardrail
21 |
22 | gr.train(
23 | unsafe_content=["test instructions"],
24 | num_samples=40,
25 | num_epochs=1,
26 | output_path=output_folder,
27 | )
28 |
29 | # Verify the model's config mappings
30 | id2label = gr._model.config.id2label
31 | label2id = gr._model.config.label2id
32 | assert id2label == { 0: "safe", 1: "unsafe" }
33 | assert label2id == { "safe": 0, "unsafe": 1 }
--------------------------------------------------------------------------------
/tests/unit/guardrail/test_gr__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from artifex.models import Guardrail
4 |
5 |
6 | def test_guardrail_init(mocker: MockerFixture):
7 | """
8 | Unit test for Guardrail.__init__.
9 | Args:
10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies.
11 | """
12 |
13 | # Mock Synthex
14 | mock_synthex = mocker.Mock()
15 | # Mock config
16 | mock_config = mocker.patch("artifex.models.classification.binary_classification.guardrail.config")
17 | mock_config.GUARDRAIL_HF_BASE_MODEL = "mocked-base-model"
18 | # Mock ClassificationModel.__init__
19 | mock_super_init = mocker.patch(
20 | "artifex.models.classification.classification_model.ClassificationModel.__init__",
21 | return_value=None
22 | )
23 |
24 | # Instantiate Guardrail
25 | model = Guardrail(mock_synthex)
26 |
27 | # Assert ClassificationModel.__init__ was called with correct args
28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model")
29 | # Assert _system_data_gen_instr_val is set correctly
30 | assert isinstance(model._system_data_gen_instr_val, list)
31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val)
--------------------------------------------------------------------------------
/tests/integration/text_anonymization/test_ta_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `TextAnonymization` class. Verify that:
13 | - The training process completes without errors.
14 | - The output model's id2label mapping is the expected one.
15 | - The output model's label2id mapping is the expected one.
16 | Args:
17 | artifex (Artifex): The Artifex instance to be used for testing.
18 | """
19 |
20 | named_entities = artifex.text_anonymization._pii_entities
21 |
22 | bio_labels = ["O"]
23 | for name in named_entities.keys():
24 | bio_labels.extend([f"B-{name}", f"I-{name}"])
25 |
26 | ta = artifex.text_anonymization
27 |
28 | ta.train(
29 | domain="test domain",
30 | num_samples=40,
31 | num_epochs=1,
32 | output_path=output_folder
33 | )
34 |
35 | # Verify the model's config mappings
36 | id2label = ta._model.config.id2label
37 | label2id = ta._model.config.label2id
38 | assert id2label == { i: label for i, label in enumerate(bio_labels) }
39 | assert label2id == { label: i for i, label in enumerate(bio_labels) }
--------------------------------------------------------------------------------
/tests/integration/intent_classifier/test_ic_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `IntentClassifier` class. Verify that:
13 | - The training process completes without errors.
14 | - The output model's id2label mapping is the expected one.
15 | - The output model's label2id mapping is the expected one.
16 | Args:
17 | artifex (Artifex): The Artifex instance to be used for testing.
18 | """
19 |
20 | classes = {
21 | "class_a": "Description for class A.",
22 | "class_b": "Description for class B.",
23 | "class_c": "Description for class C."
24 | }
25 |
26 | ic = artifex.intent_classifier
27 |
28 | ic.train(
29 | domain="test domain",
30 | classes=classes,
31 | num_samples=40,
32 | num_epochs=1,
33 | output_path=output_folder
34 | )
35 |
36 | # Verify the model's config mappings
37 | id2label = ic._model.config.id2label
38 | label2id = ic._model.config.label2id
39 | assert id2label == { i: label for i, label in enumerate(classes.keys()) }
40 | assert label2id == { label: i for i, label in enumerate(classes.keys()) }
--------------------------------------------------------------------------------
/tests/integration/text_classification/test_tc__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import ClassificationResponse
5 |
6 |
7 | @pytest.mark.integration
8 | def test__call__single_input_success(
9 | artifex: Artifex
10 | ):
11 | """
12 | Test the `__call__` method of the `TextClassification` class when a single input is
13 | provided. Ensure that it returns a list of ClassificationResponse objects.
14 | Args:
15 | artifex (Artifex): The Artifex instance to be used for testing.
16 | """
17 |
18 | out = artifex.text_classification("test input")
19 | assert isinstance(out, list)
20 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
21 |
22 |
23 | @pytest.mark.integration
24 | def test__call__multiple_inputs_success(
25 | artifex: Artifex
26 | ):
27 | """
28 | Test the `__call__` method of the `TextClassification` class when multiple inputs are
29 | provided. Ensure that it returns a list of ClassificationResponse objects.
30 | Args:
31 | artifex (Artifex): The Artifex instance to be used for testing.
32 | """
33 |
34 | out = artifex.text_classification(["test input 1", "test input 2", "test input 3"])
35 | assert isinstance(out, list)
36 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
37 |
--------------------------------------------------------------------------------
/tests/unit/emotion_detection/test_ed__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from artifex.models import EmotionDetection
4 |
5 |
6 | def test_emotion_detection_init(mocker: MockerFixture):
7 | """
8 | Unit test for EmotionDetection.__init__.
9 | Args:
10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies.
11 | """
12 |
13 | # Mock Synthex
14 | mock_synthex = mocker.Mock()
15 | # Mock config
16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.emotion_detection.config")
17 | mock_config.EMOTION_DETECTION_HF_BASE_MODEL = "mocked-base-model"
18 | # Mock ClassificationModel.__init__
19 | mock_super_init = mocker.patch(
20 | "artifex.models.classification.classification_model.ClassificationModel.__init__",
21 | return_value=None
22 | )
23 |
24 | # Instantiate EmotionDetection
25 | model = EmotionDetection(mock_synthex)
26 |
27 | # Assert ClassificationModel.__init__ was called with correct args
28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model")
29 | # Assert _system_data_gen_instr_val is set correctly
30 | assert isinstance(model._system_data_gen_instr_val, list)
31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val)
--------------------------------------------------------------------------------
/tests/unit/intent_classifier/test_ic__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from artifex.models import IntentClassifier
4 |
5 |
6 | def test_intent_classifier_init(mocker: MockerFixture):
7 | """
8 | Unit test for IntentClassifier.__init__.
9 | Args:
10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies.
11 | """
12 |
13 | # Mock Synthex
14 | mock_synthex = mocker.Mock()
15 | # Mock config
16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.intent_classifier.config")
17 | mock_config.INTENT_CLASSIFIER_HF_BASE_MODEL = "mocked-base-model"
18 | # Mock ClassificationModel.__init__
19 | mock_super_init = mocker.patch(
20 | "artifex.models.classification.classification_model.ClassificationModel.__init__",
21 | return_value=None
22 | )
23 |
24 | # Instantiate IntentClassifier
25 | model = IntentClassifier(mock_synthex)
26 |
27 | # Assert ClassificationModel.__init__ was called with correct args
28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model")
29 | # Assert _system_data_gen_instr_val is set correctly
30 | assert isinstance(model._system_data_gen_instr_val, list)
31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val)
--------------------------------------------------------------------------------
/tests/unit/sentiment_analysis/test_sa__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from artifex.models import SentimentAnalysis
4 |
5 |
6 | def test_sentiment_analysis_init(mocker: MockerFixture):
7 | """
8 | Unit test for SentimentAnalysis.__init__.
9 | Args:
10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies.
11 | """
12 |
13 | # Mock Synthex
14 | mock_synthex = mocker.Mock()
15 | # Mock config
16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.sentiment_analysis.config")
17 | mock_config.SENTIMENT_ANALYSIS_HF_BASE_MODEL = "mocked-base-model"
18 | # Mock ClassificationModel.__init__
19 | mock_super_init = mocker.patch(
20 | "artifex.models.classification.classification_model.ClassificationModel.__init__",
21 | return_value=None
22 | )
23 |
24 | # Instantiate SentimentAnalysis
25 | model = SentimentAnalysis(mock_synthex)
26 |
27 | # Assert ClassificationModel.__init__ was called with correct args
28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model")
29 | # Assert _system_data_gen_instr_val is set correctly
30 | assert isinstance(model._system_data_gen_instr_val, list)
31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val)
--------------------------------------------------------------------------------
/CLA.md:
--------------------------------------------------------------------------------
1 | Contributor License Agreement (CLA)
2 |
3 | Thank you for your interest in contributing to Artifex.
4 |
5 | By signing this CLA (electronically via GitHub or other integrated tools), you agree to the following terms for your contributions (past, present, and future):
6 |
7 | 1. **Originality**:
8 | You confirm that your contributions are your own work, or that you have the right to submit them.
9 |
10 | 2. **License**:
11 | You grant Riccardo Lucato an irrevocable, worldwide, non-exclusive, royalty-free license to use, reproduce, modify, distribute, and sublicense your contributions under the MIT License (or any other license used by the project).
12 |
13 | 3. **Employment and third-party rights**:
14 | You confirm that your contributions are not subject to any third-party rights, including those of your employer or clients. If such rights could exist, you confirm that you have obtained a valid and sufficient waiver or permission to contribute the work under this CLA and the project’s open source license. You agree that neither your employer nor any third party will have any claim or rights over your contributions to this project.
15 |
16 | 4. **Disclaimer**:
17 | Contributions are provided "as-is" without warranties or conditions of any kind.
18 |
19 | By electronically signing this CLA, you accept these terms and certify that you have the legal authority to do so.
--------------------------------------------------------------------------------
/tests/integration/emotion_detection/test_ed_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import shutil
3 |
4 | from artifex import Artifex
5 |
6 |
7 | @pytest.mark.integration
8 | def test_train_success(
9 | artifex: Artifex,
10 | output_folder: str
11 | ):
12 | """
13 | Test the `train` method of the `EmotionDetection` class. Ensure that:
14 | - The training process completes without errors.
15 | - The output model's id2label mapping is the expected one.
16 | - The output model's label2id mapping is the expected one.
17 | Args:
18 | artifex (Artifex): The Artifex instance to be used for testing.
19 | """
20 |
21 | ed = artifex.emotion_detection
22 |
23 | try:
24 | ed.train(
25 | domain="test domain",
26 | classes={
27 | "happy": "text expressing happiness",
28 | "sad": "text expressing sadness"
29 | },
30 | num_samples=40,
31 | num_epochs=1,
32 | output_path=output_folder
33 | )
34 |
35 | # Verify the model's config mappings
36 | id2label = ed._model.config.id2label
37 | label2id = ed._model.config.label2id
38 | assert id2label == { 0: "happy", 1: "sad" }
39 | assert label2id == { "happy": 0, "sad": 1 }
40 | finally:
41 | # Clean up the output folder
42 | shutil.rmtree(output_folder, ignore_errors=True)
43 |
44 |
--------------------------------------------------------------------------------
/tests/integration/named_entity_recognition/test_ner_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `NamedEntityRecognition` class. Verify that:
13 | - The training process completes without errors.
14 | - The output model's id2label mapping is the expected one.
15 | - The output model's label2id mapping is the expected one.
16 | Args:
17 | artifex (Artifex): The Artifex instance to be used for testing.
18 | """
19 |
20 | named_entities = {
21 | "ENTITY_A": "Description for entity A.",
22 | "ENTITY_B": "Description for entity B.",
23 | }
24 |
25 | bio_labels = ["O"]
26 | for name in named_entities.keys():
27 | bio_labels.extend([f"B-{name}", f"I-{name}"])
28 |
29 | ner = artifex.named_entity_recognition
30 |
31 | ner.train(
32 | domain="test domain",
33 | named_entities=named_entities,
34 | num_samples=40,
35 | num_epochs=1,
36 | output_path=output_folder
37 | )
38 |
39 | # Verify the model's config mappings
40 | id2label = ner._model.config.id2label
41 | label2id = ner._model.config.label2id
42 | assert id2label == { i: label for i, label in enumerate(bio_labels) }
43 | assert label2id == { label: i for i, label in enumerate(bio_labels) }
--------------------------------------------------------------------------------
/artifex/models/classification/multi_class_classification/intent_classifier/__init__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 |
3 | from ...classification_model import ClassificationModel
4 |
5 | from artifex.core import auto_validate_methods
6 | from artifex.config import config
7 |
8 |
9 | @auto_validate_methods
10 | class IntentClassifier(ClassificationModel):
11 | """
12 | An Intent Classifier Model for LLMs. This model is used to classify a text's intent or objective into
13 | predefined categories.
14 | """
15 |
16 | def __init__(self, synthex: Synthex):
17 | """
18 | Initializes the class with a Synthex instance.
19 | Args:
20 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used to train the model.
21 | """
22 |
23 | super().__init__(synthex, base_model_name=config.INTENT_CLASSIFIER_HF_BASE_MODEL)
24 | self._system_data_gen_instr_val: list[str] = [
25 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.",
26 | "The 'text' field should contain text that has a specific intent or objective.",
27 | "The 'labels' field should contain a label indicating the intent or objective of the 'text'.",
28 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.",
29 | "This is a list of the allowed 'labels' and their meaning: "
30 | ]
--------------------------------------------------------------------------------
/artifex/models/classification/multi_class_classification/emotion_detection/__init__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 |
3 | from ...classification_model import ClassificationModel
4 |
5 | from artifex.core import auto_validate_methods
6 | from artifex.config import config
7 |
8 |
9 | @auto_validate_methods
10 | class EmotionDetection(ClassificationModel):
11 | """
12 | An Emotion Detection Model is used to classify text into different emotional categories. In this
13 | implementation, we support the following emotions: `joy`, `anger`, `fear`, `sadness`, `surprise`, `disgust`,
14 | `excitement` and `neutral`.
15 | """
16 |
17 | def __init__(self, synthex: Synthex):
18 | """
19 | Initializes the class with a Synthex instance.
20 | Args:
21 | synthex (Synthex): An instance of the Synthex class to generate the synthetic
22 | data used to train the model.
23 | """
24 |
25 | super().__init__(synthex, base_model_name=config.EMOTION_DETECTION_HF_BASE_MODEL)
26 | self._system_data_gen_instr_val: list[str] = [
27 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.",
28 | "The 'text' field should contain text that may or may not express a certain emotion.",
29 | "The 'labels' field should contain a label indicating the emotion of the 'text'.",
30 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.",
31 | "This is a list of the allowed 'labels' their meaning: "
32 | ]
--------------------------------------------------------------------------------
/tests/integration/guardrail/test_gr__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import ClassificationResponse
5 |
6 |
7 | expected_labels = ["safe", "unsafe"]
8 |
9 | @pytest.mark.integration
10 | def test__call__single_input_success(
11 | artifex: Artifex
12 | ):
13 | """
14 | Test the `__call__` method of the `Guardrail` class when a single input is provided.
15 | Ensure that:
16 | - It returns a list of ClassificationResponse objects.
17 | - The output labels are among the expected intent labels.
18 | Args:
19 | artifex (Artifex): The Artifex instance to be used for testing.
20 | """
21 |
22 | out = artifex.guardrail("test input")
23 | assert isinstance(out, list)
24 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
25 | assert all(resp.label in expected_labels for resp in out)
26 |
27 | @pytest.mark.integration
28 | def test__call__multiple_inputs_success(
29 | artifex: Artifex
30 | ):
31 | """
32 | Test the `__call__` method of the `Guardrail` class when multiple inputs are provided.
33 | Ensure that:
34 | - It returns a list of ClassificationResponse objects.
35 | - The output labels are among the expected intent labels.
36 | Args:
37 | artifex (Artifex): The Artifex instance to be used for testing.
38 | """
39 |
40 | out = artifex.guardrail(["test input 1", "test input 2", "test input 3"])
41 | assert isinstance(out, list)
42 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
43 | assert all(resp.label in expected_labels for resp in out)
--------------------------------------------------------------------------------
/tests/integration/sentiment_analysis/test_sa__call__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import ClassificationResponse
5 |
6 |
7 | expected_labels = ["very_positive", "positive", "negative", "very_negative", "neutral"]
8 |
9 |
10 | @pytest.mark.integration
11 | def test__call__single_input_success(
12 | artifex: Artifex
13 | ):
14 | """
15 | Test the `__call__` method of the `SentimentAnalysis` class. Ensure that:
16 | - It returns a list of ClassificationResponse objects.
17 | - The output labels are among the expected sentiment labels.
18 | Args:
19 | artifex (Artifex): The Artifex instance to be used for testing.
20 | """
21 |
22 | out = artifex.sentiment_analysis("test input")
23 | assert isinstance(out, list)
24 | assert all(resp.label in expected_labels for resp in out)
25 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
26 |
27 | @pytest.mark.integration
28 | def test__call__multiple_inputs_success(
29 | artifex: Artifex
30 | ):
31 | """
32 | Test the `__call__` method of the `SentimentAnalysis` class, when multiple inputs are
33 | provided. Ensure that:
34 | - It returns a list of ClassificationResponse objects.
35 | - The output labels are among the expected sentiment labels.
36 | Args:
37 | artifex (Artifex): The Artifex instance to be used for testing.
38 | """
39 |
40 | out = artifex.sentiment_analysis(["test input 1", "test input 2", "test input 3"])
41 | assert isinstance(out, list)
42 | assert all(resp.label in expected_labels for resp in out)
43 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
--------------------------------------------------------------------------------
/tests/integration/sentiment_analysis/test_sa_train_intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test_train_success(
8 | artifex: Artifex,
9 | output_folder: str
10 | ):
11 | """
12 | Test the `train` method of the `SentimentAnalysisModel` class. Ensure that:
13 | - The training process completes without errors.
14 | - The output model's id2label mapping is { 0: "very_negative", 1: "negative", 2: "neutral", 3: "positive", 4: "very_positive" }.
15 | - The output model's label2id mapping is { "very_negative": 0, "negative": 1, "neutral": 2, "positive": 3, "very_positive": 4 }.
16 | Args:
17 | artifex (Artifex): The Artifex instance to be used for testing.
18 | """
19 |
20 | sa = artifex.sentiment_analysis
21 |
22 | sa.train(
23 | domain="general",
24 | classes={
25 | "very_negative": "Text expressing a very negative sentiment.",
26 | "negative": "Text expressing a negative sentiment.",
27 | "neutral": "Text expressing a neutral sentiment.",
28 | "positive": "Text expressing a positive sentiment.",
29 | "very_positive": "Text expressing a very positive sentiment.",
30 | },
31 | num_samples=40,
32 | num_epochs=1,
33 | output_path=output_folder
34 | )
35 |
36 | # Verify the model's config mappings
37 | id2label = sa._model.config.id2label
38 | label2id = sa._model.config.label2id
39 | assert id2label == { 0: "very_negative", 1: "negative", 2: "neutral", 3: "positive", 4: "very_positive" }
40 | assert label2id == { "very_negative": 0, "negative": 1, "neutral": 2, "positive": 3, "very_positive": 4 }
--------------------------------------------------------------------------------
/tests/integration/emotion_detection/test_ed__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import ClassificationResponse
5 |
6 |
7 | expected_labels = ["joy", "anger", "fear", "sadness", "surprise", "disgust", "excitement", "neutral"]
8 |
9 | @pytest.mark.integration
10 | def test__call__single_input_success(
11 | artifex: Artifex
12 | ):
13 | """
14 | Test the `__call__` method of the `EmotionDetection` class when a single input is provided.
15 | Ensure that:
16 | - It returns a list of ClassificationResponse objects.
17 | - The output labels are among the expected intent labels.
18 | Args:
19 | artifex (Artifex): The Artifex instance to be used for testing.
20 | """
21 |
22 | out = artifex.emotion_detection("test input")
23 | assert isinstance(out, list)
24 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
25 | assert all(resp.label in expected_labels for resp in out)
26 |
27 | @pytest.mark.integration
28 | def test__call__multiple_inputs_success(
29 | artifex: Artifex
30 | ):
31 | """
32 | Test the `__call__` method of the `EmotionDetection` class when multiple inputs are provided.
33 | Ensure that:
34 | - It returns a list of ClassificationResponse objects.
35 | - The output labels are among the expected intent labels.
36 | Args:
37 | artifex (Artifex): The Artifex instance to be used for testing.
38 | """
39 |
40 | out = artifex.emotion_detection(["test input 1", "test input 2", "test input 3"])
41 | assert isinstance(out, list)
42 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
43 | assert all(resp.label in expected_labels for resp in out)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61", "setuptools_scm", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "artifex"
7 | dynamic = ["version"]
8 | description = "Create your private AI model with no training data or GPUs 🤖🚀."
9 | authors = [
10 | { name = "Riccardo Lucato", email = "riccardo@tanaos.com" },
11 | { name = "Saurabh Pradhan" }
12 | ]
13 | license = { text = "MIT" }
14 | readme = "README.md"
15 | requires-python = ">=3.10"
16 |
17 | dependencies = [
18 | "aiohttp>=3.12.14",
19 | "datasets>=3.6.0",
20 | "filelock>=3.20.1",
21 | "jupyterlab>=4.4.8",
22 | "protobuf>=6.33.2",
23 | "rich>=14.1.0",
24 | "sentencepiece>=0.2.1",
25 | "synthex>=0.4.2",
26 | "tiktoken>=0.12.0",
27 | "torch>=2.8.0",
28 | "transformers[torch]>=4.53.1",
29 | "tzlocal>=5.3.1",
30 | "urllib3>=2.6.0",
31 | ]
32 |
33 | classifiers = [
34 | "Development Status :: 3 - Alpha",
35 | "Intended Audience :: Developers",
36 | "License :: OSI Approved :: MIT License",
37 | "Programming Language :: Python :: 3.10",
38 | "Topic :: Software Development :: Libraries :: Python Modules",
39 | "Operating System :: OS Independent",
40 | ]
41 |
42 | [project.urls]
43 | homepage = "https://github.com/tanaos/artifex"
44 |
45 | [tool.setuptools]
46 | include-package-data = true
47 | zip-safe = false
48 |
49 | [tool.setuptools.packages.find]
50 | where = ["."]
51 |
52 | [tool.setuptools.package-data]
53 | artifex = ["py.typed"]
54 |
55 | [tool.setuptools_scm]
56 | version_file = "artifex/__version__.py"
57 |
58 | [dependency-groups]
59 | dev = [
60 | "ipykernel>=6.16.2",
61 | "notebook>=6.5.7",
62 | "pytest>=8.4.1",
63 | "pytest-mock>=3.14.1",
64 | ]
65 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Purpose
4 |
5 | This project is open to contributions from anyone interested. To keep things productive and collaborative, we ask that everyone follows a few basic guidelines when participating.
6 |
7 | ## How to Participate
8 |
9 | You’re welcome to:
10 |
11 | - Submit issues, suggestions, and improvements
12 | - Contribute code via pull requests
13 | - Offer constructive feedback or ask questions
14 | - Discuss ideas, implementation details, and priorities
15 |
16 | If you disagree with something, that’s fine — just explain your reasoning and move on if consensus isn’t reached.
17 |
18 | ## Please Avoid
19 |
20 | We want to avoid things that create unnecessary friction, such as:
21 |
22 | - Personal attacks
23 | - Starting arguments for the sake of it
24 | - Name-calling or intentionally rude behavior
25 | - Taking disagreements too far or making them personal
26 |
27 | ## Enforcement
28 |
29 | Instances of disruptive, hostile or otherwise unacceptable behavior may be reported by contacting the project maintainers at info@tanaos.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
30 |
31 | Contributors who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by members of the project's leadership.
32 |
33 | ## Final Note
34 |
35 | This is an engineering project. We value clarity, logic, and constructive feedback. Treat others the way you'd expect to be treated in a professional collaboration.
36 |
--------------------------------------------------------------------------------
/tests/integration/reranker/test_rr__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 |
5 |
6 | @pytest.mark.integration
7 | def test__call__single_input_success(
8 | artifex: Artifex
9 | ):
10 | """
11 | Test the `__call__` method of the `Reranker` class when a single input is provided.
12 | Ensure that:
13 | - The return type is list[tuple[str, float]].
14 | - The output tuples only contain the provided input documents
15 | Args:
16 | artifex (Artifex): The Artifex instance to be used for testing.
17 | """
18 |
19 | input_doc = "doc1"
20 |
21 | out = artifex.reranker(query="test query", documents=input_doc)
22 | assert isinstance(out, list)
23 | assert all(
24 | isinstance(resp, tuple) and
25 | isinstance(resp[0], str) and
26 | isinstance(resp[1], float)
27 | for resp in out
28 | )
29 | assert all(resp[0] in [input_doc] for resp in out)
30 |
31 | @pytest.mark.integration
32 | def test__call__multiple_inputs_success(
33 | artifex: Artifex
34 | ):
35 | """
36 | Test the `__call__` method of the `Reranker` class when multiple inputs are provided.
37 | Ensure that:
38 | - The return type is list[tuple[str, float]].
39 | - The output tuples only contain the provided input documents.
40 | Args:
41 | artifex (Artifex): The Artifex instance to be used for testing.
42 | """
43 |
44 | input_docs = ["doc1", "doc2", "doc3"]
45 |
46 | out = artifex.reranker(query="test query", documents=input_docs)
47 | assert isinstance(out, list)
48 | assert all(
49 | isinstance(resp, tuple) and
50 | isinstance(resp[0], str) and
51 | isinstance(resp[1], float)
52 | for resp in out
53 | )
54 | assert all(resp[0] in input_docs for resp in out)
--------------------------------------------------------------------------------
/tests/integration/intent_classifier/test_ic__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import ClassificationResponse
5 |
6 |
7 | expected_labels = [
8 | "greeting", "farewell", "thank_you", "affirmation", "negation", "small_talk",
9 | "bot_capabilities", "feedback_positive", "feedback_negative", "clarification",
10 | "suggestion", "language_change"
11 | ]
12 |
13 | @pytest.mark.integration
14 | def test__call__single_input_success(
15 | artifex: Artifex
16 | ):
17 | """
18 | Test the `__call__` method of the `IntentClassifier` class when a single input is
19 | provided. Ensure that:
20 | - It returns a list of ClassificationResponse objects.
21 | - The output labels are among the expected intent labels.
22 | Args:
23 | artifex (Artifex): The Artifex instance to be used for testing.
24 | """
25 |
26 | out = artifex.intent_classifier("test input")
27 | assert isinstance(out, list)
28 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
29 | assert all(resp.label in expected_labels for resp in out)
30 |
31 | @pytest.mark.integration
32 | def test__call__multiple_inputs_success(
33 | artifex: Artifex
34 | ):
35 | """
36 | Test the `__call__` method of the `IntentClassifier` class when multiple inputs are
37 | provided. Ensure that:
38 | - It returns a list of ClassificationResponse objects.
39 | - The output labels are among the expected intent labels.
40 | Args:
41 | artifex (Artifex): The Artifex instance to be used for testing.
42 | """
43 |
44 | out = artifex.intent_classifier(["test input 1", "test input 2", "test input 3"])
45 | assert isinstance(out, list)
46 | assert all(isinstance(resp, ClassificationResponse) for resp in out)
47 | assert all(resp.label in expected_labels for resp in out)
--------------------------------------------------------------------------------
/artifex/core/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 | from artifex.config import config
4 |
5 |
6 | class ClassificationResponse(BaseModel):
7 | label: str
8 | score: float
9 |
10 | class NEREntity(BaseModel):
11 | entity_group: str
12 | word: str
13 | score: float
14 | start: int
15 | end: int
16 |
17 | class ClassificationClassName(str):
18 | """
19 | A string subclass that enforces a maximum length and disallows spaces for classification
20 | class names.
21 | """
22 |
23 | max_length = config.CLASSIFICATION_CLASS_NAME_MAX_LENGTH
24 |
25 | def __new__(cls, value: str):
26 | if not value:
27 | raise ValueError("ClassName must be a non-empty string")
28 | if len(value) > cls.max_length:
29 | raise ValueError(f"ClassName exceeds max length of {cls.max_length}")
30 | if ' ' in value:
31 | raise ValueError("ClassName must not contain spaces")
32 | return str.__new__(cls, value)
33 |
34 | class NERTagName(str):
35 | """
36 | A string subclass that enforces a maximum length, requires the string to be all caps and
37 | disallows spaces for NER tag names.
38 | """
39 |
40 | max_length = config.NER_TAGNAME_MAX_LENGTH
41 |
42 | def __new__(cls, value: str):
43 | if not value:
44 | raise ValueError("NERTagName must be a non-empty string")
45 | if len(value) > cls.max_length:
46 | raise ValueError(f"NERTagName exceeds max length of {cls.max_length}")
47 | if ' ' in value:
48 | raise ValueError("NERTagName must not contain spaces")
49 | return str.__new__(cls, value.upper())
50 |
51 | NClassClassificationClassesDesc = dict[str, str]
52 |
53 | class ClassificationInstructions(BaseModel):
54 | classes: NClassClassificationClassesDesc
55 | domain: str
56 |
57 | NERTags = dict[str, str]
58 |
59 | class NERInstructions(BaseModel):
60 | named_entity_tags: NERTags
61 | domain: str
--------------------------------------------------------------------------------
/tests/integration/named_entity_recognition/test_ner__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.core import NEREntity
5 |
6 |
7 | expected_labels = [
8 | "O", "PERSON", "ORG", "LOCATION", "DATE", "TIME", "PERCENT", "NUMBER", "FACILITY",
9 | "PRODUCT", "WORK_OF_ART", "LANGUAGE", "NORP", "ADDRESS", "PHONE_NUMBER"
10 | ]
11 |
12 |
13 | @pytest.mark.integration
14 | def test__call__single_input_success(
15 | artifex: Artifex
16 | ):
17 | """
18 | Test the `__call__` method of the `NamedEntityRecognition` class when a single input is
19 | provided. Ensure that:
20 | - It returns a list of list of NEREntity objects.
21 | - The output labels are among the expected named entity labels.
22 | Args:
23 | artifex (Artifex): The Artifex instance to be used for testing.
24 | """
25 |
26 | out = artifex.named_entity_recognition("His name is John Doe.")
27 | assert isinstance(out, list)
28 | assert all(isinstance(resp, list) for resp in out)
29 | assert all(all(isinstance(entity, NEREntity) for entity in resp) for resp in out)
30 | assert all(all(entity.entity_group in expected_labels for entity in resp) for resp in out)
31 |
32 | @pytest.mark.integration
33 | def test__call__multiple_inputs_success(
34 | artifex: Artifex
35 | ):
36 | """
37 | Test the `__call__` method of the `NamedEntityRecognition` class when multiple inputs are
38 | provided. Ensure that:
39 | - It returns a list of list of NEREntity objects.
40 | - The output labels are among the expected named entity labels.
41 | Args:
42 | artifex (Artifex): The Artifex instance to be used for testing.
43 | """
44 |
45 | out = artifex.named_entity_recognition(["His name is John Does", "His name is Jane Smith."])
46 | assert isinstance(out, list)
47 | assert all(isinstance(resp, list) for resp in out)
48 | assert all(all(isinstance(entity, NEREntity) for entity in resp) for resp in out)
49 | assert all(all(entity.entity_group in expected_labels for entity in resp) for resp in out)
--------------------------------------------------------------------------------
/tests/integration/text_anonymization/test_ta__call__intgr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from artifex import Artifex
4 | from artifex.config import config
5 |
6 |
7 | @pytest.mark.integration
8 | def test__call__single_input_success(
9 | artifex: Artifex
10 | ):
11 | """
12 | Test the `__call__` method of the `TextAnonymization` class. Ensure that:
13 | - It returns a list of strings.
14 | - All returned strings are identical to the input strings, except for masked entities.
15 | Args:
16 | artifex (Artifex): The Artifex instance to be used for testing.
17 | """
18 |
19 | input = "Jonathan Rogers lives in New York City."
20 | out = artifex.text_anonymization(input)
21 | assert isinstance(out, list)
22 | assert all(isinstance(item, str) for item in out)
23 | for idx, text in enumerate(out):
24 | split_input = text.split(" ")
25 | split_out = out[idx].split(" ")
26 | assert len(split_input) == len(split_out)
27 | assert all(
28 | word in split_input or word == config.DEFAULT_TEXT_ANONYM_MASK for word in split_out
29 | )
30 |
31 | @pytest.mark.integration
32 | def test__call__multiple_inputs_success(
33 | artifex: Artifex
34 | ):
35 | """
36 | Test the `__call__` method of the `TextAnonymization` class, when multiple inputs are
37 | provided. Ensure that:
38 | - It returns a list of strings.
39 | - All returned strings are identical to the input strings, except for masked entities.
40 | Args:
41 | artifex (Artifex): The Artifex instance to be used for testing.
42 | """
43 |
44 | out = artifex.text_anonymization([
45 | "John Doe lives in New York City.",
46 | "Mark Spencer's phone number is 123-456-7890.",
47 | "Alice was born on January 1, 1990."
48 | ])
49 | assert isinstance(out, list)
50 | assert all(isinstance(item, str) for item in out)
51 | for idx, text in enumerate(out):
52 | split_input = text.split(" ")
53 | split_out = out[idx].split(" ")
54 | assert len(split_input) == len(split_out)
55 | assert all(
56 | word in split_input or word == config.DEFAULT_TEXT_ANONYM_MASK for word in split_out
57 | )
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## Release v0.4.1 - December 22, 2025
2 |
3 | ### Changed
4 |
5 | - Turned `ClassificationModel` into a concrete class instead of an abstract class.
6 | - Replaced `instructions` argument with `unsafe_content` argument in the `Guardrail.train()`
7 | method.
8 |
9 | ### Fixed
10 |
11 | - Fixed security vulnerabilities by updating dependencies.
12 | - Suppressed annoying tokenization-related warning in `NamedEntityRecognition.__call__()` method.
13 |
14 | ## Release v0.4.0 - December 4, 2025
15 |
16 | ### Added
17 |
18 | - Added the `Reranker` model.
19 | - Added the `SentimentAnalysis` model.
20 | - Added the `EmotionDetection` model.
21 | - Added the `NamedEntityRecognition` model.
22 | - Added the `TextAnonymization` model.
23 | - Added integration tests.
24 |
25 | ### Fixed
26 |
27 | - Fixed a bug causing the "Generating training data" progress bar to display a wrong progress percentage.
28 |
29 | ### Removed
30 |
31 | - Removed support for Python <= 3.9.
32 | - Removed all `# type: ignore` comments from the codebase.
33 |
34 | ### Changed
35 |
36 | - Updated error message when the `.load()` method is provided with a nonexistent model path or an invalid file format.
37 | - Updated the `IntentClassifier` base model.
38 | - Updated the output model's directory structure for better organization.
39 |
40 | ## Release v0.3.2 - October 2, 2025
41 |
42 | ### Added
43 |
44 | - Added support for Python 3.09.
45 |
46 | ### Fixed
47 |
48 | - Fixed SyntaxError in the `config.py` file, which was causing issues during library initialization on Python versions earlier than 3.13.
49 |
50 | ## Release v0.3.1 - September 16, 2025
51 |
52 | ### Added
53 |
54 | - Added a spinner on library load.
55 | - Added a spinner on model load.
56 | - Added a log message containing the generated model path, once model generation is complete.
57 |
58 | ### Changed
59 |
60 | - Replaced data generation `tqdm` progress bar with `rich` progress bar.
61 | - Replaced data processing `tqdm` progress bar with a `rich` spinner.
62 | - Replaced model training `tqdm` progress bar with `rich` progress bar.
63 | - Moved the output model to a dedicated `output_model` directory within the output directory.
64 |
65 | ### Removed
66 |
67 | - Removed all intermediate model checkpoints.
68 | - Removed unnecessary output files from output directory.
69 |
70 | ## Release v0.3.0 - September 9, 2025
71 |
72 | First release
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Python CI/CD
2 |
3 | on:
4 | pull_request:
5 | branches: [master]
6 | release:
7 | types: [published]
8 |
9 | jobs:
10 | test:
11 | name: Run Tests
12 | runs-on: ubuntu-latest
13 | if: github.event_name == 'pull_request'
14 |
15 | steps:
16 | - name: Check out code
17 | uses: actions/checkout@v4
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.12"
23 |
24 | - name: Install uv
25 | run: |
26 | curl -LsSf https://astral.sh/uv/install.sh | sh
27 | echo "$HOME/.local/bin" >> $GITHUB_PATH
28 |
29 | - name: Cache uv environment
30 | uses: actions/cache@v3
31 | with:
32 | path: .venv
33 | key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml') }}
34 | restore-keys: |
35 | ${{ runner.os }}-uv-
36 |
37 | - name: Sync dependencies
38 | run: uv sync
39 |
40 | - name: Create .env file for testing
41 | run: |
42 | echo "API_KEY=${{ secrets.ARTIFEX_API_KEY }}" > .env
43 |
44 | - name: Run Unit Tests
45 | run: uv run pytest tests/unit -x -v --tb=short
46 |
47 | - name: Run Integration Tests
48 | run: uv run pytest tests/integration -x -v --tb=short
49 |
50 | publish:
51 | name: Build and Publish to PyPI
52 | runs-on: ubuntu-latest
53 | if: github.event_name == 'release'
54 |
55 | permissions:
56 | contents: read
57 | id-token: write
58 |
59 | steps:
60 | - name: Check out code
61 | uses: actions/checkout@v4
62 |
63 | - name: Set up Python
64 | uses: actions/setup-python@v5
65 | with:
66 | python-version: "3.12"
67 |
68 | - name: Install build tools
69 | run: |
70 | python -m pip install --upgrade pip
71 | pip install build twine
72 |
73 | - name: Build the package
74 | run: python -m build
75 |
76 | - name: Publish to PyPI
77 | env:
78 | TWINE_USERNAME: __token__
79 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
80 | run: python -m twine upload dist/*
81 |
--------------------------------------------------------------------------------
/artifex/config.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable
2 | from pydantic_settings import BaseSettings, SettingsConfigDict
3 | import os
4 | from datetime import datetime
5 | from tzlocal import get_localzone
6 | from pydantic import Field
7 |
8 |
9 | class Config(BaseSettings):
10 |
11 | # Artifex settings
12 | API_KEY: Optional[str] = None
13 | output_path_factory: Callable[[], str] = Field(
14 | default_factory=lambda:
15 | lambda: f"{os.getcwd()}/artifex_output/run-{datetime.now(tz=get_localzone()).strftime('%Y%m%d%H%M%S')}/"
16 | )
17 | @property
18 | def DEFAULT_OUTPUT_PATH(self) -> str:
19 | return self.output_path_factory()
20 |
21 | # Artifex error messages
22 | DATA_GENERATION_ERROR: str = "An error occurred while generating training data. This may be due to an intense load on the system. Please try again later."
23 |
24 | # Synthex settings
25 | DEFAULT_SYNTHEX_DATAPOINT_NUM: int = 500
26 | DEFAULT_SYNTHEX_DATASET_FORMAT: str = "csv"
27 | @property
28 | def DEFAULT_SYNTHEX_DATASET_NAME(self) -> str:
29 | return f"train_data.{self.DEFAULT_SYNTHEX_DATASET_FORMAT}"
30 | # Leave empty to put the output model directly in the output folder (no subfolder)
31 | SYNTHEX_OUTPUT_MODEL_FOLDER_NAME: str = ""
32 |
33 | # HuggingFace settings
34 | DEFAULT_HUGGINGFACE_LOGGING_LEVEL: str = "error"
35 |
36 | # Base Model
37 | DEFAULT_TOKENIZER_MAX_LENGTH: int = 256
38 |
39 | # Classification Model
40 | CLASSIFICATION_CLASS_NAME_MAX_LENGTH: int = 20
41 | CLASSIFICATION_HF_BASE_MODEL: str = "microsoft/Multilingual-MiniLM-L12-H384"
42 |
43 | # Guardrail Model
44 | GUARDRAIL_HF_BASE_MODEL: str = "tanaos/tanaos-guardrail-v1"
45 |
46 | # IntentClassifier Model
47 | INTENT_CLASSIFIER_HF_BASE_MODEL: str = "tanaos/tanaos-intent-classifier-v1"
48 |
49 | # Reranker Model
50 | RERANKER_HF_BASE_MODEL: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
51 | RERANKER_TOKENIZER_MAX_LENGTH: int = 256
52 |
53 | # Sentiment Analysis Model
54 | SENTIMENT_ANALYSIS_HF_BASE_MODEL: str = "tanaos/tanaos-sentiment-analysis-v1"
55 |
56 | # Emotion Detection Model
57 | EMOTION_DETECTION_HF_BASE_MODEL: str = "tanaos/tanaos-emotion-detection-v1"
58 |
59 | # Text Anonymization Model
60 | TEXT_ANONYMIZATION_HF_BASE_MODEL: str = "tanaos/tanaos-text-anonymizer-v1"
61 | DEFAULT_TEXT_ANONYM_MASK: str = "[MASKED]"
62 |
63 | # Named Entity Recognition Model
64 | NER_HF_BASE_MODEL: str = "tanaos/tanaos-NER-v1"
65 | NER_TOKENIZER_MAX_LENGTH: int = 256
66 | NER_TAGNAME_MAX_LENGTH: int = 20
67 |
68 | model_config = SettingsConfigDict(
69 | env_file=".env",
70 | env_prefix="",
71 | extra="allow",
72 | )
73 |
74 |
75 | config = Config()
76 |
--------------------------------------------------------------------------------
/tests/unit/classification_model/test_cm__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from transformers import PreTrainedModel, PreTrainedTokenizerBase
4 | from artifex.models import ClassificationModel
5 | from datasets import ClassLabel
6 |
7 |
8 | def test_classification_model_init(mocker: MockerFixture):
9 | """
10 | Unit test for ClassificationModel.__init__.
11 | Args:
12 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies.
13 | """
14 |
15 | # Mock Synthex
16 | mock_synthex = mocker.Mock()
17 | # Mock config
18 | mock_config = mocker.patch("artifex.models.classification.classification_model.config")
19 | mock_config.CLASSIFICATION_HF_BASE_MODEL = "mocked-base-model"
20 |
21 | # Patch Hugging Face model/tokenizer loading at the correct import path
22 | mock_model = mocker.Mock(spec=PreTrainedModel)
23 | mock_model.config = mocker.Mock(id2label={0: "label"})
24 | mocker.patch(
25 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
26 | return_value=mock_model
27 | )
28 | mock_tokenizer = mocker.Mock(spec=PreTrainedTokenizerBase)
29 | mocker.patch(
30 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained",
31 | return_value=mock_tokenizer
32 | )
33 | # Patch BaseModel.__init__ so it doesn't do anything
34 | mock_super_init = mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None)
35 |
36 | # Instantiate ClassificationModel
37 | model = ClassificationModel(mock_synthex)
38 |
39 | # Assert BaseModel.__init__ was called with correct args
40 | mock_super_init.assert_called_once_with(mock_synthex)
41 | # Assert _system_data_gen_instr is set correctly
42 | assert isinstance(model._system_data_gen_instr_val, list)
43 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val)
44 | # Assert _token_keys_val is set correctly
45 | assert isinstance(model._token_keys_val, list) and isinstance(model._token_keys_val[0], str)
46 | assert len(model._token_keys_val) == 1
47 | # Assert _synthetic_data_schema_val is set correctly
48 | assert isinstance(model._synthetic_data_schema_val, dict)
49 | assert "text" in model._synthetic_data_schema_val
50 | assert "labels" in model._synthetic_data_schema_val
51 | # Assert that _base_model_name_val is set correctly
52 | assert model._base_model_name_val == "mocked-base-model"
53 | # Assert that _model_val and _tokenizer_val are initialized correctly
54 | assert isinstance(model._model_val, PreTrainedModel)
55 | assert isinstance(model._tokenizer_val, PreTrainedTokenizerBase)
56 | # Assert that _labels_val is initialized correctly
57 | assert isinstance(model._labels_val, ClassLabel)
--------------------------------------------------------------------------------
/artifex/core/decorators.py:
--------------------------------------------------------------------------------
1 | from pydantic import validate_call, ValidationError
2 | from typing import Any, Callable, TypeVar
3 | from functools import wraps
4 | import inspect
5 |
6 |
7 | T = TypeVar("T", bound=type)
8 |
9 |
10 | def should_skip_method(attr: Any, attr_name: str) -> bool:
11 | """
12 | Determines whether a class attribute should be skipped based on its name and signature.
13 | This function skips:
14 | - Dunder (double underscore) methods, except for '__call__'.
15 | - Attributes that are not callable.
16 | - Methods that only have 'self' as their parameter.
17 | Args:
18 | cls (T): The class containing the attribute.
19 | attr_name (str): The name of the attribute to check.
20 | Returns:
21 | bool: True if the attribute should be skipped, False otherwise.
22 | """
23 |
24 | # Skip dunder methods, except the __call__ method
25 | if attr_name.startswith("__") and attr_name != "__call__":
26 | return True
27 |
28 | if not callable(attr):
29 | return True
30 |
31 | # Get method signature and skip methods that only have 'self' as parameter
32 | sig = inspect.signature(attr)
33 | params = list(sig.parameters.values())
34 | if len(params) <= 1 and params[0].name == "self":
35 | return True
36 |
37 | return False
38 |
39 | def auto_validate_methods(cls: T) -> T:
40 | """
41 | A class decorator that combines Pydantic's `validate_call` for input validation
42 | and automatic handling of validation errors, raising a custom `ArtifexValidationError`.
43 | """
44 |
45 | from artifex.core import ValidationError as ArtifexValidationError
46 |
47 | for attr_name in dir(cls):
48 | # Use getattr_static to avoid triggering descriptors
49 | raw_attr = inspect.getattr_static(cls, attr_name)
50 | attr = getattr(cls, attr_name)
51 |
52 | is_static = isinstance(raw_attr, staticmethod)
53 | is_class = isinstance(raw_attr, classmethod)
54 |
55 | # Unwrap only if it's a staticmethod/classmethod object
56 | if is_static or is_class:
57 | func = raw_attr.__func__
58 | else:
59 | func = attr
60 |
61 | if should_skip_method(func, attr_name):
62 | continue
63 |
64 | validated = validate_call(config={"arbitrary_types_allowed": True})(func)
65 |
66 | @wraps(func)
67 | def wrapper(*args: Any, __f: Callable[..., Any] = validated, **kwargs: Any) -> Any:
68 | try:
69 | return __f(*args, **kwargs)
70 | except ValidationError as e:
71 | raise ArtifexValidationError(f"Invalid input: {e}")
72 |
73 | # Re-wrap as staticmethod/classmethod if needed
74 | if is_static:
75 | setattr(cls, attr_name, staticmethod(wrapper))
76 | elif is_class:
77 | setattr(cls, attr_name, classmethod(wrapper))
78 | else:
79 | setattr(cls, attr_name, wrapper)
80 |
81 | return cls
--------------------------------------------------------------------------------
/artifex/core/_hf_patches.py:
--------------------------------------------------------------------------------
1 | from transformers import Trainer, TrainerState, TrainingArguments, TrainerCallback, TrainerControl
2 | from transformers.trainer_utils import TrainOutput
3 | from typing import Any, Dict
4 | from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TaskID
5 | from rich.console import Console
6 |
7 | """
8 | Patches for HuggingFace classes to improve user experience.
9 | """
10 |
11 | console = Console()
12 |
13 | class SilentTrainer(Trainer):
14 | """
15 | A regular transformers.Trainer which prevents the tedious final training summary dictionary
16 | from being printed to the console (since as of now, there is no built-in way to disable it).
17 | """
18 |
19 | def train(self, *args: Any, **kwargs: Any) -> TrainOutput:
20 | import builtins
21 | orig_print = builtins.print
22 |
23 | def silent_print(*a: Any, **k: Any) -> None:
24 | # Only suppress the summary dictionary
25 | if (
26 | len(a) == 1
27 | and isinstance(a[0], dict)
28 | and "train_runtime" in a[0]
29 | ):
30 | return
31 | return orig_print(*a, **k)
32 |
33 | builtins.print = silent_print
34 | try:
35 | return super().train(*args, **kwargs)
36 | finally:
37 | builtins.print = orig_print
38 |
39 |
40 | class RichProgressCallback(TrainerCallback):
41 | """
42 | A custom TrainerCallback that uses Rich to display a progress bar during training.
43 | """
44 |
45 | progress: Progress
46 | task: int
47 |
48 | def on_train_begin(
49 | self,
50 | args: TrainingArguments,
51 | state: TrainerState,
52 | control: TrainerControl,
53 | **kwargs: Dict[str, Any]
54 | ) -> None:
55 | """
56 | Called at the beginning of training.
57 | """
58 |
59 | self.progress = Progress(
60 | TextColumn("[bold blue]{task.description}"),
61 | BarColumn(),
62 | "[progress.percentage]{task.percentage:>3.0f}%",
63 | TimeElapsedColumn(),
64 | transient=True
65 | )
66 | self.task = self.progress.add_task("Training model...", total=state.max_steps)
67 | self.progress.start()
68 |
69 | def on_step_end(
70 | self,
71 | args: TrainingArguments,
72 | state: TrainerState,
73 | control: TrainerControl,
74 | **kwargs: Dict[str, Any]
75 | ) -> None:
76 | """
77 | Called at the end of each training step.
78 | """
79 |
80 | self.progress.update(TaskID(self.task), completed=state.global_step)
81 |
82 | def on_train_end(
83 | self,
84 | args: TrainingArguments,
85 | state: TrainerState,
86 | control: TrainerControl,
87 | **kwargs: Dict[str, Any]
88 | ) -> None:
89 | """
90 | Called at the end of training.
91 | """
92 |
93 | self.progress.stop()
94 | console.print("[green]✔ Training model[/green]")
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Artifex
2 |
3 | Thanks for your interest in contributing to **Artifex**!
4 | This project is in its early stages, and **we welcome ideas, issues, and pull requests** of all kinds.
5 |
6 | ## Before You Contribute
7 |
8 | All contributions must be made under the terms of [this repository's Contributor License Agreement (CLA)](CLA.md). This ensures that we can safely include your work in the project while keeping it open source under the MIT license. All contributors must sign the CLA. The process is fully automated and handled by [CLA Assistant](https://cla-assistant.io/):
9 |
10 | - When you open your first Pull Request, CLA Assistant will check if you have signed the CLA.
11 | - You will be prompted to click **“I have read the CLA Document and I hereby sign the CLA”**. This is sufficient — no manual signature is required.
12 |
13 | Please ensure that:
14 | - Your contributions are your original work, or that you have the right to submit them.
15 | - You have permission from your employer or any third-party if your contributions are covered by their rights.
16 |
17 | ## What To Contribute On
18 |
19 | We welcome contributions of any kind, **both in the form of [new issues](https://github.com/tanaos/artifex/issues)** and **new code**. Typical contributions are done in one of two ways:
20 |
21 | 1. When using the library you **come across a problem or come up with a possible improvement**. First of all, you should check the [Issues Tab](https://github.com/tanaos/artifex/issues) to see if an issue that addresses the same shortcoming is already present. If that is the case, you can jump to option number 2, otherwise you can either:
22 | - [Open a new issue](https://github.com/tanaos/artifex/issues/new) for the shortcoming you have identified (opening issue, without necessarily working on them, is a productive way of contributing to open source code!)
23 | - [Open a new issue](https://github.com/tanaos/artifex/issues/new) **and** start working on it. You can indicate that you will take care of the newly opened issue by either assigning it to yourself or stating it in the comments.
24 |
25 | 2. You visit the [Issues Tab](https://github.com/tanaos/artifex/issues) and **look for a known issue you are interested in working on**. Simple issues, perfect for developers who have never / very seldom contributed to open source code, are marked with a badge. Once you have found an issue that you like, indicate that you will be working on it by either assigning it to yourself or stating so in the comments.
26 |
27 | ## How To Contribute New Code
28 |
29 | Direct pushes to the `master` branch are not permitted. In order to contribute new code, please **follow the standard fork --> push --> pull request workflow**:
30 |
31 | 1. Fork the repository
32 | 2. Create a new branch (`git checkout -b feature/my-feature`)
33 | 3. Make your changes
34 | 4. Commit your changes (`git commit -m "Add feature"`)
35 | 5. Push to your fork (`git push origin feature/my-feature`)
36 | 6. Open a pull request
37 |
38 | ## Guidelines
39 |
40 | - Keep your code clean and consistent with existing style.
41 | - Add or update docstrings and comments where helpful.
42 | - If applicable, write or update tests.
43 | - Be constructive in discussions.
44 |
45 | ## Questions?
46 |
47 | Feel free to open an issue or start a discussion if you're not sure where to begin. We're happy to help!
48 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.8.1
2 | aiohappyeyeballs==2.6.1
3 | aiohttp==3.12.13
4 | aiosignal==1.4.0
5 | annotated-types==0.7.0
6 | anyio==4.9.0
7 | argon2-cffi==25.1.0
8 | argon2-cffi-bindings==21.2.0
9 | arrow==1.3.0
10 | asttokens==3.0.0
11 | async-lru==2.0.5
12 | attrs==25.3.0
13 | babel==2.17.0
14 | beautifulsoup4==4.13.4
15 | bleach==6.2.0
16 | certifi==2025.6.15
17 | cffi==1.17.1
18 | charset-normalizer==3.4.2
19 | comm==0.2.2
20 | datasets==3.6.0
21 | debugpy==1.8.14
22 | decorator==5.2.1
23 | defusedxml==0.7.1
24 | dill==0.3.8
25 | executing==2.2.0
26 | fastjsonschema==2.21.1
27 | filelock==3.18.0
28 | fqdn==1.5.1
29 | frozenlist==1.7.0
30 | fsspec==2025.3.0
31 | h11==0.16.0
32 | hf-xet==1.1.5
33 | httpcore==1.0.9
34 | httpx==0.28.1
35 | huggingface-hub==0.33.2
36 | idna==3.10
37 | iniconfig==2.1.0
38 | ipykernel==6.29.5
39 | ipython==9.4.0
40 | ipython-pygments-lexers==1.1.1
41 | isoduration==20.11.0
42 | jedi==0.19.2
43 | jinja2==3.1.6
44 | json5==0.12.0
45 | jsonpointer==3.0.0
46 | jsonschema==4.24.0
47 | jsonschema-specifications==2025.4.1
48 | jupyter-client==8.6.3
49 | jupyter-core==5.8.1
50 | jupyter-events==0.12.0
51 | jupyter-lsp==2.2.5
52 | jupyter-server==2.16.0
53 | jupyter-server-terminals==0.5.3
54 | jupyterlab==4.4.4
55 | jupyterlab-pygments==0.3.0
56 | jupyterlab-server==2.27.3
57 | markupsafe==3.0.2
58 | matplotlib-inline==0.1.7
59 | mistune==3.1.3
60 | mpmath==1.3.0
61 | multidict==6.6.3
62 | multiprocess==0.70.16
63 | nbclient==0.10.2
64 | nbconvert==7.16.6
65 | nbformat==5.10.4
66 | nest-asyncio==1.6.0
67 | networkx==3.5
68 | notebook==7.4.4
69 | notebook-shim==0.2.4
70 | numpy==2.3.1
71 | nvidia-cublas-cu12==12.6.4.1
72 | nvidia-cuda-cupti-cu12==12.6.80
73 | nvidia-cuda-nvrtc-cu12==12.6.77
74 | nvidia-cuda-runtime-cu12==12.6.77
75 | nvidia-cudnn-cu12==9.5.1.17
76 | nvidia-cufft-cu12==11.3.0.4
77 | nvidia-cufile-cu12==1.11.1.6
78 | nvidia-curand-cu12==10.3.7.77
79 | nvidia-cusolver-cu12==11.7.1.2
80 | nvidia-cusparse-cu12==12.5.4.2
81 | nvidia-cusparselt-cu12==0.6.3
82 | nvidia-nccl-cu12==2.26.2
83 | nvidia-nvjitlink-cu12==12.6.85
84 | nvidia-nvtx-cu12==12.6.77
85 | overrides==7.7.0
86 | packaging==25.0
87 | pandas==2.3.0
88 | pandocfilters==1.5.1
89 | parso==0.8.4
90 | pexpect==4.9.0
91 | platformdirs==4.3.8
92 | pluggy==1.6.0
93 | prometheus-client==0.22.1
94 | prompt-toolkit==3.0.51
95 | propcache==0.3.2
96 | psutil==7.0.0
97 | ptyprocess==0.7.0
98 | pure-eval==0.2.3
99 | pyarrow==20.0.0
100 | pycparser==2.22
101 | pydantic==2.11.7
102 | pydantic-core==2.33.2
103 | pydantic-settings==2.10.1
104 | pygments==2.19.2
105 | pytest==8.4.1
106 | pytest-mock==3.14.1
107 | python-dateutil==2.9.0.post0
108 | python-dotenv==1.1.1
109 | python-json-logger==3.3.0
110 | pytz==2025.2
111 | pyyaml==6.0.2
112 | pyzmq==27.0.0
113 | referencing==0.36.2
114 | regex==2024.11.6
115 | requests==2.32.4
116 | responses==0.25.7
117 | rfc3339-validator==0.1.4
118 | rfc3986-validator==0.1.1
119 | rpds-py==0.26.0
120 | safetensors==0.5.3
121 | send2trash==1.8.3
122 | setuptools==80.9.0
123 | six==1.17.0
124 | sniffio==1.3.1
125 | soupsieve==2.7
126 | stack-data==0.6.3
127 | sympy==1.14.0
128 | synthex==0.3.1
129 | terminado==0.18.1
130 | tinycss2==1.4.0
131 | tokenizers==0.21.2
132 | torch==2.7.1
133 | tornado==6.5.1
134 | tqdm==4.67.1
135 | traitlets==5.14.3
136 | transformers==4.53.1
137 | triton==3.3.1
138 | types-python-dateutil==2.9.0.20250516
139 | typing-extensions==4.14.0
140 | typing-inspection==0.4.1
141 | tzdata==2025.2
142 | tzlocal==5.3.1
143 | uri-template==1.3.0
144 | urllib3==2.5.0
145 | wcwidth==0.2.13
146 | webcolors==24.11.1
147 | webencodings==0.5.1
148 | websocket-client==1.8.0
149 | xxhash==3.5.0
150 | yarl==1.20.1
151 |
--------------------------------------------------------------------------------
/artifex/models/classification/multi_class_classification/sentiment_analysis/__init__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | from transformers.trainer_utils import TrainOutput
3 | from typing import Optional
4 |
5 | from ...classification_model import ClassificationModel
6 |
7 | from artifex.core import auto_validate_methods
8 | from artifex.config import config
9 |
10 |
11 | @auto_validate_methods
12 | class SentimentAnalysis(ClassificationModel):
13 | """
14 | A Sentiment Analysis Model is used to classify the sentiment of a given text into predefined
15 | categories, typically `positive`, `negative`, or `neutral`. In this implementation, we
16 | support two extra sentiment categories: `very_positive` and `very_negative`.
17 | """
18 |
19 | def __init__(self, synthex: Synthex):
20 | """
21 | Initializes the class with a Synthex instance.
22 | Args:
23 | synthex (Synthex): An instance of the Synthex class to generate the synthetic
24 | data used to train the model.
25 | """
26 | super().__init__(synthex, base_model_name=config.SENTIMENT_ANALYSIS_HF_BASE_MODEL)
27 | self._system_data_gen_instr_val: list[str] = [
28 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.",
29 | "The 'text' field should contain text that may or may not express a certain sentiment.",
30 | "The 'labels' field should contain a label indicating the sentiment of the 'text'.",
31 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.",
32 | "This is a list of the allowed 'labels' and their meaning: "
33 | ]
34 |
35 | def train(
36 | self, domain: str, classes: Optional[dict[str, str]] = None,
37 | output_path: Optional[str] = None, num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM,
38 | num_epochs: int = 3
39 | ) -> TrainOutput:
40 | f"""
41 | Train the Sentiment Analysis model using synthetic data generated by Synthex.
42 |
43 | NOTE: this method overrides `ClassificationModel.train()` to make the `classes`
44 | parameter optional.
45 |
46 | Args:
47 | domain (str): A description of the domain or context for which the model is being trained.
48 | classes (dict[str, str]): A dictionary mapping class names to their descriptions. The keys
49 | (class names) must be string with no spaces and a maximum length of
50 | {config.CLASSIFICATION_CLASS_NAME_MAX_LENGTH} characters.
51 | output_path (Optional[str]): The path where the generated synthetic data will be saved.
52 | num_samples (int): The number of training data samples to generate.
53 | num_epochs (int): The number of epochs for training the model.
54 | """
55 |
56 | if classes is None:
57 | classes = {
58 | "very_negative": "Text that expresses a very negative sentiment or strong dissatisfaction.",
59 | "negative": "Text that expresses a negative sentiment or dissatisfaction.",
60 | "neutral": "Either a text that does not express any sentiment at all, or a text that expresses a neutral sentiment or lack of strong feelings.",
61 | "positive": "Text that expresses a positive sentiment or satisfaction.",
62 | "very_positive": "Text that expresses a very positive sentiment or strong satisfaction."
63 | }
64 |
65 | return super().train(
66 | domain=domain, classes=classes, output_path=output_path,
67 | num_samples=num_samples, num_epochs=num_epochs
68 | )
69 |
--------------------------------------------------------------------------------
/artifex/models/classification/binary_classification/guardrail/__init__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | from synthex.models import JobOutputSchemaDefinition
3 | from typing import Optional
4 | from transformers.trainer_utils import TrainOutput
5 |
6 | from ...classification_model import ClassificationModel
7 |
8 | from artifex.core import auto_validate_methods
9 | from artifex.config import config
10 |
11 |
12 | @auto_validate_methods
13 | class Guardrail(ClassificationModel):
14 | """
15 | A Guardrail Model for LLMs. A Guardrail is a model that can be used to classify the output of a LLM
16 | as safe or unsafe, depending on a user's definition of what is safe or unsafe.
17 | """
18 |
19 | def __init__(self, synthex: Synthex):
20 | """
21 | Initializes the class with a Synthex instance.
22 | Args:
23 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used to train
24 | the model.
25 | """
26 |
27 | super().__init__(synthex, base_model_name=config.GUARDRAIL_HF_BASE_MODEL)
28 | self._system_data_gen_instr_val: list[str] = [
29 | "the 'text' field should contain two types of text: either user queries or LLM-generated responses to those queries.",
30 | "the 'labels' field should contain a label indicating whether the 'text' is safe or unsafe.",
31 | "the 'labels' field can only have one of two values: either 'safe' or 'unsafe'",
32 | "the following content is considered 'unsafe': {unsafe_content}. Everything else is considered 'safe'.",
33 | "the dataset should contain an approximately equal number of safe and unsafe 'text'",
34 | "the dataset should also contain arbitrary 'text', even if not explicitly mentioned in these instructions, but its 'labels' must reflect the actual safety of that text",
35 | ]
36 |
37 | def _get_data_gen_instr(self, user_instr: list[str]) -> list[str]:
38 | """
39 | Overrides `ClassificationModel._get_data_gen_instr` to account for the different structure of
40 | `Guardrail.train`.
41 | Args:
42 | user_instr (list[str]): A list of user instructions where the last element is the
43 | domain string, and preceding elements are class names and their descriptions.
44 | Returns:
45 | list[str]: A list containing the formatted system instructions followed by the
46 | class-related instructions (all elements except the domain).
47 | """
48 |
49 | unsafe_content = "; ".join(user_instr)
50 | out = [instr.format(unsafe_content=unsafe_content) for instr in self._system_data_gen_instr_val]
51 | return out
52 |
53 | def train(
54 | self, unsafe_content: list[str], output_path: Optional[str] = None,
55 | num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM, num_epochs: int = 3
56 | ) -> TrainOutput:
57 | f"""
58 | Overrides `ClassificationModel.train` to remove the `domain` and `classes` arguments and
59 | add the `unsafe_content` argument.
60 | Args:
61 | unsafe_content (list[str]): A list of strings describing content that should be
62 | classified as unsafe by the Guardrail model.
63 | output_path (Optional[str]): The path where the synthetic training data and the
64 | output model will be saved.
65 | num_samples (int): The number of training data samples to generate.
66 | num_epochs (int): The number of epochs for training the model.
67 | """
68 |
69 | output: TrainOutput = self._train_pipeline(
70 | user_instructions=unsafe_content, output_path=output_path, num_samples=num_samples,
71 | num_epochs=num_epochs
72 | )
73 |
74 | return output
--------------------------------------------------------------------------------
/tests/unit/core/decorators/test_should_skip_method.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from artifex.core.decorators import should_skip_method
3 | from typing import Any
4 | from pytest_mock import MockerFixture
5 | import inspect
6 |
7 |
8 | def dummy_func(self, x): pass
9 | def dummy_static(x): pass
10 | def dummy_class(cls, x): pass
11 |
12 | class Dummy:
13 | def method(self, x): pass
14 | def __call__(self, x): pass
15 | def __str__(self): pass
16 | @staticmethod
17 | def static_method(x): pass
18 | @classmethod
19 | def class_method(cls, x): pass
20 | def only_self(self): pass
21 |
22 | @pytest.mark.parametrize(
23 | "attr_name,attr,expected",
24 | [
25 | ("__str__", Dummy.__str__, True), # dunder, skip
26 | ("__call__", Dummy.__call__, False), # __call__, don't skip
27 | ("method", Dummy.method, False), # normal method, don't skip
28 | ("static_method", Dummy.static_method, False), # staticmethod, don't skip
29 | ("class_method", Dummy.class_method, False), # classmethod, don't skip
30 | ("only_self", Dummy.only_self, True), # only self param, skip
31 | ]
32 | )
33 | def test_should_skip_method(
34 | mocker: MockerFixture,
35 | attr_name: str,
36 | attr: Any,
37 | expected: bool
38 | ):
39 | """
40 | Unit test for should_skip_method. Mocks inspect.signature and callable checks.
41 | Args:
42 | mocker (MockerFixture): pytest-mock fixture for mocking.
43 | attr_name (str): Attribute name to test.
44 | attr (Any): Attribute object to test.
45 | expected (bool): Expected result from should_skip_method.
46 | """
47 |
48 | # Mock callable
49 | mocker.patch("builtins.callable", return_value=True)
50 |
51 | # Mock inspect.signature
52 | if attr_name == "only_self":
53 | mock_sig = mocker.Mock()
54 | mock_param = mocker.Mock()
55 | mock_param.name = "self"
56 | mock_sig.parameters.values.return_value = [mock_param]
57 | mocker.patch("inspect.signature", return_value=mock_sig)
58 | elif attr_name == "static_method":
59 | mock_sig = mocker.Mock()
60 | mock_param = mocker.Mock()
61 | mock_param.name = "x"
62 | mock_sig.parameters.values.return_value = [mock_param]
63 | mocker.patch("inspect.signature", return_value=mock_sig)
64 | elif attr_name == "class_method":
65 | mock_sig = mocker.Mock()
66 | mock_param_cls = mocker.Mock()
67 | mock_param_cls.name = "cls"
68 | mock_param_x = mocker.Mock()
69 | mock_param_x.name = "x"
70 | mock_sig.parameters.values.return_value = [mock_param_cls, mock_param_x]
71 | mocker.patch("inspect.signature", return_value=mock_sig)
72 | elif attr_name == "method":
73 | mock_sig = mocker.Mock()
74 | mock_param_self = mocker.Mock()
75 | mock_param_self.name = "self"
76 | mock_param_x = mocker.Mock()
77 | mock_param_x.name = "x"
78 | mock_sig.parameters.values.return_value = [mock_param_self, mock_param_x]
79 | mocker.patch("inspect.signature", return_value=mock_sig)
80 | elif attr_name == "__call__":
81 | mock_sig = mocker.Mock()
82 | mock_param_self = mocker.Mock()
83 | mock_param_self.name = "self"
84 | mock_param_x = mocker.Mock()
85 | mock_param_x.name = "x"
86 | mock_sig.parameters.values.return_value = [mock_param_self, mock_param_x]
87 | mocker.patch("inspect.signature", return_value=mock_sig)
88 | elif attr_name == "__str__":
89 | mock_sig = mocker.Mock()
90 | mock_param_self = mocker.Mock()
91 | mock_param_self.name = "self"
92 | mock_sig.parameters.values.return_value = [mock_param_self]
93 | mocker.patch("inspect.signature", return_value=mock_sig)
94 |
95 | result = should_skip_method(attr, attr_name)
96 | assert result is expected
--------------------------------------------------------------------------------
/tests/unit/core/decorators/test_auto_validate_methods.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from typing import Any
3 | from pytest_mock import MockerFixture
4 |
5 | from artifex.core.decorators import auto_validate_methods
6 |
7 |
8 | class DummyValidationError(Exception):
9 | pass
10 |
11 | def test_auto_validate_methods_valid_call(mocker: MockerFixture):
12 | """
13 | Test that auto_validate_methods correctly validates method input and returns output.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for mocking.
16 | """
17 |
18 | # Mock validate_call to just return the function itself
19 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f)
20 | # Mock ArtifexValidationError
21 | mocker.patch("artifex.core.ValidationError", DummyValidationError)
22 |
23 | class TestClass:
24 | def foo(self, x: int) -> int:
25 | return x + 1
26 |
27 | decorated = auto_validate_methods(TestClass)
28 | obj = decorated()
29 | assert obj.foo(1) == 2
30 | mock_validate_call.assert_called()
31 |
32 | def test_auto_validate_methods_raises_on_validation_error(mocker: MockerFixture):
33 | """
34 | Test that auto_validate_methods raises ArtifexValidationError on validation error.
35 | Args:
36 | mocker (MockerFixture): The pytest-mock fixture for mocking.
37 | """
38 |
39 | # Mock validate_call to raise ValidationError
40 | def raise_validation_error(f):
41 | def wrapper(*args, **kwargs):
42 | raise DummyValidationError("fail")
43 | return wrapper
44 |
45 | mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: raise_validation_error)
46 | mocker.patch("artifex.core.ValidationError", DummyValidationError)
47 |
48 | class TestClass:
49 | def foo(self, x: int) -> int:
50 | return x + 1
51 |
52 | decorated = auto_validate_methods(TestClass)
53 | obj = decorated()
54 | with pytest.raises(DummyValidationError):
55 | obj.foo("bad_input")
56 |
57 | def test_auto_validate_methods_skips_dunder_methods(mocker: MockerFixture):
58 | """
59 | Test that auto_validate_methods skips dunder methods except __call__.
60 | Args:
61 | mocker (MockerFixture): The pytest-mock fixture for mocking.
62 | """
63 |
64 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f)
65 | mocker.patch("artifex.core.ValidationError", DummyValidationError)
66 |
67 | class TestClass:
68 | def __str__(self) -> str:
69 | return "test"
70 | def foo(self, x: int) -> int:
71 | return x + 1
72 |
73 | decorated = auto_validate_methods(TestClass)
74 | obj = decorated()
75 | assert obj.foo(1) == 2
76 | assert obj.__str__() == "test"
77 | # Only foo should be validated
78 | mock_validate_call.assert_any_call(config={"arbitrary_types_allowed": True})
79 |
80 | def test_auto_validate_methods_static_and_class_methods(mocker: MockerFixture):
81 | """
82 | Test that auto_validate_methods works for static and class methods.
83 | Args:
84 | mocker (MockerFixture): The pytest-mock fixture for mocking.
85 | """
86 |
87 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f)
88 | mocker.patch("artifex.core.ValidationError", DummyValidationError)
89 |
90 | class TestClass:
91 | @staticmethod
92 | def static(x: int) -> int:
93 | return x + 2
94 |
95 | @classmethod
96 | def clsmethod(cls, x: int) -> int:
97 | return x + 3
98 |
99 | decorated = auto_validate_methods(TestClass)
100 | assert decorated.static(1) == 3
101 | assert decorated.clsmethod(1) == 4
102 | mock_validate_call.assert_any_call(config={"arbitrary_types_allowed": True})
--------------------------------------------------------------------------------
/artifex/models/named_entity_recognition/text_anonymization/__init__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | from typing import Union, Optional
3 | from transformers.trainer_utils import TrainOutput
4 |
5 | from ..named_entity_recognition import NamedEntityRecognition
6 |
7 | from artifex.core import auto_validate_methods
8 | from artifex.config import config
9 |
10 |
11 | @auto_validate_methods
12 | class TextAnonymization(NamedEntityRecognition):
13 | """
14 | A Text Anonymization model is a model that removes Personal Identifiable Information (PII) from text.
15 | This class extends the NamedEntityRecognition model to specifically target and anonymize PII in text data.
16 | """
17 |
18 | def __init__(self, synthex: Synthex):
19 | """
20 | Initializes the class with a Synthex instance.
21 | Args:
22 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used
23 | to train the model.
24 | """
25 |
26 | super().__init__(synthex)
27 | self._pii_entities = {
28 | "PERSON": "Individual people, fictional characters",
29 | "LOCATION": "Geographical areas",
30 | "DATE": "Absolute or relative dates, including years, months and/or days",
31 | "ADDRESS": "full addresses",
32 | "PHONE_NUMBER": "telephone numbers",
33 | }
34 | self._maskable_entities = list(self._pii_entities.keys())
35 |
36 | def __call__(
37 | self, text: Union[str, list[str]], entities_to_mask: Optional[list[str]] = None,
38 | mask_token: str = config.DEFAULT_TEXT_ANONYM_MASK
39 | ) -> list[str]:
40 | """
41 | Anonymizes the input text by masking PII entities.
42 | Args:
43 | text (Union[str, list[str]]): The input text or list of texts to be anonymized.
44 | Returns:
45 | list[str]: A list of anonymized texts.
46 | """
47 |
48 | if entities_to_mask is None:
49 | entities_to_mask = self._maskable_entities
50 | else:
51 | for entity in entities_to_mask:
52 | if entity not in self._maskable_entities:
53 | raise ValueError(f"Entity '{entity}' cannot be masked. Allowed entities are: {self._maskable_entities}")
54 |
55 | if isinstance(text, str):
56 | text = [text]
57 |
58 | out: list[str] = []
59 |
60 | named_entities = super().__call__(text)
61 | for idx, input_text in enumerate(text):
62 | anonymized_text = input_text
63 | # Mask entities in reverse order to avoid invalidating the start/end indices
64 | for entities in reversed(named_entities[idx]):
65 | if entities.entity_group in entities_to_mask:
66 | start, end = entities.start, entities.end
67 | anonymized_text = (
68 | anonymized_text[:start] + mask_token + anonymized_text[end:]
69 | )
70 | out.append(anonymized_text)
71 |
72 | return out
73 |
74 | def train(
75 | self, domain: str, output_path: Optional[str] = None,
76 | num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM, num_epochs: int = 3
77 | ) -> TrainOutput:
78 | """
79 | Trains the Text Anonymization model. This method is identical to the
80 | NamedEntityRecognition.train method, except that named_entities are set to a predefined
81 | list of PII entities.
82 | Args:
83 | domain (str): The domain for which to train the model.
84 | output_path (Optional[str]): The path where to save the trained model. If None, a default path is used.
85 | num_samples (int): The number of synthetic samples to generate for training.
86 | num_epochs (int): The number of epochs to train the model.
87 | Returns:
88 | TrainOutput: The output of the training process.
89 | """
90 |
91 | return super().train(
92 | named_entities=self._pii_entities, domain=domain, output_path=output_path,
93 | num_samples=num_samples, num_epochs=num_epochs,
94 | train_datapoint_examples=None
95 | )
--------------------------------------------------------------------------------
/tests/unit/classification_model/test_cm_load_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from artifex.models import ClassificationModel
4 |
5 |
6 | def make_mock_model(mocker: MockerFixture, id2label=None) -> MockerFixture:
7 | mock_model = mocker.Mock()
8 | mock_config = mocker.Mock()
9 | mock_config.id2label = id2label
10 | mock_model.config = mock_config
11 | return mock_model
12 |
13 |
14 | @pytest.fixture
15 | def mock_classification_model(mocker: MockerFixture) -> ClassificationModel:
16 | """
17 | Fixture to create a ClassificationModel with all dependencies mocked.
18 | """
19 |
20 | mock_synthex = mocker.Mock()
21 | # Patch Hugging Face model/tokenizer loading for __init__
22 | mock_model = make_mock_model(mocker, {0: "labelA", 1: "labelB"})
23 | mocker.patch(
24 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
25 | return_value=mock_model
26 | )
27 | mocker.patch(
28 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained",
29 | return_value=mocker.Mock()
30 | )
31 | mocker.patch(
32 | "artifex.models.classification.classification_model.ClassLabel",
33 | return_value=mocker.Mock(names=["labelA", "labelB"])
34 | )
35 | mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None)
36 | return ClassificationModel(mock_synthex)
37 |
38 |
39 | def test_load_model_sets_model_and_labels(
40 | mocker: MockerFixture, mock_classification_model: ClassificationModel
41 | ):
42 | """
43 | Test that _load_model sets the model and labels correctly.
44 | Args:
45 | mocker (MockerFixture): The pytest-mock fixture for mocking.
46 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance.
47 | """
48 |
49 | # Prepare a mock model with id2label
50 | id2label = {0: "foo", 1: "bar"}
51 | mock_model = make_mock_model(mocker, id2label)
52 | mock_classlabel = mocker.Mock(names=["foo", "bar"])
53 | mocker.patch(
54 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
55 | return_value=mock_model
56 | )
57 | classlabel_patch = mocker.patch(
58 | "artifex.models.classification.classification_model.ClassLabel",
59 | return_value=mock_classlabel
60 | )
61 |
62 | mock_classification_model._load_model("dummy_path")
63 |
64 | # Model should be set
65 | assert mock_classification_model._model is mock_model
66 | # ClassLabel should be called with correct names
67 | classlabel_patch.assert_called_once_with(names=["foo", "bar"])
68 | # _labels should be set to the mock_classlabel
69 | assert mock_classification_model._labels is mock_classlabel
70 |
71 |
72 | def test_load_model_raises_if_id2label_missing(
73 | mocker: MockerFixture, mock_classification_model: ClassificationModel
74 | ):
75 | """
76 | Test that _load_model raises AssertionError if id2label is missing.
77 | Args:
78 | mocker (MockerFixture): The pytest-mock fixture for mocking.
79 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance.
80 | """
81 |
82 | # Prepare a mock model with no id2label
83 | mock_model = make_mock_model(mocker, id2label=None)
84 | mocker.patch(
85 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
86 | return_value=mock_model
87 | )
88 | with pytest.raises(AssertionError, match="Model config must have id2label mapping."):
89 | mock_classification_model._load_model("dummy_path")
90 |
91 |
92 | def test_load_model_passes_path(
93 | mocker: MockerFixture, mock_classification_model: ClassificationModel
94 | ):
95 | """
96 | Test that _load_model passes the correct path to from_pretrained.
97 | Args:
98 | mocker (MockerFixture): The pytest-mock fixture for mocking.
99 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance.
100 | """
101 |
102 | id2label = {0: "a", 1: "b"}
103 | mock_model = make_mock_model(mocker, id2label)
104 | from_pretrained_patch = mocker.patch(
105 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
106 | return_value=mock_model
107 | )
108 | mocker.patch(
109 | "artifex.models.classification.classification_model.ClassLabel",
110 | return_value=mocker.Mock(names=["a", "b"])
111 | )
112 | path = "some/model/path"
113 | mock_classification_model._load_model(path)
114 | from_pretrained_patch.assert_called_once_with(path)
--------------------------------------------------------------------------------
/tests/unit/classification_model/test_cm_get_data_gen_instr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from typing import List
3 | from artifex.models.classification import ClassificationModel
4 | from synthex import Synthex
5 |
6 |
7 | class DummyClassificationModel(ClassificationModel):
8 | """
9 | Dummy concrete implementation for testing ClassificationModel.
10 | """
11 |
12 | @property
13 | def _base_model_name(self) -> str:
14 | return "dummy-model"
15 |
16 | @property
17 | def _system_data_gen_instr(self) -> List[str]:
18 | return [
19 | "System instruction 1 for {domain}",
20 | "System instruction 2 for {domain}"
21 | ]
22 |
23 |
24 | @pytest.fixture
25 | def model(mocker) -> DummyClassificationModel:
26 | """
27 | Fixture that returns a DummyClassificationModel instance with mocked Synthex.
28 | Args:
29 | mocker: The pytest-mock fixture for mocking.
30 | Returns:
31 | DummyClassificationModel: An instance of the dummy model for testing.
32 | """
33 |
34 | synthex_mock = mocker.Mock(spec=Synthex)
35 | # Patch Hugging Face model/tokenizer loading
36 | mock_model = mocker.Mock()
37 | mock_model.config = mocker.Mock(id2label={0: "labelA"})
38 | mocker.patch(
39 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
40 | return_value=mock_model
41 | )
42 | mocker.patch(
43 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained",
44 | return_value=mocker.Mock()
45 | )
46 | mocker.patch(
47 | "artifex.models.classification.classification_model.ClassLabel",
48 | return_value=mocker.Mock(names=["labelA"])
49 | )
50 | mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None)
51 | return DummyClassificationModel(synthex=synthex_mock)
52 |
53 |
54 | def test_get_data_gen_instr_basic(model: DummyClassificationModel):
55 | """
56 | Test that _get_data_gen_instr correctly formats system instructions and combines them
57 | with user instructions, excluding the domain from user instructions.
58 | """
59 |
60 | user_instr = [
61 | "classA: descriptionA",
62 | "classB: descriptionB",
63 | "test-domain"
64 | ]
65 | result = model._get_data_gen_instr(user_instr)
66 | expected = [
67 | "System instruction 1 for test-domain",
68 | "System instruction 2 for test-domain",
69 | "classA: descriptionA",
70 | "classB: descriptionB"
71 | ]
72 | assert result == expected
73 |
74 |
75 | def test_get_data_gen_instr_empty_classes(model: DummyClassificationModel):
76 | """
77 | Test that _get_data_gen_instr works when there are no class instructions, only the domain.
78 | """
79 |
80 | user_instr = [
81 | "test-domain"
82 | ]
83 | result = model._get_data_gen_instr(user_instr)
84 | expected = [
85 | "System instruction 1 for test-domain",
86 | "System instruction 2 for test-domain"
87 | ]
88 | assert result == expected
89 |
90 |
91 | def test_get_data_gen_instr_multiple_classes(model: DummyClassificationModel):
92 | """
93 | Test that _get_data_gen_instr works with multiple class instructions.
94 | """
95 |
96 | user_instr = [
97 | "classA: descriptionA",
98 | "classB: descriptionB",
99 | "classC: descriptionC",
100 | "test-domain"
101 | ]
102 | result = model._get_data_gen_instr(user_instr)
103 | expected = [
104 | "System instruction 1 for test-domain",
105 | "System instruction 2 for test-domain",
106 | "classA: descriptionA",
107 | "classB: descriptionB",
108 | "classC: descriptionC"
109 | ]
110 | assert result == expected
111 |
112 |
113 | def test_get_data_gen_instr_domain_with_spaces(model: DummyClassificationModel):
114 | """
115 | Test that _get_data_gen_instr correctly formats instructions when the domain contains spaces.
116 | """
117 |
118 | user_instr = [
119 | "classA: descriptionA",
120 | "classB: descriptionB",
121 | "complex domain name"
122 | ]
123 | result = model._get_data_gen_instr(user_instr)
124 | expected = [
125 | "System instruction 1 for complex domain name",
126 | "System instruction 2 for complex domain name",
127 | "classA: descriptionA",
128 | "classB: descriptionB"
129 | ]
130 | assert result == expected
131 |
132 |
133 | def test_get_data_gen_instr_no_classes_only_domain(model: DummyClassificationModel):
134 | """
135 | Test that _get_data_gen_instr returns only system instructions when user_instr contains only the domain.
136 | """
137 |
138 | user_instr = ["domain-only"]
139 | result = model._get_data_gen_instr(user_instr)
140 | expected = [
141 | "System instruction 1 for domain-only",
142 | "System instruction 2 for domain-only"
143 | ]
144 | assert result == expected
145 |
146 |
147 | def test_get_data_gen_instr_empty_user_instr(model: DummyClassificationModel):
148 | """
149 | Test that _get_data_gen_instr raises IndexError when user_instr is empty.
150 | """
151 |
152 | with pytest.raises(IndexError):
153 | model._get_data_gen_instr([])
154 |
155 |
156 | def test_get_data_gen_instr_special_characters(model: DummyClassificationModel):
157 | """
158 | Test that _get_data_gen_instr works with special characters in class descriptions and domain.
159 | """
160 |
161 | user_instr = [
162 | "classA: descr!ption@A#",
163 | "classB: descr$ption%B^",
164 | "domain*&^%$#@!"
165 | ]
166 | result = model._get_data_gen_instr(user_instr)
167 | expected = [
168 | "System instruction 1 for domain*&^%$#@!",
169 | "System instruction 2 for domain*&^%$#@!",
170 | "classA: descr!ption@A#",
171 | "classB: descr$ption%B^"
172 | ]
173 | assert result == expected
--------------------------------------------------------------------------------
/tests/unit/guardrail/test_gr_get_data_gen_instr.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | import pytest
3 | from pytest_mock import MockerFixture
4 |
5 | from artifex.models import Guardrail
6 | from artifex.config import config
7 |
8 |
9 | @pytest.fixture(autouse=True)
10 | def mock_dependencies(mocker: MockerFixture):
11 | """
12 | Fixture to mock all external dependencies before any test runs.
13 | This fixture runs automatically for all tests in this module.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for mocking.
16 | """
17 | # Mock config - patch before import
18 | mocker.patch.object(config, "GUARDRAIL_HF_BASE_MODEL", "mock-guardrail-model")
19 |
20 | # Mock AutoTokenizer - must be at transformers module level
21 | mock_tokenizer = mocker.MagicMock()
22 | mocker.patch(
23 | "transformers.AutoTokenizer.from_pretrained",
24 | return_value=mock_tokenizer
25 | )
26 |
27 | # Mock ClassLabel
28 | mocker.patch("datasets.ClassLabel", return_value=mocker.MagicMock())
29 |
30 | # Mock AutoModelForSequenceClassification if used by parent class
31 | mock_model = mocker.MagicMock()
32 | mock_model.config.id2label.values.return_value = ["safe", "unsafe"]
33 | mocker.patch(
34 | "transformers.AutoModelForSequenceClassification.from_pretrained",
35 | return_value=mock_model
36 | )
37 |
38 | # Mock Trainer if used
39 | mocker.patch("transformers.Trainer")
40 |
41 | # Mock TrainingArguments if used
42 | mocker.patch("transformers.TrainingArguments")
43 |
44 | @pytest.fixture
45 | def mock_synthex(mocker: MockerFixture) -> Synthex:
46 | """
47 | Fixture to create a mock Synthex instance.
48 | Args:
49 | mocker (MockerFixture): The pytest-mock fixture for mocking.
50 | Returns:
51 | Synthex: A mocked Synthex instance.
52 | """
53 |
54 | return mocker.MagicMock(spec=Synthex)
55 |
56 |
57 | @pytest.fixture
58 | def mock_guardrail(mock_synthex: Synthex) -> Guardrail:
59 | """
60 | Fixture to create a Guardrail instance with mocked dependencies.
61 | Args:
62 | mock_synthex (Synthex): A mocked Synthex instance.
63 | Returns:
64 | Guardrail: An instance of the Guardrail model with mocked dependencies.
65 | """
66 |
67 | return Guardrail(mock_synthex)
68 |
69 |
70 | @pytest.mark.unit
71 | def test_get_data_gen_instr_success(mock_guardrail: Guardrail):
72 | """
73 | Test that the _get_data_gen_instr method correctly combines system and user
74 | instructions into a single list.
75 | Args:
76 | mock_guardrail (Guardrail): The Guardrail instance to test.
77 | """
78 |
79 | user_instr_1 = "do not allow profanity"
80 | user_instr_2 = "do not allow personal information"
81 |
82 | user_instructions = [user_instr_1, user_instr_2]
83 |
84 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions)
85 |
86 | # Assert that the combined instructions are a list
87 | assert isinstance(combined_instr, list)
88 |
89 | # The total length should be that of the system instructions
90 | expected_length = len(mock_guardrail._system_data_gen_instr)
91 | assert len(combined_instr) == expected_length
92 |
93 | # User instructions should be embedded in the fourth system instruction
94 | unsafe_content_formatted = "; ".join(user_instructions)
95 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content=unsafe_content_formatted)
96 |
97 |
98 | @pytest.mark.unit
99 | def test_get_data_gen_instr_empty_user_instructions(mock_guardrail: Guardrail):
100 | """
101 | Test that the _get_data_gen_instr method handles empty user instructions list.
102 | Args:
103 | mock_guardrail (Guardrail): The Guardrail instance to test.
104 | """
105 |
106 | user_instructions = []
107 |
108 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions)
109 |
110 | assert len(combined_instr) == len(mock_guardrail._system_data_gen_instr)
111 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content="")
112 |
113 |
114 | @pytest.mark.unit
115 | def test_get_data_gen_instr_single_user_instruction(mock_guardrail: Guardrail):
116 | """
117 | Test that the _get_data_gen_instr method handles a single user instruction.
118 | Args:
119 | mock_guardrail (Guardrail): The Guardrail instance to test.
120 | """
121 |
122 | user_instr = "block hate speech"
123 | user_instructions = [user_instr]
124 |
125 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions)
126 |
127 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content=user_instr)
128 |
129 |
130 | @pytest.mark.unit
131 | def test_get_data_gen_instr_validation_failure(mock_guardrail: Guardrail):
132 | """
133 | Test that the _get_data_gen_instr method raises a ValidationError when provided
134 | with invalid user instructions (not a list).
135 | Args:
136 | mock_guardrail (Guardrail): The Guardrail instance to test.
137 | """
138 |
139 | from artifex.core import ValidationError
140 |
141 | with pytest.raises(ValidationError):
142 | mock_guardrail._get_data_gen_instr("invalid instructions")
143 |
144 |
145 | @pytest.mark.unit
146 | def test_get_data_gen_instr_does_not_modify_original_lists(mock_guardrail: Guardrail):
147 | """
148 | Test that the _get_data_gen_instr method does not modify the original lists.
149 | Args:
150 | mock_guardrail (Guardrail): The Guardrail instance to test.
151 | """
152 |
153 | user_instructions = ["instruction1", "instruction2"]
154 | original_user_instr = user_instructions.copy()
155 | original_system_instr = mock_guardrail._system_data_gen_instr.copy()
156 |
157 | mock_guardrail._get_data_gen_instr(user_instructions)
158 |
159 | # Verify original lists are unchanged
160 | assert user_instructions == original_user_instr
161 | assert mock_guardrail._system_data_gen_instr == original_system_instr
--------------------------------------------------------------------------------
/artifex/__init__.py:
--------------------------------------------------------------------------------
1 | from rich.console import Console
2 |
3 | console = Console()
4 |
5 | with console.status("Initializing Artifex..."):
6 | from synthex import Synthex
7 | from typing import Optional
8 | from transformers import logging as hf_logging
9 | import datasets
10 |
11 | from .core import auto_validate_methods
12 | from .models.classification import ClassificationModel, Guardrail, IntentClassifier, \
13 | SentimentAnalysis, EmotionDetection
14 | from .models.named_entity_recognition import NamedEntityRecognition, TextAnonymization
15 | from .models.reranker import Reranker
16 | from .config import config
17 | console.print(f"[green]✔ Initializing Artifex[/green]")
18 |
19 |
20 | if config.DEFAULT_HUGGINGFACE_LOGGING_LEVEL.lower() == "error":
21 | hf_logging.set_verbosity_error()
22 |
23 | # Disable the progress bar from the datasets library, as it interferes with rich's progress bar.
24 | datasets.disable_progress_bar()
25 |
26 |
27 | @auto_validate_methods
28 | class Artifex:
29 | """
30 | Artifex is a library for easily training and using small, task-specific AI models.
31 | """
32 |
33 | def __init__(self, api_key: Optional[str] = None):
34 | """
35 | Initializes Artifex with an API key for authentication.
36 | Args:
37 | api_key (Optional[str]): The API key to use for authentication. If not provided, attempts to load
38 | it from the .env file.
39 | """
40 |
41 | if not api_key:
42 | api_key=config.API_KEY
43 | self._synthex_client = Synthex(api_key=api_key)
44 | self._text_classification = None
45 | self._guardrail = None
46 | self._intent_classifier = None
47 | self._reranker = None
48 | self._sentiment_analysis = None
49 | self._emotion_detection = None
50 | self._named_entity_recognition = None
51 | self._text_anonymization = None
52 |
53 | @property
54 | def text_classification(self) -> ClassificationModel:
55 | """
56 | Lazy loads the ClassificationModel instance.
57 | Returns:
58 | ClassificationModel: An instance of the ClassificationModel class.
59 | """
60 |
61 | if self._text_classification is None:
62 | with console.status("Loading Classification model..."):
63 | self._text_classification = ClassificationModel(synthex=self._synthex_client)
64 | return self._text_classification
65 |
66 | @property
67 | def guardrail(self) -> Guardrail:
68 | """
69 | Lazy loads the Guardrail instance.
70 | Returns:
71 | Guardrail: An instance of the Guardrail class.
72 | """
73 |
74 | if self._guardrail is None:
75 | with console.status("Loading Guardrail model..."):
76 | self._guardrail = Guardrail(synthex=self._synthex_client)
77 | return self._guardrail
78 |
79 | @property
80 | def intent_classifier(self) -> IntentClassifier:
81 | """
82 | Lazy loads the IntentClassifier instance.
83 | Returns:
84 | IntentClassifier: An instance of the IntentClassifier class.
85 | """
86 |
87 | if self._intent_classifier is None:
88 | with console.status("Loading Intent Classifier model..."):
89 | self._intent_classifier = IntentClassifier(synthex=self._synthex_client)
90 | return self._intent_classifier
91 |
92 | @property
93 | def reranker(self) -> Reranker:
94 | """
95 | Lazy loads the Reranker instance.
96 | Returns:
97 | Reranker: An instance of the Reranker class.
98 | """
99 |
100 | if self._reranker is None:
101 | with console.status("Loading Reranker model..."):
102 | self._reranker = Reranker(synthex=self._synthex_client)
103 | return self._reranker
104 |
105 | @property
106 | def sentiment_analysis(self) -> SentimentAnalysis:
107 | """
108 | Lazy loads the SentimentAnalysis instance.
109 | Returns:
110 | SentimentAnalysis: An instance of the SentimentAnalysis class.
111 | """
112 |
113 | if self._sentiment_analysis is None:
114 | with console.status("Loading Sentiment Analysis model..."):
115 | self._sentiment_analysis = SentimentAnalysis(synthex=self._synthex_client)
116 | return self._sentiment_analysis
117 |
118 | @property
119 | def emotion_detection(self) -> EmotionDetection:
120 | """
121 | Lazy loads the EmotionDetection instance.
122 | Returns:
123 | EmotionDetection: An instance of the EmotionDetection class.
124 | """
125 |
126 | if self._emotion_detection is None:
127 | with console.status("Loading Emotion Detection model..."):
128 | self._emotion_detection = EmotionDetection(synthex=self._synthex_client)
129 | return self._emotion_detection
130 |
131 | @property
132 | def named_entity_recognition(self) -> NamedEntityRecognition:
133 | """
134 | Lazy loads the NamedEntityRecognition instance.
135 | Returns:
136 | NamedEntityRecognition: An instance of the NamedEntityRecognition class.
137 | """
138 |
139 | if self._named_entity_recognition is None:
140 | with console.status("Loading Named Entity Recognition model..."):
141 | self._named_entity_recognition = NamedEntityRecognition(synthex=self._synthex_client)
142 | return self._named_entity_recognition
143 |
144 | @property
145 | def text_anonymization(self) -> TextAnonymization:
146 | """
147 | Lazy loads the TextAnonymization instance.
148 | Returns:
149 | TextAnonymization: An instance of the TextAnonymization class.
150 | """
151 |
152 | if self._text_anonymization is None:
153 | with console.status("Loading Text Anonymization model..."):
154 | self._text_anonymization = TextAnonymization(synthex=self._synthex_client)
155 | return self._text_anonymization
--------------------------------------------------------------------------------
/tests/unit/base_model/test_bm_sanitize_output_path.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 |
4 | from artifex.core import ValidationError
5 |
6 |
7 | @pytest.fixture(scope="function", autouse=True)
8 | def mock_dependencies(mocker: MockerFixture):
9 | """
10 | Fixture to mock all external dependencies before any test runs.
11 | This fixture runs automatically for all tests in this module.
12 | Args:
13 | mocker (MockerFixture): The pytest-mock fixture for mocking.
14 | """
15 |
16 | from artifex.config import config
17 |
18 | mocker.patch.object(
19 | type(config), # Get the class of the config instance
20 | 'DEFAULT_OUTPUT_PATH',
21 | new_callable=mocker.PropertyMock,
22 | return_value="/default/output/"
23 | )
24 |
25 |
26 | @pytest.mark.unit
27 | def test_sanitize_output_path_with_none_returns_default():
28 | """
29 | Test that _sanitize_output_path returns the default path when given None.
30 | """
31 |
32 | from artifex.models import BaseModel
33 |
34 | result = BaseModel._sanitize_output_path(None)
35 |
36 | assert result == "/default/output/"
37 |
38 |
39 | @pytest.mark.unit
40 | def test_sanitize_output_path_with_empty_string_returns_default():
41 | """
42 | Test that _sanitize_output_path returns the default path when given an empty string.
43 | """
44 |
45 | from artifex.models import BaseModel
46 |
47 | result = BaseModel._sanitize_output_path("")
48 |
49 | assert result == "/default/output/"
50 |
51 |
52 | @pytest.mark.unit
53 | def test_sanitize_output_path_with_whitespace_only_returns_default():
54 | """
55 | Test that _sanitize_output_path returns the default path when given whitespace only.
56 | """
57 |
58 | from artifex.models import BaseModel
59 |
60 | result = BaseModel._sanitize_output_path(" ")
61 |
62 | assert result == "/default/output/"
63 |
64 |
65 | @pytest.mark.unit
66 | def test_sanitize_output_path_with_directory_only():
67 | """
68 | Test that _sanitize_output_path correctly handles a directory path without a file.
69 | """
70 |
71 | from artifex.models import BaseModel
72 |
73 | result = BaseModel._sanitize_output_path("/custom/output/path")
74 |
75 | # Should append date string from default path
76 | assert result == "/custom/output/path/"
77 |
78 |
79 | @pytest.mark.unit
80 | def test_sanitize_output_path_with_file_raises_validation_error():
81 | """
82 | Test that _sanitize_output_path raises an error when a file path is provided.
83 | """
84 |
85 | from artifex.models import BaseModel
86 |
87 | with pytest.raises(ValidationError) as exc_info:
88 | BaseModel._sanitize_output_path("/custom/output/model.safetensors")
89 | assert str(exc_info.value) == "The output_path parameter must be a directory path, not a file path. Try with: '/custom/output'."
90 |
91 |
92 |
93 | @pytest.mark.unit
94 | def test_sanitize_output_path_with_trailing_slash():
95 | """
96 | Test that _sanitize_output_path handles paths with trailing slashes correctly.
97 | """
98 |
99 | from artifex.models import BaseModel
100 |
101 | result = BaseModel._sanitize_output_path("/custom/path/")
102 |
103 | assert result == "/custom/path/"
104 |
105 |
106 | @pytest.mark.unit
107 | def test_sanitize_output_path_strips_whitespace():
108 | """
109 | Test that _sanitize_output_path strips leading and trailing whitespace.
110 | """
111 |
112 | from artifex.models import BaseModel
113 |
114 | result = BaseModel._sanitize_output_path(" /custom/path ")
115 |
116 | assert result == "/custom/path/"
117 |
118 |
119 | @pytest.mark.unit
120 | def test_sanitize_output_path_with_relative_path():
121 | """
122 | Test that _sanitize_output_path handles relative paths.
123 | """
124 |
125 | from artifex.models import BaseModel
126 |
127 | result = BaseModel._sanitize_output_path("./models/output/")
128 |
129 | assert result == "./models/output/"
130 |
131 |
132 | @pytest.mark.unit
133 | def test_sanitize_output_path_with_parent_directory_notation():
134 | """
135 | Test that _sanitize_output_path handles parent directory notation.
136 | """
137 |
138 | from artifex.models import BaseModel
139 |
140 | result = BaseModel._sanitize_output_path("../output/models")
141 |
142 | assert result == "../output/models/"
143 |
144 |
145 | @pytest.mark.unit
146 | def test_sanitize_output_path_appends_slash_to_path_missing_trailing_slash():
147 | """
148 | Test that _sanitize_output_path correctly appends a slash to a path that misses the ending
149 | slash.
150 | """
151 |
152 | from artifex.models import BaseModel
153 |
154 | result = BaseModel._sanitize_output_path("/custom/path")
155 |
156 | assert result == "/custom/path/"
157 |
158 |
159 | @pytest.mark.unit
160 | def test_sanitize_output_path_with_complex_path():
161 | """
162 | Test that _sanitize_output_path handles complex paths correctly.
163 | """
164 |
165 | from artifex.models import BaseModel
166 |
167 | result = BaseModel._sanitize_output_path("/home/user/my_models/sentiment_analysis/v2")
168 |
169 | assert result == "/home/user/my_models/sentiment_analysis/v2/"
170 |
171 |
172 | @pytest.mark.unit
173 | def test_sanitize_output_path_is_static_method():
174 | """
175 | Test that _sanitize_output_path can be called as a static method.
176 | """
177 |
178 | from artifex.models import BaseModel
179 |
180 | # Should not require an instance
181 | result = BaseModel._sanitize_output_path("/path")
182 |
183 | assert isinstance(result, str)
184 | assert result.endswith("/")
185 |
186 |
187 | @pytest.mark.unit
188 | def test_sanitize_output_path_validation_failure_with_non_string():
189 | """
190 | Test that _sanitize_output_path raises ValidationError with non-string input.
191 | """
192 |
193 | from artifex.models import BaseModel
194 | from artifex.core import ValidationError
195 |
196 | with pytest.raises(ValidationError):
197 | BaseModel._sanitize_output_path(123)
198 |
199 |
200 | @pytest.mark.unit
201 | def test_sanitize_output_path_validation_failure_with_list():
202 | """
203 | Test that _sanitize_output_path raises ValidationError with list input.
204 | """
205 |
206 | from artifex.models import BaseModel
207 | from artifex.core import ValidationError
208 |
209 | with pytest.raises(ValidationError):
210 | BaseModel._sanitize_output_path(["/path"])
--------------------------------------------------------------------------------
/tests/unit/guardrail/test_gr_train.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from typing import List, Optional
4 | from transformers.trainer_utils import TrainOutput
5 | from synthex import Synthex
6 | from artifex.models.classification.binary_classification.guardrail import Guardrail
7 | from artifex.config import config
8 |
9 |
10 | @pytest.fixture(scope="function", autouse=True)
11 | def mock_hf_and_config(mocker: MockerFixture) -> None:
12 | """
13 | Fixture to mock Hugging Face model/tokenizer loading and config values.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for mocking.
16 | """
17 |
18 | mocker.patch.object(config, "GUARDRAIL_HF_BASE_MODEL", "mock-guardrail-model")
19 | mocker.patch.object(config, "DEFAULT_SYNTHEX_DATAPOINT_NUM", 100)
20 | mocker.patch(
21 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained",
22 | return_value=mocker.MagicMock()
23 | )
24 | mocker.patch(
25 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained",
26 | return_value=mocker.MagicMock()
27 | )
28 |
29 |
30 | @pytest.fixture
31 | def mock_synthex(mocker: MockerFixture) -> Synthex:
32 | """
33 | Fixture to create a mock Synthex instance.
34 | Args:
35 | mocker (MockerFixture): The pytest-mock fixture for mocking.
36 | Returns:
37 | Synthex: A mocked Synthex instance.
38 | """
39 |
40 | return mocker.MagicMock(spec=Synthex)
41 |
42 |
43 | @pytest.fixture
44 | def guardrail(mocker: MockerFixture, mock_synthex: Synthex) -> Guardrail:
45 | """
46 | Fixture to create a Guardrail instance with mocked dependencies.
47 | Args:
48 | mocker (MockerFixture): The pytest-mock fixture for mocking.
49 | mock_synthex (Synthex): A mocked Synthex instance.
50 | Returns:
51 | Guardrail: An instance of the Guardrail model with mocked dependencies.
52 | """
53 |
54 | return Guardrail(mock_synthex)
55 |
56 |
57 | def test_train_calls_train_pipeline_with_required_args(
58 | guardrail: Guardrail, mocker: MockerFixture
59 | ) -> None:
60 | """
61 | Test that train() calls _train_pipeline with only required arguments.
62 | Args:
63 | guardrail (Guardrail): The Guardrail instance.
64 | mocker (MockerFixture): The pytest-mock fixture for mocking.
65 | """
66 |
67 | instructions = ["instruction1", "instruction2"]
68 | mock_output = TrainOutput(global_step=1, training_loss=0.1, metrics={})
69 | train_pipeline_mock = mocker.patch.object(
70 | guardrail, "_train_pipeline", return_value=mock_output
71 | )
72 |
73 | result = guardrail.train(unsafe_content=instructions)
74 |
75 | train_pipeline_mock.assert_called_once_with(
76 | user_instructions=instructions,
77 | output_path=None,
78 | num_samples=500,
79 | num_epochs=3
80 | )
81 |
82 | assert result is mock_output
83 |
84 |
85 | def test_train_calls_train_pipeline_with_all_args(
86 | guardrail: Guardrail, mocker: MockerFixture
87 | ) -> None:
88 | """
89 | Test that train() calls _train_pipeline with all arguments provided.
90 | Args:
91 | guardrail (Guardrail): The Guardrail instance.
92 | mocker (MockerFixture): The pytest-mock fixture for mocking.
93 | """
94 |
95 | instructions = ["foo", "bar"]
96 | output_path = "/tmp/guardrail.csv"
97 | num_samples = 42
98 | num_epochs = 7
99 | mock_output = TrainOutput(global_step=2, training_loss=0.2, metrics={})
100 | train_pipeline_mock = mocker.patch.object(
101 | guardrail, "_train_pipeline", return_value=mock_output
102 | )
103 |
104 | result = guardrail.train(
105 | unsafe_content=instructions,
106 | output_path=output_path,
107 | num_samples=num_samples,
108 | num_epochs=num_epochs
109 | )
110 |
111 | train_pipeline_mock.assert_called_once_with(
112 | user_instructions=instructions,
113 | output_path=output_path,
114 | num_samples=num_samples,
115 | num_epochs=num_epochs
116 | )
117 | assert result is mock_output
118 |
119 |
120 | def test_train_returns_trainoutput(
121 | guardrail: Guardrail, mocker: MockerFixture
122 | ) -> None:
123 | """
124 | Test that train() returns the TrainOutput from _train_pipeline.
125 | Args:
126 | guardrail (Guardrail): The Guardrail instance.
127 | mocker (MockerFixture): The pytest-mock fixture for mocking.
128 | """
129 |
130 | instructions = ["baz"]
131 | mock_output = TrainOutput(global_step=3, training_loss=0.3, metrics={})
132 | mocker.patch.object(guardrail, "_train_pipeline", return_value=mock_output)
133 |
134 | result = guardrail.train(unsafe_content=instructions)
135 | assert isinstance(result, TrainOutput)
136 | assert result is mock_output
137 |
138 |
139 | def test_train_with_empty_instructions(
140 | guardrail: Guardrail, mocker: MockerFixture
141 | ) -> None:
142 | """
143 | Test that train() works with an empty instructions list.
144 | Args:
145 | guardrail (Guardrail): The Guardrail instance.
146 | mocker (MockerFixture): The pytest-mock fixture for mocking.
147 | """
148 |
149 | instructions: List[str] = []
150 | mock_output = TrainOutput(global_step=4, training_loss=0.4, metrics={})
151 | train_pipeline_mock = mocker.patch.object(
152 | guardrail, "_train_pipeline", return_value=mock_output
153 | )
154 |
155 | result = guardrail.train(unsafe_content=instructions)
156 | train_pipeline_mock.assert_called_once_with(
157 | user_instructions=instructions,
158 | output_path=None,
159 | num_samples=500,
160 | num_epochs=3
161 | )
162 | assert result is mock_output
163 |
164 |
165 | def test_train_with_none_output_path(
166 | guardrail: Guardrail, mocker: MockerFixture
167 | ) -> None:
168 | """
169 | Test that train() passes None for output_path if not provided.
170 | Args:
171 | guardrail (Guardrail): The Guardrail instance.
172 | mocker (MockerFixture): The pytest-mock fixture for mocking.
173 | """
174 |
175 | instructions = ["test"]
176 | mock_output = TrainOutput(global_step=5, training_loss=0.5, metrics={})
177 | train_pipeline_mock = mocker.patch.object(
178 | guardrail, "_train_pipeline", return_value=mock_output
179 | )
180 |
181 | result = guardrail.train(unsafe_content=instructions)
182 | call_kwargs = train_pipeline_mock.call_args.kwargs
183 | assert call_kwargs["output_path"] is None
184 | assert result is mock_output
--------------------------------------------------------------------------------
/tests/unit/named_entity_recognition/test_ner_parse_user_instructions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from synthex import Synthex
4 |
5 | from artifex.models import NamedEntityRecognition
6 | from artifex.core import NERInstructions
7 |
8 |
9 | @pytest.fixture
10 | def mock_synthex(mocker: MockerFixture) -> Synthex:
11 | """
12 | Create a mock Synthex instance.
13 | Args:
14 | mocker: pytest-mock fixture for creating mocks.
15 | Returns:
16 | Synthex: A mocked Synthex instance.
17 | """
18 |
19 | return mocker.MagicMock(spec=Synthex)
20 |
21 |
22 | @pytest.fixture
23 | def ner_instance(
24 | mocker: MockerFixture, mock_synthex: Synthex
25 | ) -> NamedEntityRecognition:
26 | """
27 | Create a NamedEntityRecognition instance with mocked dependencies.
28 | Args:
29 | mocker: pytest-mock fixture for creating mocks.
30 | mock_synthex: Mocked Synthex instance.
31 | Returns:
32 | NamedEntityRecognition: An instance with mocked dependencies.
33 | """
34 |
35 | # Mock AutoTokenizer and AutoModelForTokenClassification imports
36 | mocker.patch(
37 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer.from_pretrained"
38 | )
39 | mocker.patch(
40 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification.from_pretrained"
41 | )
42 |
43 | return NamedEntityRecognition(mock_synthex)
44 |
45 |
46 | @pytest.mark.unit
47 | def test_parse_single_entity_tag(
48 | ner_instance: NamedEntityRecognition
49 | ):
50 | """
51 | Test parsing user instructions with a single named entity tag.
52 | Args:
53 | ner_instance: Fixture providing NamedEntityRecognition instance.
54 | """
55 |
56 | user_instructions = NERInstructions(
57 | named_entity_tags={"PERSON": "A person's name"},
58 | domain="medical records"
59 | )
60 |
61 | result = ner_instance._parse_user_instructions(user_instructions)
62 |
63 | assert len(result) == 2
64 | assert result[0] == "PERSON: A person's name"
65 | assert result[1] == "medical records"
66 |
67 |
68 | @pytest.mark.unit
69 | def test_parse_multiple_entity_tags(
70 | ner_instance: NamedEntityRecognition
71 | ):
72 | """
73 | Test parsing user instructions with multiple named entity tags.
74 | Args:
75 | ner_instance: Fixture providing NamedEntityRecognition instance.
76 | """
77 |
78 | user_instructions = NERInstructions(
79 | named_entity_tags={
80 | "PERSON": "A person's name",
81 | "LOCATION": "A geographical location",
82 | "ORGANIZATION": "A company or institution"
83 | },
84 | domain="news articles"
85 | )
86 |
87 | result = ner_instance._parse_user_instructions(user_instructions)
88 |
89 | assert len(result) == 4
90 | assert "PERSON: A person's name" in result
91 | assert "LOCATION: A geographical location" in result
92 | assert "ORGANIZATION: A company or institution" in result
93 | assert result[-1] == "news articles"
94 |
95 |
96 | @pytest.mark.unit
97 | def test_parse_empty_entity_tags(
98 | ner_instance: NamedEntityRecognition
99 | ):
100 | """
101 | Test parsing user instructions with no named entity tags:
102 | Args:
103 | ner_instance: Fixture providing NamedEntityRecognition instance.
104 | """
105 |
106 | user_instructions = NERInstructions(
107 | named_entity_tags={},
108 | domain="general text"
109 | )
110 |
111 | result = ner_instance._parse_user_instructions(user_instructions)
112 |
113 | assert len(result) == 1
114 | assert result[0] == "general text"
115 |
116 |
117 | @pytest.mark.unit
118 | def test_parse_entity_tags_ordering(
119 | ner_instance: NamedEntityRecognition
120 | ):
121 | """
122 | Test that domain is always the last element in the result.
123 | Args:
124 | ner_instance: Fixture providing NamedEntityRecognition instance.
125 | """
126 |
127 | user_instructions = NERInstructions(
128 | named_entity_tags={
129 | "DATE": "A date or time reference",
130 | "MONEY": "Monetary amounts"
131 | },
132 | domain="financial reports"
133 | )
134 |
135 | result = ner_instance._parse_user_instructions(user_instructions)
136 |
137 | # Domain should always be last
138 | assert result[-1] == "financial reports"
139 | # All other elements should be entity tags
140 | assert all(": " in item for item in result[:-1])
141 |
142 |
143 | @pytest.mark.unit
144 | def test_parse_entity_tags_with_special_characters(
145 | ner_instance: NamedEntityRecognition
146 | ):
147 | """
148 | Test parsing entity tags and descriptions with special characters.
149 | Args:
150 | ner_instance: Fixture providing NamedEntityRecognition instance.
151 | """
152 |
153 | user_instructions = NERInstructions(
154 | named_entity_tags={
155 | "EMAIL": "An email address (e.g., user@example.com)",
156 | "PHONE": "A phone number (+1-555-1234)"
157 | },
158 | domain="customer support tickets"
159 | )
160 |
161 | result = ner_instance._parse_user_instructions(user_instructions)
162 |
163 | assert len(result) == 3
164 | assert "EMAIL: An email address (e.g., user@example.com)" in result
165 | assert "PHONE: A phone number (+1-555-1234)" in result
166 | assert result[-1] == "customer support tickets"
167 |
168 |
169 | @pytest.mark.unit
170 | def test_parse_entity_tags_with_long_descriptions(
171 | ner_instance: NamedEntityRecognition
172 | ):
173 | """
174 | Test parsing entity tags with long, detailed descriptions.
175 | Args:
176 | ner_instance: Fixture providing NamedEntityRecognition instance.
177 | """
178 |
179 | long_description = (
180 | "A product name including brand, model, version, "
181 | "and any other identifying information"
182 | )
183 |
184 | user_instructions = NERInstructions(
185 | named_entity_tags={"PRODUCT": long_description},
186 | domain="e-commerce product reviews"
187 | )
188 |
189 | result = ner_instance._parse_user_instructions(user_instructions)
190 |
191 | assert len(result) == 2
192 | assert result[0] == f"PRODUCT: {long_description}"
193 | assert result[-1] == "e-commerce product reviews"
194 |
195 |
196 | @pytest.mark.unit
197 | def test_parse_preserves_tag_description_format(
198 | ner_instance: NamedEntityRecognition
199 | ):
200 | """
201 | Test that the method preserves the 'tag: description' format.
202 | Args:
203 | ner_instance: Fixture providing NamedEntityRecognition instance.
204 | """
205 |
206 | user_instructions = NERInstructions(
207 | named_entity_tags={
208 | "SKILL": "A technical or professional skill",
209 | "EDUCATION": "Educational qualification or institution"
210 | },
211 | domain="resume parsing"
212 | )
213 |
214 | result = ner_instance._parse_user_instructions(user_instructions)
215 |
216 | # Check format of tag-description pairs
217 | for item in result[:-1]:
218 | assert ": " in item
219 | tag, description = item.split(": ", 1)
220 | assert tag.isupper() or tag.replace("_", "").isupper()
221 | assert len(description) > 0
222 |
223 |
224 | @pytest.mark.unit
225 | def test_parse_with_domain_containing_special_chars(
226 | ner_instance: NamedEntityRecognition
227 | ):
228 | """
229 | Test parsing when domain contains special characters.
230 | Args:
231 | ner_instance: Fixture providing NamedEntityRecognition instance.
232 | """
233 |
234 | user_instructions = NERInstructions(
235 | named_entity_tags={"DRUG": "Pharmaceutical drug name"},
236 | domain="medical records (patient: John Doe, ID: 12345)"
237 | )
238 |
239 | result = ner_instance._parse_user_instructions(user_instructions)
240 |
241 | assert len(result) == 2
242 | assert result[0] == "DRUG: Pharmaceutical drug name"
243 | assert result[1] == "medical records (patient: John Doe, ID: 12345)"
244 |
245 |
246 | @pytest.mark.unit
247 | def test_parse_return_type(
248 | ner_instance: NamedEntityRecognition
249 | ):
250 | """
251 | Test that the method returns a list of strings.
252 | Args:
253 | ner_instance: Fixture providing NamedEntityRecognition instance.
254 | """
255 |
256 | user_instructions = NERInstructions(
257 | named_entity_tags={"GPE": "Geo-political entity"},
258 | domain="news corpus"
259 | )
260 |
261 | result = ner_instance._parse_user_instructions(user_instructions)
262 |
263 | assert isinstance(result, list)
264 | assert all(isinstance(item, str) for item in result)
--------------------------------------------------------------------------------
/tests/unit/reranker/test_rr_get_data_gen_instr.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | import pytest
3 | from pytest_mock import MockerFixture
4 |
5 | from artifex.models import Reranker
6 | from artifex.config import config
7 |
8 |
9 | @pytest.fixture(scope="function", autouse=True)
10 | def mock_dependencies(mocker: MockerFixture):
11 | """
12 | Fixture to mock all external dependencies before any test runs.
13 | This fixture runs automatically for all tests in this module.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for mocking.
16 | """
17 |
18 | # Mock config
19 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model")
20 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512)
21 |
22 | # Mock AutoTokenizer at the module where it's used
23 | mock_tokenizer = mocker.MagicMock()
24 | mocker.patch(
25 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained",
26 | return_value=mock_tokenizer
27 | )
28 |
29 | # Mock AutoModelForSequenceClassification at the module where it's used
30 | mock_model = mocker.MagicMock()
31 | mocker.patch(
32 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained",
33 | return_value=mock_model
34 | )
35 |
36 |
37 | @pytest.fixture
38 | def mock_synthex(mocker: MockerFixture) -> Synthex:
39 | """
40 | Fixture to create a mock Synthex instance.
41 | Args:
42 | mocker (MockerFixture): The pytest-mock fixture for mocking.
43 | Returns:
44 | Synthex: A mocked Synthex instance.
45 | """
46 |
47 | return mocker.MagicMock(spec=Synthex)
48 |
49 |
50 | @pytest.fixture
51 | def mock_reranker(mock_synthex: Synthex) -> Reranker:
52 | """
53 | Fixture to create a Reranker instance with mocked dependencies.
54 | Args:
55 | mock_synthex (Synthex): A mocked Synthex instance.
56 | Returns:
57 | Reranker: An instance of the Reranker model with mocked dependencies.
58 | """
59 |
60 | return Reranker(mock_synthex)
61 |
62 |
63 | @pytest.mark.unit
64 | def test_get_data_gen_instr_success(mock_reranker: Reranker):
65 | """
66 | Test that _get_data_gen_instr correctly formats system instructions with the domain.
67 | Args:
68 | mock_reranker (Reranker): The Reranker instance to test.
69 | """
70 |
71 | domain = "scientific research"
72 | user_instructions = [domain]
73 |
74 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
75 |
76 | # Assert that the result is a list
77 | assert isinstance(combined_instr, list)
78 |
79 | # The length should equal the number of system instructions
80 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr)
81 |
82 | # The domain should be formatted into the first system instruction
83 | assert domain in combined_instr[0]
84 | assert f"following domain(s): {domain}" in combined_instr[0]
85 |
86 |
87 | @pytest.mark.unit
88 | def test_get_data_gen_instr_formats_all_instructions(mock_reranker: Reranker):
89 | """
90 | Test that all system instructions are properly formatted with the domain.
91 | Args:
92 | mock_reranker (Reranker): The Reranker instance to test.
93 | """
94 |
95 | domain = "e-commerce products"
96 | user_instructions = [domain]
97 |
98 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
99 |
100 | # Verify that {domain} placeholder is replaced in all instructions
101 | for instr in combined_instr:
102 | assert "{domain}" not in instr
103 |
104 | # Check that domain appears in the first instruction
105 | assert domain in combined_instr[0]
106 |
107 |
108 | @pytest.mark.unit
109 | def test_get_data_gen_instr_preserves_instruction_count(mock_reranker: Reranker):
110 | """
111 | Test that the number of instructions matches the system instructions.
112 | Args:
113 | mock_reranker (Reranker): The Reranker instance to test.
114 | """
115 |
116 | domain = "customer support"
117 | user_instructions = [domain]
118 |
119 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
120 |
121 | # Should have exactly the same number as system instructions
122 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr)
123 |
124 |
125 | @pytest.mark.unit
126 | def test_get_data_gen_instr_with_different_domains(mock_reranker: Reranker):
127 | """
128 | Test that _get_data_gen_instr works with different domain strings.
129 | Args:
130 | mock_reranker (Reranker): The Reranker instance to test.
131 | """
132 |
133 | domains = [
134 | "medical research",
135 | "legal documents",
136 | "news articles",
137 | "technical documentation"
138 | ]
139 |
140 | for domain in domains:
141 | user_instructions = [domain]
142 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
143 |
144 | assert isinstance(combined_instr, list)
145 | assert domain in combined_instr[0]
146 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr)
147 |
148 |
149 | @pytest.mark.unit
150 | def test_get_data_gen_instr_does_not_modify_original_list(mock_reranker: Reranker):
151 | """
152 | Test that _get_data_gen_instr does not modify the original system instructions.
153 | Args:
154 | mock_reranker (Reranker): The Reranker instance to test.
155 | """
156 |
157 | domain = "financial data"
158 | user_instructions = [domain]
159 | original_system_instr = mock_reranker._system_data_gen_instr.copy()
160 |
161 | mock_reranker._get_data_gen_instr(user_instructions)
162 |
163 | # Verify original system instructions are unchanged
164 | assert mock_reranker._system_data_gen_instr == original_system_instr
165 |
166 |
167 | @pytest.mark.unit
168 | def test_get_data_gen_instr_validation_failure(mock_reranker: Reranker):
169 | """
170 | Test that _get_data_gen_instr raises ValidationError with invalid input.
171 | Args:
172 | mock_reranker (Reranker): The Reranker instance to test.
173 | """
174 |
175 | from artifex.core import ValidationError
176 |
177 | with pytest.raises(ValidationError):
178 | mock_reranker._get_data_gen_instr("invalid input")
179 |
180 |
181 | @pytest.mark.unit
182 | def test_get_data_gen_instr_with_empty_list(mock_reranker: Reranker):
183 | """
184 | Test that _get_data_gen_instr handles empty list.
185 | Args:
186 | mock_reranker (Reranker): The Reranker instance to test.
187 | """
188 |
189 | from artifex.core import ValidationError
190 |
191 | with pytest.raises((ValidationError, IndexError)):
192 | mock_reranker._get_data_gen_instr([])
193 |
194 |
195 | @pytest.mark.unit
196 | def test_get_data_gen_instr_only_uses_first_element(mock_reranker: Reranker):
197 | """
198 | Test that _get_data_gen_instr only uses the first element as domain,
199 | ignoring any additional elements.
200 | Args:
201 | mock_reranker (Reranker): The Reranker instance to test.
202 | """
203 |
204 | domain = "healthcare"
205 | extra_data = ["extra1", "extra2"]
206 | user_instructions = [domain] + extra_data
207 |
208 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
209 |
210 | # Should only format with the first element (domain)
211 | assert domain in combined_instr[0]
212 | # Extra elements should not appear in the instructions
213 | for extra in extra_data:
214 | assert extra not in combined_instr[0]
215 |
216 |
217 | @pytest.mark.unit
218 | def test_get_data_gen_instr_with_special_characters_in_domain(mock_reranker: Reranker):
219 | """
220 | Test that _get_data_gen_instr correctly handles domains with special characters.
221 | Args:
222 | mock_reranker (Reranker): The Reranker instance to test.
223 | """
224 |
225 | domain = "Q&A for tech support (beta)"
226 | user_instructions = [domain]
227 |
228 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
229 |
230 | # Domain with special characters should be properly included
231 | assert domain in combined_instr[0]
232 | assert isinstance(combined_instr, list)
233 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr)
234 |
235 |
236 | @pytest.mark.unit
237 | def test_get_data_gen_instr_returns_new_list(mock_reranker: Reranker):
238 | """
239 | Test that _get_data_gen_instr returns a new list, not the original.
240 | Args:
241 | mock_reranker (Reranker): The Reranker instance to test.
242 | """
243 |
244 | domain = "travel booking"
245 | user_instructions = [domain]
246 |
247 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions)
248 |
249 | # Modifying the result should not affect system instructions
250 | combined_instr.append("new instruction")
251 |
252 | assert len(mock_reranker._system_data_gen_instr) != len(combined_instr)
253 | assert "new instruction" not in mock_reranker._system_data_gen_instr
--------------------------------------------------------------------------------
/tests/unit/reranker/test_rr_parse_user_instructions.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | import pytest
3 | from pytest_mock import MockerFixture
4 |
5 | from artifex.models import Reranker
6 | from artifex.config import config
7 |
8 |
9 | @pytest.fixture(scope="function", autouse=True)
10 | def mock_dependencies(mocker: MockerFixture):
11 | """
12 | Fixture to mock all external dependencies before any test runs.
13 | This fixture runs automatically for all tests in this module.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for mocking.
16 | """
17 |
18 | # Mock config
19 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model")
20 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512)
21 |
22 | # Mock AutoTokenizer at the module where it's used
23 | mock_tokenizer = mocker.MagicMock()
24 | mocker.patch(
25 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained",
26 | return_value=mock_tokenizer
27 | )
28 |
29 | # Mock AutoModelForSequenceClassification at the module where it's used
30 | mock_model = mocker.MagicMock()
31 | mocker.patch(
32 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained",
33 | return_value=mock_model
34 | )
35 |
36 |
37 | @pytest.fixture
38 | def mock_synthex(mocker: MockerFixture) -> Synthex:
39 | """
40 | Fixture to create a mock Synthex instance.
41 | Args:
42 | mocker (MockerFixture): The pytest-mock fixture for mocking.
43 | Returns:
44 | Synthex: A mocked Synthex instance.
45 | """
46 |
47 | return mocker.MagicMock(spec=Synthex)
48 |
49 |
50 | @pytest.fixture
51 | def mock_reranker(mock_synthex: Synthex) -> Reranker:
52 | """
53 | Fixture to create a Reranker instance with mocked dependencies.
54 | Args:
55 | mock_synthex (Synthex): A mocked Synthex instance.
56 | Returns:
57 | Reranker: An instance of the Reranker model with mocked dependencies.
58 | """
59 |
60 | return Reranker(mock_synthex)
61 |
62 |
63 | @pytest.mark.unit
64 | def test_parse_user_instructions_returns_list(mock_reranker: Reranker):
65 | """
66 | Test that _parse_user_instructions returns a list.
67 | Args:
68 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
69 | """
70 |
71 | user_instructions = "scientific research papers"
72 |
73 | result = mock_reranker._parse_user_instructions(user_instructions)
74 |
75 | assert isinstance(result, list)
76 |
77 |
78 | @pytest.mark.unit
79 | def test_parse_user_instructions_single_element(mock_reranker: Reranker):
80 | """
81 | Test that _parse_user_instructions returns a list with a single element.
82 | Args:
83 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
84 | """
85 |
86 | user_instructions = "medical documents"
87 |
88 | result = mock_reranker._parse_user_instructions(user_instructions)
89 |
90 | assert len(result) == 1
91 |
92 |
93 | @pytest.mark.unit
94 | def test_parse_user_instructions_contains_original_string(mock_reranker: Reranker):
95 | """
96 | Test that the returned list contains the original user instructions string.
97 | Args:
98 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
99 | """
100 |
101 | user_instructions = "legal documents and contracts"
102 |
103 | result = mock_reranker._parse_user_instructions(user_instructions)
104 |
105 | assert result[0] == user_instructions
106 |
107 |
108 | @pytest.mark.unit
109 | def test_parse_user_instructions_with_empty_string(mock_reranker: Reranker):
110 | """
111 | Test that _parse_user_instructions handles an empty string.
112 | Args:
113 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
114 | """
115 |
116 | user_instructions = ""
117 |
118 | result = mock_reranker._parse_user_instructions(user_instructions)
119 |
120 | assert isinstance(result, list)
121 | assert len(result) == 1
122 | assert result[0] == ""
123 |
124 |
125 | @pytest.mark.unit
126 | def test_parse_user_instructions_with_whitespace(mock_reranker: Reranker):
127 | """
128 | Test that _parse_user_instructions preserves whitespace in the string.
129 | Args:
130 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
131 | """
132 |
133 | user_instructions = " news articles with spaces "
134 |
135 | result = mock_reranker._parse_user_instructions(user_instructions)
136 |
137 | assert len(result) == 1
138 | assert result[0] == user_instructions
139 |
140 |
141 | @pytest.mark.unit
142 | def test_parse_user_instructions_with_multiline_string(mock_reranker: Reranker):
143 | """
144 | Test that _parse_user_instructions handles multiline strings.
145 | Args:
146 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
147 | """
148 |
149 | user_instructions = """technical documentation
150 | and user manuals"""
151 |
152 | result = mock_reranker._parse_user_instructions(user_instructions)
153 |
154 | assert isinstance(result, list)
155 | assert len(result) == 1
156 | assert result[0] == user_instructions
157 |
158 |
159 | @pytest.mark.unit
160 | def test_parse_user_instructions_with_special_characters(mock_reranker: Reranker):
161 | """
162 | Test that _parse_user_instructions handles special characters.
163 | Args:
164 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
165 | """
166 |
167 | user_instructions = "Q&A for tech support (beta) - version 2.0!"
168 |
169 | result = mock_reranker._parse_user_instructions(user_instructions)
170 |
171 | assert len(result) == 1
172 | assert result[0] == user_instructions
173 |
174 |
175 | @pytest.mark.unit
176 | def test_parse_user_instructions_with_long_string(mock_reranker: Reranker):
177 | """
178 | Test that _parse_user_instructions handles long strings.
179 | Args:
180 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
181 | """
182 |
183 | user_instructions = "A" * 1000
184 |
185 | result = mock_reranker._parse_user_instructions(user_instructions)
186 |
187 | assert len(result) == 1
188 | assert result[0] == user_instructions
189 | assert len(result[0]) == 1000
190 |
191 |
192 | @pytest.mark.unit
193 | def test_parse_user_instructions_validation_failure_with_list(mock_reranker: Reranker):
194 | """
195 | Test that _parse_user_instructions raises ValidationError when given a list instead of string.
196 | Args:
197 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
198 | """
199 |
200 | from artifex.core import ValidationError
201 |
202 | with pytest.raises(ValidationError):
203 | mock_reranker._parse_user_instructions(["not", "a", "string"])
204 |
205 |
206 | @pytest.mark.unit
207 | def test_parse_user_instructions_validation_failure_with_none(mock_reranker: Reranker):
208 | """
209 | Test that _parse_user_instructions raises ValidationError when given None.
210 | Args:
211 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
212 | """
213 |
214 | from artifex.core import ValidationError
215 |
216 | with pytest.raises(ValidationError):
217 | mock_reranker._parse_user_instructions(None)
218 |
219 |
220 | @pytest.mark.unit
221 | def test_parse_user_instructions_validation_failure_with_int(mock_reranker: Reranker):
222 | """
223 | Test that _parse_user_instructions raises ValidationError when given an integer.
224 | Args:
225 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
226 | """
227 |
228 | from artifex.core import ValidationError
229 |
230 | with pytest.raises(ValidationError):
231 | mock_reranker._parse_user_instructions(123)
232 |
233 |
234 | @pytest.mark.unit
235 | def test_parse_user_instructions_does_not_modify_input(mock_reranker: Reranker):
236 | """
237 | Test that _parse_user_instructions does not modify the input string.
238 | Args:
239 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
240 | """
241 |
242 | user_instructions = "customer reviews and feedback"
243 | original = user_instructions
244 |
245 | result = mock_reranker._parse_user_instructions(user_instructions)
246 |
247 | # Input string should remain unchanged
248 | assert user_instructions == original
249 |
250 |
251 | @pytest.mark.unit
252 | def test_parse_user_instructions_with_unicode(mock_reranker: Reranker):
253 | """
254 | Test that _parse_user_instructions handles unicode characters.
255 | Args:
256 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
257 | """
258 |
259 | user_instructions = "文档分类 и categorización de documentos"
260 |
261 | result = mock_reranker._parse_user_instructions(user_instructions)
262 |
263 | assert len(result) == 1
264 | assert result[0] == user_instructions
--------------------------------------------------------------------------------
/tests/unit/text_anonymization/test_ta__call__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from synthex import Synthex
4 | from typing import List
5 |
6 | from artifex.models import TextAnonymization
7 | from artifex.config import config
8 |
9 |
10 | @pytest.fixture
11 | def mock_synthex(mocker: MockerFixture) -> Synthex:
12 | """
13 | Creates a mock Synthex instance.
14 | Args:
15 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
16 | Returns:
17 | Synthex: A mocked Synthex instance.
18 | """
19 |
20 | return mocker.Mock(spec=Synthex)
21 |
22 | @pytest.fixture
23 | def text_anonymization(mock_synthex: Synthex, mocker: MockerFixture) -> TextAnonymization:
24 | """
25 | Creates a TextAnonymization instance with mocked dependencies.
26 | Args:
27 | mock_synthex (Synthex): A mocked Synthex instance.
28 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
29 | Returns:
30 | TextAnonymization: An instance of TextAnonymization with mocked parent class.
31 | """
32 |
33 | # Mock the parent class __init__ to avoid initialization issues
34 | mocker.patch.object(TextAnonymization.__bases__[0], '__init__', return_value=None)
35 | instance = TextAnonymization(mock_synthex)
36 | return instance
37 |
38 |
39 | @pytest.mark.unit
40 | def test_call_with_single_string_no_entities(
41 | text_anonymization: TextAnonymization, mocker: MockerFixture
42 | ):
43 | """
44 | Tests __call__ with a single string input when no entities are detected.
45 | Args:
46 | text_anonymization (TextAnonymization): The TextAnonymization instance.
47 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
48 | """
49 |
50 | mock_parent_call = mocker.patch.object(
51 | TextAnonymization.__bases__[0], '__call__', return_value=[[]]
52 | )
53 |
54 | input_text = "This is a test sentence."
55 | result = text_anonymization(input_text)
56 |
57 | mock_parent_call.assert_called_once_with([input_text])
58 | assert result == [input_text]
59 |
60 |
61 | @pytest.mark.unit
62 | def test_call_with_single_string_with_entities(
63 | text_anonymization: TextAnonymization, mocker: MockerFixture
64 | ):
65 | """
66 | Tests __call__ with a single string input containing PII entities.
67 | Args:
68 | text_anonymization (TextAnonymization): The TextAnonymization instance.
69 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
70 | """
71 |
72 | mock_entity = mocker.Mock()
73 | mock_entity.entity_group = "PERSON"
74 | mock_entity.start = 11
75 | mock_entity.end = 15
76 |
77 | mock_parent_call = mocker.patch.object(
78 | TextAnonymization.__bases__[0], '__call__', return_value=[[mock_entity]]
79 | )
80 |
81 | input_text = "My name is John and I live in NYC."
82 | result = text_anonymization(input_text)
83 |
84 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK
85 | expected_output = f"My name is {expected_mask} and I live in NYC."
86 |
87 | mock_parent_call.assert_called_once_with([input_text])
88 | assert result == [expected_output]
89 |
90 |
91 | @pytest.mark.unit
92 | def test_call_with_list_of_strings(
93 | text_anonymization: TextAnonymization, mocker: MockerFixture
94 | ):
95 | """
96 | Tests __call__ with a list of strings as input.
97 | Args:
98 | text_anonymization (TextAnonymization): The TextAnonymization instance.
99 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
100 | """
101 |
102 | mock_entity1 = mocker.Mock()
103 | mock_entity1.entity_group = "PERSON"
104 | mock_entity1.start = 0
105 | mock_entity1.end = 4
106 |
107 | mock_entity2 = mocker.Mock()
108 | mock_entity2.entity_group = "LOCATION"
109 | mock_entity2.start = 8
110 | mock_entity2.end = 15
111 |
112 | mock_parent_call = mocker.patch.object(
113 | TextAnonymization.__bases__[0], '__call__',
114 | return_value=[[mock_entity1], [mock_entity2]]
115 | )
116 |
117 | input_texts = ["John lives here", "I am in London"]
118 | result = text_anonymization(input_texts)
119 |
120 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK
121 | expected_outputs = [f"{expected_mask} lives here", f"I am in {expected_mask}"]
122 |
123 | mock_parent_call.assert_called_once_with(input_texts)
124 | assert result == expected_outputs
125 |
126 |
127 | @pytest.mark.unit
128 | def test_call_with_custom_entities_to_mask(
129 | text_anonymization: TextAnonymization, mocker: MockerFixture
130 | ):
131 | """
132 | Tests __call__ with custom entities_to_mask parameter.
133 | Args:
134 | text_anonymization (TextAnonymization): The TextAnonymization instance.
135 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
136 | """
137 |
138 | mock_person = mocker.Mock()
139 | mock_person.entity_group = "PERSON"
140 | mock_person.start = 0
141 | mock_person.end = 4
142 |
143 | mock_location = mocker.Mock()
144 | mock_location.entity_group = "LOCATION"
145 | mock_location.start = 14
146 | mock_location.end = 20
147 |
148 | mock_parent_call = mocker.patch.object(
149 | TextAnonymization.__bases__[0], '__call__',
150 | return_value=[[mock_person, mock_location]]
151 | )
152 |
153 | input_text = "John lives in London"
154 | result = text_anonymization(input_text, entities_to_mask=["PERSON"])
155 |
156 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK
157 | expected_output = f"{expected_mask} lives in London"
158 |
159 | mock_parent_call.assert_called_once_with([input_text])
160 | assert result == [expected_output]
161 |
162 |
163 | @pytest.mark.unit
164 | def test_call_with_custom_mask_token(
165 | text_anonymization: TextAnonymization, mocker: MockerFixture
166 | ):
167 | """
168 | Tests __call__ with a custom mask_token parameter.
169 | Args:
170 | text_anonymization (TextAnonymization): The TextAnonymization instance.
171 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
172 | """
173 |
174 | mock_entity = mocker.Mock()
175 | mock_entity.entity_group = "PERSON"
176 | mock_entity.start = 0
177 | mock_entity.end = 4
178 |
179 | mock_parent_call = mocker.patch.object(
180 | TextAnonymization.__bases__[0], '__call__', return_value=[[mock_entity]]
181 | )
182 |
183 | input_text = "John is here"
184 | custom_mask = "[REDACTED]"
185 | result = text_anonymization(input_text, mask_token=custom_mask)
186 |
187 | expected_output = f"{custom_mask} is here"
188 |
189 | mock_parent_call.assert_called_once_with([input_text])
190 | assert result == [expected_output]
191 |
192 |
193 | @pytest.mark.unit
194 | def test_call_with_invalid_entity_to_mask(
195 | text_anonymization: TextAnonymization, mocker: MockerFixture
196 | ):
197 | """
198 | Tests __call__ raises ValueError when invalid entity is in entities_to_mask.
199 | Args:
200 | text_anonymization (TextAnonymization): The TextAnonymization instance.
201 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
202 | """
203 |
204 | with pytest.raises(ValueError) as exc_info:
205 | text_anonymization("test text", entities_to_mask=["INVALID_ENTITY"])
206 |
207 | assert "INVALID_ENTITY" in str(exc_info.value)
208 | assert "cannot be masked" in str(exc_info.value)
209 |
210 |
211 | @pytest.mark.unit
212 | def test_call_with_multiple_entities_same_text(
213 | text_anonymization: TextAnonymization, mocker: MockerFixture
214 | ):
215 | """
216 | Tests __call__ with multiple entities in the same text.
217 | Args:
218 | text_anonymization (TextAnonymization): The TextAnonymization instance.
219 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
220 | """
221 |
222 | mock_entity1 = mocker.Mock()
223 | mock_entity1.entity_group = "PERSON"
224 | mock_entity1.start = 0
225 | mock_entity1.end = 4
226 |
227 | mock_entity2 = mocker.Mock()
228 | mock_entity2.entity_group = "PHONE_NUMBER"
229 | mock_entity2.start = 17
230 | mock_entity2.end = 29
231 |
232 | mock_parent_call = mocker.patch.object(
233 | TextAnonymization.__bases__[0], '__call__',
234 | return_value=[[mock_entity1, mock_entity2]]
235 | )
236 |
237 | input_text = "John's number is 123-456-7890"
238 | result = text_anonymization(input_text)
239 |
240 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK
241 | expected_output = f"{expected_mask}'s number is {expected_mask}"
242 |
243 | mock_parent_call.assert_called_once_with([input_text])
244 | assert result == [expected_output]
245 |
246 |
247 | @pytest.mark.unit
248 | def test_call_with_empty_string(
249 | text_anonymization: TextAnonymization, mocker: MockerFixture
250 | ):
251 | """
252 | Tests __call__ with an empty string input.
253 | Args:
254 | text_anonymization (TextAnonymization): The TextAnonymization instance.
255 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
256 | """
257 |
258 | mock_parent_call = mocker.patch.object(
259 | TextAnonymization.__bases__[0], '__call__', return_value=[[]]
260 | )
261 |
262 | input_text = ""
263 | result = text_anonymization(input_text)
264 |
265 | mock_parent_call.assert_called_once_with([input_text])
266 | assert result == [""]
--------------------------------------------------------------------------------
/tests/unit/named_entity_recognition/test_ner_post_process_synthetic_dataset.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | from pytest_mock import MockerFixture
4 | from datasets import ClassLabel
5 | from typing import Any
6 |
7 | from artifex.models import NamedEntityRecognition
8 |
9 |
10 | @pytest.fixture
11 | def mock_synthex(mocker: MockerFixture) -> Any:
12 | """
13 | Create a mock Synthex instance.
14 | Args:
15 | mocker: pytest-mock fixture for creating mocks.
16 | Returns:
17 | Mock Synthex instance.
18 | """
19 |
20 | return mocker.Mock()
21 |
22 |
23 | @pytest.fixture
24 | def ner_instance(mock_synthex: Any, mocker: MockerFixture) -> NamedEntityRecognition:
25 | """
26 | Create a NamedEntityRecognition instance with mocked dependencies.
27 | Args:
28 | mock_synthex: Mocked Synthex instance.
29 | mocker: pytest-mock fixture for creating mocks.
30 | Returns:
31 | NamedEntityRecognition instance with mocked components.
32 | """
33 |
34 | # Mock AutoTokenizer and AutoModelForTokenClassification
35 | mocker.patch("artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer")
36 | mocker.patch("artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification")
37 |
38 | ner = NamedEntityRecognition(mock_synthex)
39 | # Set up labels with typical NER tags
40 | ner._labels = ClassLabel(names=["O", "B-PERSON", "I-PERSON", "B-LOCATION", "I-LOCATION"])
41 | return ner
42 |
43 |
44 | @pytest.mark.unit
45 | def test_cleanup_removes_invalid_labels(
46 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
47 | ):
48 | """
49 | Test that rows with invalid labels are removed.
50 | Args:
51 | ner_instance: NamedEntityRecognition instance.
52 | mocker: pytest-mock fixture.
53 | tmp_path: pytest temporary directory fixture.
54 | """
55 |
56 | # Create test CSV
57 | test_csv = tmp_path / "test_dataset.csv"
58 | df = pd.DataFrame({
59 | "text": ["John lives in Paris", "Invalid data"],
60 | "labels": ["John: PERSON, Paris: LOCATION", "not a valid format"]
61 | })
62 | df.to_csv(test_csv, index=False)
63 |
64 | # Run cleanup
65 | ner_instance._post_process_synthetic_dataset(str(test_csv))
66 |
67 | # Load result
68 | result_df = pd.read_csv(test_csv)
69 |
70 | # Should have removed the invalid row
71 | assert len(result_df) == 1
72 | assert "John lives in Paris" in result_df["text"].values
73 |
74 |
75 | @pytest.mark.unit
76 | def test_cleanup_converts_to_bio_format(
77 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
78 | ):
79 | """
80 | Test that labels are correctly converted to BIO format.
81 | Args:
82 | ner_instance: NamedEntityRecognition instance.
83 | mocker: pytest-mock fixture.
84 | tmp_path: pytest temporary directory fixture.
85 | """
86 |
87 | test_csv = tmp_path / "test_dataset.csv"
88 | df = pd.DataFrame({
89 | "text": ["John Smith lives in New York"],
90 | "labels": ["John Smith: PERSON, New York: LOCATION"]
91 | })
92 | df.to_csv(test_csv, index=False)
93 |
94 | ner_instance._post_process_synthetic_dataset(str(test_csv))
95 |
96 | result_df = pd.read_csv(test_csv)
97 | assert len(result_df) == 1
98 |
99 | # Parse the labels back from string representation
100 | import ast
101 | labels = ast.literal_eval(result_df["labels"].iloc[0])
102 |
103 | # Check BIO tags
104 | assert labels[0] == "B-PERSON" # John
105 | assert labels[1] == "I-PERSON" # Smith
106 | assert labels[2] == "O" # lives
107 | assert labels[3] == "O" # in
108 | assert labels[4] == "B-LOCATION" # New
109 | assert labels[5] == "I-LOCATION" # York
110 |
111 |
112 | @pytest.mark.unit
113 | def test_cleanup_removes_empty_labels(
114 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
115 | ):
116 | """
117 | Test that rows with empty labels are removed
118 | Args:
119 | ner_instance: NamedEntityRecognition instance.
120 | mocker: pytest-mock fixture.
121 | tmp_path: pytest temporary directory fixture.
122 | """
123 |
124 | test_csv = tmp_path / "test_dataset.csv"
125 | df = pd.DataFrame({
126 | "text": ["John lives here", "No entities here"],
127 | "labels": ["John: PERSON", ""]
128 | })
129 | df.to_csv(test_csv, index=False)
130 |
131 | ner_instance._post_process_synthetic_dataset(str(test_csv))
132 |
133 | result_df = pd.read_csv(test_csv)
134 |
135 | # Should only keep the row with actual entities
136 | assert len(result_df) == 1
137 | assert "John lives here" in result_df["text"].values
138 |
139 |
140 | @pytest.mark.unit
141 | def test_cleanup_removes_only_o_tags(
142 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
143 | ):
144 | """
145 | Test that rows with only 'O' tags are removed
146 | Args:
147 | ner_instance: NamedEntityRecognition instance.
148 | mocker: pytest-mock fixture.
149 | tmp_path: pytest temporary directory fixture.
150 | """
151 |
152 | test_csv = tmp_path / "test_dataset.csv"
153 | df = pd.DataFrame({
154 | "text": ["John lives here", "No entities here"],
155 | "labels": ["John: PERSON", "['O', 'O', 'O', 'O']"]
156 | })
157 |
158 | df.to_csv(test_csv, index=False)
159 |
160 | ner_instance._post_process_synthetic_dataset(str(test_csv))
161 |
162 | result_df = pd.read_csv(test_csv)
163 |
164 | # Should only keep the row with actual entities
165 | assert len(result_df) == 1
166 | assert "John lives here" in result_df["text"].values
167 |
168 |
169 | @pytest.mark.unit
170 | def test_cleanup_removes_invalid_tags(
171 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
172 | ):
173 | """
174 | Test that rows with invalid named entity tags are removed.
175 | Args:
176 | ner_instance: NamedEntityRecognition instance.
177 | mocker: pytest-mock fixture.
178 | tmp_path: pytest temporary directory fixture.
179 | """
180 |
181 | test_csv = tmp_path / "test_dataset.csv"
182 | df = pd.DataFrame({
183 | "text": ["John works at Google", "Jane lives in Paris"],
184 | "labels": ["John: PERSON, Google: ORGANIZATION", "Jane: PERSON, Paris: LOCATION"]
185 | })
186 | df.to_csv(test_csv, index=False)
187 |
188 | ner_instance._post_process_synthetic_dataset(str(test_csv))
189 |
190 | result_df = pd.read_csv(test_csv)
191 |
192 | # Should remove row with ORGANIZATION tag (not in allowed tags)
193 | assert len(result_df) == 1
194 | assert "Jane lives in Paris" in result_df["text"].values
195 |
196 |
197 | @pytest.mark.unit
198 | def test_cleanup_handles_multi_word_entities(
199 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
200 | ):
201 | """
202 | Test that multi-word entities are correctly handled.
203 | Args:
204 | ner_instance: NamedEntityRecognition instance.
205 | mocker: pytest-mock fixture.
206 | tmp_path: pytest temporary directory fixture.
207 | """
208 |
209 | test_csv = tmp_path / "test_dataset.csv"
210 | df = pd.DataFrame({
211 | "text": ["The Eiffel Tower is in Paris"],
212 | "labels": ["Eiffel Tower: LOCATION, Paris: LOCATION"]
213 | })
214 | df.to_csv(test_csv, index=False)
215 |
216 | ner_instance._post_process_synthetic_dataset(str(test_csv))
217 |
218 | result_df = pd.read_csv(test_csv)
219 | assert len(result_df) == 1
220 |
221 | import ast
222 | labels = ast.literal_eval(result_df["labels"].iloc[0])
223 |
224 | # Check that multi-word entity is tagged correctly
225 | assert labels[1] == "B-LOCATION" # Eiffel
226 | assert labels[2] == "I-LOCATION" # Tower
227 |
228 |
229 | @pytest.mark.unit
230 | def test_cleanup_handles_punctuation(
231 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
232 | ):
233 | """
234 | Test that cleanup handles punctuation in entities correctly.
235 | Args:
236 | ner_instance: NamedEntityRecognition instance.
237 | mocker: pytest-mock fixture.
238 | tmp_path: pytest temporary directory fixture.
239 | """
240 |
241 | test_csv = tmp_path / "test_dataset.csv"
242 | df = pd.DataFrame({
243 | "text": ["Dr. John Smith, PhD lives here"],
244 | "labels": ["Dr. John Smith, PhD: PERSON"]
245 | })
246 | df.to_csv(test_csv, index=False)
247 |
248 | ner_instance._post_process_synthetic_dataset(str(test_csv))
249 |
250 | result_df = pd.read_csv(test_csv)
251 | assert len(result_df) >= 0 # Should either process or remove based on punctuation handling
252 |
253 |
254 | @pytest.mark.unit
255 | def test_cleanup_case_insensitive_matching(
256 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any
257 | ):
258 | """
259 | Test that entity matching is case-insensitive.
260 | Args:
261 | ner_instance: NamedEntityRecognition instance.
262 | mocker: pytest-mock fixture.
263 | tmp_path: pytest temporary directory fixture.
264 | """
265 |
266 | test_csv = tmp_path / "test_dataset.csv"
267 | df = pd.DataFrame({
268 | "text": ["JOHN lives in paris"],
269 | "labels": ["john: PERSON, PARIS: LOCATION"]
270 | })
271 | df.to_csv(test_csv, index=False)
272 |
273 | ner_instance._post_process_synthetic_dataset(str(test_csv))
274 |
275 | result_df = pd.read_csv(test_csv)
276 | assert len(result_df) == 1
277 |
278 | import ast
279 | labels = ast.literal_eval(result_df["labels"].iloc[0])
280 | assert labels[0] == "B-PERSON"
281 | assert labels[3] == "B-LOCATION"
--------------------------------------------------------------------------------
/tests/unit/named_entity_recognition/test_ner_get_data_gen_instr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from synthex import Synthex
4 |
5 | from artifex.models import NamedEntityRecognition
6 |
7 |
8 | @pytest.fixture
9 | def mock_synthex(mocker: MockerFixture) -> Synthex:
10 | """
11 | Create a mock Synthex instance.
12 | Args:
13 | mocker: pytest-mock fixture for creating mocks.
14 | Returns:
15 | Synthex: A mocked Synthex instance.
16 | """
17 |
18 | return mocker.MagicMock(spec=Synthex)
19 |
20 |
21 | @pytest.fixture
22 | def ner_instance(
23 | mocker: MockerFixture, mock_synthex: Synthex
24 | ) -> NamedEntityRecognition:
25 | """
26 | Create a NamedEntityRecognition instance with mocked dependencies.
27 | Args:
28 | mocker: pytest-mock fixture for creating mocks.
29 | mock_synthex: Mocked Synthex instance.
30 | Returns:
31 | NamedEntityRecognition: An instance with mocked dependencies.
32 | """
33 |
34 | # Mock AutoTokenizer and AutoModelForTokenClassification imports
35 | mocker.patch(
36 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer.from_pretrained"
37 | )
38 | mocker.patch(
39 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification.from_pretrained"
40 | )
41 |
42 | return NamedEntityRecognition(mock_synthex)
43 |
44 |
45 | @pytest.mark.unit
46 | def test_get_data_gen_instr_single_entity(
47 | ner_instance: NamedEntityRecognition
48 | ):
49 | """
50 | Test data generation instruction formatting with a single entity tag.
51 | Args:
52 | ner_instance: Fixture providing NamedEntityRecognition instance.
53 | """
54 |
55 | user_instr = [
56 | "PERSON: A person's name",
57 | "medical records"
58 | ]
59 |
60 | result = ner_instance._get_data_gen_instr(user_instr)
61 |
62 | # Should return formatted system instructions
63 | assert len(result) == len(ner_instance._system_data_gen_instr)
64 |
65 | # Check that domain was properly formatted
66 | assert any("medical records" in instr for instr in result)
67 |
68 | # Check that named entity tags were properly formatted
69 | assert any("PERSON: A person's name" in instr for instr in result)
70 |
71 |
72 | @pytest.mark.unit
73 | def test_get_data_gen_instr_multiple_entities(
74 | ner_instance: NamedEntityRecognition
75 | ):
76 | """
77 | Test data generation instruction formatting with multiple entity tags.
78 | Args:
79 | ner_instance: Fixture providing NamedEntityRecognition instance.
80 | """
81 |
82 | user_instr = [
83 | "PERSON: A person's name",
84 | "LOCATION: A geographical location",
85 | "ORGANIZATION: A company or institution",
86 | "news articles"
87 | ]
88 |
89 | result = ner_instance._get_data_gen_instr(user_instr)
90 |
91 | assert len(result) == len(ner_instance._system_data_gen_instr)
92 |
93 | # Check domain formatting
94 | assert any("news articles" in instr for instr in result)
95 |
96 | # Check all entity tags are included
97 | formatted_tags_str = " ".join(result)
98 | assert "PERSON: A person's name" in formatted_tags_str
99 | assert "LOCATION: A geographical location" in formatted_tags_str
100 | assert "ORGANIZATION: A company or institution" in formatted_tags_str
101 |
102 |
103 | @pytest.mark.unit
104 | def test_get_data_gen_instr_domain_only(
105 | ner_instance: NamedEntityRecognition
106 | ):
107 | """
108 | Test data generation instruction formatting with only domain (no entity tags).
109 | Args:
110 | ner_instance: Fixture providing NamedEntityRecognition instance.
111 | """
112 |
113 | user_instr = ["general text"]
114 |
115 | result = ner_instance._get_data_gen_instr(user_instr)
116 |
117 | assert len(result) == len(ner_instance._system_data_gen_instr)
118 |
119 | # Check domain was formatted
120 | assert any("general text" in instr for instr in result)
121 |
122 | # Named entity tags should be an empty list
123 | assert any("[]" in instr for instr in result)
124 |
125 |
126 | @pytest.mark.unit
127 | def test_get_data_gen_instr_format_placeholders(
128 | ner_instance: NamedEntityRecognition
129 | ):
130 | """
131 | Test that all placeholders in system instructions are properly replaced.
132 | Args:
133 | ner_instance: Fixture providing NamedEntityRecognition instance.
134 | """
135 |
136 | user_instr = [
137 | "EMAIL: An email address",
138 | "customer support tickets"
139 | ]
140 |
141 | result = ner_instance._get_data_gen_instr(user_instr)
142 |
143 | # No placeholders should remain in the result
144 | for instr in result:
145 | assert "{domain}" not in instr
146 | assert "{named_entity_tags}" not in instr
147 |
148 |
149 | @pytest.mark.unit
150 | def test_get_data_gen_instr_preserves_instruction_count(
151 | ner_instance: NamedEntityRecognition
152 | ):
153 | """
154 | Test that the number of instructions matches the system instruction count.
155 | Args:
156 | ner_instance: Fixture providing NamedEntityRecognition instance.
157 | """
158 |
159 | user_instr = [
160 | "PRODUCT: A product name",
161 | "PRICE: A monetary amount",
162 | "e-commerce reviews"
163 | ]
164 |
165 | result = ner_instance._get_data_gen_instr(user_instr)
166 |
167 | # Should have exactly the same number of instructions as system template
168 | assert len(result) == len(ner_instance._system_data_gen_instr)
169 |
170 |
171 | @pytest.mark.unit
172 | def test_get_data_gen_instr_domain_extraction(
173 | ner_instance: NamedEntityRecognition
174 | ):
175 | """
176 | Test that domain is correctly extracted from the last element.
177 | Args:
178 | ner_instance: Fixture providing NamedEntityRecognition instance.
179 | """
180 |
181 | user_instr = [
182 | "DATE: A date reference",
183 | "MONEY: Monetary amounts",
184 | "TIME: A time reference",
185 | "financial reports and documents"
186 | ]
187 |
188 | result = ner_instance._get_data_gen_instr(user_instr)
189 |
190 | # Domain should be in the formatted instructions
191 | domain_present = any(
192 | "financial reports and documents" in instr
193 | for instr in result
194 | )
195 | assert domain_present
196 |
197 |
198 | @pytest.mark.unit
199 | def test_get_data_gen_instr_entity_tags_extraction(
200 | ner_instance: NamedEntityRecognition
201 | ):
202 | """
203 | Test that entity tags are correctly extracted from all elements except the last.
204 | Args:
205 | ner_instance: Fixture providing NamedEntityRecognition instance.
206 | """
207 |
208 | user_instr = [
209 | "DRUG: Pharmaceutical drug name",
210 | "DOSAGE: Drug dosage information",
211 | "medical prescriptions"
212 | ]
213 |
214 | result = ner_instance._get_data_gen_instr(user_instr)
215 |
216 | result_str = " ".join(result)
217 |
218 | # Both entity tags should be present
219 | assert "DRUG: Pharmaceutical drug name" in result_str
220 | assert "DOSAGE: Drug dosage information" in result_str
221 |
222 |
223 | @pytest.mark.unit
224 | def test_get_data_gen_instr_return_type(
225 | ner_instance: NamedEntityRecognition
226 | ):
227 | """
228 | Test that the method returns a list of strings.
229 | Args:
230 | ner_instance: Fixture providing NamedEntityRecognition instance.
231 | """
232 |
233 | user_instr = [
234 | "GPE: Geo-political entity",
235 | "news corpus"
236 | ]
237 |
238 | result = ner_instance._get_data_gen_instr(user_instr)
239 |
240 | assert isinstance(result, list)
241 | assert all(isinstance(instr, str) for instr in result)
242 |
243 |
244 | @pytest.mark.unit
245 | def test_get_data_gen_instr_special_characters_in_domain(
246 | ner_instance: NamedEntityRecognition
247 | ):
248 | """
249 | Test handling of special characters in domain.
250 | Args:
251 | ner_instance: Fixture providing NamedEntityRecognition instance.
252 | """
253 |
254 | user_instr = [
255 | "PERSON: A person",
256 | "domain with (special) characters & symbols!"
257 | ]
258 |
259 | result = ner_instance._get_data_gen_instr(user_instr)
260 |
261 | # Special characters should be preserved
262 | assert any(
263 | "domain with (special) characters & symbols!" in instr
264 | for instr in result
265 | )
266 |
267 |
268 | @pytest.mark.unit
269 | def test_get_data_gen_instr_special_characters_in_tags(
270 | ner_instance: NamedEntityRecognition
271 | ):
272 | """
273 | Test handling of special characters in entity tag descriptions.
274 | Args:
275 | ner_instance: Fixture providing NamedEntityRecognition instance.
276 | """
277 |
278 | user_instr = [
279 | "EMAIL: An email (e.g., user@example.com)",
280 | "PHONE: A phone number (+1-555-1234)",
281 | "contact information"
282 | ]
283 |
284 | result = ner_instance._get_data_gen_instr(user_instr)
285 |
286 | result_str = " ".join(result)
287 |
288 | # Special characters in descriptions should be preserved
289 | assert "user@example.com" in result_str
290 | assert "+1-555-1234" in result_str
291 |
292 |
293 | @pytest.mark.unit
294 | def test_get_data_gen_instr_empty_entity_list(
295 | ner_instance: NamedEntityRecognition
296 | ):
297 | """
298 | Test with empty entity list (only domain provided).
299 | Args:
300 | ner_instance: Fixture providing NamedEntityRecognition instance.
301 | """
302 |
303 | user_instr = ["simple domain"]
304 |
305 | result = ner_instance._get_data_gen_instr(user_instr)
306 |
307 | # Should still return all system instructions
308 | assert len(result) == len(ner_instance._system_data_gen_instr)
309 |
310 | # Domain should be present
311 | assert any("simple domain" in instr for instr in result)
--------------------------------------------------------------------------------
/tests/unit/text_anonymization/test_ta_train.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_mock import MockerFixture
3 | from synthex import Synthex
4 | from transformers.trainer_utils import TrainOutput
5 | from typing import Optional
6 |
7 | from artifex.models import TextAnonymization
8 | from artifex.config import config
9 |
10 |
11 | @pytest.fixture
12 | def mock_synthex(mocker: MockerFixture) -> Synthex:
13 | """
14 | Creates a mock Synthex instance.
15 | Args:
16 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
17 | Returns:
18 | Synthex: A mocked Synthex instance.
19 | """
20 |
21 | return mocker.Mock(spec=Synthex)
22 |
23 |
24 | @pytest.fixture
25 | def text_anonymization(mock_synthex: Synthex, mocker: MockerFixture) -> TextAnonymization:
26 | """
27 | Creates a TextAnonymization instance with mocked dependencies.
28 | Args:
29 | mock_synthex (Synthex): A mocked Synthex instance.
30 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
31 | Returns:
32 | TextAnonymization: An instance of TextAnonymization with mocked parent class.
33 | """
34 |
35 | # Mock the parent class __init__ to avoid initialization issues
36 | mocker.patch.object(TextAnonymization.__bases__[0], '__init__', return_value=None)
37 | instance = TextAnonymization(mock_synthex)
38 | return instance
39 |
40 |
41 | @pytest.mark.unit
42 | def test_train_with_default_parameters(
43 | text_anonymization: TextAnonymization, mocker: MockerFixture
44 | ):
45 | """
46 | Tests train() with default parameters.
47 | Args:
48 | text_anonymization (TextAnonymization): The TextAnonymization instance.
49 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
50 | """
51 |
52 | mock_train_output = mocker.Mock(spec=TrainOutput)
53 | mock_parent_train = mocker.patch.object(
54 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
55 | )
56 |
57 | domain = "healthcare"
58 | result = text_anonymization.train(domain=domain)
59 |
60 | expected_pii_entities = {
61 | "PERSON": "Individual people, fictional characters",
62 | "LOCATION": "Geographical areas",
63 | "DATE": "Absolute or relative dates, including years, months and/or days",
64 | "ADDRESS": "full addresses",
65 | "PHONE_NUMBER": "telephone numbers",
66 | }
67 |
68 | mock_parent_train.assert_called_once_with(
69 | named_entities=expected_pii_entities,
70 | domain=domain,
71 | output_path=None,
72 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM,
73 | num_epochs=3,
74 | train_datapoint_examples=None
75 | )
76 | assert result == mock_train_output
77 |
78 |
79 | @pytest.mark.unit
80 | def test_train_with_custom_output_path(
81 | text_anonymization: TextAnonymization, mocker: MockerFixture
82 | ):
83 | """
84 | Tests train() with a custom output_path parameter.
85 | Args:
86 | text_anonymization (TextAnonymization): The TextAnonymization instance.
87 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
88 | """
89 |
90 | mock_train_output = mocker.Mock(spec=TrainOutput)
91 | mock_parent_train = mocker.patch.object(
92 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
93 | )
94 |
95 | domain = "finance"
96 | output_path = "/custom/path/to/model"
97 | result = text_anonymization.train(domain=domain, output_path=output_path)
98 |
99 | expected_pii_entities = text_anonymization._pii_entities
100 |
101 | mock_parent_train.assert_called_once_with(
102 | named_entities=expected_pii_entities,
103 | domain=domain,
104 | output_path=output_path,
105 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM,
106 | num_epochs=3,
107 | train_datapoint_examples=None
108 | )
109 | assert result == mock_train_output
110 |
111 |
112 | @pytest.mark.unit
113 | def test_train_with_custom_num_samples(
114 | text_anonymization: TextAnonymization, mocker: MockerFixture
115 | ):
116 | """
117 | Tests train() with a custom num_samples parameter.
118 | Args:
119 | text_anonymization (TextAnonymization): The TextAnonymization instance.
120 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
121 | """
122 |
123 | mock_train_output = mocker.Mock(spec=TrainOutput)
124 | mock_parent_train = mocker.patch.object(
125 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
126 | )
127 |
128 | domain = "legal"
129 | num_samples = 500
130 | result = text_anonymization.train(domain=domain, num_samples=num_samples)
131 |
132 | expected_pii_entities = text_anonymization._pii_entities
133 |
134 | mock_parent_train.assert_called_once_with(
135 | named_entities=expected_pii_entities,
136 | domain=domain,
137 | output_path=None,
138 | num_samples=num_samples,
139 | num_epochs=3,
140 | train_datapoint_examples=None
141 | )
142 | assert result == mock_train_output
143 |
144 |
145 | @pytest.mark.unit
146 | def test_train_with_custom_num_epochs(
147 | text_anonymization: TextAnonymization, mocker: MockerFixture
148 | ):
149 | """
150 | Tests train() with a custom num_epochs parameter.
151 | Args:
152 | text_anonymization (TextAnonymization): The TextAnonymization instance.
153 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
154 | """
155 |
156 | mock_train_output = mocker.Mock(spec=TrainOutput)
157 | mock_parent_train = mocker.patch.object(
158 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
159 | )
160 |
161 | domain = "customer_service"
162 | num_epochs = 5
163 | result = text_anonymization.train(domain=domain, num_epochs=num_epochs)
164 |
165 | expected_pii_entities = text_anonymization._pii_entities
166 |
167 | mock_parent_train.assert_called_once_with(
168 | named_entities=expected_pii_entities,
169 | domain=domain,
170 | output_path=None,
171 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM,
172 | num_epochs=num_epochs,
173 | train_datapoint_examples=None
174 | )
175 | assert result == mock_train_output
176 |
177 |
178 | @pytest.mark.unit
179 | def test_train_with_all_custom_parameters(
180 | text_anonymization: TextAnonymization, mocker: MockerFixture
181 | ):
182 | """
183 | Tests train() with all custom parameters.
184 | Args:
185 | text_anonymization (TextAnonymization): The TextAnonymization instance.
186 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
187 | """
188 |
189 | mock_train_output = mocker.Mock(spec=TrainOutput)
190 | mock_parent_train = mocker.patch.object(
191 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
192 | )
193 |
194 | domain = "retail"
195 | output_path = "/path/to/retail/model"
196 | num_samples = 1000
197 | num_epochs = 10
198 |
199 | result = text_anonymization.train(
200 | domain=domain,
201 | output_path=output_path,
202 | num_samples=num_samples,
203 | num_epochs=num_epochs
204 | )
205 |
206 | expected_pii_entities = text_anonymization._pii_entities
207 |
208 | mock_parent_train.assert_called_once_with(
209 | named_entities=expected_pii_entities,
210 | domain=domain,
211 | output_path=output_path,
212 | num_samples=num_samples,
213 | num_epochs=num_epochs,
214 | train_datapoint_examples=None
215 | )
216 | assert result == mock_train_output
217 |
218 |
219 | @pytest.mark.unit
220 | def test_train_uses_predefined_pii_entities(
221 | text_anonymization: TextAnonymization, mocker: MockerFixture
222 | ):
223 | """
224 | Tests that train() always uses the predefined PII entities.
225 | Args:
226 | text_anonymization (TextAnonymization): The TextAnonymization instance.
227 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
228 | """
229 |
230 | mock_train_output = mocker.Mock(spec=TrainOutput)
231 | mock_parent_train = mocker.patch.object(
232 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
233 | )
234 |
235 | domain = "education"
236 | text_anonymization.train(domain=domain)
237 |
238 | # Verify that named_entities is the PII entities dict
239 | call_kwargs = mock_parent_train.call_args.kwargs
240 | assert "named_entities" in call_kwargs
241 | assert call_kwargs["named_entities"] == text_anonymization._pii_entities
242 | assert "PERSON" in call_kwargs["named_entities"]
243 | assert "LOCATION" in call_kwargs["named_entities"]
244 | assert "DATE" in call_kwargs["named_entities"]
245 | assert "ADDRESS" in call_kwargs["named_entities"]
246 | assert "PHONE_NUMBER" in call_kwargs["named_entities"]
247 |
248 |
249 | @pytest.mark.unit
250 | def test_train_sets_train_datapoint_examples_to_none(
251 | text_anonymization: TextAnonymization, mocker: MockerFixture
252 | ):
253 | """
254 | Tests that train() always sets train_datapoint_examples to None.
255 | Args:
256 | text_anonymization (TextAnonymization): The TextAnonymization instance.
257 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
258 | """
259 |
260 | mock_train_output = mocker.Mock(spec=TrainOutput)
261 | mock_parent_train = mocker.patch.object(
262 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
263 | )
264 |
265 | domain = "technology"
266 | text_anonymization.train(domain=domain)
267 |
268 | # Verify that train_datapoint_examples is always None
269 | call_kwargs = mock_parent_train.call_args.kwargs
270 | assert "train_datapoint_examples" in call_kwargs
271 | assert call_kwargs["train_datapoint_examples"] is None
272 |
273 |
274 | @pytest.mark.unit
275 | def test_train_returns_train_output(
276 | text_anonymization: TextAnonymization, mocker: MockerFixture
277 | ):
278 | """
279 | Tests that train() returns the TrainOutput from parent class.
280 | Args:
281 | text_anonymization (TextAnonymization): The TextAnonymization instance.
282 | mocker (MockerFixture): The pytest-mock fixture for creating mocks.
283 | """
284 |
285 | mock_train_output = mocker.Mock(spec=TrainOutput)
286 | mock_train_output.training_loss = 0.05
287 | mock_train_output.metrics = {"accuracy": 0.95}
288 |
289 | mocker.patch.object(
290 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output
291 | )
292 |
293 | domain = "insurance"
294 | result = text_anonymization.train(domain=domain)
295 |
296 | assert result == mock_train_output
297 | assert result.training_loss == 0.05
298 | assert result.metrics == {"accuracy": 0.95}
--------------------------------------------------------------------------------
/tests/unit/reranker/test_rr__call__.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | import pytest
3 | from pytest_mock import MockerFixture
4 | import torch
5 |
6 | from artifex.models import Reranker
7 | from artifex.config import config
8 |
9 |
10 | @pytest.fixture(scope="function", autouse=True)
11 | def mock_dependencies(mocker: MockerFixture):
12 | """
13 | Fixture to mock all external dependencies before any test runs.
14 | This fixture runs automatically for all tests in this module.
15 | Args:
16 | mocker (MockerFixture): The pytest-mock fixture for mocking.
17 | """
18 |
19 | # Mock config
20 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model")
21 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512)
22 |
23 | # Mock AutoTokenizer at the module where it's used
24 | mock_tokenizer = mocker.MagicMock()
25 | mocker.patch(
26 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained",
27 | return_value=mock_tokenizer
28 | )
29 |
30 | # Mock AutoModelForSequenceClassification at the module where it's used
31 | mock_model = mocker.MagicMock()
32 | mocker.patch(
33 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained",
34 | return_value=mock_model
35 | )
36 |
37 |
38 | @pytest.fixture
39 | def mock_synthex(mocker: MockerFixture) -> Synthex:
40 | """
41 | Fixture to create a mock Synthex instance.
42 | Args:
43 | mocker (MockerFixture): The pytest-mock fixture for mocking.
44 | Returns:
45 | Synthex: A mocked Synthex instance.
46 | """
47 |
48 | return mocker.MagicMock(spec=Synthex)
49 |
50 |
51 | @pytest.fixture
52 | def mock_reranker(mocker: MockerFixture, mock_synthex: Synthex) -> Reranker:
53 | """
54 | Fixture to create a Reranker instance with mocked dependencies.
55 | Args:
56 | mocker (MockerFixture): The pytest-mock fixture for mocking.
57 | mock_synthex (Synthex): A mocked Synthex instance.
58 | Returns:
59 | Reranker: An instance of the Reranker model with mocked dependencies.
60 | """
61 |
62 | reranker = Reranker(mock_synthex)
63 |
64 | # Mock the tokenizer to return proper inputs
65 | mock_tokenizer_output = {
66 | "input_ids": torch.tensor([[1, 2, 3]]),
67 | "attention_mask": torch.tensor([[1, 1, 1]])
68 | }
69 | reranker._tokenizer.return_value = mock_tokenizer_output
70 |
71 | # Mock the model output
72 | mock_model_output = mocker.MagicMock()
73 | reranker._model.return_value = mock_model_output
74 |
75 | return reranker
76 |
77 |
78 | @pytest.mark.unit
79 | def test_call_with_single_document(
80 | mock_reranker: Reranker, mocker: MockerFixture
81 | ):
82 | """
83 | Test that __call__ works correctly with a single document string.
84 | Args:
85 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
86 | mocker (MockerFixture): The pytest-mock fixture for mocking.
87 | """
88 |
89 | query = "What is machine learning?"
90 | document = "Machine learning is a subset of artificial intelligence."
91 |
92 | # Mock model output
93 | mock_logits = torch.tensor([[0.85]])
94 | mock_output = mocker.MagicMock()
95 | mock_output.logits = mock_logits
96 | mock_reranker._model.return_value = mock_output
97 |
98 | result = mock_reranker(query, document)
99 |
100 | # Should return a list with one tuple
101 | assert isinstance(result, list)
102 | assert len(result) == 1
103 | assert isinstance(result[0], tuple)
104 | assert result[0][0] == document
105 | assert isinstance(result[0][1], float)
106 |
107 |
108 | @pytest.mark.unit
109 | def test_call_with_multiple_documents(
110 | mock_reranker: Reranker, mocker: MockerFixture
111 | ):
112 | """
113 | Test that __call__ works correctly with multiple documents.
114 | Args:
115 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
116 | mocker (MockerFixture): The pytest-mock fixture for mocking.
117 | """
118 |
119 | query = "What is Python?"
120 | documents = [
121 | "Python is a programming language.",
122 | "The python is a snake.",
123 | "Python was created by Guido van Rossum."
124 | ]
125 |
126 | # Mock model output with different scores
127 | mock_logits = torch.tensor([[0.9], [0.2], [0.75]])
128 | mock_output = mocker.MagicMock()
129 | mock_output.logits = mock_logits
130 | mock_reranker._model.return_value = mock_output
131 |
132 | result = mock_reranker(query, documents)
133 |
134 | # Should return a list with three tuples
135 | assert isinstance(result, list)
136 | assert len(result) == 3
137 |
138 | # Check that all documents are present
139 | result_docs = [doc for doc, _ in result]
140 | assert set(result_docs) == set(documents)
141 |
142 | # Check that results are sorted by score (descending)
143 | scores = [score for _, score in result]
144 | assert scores == sorted(scores, reverse=True)
145 |
146 |
147 | @pytest.mark.unit
148 | def test_call_documents_sorted_by_score(
149 | mock_reranker: Reranker, mocker: MockerFixture
150 | ):
151 | """
152 | Test that documents are correctly sorted by relevance score in descending order.
153 | Args:
154 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
155 | mocker (MockerFixture): The pytest-mock fixture for mocking.
156 | """
157 |
158 | query = "climate change"
159 | documents = ["Doc A", "Doc B", "Doc C"]
160 |
161 | # Mock model output with specific scores
162 | mock_logits = torch.tensor([[0.3], [0.9], [0.6]])
163 | mock_output = mocker.MagicMock()
164 | mock_output.logits = mock_logits
165 | mock_reranker._model.return_value = mock_output
166 |
167 | result = mock_reranker(query, documents)
168 |
169 | # Check ordering: Doc B (0.9), Doc C (0.6), Doc A (0.3)
170 | assert result[0][0] == "Doc B"
171 | assert result[1][0] == "Doc C"
172 | assert result[2][0] == "Doc A"
173 |
174 | assert result[0][1] > result[1][1] > result[2][1]
175 |
176 |
177 | @pytest.mark.unit
178 | def test_call_tokenizer_called_correctly(
179 | mock_reranker: Reranker, mocker: MockerFixture
180 | ):
181 | """
182 | Test that the tokenizer is called with correct parameters.
183 | Args:
184 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
185 | mocker (MockerFixture): The pytest-mock fixture for mocking.
186 | """
187 |
188 | query = "test query"
189 | documents = ["doc1", "doc2"]
190 |
191 | # Mock model output
192 | mock_logits = torch.tensor([[0.5], [0.7]])
193 | mock_output = mocker.MagicMock()
194 | mock_output.logits = mock_logits
195 | mock_reranker._model.return_value = mock_output
196 |
197 | mock_reranker(query, documents)
198 |
199 | # Verify tokenizer was called with correct arguments
200 | mock_reranker._tokenizer.assert_called_once()
201 | call_args = mock_reranker._tokenizer.call_args
202 |
203 | # First argument should be list of queries
204 | assert call_args[0][0] == [query, query]
205 | # Second argument should be list of documents
206 | assert call_args[0][1] == documents
207 | # Check keyword arguments
208 | assert call_args[1]["return_tensors"] == "pt"
209 | assert call_args[1]["truncation"] is True
210 | assert call_args[1]["padding"] is True
211 |
212 |
213 | @pytest.mark.unit
214 | def test_call_model_called_with_tokenizer_output(
215 | mock_reranker: Reranker, mocker: MockerFixture
216 | ):
217 | """
218 | Test that the model is called with the tokenizer"s output.
219 | Args:
220 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
221 | mocker (MockerFixture): The pytest-mock fixture for mocking.
222 | """
223 |
224 | query = "test"
225 | documents = ["document"]
226 |
227 | # Set up tokenizer mock output
228 | tokenizer_output = {
229 | "input_ids": torch.tensor([[1, 2, 3]]),
230 | "attention_mask": torch.tensor([[1, 1, 1]])
231 | }
232 | mock_reranker._tokenizer.return_value = tokenizer_output
233 |
234 | # Mock model output
235 | mock_logits = torch.tensor([[0.5]])
236 | mock_output = mocker.MagicMock()
237 | mock_output.logits = mock_logits
238 | mock_reranker._model.return_value = mock_output
239 |
240 | mock_reranker(query, documents)
241 |
242 | # Verify model was called with tokenizer output
243 | mock_reranker._model.assert_called_once()
244 | call_kwargs = mock_reranker._model.call_args[1]
245 | assert "input_ids" in call_kwargs or len(mock_reranker._model.call_args[0]) > 0
246 |
247 |
248 | @pytest.mark.unit
249 | def test_call_with_empty_document_list(
250 | mock_reranker: Reranker, mocker: MockerFixture
251 | ):
252 | """
253 | Test that __call__ handles an empty document list correctly.
254 | Args:
255 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
256 | mocker (MockerFixture): The pytest-mock fixture for mocking.
257 | """
258 |
259 | query = "test query"
260 | documents = []
261 |
262 | # Mock model output for empty list
263 | mock_logits = torch.tensor([]).reshape(0, 1)
264 | mock_output = mocker.MagicMock()
265 | mock_output.logits = mock_logits
266 | mock_reranker._model.return_value = mock_output
267 |
268 | result = mock_reranker(query, documents)
269 |
270 | assert isinstance(result, list)
271 | assert len(result) == 0
272 |
273 |
274 | @pytest.mark.unit
275 | def test_call_converts_single_string_to_list(
276 | mock_reranker: Reranker, mocker: MockerFixture
277 | ):
278 | """
279 | Test that a single document string is converted to a list internally.
280 | Args:
281 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
282 | mocker (MockerFixture): The pytest-mock fixture for mocking.
283 | """
284 |
285 | query = "query"
286 | document = "single document"
287 |
288 | # Mock model output
289 | mock_logits = torch.tensor([[0.8]])
290 | mock_output = mocker.MagicMock()
291 | mock_output.logits = mock_logits
292 | mock_reranker._model.return_value = mock_output
293 |
294 | result = mock_reranker(query, document)
295 |
296 | # Result should be a list with one element
297 | assert isinstance(result, list)
298 | assert len(result) == 1
299 | assert result[0][0] == document
300 |
301 |
302 | @pytest.mark.unit
303 | def test_call_returns_tuples_with_correct_types(
304 | mock_reranker: Reranker, mocker: MockerFixture
305 | ):
306 | """
307 | Test that the return value contains tuples of (str, float).
308 | Args:
309 | mock_reranker (Reranker): The Reranker instance with mocked dependencies.
310 | mocker (MockerFixture): The pytest-mock fixture for mocking.
311 | """
312 | query = "test"
313 | documents = ["doc1", "doc2", "doc3"]
314 |
315 | # Mock model output
316 | mock_logits = torch.tensor([[0.5], [0.7], [0.3]])
317 | mock_output = mocker.MagicMock()
318 | mock_output.logits = mock_logits
319 | mock_reranker._model.return_value = mock_output
320 |
321 | result = mock_reranker(query, documents)
322 |
323 | for item in result:
324 | assert isinstance(item, tuple)
325 | assert len(item) == 2
326 | assert isinstance(item[0], str)
327 | assert isinstance(item[1], float)
--------------------------------------------------------------------------------
/tests/unit/base_model/test_base_model_load.py:
--------------------------------------------------------------------------------
1 | from synthex import Synthex
2 | import pytest
3 | from pytest_mock import MockerFixture
4 | import os
5 |
6 |
7 | class ConcreteBaseModel:
8 | """
9 | Concrete implementation of BaseModel for testing purposes.
10 | """
11 |
12 | def __init__(self, synthex: Synthex):
13 | from artifex.models import BaseModel
14 | # Copy the load method to this class
15 | self.load = BaseModel.load.__get__(self, ConcreteBaseModel)
16 | self._load_model = lambda model_path: None # Mock implementation
17 |
18 |
19 | @pytest.fixture
20 | def mock_synthex(mocker: MockerFixture) -> Synthex:
21 | """
22 | Fixture to create a mock Synthex instance.
23 | Args:
24 | mocker (MockerFixture): The pytest-mock fixture for mocking.
25 | Returns:
26 | Synthex: A mocked Synthex instance.
27 | """
28 |
29 | return mocker.MagicMock(spec=Synthex)
30 |
31 |
32 | @pytest.fixture
33 | def concrete_model(mock_synthex: Synthex) -> ConcreteBaseModel:
34 | """
35 | Fixture to create a concrete BaseModel instance for testing.
36 | Args:
37 | mock_synthex (Synthex): A mocked Synthex instance.
38 | Returns:
39 | ConcreteBaseModel: A concrete implementation of BaseModel.
40 | """
41 |
42 | return ConcreteBaseModel(mock_synthex)
43 |
44 |
45 | @pytest.mark.unit
46 | def test_load_with_valid_path(concrete_model: ConcreteBaseModel, mocker: MockerFixture):
47 | """
48 | Test that load() successfully loads a model from a valid path.
49 | Args:
50 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
51 | mocker (MockerFixture): The pytest-mock fixture for mocking.
52 | """
53 |
54 | model_path = "/fake/model/path"
55 |
56 | # Mock os.path.exists to return True
57 | mocker.patch("os.path.exists", return_value=True)
58 |
59 | # Mock _load_model
60 | mock_load_model = mocker.patch.object(concrete_model, "_load_model")
61 |
62 | concrete_model.load(model_path)
63 |
64 | # Verify _load_model was called with the correct path
65 | mock_load_model.assert_called_once_with(model_path)
66 |
67 |
68 | @pytest.mark.unit
69 | def test_load_raises_error_when_path_does_not_exist(
70 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
71 | ):
72 | """
73 | Test that load() raises OSError when the model path does not exist.
74 | Args:
75 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
76 | mocker (MockerFixture): The pytest-mock fixture for mocking.
77 | """
78 |
79 | model_path = "/nonexistent/path"
80 |
81 | # Mock os.path.exists to return False for the directory
82 | mocker.patch("os.path.exists", return_value=False)
83 |
84 | with pytest.raises(OSError) as exc_info:
85 | concrete_model.load(model_path)
86 |
87 | assert f"The specified model path '{model_path}' does not exist" in str(exc_info.value)
88 |
89 |
90 | @pytest.mark.unit
91 | def test_load_raises_error_when_config_json_missing(
92 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
93 | ):
94 | """
95 | Test that load() raises OSError when config.json is missing.
96 | Args:
97 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
98 | mocker (MockerFixture): The pytest-mock fixture for mocking.
99 | """
100 |
101 | model_path = "/fake/model/path"
102 |
103 | # Mock os.path.exists - directory exists, but config.json doesn"t
104 | def exists_side_effect(path: str) -> bool:
105 | if path == model_path:
106 | return True
107 | if path == os.path.join(model_path, "config.json"):
108 | return False
109 | return True
110 |
111 | mocker.patch("os.path.exists", side_effect=exists_side_effect)
112 |
113 | with pytest.raises(OSError) as exc_info:
114 | concrete_model.load(model_path)
115 |
116 | assert "missing the required file 'config.json'" in str(exc_info.value)
117 |
118 |
119 | @pytest.mark.unit
120 | def test_load_raises_error_when_model_safetensors_missing(
121 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
122 | ):
123 | """
124 | Test that load() raises OSError when model.safetensors is missing.
125 | Args:
126 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
127 | mocker (MockerFixture): The pytest-mock fixture for mocking.
128 | """
129 |
130 | model_path = "/fake/model/path"
131 |
132 | # Mock os.path.exists - directory and config.json exist, but model.safetensors doesn"t
133 | def exists_side_effect(path: str) -> bool:
134 | if path == model_path:
135 | return True
136 | if path == os.path.join(model_path, "config.json"):
137 | return True
138 | if path == os.path.join(model_path, "model.safetensors"):
139 | return False
140 | return True
141 |
142 | mocker.patch("os.path.exists", side_effect=exists_side_effect)
143 |
144 | with pytest.raises(OSError) as exc_info:
145 | concrete_model.load(model_path)
146 |
147 | assert "missing the required file 'model.safetensors'" in str(exc_info.value)
148 |
149 |
150 | @pytest.mark.unit
151 | def test_load_checks_all_required_files(
152 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
153 | ):
154 | """
155 | Test that load() checks for all required files (config.json and model.safetensors).
156 | Args:
157 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
158 | mocker (MockerFixture): The pytest-mock fixture for mocking.
159 | """
160 |
161 | model_path = "/fake/model/path"
162 |
163 | # Track all paths that os.path.exists is called with
164 | checked_paths: list[str] = []
165 |
166 | def exists_side_effect(path: str) -> bool:
167 | checked_paths.append(path)
168 | return True
169 |
170 | mocker.patch("os.path.exists", side_effect=exists_side_effect)
171 | mocker.patch.object(concrete_model, "_load_model")
172 |
173 | concrete_model.load(model_path)
174 |
175 | # Verify that both required files were checked
176 | assert os.path.join(model_path, "config.json") in checked_paths
177 | assert os.path.join(model_path, "model.safetensors") in checked_paths
178 |
179 |
180 | @pytest.mark.unit
181 | def test_load_validation_failure_with_non_string_path(concrete_model: ConcreteBaseModel):
182 | """
183 | Test that load() raises ValidationError when model_path is not a string.
184 | Args:
185 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
186 | """
187 |
188 | from artifex.core import ValidationError
189 |
190 | with pytest.raises(ValidationError):
191 | concrete_model.load(123)
192 |
193 |
194 | @pytest.mark.unit
195 | def test_load_validation_failure_with_none_path(concrete_model: ConcreteBaseModel):
196 | """
197 | Test that load() raises ValidationError when model_path is None.
198 | Args:
199 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
200 | """
201 |
202 | from artifex.core import ValidationError
203 |
204 | with pytest.raises(ValidationError):
205 | concrete_model.load(None)
206 |
207 |
208 | @pytest.mark.unit
209 | def test_load_with_relative_path(
210 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
211 | ):
212 | """
213 | Test that load() works with relative paths.
214 | Args:
215 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
216 | mocker (MockerFixture): The pytest-mock fixture for mocking.
217 | """
218 |
219 | model_path = "./models/my_model"
220 |
221 | mocker.patch("os.path.exists", return_value=True)
222 | mock_load_model = mocker.patch.object(concrete_model, "_load_model")
223 |
224 | concrete_model.load(model_path)
225 |
226 | mock_load_model.assert_called_once_with(model_path)
227 |
228 |
229 | @pytest.mark.unit
230 | def test_load_with_absolute_path(
231 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
232 | ):
233 | """
234 | Test that load() works with absolute paths.
235 | Args:
236 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
237 | mocker (MockerFixture): The pytest-mock fixture for mocking.
238 | """
239 |
240 | model_path = "/home/user/models/my_model"
241 |
242 | mocker.patch("os.path.exists", return_value=True)
243 | mock_load_model = mocker.patch.object(concrete_model, "_load_model")
244 |
245 | concrete_model.load(model_path)
246 |
247 | mock_load_model.assert_called_once_with(model_path)
248 |
249 |
250 | @pytest.mark.unit
251 | def test_load_with_path_containing_spaces(
252 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
253 | ):
254 | """
255 | Test that load() works with paths containing spaces.
256 | Args:
257 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
258 | mocker (MockerFixture): The pytest-mock fixture for mocking.
259 | """
260 |
261 | model_path = "/path/with spaces/my model"
262 |
263 | mocker.patch("os.path.exists", return_value=True)
264 | mock_load_model = mocker.patch.object(concrete_model, "_load_model")
265 |
266 | concrete_model.load(model_path)
267 |
268 | mock_load_model.assert_called_once_with(model_path)
269 |
270 |
271 | @pytest.mark.unit
272 | def test_load_error_message_contains_path(
273 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
274 | ):
275 | """
276 | Test that OSError messages contain the specified path for debugging.
277 | Args:
278 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
279 | mocker (MockerFixture): The pytest-mock fixture for mocking.
280 | """
281 |
282 | model_path = "/custom/model/path"
283 |
284 | mocker.patch("os.path.exists", return_value=False)
285 |
286 | with pytest.raises(OSError) as exc_info:
287 | concrete_model.load(model_path)
288 |
289 | assert model_path in str(exc_info.value)
290 |
291 |
292 | @pytest.mark.unit
293 | def test_load_calls_load_model_only_after_validation(
294 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
295 | ):
296 | """
297 | Test that _load_model is only called after all validations pass.
298 | Args:
299 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
300 | mocker (MockerFixture): The pytest-mock fixture for mocking.
301 | """
302 |
303 | model_path = "/fake/model/path"
304 |
305 | # Make config.json missing
306 | def exists_side_effect(path: str) -> bool:
307 | if path == model_path:
308 | return True
309 | if path == os.path.join(model_path, "config.json"):
310 | return False
311 | return True
312 |
313 | mocker.patch("os.path.exists", side_effect=exists_side_effect)
314 | mock_load_model = mocker.patch.object(concrete_model, "_load_model")
315 |
316 | with pytest.raises(OSError):
317 | concrete_model.load(model_path)
318 |
319 | # _load_model should not be called if validation fails
320 | mock_load_model.assert_not_called()
321 |
322 |
323 | @pytest.mark.unit
324 | def test_load_with_empty_string_path(
325 | concrete_model: ConcreteBaseModel, mocker: MockerFixture
326 | ):
327 | """
328 | Test that load() handles empty string path appropriately.
329 | Args:
330 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance.
331 | mocker (MockerFixture): The pytest-mock fixture for mocking.
332 | """
333 |
334 | model_path = ""
335 |
336 | mocker.patch("os.path.exists", return_value=False)
337 |
338 | with pytest.raises(OSError) as exc_info:
339 | concrete_model.load(model_path)
340 |
341 | assert "does not exist" in str(exc_info.value)
--------------------------------------------------------------------------------