├── artifex ├── py.typed ├── models │ ├── reranker │ │ └── __init__.py │ ├── classification │ │ ├── binary_classification │ │ │ ├── __init__.py │ │ │ └── guardrail │ │ │ │ └── __init__.py │ │ ├── multi_class_classification │ │ │ ├── __init__.py │ │ │ ├── intent_classifier │ │ │ │ └── __init__.py │ │ │ ├── emotion_detection │ │ │ │ └── __init__.py │ │ │ └── sentiment_analysis │ │ │ │ └── __init__.py │ │ └── __init__.py │ ├── named_entity_recognition │ │ ├── __init__.py │ │ └── text_anonymization │ │ │ └── __init__.py │ └── __init__.py ├── core │ ├── __init__.py │ ├── exceptions.py │ ├── models.py │ ├── decorators.py │ └── _hf_patches.py ├── utils.py ├── config.py └── __init__.py ├── tests ├── __init__.py ├── unit │ ├── __init__.py │ ├── guardrail │ │ ├── test_gr__init__.py │ │ ├── test_gr_get_data_gen_instr.py │ │ └── test_gr_train.py │ ├── emotion_detection │ │ └── test_ed__init__.py │ ├── intent_classifier │ │ └── test_ic__init__.py │ ├── sentiment_analysis │ │ └── test_sa__init__.py │ ├── classification_model │ │ ├── test_cm__init__.py │ │ ├── test_cm_load_model.py │ │ └── test_cm_get_data_gen_instr.py │ ├── core │ │ └── decorators │ │ │ ├── test_should_skip_method.py │ │ │ └── test_auto_validate_methods.py │ ├── base_model │ │ ├── test_bm_sanitize_output_path.py │ │ └── test_base_model_load.py │ ├── named_entity_recognition │ │ ├── test_ner_parse_user_instructions.py │ │ ├── test_ner_post_process_synthetic_dataset.py │ │ └── test_ner_get_data_gen_instr.py │ ├── reranker │ │ ├── test_rr_get_data_gen_instr.py │ │ ├── test_rr_parse_user_instructions.py │ │ └── test_rr__call__.py │ └── text_anonymization │ │ ├── test_ta__call__.py │ │ └── test_ta_train.py ├── integration │ ├── __init__.py │ ├── reranker │ │ ├── test_rr_train_intgr.py │ │ └── test_rr__call__intgr.py │ ├── text_classification │ │ ├── test_tc_train_intgr.py │ │ └── test_tc__call__intgr.py │ ├── guardrail │ │ ├── test_gr_train_intgr.py │ │ └── test_gr__call__intgr.py │ ├── text_anonymization │ │ ├── test_ta_train_intgr.py │ │ └── test_ta__call__intgr.py │ ├── intent_classifier │ │ ├── test_ic_train_intgr.py │ │ └── test_ic__call__intgr.py │ ├── emotion_detection │ │ ├── test_ed_train_intgr.py │ │ └── test_ed__call__intgr.py │ ├── named_entity_recognition │ │ ├── test_ner_train_intgr.py │ │ └── test_ner__call__intgr.py │ └── sentiment_analysis │ │ ├── test_sa__call__.py │ │ └── test_sa_train_intgr.py └── conftest.py ├── .github ├── CODEOWNERS ├── pull_request_template.md ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── python-publish.yml ├── setup.py ├── assets ├── hero.png ├── banner.png └── experiment.png ├── pytest.ini ├── .gitignore ├── SECURITY.md ├── LICENSE ├── CLA.md ├── pyproject.toml ├── CODE_OF_CONDUCT.md ├── CHANGELOG.md ├── CONTRIBUTING.md └── requirements.txt /artifex/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @rlucatoor -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup() -------------------------------------------------------------------------------- /assets/hero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/hero.png -------------------------------------------------------------------------------- /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/banner.png -------------------------------------------------------------------------------- /assets/experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanaos/artifex/HEAD/assets/experiment.png -------------------------------------------------------------------------------- /artifex/models/reranker/__init__.py: -------------------------------------------------------------------------------- 1 | from .reranker import Reranker 2 | 3 | __all__ = ["Reranker"] -------------------------------------------------------------------------------- /artifex/models/classification/binary_classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .guardrail import Guardrail 2 | 3 | __all__ = [ 4 | "Guardrail" 5 | ] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -s 3 | markers = 4 | unit: marks tests as unit (use with -m "unit") 5 | integration: marks tests as integration (use with -m "integration") -------------------------------------------------------------------------------- /artifex/models/named_entity_recognition/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_anonymization import TextAnonymization 2 | from .named_entity_recognition import NamedEntityRecognition 3 | 4 | __all__ = [ 5 | "NamedEntityRecognition", 6 | "TextAnonymization" 7 | ] -------------------------------------------------------------------------------- /artifex/models/classification/multi_class_classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .emotion_detection import EmotionDetection 2 | from .intent_classifier import IntentClassifier 3 | from .sentiment_analysis import SentimentAnalysis 4 | 5 | __all__ = [ 6 | "EmotionDetection", 7 | "IntentClassifier", 8 | "SentimentAnalysis" 9 | ] -------------------------------------------------------------------------------- /artifex/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification_model import ClassificationModel 2 | from .binary_classification import Guardrail 3 | from .multi_class_classification import EmotionDetection, IntentClassifier, SentimentAnalysis 4 | 5 | __all__ = [ 6 | "ClassificationModel", 7 | "Guardrail", 8 | "EmotionDetection", 9 | "IntentClassifier", 10 | "SentimentAnalysis", 11 | ] -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Pull Request 2 | 3 | ## What does this PR do? 4 | 5 | 6 | 7 | ## Checklist 8 | 9 | - [ ] I’ve tested the changes 10 | - [ ] I’ve added or updated docs (if applicable) 11 | - [ ] I’ve added or updated tests (if applicable) 12 | 13 | ## Related Issues 14 | 15 | 16 | 17 | ## Additional Notes 18 | 19 | 20 | -------------------------------------------------------------------------------- /tests/integration/reranker/test_rr_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `Reranker` class. 13 | Args: 14 | artifex (Artifex): The Artifex instance to be used for testing. 15 | """ 16 | 17 | artifex.reranker.train( 18 | domain="test domain", 19 | num_samples=40, 20 | num_epochs=1, 21 | output_path=output_folder 22 | ) -------------------------------------------------------------------------------- /artifex/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationModel, Guardrail, EmotionDetection, IntentClassifier, \ 2 | SentimentAnalysis 3 | 4 | from .base_model import BaseModel 5 | 6 | from .named_entity_recognition import NamedEntityRecognition, TextAnonymization 7 | 8 | from .reranker import Reranker 9 | 10 | __all__ = [ 11 | "ClassificationModel", 12 | "Guardrail", 13 | "EmotionDetection", 14 | "IntentClassifier", 15 | "SentimentAnalysis", 16 | "BaseModel", 17 | "NamedEntityRecognition", 18 | "TextAnonymization", 19 | "Reranker", 20 | ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtual environment 2 | .venv/ 3 | venv/ 4 | .env/ 5 | 6 | # VS Code stuff 7 | .vscode/ 8 | 9 | # Cache 10 | __pycache__/ 11 | .*cache 12 | 13 | # Environment variables 14 | .env 15 | 16 | # Tests 17 | test_data/ 18 | .coverage 19 | htmlcov/ 20 | test.py 21 | test.ipynb 22 | .pytest_env_backup 23 | 24 | # Distribution / packaging 25 | build/ 26 | dist/ 27 | *.egg-info/ 28 | 29 | # Version info 30 | __version__.py 31 | 32 | # Debug Mode 33 | .debug 34 | 35 | # Training output 36 | trainer_output/ 37 | artifex_output/ 38 | 39 | # Notebooks 40 | .ipynb_checkpoints/ 41 | notebooks/ -------------------------------------------------------------------------------- /artifex/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorators import auto_validate_methods 2 | from .exceptions import ServerError, ValidationError, BadRequestError 3 | from .models import ClassificationResponse, ClassificationClassName, NERTagName, NEREntity, \ 4 | ClassificationInstructions, NERInstructions 5 | 6 | 7 | __all__ = [ 8 | "auto_validate_methods", 9 | "ServerError", 10 | "ValidationError", 11 | "BadRequestError", 12 | "ClassificationResponse", 13 | "ClassificationClassName", 14 | "NERTagName", 15 | "NEREntity", 16 | "ClassificationInstructions", 17 | "NERInstructions" 18 | ] -------------------------------------------------------------------------------- /artifex/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from artifex.config import config 3 | 4 | 5 | def get_model_output_path(output_path: str) -> str: 6 | """ 7 | Get the output path for the trained model based on the provided output path (its parent directory). 8 | """ 9 | 10 | return str(os.path.join(output_path, config.SYNTHEX_OUTPUT_MODEL_FOLDER_NAME)) 11 | 12 | def get_dataset_output_path(output_path: str) -> str: 13 | """ 14 | Get the output path for the dataset based on the provided output path (its parent directory). 15 | """ 16 | 17 | return str(os.path.join(output_path, config.DEFAULT_SYNTHEX_DATASET_NAME)) 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | We currently support the latest published version of this project. 6 | Older versions may not receive security updates unless explicitly stated. 7 | 8 | ## Reporting a Vulnerability 9 | 10 | If you discover a security issue in this project, please do not open a public issue or pull requests, as doing so would make immediately make the problem visible to potentially malicious actors. 11 | 12 | Instead, report it privately by emailing: 13 | 14 | 📬 **info@tanaos.com** 15 | 16 | Please include as much detail as possible so we can verify and address the issue quickly. 17 | 18 | We appreciate responsible disclosure and will do our best to respond promptly. 19 | -------------------------------------------------------------------------------- /tests/integration/text_classification/test_tc_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `ClassificationModel` class. Verify that: 13 | - The training process completes without errors. 14 | Args: 15 | artifex (Artifex): The Artifex instance to be used for testing. 16 | """ 17 | 18 | classes = { 19 | "class_a": "Description for class A.", 20 | "class_b": "Description for class B.", 21 | "class_c": "Description for class C." 22 | } 23 | 24 | tc = artifex.text_classification 25 | 26 | tc.train( 27 | domain="test domain", 28 | classes=classes, 29 | num_samples=40, 30 | num_epochs=1, 31 | output_path=output_folder 32 | ) -------------------------------------------------------------------------------- /artifex/core/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class ArtifexError(Exception): 5 | """ 6 | Base exception for all errors raised by the library. 7 | """ 8 | 9 | def __init__( 10 | self, message: str, details: Optional[str] = None 11 | ): 12 | self.message = message 13 | self.details = details 14 | super().__init__(self.__str__()) 15 | 16 | def __str__(self): 17 | parts = [f"{self.message}"] 18 | if self.details: 19 | parts.append(f"Details: {self.details}") 20 | return " ".join(parts) 21 | 22 | 23 | class BadRequestError(ArtifexError): 24 | """Raised when the API request is malformed or invalid.""" 25 | pass 26 | 27 | class ServerError(ArtifexError): 28 | """Raised when the API server returns a 5xx error.""" 29 | pass 30 | 31 | class ValidationError(ArtifexError): 32 | """Raised when the API returns a validation error.""" 33 | pass -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import shutil 3 | from typing import Generator 4 | 5 | from artifex import Artifex 6 | from artifex.config import config 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def artifex() -> Artifex: 11 | """ 12 | Creates and returns an instance of the Artifex class using the API key 13 | from the environment variables. 14 | Returns: 15 | Artifex: An instance of the Artifex class initialized with the API key. 16 | """ 17 | 18 | api_key = config.API_KEY 19 | if not api_key: 20 | pytest.fail("API_KEY not found in environment variables") 21 | return Artifex(api_key) 22 | 23 | 24 | @pytest.fixture(scope="function") 25 | def output_folder() -> Generator[str, None, None]: 26 | """ 27 | Provides a temporary output folder path and cleans it up after the test. 28 | """ 29 | folder_path = "./output_folder/" 30 | yield folder_path 31 | # Cleanup after test 32 | shutil.rmtree(folder_path, ignore_errors=True) -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Riccardo Lucato 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /tests/integration/guardrail/test_gr_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `Guardrail` class. Ensure that: 13 | - The training process completes without errors. 14 | - The output model's id2label mapping is { 0: "safe", 1: "unsafe" }. 15 | - The output model's label2id mapping is { "safe": 0, "unsafe": 1 }. 16 | Args: 17 | artifex (Artifex): The Artifex instance to be used for testing. 18 | """ 19 | 20 | gr = artifex.guardrail 21 | 22 | gr.train( 23 | unsafe_content=["test instructions"], 24 | num_samples=40, 25 | num_epochs=1, 26 | output_path=output_folder, 27 | ) 28 | 29 | # Verify the model's config mappings 30 | id2label = gr._model.config.id2label 31 | label2id = gr._model.config.label2id 32 | assert id2label == { 0: "safe", 1: "unsafe" } 33 | assert label2id == { "safe": 0, "unsafe": 1 } -------------------------------------------------------------------------------- /tests/unit/guardrail/test_gr__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from artifex.models import Guardrail 4 | 5 | 6 | def test_guardrail_init(mocker: MockerFixture): 7 | """ 8 | Unit test for Guardrail.__init__. 9 | Args: 10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies. 11 | """ 12 | 13 | # Mock Synthex 14 | mock_synthex = mocker.Mock() 15 | # Mock config 16 | mock_config = mocker.patch("artifex.models.classification.binary_classification.guardrail.config") 17 | mock_config.GUARDRAIL_HF_BASE_MODEL = "mocked-base-model" 18 | # Mock ClassificationModel.__init__ 19 | mock_super_init = mocker.patch( 20 | "artifex.models.classification.classification_model.ClassificationModel.__init__", 21 | return_value=None 22 | ) 23 | 24 | # Instantiate Guardrail 25 | model = Guardrail(mock_synthex) 26 | 27 | # Assert ClassificationModel.__init__ was called with correct args 28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model") 29 | # Assert _system_data_gen_instr_val is set correctly 30 | assert isinstance(model._system_data_gen_instr_val, list) 31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val) -------------------------------------------------------------------------------- /tests/integration/text_anonymization/test_ta_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `TextAnonymization` class. Verify that: 13 | - The training process completes without errors. 14 | - The output model's id2label mapping is the expected one. 15 | - The output model's label2id mapping is the expected one. 16 | Args: 17 | artifex (Artifex): The Artifex instance to be used for testing. 18 | """ 19 | 20 | named_entities = artifex.text_anonymization._pii_entities 21 | 22 | bio_labels = ["O"] 23 | for name in named_entities.keys(): 24 | bio_labels.extend([f"B-{name}", f"I-{name}"]) 25 | 26 | ta = artifex.text_anonymization 27 | 28 | ta.train( 29 | domain="test domain", 30 | num_samples=40, 31 | num_epochs=1, 32 | output_path=output_folder 33 | ) 34 | 35 | # Verify the model's config mappings 36 | id2label = ta._model.config.id2label 37 | label2id = ta._model.config.label2id 38 | assert id2label == { i: label for i, label in enumerate(bio_labels) } 39 | assert label2id == { label: i for i, label in enumerate(bio_labels) } -------------------------------------------------------------------------------- /tests/integration/intent_classifier/test_ic_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `IntentClassifier` class. Verify that: 13 | - The training process completes without errors. 14 | - The output model's id2label mapping is the expected one. 15 | - The output model's label2id mapping is the expected one. 16 | Args: 17 | artifex (Artifex): The Artifex instance to be used for testing. 18 | """ 19 | 20 | classes = { 21 | "class_a": "Description for class A.", 22 | "class_b": "Description for class B.", 23 | "class_c": "Description for class C." 24 | } 25 | 26 | ic = artifex.intent_classifier 27 | 28 | ic.train( 29 | domain="test domain", 30 | classes=classes, 31 | num_samples=40, 32 | num_epochs=1, 33 | output_path=output_folder 34 | ) 35 | 36 | # Verify the model's config mappings 37 | id2label = ic._model.config.id2label 38 | label2id = ic._model.config.label2id 39 | assert id2label == { i: label for i, label in enumerate(classes.keys()) } 40 | assert label2id == { label: i for i, label in enumerate(classes.keys()) } -------------------------------------------------------------------------------- /tests/integration/text_classification/test_tc__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import ClassificationResponse 5 | 6 | 7 | @pytest.mark.integration 8 | def test__call__single_input_success( 9 | artifex: Artifex 10 | ): 11 | """ 12 | Test the `__call__` method of the `TextClassification` class when a single input is 13 | provided. Ensure that it returns a list of ClassificationResponse objects. 14 | Args: 15 | artifex (Artifex): The Artifex instance to be used for testing. 16 | """ 17 | 18 | out = artifex.text_classification("test input") 19 | assert isinstance(out, list) 20 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 21 | 22 | 23 | @pytest.mark.integration 24 | def test__call__multiple_inputs_success( 25 | artifex: Artifex 26 | ): 27 | """ 28 | Test the `__call__` method of the `TextClassification` class when multiple inputs are 29 | provided. Ensure that it returns a list of ClassificationResponse objects. 30 | Args: 31 | artifex (Artifex): The Artifex instance to be used for testing. 32 | """ 33 | 34 | out = artifex.text_classification(["test input 1", "test input 2", "test input 3"]) 35 | assert isinstance(out, list) 36 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 37 | -------------------------------------------------------------------------------- /tests/unit/emotion_detection/test_ed__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from artifex.models import EmotionDetection 4 | 5 | 6 | def test_emotion_detection_init(mocker: MockerFixture): 7 | """ 8 | Unit test for EmotionDetection.__init__. 9 | Args: 10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies. 11 | """ 12 | 13 | # Mock Synthex 14 | mock_synthex = mocker.Mock() 15 | # Mock config 16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.emotion_detection.config") 17 | mock_config.EMOTION_DETECTION_HF_BASE_MODEL = "mocked-base-model" 18 | # Mock ClassificationModel.__init__ 19 | mock_super_init = mocker.patch( 20 | "artifex.models.classification.classification_model.ClassificationModel.__init__", 21 | return_value=None 22 | ) 23 | 24 | # Instantiate EmotionDetection 25 | model = EmotionDetection(mock_synthex) 26 | 27 | # Assert ClassificationModel.__init__ was called with correct args 28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model") 29 | # Assert _system_data_gen_instr_val is set correctly 30 | assert isinstance(model._system_data_gen_instr_val, list) 31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val) -------------------------------------------------------------------------------- /tests/unit/intent_classifier/test_ic__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from artifex.models import IntentClassifier 4 | 5 | 6 | def test_intent_classifier_init(mocker: MockerFixture): 7 | """ 8 | Unit test for IntentClassifier.__init__. 9 | Args: 10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies. 11 | """ 12 | 13 | # Mock Synthex 14 | mock_synthex = mocker.Mock() 15 | # Mock config 16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.intent_classifier.config") 17 | mock_config.INTENT_CLASSIFIER_HF_BASE_MODEL = "mocked-base-model" 18 | # Mock ClassificationModel.__init__ 19 | mock_super_init = mocker.patch( 20 | "artifex.models.classification.classification_model.ClassificationModel.__init__", 21 | return_value=None 22 | ) 23 | 24 | # Instantiate IntentClassifier 25 | model = IntentClassifier(mock_synthex) 26 | 27 | # Assert ClassificationModel.__init__ was called with correct args 28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model") 29 | # Assert _system_data_gen_instr_val is set correctly 30 | assert isinstance(model._system_data_gen_instr_val, list) 31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val) -------------------------------------------------------------------------------- /tests/unit/sentiment_analysis/test_sa__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from artifex.models import SentimentAnalysis 4 | 5 | 6 | def test_sentiment_analysis_init(mocker: MockerFixture): 7 | """ 8 | Unit test for SentimentAnalysis.__init__. 9 | Args: 10 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies. 11 | """ 12 | 13 | # Mock Synthex 14 | mock_synthex = mocker.Mock() 15 | # Mock config 16 | mock_config = mocker.patch("artifex.models.classification.multi_class_classification.sentiment_analysis.config") 17 | mock_config.SENTIMENT_ANALYSIS_HF_BASE_MODEL = "mocked-base-model" 18 | # Mock ClassificationModel.__init__ 19 | mock_super_init = mocker.patch( 20 | "artifex.models.classification.classification_model.ClassificationModel.__init__", 21 | return_value=None 22 | ) 23 | 24 | # Instantiate SentimentAnalysis 25 | model = SentimentAnalysis(mock_synthex) 26 | 27 | # Assert ClassificationModel.__init__ was called with correct args 28 | mock_super_init.assert_called_once_with(mock_synthex, base_model_name="mocked-base-model") 29 | # Assert _system_data_gen_instr_val is set correctly 30 | assert isinstance(model._system_data_gen_instr_val, list) 31 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val) -------------------------------------------------------------------------------- /CLA.md: -------------------------------------------------------------------------------- 1 | Contributor License Agreement (CLA) 2 | 3 | Thank you for your interest in contributing to Artifex. 4 | 5 | By signing this CLA (electronically via GitHub or other integrated tools), you agree to the following terms for your contributions (past, present, and future): 6 | 7 | 1. **Originality**: 8 | You confirm that your contributions are your own work, or that you have the right to submit them. 9 | 10 | 2. **License**: 11 | You grant Riccardo Lucato an irrevocable, worldwide, non-exclusive, royalty-free license to use, reproduce, modify, distribute, and sublicense your contributions under the MIT License (or any other license used by the project). 12 | 13 | 3. **Employment and third-party rights**: 14 | You confirm that your contributions are not subject to any third-party rights, including those of your employer or clients. If such rights could exist, you confirm that you have obtained a valid and sufficient waiver or permission to contribute the work under this CLA and the project’s open source license. You agree that neither your employer nor any third party will have any claim or rights over your contributions to this project. 15 | 16 | 4. **Disclaimer**: 17 | Contributions are provided "as-is" without warranties or conditions of any kind. 18 | 19 | By electronically signing this CLA, you accept these terms and certify that you have the legal authority to do so. -------------------------------------------------------------------------------- /tests/integration/emotion_detection/test_ed_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import shutil 3 | 4 | from artifex import Artifex 5 | 6 | 7 | @pytest.mark.integration 8 | def test_train_success( 9 | artifex: Artifex, 10 | output_folder: str 11 | ): 12 | """ 13 | Test the `train` method of the `EmotionDetection` class. Ensure that: 14 | - The training process completes without errors. 15 | - The output model's id2label mapping is the expected one. 16 | - The output model's label2id mapping is the expected one. 17 | Args: 18 | artifex (Artifex): The Artifex instance to be used for testing. 19 | """ 20 | 21 | ed = artifex.emotion_detection 22 | 23 | try: 24 | ed.train( 25 | domain="test domain", 26 | classes={ 27 | "happy": "text expressing happiness", 28 | "sad": "text expressing sadness" 29 | }, 30 | num_samples=40, 31 | num_epochs=1, 32 | output_path=output_folder 33 | ) 34 | 35 | # Verify the model's config mappings 36 | id2label = ed._model.config.id2label 37 | label2id = ed._model.config.label2id 38 | assert id2label == { 0: "happy", 1: "sad" } 39 | assert label2id == { "happy": 0, "sad": 1 } 40 | finally: 41 | # Clean up the output folder 42 | shutil.rmtree(output_folder, ignore_errors=True) 43 | 44 | -------------------------------------------------------------------------------- /tests/integration/named_entity_recognition/test_ner_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `NamedEntityRecognition` class. Verify that: 13 | - The training process completes without errors. 14 | - The output model's id2label mapping is the expected one. 15 | - The output model's label2id mapping is the expected one. 16 | Args: 17 | artifex (Artifex): The Artifex instance to be used for testing. 18 | """ 19 | 20 | named_entities = { 21 | "ENTITY_A": "Description for entity A.", 22 | "ENTITY_B": "Description for entity B.", 23 | } 24 | 25 | bio_labels = ["O"] 26 | for name in named_entities.keys(): 27 | bio_labels.extend([f"B-{name}", f"I-{name}"]) 28 | 29 | ner = artifex.named_entity_recognition 30 | 31 | ner.train( 32 | domain="test domain", 33 | named_entities=named_entities, 34 | num_samples=40, 35 | num_epochs=1, 36 | output_path=output_folder 37 | ) 38 | 39 | # Verify the model's config mappings 40 | id2label = ner._model.config.id2label 41 | label2id = ner._model.config.label2id 42 | assert id2label == { i: label for i, label in enumerate(bio_labels) } 43 | assert label2id == { label: i for i, label in enumerate(bio_labels) } -------------------------------------------------------------------------------- /artifex/models/classification/multi_class_classification/intent_classifier/__init__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | 3 | from ...classification_model import ClassificationModel 4 | 5 | from artifex.core import auto_validate_methods 6 | from artifex.config import config 7 | 8 | 9 | @auto_validate_methods 10 | class IntentClassifier(ClassificationModel): 11 | """ 12 | An Intent Classifier Model for LLMs. This model is used to classify a text's intent or objective into 13 | predefined categories. 14 | """ 15 | 16 | def __init__(self, synthex: Synthex): 17 | """ 18 | Initializes the class with a Synthex instance. 19 | Args: 20 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used to train the model. 21 | """ 22 | 23 | super().__init__(synthex, base_model_name=config.INTENT_CLASSIFIER_HF_BASE_MODEL) 24 | self._system_data_gen_instr_val: list[str] = [ 25 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.", 26 | "The 'text' field should contain text that has a specific intent or objective.", 27 | "The 'labels' field should contain a label indicating the intent or objective of the 'text'.", 28 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.", 29 | "This is a list of the allowed 'labels' and their meaning: " 30 | ] -------------------------------------------------------------------------------- /artifex/models/classification/multi_class_classification/emotion_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | 3 | from ...classification_model import ClassificationModel 4 | 5 | from artifex.core import auto_validate_methods 6 | from artifex.config import config 7 | 8 | 9 | @auto_validate_methods 10 | class EmotionDetection(ClassificationModel): 11 | """ 12 | An Emotion Detection Model is used to classify text into different emotional categories. In this 13 | implementation, we support the following emotions: `joy`, `anger`, `fear`, `sadness`, `surprise`, `disgust`, 14 | `excitement` and `neutral`. 15 | """ 16 | 17 | def __init__(self, synthex: Synthex): 18 | """ 19 | Initializes the class with a Synthex instance. 20 | Args: 21 | synthex (Synthex): An instance of the Synthex class to generate the synthetic 22 | data used to train the model. 23 | """ 24 | 25 | super().__init__(synthex, base_model_name=config.EMOTION_DETECTION_HF_BASE_MODEL) 26 | self._system_data_gen_instr_val: list[str] = [ 27 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.", 28 | "The 'text' field should contain text that may or may not express a certain emotion.", 29 | "The 'labels' field should contain a label indicating the emotion of the 'text'.", 30 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.", 31 | "This is a list of the allowed 'labels' their meaning: " 32 | ] -------------------------------------------------------------------------------- /tests/integration/guardrail/test_gr__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import ClassificationResponse 5 | 6 | 7 | expected_labels = ["safe", "unsafe"] 8 | 9 | @pytest.mark.integration 10 | def test__call__single_input_success( 11 | artifex: Artifex 12 | ): 13 | """ 14 | Test the `__call__` method of the `Guardrail` class when a single input is provided. 15 | Ensure that: 16 | - It returns a list of ClassificationResponse objects. 17 | - The output labels are among the expected intent labels. 18 | Args: 19 | artifex (Artifex): The Artifex instance to be used for testing. 20 | """ 21 | 22 | out = artifex.guardrail("test input") 23 | assert isinstance(out, list) 24 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 25 | assert all(resp.label in expected_labels for resp in out) 26 | 27 | @pytest.mark.integration 28 | def test__call__multiple_inputs_success( 29 | artifex: Artifex 30 | ): 31 | """ 32 | Test the `__call__` method of the `Guardrail` class when multiple inputs are provided. 33 | Ensure that: 34 | - It returns a list of ClassificationResponse objects. 35 | - The output labels are among the expected intent labels. 36 | Args: 37 | artifex (Artifex): The Artifex instance to be used for testing. 38 | """ 39 | 40 | out = artifex.guardrail(["test input 1", "test input 2", "test input 3"]) 41 | assert isinstance(out, list) 42 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 43 | assert all(resp.label in expected_labels for resp in out) -------------------------------------------------------------------------------- /tests/integration/sentiment_analysis/test_sa__call__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import ClassificationResponse 5 | 6 | 7 | expected_labels = ["very_positive", "positive", "negative", "very_negative", "neutral"] 8 | 9 | 10 | @pytest.mark.integration 11 | def test__call__single_input_success( 12 | artifex: Artifex 13 | ): 14 | """ 15 | Test the `__call__` method of the `SentimentAnalysis` class. Ensure that: 16 | - It returns a list of ClassificationResponse objects. 17 | - The output labels are among the expected sentiment labels. 18 | Args: 19 | artifex (Artifex): The Artifex instance to be used for testing. 20 | """ 21 | 22 | out = artifex.sentiment_analysis("test input") 23 | assert isinstance(out, list) 24 | assert all(resp.label in expected_labels for resp in out) 25 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 26 | 27 | @pytest.mark.integration 28 | def test__call__multiple_inputs_success( 29 | artifex: Artifex 30 | ): 31 | """ 32 | Test the `__call__` method of the `SentimentAnalysis` class, when multiple inputs are 33 | provided. Ensure that: 34 | - It returns a list of ClassificationResponse objects. 35 | - The output labels are among the expected sentiment labels. 36 | Args: 37 | artifex (Artifex): The Artifex instance to be used for testing. 38 | """ 39 | 40 | out = artifex.sentiment_analysis(["test input 1", "test input 2", "test input 3"]) 41 | assert isinstance(out, list) 42 | assert all(resp.label in expected_labels for resp in out) 43 | assert all(isinstance(resp, ClassificationResponse) for resp in out) -------------------------------------------------------------------------------- /tests/integration/sentiment_analysis/test_sa_train_intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test_train_success( 8 | artifex: Artifex, 9 | output_folder: str 10 | ): 11 | """ 12 | Test the `train` method of the `SentimentAnalysisModel` class. Ensure that: 13 | - The training process completes without errors. 14 | - The output model's id2label mapping is { 0: "very_negative", 1: "negative", 2: "neutral", 3: "positive", 4: "very_positive" }. 15 | - The output model's label2id mapping is { "very_negative": 0, "negative": 1, "neutral": 2, "positive": 3, "very_positive": 4 }. 16 | Args: 17 | artifex (Artifex): The Artifex instance to be used for testing. 18 | """ 19 | 20 | sa = artifex.sentiment_analysis 21 | 22 | sa.train( 23 | domain="general", 24 | classes={ 25 | "very_negative": "Text expressing a very negative sentiment.", 26 | "negative": "Text expressing a negative sentiment.", 27 | "neutral": "Text expressing a neutral sentiment.", 28 | "positive": "Text expressing a positive sentiment.", 29 | "very_positive": "Text expressing a very positive sentiment.", 30 | }, 31 | num_samples=40, 32 | num_epochs=1, 33 | output_path=output_folder 34 | ) 35 | 36 | # Verify the model's config mappings 37 | id2label = sa._model.config.id2label 38 | label2id = sa._model.config.label2id 39 | assert id2label == { 0: "very_negative", 1: "negative", 2: "neutral", 3: "positive", 4: "very_positive" } 40 | assert label2id == { "very_negative": 0, "negative": 1, "neutral": 2, "positive": 3, "very_positive": 4 } -------------------------------------------------------------------------------- /tests/integration/emotion_detection/test_ed__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import ClassificationResponse 5 | 6 | 7 | expected_labels = ["joy", "anger", "fear", "sadness", "surprise", "disgust", "excitement", "neutral"] 8 | 9 | @pytest.mark.integration 10 | def test__call__single_input_success( 11 | artifex: Artifex 12 | ): 13 | """ 14 | Test the `__call__` method of the `EmotionDetection` class when a single input is provided. 15 | Ensure that: 16 | - It returns a list of ClassificationResponse objects. 17 | - The output labels are among the expected intent labels. 18 | Args: 19 | artifex (Artifex): The Artifex instance to be used for testing. 20 | """ 21 | 22 | out = artifex.emotion_detection("test input") 23 | assert isinstance(out, list) 24 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 25 | assert all(resp.label in expected_labels for resp in out) 26 | 27 | @pytest.mark.integration 28 | def test__call__multiple_inputs_success( 29 | artifex: Artifex 30 | ): 31 | """ 32 | Test the `__call__` method of the `EmotionDetection` class when multiple inputs are provided. 33 | Ensure that: 34 | - It returns a list of ClassificationResponse objects. 35 | - The output labels are among the expected intent labels. 36 | Args: 37 | artifex (Artifex): The Artifex instance to be used for testing. 38 | """ 39 | 40 | out = artifex.emotion_detection(["test input 1", "test input 2", "test input 3"]) 41 | assert isinstance(out, list) 42 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 43 | assert all(resp.label in expected_labels for resp in out) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "setuptools_scm", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "artifex" 7 | dynamic = ["version"] 8 | description = "Create your private AI model with no training data or GPUs 🤖🚀." 9 | authors = [ 10 | { name = "Riccardo Lucato", email = "riccardo@tanaos.com" }, 11 | { name = "Saurabh Pradhan" } 12 | ] 13 | license = { text = "MIT" } 14 | readme = "README.md" 15 | requires-python = ">=3.10" 16 | 17 | dependencies = [ 18 | "aiohttp>=3.12.14", 19 | "datasets>=3.6.0", 20 | "filelock>=3.20.1", 21 | "jupyterlab>=4.4.8", 22 | "protobuf>=6.33.2", 23 | "rich>=14.1.0", 24 | "sentencepiece>=0.2.1", 25 | "synthex>=0.4.2", 26 | "tiktoken>=0.12.0", 27 | "torch>=2.8.0", 28 | "transformers[torch]>=4.53.1", 29 | "tzlocal>=5.3.1", 30 | "urllib3>=2.6.0", 31 | ] 32 | 33 | classifiers = [ 34 | "Development Status :: 3 - Alpha", 35 | "Intended Audience :: Developers", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3.10", 38 | "Topic :: Software Development :: Libraries :: Python Modules", 39 | "Operating System :: OS Independent", 40 | ] 41 | 42 | [project.urls] 43 | homepage = "https://github.com/tanaos/artifex" 44 | 45 | [tool.setuptools] 46 | include-package-data = true 47 | zip-safe = false 48 | 49 | [tool.setuptools.packages.find] 50 | where = ["."] 51 | 52 | [tool.setuptools.package-data] 53 | artifex = ["py.typed"] 54 | 55 | [tool.setuptools_scm] 56 | version_file = "artifex/__version__.py" 57 | 58 | [dependency-groups] 59 | dev = [ 60 | "ipykernel>=6.16.2", 61 | "notebook>=6.5.7", 62 | "pytest>=8.4.1", 63 | "pytest-mock>=3.14.1", 64 | ] 65 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Purpose 4 | 5 | This project is open to contributions from anyone interested. To keep things productive and collaborative, we ask that everyone follows a few basic guidelines when participating. 6 | 7 | ## How to Participate 8 | 9 | You’re welcome to: 10 | 11 | - Submit issues, suggestions, and improvements 12 | - Contribute code via pull requests 13 | - Offer constructive feedback or ask questions 14 | - Discuss ideas, implementation details, and priorities 15 | 16 | If you disagree with something, that’s fine — just explain your reasoning and move on if consensus isn’t reached. 17 | 18 | ## Please Avoid 19 | 20 | We want to avoid things that create unnecessary friction, such as: 21 | 22 | - Personal attacks 23 | - Starting arguments for the sake of it 24 | - Name-calling or intentionally rude behavior 25 | - Taking disagreements too far or making them personal 26 | 27 | ## Enforcement 28 | 29 | Instances of disruptive, hostile or otherwise unacceptable behavior may be reported by contacting the project maintainers at info@tanaos.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 30 | 31 | Contributors who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by members of the project's leadership. 32 | 33 | ## Final Note 34 | 35 | This is an engineering project. We value clarity, logic, and constructive feedback. Treat others the way you'd expect to be treated in a professional collaboration. 36 | -------------------------------------------------------------------------------- /tests/integration/reranker/test_rr__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | 5 | 6 | @pytest.mark.integration 7 | def test__call__single_input_success( 8 | artifex: Artifex 9 | ): 10 | """ 11 | Test the `__call__` method of the `Reranker` class when a single input is provided. 12 | Ensure that: 13 | - The return type is list[tuple[str, float]]. 14 | - The output tuples only contain the provided input documents 15 | Args: 16 | artifex (Artifex): The Artifex instance to be used for testing. 17 | """ 18 | 19 | input_doc = "doc1" 20 | 21 | out = artifex.reranker(query="test query", documents=input_doc) 22 | assert isinstance(out, list) 23 | assert all( 24 | isinstance(resp, tuple) and 25 | isinstance(resp[0], str) and 26 | isinstance(resp[1], float) 27 | for resp in out 28 | ) 29 | assert all(resp[0] in [input_doc] for resp in out) 30 | 31 | @pytest.mark.integration 32 | def test__call__multiple_inputs_success( 33 | artifex: Artifex 34 | ): 35 | """ 36 | Test the `__call__` method of the `Reranker` class when multiple inputs are provided. 37 | Ensure that: 38 | - The return type is list[tuple[str, float]]. 39 | - The output tuples only contain the provided input documents. 40 | Args: 41 | artifex (Artifex): The Artifex instance to be used for testing. 42 | """ 43 | 44 | input_docs = ["doc1", "doc2", "doc3"] 45 | 46 | out = artifex.reranker(query="test query", documents=input_docs) 47 | assert isinstance(out, list) 48 | assert all( 49 | isinstance(resp, tuple) and 50 | isinstance(resp[0], str) and 51 | isinstance(resp[1], float) 52 | for resp in out 53 | ) 54 | assert all(resp[0] in input_docs for resp in out) -------------------------------------------------------------------------------- /tests/integration/intent_classifier/test_ic__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import ClassificationResponse 5 | 6 | 7 | expected_labels = [ 8 | "greeting", "farewell", "thank_you", "affirmation", "negation", "small_talk", 9 | "bot_capabilities", "feedback_positive", "feedback_negative", "clarification", 10 | "suggestion", "language_change" 11 | ] 12 | 13 | @pytest.mark.integration 14 | def test__call__single_input_success( 15 | artifex: Artifex 16 | ): 17 | """ 18 | Test the `__call__` method of the `IntentClassifier` class when a single input is 19 | provided. Ensure that: 20 | - It returns a list of ClassificationResponse objects. 21 | - The output labels are among the expected intent labels. 22 | Args: 23 | artifex (Artifex): The Artifex instance to be used for testing. 24 | """ 25 | 26 | out = artifex.intent_classifier("test input") 27 | assert isinstance(out, list) 28 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 29 | assert all(resp.label in expected_labels for resp in out) 30 | 31 | @pytest.mark.integration 32 | def test__call__multiple_inputs_success( 33 | artifex: Artifex 34 | ): 35 | """ 36 | Test the `__call__` method of the `IntentClassifier` class when multiple inputs are 37 | provided. Ensure that: 38 | - It returns a list of ClassificationResponse objects. 39 | - The output labels are among the expected intent labels. 40 | Args: 41 | artifex (Artifex): The Artifex instance to be used for testing. 42 | """ 43 | 44 | out = artifex.intent_classifier(["test input 1", "test input 2", "test input 3"]) 45 | assert isinstance(out, list) 46 | assert all(isinstance(resp, ClassificationResponse) for resp in out) 47 | assert all(resp.label in expected_labels for resp in out) -------------------------------------------------------------------------------- /artifex/core/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from artifex.config import config 4 | 5 | 6 | class ClassificationResponse(BaseModel): 7 | label: str 8 | score: float 9 | 10 | class NEREntity(BaseModel): 11 | entity_group: str 12 | word: str 13 | score: float 14 | start: int 15 | end: int 16 | 17 | class ClassificationClassName(str): 18 | """ 19 | A string subclass that enforces a maximum length and disallows spaces for classification 20 | class names. 21 | """ 22 | 23 | max_length = config.CLASSIFICATION_CLASS_NAME_MAX_LENGTH 24 | 25 | def __new__(cls, value: str): 26 | if not value: 27 | raise ValueError("ClassName must be a non-empty string") 28 | if len(value) > cls.max_length: 29 | raise ValueError(f"ClassName exceeds max length of {cls.max_length}") 30 | if ' ' in value: 31 | raise ValueError("ClassName must not contain spaces") 32 | return str.__new__(cls, value) 33 | 34 | class NERTagName(str): 35 | """ 36 | A string subclass that enforces a maximum length, requires the string to be all caps and 37 | disallows spaces for NER tag names. 38 | """ 39 | 40 | max_length = config.NER_TAGNAME_MAX_LENGTH 41 | 42 | def __new__(cls, value: str): 43 | if not value: 44 | raise ValueError("NERTagName must be a non-empty string") 45 | if len(value) > cls.max_length: 46 | raise ValueError(f"NERTagName exceeds max length of {cls.max_length}") 47 | if ' ' in value: 48 | raise ValueError("NERTagName must not contain spaces") 49 | return str.__new__(cls, value.upper()) 50 | 51 | NClassClassificationClassesDesc = dict[str, str] 52 | 53 | class ClassificationInstructions(BaseModel): 54 | classes: NClassClassificationClassesDesc 55 | domain: str 56 | 57 | NERTags = dict[str, str] 58 | 59 | class NERInstructions(BaseModel): 60 | named_entity_tags: NERTags 61 | domain: str -------------------------------------------------------------------------------- /tests/integration/named_entity_recognition/test_ner__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.core import NEREntity 5 | 6 | 7 | expected_labels = [ 8 | "O", "PERSON", "ORG", "LOCATION", "DATE", "TIME", "PERCENT", "NUMBER", "FACILITY", 9 | "PRODUCT", "WORK_OF_ART", "LANGUAGE", "NORP", "ADDRESS", "PHONE_NUMBER" 10 | ] 11 | 12 | 13 | @pytest.mark.integration 14 | def test__call__single_input_success( 15 | artifex: Artifex 16 | ): 17 | """ 18 | Test the `__call__` method of the `NamedEntityRecognition` class when a single input is 19 | provided. Ensure that: 20 | - It returns a list of list of NEREntity objects. 21 | - The output labels are among the expected named entity labels. 22 | Args: 23 | artifex (Artifex): The Artifex instance to be used for testing. 24 | """ 25 | 26 | out = artifex.named_entity_recognition("His name is John Doe.") 27 | assert isinstance(out, list) 28 | assert all(isinstance(resp, list) for resp in out) 29 | assert all(all(isinstance(entity, NEREntity) for entity in resp) for resp in out) 30 | assert all(all(entity.entity_group in expected_labels for entity in resp) for resp in out) 31 | 32 | @pytest.mark.integration 33 | def test__call__multiple_inputs_success( 34 | artifex: Artifex 35 | ): 36 | """ 37 | Test the `__call__` method of the `NamedEntityRecognition` class when multiple inputs are 38 | provided. Ensure that: 39 | - It returns a list of list of NEREntity objects. 40 | - The output labels are among the expected named entity labels. 41 | Args: 42 | artifex (Artifex): The Artifex instance to be used for testing. 43 | """ 44 | 45 | out = artifex.named_entity_recognition(["His name is John Does", "His name is Jane Smith."]) 46 | assert isinstance(out, list) 47 | assert all(isinstance(resp, list) for resp in out) 48 | assert all(all(isinstance(entity, NEREntity) for entity in resp) for resp in out) 49 | assert all(all(entity.entity_group in expected_labels for entity in resp) for resp in out) -------------------------------------------------------------------------------- /tests/integration/text_anonymization/test_ta__call__intgr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from artifex import Artifex 4 | from artifex.config import config 5 | 6 | 7 | @pytest.mark.integration 8 | def test__call__single_input_success( 9 | artifex: Artifex 10 | ): 11 | """ 12 | Test the `__call__` method of the `TextAnonymization` class. Ensure that: 13 | - It returns a list of strings. 14 | - All returned strings are identical to the input strings, except for masked entities. 15 | Args: 16 | artifex (Artifex): The Artifex instance to be used for testing. 17 | """ 18 | 19 | input = "Jonathan Rogers lives in New York City." 20 | out = artifex.text_anonymization(input) 21 | assert isinstance(out, list) 22 | assert all(isinstance(item, str) for item in out) 23 | for idx, text in enumerate(out): 24 | split_input = text.split(" ") 25 | split_out = out[idx].split(" ") 26 | assert len(split_input) == len(split_out) 27 | assert all( 28 | word in split_input or word == config.DEFAULT_TEXT_ANONYM_MASK for word in split_out 29 | ) 30 | 31 | @pytest.mark.integration 32 | def test__call__multiple_inputs_success( 33 | artifex: Artifex 34 | ): 35 | """ 36 | Test the `__call__` method of the `TextAnonymization` class, when multiple inputs are 37 | provided. Ensure that: 38 | - It returns a list of strings. 39 | - All returned strings are identical to the input strings, except for masked entities. 40 | Args: 41 | artifex (Artifex): The Artifex instance to be used for testing. 42 | """ 43 | 44 | out = artifex.text_anonymization([ 45 | "John Doe lives in New York City.", 46 | "Mark Spencer's phone number is 123-456-7890.", 47 | "Alice was born on January 1, 1990." 48 | ]) 49 | assert isinstance(out, list) 50 | assert all(isinstance(item, str) for item in out) 51 | for idx, text in enumerate(out): 52 | split_input = text.split(" ") 53 | split_out = out[idx].split(" ") 54 | assert len(split_input) == len(split_out) 55 | assert all( 56 | word in split_input or word == config.DEFAULT_TEXT_ANONYM_MASK for word in split_out 57 | ) -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Release v0.4.1 - December 22, 2025 2 | 3 | ### Changed 4 | 5 | - Turned `ClassificationModel` into a concrete class instead of an abstract class. 6 | - Replaced `instructions` argument with `unsafe_content` argument in the `Guardrail.train()` 7 | method. 8 | 9 | ### Fixed 10 | 11 | - Fixed security vulnerabilities by updating dependencies. 12 | - Suppressed annoying tokenization-related warning in `NamedEntityRecognition.__call__()` method. 13 | 14 | ## Release v0.4.0 - December 4, 2025 15 | 16 | ### Added 17 | 18 | - Added the `Reranker` model. 19 | - Added the `SentimentAnalysis` model. 20 | - Added the `EmotionDetection` model. 21 | - Added the `NamedEntityRecognition` model. 22 | - Added the `TextAnonymization` model. 23 | - Added integration tests. 24 | 25 | ### Fixed 26 | 27 | - Fixed a bug causing the "Generating training data" progress bar to display a wrong progress percentage. 28 | 29 | ### Removed 30 | 31 | - Removed support for Python <= 3.9. 32 | - Removed all `# type: ignore` comments from the codebase. 33 | 34 | ### Changed 35 | 36 | - Updated error message when the `.load()` method is provided with a nonexistent model path or an invalid file format. 37 | - Updated the `IntentClassifier` base model. 38 | - Updated the output model's directory structure for better organization. 39 | 40 | ## Release v0.3.2 - October 2, 2025 41 | 42 | ### Added 43 | 44 | - Added support for Python 3.09. 45 | 46 | ### Fixed 47 | 48 | - Fixed SyntaxError in the `config.py` file, which was causing issues during library initialization on Python versions earlier than 3.13. 49 | 50 | ## Release v0.3.1 - September 16, 2025 51 | 52 | ### Added 53 | 54 | - Added a spinner on library load. 55 | - Added a spinner on model load. 56 | - Added a log message containing the generated model path, once model generation is complete. 57 | 58 | ### Changed 59 | 60 | - Replaced data generation `tqdm` progress bar with `rich` progress bar. 61 | - Replaced data processing `tqdm` progress bar with a `rich` spinner. 62 | - Replaced model training `tqdm` progress bar with `rich` progress bar. 63 | - Moved the output model to a dedicated `output_model` directory within the output directory. 64 | 65 | ### Removed 66 | 67 | - Removed all intermediate model checkpoints. 68 | - Removed unnecessary output files from output directory. 69 | 70 | ## Release v0.3.0 - September 9, 2025 71 | 72 | First release -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Python CI/CD 2 | 3 | on: 4 | pull_request: 5 | branches: [master] 6 | release: 7 | types: [published] 8 | 9 | jobs: 10 | test: 11 | name: Run Tests 12 | runs-on: ubuntu-latest 13 | if: github.event_name == 'pull_request' 14 | 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.12" 23 | 24 | - name: Install uv 25 | run: | 26 | curl -LsSf https://astral.sh/uv/install.sh | sh 27 | echo "$HOME/.local/bin" >> $GITHUB_PATH 28 | 29 | - name: Cache uv environment 30 | uses: actions/cache@v3 31 | with: 32 | path: .venv 33 | key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml') }} 34 | restore-keys: | 35 | ${{ runner.os }}-uv- 36 | 37 | - name: Sync dependencies 38 | run: uv sync 39 | 40 | - name: Create .env file for testing 41 | run: | 42 | echo "API_KEY=${{ secrets.ARTIFEX_API_KEY }}" > .env 43 | 44 | - name: Run Unit Tests 45 | run: uv run pytest tests/unit -x -v --tb=short 46 | 47 | - name: Run Integration Tests 48 | run: uv run pytest tests/integration -x -v --tb=short 49 | 50 | publish: 51 | name: Build and Publish to PyPI 52 | runs-on: ubuntu-latest 53 | if: github.event_name == 'release' 54 | 55 | permissions: 56 | contents: read 57 | id-token: write 58 | 59 | steps: 60 | - name: Check out code 61 | uses: actions/checkout@v4 62 | 63 | - name: Set up Python 64 | uses: actions/setup-python@v5 65 | with: 66 | python-version: "3.12" 67 | 68 | - name: Install build tools 69 | run: | 70 | python -m pip install --upgrade pip 71 | pip install build twine 72 | 73 | - name: Build the package 74 | run: python -m build 75 | 76 | - name: Publish to PyPI 77 | env: 78 | TWINE_USERNAME: __token__ 79 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 80 | run: python -m twine upload dist/* 81 | -------------------------------------------------------------------------------- /artifex/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | from pydantic_settings import BaseSettings, SettingsConfigDict 3 | import os 4 | from datetime import datetime 5 | from tzlocal import get_localzone 6 | from pydantic import Field 7 | 8 | 9 | class Config(BaseSettings): 10 | 11 | # Artifex settings 12 | API_KEY: Optional[str] = None 13 | output_path_factory: Callable[[], str] = Field( 14 | default_factory=lambda: 15 | lambda: f"{os.getcwd()}/artifex_output/run-{datetime.now(tz=get_localzone()).strftime('%Y%m%d%H%M%S')}/" 16 | ) 17 | @property 18 | def DEFAULT_OUTPUT_PATH(self) -> str: 19 | return self.output_path_factory() 20 | 21 | # Artifex error messages 22 | DATA_GENERATION_ERROR: str = "An error occurred while generating training data. This may be due to an intense load on the system. Please try again later." 23 | 24 | # Synthex settings 25 | DEFAULT_SYNTHEX_DATAPOINT_NUM: int = 500 26 | DEFAULT_SYNTHEX_DATASET_FORMAT: str = "csv" 27 | @property 28 | def DEFAULT_SYNTHEX_DATASET_NAME(self) -> str: 29 | return f"train_data.{self.DEFAULT_SYNTHEX_DATASET_FORMAT}" 30 | # Leave empty to put the output model directly in the output folder (no subfolder) 31 | SYNTHEX_OUTPUT_MODEL_FOLDER_NAME: str = "" 32 | 33 | # HuggingFace settings 34 | DEFAULT_HUGGINGFACE_LOGGING_LEVEL: str = "error" 35 | 36 | # Base Model 37 | DEFAULT_TOKENIZER_MAX_LENGTH: int = 256 38 | 39 | # Classification Model 40 | CLASSIFICATION_CLASS_NAME_MAX_LENGTH: int = 20 41 | CLASSIFICATION_HF_BASE_MODEL: str = "microsoft/Multilingual-MiniLM-L12-H384" 42 | 43 | # Guardrail Model 44 | GUARDRAIL_HF_BASE_MODEL: str = "tanaos/tanaos-guardrail-v1" 45 | 46 | # IntentClassifier Model 47 | INTENT_CLASSIFIER_HF_BASE_MODEL: str = "tanaos/tanaos-intent-classifier-v1" 48 | 49 | # Reranker Model 50 | RERANKER_HF_BASE_MODEL: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" 51 | RERANKER_TOKENIZER_MAX_LENGTH: int = 256 52 | 53 | # Sentiment Analysis Model 54 | SENTIMENT_ANALYSIS_HF_BASE_MODEL: str = "tanaos/tanaos-sentiment-analysis-v1" 55 | 56 | # Emotion Detection Model 57 | EMOTION_DETECTION_HF_BASE_MODEL: str = "tanaos/tanaos-emotion-detection-v1" 58 | 59 | # Text Anonymization Model 60 | TEXT_ANONYMIZATION_HF_BASE_MODEL: str = "tanaos/tanaos-text-anonymizer-v1" 61 | DEFAULT_TEXT_ANONYM_MASK: str = "[MASKED]" 62 | 63 | # Named Entity Recognition Model 64 | NER_HF_BASE_MODEL: str = "tanaos/tanaos-NER-v1" 65 | NER_TOKENIZER_MAX_LENGTH: int = 256 66 | NER_TAGNAME_MAX_LENGTH: int = 20 67 | 68 | model_config = SettingsConfigDict( 69 | env_file=".env", 70 | env_prefix="", 71 | extra="allow", 72 | ) 73 | 74 | 75 | config = Config() 76 | -------------------------------------------------------------------------------- /tests/unit/classification_model/test_cm__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from transformers import PreTrainedModel, PreTrainedTokenizerBase 4 | from artifex.models import ClassificationModel 5 | from datasets import ClassLabel 6 | 7 | 8 | def test_classification_model_init(mocker: MockerFixture): 9 | """ 10 | Unit test for ClassificationModel.__init__. 11 | Args: 12 | mocker (pytest_mock.MockerFixture): The pytest-mock fixture for mocking dependencies. 13 | """ 14 | 15 | # Mock Synthex 16 | mock_synthex = mocker.Mock() 17 | # Mock config 18 | mock_config = mocker.patch("artifex.models.classification.classification_model.config") 19 | mock_config.CLASSIFICATION_HF_BASE_MODEL = "mocked-base-model" 20 | 21 | # Patch Hugging Face model/tokenizer loading at the correct import path 22 | mock_model = mocker.Mock(spec=PreTrainedModel) 23 | mock_model.config = mocker.Mock(id2label={0: "label"}) 24 | mocker.patch( 25 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 26 | return_value=mock_model 27 | ) 28 | mock_tokenizer = mocker.Mock(spec=PreTrainedTokenizerBase) 29 | mocker.patch( 30 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained", 31 | return_value=mock_tokenizer 32 | ) 33 | # Patch BaseModel.__init__ so it doesn't do anything 34 | mock_super_init = mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None) 35 | 36 | # Instantiate ClassificationModel 37 | model = ClassificationModel(mock_synthex) 38 | 39 | # Assert BaseModel.__init__ was called with correct args 40 | mock_super_init.assert_called_once_with(mock_synthex) 41 | # Assert _system_data_gen_instr is set correctly 42 | assert isinstance(model._system_data_gen_instr_val, list) 43 | assert all(isinstance(item, str) for item in model._system_data_gen_instr_val) 44 | # Assert _token_keys_val is set correctly 45 | assert isinstance(model._token_keys_val, list) and isinstance(model._token_keys_val[0], str) 46 | assert len(model._token_keys_val) == 1 47 | # Assert _synthetic_data_schema_val is set correctly 48 | assert isinstance(model._synthetic_data_schema_val, dict) 49 | assert "text" in model._synthetic_data_schema_val 50 | assert "labels" in model._synthetic_data_schema_val 51 | # Assert that _base_model_name_val is set correctly 52 | assert model._base_model_name_val == "mocked-base-model" 53 | # Assert that _model_val and _tokenizer_val are initialized correctly 54 | assert isinstance(model._model_val, PreTrainedModel) 55 | assert isinstance(model._tokenizer_val, PreTrainedTokenizerBase) 56 | # Assert that _labels_val is initialized correctly 57 | assert isinstance(model._labels_val, ClassLabel) -------------------------------------------------------------------------------- /artifex/core/decorators.py: -------------------------------------------------------------------------------- 1 | from pydantic import validate_call, ValidationError 2 | from typing import Any, Callable, TypeVar 3 | from functools import wraps 4 | import inspect 5 | 6 | 7 | T = TypeVar("T", bound=type) 8 | 9 | 10 | def should_skip_method(attr: Any, attr_name: str) -> bool: 11 | """ 12 | Determines whether a class attribute should be skipped based on its name and signature. 13 | This function skips: 14 | - Dunder (double underscore) methods, except for '__call__'. 15 | - Attributes that are not callable. 16 | - Methods that only have 'self' as their parameter. 17 | Args: 18 | cls (T): The class containing the attribute. 19 | attr_name (str): The name of the attribute to check. 20 | Returns: 21 | bool: True if the attribute should be skipped, False otherwise. 22 | """ 23 | 24 | # Skip dunder methods, except the __call__ method 25 | if attr_name.startswith("__") and attr_name != "__call__": 26 | return True 27 | 28 | if not callable(attr): 29 | return True 30 | 31 | # Get method signature and skip methods that only have 'self' as parameter 32 | sig = inspect.signature(attr) 33 | params = list(sig.parameters.values()) 34 | if len(params) <= 1 and params[0].name == "self": 35 | return True 36 | 37 | return False 38 | 39 | def auto_validate_methods(cls: T) -> T: 40 | """ 41 | A class decorator that combines Pydantic's `validate_call` for input validation 42 | and automatic handling of validation errors, raising a custom `ArtifexValidationError`. 43 | """ 44 | 45 | from artifex.core import ValidationError as ArtifexValidationError 46 | 47 | for attr_name in dir(cls): 48 | # Use getattr_static to avoid triggering descriptors 49 | raw_attr = inspect.getattr_static(cls, attr_name) 50 | attr = getattr(cls, attr_name) 51 | 52 | is_static = isinstance(raw_attr, staticmethod) 53 | is_class = isinstance(raw_attr, classmethod) 54 | 55 | # Unwrap only if it's a staticmethod/classmethod object 56 | if is_static or is_class: 57 | func = raw_attr.__func__ 58 | else: 59 | func = attr 60 | 61 | if should_skip_method(func, attr_name): 62 | continue 63 | 64 | validated = validate_call(config={"arbitrary_types_allowed": True})(func) 65 | 66 | @wraps(func) 67 | def wrapper(*args: Any, __f: Callable[..., Any] = validated, **kwargs: Any) -> Any: 68 | try: 69 | return __f(*args, **kwargs) 70 | except ValidationError as e: 71 | raise ArtifexValidationError(f"Invalid input: {e}") 72 | 73 | # Re-wrap as staticmethod/classmethod if needed 74 | if is_static: 75 | setattr(cls, attr_name, staticmethod(wrapper)) 76 | elif is_class: 77 | setattr(cls, attr_name, classmethod(wrapper)) 78 | else: 79 | setattr(cls, attr_name, wrapper) 80 | 81 | return cls -------------------------------------------------------------------------------- /artifex/core/_hf_patches.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer, TrainerState, TrainingArguments, TrainerCallback, TrainerControl 2 | from transformers.trainer_utils import TrainOutput 3 | from typing import Any, Dict 4 | from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TaskID 5 | from rich.console import Console 6 | 7 | """ 8 | Patches for HuggingFace classes to improve user experience. 9 | """ 10 | 11 | console = Console() 12 | 13 | class SilentTrainer(Trainer): 14 | """ 15 | A regular transformers.Trainer which prevents the tedious final training summary dictionary 16 | from being printed to the console (since as of now, there is no built-in way to disable it). 17 | """ 18 | 19 | def train(self, *args: Any, **kwargs: Any) -> TrainOutput: 20 | import builtins 21 | orig_print = builtins.print 22 | 23 | def silent_print(*a: Any, **k: Any) -> None: 24 | # Only suppress the summary dictionary 25 | if ( 26 | len(a) == 1 27 | and isinstance(a[0], dict) 28 | and "train_runtime" in a[0] 29 | ): 30 | return 31 | return orig_print(*a, **k) 32 | 33 | builtins.print = silent_print 34 | try: 35 | return super().train(*args, **kwargs) 36 | finally: 37 | builtins.print = orig_print 38 | 39 | 40 | class RichProgressCallback(TrainerCallback): 41 | """ 42 | A custom TrainerCallback that uses Rich to display a progress bar during training. 43 | """ 44 | 45 | progress: Progress 46 | task: int 47 | 48 | def on_train_begin( 49 | self, 50 | args: TrainingArguments, 51 | state: TrainerState, 52 | control: TrainerControl, 53 | **kwargs: Dict[str, Any] 54 | ) -> None: 55 | """ 56 | Called at the beginning of training. 57 | """ 58 | 59 | self.progress = Progress( 60 | TextColumn("[bold blue]{task.description}"), 61 | BarColumn(), 62 | "[progress.percentage]{task.percentage:>3.0f}%", 63 | TimeElapsedColumn(), 64 | transient=True 65 | ) 66 | self.task = self.progress.add_task("Training model...", total=state.max_steps) 67 | self.progress.start() 68 | 69 | def on_step_end( 70 | self, 71 | args: TrainingArguments, 72 | state: TrainerState, 73 | control: TrainerControl, 74 | **kwargs: Dict[str, Any] 75 | ) -> None: 76 | """ 77 | Called at the end of each training step. 78 | """ 79 | 80 | self.progress.update(TaskID(self.task), completed=state.global_step) 81 | 82 | def on_train_end( 83 | self, 84 | args: TrainingArguments, 85 | state: TrainerState, 86 | control: TrainerControl, 87 | **kwargs: Dict[str, Any] 88 | ) -> None: 89 | """ 90 | Called at the end of training. 91 | """ 92 | 93 | self.progress.stop() 94 | console.print("[green]✔ Training model[/green]") -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Artifex 2 | 3 | Thanks for your interest in contributing to **Artifex**! 4 | This project is in its early stages, and **we welcome ideas, issues, and pull requests** of all kinds. 5 | 6 | ## Before You Contribute 7 | 8 | All contributions must be made under the terms of [this repository's Contributor License Agreement (CLA)](CLA.md). This ensures that we can safely include your work in the project while keeping it open source under the MIT license. All contributors must sign the CLA. The process is fully automated and handled by [CLA Assistant](https://cla-assistant.io/): 9 | 10 | - When you open your first Pull Request, CLA Assistant will check if you have signed the CLA. 11 | - You will be prompted to click **“I have read the CLA Document and I hereby sign the CLA”**. This is sufficient — no manual signature is required. 12 | 13 | Please ensure that: 14 | - Your contributions are your original work, or that you have the right to submit them. 15 | - You have permission from your employer or any third-party if your contributions are covered by their rights. 16 | 17 | ## What To Contribute On 18 | 19 | We welcome contributions of any kind, **both in the form of [new issues](https://github.com/tanaos/artifex/issues)** and **new code**. Typical contributions are done in one of two ways: 20 | 21 | 1. When using the library you **come across a problem or come up with a possible improvement**. First of all, you should check the [Issues Tab](https://github.com/tanaos/artifex/issues) to see if an issue that addresses the same shortcoming is already present. If that is the case, you can jump to option number 2, otherwise you can either: 22 | - [Open a new issue](https://github.com/tanaos/artifex/issues/new) for the shortcoming you have identified (opening issue, without necessarily working on them, is a productive way of contributing to open source code!) 23 | - [Open a new issue](https://github.com/tanaos/artifex/issues/new) **and** start working on it. You can indicate that you will take care of the newly opened issue by either assigning it to yourself or stating it in the comments. 24 | 25 | 2. You visit the [Issues Tab](https://github.com/tanaos/artifex/issues) and **look for a known issue you are interested in working on**. Simple issues, perfect for developers who have never / very seldom contributed to open source code, are marked with a badge. Once you have found an issue that you like, indicate that you will be working on it by either assigning it to yourself or stating so in the comments. 26 | 27 | ## How To Contribute New Code 28 | 29 | Direct pushes to the `master` branch are not permitted. In order to contribute new code, please **follow the standard fork --> push --> pull request workflow**: 30 | 31 | 1. Fork the repository 32 | 2. Create a new branch (`git checkout -b feature/my-feature`) 33 | 3. Make your changes 34 | 4. Commit your changes (`git commit -m "Add feature"`) 35 | 5. Push to your fork (`git push origin feature/my-feature`) 36 | 6. Open a pull request 37 | 38 | ## Guidelines 39 | 40 | - Keep your code clean and consistent with existing style. 41 | - Add or update docstrings and comments where helpful. 42 | - If applicable, write or update tests. 43 | - Be constructive in discussions. 44 | 45 | ## Questions? 46 | 47 | Feel free to open an issue or start a discussion if you're not sure where to begin. We're happy to help! 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.8.1 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.12.13 4 | aiosignal==1.4.0 5 | annotated-types==0.7.0 6 | anyio==4.9.0 7 | argon2-cffi==25.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==3.0.0 11 | async-lru==2.0.5 12 | attrs==25.3.0 13 | babel==2.17.0 14 | beautifulsoup4==4.13.4 15 | bleach==6.2.0 16 | certifi==2025.6.15 17 | cffi==1.17.1 18 | charset-normalizer==3.4.2 19 | comm==0.2.2 20 | datasets==3.6.0 21 | debugpy==1.8.14 22 | decorator==5.2.1 23 | defusedxml==0.7.1 24 | dill==0.3.8 25 | executing==2.2.0 26 | fastjsonschema==2.21.1 27 | filelock==3.18.0 28 | fqdn==1.5.1 29 | frozenlist==1.7.0 30 | fsspec==2025.3.0 31 | h11==0.16.0 32 | hf-xet==1.1.5 33 | httpcore==1.0.9 34 | httpx==0.28.1 35 | huggingface-hub==0.33.2 36 | idna==3.10 37 | iniconfig==2.1.0 38 | ipykernel==6.29.5 39 | ipython==9.4.0 40 | ipython-pygments-lexers==1.1.1 41 | isoduration==20.11.0 42 | jedi==0.19.2 43 | jinja2==3.1.6 44 | json5==0.12.0 45 | jsonpointer==3.0.0 46 | jsonschema==4.24.0 47 | jsonschema-specifications==2025.4.1 48 | jupyter-client==8.6.3 49 | jupyter-core==5.8.1 50 | jupyter-events==0.12.0 51 | jupyter-lsp==2.2.5 52 | jupyter-server==2.16.0 53 | jupyter-server-terminals==0.5.3 54 | jupyterlab==4.4.4 55 | jupyterlab-pygments==0.3.0 56 | jupyterlab-server==2.27.3 57 | markupsafe==3.0.2 58 | matplotlib-inline==0.1.7 59 | mistune==3.1.3 60 | mpmath==1.3.0 61 | multidict==6.6.3 62 | multiprocess==0.70.16 63 | nbclient==0.10.2 64 | nbconvert==7.16.6 65 | nbformat==5.10.4 66 | nest-asyncio==1.6.0 67 | networkx==3.5 68 | notebook==7.4.4 69 | notebook-shim==0.2.4 70 | numpy==2.3.1 71 | nvidia-cublas-cu12==12.6.4.1 72 | nvidia-cuda-cupti-cu12==12.6.80 73 | nvidia-cuda-nvrtc-cu12==12.6.77 74 | nvidia-cuda-runtime-cu12==12.6.77 75 | nvidia-cudnn-cu12==9.5.1.17 76 | nvidia-cufft-cu12==11.3.0.4 77 | nvidia-cufile-cu12==1.11.1.6 78 | nvidia-curand-cu12==10.3.7.77 79 | nvidia-cusolver-cu12==11.7.1.2 80 | nvidia-cusparse-cu12==12.5.4.2 81 | nvidia-cusparselt-cu12==0.6.3 82 | nvidia-nccl-cu12==2.26.2 83 | nvidia-nvjitlink-cu12==12.6.85 84 | nvidia-nvtx-cu12==12.6.77 85 | overrides==7.7.0 86 | packaging==25.0 87 | pandas==2.3.0 88 | pandocfilters==1.5.1 89 | parso==0.8.4 90 | pexpect==4.9.0 91 | platformdirs==4.3.8 92 | pluggy==1.6.0 93 | prometheus-client==0.22.1 94 | prompt-toolkit==3.0.51 95 | propcache==0.3.2 96 | psutil==7.0.0 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.3 99 | pyarrow==20.0.0 100 | pycparser==2.22 101 | pydantic==2.11.7 102 | pydantic-core==2.33.2 103 | pydantic-settings==2.10.1 104 | pygments==2.19.2 105 | pytest==8.4.1 106 | pytest-mock==3.14.1 107 | python-dateutil==2.9.0.post0 108 | python-dotenv==1.1.1 109 | python-json-logger==3.3.0 110 | pytz==2025.2 111 | pyyaml==6.0.2 112 | pyzmq==27.0.0 113 | referencing==0.36.2 114 | regex==2024.11.6 115 | requests==2.32.4 116 | responses==0.25.7 117 | rfc3339-validator==0.1.4 118 | rfc3986-validator==0.1.1 119 | rpds-py==0.26.0 120 | safetensors==0.5.3 121 | send2trash==1.8.3 122 | setuptools==80.9.0 123 | six==1.17.0 124 | sniffio==1.3.1 125 | soupsieve==2.7 126 | stack-data==0.6.3 127 | sympy==1.14.0 128 | synthex==0.3.1 129 | terminado==0.18.1 130 | tinycss2==1.4.0 131 | tokenizers==0.21.2 132 | torch==2.7.1 133 | tornado==6.5.1 134 | tqdm==4.67.1 135 | traitlets==5.14.3 136 | transformers==4.53.1 137 | triton==3.3.1 138 | types-python-dateutil==2.9.0.20250516 139 | typing-extensions==4.14.0 140 | typing-inspection==0.4.1 141 | tzdata==2025.2 142 | tzlocal==5.3.1 143 | uri-template==1.3.0 144 | urllib3==2.5.0 145 | wcwidth==0.2.13 146 | webcolors==24.11.1 147 | webencodings==0.5.1 148 | websocket-client==1.8.0 149 | xxhash==3.5.0 150 | yarl==1.20.1 151 | -------------------------------------------------------------------------------- /artifex/models/classification/multi_class_classification/sentiment_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | from transformers.trainer_utils import TrainOutput 3 | from typing import Optional 4 | 5 | from ...classification_model import ClassificationModel 6 | 7 | from artifex.core import auto_validate_methods 8 | from artifex.config import config 9 | 10 | 11 | @auto_validate_methods 12 | class SentimentAnalysis(ClassificationModel): 13 | """ 14 | A Sentiment Analysis Model is used to classify the sentiment of a given text into predefined 15 | categories, typically `positive`, `negative`, or `neutral`. In this implementation, we 16 | support two extra sentiment categories: `very_positive` and `very_negative`. 17 | """ 18 | 19 | def __init__(self, synthex: Synthex): 20 | """ 21 | Initializes the class with a Synthex instance. 22 | Args: 23 | synthex (Synthex): An instance of the Synthex class to generate the synthetic 24 | data used to train the model. 25 | """ 26 | super().__init__(synthex, base_model_name=config.SENTIMENT_ANALYSIS_HF_BASE_MODEL) 27 | self._system_data_gen_instr_val: list[str] = [ 28 | "The 'text' field should contain text that belongs to the following domain(s): {domain}.", 29 | "The 'text' field should contain text that may or may not express a certain sentiment.", 30 | "The 'labels' field should contain a label indicating the sentiment of the 'text'.", 31 | "'labels' must only contain one of the provided labels; under no circumstances should it contain arbitrary text.", 32 | "This is a list of the allowed 'labels' and their meaning: " 33 | ] 34 | 35 | def train( 36 | self, domain: str, classes: Optional[dict[str, str]] = None, 37 | output_path: Optional[str] = None, num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM, 38 | num_epochs: int = 3 39 | ) -> TrainOutput: 40 | f""" 41 | Train the Sentiment Analysis model using synthetic data generated by Synthex. 42 | 43 | NOTE: this method overrides `ClassificationModel.train()` to make the `classes` 44 | parameter optional. 45 | 46 | Args: 47 | domain (str): A description of the domain or context for which the model is being trained. 48 | classes (dict[str, str]): A dictionary mapping class names to their descriptions. The keys 49 | (class names) must be string with no spaces and a maximum length of 50 | {config.CLASSIFICATION_CLASS_NAME_MAX_LENGTH} characters. 51 | output_path (Optional[str]): The path where the generated synthetic data will be saved. 52 | num_samples (int): The number of training data samples to generate. 53 | num_epochs (int): The number of epochs for training the model. 54 | """ 55 | 56 | if classes is None: 57 | classes = { 58 | "very_negative": "Text that expresses a very negative sentiment or strong dissatisfaction.", 59 | "negative": "Text that expresses a negative sentiment or dissatisfaction.", 60 | "neutral": "Either a text that does not express any sentiment at all, or a text that expresses a neutral sentiment or lack of strong feelings.", 61 | "positive": "Text that expresses a positive sentiment or satisfaction.", 62 | "very_positive": "Text that expresses a very positive sentiment or strong satisfaction." 63 | } 64 | 65 | return super().train( 66 | domain=domain, classes=classes, output_path=output_path, 67 | num_samples=num_samples, num_epochs=num_epochs 68 | ) 69 | -------------------------------------------------------------------------------- /artifex/models/classification/binary_classification/guardrail/__init__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | from synthex.models import JobOutputSchemaDefinition 3 | from typing import Optional 4 | from transformers.trainer_utils import TrainOutput 5 | 6 | from ...classification_model import ClassificationModel 7 | 8 | from artifex.core import auto_validate_methods 9 | from artifex.config import config 10 | 11 | 12 | @auto_validate_methods 13 | class Guardrail(ClassificationModel): 14 | """ 15 | A Guardrail Model for LLMs. A Guardrail is a model that can be used to classify the output of a LLM 16 | as safe or unsafe, depending on a user's definition of what is safe or unsafe. 17 | """ 18 | 19 | def __init__(self, synthex: Synthex): 20 | """ 21 | Initializes the class with a Synthex instance. 22 | Args: 23 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used to train 24 | the model. 25 | """ 26 | 27 | super().__init__(synthex, base_model_name=config.GUARDRAIL_HF_BASE_MODEL) 28 | self._system_data_gen_instr_val: list[str] = [ 29 | "the 'text' field should contain two types of text: either user queries or LLM-generated responses to those queries.", 30 | "the 'labels' field should contain a label indicating whether the 'text' is safe or unsafe.", 31 | "the 'labels' field can only have one of two values: either 'safe' or 'unsafe'", 32 | "the following content is considered 'unsafe': {unsafe_content}. Everything else is considered 'safe'.", 33 | "the dataset should contain an approximately equal number of safe and unsafe 'text'", 34 | "the dataset should also contain arbitrary 'text', even if not explicitly mentioned in these instructions, but its 'labels' must reflect the actual safety of that text", 35 | ] 36 | 37 | def _get_data_gen_instr(self, user_instr: list[str]) -> list[str]: 38 | """ 39 | Overrides `ClassificationModel._get_data_gen_instr` to account for the different structure of 40 | `Guardrail.train`. 41 | Args: 42 | user_instr (list[str]): A list of user instructions where the last element is the 43 | domain string, and preceding elements are class names and their descriptions. 44 | Returns: 45 | list[str]: A list containing the formatted system instructions followed by the 46 | class-related instructions (all elements except the domain). 47 | """ 48 | 49 | unsafe_content = "; ".join(user_instr) 50 | out = [instr.format(unsafe_content=unsafe_content) for instr in self._system_data_gen_instr_val] 51 | return out 52 | 53 | def train( 54 | self, unsafe_content: list[str], output_path: Optional[str] = None, 55 | num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM, num_epochs: int = 3 56 | ) -> TrainOutput: 57 | f""" 58 | Overrides `ClassificationModel.train` to remove the `domain` and `classes` arguments and 59 | add the `unsafe_content` argument. 60 | Args: 61 | unsafe_content (list[str]): A list of strings describing content that should be 62 | classified as unsafe by the Guardrail model. 63 | output_path (Optional[str]): The path where the synthetic training data and the 64 | output model will be saved. 65 | num_samples (int): The number of training data samples to generate. 66 | num_epochs (int): The number of epochs for training the model. 67 | """ 68 | 69 | output: TrainOutput = self._train_pipeline( 70 | user_instructions=unsafe_content, output_path=output_path, num_samples=num_samples, 71 | num_epochs=num_epochs 72 | ) 73 | 74 | return output -------------------------------------------------------------------------------- /tests/unit/core/decorators/test_should_skip_method.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from artifex.core.decorators import should_skip_method 3 | from typing import Any 4 | from pytest_mock import MockerFixture 5 | import inspect 6 | 7 | 8 | def dummy_func(self, x): pass 9 | def dummy_static(x): pass 10 | def dummy_class(cls, x): pass 11 | 12 | class Dummy: 13 | def method(self, x): pass 14 | def __call__(self, x): pass 15 | def __str__(self): pass 16 | @staticmethod 17 | def static_method(x): pass 18 | @classmethod 19 | def class_method(cls, x): pass 20 | def only_self(self): pass 21 | 22 | @pytest.mark.parametrize( 23 | "attr_name,attr,expected", 24 | [ 25 | ("__str__", Dummy.__str__, True), # dunder, skip 26 | ("__call__", Dummy.__call__, False), # __call__, don't skip 27 | ("method", Dummy.method, False), # normal method, don't skip 28 | ("static_method", Dummy.static_method, False), # staticmethod, don't skip 29 | ("class_method", Dummy.class_method, False), # classmethod, don't skip 30 | ("only_self", Dummy.only_self, True), # only self param, skip 31 | ] 32 | ) 33 | def test_should_skip_method( 34 | mocker: MockerFixture, 35 | attr_name: str, 36 | attr: Any, 37 | expected: bool 38 | ): 39 | """ 40 | Unit test for should_skip_method. Mocks inspect.signature and callable checks. 41 | Args: 42 | mocker (MockerFixture): pytest-mock fixture for mocking. 43 | attr_name (str): Attribute name to test. 44 | attr (Any): Attribute object to test. 45 | expected (bool): Expected result from should_skip_method. 46 | """ 47 | 48 | # Mock callable 49 | mocker.patch("builtins.callable", return_value=True) 50 | 51 | # Mock inspect.signature 52 | if attr_name == "only_self": 53 | mock_sig = mocker.Mock() 54 | mock_param = mocker.Mock() 55 | mock_param.name = "self" 56 | mock_sig.parameters.values.return_value = [mock_param] 57 | mocker.patch("inspect.signature", return_value=mock_sig) 58 | elif attr_name == "static_method": 59 | mock_sig = mocker.Mock() 60 | mock_param = mocker.Mock() 61 | mock_param.name = "x" 62 | mock_sig.parameters.values.return_value = [mock_param] 63 | mocker.patch("inspect.signature", return_value=mock_sig) 64 | elif attr_name == "class_method": 65 | mock_sig = mocker.Mock() 66 | mock_param_cls = mocker.Mock() 67 | mock_param_cls.name = "cls" 68 | mock_param_x = mocker.Mock() 69 | mock_param_x.name = "x" 70 | mock_sig.parameters.values.return_value = [mock_param_cls, mock_param_x] 71 | mocker.patch("inspect.signature", return_value=mock_sig) 72 | elif attr_name == "method": 73 | mock_sig = mocker.Mock() 74 | mock_param_self = mocker.Mock() 75 | mock_param_self.name = "self" 76 | mock_param_x = mocker.Mock() 77 | mock_param_x.name = "x" 78 | mock_sig.parameters.values.return_value = [mock_param_self, mock_param_x] 79 | mocker.patch("inspect.signature", return_value=mock_sig) 80 | elif attr_name == "__call__": 81 | mock_sig = mocker.Mock() 82 | mock_param_self = mocker.Mock() 83 | mock_param_self.name = "self" 84 | mock_param_x = mocker.Mock() 85 | mock_param_x.name = "x" 86 | mock_sig.parameters.values.return_value = [mock_param_self, mock_param_x] 87 | mocker.patch("inspect.signature", return_value=mock_sig) 88 | elif attr_name == "__str__": 89 | mock_sig = mocker.Mock() 90 | mock_param_self = mocker.Mock() 91 | mock_param_self.name = "self" 92 | mock_sig.parameters.values.return_value = [mock_param_self] 93 | mocker.patch("inspect.signature", return_value=mock_sig) 94 | 95 | result = should_skip_method(attr, attr_name) 96 | assert result is expected -------------------------------------------------------------------------------- /tests/unit/core/decorators/test_auto_validate_methods.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import Any 3 | from pytest_mock import MockerFixture 4 | 5 | from artifex.core.decorators import auto_validate_methods 6 | 7 | 8 | class DummyValidationError(Exception): 9 | pass 10 | 11 | def test_auto_validate_methods_valid_call(mocker: MockerFixture): 12 | """ 13 | Test that auto_validate_methods correctly validates method input and returns output. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for mocking. 16 | """ 17 | 18 | # Mock validate_call to just return the function itself 19 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f) 20 | # Mock ArtifexValidationError 21 | mocker.patch("artifex.core.ValidationError", DummyValidationError) 22 | 23 | class TestClass: 24 | def foo(self, x: int) -> int: 25 | return x + 1 26 | 27 | decorated = auto_validate_methods(TestClass) 28 | obj = decorated() 29 | assert obj.foo(1) == 2 30 | mock_validate_call.assert_called() 31 | 32 | def test_auto_validate_methods_raises_on_validation_error(mocker: MockerFixture): 33 | """ 34 | Test that auto_validate_methods raises ArtifexValidationError on validation error. 35 | Args: 36 | mocker (MockerFixture): The pytest-mock fixture for mocking. 37 | """ 38 | 39 | # Mock validate_call to raise ValidationError 40 | def raise_validation_error(f): 41 | def wrapper(*args, **kwargs): 42 | raise DummyValidationError("fail") 43 | return wrapper 44 | 45 | mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: raise_validation_error) 46 | mocker.patch("artifex.core.ValidationError", DummyValidationError) 47 | 48 | class TestClass: 49 | def foo(self, x: int) -> int: 50 | return x + 1 51 | 52 | decorated = auto_validate_methods(TestClass) 53 | obj = decorated() 54 | with pytest.raises(DummyValidationError): 55 | obj.foo("bad_input") 56 | 57 | def test_auto_validate_methods_skips_dunder_methods(mocker: MockerFixture): 58 | """ 59 | Test that auto_validate_methods skips dunder methods except __call__. 60 | Args: 61 | mocker (MockerFixture): The pytest-mock fixture for mocking. 62 | """ 63 | 64 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f) 65 | mocker.patch("artifex.core.ValidationError", DummyValidationError) 66 | 67 | class TestClass: 68 | def __str__(self) -> str: 69 | return "test" 70 | def foo(self, x: int) -> int: 71 | return x + 1 72 | 73 | decorated = auto_validate_methods(TestClass) 74 | obj = decorated() 75 | assert obj.foo(1) == 2 76 | assert obj.__str__() == "test" 77 | # Only foo should be validated 78 | mock_validate_call.assert_any_call(config={"arbitrary_types_allowed": True}) 79 | 80 | def test_auto_validate_methods_static_and_class_methods(mocker: MockerFixture): 81 | """ 82 | Test that auto_validate_methods works for static and class methods. 83 | Args: 84 | mocker (MockerFixture): The pytest-mock fixture for mocking. 85 | """ 86 | 87 | mock_validate_call = mocker.patch("artifex.core.decorators.validate_call", side_effect=lambda *a, **kw: lambda f: f) 88 | mocker.patch("artifex.core.ValidationError", DummyValidationError) 89 | 90 | class TestClass: 91 | @staticmethod 92 | def static(x: int) -> int: 93 | return x + 2 94 | 95 | @classmethod 96 | def clsmethod(cls, x: int) -> int: 97 | return x + 3 98 | 99 | decorated = auto_validate_methods(TestClass) 100 | assert decorated.static(1) == 3 101 | assert decorated.clsmethod(1) == 4 102 | mock_validate_call.assert_any_call(config={"arbitrary_types_allowed": True}) -------------------------------------------------------------------------------- /artifex/models/named_entity_recognition/text_anonymization/__init__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | from typing import Union, Optional 3 | from transformers.trainer_utils import TrainOutput 4 | 5 | from ..named_entity_recognition import NamedEntityRecognition 6 | 7 | from artifex.core import auto_validate_methods 8 | from artifex.config import config 9 | 10 | 11 | @auto_validate_methods 12 | class TextAnonymization(NamedEntityRecognition): 13 | """ 14 | A Text Anonymization model is a model that removes Personal Identifiable Information (PII) from text. 15 | This class extends the NamedEntityRecognition model to specifically target and anonymize PII in text data. 16 | """ 17 | 18 | def __init__(self, synthex: Synthex): 19 | """ 20 | Initializes the class with a Synthex instance. 21 | Args: 22 | synthex (Synthex): An instance of the Synthex class to generate the synthetic data used 23 | to train the model. 24 | """ 25 | 26 | super().__init__(synthex) 27 | self._pii_entities = { 28 | "PERSON": "Individual people, fictional characters", 29 | "LOCATION": "Geographical areas", 30 | "DATE": "Absolute or relative dates, including years, months and/or days", 31 | "ADDRESS": "full addresses", 32 | "PHONE_NUMBER": "telephone numbers", 33 | } 34 | self._maskable_entities = list(self._pii_entities.keys()) 35 | 36 | def __call__( 37 | self, text: Union[str, list[str]], entities_to_mask: Optional[list[str]] = None, 38 | mask_token: str = config.DEFAULT_TEXT_ANONYM_MASK 39 | ) -> list[str]: 40 | """ 41 | Anonymizes the input text by masking PII entities. 42 | Args: 43 | text (Union[str, list[str]]): The input text or list of texts to be anonymized. 44 | Returns: 45 | list[str]: A list of anonymized texts. 46 | """ 47 | 48 | if entities_to_mask is None: 49 | entities_to_mask = self._maskable_entities 50 | else: 51 | for entity in entities_to_mask: 52 | if entity not in self._maskable_entities: 53 | raise ValueError(f"Entity '{entity}' cannot be masked. Allowed entities are: {self._maskable_entities}") 54 | 55 | if isinstance(text, str): 56 | text = [text] 57 | 58 | out: list[str] = [] 59 | 60 | named_entities = super().__call__(text) 61 | for idx, input_text in enumerate(text): 62 | anonymized_text = input_text 63 | # Mask entities in reverse order to avoid invalidating the start/end indices 64 | for entities in reversed(named_entities[idx]): 65 | if entities.entity_group in entities_to_mask: 66 | start, end = entities.start, entities.end 67 | anonymized_text = ( 68 | anonymized_text[:start] + mask_token + anonymized_text[end:] 69 | ) 70 | out.append(anonymized_text) 71 | 72 | return out 73 | 74 | def train( 75 | self, domain: str, output_path: Optional[str] = None, 76 | num_samples: int = config.DEFAULT_SYNTHEX_DATAPOINT_NUM, num_epochs: int = 3 77 | ) -> TrainOutput: 78 | """ 79 | Trains the Text Anonymization model. This method is identical to the 80 | NamedEntityRecognition.train method, except that named_entities are set to a predefined 81 | list of PII entities. 82 | Args: 83 | domain (str): The domain for which to train the model. 84 | output_path (Optional[str]): The path where to save the trained model. If None, a default path is used. 85 | num_samples (int): The number of synthetic samples to generate for training. 86 | num_epochs (int): The number of epochs to train the model. 87 | Returns: 88 | TrainOutput: The output of the training process. 89 | """ 90 | 91 | return super().train( 92 | named_entities=self._pii_entities, domain=domain, output_path=output_path, 93 | num_samples=num_samples, num_epochs=num_epochs, 94 | train_datapoint_examples=None 95 | ) -------------------------------------------------------------------------------- /tests/unit/classification_model/test_cm_load_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from artifex.models import ClassificationModel 4 | 5 | 6 | def make_mock_model(mocker: MockerFixture, id2label=None) -> MockerFixture: 7 | mock_model = mocker.Mock() 8 | mock_config = mocker.Mock() 9 | mock_config.id2label = id2label 10 | mock_model.config = mock_config 11 | return mock_model 12 | 13 | 14 | @pytest.fixture 15 | def mock_classification_model(mocker: MockerFixture) -> ClassificationModel: 16 | """ 17 | Fixture to create a ClassificationModel with all dependencies mocked. 18 | """ 19 | 20 | mock_synthex = mocker.Mock() 21 | # Patch Hugging Face model/tokenizer loading for __init__ 22 | mock_model = make_mock_model(mocker, {0: "labelA", 1: "labelB"}) 23 | mocker.patch( 24 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 25 | return_value=mock_model 26 | ) 27 | mocker.patch( 28 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained", 29 | return_value=mocker.Mock() 30 | ) 31 | mocker.patch( 32 | "artifex.models.classification.classification_model.ClassLabel", 33 | return_value=mocker.Mock(names=["labelA", "labelB"]) 34 | ) 35 | mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None) 36 | return ClassificationModel(mock_synthex) 37 | 38 | 39 | def test_load_model_sets_model_and_labels( 40 | mocker: MockerFixture, mock_classification_model: ClassificationModel 41 | ): 42 | """ 43 | Test that _load_model sets the model and labels correctly. 44 | Args: 45 | mocker (MockerFixture): The pytest-mock fixture for mocking. 46 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance. 47 | """ 48 | 49 | # Prepare a mock model with id2label 50 | id2label = {0: "foo", 1: "bar"} 51 | mock_model = make_mock_model(mocker, id2label) 52 | mock_classlabel = mocker.Mock(names=["foo", "bar"]) 53 | mocker.patch( 54 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 55 | return_value=mock_model 56 | ) 57 | classlabel_patch = mocker.patch( 58 | "artifex.models.classification.classification_model.ClassLabel", 59 | return_value=mock_classlabel 60 | ) 61 | 62 | mock_classification_model._load_model("dummy_path") 63 | 64 | # Model should be set 65 | assert mock_classification_model._model is mock_model 66 | # ClassLabel should be called with correct names 67 | classlabel_patch.assert_called_once_with(names=["foo", "bar"]) 68 | # _labels should be set to the mock_classlabel 69 | assert mock_classification_model._labels is mock_classlabel 70 | 71 | 72 | def test_load_model_raises_if_id2label_missing( 73 | mocker: MockerFixture, mock_classification_model: ClassificationModel 74 | ): 75 | """ 76 | Test that _load_model raises AssertionError if id2label is missing. 77 | Args: 78 | mocker (MockerFixture): The pytest-mock fixture for mocking. 79 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance. 80 | """ 81 | 82 | # Prepare a mock model with no id2label 83 | mock_model = make_mock_model(mocker, id2label=None) 84 | mocker.patch( 85 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 86 | return_value=mock_model 87 | ) 88 | with pytest.raises(AssertionError, match="Model config must have id2label mapping."): 89 | mock_classification_model._load_model("dummy_path") 90 | 91 | 92 | def test_load_model_passes_path( 93 | mocker: MockerFixture, mock_classification_model: ClassificationModel 94 | ): 95 | """ 96 | Test that _load_model passes the correct path to from_pretrained. 97 | Args: 98 | mocker (MockerFixture): The pytest-mock fixture for mocking. 99 | mock_classification_model (ClassificationModel): The mocked ClassificationModel instance. 100 | """ 101 | 102 | id2label = {0: "a", 1: "b"} 103 | mock_model = make_mock_model(mocker, id2label) 104 | from_pretrained_patch = mocker.patch( 105 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 106 | return_value=mock_model 107 | ) 108 | mocker.patch( 109 | "artifex.models.classification.classification_model.ClassLabel", 110 | return_value=mocker.Mock(names=["a", "b"]) 111 | ) 112 | path = "some/model/path" 113 | mock_classification_model._load_model(path) 114 | from_pretrained_patch.assert_called_once_with(path) -------------------------------------------------------------------------------- /tests/unit/classification_model/test_cm_get_data_gen_instr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import List 3 | from artifex.models.classification import ClassificationModel 4 | from synthex import Synthex 5 | 6 | 7 | class DummyClassificationModel(ClassificationModel): 8 | """ 9 | Dummy concrete implementation for testing ClassificationModel. 10 | """ 11 | 12 | @property 13 | def _base_model_name(self) -> str: 14 | return "dummy-model" 15 | 16 | @property 17 | def _system_data_gen_instr(self) -> List[str]: 18 | return [ 19 | "System instruction 1 for {domain}", 20 | "System instruction 2 for {domain}" 21 | ] 22 | 23 | 24 | @pytest.fixture 25 | def model(mocker) -> DummyClassificationModel: 26 | """ 27 | Fixture that returns a DummyClassificationModel instance with mocked Synthex. 28 | Args: 29 | mocker: The pytest-mock fixture for mocking. 30 | Returns: 31 | DummyClassificationModel: An instance of the dummy model for testing. 32 | """ 33 | 34 | synthex_mock = mocker.Mock(spec=Synthex) 35 | # Patch Hugging Face model/tokenizer loading 36 | mock_model = mocker.Mock() 37 | mock_model.config = mocker.Mock(id2label={0: "labelA"}) 38 | mocker.patch( 39 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 40 | return_value=mock_model 41 | ) 42 | mocker.patch( 43 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained", 44 | return_value=mocker.Mock() 45 | ) 46 | mocker.patch( 47 | "artifex.models.classification.classification_model.ClassLabel", 48 | return_value=mocker.Mock(names=["labelA"]) 49 | ) 50 | mocker.patch("artifex.models.base_model.BaseModel.__init__", return_value=None) 51 | return DummyClassificationModel(synthex=synthex_mock) 52 | 53 | 54 | def test_get_data_gen_instr_basic(model: DummyClassificationModel): 55 | """ 56 | Test that _get_data_gen_instr correctly formats system instructions and combines them 57 | with user instructions, excluding the domain from user instructions. 58 | """ 59 | 60 | user_instr = [ 61 | "classA: descriptionA", 62 | "classB: descriptionB", 63 | "test-domain" 64 | ] 65 | result = model._get_data_gen_instr(user_instr) 66 | expected = [ 67 | "System instruction 1 for test-domain", 68 | "System instruction 2 for test-domain", 69 | "classA: descriptionA", 70 | "classB: descriptionB" 71 | ] 72 | assert result == expected 73 | 74 | 75 | def test_get_data_gen_instr_empty_classes(model: DummyClassificationModel): 76 | """ 77 | Test that _get_data_gen_instr works when there are no class instructions, only the domain. 78 | """ 79 | 80 | user_instr = [ 81 | "test-domain" 82 | ] 83 | result = model._get_data_gen_instr(user_instr) 84 | expected = [ 85 | "System instruction 1 for test-domain", 86 | "System instruction 2 for test-domain" 87 | ] 88 | assert result == expected 89 | 90 | 91 | def test_get_data_gen_instr_multiple_classes(model: DummyClassificationModel): 92 | """ 93 | Test that _get_data_gen_instr works with multiple class instructions. 94 | """ 95 | 96 | user_instr = [ 97 | "classA: descriptionA", 98 | "classB: descriptionB", 99 | "classC: descriptionC", 100 | "test-domain" 101 | ] 102 | result = model._get_data_gen_instr(user_instr) 103 | expected = [ 104 | "System instruction 1 for test-domain", 105 | "System instruction 2 for test-domain", 106 | "classA: descriptionA", 107 | "classB: descriptionB", 108 | "classC: descriptionC" 109 | ] 110 | assert result == expected 111 | 112 | 113 | def test_get_data_gen_instr_domain_with_spaces(model: DummyClassificationModel): 114 | """ 115 | Test that _get_data_gen_instr correctly formats instructions when the domain contains spaces. 116 | """ 117 | 118 | user_instr = [ 119 | "classA: descriptionA", 120 | "classB: descriptionB", 121 | "complex domain name" 122 | ] 123 | result = model._get_data_gen_instr(user_instr) 124 | expected = [ 125 | "System instruction 1 for complex domain name", 126 | "System instruction 2 for complex domain name", 127 | "classA: descriptionA", 128 | "classB: descriptionB" 129 | ] 130 | assert result == expected 131 | 132 | 133 | def test_get_data_gen_instr_no_classes_only_domain(model: DummyClassificationModel): 134 | """ 135 | Test that _get_data_gen_instr returns only system instructions when user_instr contains only the domain. 136 | """ 137 | 138 | user_instr = ["domain-only"] 139 | result = model._get_data_gen_instr(user_instr) 140 | expected = [ 141 | "System instruction 1 for domain-only", 142 | "System instruction 2 for domain-only" 143 | ] 144 | assert result == expected 145 | 146 | 147 | def test_get_data_gen_instr_empty_user_instr(model: DummyClassificationModel): 148 | """ 149 | Test that _get_data_gen_instr raises IndexError when user_instr is empty. 150 | """ 151 | 152 | with pytest.raises(IndexError): 153 | model._get_data_gen_instr([]) 154 | 155 | 156 | def test_get_data_gen_instr_special_characters(model: DummyClassificationModel): 157 | """ 158 | Test that _get_data_gen_instr works with special characters in class descriptions and domain. 159 | """ 160 | 161 | user_instr = [ 162 | "classA: descr!ption@A#", 163 | "classB: descr$ption%B^", 164 | "domain*&^%$#@!" 165 | ] 166 | result = model._get_data_gen_instr(user_instr) 167 | expected = [ 168 | "System instruction 1 for domain*&^%$#@!", 169 | "System instruction 2 for domain*&^%$#@!", 170 | "classA: descr!ption@A#", 171 | "classB: descr$ption%B^" 172 | ] 173 | assert result == expected -------------------------------------------------------------------------------- /tests/unit/guardrail/test_gr_get_data_gen_instr.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | import pytest 3 | from pytest_mock import MockerFixture 4 | 5 | from artifex.models import Guardrail 6 | from artifex.config import config 7 | 8 | 9 | @pytest.fixture(autouse=True) 10 | def mock_dependencies(mocker: MockerFixture): 11 | """ 12 | Fixture to mock all external dependencies before any test runs. 13 | This fixture runs automatically for all tests in this module. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for mocking. 16 | """ 17 | # Mock config - patch before import 18 | mocker.patch.object(config, "GUARDRAIL_HF_BASE_MODEL", "mock-guardrail-model") 19 | 20 | # Mock AutoTokenizer - must be at transformers module level 21 | mock_tokenizer = mocker.MagicMock() 22 | mocker.patch( 23 | "transformers.AutoTokenizer.from_pretrained", 24 | return_value=mock_tokenizer 25 | ) 26 | 27 | # Mock ClassLabel 28 | mocker.patch("datasets.ClassLabel", return_value=mocker.MagicMock()) 29 | 30 | # Mock AutoModelForSequenceClassification if used by parent class 31 | mock_model = mocker.MagicMock() 32 | mock_model.config.id2label.values.return_value = ["safe", "unsafe"] 33 | mocker.patch( 34 | "transformers.AutoModelForSequenceClassification.from_pretrained", 35 | return_value=mock_model 36 | ) 37 | 38 | # Mock Trainer if used 39 | mocker.patch("transformers.Trainer") 40 | 41 | # Mock TrainingArguments if used 42 | mocker.patch("transformers.TrainingArguments") 43 | 44 | @pytest.fixture 45 | def mock_synthex(mocker: MockerFixture) -> Synthex: 46 | """ 47 | Fixture to create a mock Synthex instance. 48 | Args: 49 | mocker (MockerFixture): The pytest-mock fixture for mocking. 50 | Returns: 51 | Synthex: A mocked Synthex instance. 52 | """ 53 | 54 | return mocker.MagicMock(spec=Synthex) 55 | 56 | 57 | @pytest.fixture 58 | def mock_guardrail(mock_synthex: Synthex) -> Guardrail: 59 | """ 60 | Fixture to create a Guardrail instance with mocked dependencies. 61 | Args: 62 | mock_synthex (Synthex): A mocked Synthex instance. 63 | Returns: 64 | Guardrail: An instance of the Guardrail model with mocked dependencies. 65 | """ 66 | 67 | return Guardrail(mock_synthex) 68 | 69 | 70 | @pytest.mark.unit 71 | def test_get_data_gen_instr_success(mock_guardrail: Guardrail): 72 | """ 73 | Test that the _get_data_gen_instr method correctly combines system and user 74 | instructions into a single list. 75 | Args: 76 | mock_guardrail (Guardrail): The Guardrail instance to test. 77 | """ 78 | 79 | user_instr_1 = "do not allow profanity" 80 | user_instr_2 = "do not allow personal information" 81 | 82 | user_instructions = [user_instr_1, user_instr_2] 83 | 84 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions) 85 | 86 | # Assert that the combined instructions are a list 87 | assert isinstance(combined_instr, list) 88 | 89 | # The total length should be that of the system instructions 90 | expected_length = len(mock_guardrail._system_data_gen_instr) 91 | assert len(combined_instr) == expected_length 92 | 93 | # User instructions should be embedded in the fourth system instruction 94 | unsafe_content_formatted = "; ".join(user_instructions) 95 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content=unsafe_content_formatted) 96 | 97 | 98 | @pytest.mark.unit 99 | def test_get_data_gen_instr_empty_user_instructions(mock_guardrail: Guardrail): 100 | """ 101 | Test that the _get_data_gen_instr method handles empty user instructions list. 102 | Args: 103 | mock_guardrail (Guardrail): The Guardrail instance to test. 104 | """ 105 | 106 | user_instructions = [] 107 | 108 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions) 109 | 110 | assert len(combined_instr) == len(mock_guardrail._system_data_gen_instr) 111 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content="") 112 | 113 | 114 | @pytest.mark.unit 115 | def test_get_data_gen_instr_single_user_instruction(mock_guardrail: Guardrail): 116 | """ 117 | Test that the _get_data_gen_instr method handles a single user instruction. 118 | Args: 119 | mock_guardrail (Guardrail): The Guardrail instance to test. 120 | """ 121 | 122 | user_instr = "block hate speech" 123 | user_instructions = [user_instr] 124 | 125 | combined_instr = mock_guardrail._get_data_gen_instr(user_instructions) 126 | 127 | assert combined_instr[3] == mock_guardrail._system_data_gen_instr[3].format(unsafe_content=user_instr) 128 | 129 | 130 | @pytest.mark.unit 131 | def test_get_data_gen_instr_validation_failure(mock_guardrail: Guardrail): 132 | """ 133 | Test that the _get_data_gen_instr method raises a ValidationError when provided 134 | with invalid user instructions (not a list). 135 | Args: 136 | mock_guardrail (Guardrail): The Guardrail instance to test. 137 | """ 138 | 139 | from artifex.core import ValidationError 140 | 141 | with pytest.raises(ValidationError): 142 | mock_guardrail._get_data_gen_instr("invalid instructions") 143 | 144 | 145 | @pytest.mark.unit 146 | def test_get_data_gen_instr_does_not_modify_original_lists(mock_guardrail: Guardrail): 147 | """ 148 | Test that the _get_data_gen_instr method does not modify the original lists. 149 | Args: 150 | mock_guardrail (Guardrail): The Guardrail instance to test. 151 | """ 152 | 153 | user_instructions = ["instruction1", "instruction2"] 154 | original_user_instr = user_instructions.copy() 155 | original_system_instr = mock_guardrail._system_data_gen_instr.copy() 156 | 157 | mock_guardrail._get_data_gen_instr(user_instructions) 158 | 159 | # Verify original lists are unchanged 160 | assert user_instructions == original_user_instr 161 | assert mock_guardrail._system_data_gen_instr == original_system_instr -------------------------------------------------------------------------------- /artifex/__init__.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | 3 | console = Console() 4 | 5 | with console.status("Initializing Artifex..."): 6 | from synthex import Synthex 7 | from typing import Optional 8 | from transformers import logging as hf_logging 9 | import datasets 10 | 11 | from .core import auto_validate_methods 12 | from .models.classification import ClassificationModel, Guardrail, IntentClassifier, \ 13 | SentimentAnalysis, EmotionDetection 14 | from .models.named_entity_recognition import NamedEntityRecognition, TextAnonymization 15 | from .models.reranker import Reranker 16 | from .config import config 17 | console.print(f"[green]✔ Initializing Artifex[/green]") 18 | 19 | 20 | if config.DEFAULT_HUGGINGFACE_LOGGING_LEVEL.lower() == "error": 21 | hf_logging.set_verbosity_error() 22 | 23 | # Disable the progress bar from the datasets library, as it interferes with rich's progress bar. 24 | datasets.disable_progress_bar() 25 | 26 | 27 | @auto_validate_methods 28 | class Artifex: 29 | """ 30 | Artifex is a library for easily training and using small, task-specific AI models. 31 | """ 32 | 33 | def __init__(self, api_key: Optional[str] = None): 34 | """ 35 | Initializes Artifex with an API key for authentication. 36 | Args: 37 | api_key (Optional[str]): The API key to use for authentication. If not provided, attempts to load 38 | it from the .env file. 39 | """ 40 | 41 | if not api_key: 42 | api_key=config.API_KEY 43 | self._synthex_client = Synthex(api_key=api_key) 44 | self._text_classification = None 45 | self._guardrail = None 46 | self._intent_classifier = None 47 | self._reranker = None 48 | self._sentiment_analysis = None 49 | self._emotion_detection = None 50 | self._named_entity_recognition = None 51 | self._text_anonymization = None 52 | 53 | @property 54 | def text_classification(self) -> ClassificationModel: 55 | """ 56 | Lazy loads the ClassificationModel instance. 57 | Returns: 58 | ClassificationModel: An instance of the ClassificationModel class. 59 | """ 60 | 61 | if self._text_classification is None: 62 | with console.status("Loading Classification model..."): 63 | self._text_classification = ClassificationModel(synthex=self._synthex_client) 64 | return self._text_classification 65 | 66 | @property 67 | def guardrail(self) -> Guardrail: 68 | """ 69 | Lazy loads the Guardrail instance. 70 | Returns: 71 | Guardrail: An instance of the Guardrail class. 72 | """ 73 | 74 | if self._guardrail is None: 75 | with console.status("Loading Guardrail model..."): 76 | self._guardrail = Guardrail(synthex=self._synthex_client) 77 | return self._guardrail 78 | 79 | @property 80 | def intent_classifier(self) -> IntentClassifier: 81 | """ 82 | Lazy loads the IntentClassifier instance. 83 | Returns: 84 | IntentClassifier: An instance of the IntentClassifier class. 85 | """ 86 | 87 | if self._intent_classifier is None: 88 | with console.status("Loading Intent Classifier model..."): 89 | self._intent_classifier = IntentClassifier(synthex=self._synthex_client) 90 | return self._intent_classifier 91 | 92 | @property 93 | def reranker(self) -> Reranker: 94 | """ 95 | Lazy loads the Reranker instance. 96 | Returns: 97 | Reranker: An instance of the Reranker class. 98 | """ 99 | 100 | if self._reranker is None: 101 | with console.status("Loading Reranker model..."): 102 | self._reranker = Reranker(synthex=self._synthex_client) 103 | return self._reranker 104 | 105 | @property 106 | def sentiment_analysis(self) -> SentimentAnalysis: 107 | """ 108 | Lazy loads the SentimentAnalysis instance. 109 | Returns: 110 | SentimentAnalysis: An instance of the SentimentAnalysis class. 111 | """ 112 | 113 | if self._sentiment_analysis is None: 114 | with console.status("Loading Sentiment Analysis model..."): 115 | self._sentiment_analysis = SentimentAnalysis(synthex=self._synthex_client) 116 | return self._sentiment_analysis 117 | 118 | @property 119 | def emotion_detection(self) -> EmotionDetection: 120 | """ 121 | Lazy loads the EmotionDetection instance. 122 | Returns: 123 | EmotionDetection: An instance of the EmotionDetection class. 124 | """ 125 | 126 | if self._emotion_detection is None: 127 | with console.status("Loading Emotion Detection model..."): 128 | self._emotion_detection = EmotionDetection(synthex=self._synthex_client) 129 | return self._emotion_detection 130 | 131 | @property 132 | def named_entity_recognition(self) -> NamedEntityRecognition: 133 | """ 134 | Lazy loads the NamedEntityRecognition instance. 135 | Returns: 136 | NamedEntityRecognition: An instance of the NamedEntityRecognition class. 137 | """ 138 | 139 | if self._named_entity_recognition is None: 140 | with console.status("Loading Named Entity Recognition model..."): 141 | self._named_entity_recognition = NamedEntityRecognition(synthex=self._synthex_client) 142 | return self._named_entity_recognition 143 | 144 | @property 145 | def text_anonymization(self) -> TextAnonymization: 146 | """ 147 | Lazy loads the TextAnonymization instance. 148 | Returns: 149 | TextAnonymization: An instance of the TextAnonymization class. 150 | """ 151 | 152 | if self._text_anonymization is None: 153 | with console.status("Loading Text Anonymization model..."): 154 | self._text_anonymization = TextAnonymization(synthex=self._synthex_client) 155 | return self._text_anonymization -------------------------------------------------------------------------------- /tests/unit/base_model/test_bm_sanitize_output_path.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | 4 | from artifex.core import ValidationError 5 | 6 | 7 | @pytest.fixture(scope="function", autouse=True) 8 | def mock_dependencies(mocker: MockerFixture): 9 | """ 10 | Fixture to mock all external dependencies before any test runs. 11 | This fixture runs automatically for all tests in this module. 12 | Args: 13 | mocker (MockerFixture): The pytest-mock fixture for mocking. 14 | """ 15 | 16 | from artifex.config import config 17 | 18 | mocker.patch.object( 19 | type(config), # Get the class of the config instance 20 | 'DEFAULT_OUTPUT_PATH', 21 | new_callable=mocker.PropertyMock, 22 | return_value="/default/output/" 23 | ) 24 | 25 | 26 | @pytest.mark.unit 27 | def test_sanitize_output_path_with_none_returns_default(): 28 | """ 29 | Test that _sanitize_output_path returns the default path when given None. 30 | """ 31 | 32 | from artifex.models import BaseModel 33 | 34 | result = BaseModel._sanitize_output_path(None) 35 | 36 | assert result == "/default/output/" 37 | 38 | 39 | @pytest.mark.unit 40 | def test_sanitize_output_path_with_empty_string_returns_default(): 41 | """ 42 | Test that _sanitize_output_path returns the default path when given an empty string. 43 | """ 44 | 45 | from artifex.models import BaseModel 46 | 47 | result = BaseModel._sanitize_output_path("") 48 | 49 | assert result == "/default/output/" 50 | 51 | 52 | @pytest.mark.unit 53 | def test_sanitize_output_path_with_whitespace_only_returns_default(): 54 | """ 55 | Test that _sanitize_output_path returns the default path when given whitespace only. 56 | """ 57 | 58 | from artifex.models import BaseModel 59 | 60 | result = BaseModel._sanitize_output_path(" ") 61 | 62 | assert result == "/default/output/" 63 | 64 | 65 | @pytest.mark.unit 66 | def test_sanitize_output_path_with_directory_only(): 67 | """ 68 | Test that _sanitize_output_path correctly handles a directory path without a file. 69 | """ 70 | 71 | from artifex.models import BaseModel 72 | 73 | result = BaseModel._sanitize_output_path("/custom/output/path") 74 | 75 | # Should append date string from default path 76 | assert result == "/custom/output/path/" 77 | 78 | 79 | @pytest.mark.unit 80 | def test_sanitize_output_path_with_file_raises_validation_error(): 81 | """ 82 | Test that _sanitize_output_path raises an error when a file path is provided. 83 | """ 84 | 85 | from artifex.models import BaseModel 86 | 87 | with pytest.raises(ValidationError) as exc_info: 88 | BaseModel._sanitize_output_path("/custom/output/model.safetensors") 89 | assert str(exc_info.value) == "The output_path parameter must be a directory path, not a file path. Try with: '/custom/output'." 90 | 91 | 92 | 93 | @pytest.mark.unit 94 | def test_sanitize_output_path_with_trailing_slash(): 95 | """ 96 | Test that _sanitize_output_path handles paths with trailing slashes correctly. 97 | """ 98 | 99 | from artifex.models import BaseModel 100 | 101 | result = BaseModel._sanitize_output_path("/custom/path/") 102 | 103 | assert result == "/custom/path/" 104 | 105 | 106 | @pytest.mark.unit 107 | def test_sanitize_output_path_strips_whitespace(): 108 | """ 109 | Test that _sanitize_output_path strips leading and trailing whitespace. 110 | """ 111 | 112 | from artifex.models import BaseModel 113 | 114 | result = BaseModel._sanitize_output_path(" /custom/path ") 115 | 116 | assert result == "/custom/path/" 117 | 118 | 119 | @pytest.mark.unit 120 | def test_sanitize_output_path_with_relative_path(): 121 | """ 122 | Test that _sanitize_output_path handles relative paths. 123 | """ 124 | 125 | from artifex.models import BaseModel 126 | 127 | result = BaseModel._sanitize_output_path("./models/output/") 128 | 129 | assert result == "./models/output/" 130 | 131 | 132 | @pytest.mark.unit 133 | def test_sanitize_output_path_with_parent_directory_notation(): 134 | """ 135 | Test that _sanitize_output_path handles parent directory notation. 136 | """ 137 | 138 | from artifex.models import BaseModel 139 | 140 | result = BaseModel._sanitize_output_path("../output/models") 141 | 142 | assert result == "../output/models/" 143 | 144 | 145 | @pytest.mark.unit 146 | def test_sanitize_output_path_appends_slash_to_path_missing_trailing_slash(): 147 | """ 148 | Test that _sanitize_output_path correctly appends a slash to a path that misses the ending 149 | slash. 150 | """ 151 | 152 | from artifex.models import BaseModel 153 | 154 | result = BaseModel._sanitize_output_path("/custom/path") 155 | 156 | assert result == "/custom/path/" 157 | 158 | 159 | @pytest.mark.unit 160 | def test_sanitize_output_path_with_complex_path(): 161 | """ 162 | Test that _sanitize_output_path handles complex paths correctly. 163 | """ 164 | 165 | from artifex.models import BaseModel 166 | 167 | result = BaseModel._sanitize_output_path("/home/user/my_models/sentiment_analysis/v2") 168 | 169 | assert result == "/home/user/my_models/sentiment_analysis/v2/" 170 | 171 | 172 | @pytest.mark.unit 173 | def test_sanitize_output_path_is_static_method(): 174 | """ 175 | Test that _sanitize_output_path can be called as a static method. 176 | """ 177 | 178 | from artifex.models import BaseModel 179 | 180 | # Should not require an instance 181 | result = BaseModel._sanitize_output_path("/path") 182 | 183 | assert isinstance(result, str) 184 | assert result.endswith("/") 185 | 186 | 187 | @pytest.mark.unit 188 | def test_sanitize_output_path_validation_failure_with_non_string(): 189 | """ 190 | Test that _sanitize_output_path raises ValidationError with non-string input. 191 | """ 192 | 193 | from artifex.models import BaseModel 194 | from artifex.core import ValidationError 195 | 196 | with pytest.raises(ValidationError): 197 | BaseModel._sanitize_output_path(123) 198 | 199 | 200 | @pytest.mark.unit 201 | def test_sanitize_output_path_validation_failure_with_list(): 202 | """ 203 | Test that _sanitize_output_path raises ValidationError with list input. 204 | """ 205 | 206 | from artifex.models import BaseModel 207 | from artifex.core import ValidationError 208 | 209 | with pytest.raises(ValidationError): 210 | BaseModel._sanitize_output_path(["/path"]) -------------------------------------------------------------------------------- /tests/unit/guardrail/test_gr_train.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from typing import List, Optional 4 | from transformers.trainer_utils import TrainOutput 5 | from synthex import Synthex 6 | from artifex.models.classification.binary_classification.guardrail import Guardrail 7 | from artifex.config import config 8 | 9 | 10 | @pytest.fixture(scope="function", autouse=True) 11 | def mock_hf_and_config(mocker: MockerFixture) -> None: 12 | """ 13 | Fixture to mock Hugging Face model/tokenizer loading and config values. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for mocking. 16 | """ 17 | 18 | mocker.patch.object(config, "GUARDRAIL_HF_BASE_MODEL", "mock-guardrail-model") 19 | mocker.patch.object(config, "DEFAULT_SYNTHEX_DATAPOINT_NUM", 100) 20 | mocker.patch( 21 | "artifex.models.classification.classification_model.AutoModelForSequenceClassification.from_pretrained", 22 | return_value=mocker.MagicMock() 23 | ) 24 | mocker.patch( 25 | "artifex.models.classification.classification_model.AutoTokenizer.from_pretrained", 26 | return_value=mocker.MagicMock() 27 | ) 28 | 29 | 30 | @pytest.fixture 31 | def mock_synthex(mocker: MockerFixture) -> Synthex: 32 | """ 33 | Fixture to create a mock Synthex instance. 34 | Args: 35 | mocker (MockerFixture): The pytest-mock fixture for mocking. 36 | Returns: 37 | Synthex: A mocked Synthex instance. 38 | """ 39 | 40 | return mocker.MagicMock(spec=Synthex) 41 | 42 | 43 | @pytest.fixture 44 | def guardrail(mocker: MockerFixture, mock_synthex: Synthex) -> Guardrail: 45 | """ 46 | Fixture to create a Guardrail instance with mocked dependencies. 47 | Args: 48 | mocker (MockerFixture): The pytest-mock fixture for mocking. 49 | mock_synthex (Synthex): A mocked Synthex instance. 50 | Returns: 51 | Guardrail: An instance of the Guardrail model with mocked dependencies. 52 | """ 53 | 54 | return Guardrail(mock_synthex) 55 | 56 | 57 | def test_train_calls_train_pipeline_with_required_args( 58 | guardrail: Guardrail, mocker: MockerFixture 59 | ) -> None: 60 | """ 61 | Test that train() calls _train_pipeline with only required arguments. 62 | Args: 63 | guardrail (Guardrail): The Guardrail instance. 64 | mocker (MockerFixture): The pytest-mock fixture for mocking. 65 | """ 66 | 67 | instructions = ["instruction1", "instruction2"] 68 | mock_output = TrainOutput(global_step=1, training_loss=0.1, metrics={}) 69 | train_pipeline_mock = mocker.patch.object( 70 | guardrail, "_train_pipeline", return_value=mock_output 71 | ) 72 | 73 | result = guardrail.train(unsafe_content=instructions) 74 | 75 | train_pipeline_mock.assert_called_once_with( 76 | user_instructions=instructions, 77 | output_path=None, 78 | num_samples=500, 79 | num_epochs=3 80 | ) 81 | 82 | assert result is mock_output 83 | 84 | 85 | def test_train_calls_train_pipeline_with_all_args( 86 | guardrail: Guardrail, mocker: MockerFixture 87 | ) -> None: 88 | """ 89 | Test that train() calls _train_pipeline with all arguments provided. 90 | Args: 91 | guardrail (Guardrail): The Guardrail instance. 92 | mocker (MockerFixture): The pytest-mock fixture for mocking. 93 | """ 94 | 95 | instructions = ["foo", "bar"] 96 | output_path = "/tmp/guardrail.csv" 97 | num_samples = 42 98 | num_epochs = 7 99 | mock_output = TrainOutput(global_step=2, training_loss=0.2, metrics={}) 100 | train_pipeline_mock = mocker.patch.object( 101 | guardrail, "_train_pipeline", return_value=mock_output 102 | ) 103 | 104 | result = guardrail.train( 105 | unsafe_content=instructions, 106 | output_path=output_path, 107 | num_samples=num_samples, 108 | num_epochs=num_epochs 109 | ) 110 | 111 | train_pipeline_mock.assert_called_once_with( 112 | user_instructions=instructions, 113 | output_path=output_path, 114 | num_samples=num_samples, 115 | num_epochs=num_epochs 116 | ) 117 | assert result is mock_output 118 | 119 | 120 | def test_train_returns_trainoutput( 121 | guardrail: Guardrail, mocker: MockerFixture 122 | ) -> None: 123 | """ 124 | Test that train() returns the TrainOutput from _train_pipeline. 125 | Args: 126 | guardrail (Guardrail): The Guardrail instance. 127 | mocker (MockerFixture): The pytest-mock fixture for mocking. 128 | """ 129 | 130 | instructions = ["baz"] 131 | mock_output = TrainOutput(global_step=3, training_loss=0.3, metrics={}) 132 | mocker.patch.object(guardrail, "_train_pipeline", return_value=mock_output) 133 | 134 | result = guardrail.train(unsafe_content=instructions) 135 | assert isinstance(result, TrainOutput) 136 | assert result is mock_output 137 | 138 | 139 | def test_train_with_empty_instructions( 140 | guardrail: Guardrail, mocker: MockerFixture 141 | ) -> None: 142 | """ 143 | Test that train() works with an empty instructions list. 144 | Args: 145 | guardrail (Guardrail): The Guardrail instance. 146 | mocker (MockerFixture): The pytest-mock fixture for mocking. 147 | """ 148 | 149 | instructions: List[str] = [] 150 | mock_output = TrainOutput(global_step=4, training_loss=0.4, metrics={}) 151 | train_pipeline_mock = mocker.patch.object( 152 | guardrail, "_train_pipeline", return_value=mock_output 153 | ) 154 | 155 | result = guardrail.train(unsafe_content=instructions) 156 | train_pipeline_mock.assert_called_once_with( 157 | user_instructions=instructions, 158 | output_path=None, 159 | num_samples=500, 160 | num_epochs=3 161 | ) 162 | assert result is mock_output 163 | 164 | 165 | def test_train_with_none_output_path( 166 | guardrail: Guardrail, mocker: MockerFixture 167 | ) -> None: 168 | """ 169 | Test that train() passes None for output_path if not provided. 170 | Args: 171 | guardrail (Guardrail): The Guardrail instance. 172 | mocker (MockerFixture): The pytest-mock fixture for mocking. 173 | """ 174 | 175 | instructions = ["test"] 176 | mock_output = TrainOutput(global_step=5, training_loss=0.5, metrics={}) 177 | train_pipeline_mock = mocker.patch.object( 178 | guardrail, "_train_pipeline", return_value=mock_output 179 | ) 180 | 181 | result = guardrail.train(unsafe_content=instructions) 182 | call_kwargs = train_pipeline_mock.call_args.kwargs 183 | assert call_kwargs["output_path"] is None 184 | assert result is mock_output -------------------------------------------------------------------------------- /tests/unit/named_entity_recognition/test_ner_parse_user_instructions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from synthex import Synthex 4 | 5 | from artifex.models import NamedEntityRecognition 6 | from artifex.core import NERInstructions 7 | 8 | 9 | @pytest.fixture 10 | def mock_synthex(mocker: MockerFixture) -> Synthex: 11 | """ 12 | Create a mock Synthex instance. 13 | Args: 14 | mocker: pytest-mock fixture for creating mocks. 15 | Returns: 16 | Synthex: A mocked Synthex instance. 17 | """ 18 | 19 | return mocker.MagicMock(spec=Synthex) 20 | 21 | 22 | @pytest.fixture 23 | def ner_instance( 24 | mocker: MockerFixture, mock_synthex: Synthex 25 | ) -> NamedEntityRecognition: 26 | """ 27 | Create a NamedEntityRecognition instance with mocked dependencies. 28 | Args: 29 | mocker: pytest-mock fixture for creating mocks. 30 | mock_synthex: Mocked Synthex instance. 31 | Returns: 32 | NamedEntityRecognition: An instance with mocked dependencies. 33 | """ 34 | 35 | # Mock AutoTokenizer and AutoModelForTokenClassification imports 36 | mocker.patch( 37 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer.from_pretrained" 38 | ) 39 | mocker.patch( 40 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification.from_pretrained" 41 | ) 42 | 43 | return NamedEntityRecognition(mock_synthex) 44 | 45 | 46 | @pytest.mark.unit 47 | def test_parse_single_entity_tag( 48 | ner_instance: NamedEntityRecognition 49 | ): 50 | """ 51 | Test parsing user instructions with a single named entity tag. 52 | Args: 53 | ner_instance: Fixture providing NamedEntityRecognition instance. 54 | """ 55 | 56 | user_instructions = NERInstructions( 57 | named_entity_tags={"PERSON": "A person's name"}, 58 | domain="medical records" 59 | ) 60 | 61 | result = ner_instance._parse_user_instructions(user_instructions) 62 | 63 | assert len(result) == 2 64 | assert result[0] == "PERSON: A person's name" 65 | assert result[1] == "medical records" 66 | 67 | 68 | @pytest.mark.unit 69 | def test_parse_multiple_entity_tags( 70 | ner_instance: NamedEntityRecognition 71 | ): 72 | """ 73 | Test parsing user instructions with multiple named entity tags. 74 | Args: 75 | ner_instance: Fixture providing NamedEntityRecognition instance. 76 | """ 77 | 78 | user_instructions = NERInstructions( 79 | named_entity_tags={ 80 | "PERSON": "A person's name", 81 | "LOCATION": "A geographical location", 82 | "ORGANIZATION": "A company or institution" 83 | }, 84 | domain="news articles" 85 | ) 86 | 87 | result = ner_instance._parse_user_instructions(user_instructions) 88 | 89 | assert len(result) == 4 90 | assert "PERSON: A person's name" in result 91 | assert "LOCATION: A geographical location" in result 92 | assert "ORGANIZATION: A company or institution" in result 93 | assert result[-1] == "news articles" 94 | 95 | 96 | @pytest.mark.unit 97 | def test_parse_empty_entity_tags( 98 | ner_instance: NamedEntityRecognition 99 | ): 100 | """ 101 | Test parsing user instructions with no named entity tags: 102 | Args: 103 | ner_instance: Fixture providing NamedEntityRecognition instance. 104 | """ 105 | 106 | user_instructions = NERInstructions( 107 | named_entity_tags={}, 108 | domain="general text" 109 | ) 110 | 111 | result = ner_instance._parse_user_instructions(user_instructions) 112 | 113 | assert len(result) == 1 114 | assert result[0] == "general text" 115 | 116 | 117 | @pytest.mark.unit 118 | def test_parse_entity_tags_ordering( 119 | ner_instance: NamedEntityRecognition 120 | ): 121 | """ 122 | Test that domain is always the last element in the result. 123 | Args: 124 | ner_instance: Fixture providing NamedEntityRecognition instance. 125 | """ 126 | 127 | user_instructions = NERInstructions( 128 | named_entity_tags={ 129 | "DATE": "A date or time reference", 130 | "MONEY": "Monetary amounts" 131 | }, 132 | domain="financial reports" 133 | ) 134 | 135 | result = ner_instance._parse_user_instructions(user_instructions) 136 | 137 | # Domain should always be last 138 | assert result[-1] == "financial reports" 139 | # All other elements should be entity tags 140 | assert all(": " in item for item in result[:-1]) 141 | 142 | 143 | @pytest.mark.unit 144 | def test_parse_entity_tags_with_special_characters( 145 | ner_instance: NamedEntityRecognition 146 | ): 147 | """ 148 | Test parsing entity tags and descriptions with special characters. 149 | Args: 150 | ner_instance: Fixture providing NamedEntityRecognition instance. 151 | """ 152 | 153 | user_instructions = NERInstructions( 154 | named_entity_tags={ 155 | "EMAIL": "An email address (e.g., user@example.com)", 156 | "PHONE": "A phone number (+1-555-1234)" 157 | }, 158 | domain="customer support tickets" 159 | ) 160 | 161 | result = ner_instance._parse_user_instructions(user_instructions) 162 | 163 | assert len(result) == 3 164 | assert "EMAIL: An email address (e.g., user@example.com)" in result 165 | assert "PHONE: A phone number (+1-555-1234)" in result 166 | assert result[-1] == "customer support tickets" 167 | 168 | 169 | @pytest.mark.unit 170 | def test_parse_entity_tags_with_long_descriptions( 171 | ner_instance: NamedEntityRecognition 172 | ): 173 | """ 174 | Test parsing entity tags with long, detailed descriptions. 175 | Args: 176 | ner_instance: Fixture providing NamedEntityRecognition instance. 177 | """ 178 | 179 | long_description = ( 180 | "A product name including brand, model, version, " 181 | "and any other identifying information" 182 | ) 183 | 184 | user_instructions = NERInstructions( 185 | named_entity_tags={"PRODUCT": long_description}, 186 | domain="e-commerce product reviews" 187 | ) 188 | 189 | result = ner_instance._parse_user_instructions(user_instructions) 190 | 191 | assert len(result) == 2 192 | assert result[0] == f"PRODUCT: {long_description}" 193 | assert result[-1] == "e-commerce product reviews" 194 | 195 | 196 | @pytest.mark.unit 197 | def test_parse_preserves_tag_description_format( 198 | ner_instance: NamedEntityRecognition 199 | ): 200 | """ 201 | Test that the method preserves the 'tag: description' format. 202 | Args: 203 | ner_instance: Fixture providing NamedEntityRecognition instance. 204 | """ 205 | 206 | user_instructions = NERInstructions( 207 | named_entity_tags={ 208 | "SKILL": "A technical or professional skill", 209 | "EDUCATION": "Educational qualification or institution" 210 | }, 211 | domain="resume parsing" 212 | ) 213 | 214 | result = ner_instance._parse_user_instructions(user_instructions) 215 | 216 | # Check format of tag-description pairs 217 | for item in result[:-1]: 218 | assert ": " in item 219 | tag, description = item.split(": ", 1) 220 | assert tag.isupper() or tag.replace("_", "").isupper() 221 | assert len(description) > 0 222 | 223 | 224 | @pytest.mark.unit 225 | def test_parse_with_domain_containing_special_chars( 226 | ner_instance: NamedEntityRecognition 227 | ): 228 | """ 229 | Test parsing when domain contains special characters. 230 | Args: 231 | ner_instance: Fixture providing NamedEntityRecognition instance. 232 | """ 233 | 234 | user_instructions = NERInstructions( 235 | named_entity_tags={"DRUG": "Pharmaceutical drug name"}, 236 | domain="medical records (patient: John Doe, ID: 12345)" 237 | ) 238 | 239 | result = ner_instance._parse_user_instructions(user_instructions) 240 | 241 | assert len(result) == 2 242 | assert result[0] == "DRUG: Pharmaceutical drug name" 243 | assert result[1] == "medical records (patient: John Doe, ID: 12345)" 244 | 245 | 246 | @pytest.mark.unit 247 | def test_parse_return_type( 248 | ner_instance: NamedEntityRecognition 249 | ): 250 | """ 251 | Test that the method returns a list of strings. 252 | Args: 253 | ner_instance: Fixture providing NamedEntityRecognition instance. 254 | """ 255 | 256 | user_instructions = NERInstructions( 257 | named_entity_tags={"GPE": "Geo-political entity"}, 258 | domain="news corpus" 259 | ) 260 | 261 | result = ner_instance._parse_user_instructions(user_instructions) 262 | 263 | assert isinstance(result, list) 264 | assert all(isinstance(item, str) for item in result) -------------------------------------------------------------------------------- /tests/unit/reranker/test_rr_get_data_gen_instr.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | import pytest 3 | from pytest_mock import MockerFixture 4 | 5 | from artifex.models import Reranker 6 | from artifex.config import config 7 | 8 | 9 | @pytest.fixture(scope="function", autouse=True) 10 | def mock_dependencies(mocker: MockerFixture): 11 | """ 12 | Fixture to mock all external dependencies before any test runs. 13 | This fixture runs automatically for all tests in this module. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for mocking. 16 | """ 17 | 18 | # Mock config 19 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model") 20 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512) 21 | 22 | # Mock AutoTokenizer at the module where it's used 23 | mock_tokenizer = mocker.MagicMock() 24 | mocker.patch( 25 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained", 26 | return_value=mock_tokenizer 27 | ) 28 | 29 | # Mock AutoModelForSequenceClassification at the module where it's used 30 | mock_model = mocker.MagicMock() 31 | mocker.patch( 32 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained", 33 | return_value=mock_model 34 | ) 35 | 36 | 37 | @pytest.fixture 38 | def mock_synthex(mocker: MockerFixture) -> Synthex: 39 | """ 40 | Fixture to create a mock Synthex instance. 41 | Args: 42 | mocker (MockerFixture): The pytest-mock fixture for mocking. 43 | Returns: 44 | Synthex: A mocked Synthex instance. 45 | """ 46 | 47 | return mocker.MagicMock(spec=Synthex) 48 | 49 | 50 | @pytest.fixture 51 | def mock_reranker(mock_synthex: Synthex) -> Reranker: 52 | """ 53 | Fixture to create a Reranker instance with mocked dependencies. 54 | Args: 55 | mock_synthex (Synthex): A mocked Synthex instance. 56 | Returns: 57 | Reranker: An instance of the Reranker model with mocked dependencies. 58 | """ 59 | 60 | return Reranker(mock_synthex) 61 | 62 | 63 | @pytest.mark.unit 64 | def test_get_data_gen_instr_success(mock_reranker: Reranker): 65 | """ 66 | Test that _get_data_gen_instr correctly formats system instructions with the domain. 67 | Args: 68 | mock_reranker (Reranker): The Reranker instance to test. 69 | """ 70 | 71 | domain = "scientific research" 72 | user_instructions = [domain] 73 | 74 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 75 | 76 | # Assert that the result is a list 77 | assert isinstance(combined_instr, list) 78 | 79 | # The length should equal the number of system instructions 80 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr) 81 | 82 | # The domain should be formatted into the first system instruction 83 | assert domain in combined_instr[0] 84 | assert f"following domain(s): {domain}" in combined_instr[0] 85 | 86 | 87 | @pytest.mark.unit 88 | def test_get_data_gen_instr_formats_all_instructions(mock_reranker: Reranker): 89 | """ 90 | Test that all system instructions are properly formatted with the domain. 91 | Args: 92 | mock_reranker (Reranker): The Reranker instance to test. 93 | """ 94 | 95 | domain = "e-commerce products" 96 | user_instructions = [domain] 97 | 98 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 99 | 100 | # Verify that {domain} placeholder is replaced in all instructions 101 | for instr in combined_instr: 102 | assert "{domain}" not in instr 103 | 104 | # Check that domain appears in the first instruction 105 | assert domain in combined_instr[0] 106 | 107 | 108 | @pytest.mark.unit 109 | def test_get_data_gen_instr_preserves_instruction_count(mock_reranker: Reranker): 110 | """ 111 | Test that the number of instructions matches the system instructions. 112 | Args: 113 | mock_reranker (Reranker): The Reranker instance to test. 114 | """ 115 | 116 | domain = "customer support" 117 | user_instructions = [domain] 118 | 119 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 120 | 121 | # Should have exactly the same number as system instructions 122 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr) 123 | 124 | 125 | @pytest.mark.unit 126 | def test_get_data_gen_instr_with_different_domains(mock_reranker: Reranker): 127 | """ 128 | Test that _get_data_gen_instr works with different domain strings. 129 | Args: 130 | mock_reranker (Reranker): The Reranker instance to test. 131 | """ 132 | 133 | domains = [ 134 | "medical research", 135 | "legal documents", 136 | "news articles", 137 | "technical documentation" 138 | ] 139 | 140 | for domain in domains: 141 | user_instructions = [domain] 142 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 143 | 144 | assert isinstance(combined_instr, list) 145 | assert domain in combined_instr[0] 146 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr) 147 | 148 | 149 | @pytest.mark.unit 150 | def test_get_data_gen_instr_does_not_modify_original_list(mock_reranker: Reranker): 151 | """ 152 | Test that _get_data_gen_instr does not modify the original system instructions. 153 | Args: 154 | mock_reranker (Reranker): The Reranker instance to test. 155 | """ 156 | 157 | domain = "financial data" 158 | user_instructions = [domain] 159 | original_system_instr = mock_reranker._system_data_gen_instr.copy() 160 | 161 | mock_reranker._get_data_gen_instr(user_instructions) 162 | 163 | # Verify original system instructions are unchanged 164 | assert mock_reranker._system_data_gen_instr == original_system_instr 165 | 166 | 167 | @pytest.mark.unit 168 | def test_get_data_gen_instr_validation_failure(mock_reranker: Reranker): 169 | """ 170 | Test that _get_data_gen_instr raises ValidationError with invalid input. 171 | Args: 172 | mock_reranker (Reranker): The Reranker instance to test. 173 | """ 174 | 175 | from artifex.core import ValidationError 176 | 177 | with pytest.raises(ValidationError): 178 | mock_reranker._get_data_gen_instr("invalid input") 179 | 180 | 181 | @pytest.mark.unit 182 | def test_get_data_gen_instr_with_empty_list(mock_reranker: Reranker): 183 | """ 184 | Test that _get_data_gen_instr handles empty list. 185 | Args: 186 | mock_reranker (Reranker): The Reranker instance to test. 187 | """ 188 | 189 | from artifex.core import ValidationError 190 | 191 | with pytest.raises((ValidationError, IndexError)): 192 | mock_reranker._get_data_gen_instr([]) 193 | 194 | 195 | @pytest.mark.unit 196 | def test_get_data_gen_instr_only_uses_first_element(mock_reranker: Reranker): 197 | """ 198 | Test that _get_data_gen_instr only uses the first element as domain, 199 | ignoring any additional elements. 200 | Args: 201 | mock_reranker (Reranker): The Reranker instance to test. 202 | """ 203 | 204 | domain = "healthcare" 205 | extra_data = ["extra1", "extra2"] 206 | user_instructions = [domain] + extra_data 207 | 208 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 209 | 210 | # Should only format with the first element (domain) 211 | assert domain in combined_instr[0] 212 | # Extra elements should not appear in the instructions 213 | for extra in extra_data: 214 | assert extra not in combined_instr[0] 215 | 216 | 217 | @pytest.mark.unit 218 | def test_get_data_gen_instr_with_special_characters_in_domain(mock_reranker: Reranker): 219 | """ 220 | Test that _get_data_gen_instr correctly handles domains with special characters. 221 | Args: 222 | mock_reranker (Reranker): The Reranker instance to test. 223 | """ 224 | 225 | domain = "Q&A for tech support (beta)" 226 | user_instructions = [domain] 227 | 228 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 229 | 230 | # Domain with special characters should be properly included 231 | assert domain in combined_instr[0] 232 | assert isinstance(combined_instr, list) 233 | assert len(combined_instr) == len(mock_reranker._system_data_gen_instr) 234 | 235 | 236 | @pytest.mark.unit 237 | def test_get_data_gen_instr_returns_new_list(mock_reranker: Reranker): 238 | """ 239 | Test that _get_data_gen_instr returns a new list, not the original. 240 | Args: 241 | mock_reranker (Reranker): The Reranker instance to test. 242 | """ 243 | 244 | domain = "travel booking" 245 | user_instructions = [domain] 246 | 247 | combined_instr = mock_reranker._get_data_gen_instr(user_instructions) 248 | 249 | # Modifying the result should not affect system instructions 250 | combined_instr.append("new instruction") 251 | 252 | assert len(mock_reranker._system_data_gen_instr) != len(combined_instr) 253 | assert "new instruction" not in mock_reranker._system_data_gen_instr -------------------------------------------------------------------------------- /tests/unit/reranker/test_rr_parse_user_instructions.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | import pytest 3 | from pytest_mock import MockerFixture 4 | 5 | from artifex.models import Reranker 6 | from artifex.config import config 7 | 8 | 9 | @pytest.fixture(scope="function", autouse=True) 10 | def mock_dependencies(mocker: MockerFixture): 11 | """ 12 | Fixture to mock all external dependencies before any test runs. 13 | This fixture runs automatically for all tests in this module. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for mocking. 16 | """ 17 | 18 | # Mock config 19 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model") 20 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512) 21 | 22 | # Mock AutoTokenizer at the module where it's used 23 | mock_tokenizer = mocker.MagicMock() 24 | mocker.patch( 25 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained", 26 | return_value=mock_tokenizer 27 | ) 28 | 29 | # Mock AutoModelForSequenceClassification at the module where it's used 30 | mock_model = mocker.MagicMock() 31 | mocker.patch( 32 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained", 33 | return_value=mock_model 34 | ) 35 | 36 | 37 | @pytest.fixture 38 | def mock_synthex(mocker: MockerFixture) -> Synthex: 39 | """ 40 | Fixture to create a mock Synthex instance. 41 | Args: 42 | mocker (MockerFixture): The pytest-mock fixture for mocking. 43 | Returns: 44 | Synthex: A mocked Synthex instance. 45 | """ 46 | 47 | return mocker.MagicMock(spec=Synthex) 48 | 49 | 50 | @pytest.fixture 51 | def mock_reranker(mock_synthex: Synthex) -> Reranker: 52 | """ 53 | Fixture to create a Reranker instance with mocked dependencies. 54 | Args: 55 | mock_synthex (Synthex): A mocked Synthex instance. 56 | Returns: 57 | Reranker: An instance of the Reranker model with mocked dependencies. 58 | """ 59 | 60 | return Reranker(mock_synthex) 61 | 62 | 63 | @pytest.mark.unit 64 | def test_parse_user_instructions_returns_list(mock_reranker: Reranker): 65 | """ 66 | Test that _parse_user_instructions returns a list. 67 | Args: 68 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 69 | """ 70 | 71 | user_instructions = "scientific research papers" 72 | 73 | result = mock_reranker._parse_user_instructions(user_instructions) 74 | 75 | assert isinstance(result, list) 76 | 77 | 78 | @pytest.mark.unit 79 | def test_parse_user_instructions_single_element(mock_reranker: Reranker): 80 | """ 81 | Test that _parse_user_instructions returns a list with a single element. 82 | Args: 83 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 84 | """ 85 | 86 | user_instructions = "medical documents" 87 | 88 | result = mock_reranker._parse_user_instructions(user_instructions) 89 | 90 | assert len(result) == 1 91 | 92 | 93 | @pytest.mark.unit 94 | def test_parse_user_instructions_contains_original_string(mock_reranker: Reranker): 95 | """ 96 | Test that the returned list contains the original user instructions string. 97 | Args: 98 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 99 | """ 100 | 101 | user_instructions = "legal documents and contracts" 102 | 103 | result = mock_reranker._parse_user_instructions(user_instructions) 104 | 105 | assert result[0] == user_instructions 106 | 107 | 108 | @pytest.mark.unit 109 | def test_parse_user_instructions_with_empty_string(mock_reranker: Reranker): 110 | """ 111 | Test that _parse_user_instructions handles an empty string. 112 | Args: 113 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 114 | """ 115 | 116 | user_instructions = "" 117 | 118 | result = mock_reranker._parse_user_instructions(user_instructions) 119 | 120 | assert isinstance(result, list) 121 | assert len(result) == 1 122 | assert result[0] == "" 123 | 124 | 125 | @pytest.mark.unit 126 | def test_parse_user_instructions_with_whitespace(mock_reranker: Reranker): 127 | """ 128 | Test that _parse_user_instructions preserves whitespace in the string. 129 | Args: 130 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 131 | """ 132 | 133 | user_instructions = " news articles with spaces " 134 | 135 | result = mock_reranker._parse_user_instructions(user_instructions) 136 | 137 | assert len(result) == 1 138 | assert result[0] == user_instructions 139 | 140 | 141 | @pytest.mark.unit 142 | def test_parse_user_instructions_with_multiline_string(mock_reranker: Reranker): 143 | """ 144 | Test that _parse_user_instructions handles multiline strings. 145 | Args: 146 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 147 | """ 148 | 149 | user_instructions = """technical documentation 150 | and user manuals""" 151 | 152 | result = mock_reranker._parse_user_instructions(user_instructions) 153 | 154 | assert isinstance(result, list) 155 | assert len(result) == 1 156 | assert result[0] == user_instructions 157 | 158 | 159 | @pytest.mark.unit 160 | def test_parse_user_instructions_with_special_characters(mock_reranker: Reranker): 161 | """ 162 | Test that _parse_user_instructions handles special characters. 163 | Args: 164 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 165 | """ 166 | 167 | user_instructions = "Q&A for tech support (beta) - version 2.0!" 168 | 169 | result = mock_reranker._parse_user_instructions(user_instructions) 170 | 171 | assert len(result) == 1 172 | assert result[0] == user_instructions 173 | 174 | 175 | @pytest.mark.unit 176 | def test_parse_user_instructions_with_long_string(mock_reranker: Reranker): 177 | """ 178 | Test that _parse_user_instructions handles long strings. 179 | Args: 180 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 181 | """ 182 | 183 | user_instructions = "A" * 1000 184 | 185 | result = mock_reranker._parse_user_instructions(user_instructions) 186 | 187 | assert len(result) == 1 188 | assert result[0] == user_instructions 189 | assert len(result[0]) == 1000 190 | 191 | 192 | @pytest.mark.unit 193 | def test_parse_user_instructions_validation_failure_with_list(mock_reranker: Reranker): 194 | """ 195 | Test that _parse_user_instructions raises ValidationError when given a list instead of string. 196 | Args: 197 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 198 | """ 199 | 200 | from artifex.core import ValidationError 201 | 202 | with pytest.raises(ValidationError): 203 | mock_reranker._parse_user_instructions(["not", "a", "string"]) 204 | 205 | 206 | @pytest.mark.unit 207 | def test_parse_user_instructions_validation_failure_with_none(mock_reranker: Reranker): 208 | """ 209 | Test that _parse_user_instructions raises ValidationError when given None. 210 | Args: 211 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 212 | """ 213 | 214 | from artifex.core import ValidationError 215 | 216 | with pytest.raises(ValidationError): 217 | mock_reranker._parse_user_instructions(None) 218 | 219 | 220 | @pytest.mark.unit 221 | def test_parse_user_instructions_validation_failure_with_int(mock_reranker: Reranker): 222 | """ 223 | Test that _parse_user_instructions raises ValidationError when given an integer. 224 | Args: 225 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 226 | """ 227 | 228 | from artifex.core import ValidationError 229 | 230 | with pytest.raises(ValidationError): 231 | mock_reranker._parse_user_instructions(123) 232 | 233 | 234 | @pytest.mark.unit 235 | def test_parse_user_instructions_does_not_modify_input(mock_reranker: Reranker): 236 | """ 237 | Test that _parse_user_instructions does not modify the input string. 238 | Args: 239 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 240 | """ 241 | 242 | user_instructions = "customer reviews and feedback" 243 | original = user_instructions 244 | 245 | result = mock_reranker._parse_user_instructions(user_instructions) 246 | 247 | # Input string should remain unchanged 248 | assert user_instructions == original 249 | 250 | 251 | @pytest.mark.unit 252 | def test_parse_user_instructions_with_unicode(mock_reranker: Reranker): 253 | """ 254 | Test that _parse_user_instructions handles unicode characters. 255 | Args: 256 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 257 | """ 258 | 259 | user_instructions = "文档分类 и categorización de documentos" 260 | 261 | result = mock_reranker._parse_user_instructions(user_instructions) 262 | 263 | assert len(result) == 1 264 | assert result[0] == user_instructions -------------------------------------------------------------------------------- /tests/unit/text_anonymization/test_ta__call__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from synthex import Synthex 4 | from typing import List 5 | 6 | from artifex.models import TextAnonymization 7 | from artifex.config import config 8 | 9 | 10 | @pytest.fixture 11 | def mock_synthex(mocker: MockerFixture) -> Synthex: 12 | """ 13 | Creates a mock Synthex instance. 14 | Args: 15 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 16 | Returns: 17 | Synthex: A mocked Synthex instance. 18 | """ 19 | 20 | return mocker.Mock(spec=Synthex) 21 | 22 | @pytest.fixture 23 | def text_anonymization(mock_synthex: Synthex, mocker: MockerFixture) -> TextAnonymization: 24 | """ 25 | Creates a TextAnonymization instance with mocked dependencies. 26 | Args: 27 | mock_synthex (Synthex): A mocked Synthex instance. 28 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 29 | Returns: 30 | TextAnonymization: An instance of TextAnonymization with mocked parent class. 31 | """ 32 | 33 | # Mock the parent class __init__ to avoid initialization issues 34 | mocker.patch.object(TextAnonymization.__bases__[0], '__init__', return_value=None) 35 | instance = TextAnonymization(mock_synthex) 36 | return instance 37 | 38 | 39 | @pytest.mark.unit 40 | def test_call_with_single_string_no_entities( 41 | text_anonymization: TextAnonymization, mocker: MockerFixture 42 | ): 43 | """ 44 | Tests __call__ with a single string input when no entities are detected. 45 | Args: 46 | text_anonymization (TextAnonymization): The TextAnonymization instance. 47 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 48 | """ 49 | 50 | mock_parent_call = mocker.patch.object( 51 | TextAnonymization.__bases__[0], '__call__', return_value=[[]] 52 | ) 53 | 54 | input_text = "This is a test sentence." 55 | result = text_anonymization(input_text) 56 | 57 | mock_parent_call.assert_called_once_with([input_text]) 58 | assert result == [input_text] 59 | 60 | 61 | @pytest.mark.unit 62 | def test_call_with_single_string_with_entities( 63 | text_anonymization: TextAnonymization, mocker: MockerFixture 64 | ): 65 | """ 66 | Tests __call__ with a single string input containing PII entities. 67 | Args: 68 | text_anonymization (TextAnonymization): The TextAnonymization instance. 69 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 70 | """ 71 | 72 | mock_entity = mocker.Mock() 73 | mock_entity.entity_group = "PERSON" 74 | mock_entity.start = 11 75 | mock_entity.end = 15 76 | 77 | mock_parent_call = mocker.patch.object( 78 | TextAnonymization.__bases__[0], '__call__', return_value=[[mock_entity]] 79 | ) 80 | 81 | input_text = "My name is John and I live in NYC." 82 | result = text_anonymization(input_text) 83 | 84 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK 85 | expected_output = f"My name is {expected_mask} and I live in NYC." 86 | 87 | mock_parent_call.assert_called_once_with([input_text]) 88 | assert result == [expected_output] 89 | 90 | 91 | @pytest.mark.unit 92 | def test_call_with_list_of_strings( 93 | text_anonymization: TextAnonymization, mocker: MockerFixture 94 | ): 95 | """ 96 | Tests __call__ with a list of strings as input. 97 | Args: 98 | text_anonymization (TextAnonymization): The TextAnonymization instance. 99 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 100 | """ 101 | 102 | mock_entity1 = mocker.Mock() 103 | mock_entity1.entity_group = "PERSON" 104 | mock_entity1.start = 0 105 | mock_entity1.end = 4 106 | 107 | mock_entity2 = mocker.Mock() 108 | mock_entity2.entity_group = "LOCATION" 109 | mock_entity2.start = 8 110 | mock_entity2.end = 15 111 | 112 | mock_parent_call = mocker.patch.object( 113 | TextAnonymization.__bases__[0], '__call__', 114 | return_value=[[mock_entity1], [mock_entity2]] 115 | ) 116 | 117 | input_texts = ["John lives here", "I am in London"] 118 | result = text_anonymization(input_texts) 119 | 120 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK 121 | expected_outputs = [f"{expected_mask} lives here", f"I am in {expected_mask}"] 122 | 123 | mock_parent_call.assert_called_once_with(input_texts) 124 | assert result == expected_outputs 125 | 126 | 127 | @pytest.mark.unit 128 | def test_call_with_custom_entities_to_mask( 129 | text_anonymization: TextAnonymization, mocker: MockerFixture 130 | ): 131 | """ 132 | Tests __call__ with custom entities_to_mask parameter. 133 | Args: 134 | text_anonymization (TextAnonymization): The TextAnonymization instance. 135 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 136 | """ 137 | 138 | mock_person = mocker.Mock() 139 | mock_person.entity_group = "PERSON" 140 | mock_person.start = 0 141 | mock_person.end = 4 142 | 143 | mock_location = mocker.Mock() 144 | mock_location.entity_group = "LOCATION" 145 | mock_location.start = 14 146 | mock_location.end = 20 147 | 148 | mock_parent_call = mocker.patch.object( 149 | TextAnonymization.__bases__[0], '__call__', 150 | return_value=[[mock_person, mock_location]] 151 | ) 152 | 153 | input_text = "John lives in London" 154 | result = text_anonymization(input_text, entities_to_mask=["PERSON"]) 155 | 156 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK 157 | expected_output = f"{expected_mask} lives in London" 158 | 159 | mock_parent_call.assert_called_once_with([input_text]) 160 | assert result == [expected_output] 161 | 162 | 163 | @pytest.mark.unit 164 | def test_call_with_custom_mask_token( 165 | text_anonymization: TextAnonymization, mocker: MockerFixture 166 | ): 167 | """ 168 | Tests __call__ with a custom mask_token parameter. 169 | Args: 170 | text_anonymization (TextAnonymization): The TextAnonymization instance. 171 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 172 | """ 173 | 174 | mock_entity = mocker.Mock() 175 | mock_entity.entity_group = "PERSON" 176 | mock_entity.start = 0 177 | mock_entity.end = 4 178 | 179 | mock_parent_call = mocker.patch.object( 180 | TextAnonymization.__bases__[0], '__call__', return_value=[[mock_entity]] 181 | ) 182 | 183 | input_text = "John is here" 184 | custom_mask = "[REDACTED]" 185 | result = text_anonymization(input_text, mask_token=custom_mask) 186 | 187 | expected_output = f"{custom_mask} is here" 188 | 189 | mock_parent_call.assert_called_once_with([input_text]) 190 | assert result == [expected_output] 191 | 192 | 193 | @pytest.mark.unit 194 | def test_call_with_invalid_entity_to_mask( 195 | text_anonymization: TextAnonymization, mocker: MockerFixture 196 | ): 197 | """ 198 | Tests __call__ raises ValueError when invalid entity is in entities_to_mask. 199 | Args: 200 | text_anonymization (TextAnonymization): The TextAnonymization instance. 201 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 202 | """ 203 | 204 | with pytest.raises(ValueError) as exc_info: 205 | text_anonymization("test text", entities_to_mask=["INVALID_ENTITY"]) 206 | 207 | assert "INVALID_ENTITY" in str(exc_info.value) 208 | assert "cannot be masked" in str(exc_info.value) 209 | 210 | 211 | @pytest.mark.unit 212 | def test_call_with_multiple_entities_same_text( 213 | text_anonymization: TextAnonymization, mocker: MockerFixture 214 | ): 215 | """ 216 | Tests __call__ with multiple entities in the same text. 217 | Args: 218 | text_anonymization (TextAnonymization): The TextAnonymization instance. 219 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 220 | """ 221 | 222 | mock_entity1 = mocker.Mock() 223 | mock_entity1.entity_group = "PERSON" 224 | mock_entity1.start = 0 225 | mock_entity1.end = 4 226 | 227 | mock_entity2 = mocker.Mock() 228 | mock_entity2.entity_group = "PHONE_NUMBER" 229 | mock_entity2.start = 17 230 | mock_entity2.end = 29 231 | 232 | mock_parent_call = mocker.patch.object( 233 | TextAnonymization.__bases__[0], '__call__', 234 | return_value=[[mock_entity1, mock_entity2]] 235 | ) 236 | 237 | input_text = "John's number is 123-456-7890" 238 | result = text_anonymization(input_text) 239 | 240 | expected_mask = config.DEFAULT_TEXT_ANONYM_MASK 241 | expected_output = f"{expected_mask}'s number is {expected_mask}" 242 | 243 | mock_parent_call.assert_called_once_with([input_text]) 244 | assert result == [expected_output] 245 | 246 | 247 | @pytest.mark.unit 248 | def test_call_with_empty_string( 249 | text_anonymization: TextAnonymization, mocker: MockerFixture 250 | ): 251 | """ 252 | Tests __call__ with an empty string input. 253 | Args: 254 | text_anonymization (TextAnonymization): The TextAnonymization instance. 255 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 256 | """ 257 | 258 | mock_parent_call = mocker.patch.object( 259 | TextAnonymization.__bases__[0], '__call__', return_value=[[]] 260 | ) 261 | 262 | input_text = "" 263 | result = text_anonymization(input_text) 264 | 265 | mock_parent_call.assert_called_once_with([input_text]) 266 | assert result == [""] -------------------------------------------------------------------------------- /tests/unit/named_entity_recognition/test_ner_post_process_synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | from pytest_mock import MockerFixture 4 | from datasets import ClassLabel 5 | from typing import Any 6 | 7 | from artifex.models import NamedEntityRecognition 8 | 9 | 10 | @pytest.fixture 11 | def mock_synthex(mocker: MockerFixture) -> Any: 12 | """ 13 | Create a mock Synthex instance. 14 | Args: 15 | mocker: pytest-mock fixture for creating mocks. 16 | Returns: 17 | Mock Synthex instance. 18 | """ 19 | 20 | return mocker.Mock() 21 | 22 | 23 | @pytest.fixture 24 | def ner_instance(mock_synthex: Any, mocker: MockerFixture) -> NamedEntityRecognition: 25 | """ 26 | Create a NamedEntityRecognition instance with mocked dependencies. 27 | Args: 28 | mock_synthex: Mocked Synthex instance. 29 | mocker: pytest-mock fixture for creating mocks. 30 | Returns: 31 | NamedEntityRecognition instance with mocked components. 32 | """ 33 | 34 | # Mock AutoTokenizer and AutoModelForTokenClassification 35 | mocker.patch("artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer") 36 | mocker.patch("artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification") 37 | 38 | ner = NamedEntityRecognition(mock_synthex) 39 | # Set up labels with typical NER tags 40 | ner._labels = ClassLabel(names=["O", "B-PERSON", "I-PERSON", "B-LOCATION", "I-LOCATION"]) 41 | return ner 42 | 43 | 44 | @pytest.mark.unit 45 | def test_cleanup_removes_invalid_labels( 46 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 47 | ): 48 | """ 49 | Test that rows with invalid labels are removed. 50 | Args: 51 | ner_instance: NamedEntityRecognition instance. 52 | mocker: pytest-mock fixture. 53 | tmp_path: pytest temporary directory fixture. 54 | """ 55 | 56 | # Create test CSV 57 | test_csv = tmp_path / "test_dataset.csv" 58 | df = pd.DataFrame({ 59 | "text": ["John lives in Paris", "Invalid data"], 60 | "labels": ["John: PERSON, Paris: LOCATION", "not a valid format"] 61 | }) 62 | df.to_csv(test_csv, index=False) 63 | 64 | # Run cleanup 65 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 66 | 67 | # Load result 68 | result_df = pd.read_csv(test_csv) 69 | 70 | # Should have removed the invalid row 71 | assert len(result_df) == 1 72 | assert "John lives in Paris" in result_df["text"].values 73 | 74 | 75 | @pytest.mark.unit 76 | def test_cleanup_converts_to_bio_format( 77 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 78 | ): 79 | """ 80 | Test that labels are correctly converted to BIO format. 81 | Args: 82 | ner_instance: NamedEntityRecognition instance. 83 | mocker: pytest-mock fixture. 84 | tmp_path: pytest temporary directory fixture. 85 | """ 86 | 87 | test_csv = tmp_path / "test_dataset.csv" 88 | df = pd.DataFrame({ 89 | "text": ["John Smith lives in New York"], 90 | "labels": ["John Smith: PERSON, New York: LOCATION"] 91 | }) 92 | df.to_csv(test_csv, index=False) 93 | 94 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 95 | 96 | result_df = pd.read_csv(test_csv) 97 | assert len(result_df) == 1 98 | 99 | # Parse the labels back from string representation 100 | import ast 101 | labels = ast.literal_eval(result_df["labels"].iloc[0]) 102 | 103 | # Check BIO tags 104 | assert labels[0] == "B-PERSON" # John 105 | assert labels[1] == "I-PERSON" # Smith 106 | assert labels[2] == "O" # lives 107 | assert labels[3] == "O" # in 108 | assert labels[4] == "B-LOCATION" # New 109 | assert labels[5] == "I-LOCATION" # York 110 | 111 | 112 | @pytest.mark.unit 113 | def test_cleanup_removes_empty_labels( 114 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 115 | ): 116 | """ 117 | Test that rows with empty labels are removed 118 | Args: 119 | ner_instance: NamedEntityRecognition instance. 120 | mocker: pytest-mock fixture. 121 | tmp_path: pytest temporary directory fixture. 122 | """ 123 | 124 | test_csv = tmp_path / "test_dataset.csv" 125 | df = pd.DataFrame({ 126 | "text": ["John lives here", "No entities here"], 127 | "labels": ["John: PERSON", ""] 128 | }) 129 | df.to_csv(test_csv, index=False) 130 | 131 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 132 | 133 | result_df = pd.read_csv(test_csv) 134 | 135 | # Should only keep the row with actual entities 136 | assert len(result_df) == 1 137 | assert "John lives here" in result_df["text"].values 138 | 139 | 140 | @pytest.mark.unit 141 | def test_cleanup_removes_only_o_tags( 142 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 143 | ): 144 | """ 145 | Test that rows with only 'O' tags are removed 146 | Args: 147 | ner_instance: NamedEntityRecognition instance. 148 | mocker: pytest-mock fixture. 149 | tmp_path: pytest temporary directory fixture. 150 | """ 151 | 152 | test_csv = tmp_path / "test_dataset.csv" 153 | df = pd.DataFrame({ 154 | "text": ["John lives here", "No entities here"], 155 | "labels": ["John: PERSON", "['O', 'O', 'O', 'O']"] 156 | }) 157 | 158 | df.to_csv(test_csv, index=False) 159 | 160 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 161 | 162 | result_df = pd.read_csv(test_csv) 163 | 164 | # Should only keep the row with actual entities 165 | assert len(result_df) == 1 166 | assert "John lives here" in result_df["text"].values 167 | 168 | 169 | @pytest.mark.unit 170 | def test_cleanup_removes_invalid_tags( 171 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 172 | ): 173 | """ 174 | Test that rows with invalid named entity tags are removed. 175 | Args: 176 | ner_instance: NamedEntityRecognition instance. 177 | mocker: pytest-mock fixture. 178 | tmp_path: pytest temporary directory fixture. 179 | """ 180 | 181 | test_csv = tmp_path / "test_dataset.csv" 182 | df = pd.DataFrame({ 183 | "text": ["John works at Google", "Jane lives in Paris"], 184 | "labels": ["John: PERSON, Google: ORGANIZATION", "Jane: PERSON, Paris: LOCATION"] 185 | }) 186 | df.to_csv(test_csv, index=False) 187 | 188 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 189 | 190 | result_df = pd.read_csv(test_csv) 191 | 192 | # Should remove row with ORGANIZATION tag (not in allowed tags) 193 | assert len(result_df) == 1 194 | assert "Jane lives in Paris" in result_df["text"].values 195 | 196 | 197 | @pytest.mark.unit 198 | def test_cleanup_handles_multi_word_entities( 199 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 200 | ): 201 | """ 202 | Test that multi-word entities are correctly handled. 203 | Args: 204 | ner_instance: NamedEntityRecognition instance. 205 | mocker: pytest-mock fixture. 206 | tmp_path: pytest temporary directory fixture. 207 | """ 208 | 209 | test_csv = tmp_path / "test_dataset.csv" 210 | df = pd.DataFrame({ 211 | "text": ["The Eiffel Tower is in Paris"], 212 | "labels": ["Eiffel Tower: LOCATION, Paris: LOCATION"] 213 | }) 214 | df.to_csv(test_csv, index=False) 215 | 216 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 217 | 218 | result_df = pd.read_csv(test_csv) 219 | assert len(result_df) == 1 220 | 221 | import ast 222 | labels = ast.literal_eval(result_df["labels"].iloc[0]) 223 | 224 | # Check that multi-word entity is tagged correctly 225 | assert labels[1] == "B-LOCATION" # Eiffel 226 | assert labels[2] == "I-LOCATION" # Tower 227 | 228 | 229 | @pytest.mark.unit 230 | def test_cleanup_handles_punctuation( 231 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 232 | ): 233 | """ 234 | Test that cleanup handles punctuation in entities correctly. 235 | Args: 236 | ner_instance: NamedEntityRecognition instance. 237 | mocker: pytest-mock fixture. 238 | tmp_path: pytest temporary directory fixture. 239 | """ 240 | 241 | test_csv = tmp_path / "test_dataset.csv" 242 | df = pd.DataFrame({ 243 | "text": ["Dr. John Smith, PhD lives here"], 244 | "labels": ["Dr. John Smith, PhD: PERSON"] 245 | }) 246 | df.to_csv(test_csv, index=False) 247 | 248 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 249 | 250 | result_df = pd.read_csv(test_csv) 251 | assert len(result_df) >= 0 # Should either process or remove based on punctuation handling 252 | 253 | 254 | @pytest.mark.unit 255 | def test_cleanup_case_insensitive_matching( 256 | ner_instance: NamedEntityRecognition, mocker: MockerFixture, tmp_path: Any 257 | ): 258 | """ 259 | Test that entity matching is case-insensitive. 260 | Args: 261 | ner_instance: NamedEntityRecognition instance. 262 | mocker: pytest-mock fixture. 263 | tmp_path: pytest temporary directory fixture. 264 | """ 265 | 266 | test_csv = tmp_path / "test_dataset.csv" 267 | df = pd.DataFrame({ 268 | "text": ["JOHN lives in paris"], 269 | "labels": ["john: PERSON, PARIS: LOCATION"] 270 | }) 271 | df.to_csv(test_csv, index=False) 272 | 273 | ner_instance._post_process_synthetic_dataset(str(test_csv)) 274 | 275 | result_df = pd.read_csv(test_csv) 276 | assert len(result_df) == 1 277 | 278 | import ast 279 | labels = ast.literal_eval(result_df["labels"].iloc[0]) 280 | assert labels[0] == "B-PERSON" 281 | assert labels[3] == "B-LOCATION" -------------------------------------------------------------------------------- /tests/unit/named_entity_recognition/test_ner_get_data_gen_instr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from synthex import Synthex 4 | 5 | from artifex.models import NamedEntityRecognition 6 | 7 | 8 | @pytest.fixture 9 | def mock_synthex(mocker: MockerFixture) -> Synthex: 10 | """ 11 | Create a mock Synthex instance. 12 | Args: 13 | mocker: pytest-mock fixture for creating mocks. 14 | Returns: 15 | Synthex: A mocked Synthex instance. 16 | """ 17 | 18 | return mocker.MagicMock(spec=Synthex) 19 | 20 | 21 | @pytest.fixture 22 | def ner_instance( 23 | mocker: MockerFixture, mock_synthex: Synthex 24 | ) -> NamedEntityRecognition: 25 | """ 26 | Create a NamedEntityRecognition instance with mocked dependencies. 27 | Args: 28 | mocker: pytest-mock fixture for creating mocks. 29 | mock_synthex: Mocked Synthex instance. 30 | Returns: 31 | NamedEntityRecognition: An instance with mocked dependencies. 32 | """ 33 | 34 | # Mock AutoTokenizer and AutoModelForTokenClassification imports 35 | mocker.patch( 36 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoTokenizer.from_pretrained" 37 | ) 38 | mocker.patch( 39 | "artifex.models.named_entity_recognition.named_entity_recognition.AutoModelForTokenClassification.from_pretrained" 40 | ) 41 | 42 | return NamedEntityRecognition(mock_synthex) 43 | 44 | 45 | @pytest.mark.unit 46 | def test_get_data_gen_instr_single_entity( 47 | ner_instance: NamedEntityRecognition 48 | ): 49 | """ 50 | Test data generation instruction formatting with a single entity tag. 51 | Args: 52 | ner_instance: Fixture providing NamedEntityRecognition instance. 53 | """ 54 | 55 | user_instr = [ 56 | "PERSON: A person's name", 57 | "medical records" 58 | ] 59 | 60 | result = ner_instance._get_data_gen_instr(user_instr) 61 | 62 | # Should return formatted system instructions 63 | assert len(result) == len(ner_instance._system_data_gen_instr) 64 | 65 | # Check that domain was properly formatted 66 | assert any("medical records" in instr for instr in result) 67 | 68 | # Check that named entity tags were properly formatted 69 | assert any("PERSON: A person's name" in instr for instr in result) 70 | 71 | 72 | @pytest.mark.unit 73 | def test_get_data_gen_instr_multiple_entities( 74 | ner_instance: NamedEntityRecognition 75 | ): 76 | """ 77 | Test data generation instruction formatting with multiple entity tags. 78 | Args: 79 | ner_instance: Fixture providing NamedEntityRecognition instance. 80 | """ 81 | 82 | user_instr = [ 83 | "PERSON: A person's name", 84 | "LOCATION: A geographical location", 85 | "ORGANIZATION: A company or institution", 86 | "news articles" 87 | ] 88 | 89 | result = ner_instance._get_data_gen_instr(user_instr) 90 | 91 | assert len(result) == len(ner_instance._system_data_gen_instr) 92 | 93 | # Check domain formatting 94 | assert any("news articles" in instr for instr in result) 95 | 96 | # Check all entity tags are included 97 | formatted_tags_str = " ".join(result) 98 | assert "PERSON: A person's name" in formatted_tags_str 99 | assert "LOCATION: A geographical location" in formatted_tags_str 100 | assert "ORGANIZATION: A company or institution" in formatted_tags_str 101 | 102 | 103 | @pytest.mark.unit 104 | def test_get_data_gen_instr_domain_only( 105 | ner_instance: NamedEntityRecognition 106 | ): 107 | """ 108 | Test data generation instruction formatting with only domain (no entity tags). 109 | Args: 110 | ner_instance: Fixture providing NamedEntityRecognition instance. 111 | """ 112 | 113 | user_instr = ["general text"] 114 | 115 | result = ner_instance._get_data_gen_instr(user_instr) 116 | 117 | assert len(result) == len(ner_instance._system_data_gen_instr) 118 | 119 | # Check domain was formatted 120 | assert any("general text" in instr for instr in result) 121 | 122 | # Named entity tags should be an empty list 123 | assert any("[]" in instr for instr in result) 124 | 125 | 126 | @pytest.mark.unit 127 | def test_get_data_gen_instr_format_placeholders( 128 | ner_instance: NamedEntityRecognition 129 | ): 130 | """ 131 | Test that all placeholders in system instructions are properly replaced. 132 | Args: 133 | ner_instance: Fixture providing NamedEntityRecognition instance. 134 | """ 135 | 136 | user_instr = [ 137 | "EMAIL: An email address", 138 | "customer support tickets" 139 | ] 140 | 141 | result = ner_instance._get_data_gen_instr(user_instr) 142 | 143 | # No placeholders should remain in the result 144 | for instr in result: 145 | assert "{domain}" not in instr 146 | assert "{named_entity_tags}" not in instr 147 | 148 | 149 | @pytest.mark.unit 150 | def test_get_data_gen_instr_preserves_instruction_count( 151 | ner_instance: NamedEntityRecognition 152 | ): 153 | """ 154 | Test that the number of instructions matches the system instruction count. 155 | Args: 156 | ner_instance: Fixture providing NamedEntityRecognition instance. 157 | """ 158 | 159 | user_instr = [ 160 | "PRODUCT: A product name", 161 | "PRICE: A monetary amount", 162 | "e-commerce reviews" 163 | ] 164 | 165 | result = ner_instance._get_data_gen_instr(user_instr) 166 | 167 | # Should have exactly the same number of instructions as system template 168 | assert len(result) == len(ner_instance._system_data_gen_instr) 169 | 170 | 171 | @pytest.mark.unit 172 | def test_get_data_gen_instr_domain_extraction( 173 | ner_instance: NamedEntityRecognition 174 | ): 175 | """ 176 | Test that domain is correctly extracted from the last element. 177 | Args: 178 | ner_instance: Fixture providing NamedEntityRecognition instance. 179 | """ 180 | 181 | user_instr = [ 182 | "DATE: A date reference", 183 | "MONEY: Monetary amounts", 184 | "TIME: A time reference", 185 | "financial reports and documents" 186 | ] 187 | 188 | result = ner_instance._get_data_gen_instr(user_instr) 189 | 190 | # Domain should be in the formatted instructions 191 | domain_present = any( 192 | "financial reports and documents" in instr 193 | for instr in result 194 | ) 195 | assert domain_present 196 | 197 | 198 | @pytest.mark.unit 199 | def test_get_data_gen_instr_entity_tags_extraction( 200 | ner_instance: NamedEntityRecognition 201 | ): 202 | """ 203 | Test that entity tags are correctly extracted from all elements except the last. 204 | Args: 205 | ner_instance: Fixture providing NamedEntityRecognition instance. 206 | """ 207 | 208 | user_instr = [ 209 | "DRUG: Pharmaceutical drug name", 210 | "DOSAGE: Drug dosage information", 211 | "medical prescriptions" 212 | ] 213 | 214 | result = ner_instance._get_data_gen_instr(user_instr) 215 | 216 | result_str = " ".join(result) 217 | 218 | # Both entity tags should be present 219 | assert "DRUG: Pharmaceutical drug name" in result_str 220 | assert "DOSAGE: Drug dosage information" in result_str 221 | 222 | 223 | @pytest.mark.unit 224 | def test_get_data_gen_instr_return_type( 225 | ner_instance: NamedEntityRecognition 226 | ): 227 | """ 228 | Test that the method returns a list of strings. 229 | Args: 230 | ner_instance: Fixture providing NamedEntityRecognition instance. 231 | """ 232 | 233 | user_instr = [ 234 | "GPE: Geo-political entity", 235 | "news corpus" 236 | ] 237 | 238 | result = ner_instance._get_data_gen_instr(user_instr) 239 | 240 | assert isinstance(result, list) 241 | assert all(isinstance(instr, str) for instr in result) 242 | 243 | 244 | @pytest.mark.unit 245 | def test_get_data_gen_instr_special_characters_in_domain( 246 | ner_instance: NamedEntityRecognition 247 | ): 248 | """ 249 | Test handling of special characters in domain. 250 | Args: 251 | ner_instance: Fixture providing NamedEntityRecognition instance. 252 | """ 253 | 254 | user_instr = [ 255 | "PERSON: A person", 256 | "domain with (special) characters & symbols!" 257 | ] 258 | 259 | result = ner_instance._get_data_gen_instr(user_instr) 260 | 261 | # Special characters should be preserved 262 | assert any( 263 | "domain with (special) characters & symbols!" in instr 264 | for instr in result 265 | ) 266 | 267 | 268 | @pytest.mark.unit 269 | def test_get_data_gen_instr_special_characters_in_tags( 270 | ner_instance: NamedEntityRecognition 271 | ): 272 | """ 273 | Test handling of special characters in entity tag descriptions. 274 | Args: 275 | ner_instance: Fixture providing NamedEntityRecognition instance. 276 | """ 277 | 278 | user_instr = [ 279 | "EMAIL: An email (e.g., user@example.com)", 280 | "PHONE: A phone number (+1-555-1234)", 281 | "contact information" 282 | ] 283 | 284 | result = ner_instance._get_data_gen_instr(user_instr) 285 | 286 | result_str = " ".join(result) 287 | 288 | # Special characters in descriptions should be preserved 289 | assert "user@example.com" in result_str 290 | assert "+1-555-1234" in result_str 291 | 292 | 293 | @pytest.mark.unit 294 | def test_get_data_gen_instr_empty_entity_list( 295 | ner_instance: NamedEntityRecognition 296 | ): 297 | """ 298 | Test with empty entity list (only domain provided). 299 | Args: 300 | ner_instance: Fixture providing NamedEntityRecognition instance. 301 | """ 302 | 303 | user_instr = ["simple domain"] 304 | 305 | result = ner_instance._get_data_gen_instr(user_instr) 306 | 307 | # Should still return all system instructions 308 | assert len(result) == len(ner_instance._system_data_gen_instr) 309 | 310 | # Domain should be present 311 | assert any("simple domain" in instr for instr in result) -------------------------------------------------------------------------------- /tests/unit/text_anonymization/test_ta_train.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | from synthex import Synthex 4 | from transformers.trainer_utils import TrainOutput 5 | from typing import Optional 6 | 7 | from artifex.models import TextAnonymization 8 | from artifex.config import config 9 | 10 | 11 | @pytest.fixture 12 | def mock_synthex(mocker: MockerFixture) -> Synthex: 13 | """ 14 | Creates a mock Synthex instance. 15 | Args: 16 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 17 | Returns: 18 | Synthex: A mocked Synthex instance. 19 | """ 20 | 21 | return mocker.Mock(spec=Synthex) 22 | 23 | 24 | @pytest.fixture 25 | def text_anonymization(mock_synthex: Synthex, mocker: MockerFixture) -> TextAnonymization: 26 | """ 27 | Creates a TextAnonymization instance with mocked dependencies. 28 | Args: 29 | mock_synthex (Synthex): A mocked Synthex instance. 30 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 31 | Returns: 32 | TextAnonymization: An instance of TextAnonymization with mocked parent class. 33 | """ 34 | 35 | # Mock the parent class __init__ to avoid initialization issues 36 | mocker.patch.object(TextAnonymization.__bases__[0], '__init__', return_value=None) 37 | instance = TextAnonymization(mock_synthex) 38 | return instance 39 | 40 | 41 | @pytest.mark.unit 42 | def test_train_with_default_parameters( 43 | text_anonymization: TextAnonymization, mocker: MockerFixture 44 | ): 45 | """ 46 | Tests train() with default parameters. 47 | Args: 48 | text_anonymization (TextAnonymization): The TextAnonymization instance. 49 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 50 | """ 51 | 52 | mock_train_output = mocker.Mock(spec=TrainOutput) 53 | mock_parent_train = mocker.patch.object( 54 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 55 | ) 56 | 57 | domain = "healthcare" 58 | result = text_anonymization.train(domain=domain) 59 | 60 | expected_pii_entities = { 61 | "PERSON": "Individual people, fictional characters", 62 | "LOCATION": "Geographical areas", 63 | "DATE": "Absolute or relative dates, including years, months and/or days", 64 | "ADDRESS": "full addresses", 65 | "PHONE_NUMBER": "telephone numbers", 66 | } 67 | 68 | mock_parent_train.assert_called_once_with( 69 | named_entities=expected_pii_entities, 70 | domain=domain, 71 | output_path=None, 72 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM, 73 | num_epochs=3, 74 | train_datapoint_examples=None 75 | ) 76 | assert result == mock_train_output 77 | 78 | 79 | @pytest.mark.unit 80 | def test_train_with_custom_output_path( 81 | text_anonymization: TextAnonymization, mocker: MockerFixture 82 | ): 83 | """ 84 | Tests train() with a custom output_path parameter. 85 | Args: 86 | text_anonymization (TextAnonymization): The TextAnonymization instance. 87 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 88 | """ 89 | 90 | mock_train_output = mocker.Mock(spec=TrainOutput) 91 | mock_parent_train = mocker.patch.object( 92 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 93 | ) 94 | 95 | domain = "finance" 96 | output_path = "/custom/path/to/model" 97 | result = text_anonymization.train(domain=domain, output_path=output_path) 98 | 99 | expected_pii_entities = text_anonymization._pii_entities 100 | 101 | mock_parent_train.assert_called_once_with( 102 | named_entities=expected_pii_entities, 103 | domain=domain, 104 | output_path=output_path, 105 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM, 106 | num_epochs=3, 107 | train_datapoint_examples=None 108 | ) 109 | assert result == mock_train_output 110 | 111 | 112 | @pytest.mark.unit 113 | def test_train_with_custom_num_samples( 114 | text_anonymization: TextAnonymization, mocker: MockerFixture 115 | ): 116 | """ 117 | Tests train() with a custom num_samples parameter. 118 | Args: 119 | text_anonymization (TextAnonymization): The TextAnonymization instance. 120 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 121 | """ 122 | 123 | mock_train_output = mocker.Mock(spec=TrainOutput) 124 | mock_parent_train = mocker.patch.object( 125 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 126 | ) 127 | 128 | domain = "legal" 129 | num_samples = 500 130 | result = text_anonymization.train(domain=domain, num_samples=num_samples) 131 | 132 | expected_pii_entities = text_anonymization._pii_entities 133 | 134 | mock_parent_train.assert_called_once_with( 135 | named_entities=expected_pii_entities, 136 | domain=domain, 137 | output_path=None, 138 | num_samples=num_samples, 139 | num_epochs=3, 140 | train_datapoint_examples=None 141 | ) 142 | assert result == mock_train_output 143 | 144 | 145 | @pytest.mark.unit 146 | def test_train_with_custom_num_epochs( 147 | text_anonymization: TextAnonymization, mocker: MockerFixture 148 | ): 149 | """ 150 | Tests train() with a custom num_epochs parameter. 151 | Args: 152 | text_anonymization (TextAnonymization): The TextAnonymization instance. 153 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 154 | """ 155 | 156 | mock_train_output = mocker.Mock(spec=TrainOutput) 157 | mock_parent_train = mocker.patch.object( 158 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 159 | ) 160 | 161 | domain = "customer_service" 162 | num_epochs = 5 163 | result = text_anonymization.train(domain=domain, num_epochs=num_epochs) 164 | 165 | expected_pii_entities = text_anonymization._pii_entities 166 | 167 | mock_parent_train.assert_called_once_with( 168 | named_entities=expected_pii_entities, 169 | domain=domain, 170 | output_path=None, 171 | num_samples=config.DEFAULT_SYNTHEX_DATAPOINT_NUM, 172 | num_epochs=num_epochs, 173 | train_datapoint_examples=None 174 | ) 175 | assert result == mock_train_output 176 | 177 | 178 | @pytest.mark.unit 179 | def test_train_with_all_custom_parameters( 180 | text_anonymization: TextAnonymization, mocker: MockerFixture 181 | ): 182 | """ 183 | Tests train() with all custom parameters. 184 | Args: 185 | text_anonymization (TextAnonymization): The TextAnonymization instance. 186 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 187 | """ 188 | 189 | mock_train_output = mocker.Mock(spec=TrainOutput) 190 | mock_parent_train = mocker.patch.object( 191 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 192 | ) 193 | 194 | domain = "retail" 195 | output_path = "/path/to/retail/model" 196 | num_samples = 1000 197 | num_epochs = 10 198 | 199 | result = text_anonymization.train( 200 | domain=domain, 201 | output_path=output_path, 202 | num_samples=num_samples, 203 | num_epochs=num_epochs 204 | ) 205 | 206 | expected_pii_entities = text_anonymization._pii_entities 207 | 208 | mock_parent_train.assert_called_once_with( 209 | named_entities=expected_pii_entities, 210 | domain=domain, 211 | output_path=output_path, 212 | num_samples=num_samples, 213 | num_epochs=num_epochs, 214 | train_datapoint_examples=None 215 | ) 216 | assert result == mock_train_output 217 | 218 | 219 | @pytest.mark.unit 220 | def test_train_uses_predefined_pii_entities( 221 | text_anonymization: TextAnonymization, mocker: MockerFixture 222 | ): 223 | """ 224 | Tests that train() always uses the predefined PII entities. 225 | Args: 226 | text_anonymization (TextAnonymization): The TextAnonymization instance. 227 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 228 | """ 229 | 230 | mock_train_output = mocker.Mock(spec=TrainOutput) 231 | mock_parent_train = mocker.patch.object( 232 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 233 | ) 234 | 235 | domain = "education" 236 | text_anonymization.train(domain=domain) 237 | 238 | # Verify that named_entities is the PII entities dict 239 | call_kwargs = mock_parent_train.call_args.kwargs 240 | assert "named_entities" in call_kwargs 241 | assert call_kwargs["named_entities"] == text_anonymization._pii_entities 242 | assert "PERSON" in call_kwargs["named_entities"] 243 | assert "LOCATION" in call_kwargs["named_entities"] 244 | assert "DATE" in call_kwargs["named_entities"] 245 | assert "ADDRESS" in call_kwargs["named_entities"] 246 | assert "PHONE_NUMBER" in call_kwargs["named_entities"] 247 | 248 | 249 | @pytest.mark.unit 250 | def test_train_sets_train_datapoint_examples_to_none( 251 | text_anonymization: TextAnonymization, mocker: MockerFixture 252 | ): 253 | """ 254 | Tests that train() always sets train_datapoint_examples to None. 255 | Args: 256 | text_anonymization (TextAnonymization): The TextAnonymization instance. 257 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 258 | """ 259 | 260 | mock_train_output = mocker.Mock(spec=TrainOutput) 261 | mock_parent_train = mocker.patch.object( 262 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 263 | ) 264 | 265 | domain = "technology" 266 | text_anonymization.train(domain=domain) 267 | 268 | # Verify that train_datapoint_examples is always None 269 | call_kwargs = mock_parent_train.call_args.kwargs 270 | assert "train_datapoint_examples" in call_kwargs 271 | assert call_kwargs["train_datapoint_examples"] is None 272 | 273 | 274 | @pytest.mark.unit 275 | def test_train_returns_train_output( 276 | text_anonymization: TextAnonymization, mocker: MockerFixture 277 | ): 278 | """ 279 | Tests that train() returns the TrainOutput from parent class. 280 | Args: 281 | text_anonymization (TextAnonymization): The TextAnonymization instance. 282 | mocker (MockerFixture): The pytest-mock fixture for creating mocks. 283 | """ 284 | 285 | mock_train_output = mocker.Mock(spec=TrainOutput) 286 | mock_train_output.training_loss = 0.05 287 | mock_train_output.metrics = {"accuracy": 0.95} 288 | 289 | mocker.patch.object( 290 | TextAnonymization.__bases__[0], 'train', return_value=mock_train_output 291 | ) 292 | 293 | domain = "insurance" 294 | result = text_anonymization.train(domain=domain) 295 | 296 | assert result == mock_train_output 297 | assert result.training_loss == 0.05 298 | assert result.metrics == {"accuracy": 0.95} -------------------------------------------------------------------------------- /tests/unit/reranker/test_rr__call__.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | import pytest 3 | from pytest_mock import MockerFixture 4 | import torch 5 | 6 | from artifex.models import Reranker 7 | from artifex.config import config 8 | 9 | 10 | @pytest.fixture(scope="function", autouse=True) 11 | def mock_dependencies(mocker: MockerFixture): 12 | """ 13 | Fixture to mock all external dependencies before any test runs. 14 | This fixture runs automatically for all tests in this module. 15 | Args: 16 | mocker (MockerFixture): The pytest-mock fixture for mocking. 17 | """ 18 | 19 | # Mock config 20 | mocker.patch.object(config, "RERANKER_HF_BASE_MODEL", "mock-reranker-model") 21 | mocker.patch.object(config, "RERANKER_TOKENIZER_MAX_LENGTH", 512) 22 | 23 | # Mock AutoTokenizer at the module where it's used 24 | mock_tokenizer = mocker.MagicMock() 25 | mocker.patch( 26 | "artifex.models.reranker.reranker.AutoTokenizer.from_pretrained", 27 | return_value=mock_tokenizer 28 | ) 29 | 30 | # Mock AutoModelForSequenceClassification at the module where it's used 31 | mock_model = mocker.MagicMock() 32 | mocker.patch( 33 | "artifex.models.reranker.reranker.AutoModelForSequenceClassification.from_pretrained", 34 | return_value=mock_model 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def mock_synthex(mocker: MockerFixture) -> Synthex: 40 | """ 41 | Fixture to create a mock Synthex instance. 42 | Args: 43 | mocker (MockerFixture): The pytest-mock fixture for mocking. 44 | Returns: 45 | Synthex: A mocked Synthex instance. 46 | """ 47 | 48 | return mocker.MagicMock(spec=Synthex) 49 | 50 | 51 | @pytest.fixture 52 | def mock_reranker(mocker: MockerFixture, mock_synthex: Synthex) -> Reranker: 53 | """ 54 | Fixture to create a Reranker instance with mocked dependencies. 55 | Args: 56 | mocker (MockerFixture): The pytest-mock fixture for mocking. 57 | mock_synthex (Synthex): A mocked Synthex instance. 58 | Returns: 59 | Reranker: An instance of the Reranker model with mocked dependencies. 60 | """ 61 | 62 | reranker = Reranker(mock_synthex) 63 | 64 | # Mock the tokenizer to return proper inputs 65 | mock_tokenizer_output = { 66 | "input_ids": torch.tensor([[1, 2, 3]]), 67 | "attention_mask": torch.tensor([[1, 1, 1]]) 68 | } 69 | reranker._tokenizer.return_value = mock_tokenizer_output 70 | 71 | # Mock the model output 72 | mock_model_output = mocker.MagicMock() 73 | reranker._model.return_value = mock_model_output 74 | 75 | return reranker 76 | 77 | 78 | @pytest.mark.unit 79 | def test_call_with_single_document( 80 | mock_reranker: Reranker, mocker: MockerFixture 81 | ): 82 | """ 83 | Test that __call__ works correctly with a single document string. 84 | Args: 85 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 86 | mocker (MockerFixture): The pytest-mock fixture for mocking. 87 | """ 88 | 89 | query = "What is machine learning?" 90 | document = "Machine learning is a subset of artificial intelligence." 91 | 92 | # Mock model output 93 | mock_logits = torch.tensor([[0.85]]) 94 | mock_output = mocker.MagicMock() 95 | mock_output.logits = mock_logits 96 | mock_reranker._model.return_value = mock_output 97 | 98 | result = mock_reranker(query, document) 99 | 100 | # Should return a list with one tuple 101 | assert isinstance(result, list) 102 | assert len(result) == 1 103 | assert isinstance(result[0], tuple) 104 | assert result[0][0] == document 105 | assert isinstance(result[0][1], float) 106 | 107 | 108 | @pytest.mark.unit 109 | def test_call_with_multiple_documents( 110 | mock_reranker: Reranker, mocker: MockerFixture 111 | ): 112 | """ 113 | Test that __call__ works correctly with multiple documents. 114 | Args: 115 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 116 | mocker (MockerFixture): The pytest-mock fixture for mocking. 117 | """ 118 | 119 | query = "What is Python?" 120 | documents = [ 121 | "Python is a programming language.", 122 | "The python is a snake.", 123 | "Python was created by Guido van Rossum." 124 | ] 125 | 126 | # Mock model output with different scores 127 | mock_logits = torch.tensor([[0.9], [0.2], [0.75]]) 128 | mock_output = mocker.MagicMock() 129 | mock_output.logits = mock_logits 130 | mock_reranker._model.return_value = mock_output 131 | 132 | result = mock_reranker(query, documents) 133 | 134 | # Should return a list with three tuples 135 | assert isinstance(result, list) 136 | assert len(result) == 3 137 | 138 | # Check that all documents are present 139 | result_docs = [doc for doc, _ in result] 140 | assert set(result_docs) == set(documents) 141 | 142 | # Check that results are sorted by score (descending) 143 | scores = [score for _, score in result] 144 | assert scores == sorted(scores, reverse=True) 145 | 146 | 147 | @pytest.mark.unit 148 | def test_call_documents_sorted_by_score( 149 | mock_reranker: Reranker, mocker: MockerFixture 150 | ): 151 | """ 152 | Test that documents are correctly sorted by relevance score in descending order. 153 | Args: 154 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 155 | mocker (MockerFixture): The pytest-mock fixture for mocking. 156 | """ 157 | 158 | query = "climate change" 159 | documents = ["Doc A", "Doc B", "Doc C"] 160 | 161 | # Mock model output with specific scores 162 | mock_logits = torch.tensor([[0.3], [0.9], [0.6]]) 163 | mock_output = mocker.MagicMock() 164 | mock_output.logits = mock_logits 165 | mock_reranker._model.return_value = mock_output 166 | 167 | result = mock_reranker(query, documents) 168 | 169 | # Check ordering: Doc B (0.9), Doc C (0.6), Doc A (0.3) 170 | assert result[0][0] == "Doc B" 171 | assert result[1][0] == "Doc C" 172 | assert result[2][0] == "Doc A" 173 | 174 | assert result[0][1] > result[1][1] > result[2][1] 175 | 176 | 177 | @pytest.mark.unit 178 | def test_call_tokenizer_called_correctly( 179 | mock_reranker: Reranker, mocker: MockerFixture 180 | ): 181 | """ 182 | Test that the tokenizer is called with correct parameters. 183 | Args: 184 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 185 | mocker (MockerFixture): The pytest-mock fixture for mocking. 186 | """ 187 | 188 | query = "test query" 189 | documents = ["doc1", "doc2"] 190 | 191 | # Mock model output 192 | mock_logits = torch.tensor([[0.5], [0.7]]) 193 | mock_output = mocker.MagicMock() 194 | mock_output.logits = mock_logits 195 | mock_reranker._model.return_value = mock_output 196 | 197 | mock_reranker(query, documents) 198 | 199 | # Verify tokenizer was called with correct arguments 200 | mock_reranker._tokenizer.assert_called_once() 201 | call_args = mock_reranker._tokenizer.call_args 202 | 203 | # First argument should be list of queries 204 | assert call_args[0][0] == [query, query] 205 | # Second argument should be list of documents 206 | assert call_args[0][1] == documents 207 | # Check keyword arguments 208 | assert call_args[1]["return_tensors"] == "pt" 209 | assert call_args[1]["truncation"] is True 210 | assert call_args[1]["padding"] is True 211 | 212 | 213 | @pytest.mark.unit 214 | def test_call_model_called_with_tokenizer_output( 215 | mock_reranker: Reranker, mocker: MockerFixture 216 | ): 217 | """ 218 | Test that the model is called with the tokenizer"s output. 219 | Args: 220 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 221 | mocker (MockerFixture): The pytest-mock fixture for mocking. 222 | """ 223 | 224 | query = "test" 225 | documents = ["document"] 226 | 227 | # Set up tokenizer mock output 228 | tokenizer_output = { 229 | "input_ids": torch.tensor([[1, 2, 3]]), 230 | "attention_mask": torch.tensor([[1, 1, 1]]) 231 | } 232 | mock_reranker._tokenizer.return_value = tokenizer_output 233 | 234 | # Mock model output 235 | mock_logits = torch.tensor([[0.5]]) 236 | mock_output = mocker.MagicMock() 237 | mock_output.logits = mock_logits 238 | mock_reranker._model.return_value = mock_output 239 | 240 | mock_reranker(query, documents) 241 | 242 | # Verify model was called with tokenizer output 243 | mock_reranker._model.assert_called_once() 244 | call_kwargs = mock_reranker._model.call_args[1] 245 | assert "input_ids" in call_kwargs or len(mock_reranker._model.call_args[0]) > 0 246 | 247 | 248 | @pytest.mark.unit 249 | def test_call_with_empty_document_list( 250 | mock_reranker: Reranker, mocker: MockerFixture 251 | ): 252 | """ 253 | Test that __call__ handles an empty document list correctly. 254 | Args: 255 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 256 | mocker (MockerFixture): The pytest-mock fixture for mocking. 257 | """ 258 | 259 | query = "test query" 260 | documents = [] 261 | 262 | # Mock model output for empty list 263 | mock_logits = torch.tensor([]).reshape(0, 1) 264 | mock_output = mocker.MagicMock() 265 | mock_output.logits = mock_logits 266 | mock_reranker._model.return_value = mock_output 267 | 268 | result = mock_reranker(query, documents) 269 | 270 | assert isinstance(result, list) 271 | assert len(result) == 0 272 | 273 | 274 | @pytest.mark.unit 275 | def test_call_converts_single_string_to_list( 276 | mock_reranker: Reranker, mocker: MockerFixture 277 | ): 278 | """ 279 | Test that a single document string is converted to a list internally. 280 | Args: 281 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 282 | mocker (MockerFixture): The pytest-mock fixture for mocking. 283 | """ 284 | 285 | query = "query" 286 | document = "single document" 287 | 288 | # Mock model output 289 | mock_logits = torch.tensor([[0.8]]) 290 | mock_output = mocker.MagicMock() 291 | mock_output.logits = mock_logits 292 | mock_reranker._model.return_value = mock_output 293 | 294 | result = mock_reranker(query, document) 295 | 296 | # Result should be a list with one element 297 | assert isinstance(result, list) 298 | assert len(result) == 1 299 | assert result[0][0] == document 300 | 301 | 302 | @pytest.mark.unit 303 | def test_call_returns_tuples_with_correct_types( 304 | mock_reranker: Reranker, mocker: MockerFixture 305 | ): 306 | """ 307 | Test that the return value contains tuples of (str, float). 308 | Args: 309 | mock_reranker (Reranker): The Reranker instance with mocked dependencies. 310 | mocker (MockerFixture): The pytest-mock fixture for mocking. 311 | """ 312 | query = "test" 313 | documents = ["doc1", "doc2", "doc3"] 314 | 315 | # Mock model output 316 | mock_logits = torch.tensor([[0.5], [0.7], [0.3]]) 317 | mock_output = mocker.MagicMock() 318 | mock_output.logits = mock_logits 319 | mock_reranker._model.return_value = mock_output 320 | 321 | result = mock_reranker(query, documents) 322 | 323 | for item in result: 324 | assert isinstance(item, tuple) 325 | assert len(item) == 2 326 | assert isinstance(item[0], str) 327 | assert isinstance(item[1], float) -------------------------------------------------------------------------------- /tests/unit/base_model/test_base_model_load.py: -------------------------------------------------------------------------------- 1 | from synthex import Synthex 2 | import pytest 3 | from pytest_mock import MockerFixture 4 | import os 5 | 6 | 7 | class ConcreteBaseModel: 8 | """ 9 | Concrete implementation of BaseModel for testing purposes. 10 | """ 11 | 12 | def __init__(self, synthex: Synthex): 13 | from artifex.models import BaseModel 14 | # Copy the load method to this class 15 | self.load = BaseModel.load.__get__(self, ConcreteBaseModel) 16 | self._load_model = lambda model_path: None # Mock implementation 17 | 18 | 19 | @pytest.fixture 20 | def mock_synthex(mocker: MockerFixture) -> Synthex: 21 | """ 22 | Fixture to create a mock Synthex instance. 23 | Args: 24 | mocker (MockerFixture): The pytest-mock fixture for mocking. 25 | Returns: 26 | Synthex: A mocked Synthex instance. 27 | """ 28 | 29 | return mocker.MagicMock(spec=Synthex) 30 | 31 | 32 | @pytest.fixture 33 | def concrete_model(mock_synthex: Synthex) -> ConcreteBaseModel: 34 | """ 35 | Fixture to create a concrete BaseModel instance for testing. 36 | Args: 37 | mock_synthex (Synthex): A mocked Synthex instance. 38 | Returns: 39 | ConcreteBaseModel: A concrete implementation of BaseModel. 40 | """ 41 | 42 | return ConcreteBaseModel(mock_synthex) 43 | 44 | 45 | @pytest.mark.unit 46 | def test_load_with_valid_path(concrete_model: ConcreteBaseModel, mocker: MockerFixture): 47 | """ 48 | Test that load() successfully loads a model from a valid path. 49 | Args: 50 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 51 | mocker (MockerFixture): The pytest-mock fixture for mocking. 52 | """ 53 | 54 | model_path = "/fake/model/path" 55 | 56 | # Mock os.path.exists to return True 57 | mocker.patch("os.path.exists", return_value=True) 58 | 59 | # Mock _load_model 60 | mock_load_model = mocker.patch.object(concrete_model, "_load_model") 61 | 62 | concrete_model.load(model_path) 63 | 64 | # Verify _load_model was called with the correct path 65 | mock_load_model.assert_called_once_with(model_path) 66 | 67 | 68 | @pytest.mark.unit 69 | def test_load_raises_error_when_path_does_not_exist( 70 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 71 | ): 72 | """ 73 | Test that load() raises OSError when the model path does not exist. 74 | Args: 75 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 76 | mocker (MockerFixture): The pytest-mock fixture for mocking. 77 | """ 78 | 79 | model_path = "/nonexistent/path" 80 | 81 | # Mock os.path.exists to return False for the directory 82 | mocker.patch("os.path.exists", return_value=False) 83 | 84 | with pytest.raises(OSError) as exc_info: 85 | concrete_model.load(model_path) 86 | 87 | assert f"The specified model path '{model_path}' does not exist" in str(exc_info.value) 88 | 89 | 90 | @pytest.mark.unit 91 | def test_load_raises_error_when_config_json_missing( 92 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 93 | ): 94 | """ 95 | Test that load() raises OSError when config.json is missing. 96 | Args: 97 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 98 | mocker (MockerFixture): The pytest-mock fixture for mocking. 99 | """ 100 | 101 | model_path = "/fake/model/path" 102 | 103 | # Mock os.path.exists - directory exists, but config.json doesn"t 104 | def exists_side_effect(path: str) -> bool: 105 | if path == model_path: 106 | return True 107 | if path == os.path.join(model_path, "config.json"): 108 | return False 109 | return True 110 | 111 | mocker.patch("os.path.exists", side_effect=exists_side_effect) 112 | 113 | with pytest.raises(OSError) as exc_info: 114 | concrete_model.load(model_path) 115 | 116 | assert "missing the required file 'config.json'" in str(exc_info.value) 117 | 118 | 119 | @pytest.mark.unit 120 | def test_load_raises_error_when_model_safetensors_missing( 121 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 122 | ): 123 | """ 124 | Test that load() raises OSError when model.safetensors is missing. 125 | Args: 126 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 127 | mocker (MockerFixture): The pytest-mock fixture for mocking. 128 | """ 129 | 130 | model_path = "/fake/model/path" 131 | 132 | # Mock os.path.exists - directory and config.json exist, but model.safetensors doesn"t 133 | def exists_side_effect(path: str) -> bool: 134 | if path == model_path: 135 | return True 136 | if path == os.path.join(model_path, "config.json"): 137 | return True 138 | if path == os.path.join(model_path, "model.safetensors"): 139 | return False 140 | return True 141 | 142 | mocker.patch("os.path.exists", side_effect=exists_side_effect) 143 | 144 | with pytest.raises(OSError) as exc_info: 145 | concrete_model.load(model_path) 146 | 147 | assert "missing the required file 'model.safetensors'" in str(exc_info.value) 148 | 149 | 150 | @pytest.mark.unit 151 | def test_load_checks_all_required_files( 152 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 153 | ): 154 | """ 155 | Test that load() checks for all required files (config.json and model.safetensors). 156 | Args: 157 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 158 | mocker (MockerFixture): The pytest-mock fixture for mocking. 159 | """ 160 | 161 | model_path = "/fake/model/path" 162 | 163 | # Track all paths that os.path.exists is called with 164 | checked_paths: list[str] = [] 165 | 166 | def exists_side_effect(path: str) -> bool: 167 | checked_paths.append(path) 168 | return True 169 | 170 | mocker.patch("os.path.exists", side_effect=exists_side_effect) 171 | mocker.patch.object(concrete_model, "_load_model") 172 | 173 | concrete_model.load(model_path) 174 | 175 | # Verify that both required files were checked 176 | assert os.path.join(model_path, "config.json") in checked_paths 177 | assert os.path.join(model_path, "model.safetensors") in checked_paths 178 | 179 | 180 | @pytest.mark.unit 181 | def test_load_validation_failure_with_non_string_path(concrete_model: ConcreteBaseModel): 182 | """ 183 | Test that load() raises ValidationError when model_path is not a string. 184 | Args: 185 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 186 | """ 187 | 188 | from artifex.core import ValidationError 189 | 190 | with pytest.raises(ValidationError): 191 | concrete_model.load(123) 192 | 193 | 194 | @pytest.mark.unit 195 | def test_load_validation_failure_with_none_path(concrete_model: ConcreteBaseModel): 196 | """ 197 | Test that load() raises ValidationError when model_path is None. 198 | Args: 199 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 200 | """ 201 | 202 | from artifex.core import ValidationError 203 | 204 | with pytest.raises(ValidationError): 205 | concrete_model.load(None) 206 | 207 | 208 | @pytest.mark.unit 209 | def test_load_with_relative_path( 210 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 211 | ): 212 | """ 213 | Test that load() works with relative paths. 214 | Args: 215 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 216 | mocker (MockerFixture): The pytest-mock fixture for mocking. 217 | """ 218 | 219 | model_path = "./models/my_model" 220 | 221 | mocker.patch("os.path.exists", return_value=True) 222 | mock_load_model = mocker.patch.object(concrete_model, "_load_model") 223 | 224 | concrete_model.load(model_path) 225 | 226 | mock_load_model.assert_called_once_with(model_path) 227 | 228 | 229 | @pytest.mark.unit 230 | def test_load_with_absolute_path( 231 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 232 | ): 233 | """ 234 | Test that load() works with absolute paths. 235 | Args: 236 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 237 | mocker (MockerFixture): The pytest-mock fixture for mocking. 238 | """ 239 | 240 | model_path = "/home/user/models/my_model" 241 | 242 | mocker.patch("os.path.exists", return_value=True) 243 | mock_load_model = mocker.patch.object(concrete_model, "_load_model") 244 | 245 | concrete_model.load(model_path) 246 | 247 | mock_load_model.assert_called_once_with(model_path) 248 | 249 | 250 | @pytest.mark.unit 251 | def test_load_with_path_containing_spaces( 252 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 253 | ): 254 | """ 255 | Test that load() works with paths containing spaces. 256 | Args: 257 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 258 | mocker (MockerFixture): The pytest-mock fixture for mocking. 259 | """ 260 | 261 | model_path = "/path/with spaces/my model" 262 | 263 | mocker.patch("os.path.exists", return_value=True) 264 | mock_load_model = mocker.patch.object(concrete_model, "_load_model") 265 | 266 | concrete_model.load(model_path) 267 | 268 | mock_load_model.assert_called_once_with(model_path) 269 | 270 | 271 | @pytest.mark.unit 272 | def test_load_error_message_contains_path( 273 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 274 | ): 275 | """ 276 | Test that OSError messages contain the specified path for debugging. 277 | Args: 278 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 279 | mocker (MockerFixture): The pytest-mock fixture for mocking. 280 | """ 281 | 282 | model_path = "/custom/model/path" 283 | 284 | mocker.patch("os.path.exists", return_value=False) 285 | 286 | with pytest.raises(OSError) as exc_info: 287 | concrete_model.load(model_path) 288 | 289 | assert model_path in str(exc_info.value) 290 | 291 | 292 | @pytest.mark.unit 293 | def test_load_calls_load_model_only_after_validation( 294 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 295 | ): 296 | """ 297 | Test that _load_model is only called after all validations pass. 298 | Args: 299 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 300 | mocker (MockerFixture): The pytest-mock fixture for mocking. 301 | """ 302 | 303 | model_path = "/fake/model/path" 304 | 305 | # Make config.json missing 306 | def exists_side_effect(path: str) -> bool: 307 | if path == model_path: 308 | return True 309 | if path == os.path.join(model_path, "config.json"): 310 | return False 311 | return True 312 | 313 | mocker.patch("os.path.exists", side_effect=exists_side_effect) 314 | mock_load_model = mocker.patch.object(concrete_model, "_load_model") 315 | 316 | with pytest.raises(OSError): 317 | concrete_model.load(model_path) 318 | 319 | # _load_model should not be called if validation fails 320 | mock_load_model.assert_not_called() 321 | 322 | 323 | @pytest.mark.unit 324 | def test_load_with_empty_string_path( 325 | concrete_model: ConcreteBaseModel, mocker: MockerFixture 326 | ): 327 | """ 328 | Test that load() handles empty string path appropriately. 329 | Args: 330 | concrete_model (ConcreteBaseModel): The concrete BaseModel instance. 331 | mocker (MockerFixture): The pytest-mock fixture for mocking. 332 | """ 333 | 334 | model_path = "" 335 | 336 | mocker.patch("os.path.exists", return_value=False) 337 | 338 | with pytest.raises(OSError) as exc_info: 339 | concrete_model.load(model_path) 340 | 341 | assert "does not exist" in str(exc_info.value) --------------------------------------------------------------------------------