├── tests ├── evaluators │ └── __init__.py ├── losses │ ├── __init__.py │ ├── utils.py │ ├── test_softnn_loss.py │ └── test_xbm_loss.py ├── matchers │ ├── __init__.py │ ├── test_match_nearest.py │ └── test_majority_vote.py ├── models │ ├── __init__.py │ └── test_similarity_model.py ├── samplers │ ├── __init__.py │ ├── test_tfdataset_samplers.py │ └── test_tfrecord_samplers.py ├── search │ ├── __init__.py │ └── test_nmslib.py ├── stores │ ├── __init__.py │ ├── test_redis.py │ ├── test_memory.py │ └── test_cached.py ├── architectures │ ├── __init__.py │ └── test_resnet18.py ├── visualization │ └── __init__.py ├── retrieval_metrics │ ├── __init__.py │ ├── test_recall_at_k.py │ ├── test_precision_at_k.py │ ├── test_map_at_k.py │ └── test_retrieval_metric.py ├── training_metrics │ └── __init__.py ├── classification_metrics │ └── __init__.py ├── data │ └── mnist_fashion_embeddings.npz ├── __init__.py ├── conftest.py ├── augmenters │ ├── test_augmentation_utils │ │ ├── test_solarize.py │ │ ├── test_crop.py │ │ ├── test_flip.py │ │ ├── test_blur.py │ │ └── test_color_jitter.py │ └── test_augmenters.py └── integration │ └── test_readme_minimal.py ├── benchmark └── supervised │ ├── requirements.txt │ ├── components │ ├── datasets │ │ ├── __init__.py │ │ └── types.py │ ├── __init__.py │ ├── augmentations.py │ ├── architectures.py │ ├── optimizers.py │ └── experiments.py │ ├── README.md │ └── create_datasets.py ├── tensorflow_similarity ├── augmenters │ ├── augmentation_utils │ │ ├── __init__.py │ │ ├── random_apply.py │ │ ├── solarize.py │ │ └── flip.py │ ├── __init__.py │ ├── augmenter.py │ └── contrastive.py ├── __init__.py ├── architectures │ ├── __init__.py │ └── utils.py ├── models │ └── __init__.py ├── matchers │ ├── __init__.py │ ├── utils.py │ ├── match_nearest.py │ └── match_majority_vote.py ├── retrieval_metrics │ ├── __init__.py │ └── utils.py ├── classification_metrics │ ├── __init__.py │ ├── utils.py │ ├── recall.py │ ├── negative_predictive_value.py │ └── false_positive_rate.py ├── losses │ ├── __init__.py │ └── metric_loss.py ├── training_metrics │ ├── __init__.py │ └── utils.py ├── visualization │ ├── __init__.py │ └── vizualize_views.py ├── samplers │ ├── __init__.py │ └── file_samplers.py ├── evaluators │ └── __init__.py └── distances │ ├── cosine.py │ ├── manhattan.py │ ├── inner_product.py │ ├── snr.py │ └── distance.py ├── .flake8 ├── pyproject.toml ├── api └── TFSimilarity │ ├── all_symbols.md │ ├── api_report.pb │ ├── distances │ ├── ABC.md │ ├── distance_canonicalizer.md │ ├── Distance.md │ ├── SNRDistance.md │ └── SquaredEuclideanDistance.md │ ├── indexer │ ├── PandasDataFrame.md │ ├── time.md │ ├── CalibrationResults.md │ ├── make_classification_matcher.md │ └── Lookup.md │ ├── training_metrics │ ├── avg_neg.md │ ├── avg_pos.md │ ├── max_neg.md │ ├── max_pos.md │ ├── min_neg.md │ ├── min_pos.md │ ├── sum_neg.md │ ├── sum_pos.md │ ├── dist_gap.md │ └── batch_class_ratio.md │ ├── utils │ ├── tf_cap_memory.md │ └── is_tensor_or_variable.md │ ├── callbacks │ ├── unpack_lookup_labels.md │ ├── unpack_lookup_distances.md │ ├── unpack_results.md │ ├── make_classification_metric.md │ └── SplitValidationLoss.md │ ├── layers │ ├── ActivationStdLoggingLayer.md │ └── MetricEmbedding.md │ ├── models.md │ ├── architectures.md │ ├── augmenters.md │ ├── visualization.md │ ├── matchers.md │ ├── retrieval_metrics.md │ ├── stores.md │ ├── layers.md │ ├── utils.md │ ├── losses.md │ ├── training_metrics.md │ ├── search.md │ ├── classification_metrics.md │ ├── evaluators.md │ ├── distances.md │ ├── visualization │ ├── viz_neigbors_imgs.md │ ├── visualize_views.md │ ├── confusion_matrix.md │ └── projector.md │ ├── augmenters │ ├── Augmenter.md │ ├── ContrastiveAugmenter.md │ ├── BarlowAugmenter.md │ └── SimCLRAugmenter.md │ ├── samplers.md │ ├── samplers │ └── select_examples.md │ ├── callbacks.md │ └── architectures │ └── ResNet18Sim.md ├── mypy.ini ├── .gitattributes ├── assets └── images │ ├── tfsim_overview.png │ ├── contrastive_batch.png │ ├── similar-cats-and-dogs.jpg │ ├── contrastive_model_terms.png │ ├── self_supervised_overview.png │ ├── contrastive_loss_functions.png │ └── similarity-learning-overview.png ├── pytest.ini ├── .git-blame-ignore-revs ├── scripts ├── README.md └── increment_version.py ├── .pre-commit-config.yaml ├── .coveragerc ├── examples └── README.md ├── .github └── workflows │ ├── python-publish.yml │ ├── nightly-publish.yml │ └── test.yml ├── .gitignore └── CONTRIBUTING.md /tests/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/matchers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/stores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/retrieval_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/training_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/classification_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/supervised/requirements.txt: -------------------------------------------------------------------------------- 1 | keras-tuner 2 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/augmentation_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/supervised/components/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils # noqa 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203 3 | exclude = tmp.py, tests/ 4 | max-line-length = 120 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [tool.black] 5 | line-length = 119 6 | -------------------------------------------------------------------------------- /api/TFSimilarity/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow Similarity 2 | 3 | 4 | ## Primary symbols -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | warn_return_any = True 4 | warn_unused_configs = True -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.{cmd,[cC][mM][dD]} text eol=crlf 3 | *.{bat,[bB][aA][tT]} text eol=crlf -------------------------------------------------------------------------------- /api/TFSimilarity/api_report.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/api/TFSimilarity/api_report.pb -------------------------------------------------------------------------------- /assets/images/tfsim_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/tfsim_overview.png -------------------------------------------------------------------------------- /assets/images/contrastive_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_batch.png -------------------------------------------------------------------------------- /assets/images/similar-cats-and-dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/similar-cats-and-dogs.jpg -------------------------------------------------------------------------------- /tests/data/mnist_fashion_embeddings.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/tests/data/mnist_fashion_embeddings.npz -------------------------------------------------------------------------------- /assets/images/contrastive_model_terms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_model_terms.png -------------------------------------------------------------------------------- /assets/images/self_supervised_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/self_supervised_overview.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | CURRRENT_DIR = Path(__file__).parent.absolute() 4 | DATA_DIR = CURRRENT_DIR / "data/" 5 | -------------------------------------------------------------------------------- /assets/images/contrastive_loss_functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_loss_functions.png -------------------------------------------------------------------------------- /assets/images/similarity-learning-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/similarity-learning-overview.png -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli = 1 3 | log_level = INFO 4 | filterwarnings = 5 | ignore::DeprecationWarning 6 | ignore::PendingDeprecationWarning -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Since git version 2.23, git-blame has a feature to ignore any commit in this file. 2 | 3 | # Migrate code to black + isort 4 | 2821d07e7248c412384734b1592b94140754897d 5 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from termcolor import cprint 3 | 4 | 5 | def pytest_configure(config): 6 | tf.config.set_visible_devices([], "GPU") 7 | cprint("Tensorflow set to CPU", "green") 8 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Utility script 2 | 3 | 4 | ## Generate API documentation 5 | 6 | From `scripts/` directory 7 | 8 | ``` 9 | pip install git+https://github.com/tensorflow/docs 10 | python gen_api_doc.py 11 | ``` -------------------------------------------------------------------------------- /api/TFSimilarity/distances/ABC.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.distances.ABC 2 | 3 | 4 | 5 | 6 | 7 | 8 | Helper class that provides a standard way to create an ABC using 9 | 10 | inheritance. 11 | 12 | -------------------------------------------------------------------------------- /api/TFSimilarity/indexer/PandasDataFrame.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.indexer.PandasDataFrame 2 | 3 | 4 | 5 | 6 | 7 | Symbolic pandas frame 8 | 9 | Pandas type are too loose you get an Any. We want a DataFrame 10 | 11 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/avg_neg.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.avg_neg 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.avg_neg( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/avg_pos.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.avg_pos 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.avg_pos( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/max_neg.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.max_neg 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.max_neg( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/max_pos.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.max_pos 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.max_pos( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/min_neg.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.min_neg 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.min_neg( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/min_pos.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.min_pos 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.min_pos( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/sum_neg.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.sum_neg 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.sum_neg( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/sum_pos.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.sum_pos 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.sum_pos( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/dist_gap.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.dist_gap 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ```python 10 | TFSimilarity.training_metrics.dist_gap( 11 | distance 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/utils/tf_cap_memory.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.utils.tf_cap_memory 2 | 3 | 4 | 5 | 6 | 7 | Avoid TF to hog memory before needing it 8 | 9 | ```python 10 | TFSimilarity.utils.tf_cap_memory() 11 | ``` 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /api/TFSimilarity/utils/is_tensor_or_variable.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.utils.is_tensor_or_variable 2 | 3 | 4 | 5 | 6 | 7 | check if a variable is tf.Tensor or tf.Variable 8 | 9 | ```python 10 | TFSimilarity.utils.is_tensor_or_variable( 11 | x 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.6.0 4 | hooks: 5 | - id: black 6 | stages: ['commit'] 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | name: isort (python) 13 | stages: ['commit'] 14 | -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks/unpack_lookup_labels.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.callbacks.unpack_lookup_labels 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ```python 11 | TFSimilarity.callbacks.unpack_lookup_labels( 12 | lookups: Sequence[Sequence[Lookup]], dtype: Union[str, tf.DType] 13 | ``` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /benchmark/supervised/components/__init__.py: -------------------------------------------------------------------------------- 1 | from . import metrics # noqa 2 | from . import utils # noqa 3 | from .architectures import make_architecture # noqa 4 | from .augmentations import make_augmentations # noqa 5 | from .experiments import make_experiments # noqa 6 | from .losses import make_loss # noqa 7 | from .optimizers import make_optimizer # noqa 8 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = tensorflow_similarity 3 | 4 | omit = 5 | # omit anything in an api directory anywhere 6 | */api/* 7 | # omit all __init__.py files 8 | */__init__.py 9 | 10 | [report] 11 | omit = 12 | # omit anything in an api directory anywhere 13 | */api/* 14 | # omit all __init__.py files 15 | */__init__.py 16 | -------------------------------------------------------------------------------- /api/TFSimilarity/indexer/time.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.indexer.time 2 | 3 | 4 | 5 | 6 | 7 | 8 | time() -> floating point number 9 | 10 | ```python 11 | TFSimilarity.indexer.time() 12 | ``` 13 | 14 | 15 | 16 | 17 | 18 | Return the current time in seconds since the Epoch. 19 | Fractions of a second may be present if the system clock provides them. -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks/unpack_lookup_distances.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.callbacks.unpack_lookup_distances 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ```python 11 | TFSimilarity.callbacks.unpack_lookup_distances( 12 | lookups: Sequence[Sequence[Lookup]], 13 | dtype: Union[str, tf.DType], 14 | distance_rounding: Optional[int] = None 15 | ``` 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /api/TFSimilarity/layers/ActivationStdLoggingLayer.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.layers.ActivationStdLoggingLayer 2 | 3 | 4 | 5 | 6 | 7 | Computes the mean std of the activations of a layer. 8 | 9 | ```python 10 | TFSimilarity.layers.ActivationStdLoggingLayer( 11 | name, **kwargs 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | 19 | x = reduce_std(l2_normalize(inputs, axis=0), axis=-1) 20 | 21 | And then aggregate the per-batch mean of x over each epoch. 22 | 23 | -------------------------------------------------------------------------------- /benchmark/supervised/README.md: -------------------------------------------------------------------------------- 1 | # Similarity Metric Benchmarking System 2 | 3 | ## Current results 4 | 5 | ## Running the benchmark 6 | 7 | 1. Generate initial training data with preprocessing: `python create_datasets.py -c config/benchmark_prod.json -f "cars.*lamb.*"` 8 | 9 | 2. Train and Hypertune the model parameters: `python hyper_parameter_search.py -c config/benchmark_prod.json -f "cars.*lamb.*"` 10 | 11 | 2. Evaluate performance: 12 | 13 | 3. Analyze results using notebook: 14 | -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks/unpack_results.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.callbacks.unpack_results 2 | 3 | 4 | 5 | 6 | 7 | Updates logs, writes summary, and returns list of strings of 8 | 9 | 10 | ```python 11 | TFSimilarity.callbacks.unpack_results( 12 | results: Dict[str, np.ndarray], 13 | epoch: int, 14 | logs: dict, 15 | tb_writer: tf.summary.SummaryWriter, 16 | name_suffix: Optional[str] = 17 | ) -> List[str] 18 | ``` 19 | 20 | 21 | 22 | 23 | evaluation metric -------------------------------------------------------------------------------- /api/TFSimilarity/models.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.models 2 | 3 | 4 | 5 | 6 | 7 | Specialized `keras.model.Model` subclasses that offers additional 8 | 9 | functionalities to make training and serving similarity models quick and easy. 10 | 11 | ## Classes 12 | 13 | - [`class ContrastiveModel`](../TFSimilarity/models/ContrastiveModel.md): `Model` groups layers into an object with training and inference features. 14 | 15 | - [`class SimilarityModel`](../TFSimilarity/callbacks/SimilarityModel.md): Specialized Keras.Model which implement the core features needed for 16 | 17 | -------------------------------------------------------------------------------- /api/TFSimilarity/architectures.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.architectures 2 | 3 | 4 | 5 | 6 | 7 | Deep-learning model architectures tweaked to work well for similarity 8 | 9 | training 10 | 11 | ## Functions 12 | 13 | - [`EfficientNetSim(...)`](../TFSimilarity/architectures/EfficientNetSim.md): Build an EfficientNet Model backbone for similarity learning 14 | 15 | - [`ResNet18Sim(...)`](../TFSimilarity/architectures/ResNet18Sim.md): Build an ResNet18 Model backbone for similarity learning 16 | 17 | - [`ResNet50Sim(...)`](../TFSimilarity/architectures/ResNet50Sim.md): Build an ResNet50 Model backbone for similarity learning 18 | 19 | -------------------------------------------------------------------------------- /tests/matchers/test_match_nearest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.matchers import MatchNearest 5 | 6 | 7 | def test_predict(): 8 | mn = MatchNearest() 9 | 10 | lookup_labels = tf.constant([[10, 12], [20, 13], [30, 14], [40, 15]]) 11 | lookup_distances = tf.constant([[1.0, 1.5], [1.0, 1.7], [2.0, 2.1], [2.0, 2.2]]) 12 | 13 | d_labels, d_dist = mn.derive_match(lookup_labels, lookup_distances) 14 | 15 | np.testing.assert_array_equal(d_labels.numpy(), np.array([[10], [20], [30], [40]])) 16 | np.testing.assert_array_equal(d_dist.numpy(), np.array([[1.0], [1.0], [2.0], [2.0]])) 17 | -------------------------------------------------------------------------------- /api/TFSimilarity/augmenters.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.augmenters 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class Augmenter`](../TFSimilarity/augmenters/Augmenter.md): Helper class that provides a standard way to create an ABC using 14 | 15 | - [`class BarlowAugmenter`](../TFSimilarity/augmenters/BarlowAugmenter.md): Helper class that provides a standard way to create an ABC using 16 | 17 | - [`class ContrastiveAugmenter`](../TFSimilarity/augmenters/ContrastiveAugmenter.md): Helper class that provides a standard way to create an ABC using 18 | 19 | - [`class SimCLRAugmenter`](../TFSimilarity/augmenters/SimCLRAugmenter.md): SimCLR augmentation pipeline as defined in 20 | 21 | -------------------------------------------------------------------------------- /api/TFSimilarity/layers/MetricEmbedding.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.layers.MetricEmbedding 2 | 3 | 4 | 5 | 6 | 7 | L2 Normalized `Dense` layer. 8 | 9 | ```python 10 | TFSimilarity.layers.MetricEmbedding( 11 | units, 12 | activation=None, 13 | use_bias=True, 14 | kernel_initializer=glorot_uniform, 15 | bias_initializer=zeros, 16 | kernel_regularizer=None, 17 | bias_regularizer=None, 18 | activity_regularizer=None, 19 | kernel_constraint=None, 20 | bias_constraint=None, 21 | **kwargs 22 | ) 23 | ``` 24 | 25 | 26 | 27 | 28 | 29 | This layer is usually used as output layer, especially when using cosine 30 | distance as the similarity metric. 31 | 32 | -------------------------------------------------------------------------------- /api/TFSimilarity/visualization.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.visualization 2 | 3 | 4 | 5 | 6 | 7 | Collection of specialized notebook vizualization tools 8 | 9 | 10 | 11 | ## Functions 12 | 13 | - [`confusion_matrix(...)`](../TFSimilarity/visualization/confusion_matrix.md): Plot confusion matrix 14 | 15 | - [`projector(...)`](../TFSimilarity/visualization/projector.md): Visualize the embeddings in 2D or 3D using UMAP projection 16 | 17 | - [`visualize_views(...)`](../TFSimilarity/visualization/visualize_views.md): Display side by side different image views with labels, and predictions 18 | 19 | - [`viz_neigbors_imgs(...)`](../TFSimilarity/visualization/viz_neigbors_imgs.md): Display images nearest neighboors 20 | 21 | -------------------------------------------------------------------------------- /tests/matchers/test_majority_vote.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.matchers import MatchMajorityVote 5 | 6 | 7 | def test_predict(): 8 | mn = MatchMajorityVote() 9 | 10 | lookup_labels = tf.constant( 11 | [ 12 | [10, 12, 10, 12, 10], 13 | # Ties should take the closer label. 14 | [20, 13, 13, 20, 30], 15 | ] 16 | ) 17 | lookup_distances = tf.constant([[1.0, 1.1, 1.2, 1.3, 1.4], [2.0, 2.1, 2.2, 2.3, 2.4]]) 18 | 19 | d_labels, d_dist = mn.derive_match(lookup_labels, lookup_distances) 20 | 21 | np.testing.assert_array_equal(d_labels.numpy(), np.array([[10], [20]])) 22 | np.testing.assert_allclose(d_dist.numpy(), np.array([[1.2], [2.2]])) 23 | -------------------------------------------------------------------------------- /tensorflow_similarity/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | __version__ = "0.18.0.dev12" 15 | 16 | 17 | from . import models # noqa 18 | from . import samplers # noqa 19 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .augmenter import Augmenter # noqa 15 | from .barlow import BarlowAugmenter # noqa 16 | from .contrastive import ContrastiveAugmenter # noqa 17 | from .simclr import SimCLRAugmenter # noqa 18 | -------------------------------------------------------------------------------- /api/TFSimilarity/matchers.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.matchers 2 | 3 | 4 | 5 | 6 | 7 | Matchers define the classification matching strategy when using 8 | 9 | similarity models to match query examples to the classes of indexed 10 | examples. 11 | 12 | ## Classes 13 | 14 | - [`class ClassificationMatch`](../TFSimilarity/callbacks/ClassificationMatch.md): Abstract base class for defining the classification matching strategy. 15 | 16 | - [`class MatchMajorityVote`](../TFSimilarity/matchers/MatchMajorityVote.md): Match metrics for the most common label in a result set. 17 | 18 | - [`class MatchNearest`](../TFSimilarity/matchers/MatchNearest.md): Match metrics for labels at k=1. 19 | 20 | ## Functions 21 | 22 | - [`make_classification_matcher(...)`](../TFSimilarity/indexer/make_classification_matcher.md): Convert classification matcher from str name to object if needed. 23 | 24 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics/batch_class_ratio.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.training_metrics.batch_class_ratio 2 | 3 | 4 | 5 | 6 | 7 | Computes the average number of examples per class within each batch. 8 | 9 | ```python 10 | TFSimilarity.training_metrics.batch_class_ratio( 11 | sampler: Sampler, num_batches: int = 100 12 | ) -> float 13 | ``` 14 | 15 | 16 | 17 | 18 | Similarity learning requires at least 2 examples per class in each batch. 19 | This is needed in order to construct the triplets. This function 20 | provides the average number of examples per class within each batch and 21 | can be used to check that a sampler is working correctly. 22 | The ratio should be >= 2. 23 | Args: 24 | sampler: A tf.similarity sampler object. 25 | num_batches: The number of batches to sample. 26 | Returns: 27 | The average number of examples per class. -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ### TensorFlow Similarity examples 2 | 3 | 4 | | Tutorial | Learning Type | Description | 5 | | ------ | :-----: | ---------- | 6 | | [Hello World](./supervised_hello_world.ipynb) | Supervised | Train and use an image similarity model to find similar looking MNIST digits | 7 | | [Self-Supervised Learning](./unsupervised_hello_world.ipynb) | Unsupervised | Train an image model using the SimSiam based self-supervised contrastive learning. | 8 | | [visualization](./visualization.ipynb) | Supervised | Train an image similarity model on the Stanford Dogs dataset using Evaluation Callbacks and the interactive visualizer | 9 | | [Sampler IO Cookbook](./sampler_io_cookbook.ipynb) | Utils | Examples demonstrating how to use the various in memory batch samplers. | 10 | | [CLIP finetuning](./multimodal_example.ipynb) | Supervised | Finetune CLIP on atric-dataset using multiple negatives ranking loss. 11 | -------------------------------------------------------------------------------- /tests/samplers/test_tfdataset_samplers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler 4 | 5 | 6 | class DatasetSamplersTest(tf.test.TestCase): 7 | def test_basic(self): 8 | dataset_name = "mnist" 9 | sampler = TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=10) 10 | batch = sampler.generate_batch(42) 11 | self.assertEqual(batch[0].shape, (20, 28, 28, 1)) 12 | 13 | def test_wrong_key(self): 14 | dataset_name = "mnist" 15 | 16 | # X 17 | with self.assertRaises(ValueError): 18 | TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=4, x_key="error") 19 | # Y 20 | with self.assertRaises(ValueError): 21 | TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=4, y_key="error") 22 | -------------------------------------------------------------------------------- /tensorflow_similarity/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Deep-learning model architectures tweaked to work well for similarity 15 | training""" 16 | from .efficientnet import EfficientNetSim # noqa 17 | from .resnet18 import ResNet18Sim # noqa 18 | from .resnet50 import ResNet50Sim # noqa 19 | -------------------------------------------------------------------------------- /tests/retrieval_metrics/test_recall_at_k.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.retrieval_metrics import RecallAtK 5 | 6 | testdata = [ 7 | ( 8 | "micro", 9 | tf.constant(0.75), 10 | ), 11 | ( 12 | "macro", 13 | tf.constant(0.5), 14 | ), 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("avg, expected", testdata, ids=["micro", "macro"]) 19 | def test_compute(avg, expected): 20 | query_labels = tf.constant([1, 1, 1, 0]) 21 | match_mask = tf.constant( 22 | [ 23 | [True, True, False], 24 | [True, True, False], 25 | [True, True, False], 26 | [False, False, True], 27 | ], 28 | dtype=bool, 29 | ) 30 | rm = RecallAtK(k=2, average=avg) 31 | 32 | recall = rm.compute(query_labels=query_labels, match_mask=match_mask) 33 | assert recall == expected 34 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmentation_utils/test_solarize.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters.augmentation_utils import solarize 5 | 6 | 7 | def create_img(width=32, height=32, channels=3): 8 | return tf.random.uniform([width, height, channels], 0, 1) 9 | 10 | 11 | def test_solarization(): 12 | img = create_img() 13 | 14 | random_solarize_always = solarize.random_solarize(img, p=1, thresh=0, pixel_min=0, pixel_max=1) 15 | random_solarize_never = solarize.random_solarize(img, p=0, thresh=0, pixel_min=0, pixel_max=1) 16 | 17 | # check shapes 18 | assert (tf.shape(random_solarize_always) == tf.shape(img)).numpy().all() 19 | assert (tf.shape(random_solarize_never) == tf.shape(img)).numpy().all() 20 | 21 | # check if flip works 22 | assert not (random_solarize_always == img).numpy().all() 23 | assert (random_solarize_never == img).numpy().all() 24 | -------------------------------------------------------------------------------- /api/TFSimilarity/retrieval_metrics.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.retrieval_metrics 2 | 3 | 4 | 5 | 6 | 7 | Retrieval metrics measure the quality of the embedding space given a 8 | 9 | set query examples and a set of indexed examples. Informally it can be thought 10 | of as how well the space is clustered among other things. 11 | 12 | ## Classes 13 | 14 | - [`class BNDCG`](../TFSimilarity/retrieval_metrics/BNDCG.md): Binary normalized discounted cumulative gain. 15 | 16 | - [`class MapAtK`](../TFSimilarity/retrieval_metrics/MapAtK.md): Mean Average precision - mAP@K is computed as. 17 | 18 | - [`class PrecisionAtK`](../TFSimilarity/retrieval_metrics/PrecisionAtK.md): Precision@K is computed as. 19 | 20 | - [`class RecallAtK`](../TFSimilarity/retrieval_metrics/RecallAtK.md): The metric learning version of Recall@K. 21 | 22 | - [`class RetrievalMetric`](../TFSimilarity/indexer/RetrievalMetric.md): Abstract base class for computing retrieval metrics. 23 | 24 | -------------------------------------------------------------------------------- /tests/retrieval_metrics/test_precision_at_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tensorflow_similarity.retrieval_metrics import PrecisionAtK 6 | 7 | testdata = [ 8 | ( 9 | "micro", 10 | tf.constant(0.583333333), 11 | ), 12 | ( 13 | "macro", 14 | tf.constant(0.5), 15 | ), 16 | ] 17 | 18 | 19 | @pytest.mark.parametrize("avg, expected", testdata, ids=["micro", "macro"]) 20 | def test_compute(avg, expected): 21 | query_labels = tf.constant([1, 1, 1, 0]) 22 | match_mask = tf.constant( 23 | [ 24 | [True, True, False], 25 | [True, True, False], 26 | [True, True, False], 27 | [False, False, True], 28 | ], 29 | dtype=bool, 30 | ) 31 | rm = PrecisionAtK(k=3, average=avg) 32 | 33 | precision = rm.compute(query_labels=query_labels, match_mask=match_mask) 34 | np.testing.assert_allclose(precision, expected, atol=1e-05) 35 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmenters.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters import BarlowAugmenter, SimCLRAugmenter 5 | 6 | 7 | def create_imgs(width=32, height=32, channels=3, num=5): 8 | return tf.random.uniform([num, width, height, channels], 0, 1) 9 | 10 | 11 | def test_barlow(): 12 | imgs = create_imgs() 13 | WIDTH = 32 14 | HEIGHT = 32 15 | CHANNELS = 3 16 | NUM = 5 17 | 18 | aug = BarlowAugmenter(WIDTH, HEIGHT) 19 | augmented = aug.augment(imgs) 20 | 21 | assert (tf.shape(augmented) == tf.constant([2, NUM, WIDTH, HEIGHT, CHANNELS])).numpy().all() 22 | 23 | 24 | def test_simclr(): 25 | imgs = create_imgs() 26 | WIDTH = 32 27 | HEIGHT = 32 28 | CHANNELS = 3 29 | NUM = 5 30 | 31 | aug = SimCLRAugmenter(HEIGHT, WIDTH) 32 | augmented = aug.augment(imgs, tf.constant([0]), 2, True) 33 | 34 | assert (tf.shape(augmented) == tf.constant([2, NUM, WIDTH, HEIGHT, CHANNELS])).numpy().all() 35 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: # Allow manual triggers 5 | 6 | jobs: 7 | publish: 8 | # Prevent Publish from running on forks. 9 | if: | 10 | github.ref == 'refs/heads/master' && 11 | github.repository == 'tensorflow/similarity' 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.9' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | 26 | - name: Install package 27 | run: | 28 | pip install ".[tensorflow,dev]" 29 | 30 | - name: Build package 31 | run: | 32 | python setup.py sdist bdist_wheel 33 | 34 | - name: Publish package 35 | run: | 36 | twine upload -u ${{ secrets.PYPI_USERNAME }} -p ${{ secrets.PYPY_TOKEN }} dist/* --verbose 37 | -------------------------------------------------------------------------------- /api/TFSimilarity/stores.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.stores 2 | 3 | 4 | 5 | 6 | 7 | Key Values Stores store the data associated with the embeddings indexed by 8 | 9 | the `Indexer()`. 10 | 11 | Each key of the store represent a **record** that contains information 12 | about a given embedding. 13 | 14 | The main use-case for the store is to retrieve the records associated 15 | with the ids returned by a nearest neigboor search performed with the 16 | - [`Search()`](../search/). 17 | 18 | Additionally one might want to inspect the content of the index which is why 19 | `Store()` class may implement an export to 20 | a [Pandas Dataframe](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) 21 | via the `to_pandas()` method. 22 | 23 | ## Classes 24 | 25 | - [`class MemoryStore`](../TFSimilarity/indexer/MemoryStore.md): Efficient in-memory dataset store 26 | 27 | - [`class Store`](../TFSimilarity/indexer/Store.md): Helper class that provides a standard way to create an ABC using 28 | 29 | -------------------------------------------------------------------------------- /tensorflow_similarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Specialized `keras.model.Model` subclasses that offers additional 15 | functionalities to make training and serving similarity models quick and easy. 16 | """ 17 | 18 | from .contrastive_model import ContrastiveModel # noqa 19 | from .contrastive_model import create_contrastive_model # noqa 20 | from .similarity_model import SimilarityModel # noqa 21 | -------------------------------------------------------------------------------- /api/TFSimilarity/distances/distance_canonicalizer.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.distances.distance_canonicalizer 2 | 3 | 4 | 5 | 6 | 7 | Normalize user requested distance to its matching Distance object. 8 | 9 | 10 | ```python 11 | TFSimilarity.distances.distance_canonicalizer( 12 | ``` 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 31 | 32 |
26 | user_distance 27 | 29 | Requested distance either by name or by object 30 |
33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 48 | 49 |
43 | Distance 44 | 46 | Requested object name. 47 |
50 | 51 | -------------------------------------------------------------------------------- /api/TFSimilarity/layers.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.layers 2 | 3 | 4 | 5 | 6 | 7 | Specialized Similarity `keras.layers` 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class ActivationStdLoggingLayer`](../TFSimilarity/layers/ActivationStdLoggingLayer.md): Computes the mean std of the activations of a layer. 14 | 15 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor 16 | 17 | - [`class GeneralizedMeanPooling`](../TFSimilarity/layers/GeneralizedMeanPooling.md): This is the class from which all layers inherit. 18 | 19 | - [`class GeneralizedMeanPooling1D`](../TFSimilarity/layers/GeneralizedMeanPooling1D.md): Computes the Generalized Mean of each channel in a tensor. 20 | 21 | - [`class GeneralizedMeanPooling2D`](../TFSimilarity/layers/GeneralizedMeanPooling2D.md): Computes the Generalized Mean of each channel in a tensor. 22 | 23 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor 24 | 25 | - [`class MetricEmbedding`](../TFSimilarity/layers/MetricEmbedding.md): L2 Normalized `Dense` layer. 26 | 27 | -------------------------------------------------------------------------------- /tensorflow_similarity/matchers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Matchers define the classification matching strategy when using 15 | similarity models to match query examples to the classes of indexed 16 | examples.""" 17 | # Match Types 18 | from .classification_match import ClassificationMatch # noqa 19 | from .match_majority_vote import MatchMajorityVote # noqa 20 | from .match_nearest import MatchNearest # noqa 21 | from .utils import make_classification_matcher # noqa 22 | -------------------------------------------------------------------------------- /tests/losses/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def generate_perfect_test_batch(batch_size=32): 5 | """Generate a batch of embeddings and labels that will result in a perfect loss score.""" 6 | 7 | # y_true: labels 8 | y_true = tf.range(0, batch_size, dtype=tf.int32) 9 | y_true = tf.concat([y_true, y_true], axis=0) 10 | 11 | # y_preds: embedding 12 | y_preds = tf.one_hot(y_true, depth=batch_size, dtype=tf.float32) 13 | 14 | y_true = tf.expand_dims(y_true, axis=1) 15 | return y_true, y_preds 16 | 17 | 18 | def generate_bad_test_batch(batch_size=32): 19 | """Generate a batch of embeddings and labels that will result in a mismatch for all classes.""" 20 | 21 | # y_true: labels 22 | y_true = tf.range(0, batch_size, dtype=tf.int32) 23 | 24 | # y_preds: embedding 25 | y_preds = tf.concat([y_true, y_true[::-1]], axis=0) 26 | y_preds = tf.one_hot(y_preds, depth=batch_size, dtype=tf.float32) 27 | 28 | y_true = tf.concat([y_true, y_true], axis=0) 29 | y_true = tf.expand_dims(y_true, axis=1) 30 | return y_true, y_preds 31 | -------------------------------------------------------------------------------- /tensorflow_similarity/retrieval_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Retrieval metrics measure the quality of the embedding space given a 15 | set query examples and a set of indexed examples. Informally it can be thought 16 | of as how well the space is clustered among other things. 17 | """ 18 | from .bndcg import BNDCG # noqa 19 | from .map_at_k import MapAtK # noqa 20 | from .precision_at_k import PrecisionAtK # noqa 21 | from .recall_at_k import RecallAtK # noqa 22 | from .retrieval_metric import RetrievalMetric # noqa 23 | -------------------------------------------------------------------------------- /api/TFSimilarity/utils.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.utils 2 | 3 | 4 | 5 | 6 | 7 | Various utilities functions for improved quality of life. 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class BoolTensor`](../TFSimilarity/utils/BoolTensor.md): Bool tensor 14 | 15 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor 16 | 17 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor 18 | 19 | - [`class Lookup`](../TFSimilarity/indexer/Lookup.md): Metadata associated with a query match. 20 | 21 | ## Functions 22 | 23 | - [`is_tensor_or_variable(...)`](../TFSimilarity/utils/is_tensor_or_variable.md): check if a variable is tf.Tensor or tf.Variable 24 | 25 | - [`tf_cap_memory(...)`](../TFSimilarity/utils/tf_cap_memory.md): Avoid TF to hog memory before needing it 26 | 27 | - [`unpack_lookup_distances(...)`](../TFSimilarity/callbacks/unpack_lookup_distances.md) 28 | 29 | - [`unpack_lookup_labels(...)`](../TFSimilarity/callbacks/unpack_lookup_labels.md) 30 | 31 | - [`unpack_results(...)`](../TFSimilarity/callbacks/unpack_results.md): Updates logs, writes summary, and returns list of strings of 32 | 33 | -------------------------------------------------------------------------------- /tensorflow_similarity/classification_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classification metrics measure matching classification quality between a 15 | set query examples and a set of indexed examples. """ 16 | from .binary_accuracy import BinaryAccuracy # noqa 17 | from .classification_metric import ClassificationMetric # noqa 18 | from .f1_score import F1Score # noqa 19 | from .false_positive_rate import FalsePositiveRate # noqa 20 | from .negative_predictive_value import NegativePredictiveValue # noqa 21 | from .precision import Precision # noqa 22 | from .recall import Recall # noqa 23 | from .utils import make_classification_metric # noqa 24 | -------------------------------------------------------------------------------- /api/TFSimilarity/losses.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.losses 2 | 3 | 4 | 5 | 6 | 7 | Contrastive learning specialized losses. 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class Barlow`](../TFSimilarity/losses/Barlow.md): Computes the Barlow Loss between two batches of embeddings. 14 | 15 | - [`class CircleLoss`](../TFSimilarity/losses/CircleLoss.md): Computes the CircleLoss 16 | 17 | - [`class MetricLoss`](../TFSimilarity/losses/MetricLoss.md): Wraps a loss function in the `Loss` class. 18 | 19 | - [`class MultiSimilarityLoss`](../TFSimilarity/losses/MultiSimilarityLoss.md): Computes the multi similarity loss in an online fashion. 20 | 21 | - [`class PNLoss`](../TFSimilarity/losses/PNLoss.md): Computes the PN loss in an online fashion. 22 | 23 | - [`class SimCLRLoss`](../TFSimilarity/losses/SimCLRLoss.md): SimCLR Loss 24 | 25 | - [`class SimSiamLoss`](../TFSimilarity/losses/SimSiamLoss.md): SimSiam Loss 26 | 27 | - [`class SoftNearestNeighborLoss`](../TFSimilarity/losses/SoftNearestNeighborLoss.md): Computes the soft nearest neighbors loss in an online fashion. 28 | 29 | - [`class TripletLoss`](../TFSimilarity/losses/TripletLoss.md): Computes the triplet loss in an online fashion. 30 | 31 | - [`class VicReg`](../TFSimilarity/losses/VicReg.md): VicReg Loss 32 | 33 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmentation_utils/test_crop.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters.augmentation_utils import cropping 5 | 6 | 7 | def create_img(width=32, height=32, channels=3): 8 | return tf.random.uniform([width, height, channels], 0, 1) 9 | 10 | 11 | def test_center_cropping(): 12 | img = create_img() 13 | WIDTH = 32 14 | HEIGHT = 32 15 | CHANNELS = 3 16 | 17 | center_cropped = cropping.center_crop(img, HEIGHT, WIDTH, 0.5) 18 | 19 | assert (tf.shape(center_cropped) == tf.shape(img)).numpy().all() 20 | 21 | 22 | def test_random_cropping(): 23 | # Random Crop 24 | img = create_img() 25 | WIDTH = 32 26 | HEIGHT = 32 27 | CHANNELS = 3 28 | 29 | random_cropping_always = cropping.random_crop_with_resize(img, HEIGHT, WIDTH, 1) 30 | 31 | random_cropping_never = cropping.random_crop_with_resize(img, HEIGHT, WIDTH, 0) 32 | 33 | # check shapes 34 | assert (tf.shape(random_cropping_always) == tf.shape(img)).numpy().all() 35 | assert (tf.shape(random_cropping_never) == tf.shape(img)).numpy().all() 36 | 37 | # check if crop works 38 | assert not (random_cropping_always == img).numpy().all() 39 | assert (random_cropping_never == img).numpy().all() 40 | -------------------------------------------------------------------------------- /api/TFSimilarity/training_metrics.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.training_metrics 2 | 3 | 4 | 5 | 6 | 7 | Specialized `keras.metrics` that tracks how distances evolves during training. 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class DistanceGapMetric`](../TFSimilarity/training_metrics/DistanceGapMetric.md): Encapsulates metric logic and state. 14 | 15 | - [`class DistanceMetric`](../TFSimilarity/training_metrics/DistanceMetric.md): Encapsulates metric logic and state. 16 | 17 | ## Functions 18 | 19 | - [`avg_neg(...)`](../TFSimilarity/training_metrics/avg_neg.md) 20 | 21 | - [`avg_pos(...)`](../TFSimilarity/training_metrics/avg_pos.md) 22 | 23 | - [`batch_class_ratio(...)`](../TFSimilarity/training_metrics/batch_class_ratio.md): Computes the average number of examples per class within each batch. 24 | 25 | - [`dist_gap(...)`](../TFSimilarity/training_metrics/dist_gap.md) 26 | 27 | - [`max_neg(...)`](../TFSimilarity/training_metrics/max_neg.md) 28 | 29 | - [`max_pos(...)`](../TFSimilarity/training_metrics/max_pos.md) 30 | 31 | - [`min_neg(...)`](../TFSimilarity/training_metrics/min_neg.md) 32 | 33 | - [`min_pos(...)`](../TFSimilarity/training_metrics/min_pos.md) 34 | 35 | - [`sum_neg(...)`](../TFSimilarity/training_metrics/sum_neg.md) 36 | 37 | - [`sum_pos(...)`](../TFSimilarity/training_metrics/sum_pos.md) 38 | 39 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/augmenter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import abc 17 | 18 | from ..types import Tensor 19 | 20 | 21 | class Augmenter(abc.ABC): 22 | @abc.abstractmethod 23 | def augment( 24 | self, 25 | x: Tensor, 26 | y: Tensor, 27 | num_augmentations_per_example: int, 28 | is_warmup: bool, 29 | ) -> list[Tensor]: 30 | pass 31 | 32 | def __call__( 33 | self, 34 | x: Tensor, 35 | y: Tensor, 36 | num_augmentations_per_example: int, 37 | is_warmup: bool, 38 | ) -> list[Tensor]: 39 | return self.augment(x, y, num_augmentations_per_example, is_warmup) 40 | -------------------------------------------------------------------------------- /benchmark/supervised/components/augmentations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from functools import partial 5 | from typing import Any 6 | 7 | from tensorflow.image import random_flip_left_right, random_flip_up_down 8 | 9 | from tensorflow_similarity.augmenters.augmentation_utils.cropping import ( 10 | center_crop, 11 | crop_and_resize, 12 | ) 13 | 14 | AUGMENTATIONS = {} 15 | AUGMENTATIONS["random_resized_crop"] = lambda p: partial( 16 | crop_and_resize, 17 | height=p.get("height", 277), 18 | width=p.get("width", 277), 19 | area_range=p.get("area_range", (0.2, 1.0)), 20 | ) 21 | AUGMENTATIONS["random_flip"] = ( 22 | lambda p: random_flip_left_right if p.get("mode", "horizontal") == "horizontal" else random_flip_up_down 23 | ) 24 | AUGMENTATIONS["center_crop"] = lambda p: partial( 25 | center_crop, 26 | height=p.get("height", 256), 27 | width=p.get("width", 256), 28 | crop_proportion=p.get("crop_proportion", 1.0), 29 | ) 30 | 31 | 32 | # TODO(ovallis): Return type should be tuple[Callable[[FloatTensor], FloatTensor]], but 33 | # mypy doesn't recogonize the return types of the callabels. 34 | def make_augmentations(cfg: Mapping[str, Any]) -> tuple[Any, ...]: 35 | return tuple([AUGMENTATIONS[aug_id](params) for aug_id, params in cfg.items()]) 36 | -------------------------------------------------------------------------------- /tensorflow_similarity/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Contrastive learning specialized losses. 17 | """ 18 | from .barlow import Barlow # noqa 19 | from .circle_loss import CircleLoss # noqa 20 | from .lifted_structure_loss import LiftedStructLoss # noqa 21 | from .metric_loss import MetricLoss # noqa 22 | from .multinegrank_loss import MultiNegativesRankLoss # noqa 23 | from .multisim_loss import MultiSimilarityLoss # noqa 24 | from .pn_loss import PNLoss # noqa 25 | from .simclr import SimCLRLoss # noqa 26 | from .simsiam import SimSiamLoss # noqa 27 | from .softnn_loss import SoftNearestNeighborLoss # noqa 28 | from .triplet_loss import TripletLoss # noqa 29 | from .vicreg import VicReg # noqa 30 | from .xbm_loss import XBM # noqa 31 | -------------------------------------------------------------------------------- /tensorflow_similarity/training_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | "Specialized `keras.metrics` that tracks how distances evolves during training." 15 | from .distance_metrics import DistanceGapMetric # noqa 16 | from .distance_metrics import DistanceMetric # noqa 17 | from .distance_metrics import avg_neg # noqa 18 | from .distance_metrics import avg_pos # noqa 19 | from .distance_metrics import dist_gap # noqa 20 | from .distance_metrics import max_neg # noqa 21 | from .distance_metrics import max_pos # noqa 22 | from .distance_metrics import min_neg # noqa 23 | from .distance_metrics import min_pos # noqa 24 | from .distance_metrics import sum_neg # noqa 25 | from .distance_metrics import sum_pos # noqa 26 | from .utils import batch_class_ratio # noqa 27 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/augmentation_utils/random_apply.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import annotations 16 | 17 | from collections.abc import Callable 18 | 19 | import tensorflow as tf 20 | 21 | from tensorflow_similarity.types import Tensor 22 | 23 | 24 | def random_apply(func: Callable[[Tensor], Tensor], p: float, x: Tensor) -> Tensor: 25 | """Randomly apply function func to x with probability p.""" 26 | return tf.cond( 27 | tf.less( 28 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32), 29 | tf.cast(p, tf.float32), 30 | ), 31 | lambda: func(x), 32 | lambda: x, 33 | ) 34 | -------------------------------------------------------------------------------- /api/TFSimilarity/indexer/CalibrationResults.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.indexer.CalibrationResults 2 | 3 | 4 | 5 | 6 | 7 | Cutpoints and thresholds associated with a calibration. 8 | 9 | ```python 10 | TFSimilarity.indexer.CalibrationResults( 11 | cutpoints: Mapping[str, Mapping[str, Union[str, float, int]]], 12 | thresholds: Mapping[str, Sequence[Union[float, int]]] 13 | ) 14 | ``` 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 32 | 37 | 38 | 41 | 46 | 47 |
30 | cutpoints 31 | 33 | A Dict mapping the cutpoint name to a Dict containing the 34 | ClassificationMetric values associated with a particular distance 35 | threshold, e.g., 'optimal' : 'acc': 0.90, 'f1': 0.92. 36 |
39 | thresholds 40 | 42 | A Dict mapping ClassificationMetric names to a list 43 | containing the metric's value computed at each of the distance 44 | thresholds, e.g., 'f1': [0.99, 0.80], 'distance': [0.0, 1.0]. 45 |
48 | 49 | 50 | 51 | ## Methods 52 | 53 |

__eq__

54 | 55 | ```python 56 | __eq__( 57 | other 58 | ) 59 | ``` 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/augmentation_utils/solarize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import annotations 16 | 17 | import tensorflow as tf 18 | 19 | from tensorflow_similarity.augmenters.augmentation_utils.random_apply import ( 20 | random_apply, 21 | ) 22 | from tensorflow_similarity.types import Tensor 23 | 24 | 25 | def random_solarize( 26 | image: Tensor, 27 | p: float = 0.2, 28 | pixel_min: int = 0, 29 | pixel_max: int = 255, 30 | thresh: int = 10, 31 | ) -> Tensor: 32 | def _transform(image: Tensor) -> Tensor: 33 | return tf.where(image < thresh, image, pixel_max - image) 34 | 35 | return random_apply(_transform, p=p, x=image) 36 | -------------------------------------------------------------------------------- /tests/integration/test_readme_minimal.py: -------------------------------------------------------------------------------- 1 | "Test that the example in the readme work correctly" 2 | import os 3 | 4 | import pytest 5 | from tensorflow.keras import layers 6 | 7 | from tensorflow_similarity.layers import MetricEmbedding 8 | from tensorflow_similarity.losses import MultiSimilarityLoss 9 | from tensorflow_similarity.models import SimilarityModel 10 | from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler 11 | 12 | 13 | @pytest.fixture 14 | def readme_path(request): 15 | """Helper to load README relative to the test file.""" 16 | # README path needs to be relative to the test. 17 | test_path = os.path.dirname(os.path.realpath(request.module.__file__)) 18 | return os.path.join(test_path, "..", "..", "README.md") 19 | 20 | 21 | def test_readme_text_directly(readme_path): 22 | """Quick and dirty test of the README.md code snippets.""" 23 | code = [] 24 | code_block = False 25 | 26 | with open(readme_path, "r") as f: 27 | for line in f: 28 | if line.endswith("```\n"): 29 | code_block = False 30 | 31 | # Add all code lines except for the viz function. 32 | if code_block and not line.startswith("viz_neighbors_imgs"): 33 | code.append(line) 34 | 35 | if line.startswith("```python"): 36 | code_block = True 37 | 38 | exec(("\n").join(code)) 39 | -------------------------------------------------------------------------------- /tensorflow_similarity/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | "Collection of specialized notebook vizualization tools" 15 | import importlib.util 16 | 17 | excs = [] 18 | for mod in ["PIL", "umap", "bokeh", "distinctipy", "matplotlib"]: 19 | if importlib.util.find_spec(mod) is None: 20 | excs.append( 21 | ModuleNotFoundError( 22 | f"{mod} is not installed. Please install it with `pip install tensorflow_similarity[visualization]`" 23 | ) 24 | ) 25 | 26 | if excs: 27 | raise Exception(excs) 28 | else: 29 | from .confusion_matrix_viz import confusion_matrix # noqa 30 | from .neighbors_viz import viz_neigbors_imgs # noqa 31 | from .projector import projector # noqa 32 | from .vizualize_views import visualize_views # noqa 33 | -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks/make_classification_metric.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.callbacks.make_classification_metric 2 | 3 | 4 | 5 | 6 | 7 | Convert classification metric from str name to object if needed. 8 | 9 | 10 | ```python 11 | TFSimilarity.callbacks.make_classification_metric( 12 | name: str = 13 | ``` 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 32 | 33 |
27 | metric 28 | 30 | ClassificationMetric() or metric name. 31 |
34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 49 | 50 |
44 | ValueError 45 | 47 | Unknown metric name: metric, typo? 48 |
51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 63 | 66 | 67 |
61 | ClassificationMetric 62 | 64 | Instantiated metric if needed. 65 |
68 | 69 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmentation_utils/test_flip.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters.augmentation_utils import flip 5 | 6 | 7 | def create_img(width=32, height=32, channels=3): 8 | return tf.random.uniform([width, height, channels], 0, 1) 9 | 10 | 11 | def test_flip_left_right(): 12 | img = create_img() 13 | 14 | random_flip_always = flip.random_random_flip_left_right(img, 1) 15 | random_flip_never = flip.random_random_flip_left_right(img, 0) 16 | 17 | # check shapes 18 | assert (tf.shape(random_flip_always) == tf.shape(img)).numpy().all() 19 | assert (tf.shape(random_flip_never) == tf.shape(img)).numpy().all() 20 | 21 | # check if flip works 22 | # assert not (random_flip_always == img).numpy().all() 23 | assert (random_flip_never == img).numpy().all() 24 | 25 | 26 | def test_flip_top_bottom(): 27 | img = create_img() 28 | 29 | random_flip_always = flip.random_random_flip_top_bottom(img, 1) 30 | random_flip_never = flip.random_random_flip_top_bottom(img, 0) 31 | 32 | # check shapes 33 | assert (tf.shape(random_flip_always) == tf.shape(img)).numpy().all() 34 | assert (tf.shape(random_flip_never) == tf.shape(img)).numpy().all() 35 | 36 | # check if flip works 37 | # assert not (random_flip_always == img).numpy().all() # removing because would fail if flip were symmetric 38 | assert (random_flip_never == img).numpy().all() 39 | -------------------------------------------------------------------------------- /api/TFSimilarity/search.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.search 2 | 3 | 4 | 5 | 6 | 7 | Efficiently find nearest indexed embeddings 8 | 9 | 10 | The search is used to find the closest indexed example embeddings 11 | to a query example embebbeding. 12 | To do so it performs a sub-linear time 13 | - [ANN (Approximate Nearst Neigboors)](https://en.wikipedia.org/wiki/Nearest_neighbor_search) 14 | search on the indexed set of embedding examples. 15 | 16 | Different ANN librairies have [different performance profiles](https://github.com/erikbern/ann-benchmarks). 17 | Tensorflow Similarity by default use [NMSLIB](https://github.com/nmslib/nmslib) 18 | which has a strong performance profile and is very portable. 19 | 20 | Adding another backend is fairly straightforward: subclass the 21 | abstract class `Search()` and implement the abstract methods. Then to use it 22 | pass it to the `compile()` method of your [SimilarityModel]. 23 | 24 | Adding your search backend as a built-in choice invlolves 25 | modifiying the [Indexer](../indexer.md) and sending a PR. In general, unless 26 | the backend is of general use, its better to not include it as 27 | a built-in option as it must be supported moving forward. 28 | 29 | ## Classes 30 | 31 | - [`class NMSLibSearch`](../TFSimilarity/indexer/NMSLibSearch.md): Efficiently find nearest embeddings by indexing known embeddings and make 32 | 33 | - [`class Search`](../TFSimilarity/indexer/Search.md): Helper class that provides a standard way to create an ABC using 34 | 35 | -------------------------------------------------------------------------------- /api/TFSimilarity/classification_metrics.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.classification_metrics 2 | 3 | 4 | 5 | 6 | 7 | Classification metrics measure matching classification quality between a 8 | 9 | set query examples and a set of indexed examples. 10 | 11 | ## Classes 12 | 13 | - [`class BinaryAccuracy`](../TFSimilarity/classification_metrics/BinaryAccuracy.md): Calculates how often the query label matches the derived lookup label. 14 | 15 | - [`class ClassificationMetric`](../TFSimilarity/callbacks/ClassificationMetric.md): Abstract base class for computing classification metrics. 16 | 17 | - [`class F1Score`](../TFSimilarity/classification_metrics/F1Score.md): Calculates the harmonic mean of precision and recall. 18 | 19 | - [`class FalsePositiveRate`](../TFSimilarity/classification_metrics/FalsePositiveRate.md): Calculates the false positive rate of the query classification. 20 | 21 | - [`class NegativePredictiveValue`](../TFSimilarity/classification_metrics/NegativePredictiveValue.md): Calculates the negative predictive value of the query classification. 22 | 23 | - [`class Precision`](../TFSimilarity/classification_metrics/Precision.md): Calculates the precision of the query classification. 24 | 25 | - [`class Recall`](../TFSimilarity/classification_metrics/Recall.md): Calculates the recall of the query classification. 26 | 27 | ## Functions 28 | 29 | - [`make_classification_metric(...)`](../TFSimilarity/callbacks/make_classification_metric.md): Convert classification metric from str name to object if needed. 30 | 31 | -------------------------------------------------------------------------------- /api/TFSimilarity/evaluators.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.evaluators 2 | 3 | 4 | 5 | 6 | 7 | Evaluates search index performance and calibrates it. 8 | 9 | 10 | ## Use 11 | 12 | Evaluators are used for two primary purposes: 13 | 14 | - Evaluate model performance on a reference index during training and 15 | evaluation phase via the `evaluate_classification()` and `evaluate_retrieval()` 16 | methods. Evaluation on a reference index is 17 | required to be able to assess model performance using 18 | - [Classification metrics](../classification_metrics/) and 19 | - [Retrieval metrics](../retrieval_metrics/). 20 | Those metrics can't be computed without indexing data and looking up 21 | nearest neighbors. 22 | 23 | - Calibrating the model requires evaluating various distance thresholds 24 | to find the maximal distance threshold. Those thresholds either meet, 25 | if possible, the user supplied `thresholds_targets` performance value or 26 | the optimal value with respect to the calibration `classification metric`. 27 | Calibration is required to perform matching 28 | because the optimal distance thresholds will change depending on 29 | the model, dataset and, training. Accordingly those thresholds 30 | need to be evaluated empirically for every use-case. 31 | 32 | ## Classes 33 | 34 | - [`class Evaluator`](../TFSimilarity/callbacks/Evaluator.md): Evaluates search index performance and calibrates it. 35 | 36 | - [`class MemoryEvaluator`](../TFSimilarity/callbacks/MemoryEvaluator.md): In memory index performance evaluation and classification. 37 | 38 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/augmentation_utils/flip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import annotations 16 | 17 | import tensorflow as tf 18 | 19 | from tensorflow_similarity.augmenters.augmentation_utils.random_apply import ( 20 | random_apply, 21 | ) 22 | from tensorflow_similarity.types import Tensor 23 | 24 | 25 | def random_random_flip_left_right(image: Tensor, p: float = 0.5) -> Tensor: 26 | def _transform(image: Tensor) -> Tensor: 27 | return tf.image.random_flip_left_right(image) 28 | 29 | return random_apply(_transform, p=p, x=image) 30 | 31 | 32 | def random_random_flip_top_bottom(image: Tensor, p: float = 0.5) -> Tensor: 33 | def _transform(image: Tensor) -> Tensor: 34 | return tf.image.random_flip_up_down(image) 35 | 36 | return random_apply(_transform, p=p, x=image) 37 | -------------------------------------------------------------------------------- /tensorflow_similarity/augmenters/contrastive.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import os 17 | from collections.abc import Callable 18 | 19 | import tensorflow as tf 20 | 21 | from ..augmenters import Augmenter 22 | from ..types import Tensor 23 | 24 | 25 | class ContrastiveAugmenter(Augmenter): 26 | def __init__(self, process: Callable, num_cpu: int | None = os.cpu_count()): 27 | self.process = process 28 | self.num_cpu = num_cpu 29 | 30 | def augment(self, x: Tensor, y: Tensor, num_views: int, is_warmup: bool) -> list[Tensor]: 31 | with tf.device("/cpu:0"): 32 | inputs = tf.stack(x) 33 | 34 | views = [] 35 | for _ in range(num_views): 36 | # multi-cor augementations 37 | view = tf.map_fn(self.process, inputs, parallel_iterations=self.num_cpu) 38 | views.append(view) 39 | return views 40 | -------------------------------------------------------------------------------- /api/TFSimilarity/distances.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.distances 2 | 3 | 4 | 5 | 6 | 7 | Vectorized embedding pairwise distances computation functions 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class ABC`](../TFSimilarity/distances/ABC.md): Helper class that provides a standard way to create an ABC using 14 | 15 | - [`class CosineDistance`](../TFSimilarity/distances/CosineDistance.md): Compute pairwise cosine distances between embeddings. 16 | 17 | - [`class Distance`](../TFSimilarity/distances/Distance.md): Note: don't forget to add your distance to the DISTANCES list 18 | 19 | - [`class EuclideanDistance`](../TFSimilarity/distances/EuclideanDistance.md): Compute pairwise euclidean distances between embeddings. 20 | 21 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor 22 | 23 | - [`class InnerProductSimilarity`](../TFSimilarity/distances/InnerProductSimilarity.md): Compute the pairwise inner product between embeddings. 24 | 25 | - [`class ManhattanDistance`](../TFSimilarity/distances/ManhattanDistance.md): Compute pairwise Manhattan distances between embeddings. 26 | 27 | - [`class SNRDistance`](../TFSimilarity/distances/SNRDistance.md): Computes pairwise SNR distances between embeddings. 28 | 29 | - [`class SquaredEuclideanDistance`](../TFSimilarity/distances/SquaredEuclideanDistance.md): Compute pairwise squared Euclidean distance. 30 | 31 | ## Functions 32 | 33 | - [`distance_canonicalizer(...)`](../TFSimilarity/distances/distance_canonicalizer.md): Normalize user requested distance to its matching Distance object. 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /tests/architectures/test_resnet18.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | 5 | from tensorflow_similarity.architectures import resnet18 6 | 7 | 8 | def test_include_top(): 9 | input_shape = (32, 32, 3) 10 | resnet = resnet18.ResNet18Sim(input_shape, include_top=True) 11 | 12 | # The second to last layer should use gem pooling when include_top is True 13 | assert resnet.layers[-2].name == "gem_pool" 14 | assert resnet.layers[-2].p == 3.0 15 | # The default is l2_norm True, so we expect the last layer to be 16 | # MetricEmbedding. 17 | assert re.match("metric_embedding", resnet.layers[-1].name) is not None 18 | 19 | 20 | def test_l2_norm_false(): 21 | input_shape = (32, 32, 3) 22 | resnet = resnet18.ResNet18Sim(input_shape, include_top=True, l2_norm=False) 23 | 24 | # The second to last layer should use gem pooling when include_top is True 25 | assert resnet.layers[-2].name == "gem_pool" 26 | assert resnet.layers[-2].p == 3.0 27 | # If l2_norm is False, we should return a dense layer as the last layer. 28 | assert re.match("dense", resnet.layers[-1].name) is not None 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "pooling, name", zip(["gem", "avg", "max"], ["gem_pool", "avg_pool", "max_pool"]), ids=["gem", "avg", "max"] 33 | ) 34 | def test_include_top_false(pooling, name): 35 | input_shape = (32, 32, 3) 36 | resnet = resnet18.ResNet18Sim(input_shape, include_top=False, pooling=pooling) 37 | 38 | # The second to last layer should use gem pooling when include_top is True 39 | assert resnet.layers[-1].name == name 40 | -------------------------------------------------------------------------------- /api/TFSimilarity/indexer/make_classification_matcher.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.indexer.make_classification_matcher 2 | 3 | 4 | 5 | 6 | 7 | Convert classification matcher from str name to object if needed. 8 | 9 | 10 | ```python 11 | TFSimilarity.indexer.make_classification_matcher( 12 | ``` 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 35 | 36 |
26 | matcher 27 | 29 | 'match_nearest', 'match_majority_vote' or 30 | ClassificationMatch object. Defines the classification matching, 31 | e.g., match_nearest will count a True Positive if the query_label 32 | is equal to the label of the nearest neighbor and the distance is 33 | less than or equal to the distance threshold. 34 |
37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 49 | 52 | 53 |
47 | ValueError 48 | 50 | matcher name is invalid. 51 |
54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 66 | 69 | 70 |
64 | ClassificationMatch 65 | 67 | Instantiated matcher if needed. 68 |
71 | 72 | -------------------------------------------------------------------------------- /tensorflow_similarity/retrieval_metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..types import BoolTensor, IntTensor 22 | 23 | 24 | def compute_match_mask(query_labels: IntTensor, lookup_labels: IntTensor) -> BoolTensor: 25 | """Compute a boolean mask (indicator function) marking the TPs in the results. 26 | 27 | Args: 28 | query_labels: A 1D tensor of the labels associated with the queries. 29 | 30 | lookup_labels: A 2D tensor where the jth row is the labels associated 31 | with the set of k neighbors for the jth query. 32 | 33 | Returns: 34 | A 2D boolean tensor indicating which lookups match the label of their 35 | associated query. 36 | """ 37 | if tf.rank(query_labels) == 1: 38 | query_labels = tf.expand_dims(query_labels, axis=-1) 39 | 40 | match_mask: BoolTensor = tf.math.equal(lookup_labels, query_labels) 41 | 42 | return match_mask 43 | -------------------------------------------------------------------------------- /api/TFSimilarity/visualization/viz_neigbors_imgs.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.visualization.viz_neigbors_imgs 2 | 3 | 4 | 5 | 6 | 7 | Display images nearest neighboors 8 | 9 | ```python 10 | TFSimilarity.visualization.viz_neigbors_imgs( 11 | example_class: int, 12 | class_mapping: Optional[Mapping[int, str]] = None, 13 | fig_size: Tuple[int, int] = (24, 4), 14 | cmap: str = viridis, 15 | show: bool = True 16 | ) 17 | ``` 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 | 54 | 59 | 60 | 63 | 66 | 67 | 70 | 74 | 75 | 78 | 81 | 82 |
31 | example 32 | 34 | The data used as query input. 35 |
38 | example_class 39 | 41 | The class of the data used as query 42 |
45 | neighbors 46 | 48 | The list of neighbors returned by the lookup() 49 |
52 | class_mapping 53 | 55 | Mapping from class numerical ids to a class name. If not 56 | set, the plot will display the class numerical id instead. 57 | Defaults to None. 58 |
61 | fig_size 62 | 64 | Size of the figure. Defaults to (24, 4). 65 |
68 | cmap 69 | 71 | Default color scheme for black and white images e.g mnist. 72 | Defaults to 'viridis'. 73 |
76 | show 77 | 79 | If the plot is going to be shown or not. Defaults to True. 80 |
83 | 84 | -------------------------------------------------------------------------------- /api/TFSimilarity/augmenters/Augmenter.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.augmenters.Augmenter 2 | 3 | 4 | 5 | 6 | 7 | Helper class that provides a standard way to create an ABC using 8 | 9 | Inherits From: [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | inheritance. 12 | 13 | ## Methods 14 | 15 |

augment

16 | 17 | View source 18 | 19 | ``python 20 | @abc.abstractmethod`` 21 | 22 | ```python 23 | augment( 24 | x: TFSimilarity.callbacks.Tensor``` 25 | , 26 | y: TFSimilarity.callbacks.Tensor``` 27 | , 28 | num_augmentations_per_example: int, 29 | is_warmup: bool 30 | ) -> List[TFSimilarity.callbacks.Tensor``` 31 | ] 32 | ``` 33 | 34 | 35 | 36 | 37 | 38 |

__call__

39 | 40 | View source 41 | 42 | ```python 43 | __call__( 44 | x: TFSimilarity.callbacks.Tensor``` 45 | , 46 | y: TFSimilarity.callbacks.Tensor``` 47 | , 48 | num_augmentations_per_example: int, 49 | is_warmup: bool 50 | ) -> List[TFSimilarity.callbacks.Tensor``` 51 | ] 52 | ``` 53 | 54 | 55 | Call self as a function. 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /api/TFSimilarity/samplers.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.samplers 2 | 3 | 4 | 5 | 6 | 7 | Data Samplers generate balanced batches for smooth training. 8 | 9 | 10 | *A well balanced batch is a batch that contains at least 2 examples for 11 | each class present in the batch*. 12 | 13 | Having well balanced batches is important for many types of similarity learning 14 | including contrastive learning because contrastive losses require at least 15 | two examples (and sometimes more) to be able to compute distances between 16 | the embeddings. 17 | 18 | To address this need, TensorFlow Similarity provides data samplers for 19 | various types of datasets that: 20 | - Ensure that batches contain at least N examples of each class present in 21 | the batch. 22 | - Support restricting the batches to a subset of the classes present in 23 | the dataset. 24 | 25 | ## Classes 26 | 27 | - [`class MultiShotMemorySampler`](../TFSimilarity/samplers/MultiShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset. 28 | 29 | - [`class SingleShotMemorySampler`](../TFSimilarity/samplers/SingleShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset. 30 | 31 | - [`class TFDatasetMultiShotMemorySampler`](../TFSimilarity/samplers/TFDatasetMultiShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset. 32 | 33 | ## Functions 34 | 35 | - [`TFRecordDatasetSampler(...)`](../TFSimilarity/samplers/TFRecordDatasetSampler.md): Create a [TFRecordDataset](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset) based sampler. 36 | 37 | - [`select_examples(...)`](../TFSimilarity/samplers/select_examples.md): Randomly select at most N examples per class 38 | 39 | -------------------------------------------------------------------------------- /api/TFSimilarity/samplers/select_examples.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.samplers.select_examples 2 | 3 | 4 | 5 | 6 | 7 | Randomly select at most N examples per class 8 | 9 | ```python 10 | TFSimilarity.samplers.select_examples( 11 | class_list: Sequence[int] = None, 12 | num_examples_per_class: int = None 13 | ) -> Tuple[np.ndarray, np.ndarray] 14 | ``` 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 30 | 33 | 34 | 37 | 40 | 41 | 44 | 49 | 50 | 53 | 58 | 59 |
28 | x 29 | 31 | A 2-D Tensor containing the data. 32 |
35 | y 36 | 38 | A 1-D Tensor containing the labels. 39 |
42 | class_list 43 | 45 | Filter the list of examples to only keep thoses those who 46 | belong to the supplied class list. In no class is supplied, keep 47 | examples for all the classes. Default to None - keep all the examples. 48 |
51 | num_examples_per_class 52 | 54 | Restrict the number of examples for EACH 55 | class to num_examples_per_class if set. If not set, all the available 56 | examples are selected. Defaults to None - no selection. 57 |
60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 71 | 72 | 73 |
69 | A Tuple containing the subset of x and y. 70 |
74 | 75 | -------------------------------------------------------------------------------- /benchmark/supervised/components/architectures.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from typing import Any 5 | 6 | import tensorflow as tf 7 | 8 | from tensorflow_similarity.architectures import ( 9 | EfficientNetSim, 10 | ResNet18Sim, 11 | ResNet50Sim, 12 | ) 13 | 14 | ARCHITECTURES = {} 15 | ARCHITECTURES["effnet"] = lambda p: EfficientNetSim( 16 | input_shape=p["input_shape"], 17 | embedding_size=p.get("embedding", 128), 18 | variant=p.get("variant", "B0"), 19 | weights=p.get("weights", "imagenet"), 20 | trainable=p.get("trainable", "frozen"), 21 | l2_norm=p.get("l2_norm", True), 22 | include_top=p.get("include_top", True), 23 | pooling=p.get("pooling", "gem"), 24 | gem_p=p.get("gem_p", 3.0), 25 | ) 26 | ARCHITECTURES["resnet50"] = lambda p: ResNet50Sim( 27 | input_shape=p["input_shape"], 28 | embedding_size=p.get("embedding", 128), 29 | weights=p.get("weights", "imagenet"), 30 | trainable=p.get("trainable", "frozen"), 31 | l2_norm=p.get("l2_norm", True), 32 | include_top=p.get("include_top", True), 33 | pooling=p.get("pooling", "gem"), 34 | gem_p=p.get("gem_p", 3.0), 35 | ) 36 | ARCHITECTURES["resnet18"] = lambda p: ResNet18Sim( 37 | input_shape=p["input_shape"], 38 | embedding_size=p.get("embedding", 128), 39 | l2_norm=p.get("l2_norm", True), 40 | include_top=p.get("include_top", True), 41 | pooling=p.get("pooling", "gem"), 42 | gem_p=p.get("gem_p", 3.0), 43 | ) 44 | 45 | 46 | def make_architecture(architecture_id: str, params: Mapping[str, Any]) -> tf.keras.Model: 47 | try: 48 | return ARCHITECTURES[architecture_id](params) 49 | except KeyError as exc: 50 | raise ValueError(f"Unknown architecture name: {architecture_id}") from exc 51 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmentation_utils/test_blur.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters.augmentation_utils import blur 5 | 6 | 7 | def create_img(width=32, height=32, channels=3): 8 | return tf.random.uniform([width, height, channels], 0, 1) 9 | 10 | 11 | def test_random_blur(): 12 | # Random Blur 13 | img = create_img() 14 | WIDTH = 32 15 | HEIGHT = 32 16 | CHANNELS = 3 17 | 18 | random_blurred_always = blur.random_blur(img, 32, 32, 1.0) 19 | 20 | random_blurred_never = blur.random_blur(img, 32, 32, 0) 21 | 22 | # check shapes 23 | assert (tf.shape(random_blurred_always) == tf.shape(img)).numpy().all() 24 | assert (tf.shape(random_blurred_never) == tf.shape(img)).numpy().all() 25 | 26 | # check if blur works 27 | assert not (random_blurred_always == img).numpy().all() 28 | assert (random_blurred_never == img).numpy().all() 29 | 30 | 31 | def test_batch_random_blur(): 32 | img = create_img() 33 | WIDTH = 32 34 | HEIGHT = 32 35 | CHANNELS = 3 36 | 37 | batched_img = [img] 38 | 39 | random_batched_blurred_always = blur.batch_random_blur(batched_img, 32, 32, 1.0) 40 | 41 | random_batched_blurred_never = blur.batch_random_blur(batched_img, 32, 32, 0) 42 | 43 | # check shapes 44 | assert (tf.shape(random_batched_blurred_always) == tf.shape(batched_img)).numpy().all() 45 | assert (tf.shape(random_batched_blurred_never) == tf.shape(batched_img)).numpy().all() 46 | 47 | # check if blur works 48 | equality_always = tf.reshape(tf.equal(random_batched_blurred_always, batched_img), [-1]) 49 | equality_never = tf.reshape(tf.equal(random_batched_blurred_never, batched_img), [-1]) 50 | 51 | assert not equality_always.numpy().all() 52 | assert equality_never.numpy().all() 53 | -------------------------------------------------------------------------------- /tensorflow_similarity/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data Samplers generate balanced batches for smooth training. 16 | 17 | *A well balanced batch is a batch that contains at least 2 examples for 18 | each class present in the batch*. 19 | 20 | Having well balanced batches is important for many types of similarity learning 21 | including contrastive learning because contrastive losses require at least 22 | two examples (and sometimes more) to be able to compute distances between 23 | the embeddings. 24 | 25 | To address this need, TensorFlow Similarity provides data samplers for 26 | various types of datasets that: 27 | - Ensure that batches contain at least N examples of each class present in 28 | the batch. 29 | - Support restricting the batches to a subset of the classes present in 30 | the dataset. 31 | """ 32 | from .file_samplers import MultiShotFileSampler # noqa 33 | from .memory_samplers import MultiShotMemorySampler # noqa 34 | from .memory_samplers import SingleShotMemorySampler # noqa 35 | from .tfdata_sampler import TFDataSampler # noqa 36 | from .tfdataset_samplers import TFDatasetMultiShotMemorySampler # noqa 37 | from .tfrecords_samplers import TFRecordDatasetSampler # noqa 38 | from .utils import select_examples # noqa 39 | -------------------------------------------------------------------------------- /api/TFSimilarity/visualization/visualize_views.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.visualization.visualize_views 2 | 3 | 4 | 5 | 6 | 7 | Display side by side different image views with labels, and predictions 8 | 9 | ```python 10 | TFSimilarity.visualization.visualize_views( 11 | num_imgs: int = None, 12 | views_per_col: int = 4, 13 | fig_size: Tuple[int, int] = (24, 4), 14 | max_pixel_value: float = 1.0, 15 | min_pixel_value: float = 0.0 16 | ) 17 | ``` 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 | 54 | 57 | 58 | 61 | 64 | 65 | 68 | 72 | 73 | 76 | 80 | 81 |
31 | views 32 | 34 | Aray of views 35 |
38 | predictions 39 | 41 | model output. 42 |
45 | labels 46 | 48 | image labels 49 |
52 | num_imgs 53 | 55 | number of images to use. 56 |
59 | views_per_col 60 | 62 | Int, number of images in one row. Defaults to 3. 63 |
66 | max_pixel_value 67 | 69 | Max expected value for a pixel. Used to scale the image 70 | between [0,1]. 71 |
74 | min_pixel_value 75 | 77 | Min expected value for a pixel. Used to scale the image 78 | between [0,1]. 79 |
82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 93 | 94 | 95 |
91 | None. 92 |
96 | 97 | -------------------------------------------------------------------------------- /tests/augmenters/test_augmentation_utils/test_color_jitter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.augmenters.augmentation_utils import color_jitter 5 | 6 | 7 | def create_img(width=32, height=32, channels=3): 8 | return tf.random.uniform([width, height, channels], 0, 1) 9 | 10 | 11 | def test_random_color_jitter_multiplicative(): 12 | # Random Color Jitter 13 | img = create_img() 14 | WIDTH = 32 15 | HEIGHT = 32 16 | CHANNELS = 3 17 | 18 | random_jitter_always = color_jitter.random_color_jitter(img, 1, 1, 1, impl="multiplicative") 19 | 20 | random_jitter_never = color_jitter.random_color_jitter(img, 0, impl="multiplicative") 21 | 22 | # check shapes 23 | assert (tf.shape(random_jitter_always) == tf.shape(img)).numpy().all() 24 | assert (tf.shape(random_jitter_never) == tf.shape(img)).numpy().all() 25 | 26 | # check if blur works 27 | assert not (random_jitter_always == img).numpy().all() 28 | assert (random_jitter_never == img).numpy().all() 29 | 30 | 31 | def test_random_color_jitter_additive(): 32 | # Random Color Jitter 33 | img = create_img() 34 | WIDTH = 32 35 | HEIGHT = 32 36 | CHANNELS = 3 37 | 38 | random_jitter_always = color_jitter.random_color_jitter( 39 | img, 1, 1, 1, impl="additive" # won't make a difference between barlow/v1 40 | ) 41 | 42 | random_jitter_never = color_jitter.random_color_jitter( 43 | img, 0, impl="additive" # won't make a difference between barlow/v1 44 | ) 45 | 46 | # check shapes 47 | assert (tf.shape(random_jitter_always) == tf.shape(img)).numpy().all() 48 | assert (tf.shape(random_jitter_never) == tf.shape(img)).numpy().all() 49 | 50 | # check if color jitter works 51 | assert not (random_jitter_always == img).numpy().all() 52 | assert (random_jitter_never == img).numpy().all() 53 | -------------------------------------------------------------------------------- /api/TFSimilarity/augmenters/ContrastiveAugmenter.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.augmenters.ContrastiveAugmenter 2 | 3 | 4 | 5 | 6 | 7 | Helper class that provides a standard way to create an ABC using 8 | 9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | ```python 12 | TFSimilarity.augmenters.ContrastiveAugmenter( 13 | process: Callable, num_cpu: Optional[int] = os.cpu_count() 14 | ) 15 | ``` 16 | 17 | 18 | 19 | 20 | inheritance. 21 | 22 | ## Methods 23 | 24 |

augment

25 | 26 | View source 27 | 28 | ```python 29 | augment( 30 | x: TFSimilarity.callbacks.Tensor``` 31 | , 32 | y: TFSimilarity.callbacks.Tensor``` 33 | , 34 | num_views: int, 35 | is_warmup: bool 36 | ) -> List[TFSimilarity.callbacks.Tensor``` 37 | ] 38 | ``` 39 | 40 | 41 | 42 | 43 | 44 |

__call__

45 | 46 | View source 47 | 48 | ```python 49 | __call__( 50 | x: TFSimilarity.callbacks.Tensor``` 51 | , 52 | y: TFSimilarity.callbacks.Tensor``` 53 | , 54 | num_augmentations_per_example: int, 55 | is_warmup: bool 56 | ) -> List[TFSimilarity.callbacks.Tensor``` 57 | ] 58 | ``` 59 | 60 | 61 | Call self as a function. 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /api/TFSimilarity/visualization/confusion_matrix.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.visualization.confusion_matrix 2 | 3 | 4 | 5 | 6 | 7 | Plot confusion matrix 8 | 9 | ```python 10 | TFSimilarity.visualization.confusion_matrix( 11 | normalize: bool = True, 12 | title: str = Confusion matrix, 13 | cmap: str = Blues, 14 | show: bool = True 15 | ``` 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 31 | 34 | 35 | 38 | 41 | 42 | 45 | 49 | 50 | 53 | 57 | 58 | 61 | 64 | 65 | 68 | 71 | 72 | 75 | 78 | 79 |
29 | y_pred 30 | 32 | Model prediction returned by model.match() 33 |
36 | y_true 37 | 39 | Expected class_id. 40 |
43 | normalize 44 | 46 | Normalizes matrix values between 0 and 1. 47 | Defaults to True. 48 |
51 | labels 52 | 54 | List of class string label to display instead of the class 55 | numerical ids. Defaults to None. 56 |
59 | title 60 | 62 | Title of the confusion matrix. Defaults to 'Confusion matrix'. 63 |
66 | cmap 67 | 69 | Color schema as CMAP. Defaults to 'Blues'. 70 |
73 | show 74 | 76 | If the plot is going to be shown or not. Defaults to True. 77 |
80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 91 | 92 | 93 |
89 | A Tuple containing the plot and confusion matrix. 90 |
94 | 95 | -------------------------------------------------------------------------------- /tensorflow_similarity/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Evaluates search index performance and calibrates it. 15 | 16 | ## Use 17 | 18 | Evaluators are used for two primary purposes: 19 | 20 | - Evaluate model performance on a reference index during training and 21 | evaluation phase via the `evaluate_classification()` and `evaluate_retrieval()` 22 | methods. Evaluation on a reference index is 23 | required to be able to assess model performance using 24 | [Classification metrics](../classification_metrics/) and 25 | [Retrieval metrics](../retrieval_metrics/). 26 | Those metrics can't be computed without indexing data and looking up 27 | nearest neighbors. 28 | 29 | - Calibrating the model requires evaluating various distance thresholds 30 | to find the maximal distance threshold. Those thresholds either meet, 31 | if possible, the user supplied `thresholds_targets` performance value or 32 | the optimal value with respect to the calibration `classification metric`. 33 | Calibration is required to perform matching 34 | because the optimal distance thresholds will change depending on 35 | the model, dataset and, training. Accordingly those thresholds 36 | need to be evaluated empirically for every use-case. 37 | """ 38 | from .evaluator import Evaluator # noqa 39 | from .memory_evaluator import MemoryEvaluator # noqa 40 | -------------------------------------------------------------------------------- /api/TFSimilarity/augmenters/BarlowAugmenter.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.augmenters.BarlowAugmenter 2 | 3 | 4 | 5 | 6 | 7 | Helper class that provides a standard way to create an ABC using 8 | 9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | ```python 12 | TFSimilarity.augmenters.BarlowAugmenter( 13 | width: int, 14 | height: int, 15 | flip_probability=0.5, 16 | brightness_multiplier=0.8, 17 | contrast_multiplier=0.6, 18 | saturation_multiplier=0.6, 19 | hue_multiplier=0.2, 20 | jitter_probability=0.8, 21 | greyscale_probability=0.2, 22 | blur_probability=0.2, 23 | blur_min_sigma=0, 24 | blur_max_sigma=1, 25 | solarize_probability=0.2, 26 | solarize_pixel_min=0, 27 | solarize_pixel_max=255, 28 | solarize_thresh=10, 29 | num_cpu: Optional[int] = os.cpu_count() 30 | ) 31 | ``` 32 | 33 | 34 | 35 | 36 | inheritance. 37 | 38 | ## Methods 39 | 40 |

augment

41 | 42 | View source 43 | 44 | ``python 45 | @tf.function`` 46 | 47 | ```python 48 | augment( 49 | x: Any, 50 | y: Any = tf.constant([0]), 51 | num_augmentations_per_example: int = 2, 52 | is_warmup: bool = True 53 | ) -> List[Any] 54 | ``` 55 | 56 | 57 | 58 | 59 | 60 |

__call__

61 | 62 | View source 63 | 64 | ```python 65 | __call__( 66 | x: Any, 67 | y: Any = tf.constant([0]), 68 | num_augmentations_per_example: int = 2, 69 | is_warmup: bool = True 70 | ) -> List[Any] 71 | ``` 72 | 73 | 74 | Call self as a function. 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /tensorflow_similarity/training_metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..samplers.samplers import Sampler 22 | 23 | 24 | def batch_class_ratio(sampler: Sampler, num_batches: int = 100) -> float: 25 | """Computes the average number of examples per class within each batch. 26 | Similarity learning requires at least 2 examples per class in each batch. 27 | This is needed in order to construct the triplets. This function 28 | provides the average number of examples per class within each batch and 29 | can be used to check that a sampler is working correctly. 30 | The ratio should be >= 2. 31 | Args: 32 | sampler: A tf.similarity sampler object. 33 | num_batches: The number of batches to sample. 34 | Returns: 35 | The average number of examples per class. 36 | """ 37 | ratio = 0 38 | for batch_count, (_, y) in enumerate(sampler): 39 | if batch_count < num_batches: 40 | batch_size = tf.shape(y)[0] 41 | num_classes = tf.shape(tf.unique(y)[0])[0] 42 | ratio += tf.math.divide(batch_size, num_classes) 43 | else: 44 | break 45 | 46 | return float(ratio / (batch_count + 1)) 47 | -------------------------------------------------------------------------------- /scripts/increment_version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def abs_path(rel_path): 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | return os.path.join(here, rel_path) 7 | 8 | 9 | def read(rel_path): 10 | with open(abs_path(rel_path), "rt") as fd: 11 | return fd.read() 12 | 13 | 14 | def write(rel_path, lines): 15 | with open(abs_path(rel_path), "wt") as fd: 16 | for line in lines: 17 | fd.write(line + "\n") 18 | 19 | 20 | def increment_dev_version(previous_version): 21 | delim = "dev" 22 | if delim not in previous_version: 23 | raise ValueError(f"The previous version {previous_version} does contain a dev suffix") 24 | # Split and increment dev version 25 | main_version, dev_version = previous_version.split(delim) 26 | dev_version = int(dev_version) + 1 27 | # Construct new version 28 | new_version = f"{main_version}{delim}{dev_version}" 29 | 30 | return new_version 31 | 32 | 33 | def update_version(rel_path): 34 | lines = [] 35 | for line in read(rel_path).splitlines(): 36 | if line.startswith("__version__"): 37 | delim = '"' if '"' in line else "'" 38 | previous_version = line.split(delim)[1] 39 | new_version = increment_dev_version(previous_version) 40 | line = line.replace(previous_version, new_version) 41 | lines.append(line) 42 | 43 | write(rel_path, lines) 44 | 45 | return new_version 46 | 47 | 48 | if __name__ == "__main__": 49 | version_path = "../tensorflow_similarity/__init__.py" 50 | 51 | os.system("git config --global user.email 'action@github.com'") 52 | os.system("git config --global user.name 'Github Actions Bot'") 53 | 54 | # Assumes we are in the scripts/ dir 55 | new_version = update_version(version_path) 56 | 57 | os.system("git add -u") 58 | os.system(f"git commit -m '[nightly] Increase version to {new_version}'") 59 | os.system("git push") 60 | -------------------------------------------------------------------------------- /tests/stores/test_redis.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import numpy as np 4 | 5 | from tensorflow_similarity.stores.redis import RedisStore 6 | 7 | 8 | class MockRedis: 9 | def __init__(self, *args, **kwargs): 10 | self.flushdb() 11 | 12 | def get(self, key): 13 | if key in self.cache: 14 | return self.cache[key] 15 | return None # return nil 16 | 17 | def set(self, key, value, *args, **kwargs): 18 | self.cache[key] = value 19 | return "OK" 20 | 21 | def flushdb(self): 22 | self.cache = {} 23 | return "OK" 24 | 25 | def incr(self, key): 26 | if key in self.cache: 27 | self.cache[key] += 1 28 | else: 29 | self.cache[key] = 1 30 | return self.cache[key] 31 | 32 | 33 | @patch("redis.Redis", MockRedis) 34 | def test_store_and_retrieve(): 35 | recs = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 36 | 37 | kv_store = RedisStore() 38 | idxs = [] 39 | for r in recs: 40 | idx = kv_store.add(r[0], r[1], r[2]) 41 | idxs.append(idx) 42 | 43 | # check index numbering 44 | for gt, idx in enumerate(idxs): 45 | assert isinstance(idx, int) 46 | assert gt == idx 47 | 48 | # get back three elements 49 | for idx in idxs: 50 | emb, lbl, dt = kv_store.get(idx) 51 | assert emb == recs[idx][0] 52 | assert lbl == recs[idx][1] 53 | assert dt == recs[idx][2] 54 | 55 | 56 | @patch("redis.Redis", MockRedis) 57 | def test_batch_add(): 58 | embs = np.array([[0.1, 0.2], [0.2, 0.3]]) 59 | lbls = np.array([1, 2]) 60 | data = np.array([[0, 0, 0], [1, 1, 1]]) 61 | 62 | kv_store = RedisStore() 63 | idxs = kv_store.batch_add(embs, lbls, data) 64 | for idx in idxs: 65 | emb, lbl, dt = kv_store.get(idx) 66 | assert np.array_equal(emb, embs[idx]) 67 | assert np.array_equal(lbl, lbls[idx]) 68 | assert np.array_equal(dt, data[idx]) 69 | -------------------------------------------------------------------------------- /benchmark/supervised/components/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from typing import Any 5 | 6 | import keras_tuner 7 | import tensorflow as tf 8 | from tensorflow.keras.optimizers import Adam, RMSprop 9 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule 10 | from tensorflow_addons.optimizers import LAMB 11 | 12 | from . import utils 13 | 14 | OPTIMIZERS = {} 15 | OPTIMIZERS["adam"] = lambda p, lrs, hp: Adam( 16 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp), 17 | beta_1=utils.get_param(p, "beta_1", 0.9, hp), 18 | beta_2=utils.get_param(p, "beta_2", 0.999, hp), 19 | epsilon=utils.get_param(p, "epsilon", 1e-07, hp), 20 | amsgrad=utils.get_param(p, "amsgrad", False, hp), 21 | ) 22 | OPTIMIZERS["lamb"] = lambda p, lrs, hp: LAMB( 23 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp), 24 | beta_1=utils.get_param(p, "beta_1", 0.9, hp), 25 | beta_2=utils.get_param(p, "beta_2", 0.999, hp), 26 | epsilon=utils.get_param(p, "epsilon", 1e-06, hp), 27 | weight_decay=utils.get_param(p, "weight_decay", 0.0, hp), 28 | exclude_from_weight_decay=utils.get_param(p, "exclude_from_weight_decay", None, hp), 29 | exclude_from_layer_adaptation=utils.get_param(p, "exclude_from_layer_adaptation", None, hp), 30 | ) 31 | OPTIMIZERS["rmsprop"] = lambda p, lrs, hp: RMSprop( 32 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp), 33 | rho=utils.get_param(p, "rho", 0.9, hp), 34 | momentum=utils.get_param(p, "momentum", 0.0, hp), 35 | epsilon=utils.get_param(p, "epsilon", 1e-07, hp), 36 | centered=utils.get_param(p, "centered", False, hp), 37 | ) 38 | 39 | 40 | def make_optimizer( 41 | opt_id: str, 42 | params: Mapping[str, Any], 43 | lr_schedule: LearningRateSchedule | None, 44 | hp: keras_tuner.HyperParameters | None = None, 45 | ) -> tf.keras.Optimizer: 46 | try: 47 | return OPTIMIZERS[opt_id](params, lr_schedule, hp) 48 | except KeyError as exc: 49 | raise ValueError(f"Unknown optimizer name: {opt_id}") from exc 50 | -------------------------------------------------------------------------------- /benchmark/supervised/components/datasets/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import inspect 5 | from collections.abc import Mapping, Sequence 6 | from typing import Any 7 | 8 | 9 | @dataclasses.dataclass 10 | class Fold: 11 | training_classes: list[int] = dataclasses.field(default_factory=list) 12 | val_classes: list[int] = dataclasses.field(default_factory=list) 13 | train_idxs: list[int] = dataclasses.field(default_factory=list) 14 | val_idxs: list[int] = dataclasses.field(default_factory=list) 15 | 16 | 17 | @dataclasses.dataclass 18 | class Splits: 19 | train_idxs: list[int] = dataclasses.field(default_factory=list) 20 | test_idxs: list[int] = dataclasses.field(default_factory=list) 21 | folds: Mapping[str, Fold] = dataclasses.field(default_factory=dict) 22 | 23 | 24 | @dataclasses.dataclass 25 | class TrainValSplit: 26 | num_splits: int 27 | val_class_pctg: float 28 | 29 | 30 | @dataclasses.dataclass 31 | class EvalCallback: 32 | max_num_queries: int 33 | max_num_targets: int 34 | 35 | 36 | @dataclasses.dataclass(eq=True, frozen=True) 37 | class DatasetConfig: 38 | name: str 39 | component: str 40 | dataset_id: str 41 | x_key: str 42 | y_key: str 43 | train_classes: Sequence[int] 44 | test_classes: Sequence[int] 45 | train_val_splits: TrainValSplit | None = None 46 | eval_callback: EvalCallback | None = None 47 | 48 | def __post_init__(self): 49 | if self.train_val_splits is None: 50 | self.train_val_splits = TrainValSplit( 51 | num_splits=0, 52 | val_class_pctg=0.0, 53 | ) 54 | 55 | def __hash__(self): 56 | return hash((self.name, self.component, self.dataset_id, self.x_key, self.y_key)) 57 | 58 | @classmethod 59 | def from_dict(cls, params: dict[str, Any]) -> DatasetConfig: 60 | """Create a DatasetConfig from a dictionary. 61 | 62 | This supports kwargs that are not part of the DatasetConfig class. 63 | """ 64 | kwargs = {k: v for k, v in params.items() if k in inspect.signature(cls).parameters} 65 | return cls(**kwargs) 66 | -------------------------------------------------------------------------------- /tensorflow_similarity/matchers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING, Type 17 | 18 | if TYPE_CHECKING: 19 | from .classification_match import ClassificationMatch 20 | 21 | from .match_majority_vote import MatchMajorityVote 22 | from .match_nearest import MatchNearest 23 | 24 | 25 | def make_classification_matcher(matcher: str | ClassificationMatch) -> ClassificationMatch: 26 | """Convert classification matcher from str name to object if needed. 27 | 28 | Args: 29 | matcher: {'match_nearest', 'match_majority_vote'} or 30 | ClassificationMatch object. Defines the classification matching, 31 | e.g., match_nearest will count a True Positive if the query_label 32 | is equal to the label of the nearest neighbor and the distance is 33 | less than or equal to the distance threshold. 34 | 35 | Raises: 36 | ValueError: matcher name is invalid. 37 | 38 | Returns: 39 | ClassificationMatch: Instantiated matcher if needed. 40 | """ 41 | # ! Matcher must be non-instantiated. 42 | MATCHER_ALIASES: dict[str, Type[ClassificationMatch]] = { 43 | "match_nearest": MatchNearest, 44 | "match_majority_vote": MatchMajorityVote, 45 | } 46 | 47 | if isinstance(matcher, str): 48 | if matcher.lower() in MATCHER_ALIASES: 49 | matcher = MATCHER_ALIASES[matcher.lower()]() 50 | else: 51 | raise ValueError(f"Unknown matcher name: {matcher}, typo?") 52 | 53 | return matcher 54 | -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks.md: -------------------------------------------------------------------------------- 1 | # Module: TFSimilarity.callbacks 2 | 3 | 4 | 5 | 6 | 7 | Specialized callbacks that track similarity metrics during training 8 | 9 | 10 | 11 | ## Classes 12 | 13 | - [`class Callback`](../TFSimilarity/callbacks/Callback.md): Abstract base class used to build new callbacks. 14 | 15 | - [`class ClassificationMatch`](../TFSimilarity/callbacks/ClassificationMatch.md): Abstract base class for defining the classification matching strategy. 16 | 17 | - [`class ClassificationMetric`](../TFSimilarity/callbacks/ClassificationMetric.md): Abstract base class for computing classification metrics. 18 | 19 | - [`class EvalCallback`](../TFSimilarity/callbacks/EvalCallback.md): Epoch end evaluation callback that build a test index and evaluate 20 | 21 | - [`class Evaluator`](../TFSimilarity/callbacks/Evaluator.md): Evaluates search index performance and calibrates it. 22 | 23 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor 24 | 25 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor 26 | 27 | - [`class MemoryEvaluator`](../TFSimilarity/callbacks/MemoryEvaluator.md): In memory index performance evaluation and classification. 28 | 29 | - [`class Path`](../TFSimilarity/callbacks/Path.md): PurePath subclass that can make system calls. 30 | 31 | - [`class SimilarityModel`](../TFSimilarity/callbacks/SimilarityModel.md): Specialized Keras.Model which implement the core features needed for 32 | 33 | - [`class Tensor`](../TFSimilarity/callbacks/Tensor.md): A `tf.Tensor` represents a multidimensional array of elements. 34 | 35 | ## Functions 36 | 37 | - [`SplitValidationLoss(...)`](../TFSimilarity/callbacks/SplitValidationLoss.md): Creates the validation callbacks. 38 | 39 | - [`make_classification_metric(...)`](../TFSimilarity/callbacks/make_classification_metric.md): Convert classification metric from str name to object if needed. 40 | 41 | - [`unpack_lookup_distances(...)`](../TFSimilarity/callbacks/unpack_lookup_distances.md) 42 | 43 | - [`unpack_lookup_labels(...)`](../TFSimilarity/callbacks/unpack_lookup_labels.md) 44 | 45 | - [`unpack_results(...)`](../TFSimilarity/callbacks/unpack_results.md): Updates logs, writes summary, and returns list of strings of 46 | 47 | -------------------------------------------------------------------------------- /tensorflow_similarity/distances/cosine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inner product similarity computation functions for embeddings.""" 15 | from __future__ import annotations 16 | 17 | from typing import TYPE_CHECKING 18 | 19 | import tensorflow as tf 20 | 21 | if TYPE_CHECKING: 22 | from ..types import FloatTensor 23 | 24 | from .distance import Distance 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Similarity") 28 | class CosineDistance(Distance): 29 | """Compute pairwise cosine distances between embeddings. 30 | 31 | The [Cosine Distance](https://en.wikipedia.org/wiki/Cosine_similarity) is 32 | an angular distance that varies from 0 (similar) to 1 (dissimilar). 33 | """ 34 | 35 | def __init__(self, name: str = "cosine", **kwargs): 36 | "Init Cosine distance" 37 | super().__init__(name=name, **kwargs) 38 | 39 | @tf.function 40 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor: 41 | """Compute pairwise distances for a given batch of embeddings. 42 | 43 | Args: 44 | query_embeddings: Embeddings to compute the pairwise one. The embeddings 45 | are expected to be normalized. 46 | key_embeddings: Embeddings to compute the pairwise one. The embeddings 47 | are expected to be normalized. 48 | 49 | Returns: 50 | FloatTensor: Pairwise distance tensor. 51 | """ 52 | distances = 1 - tf.linalg.matmul(query_embeddings, key_embeddings, transpose_b=True) 53 | min_clip_distances: FloatTensor = tf.math.maximum(distances, 0.0) 54 | return min_clip_distances 55 | -------------------------------------------------------------------------------- /tests/samplers/test_tfrecord_samplers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_similarity.samplers import TFRecordDatasetSampler 6 | 7 | 8 | def _int64_feature(value): 9 | """Returns an int64_list from a bool / enum / int / uint.""" 10 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 11 | 12 | 13 | def to_tfrecord(sid, value): 14 | feature = { 15 | "sid": _int64_feature(sid), 16 | "value": _int64_feature(value), 17 | } 18 | 19 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 20 | return example.SerializeToString() 21 | 22 | 23 | def deserialization_fn(serialized_example): 24 | fd = { 25 | "sid": tf.io.FixedLenFeature([], dtype=tf.int64), 26 | "value": tf.io.FixedLenFeature([], dtype=tf.int64), 27 | } 28 | sample = tf.io.parse_single_example(serialized_example, fd) 29 | 30 | return (sample["sid"], sample["value"]) 31 | 32 | 33 | class TFRecordSamplerTest(tf.test.TestCase): 34 | def setUp(self): 35 | super().setUp() 36 | 37 | for sid in range(100): 38 | shard_path = os.path.join(self.get_temp_dir(), f"tfr_{sid}.tfrec") 39 | 40 | with tf.io.TFRecordWriter(str(shard_path)) as w: 41 | for value in range(1000): 42 | example = to_tfrecord(sid, sid * 1000 + value) 43 | w.write(example) 44 | 45 | def test_basic(self): 46 | sampler = TFRecordDatasetSampler( 47 | self.get_temp_dir(), 48 | deserialization_fn=deserialization_fn, 49 | batch_size=10, 50 | example_per_class=2, 51 | ) 52 | 53 | si = iter(sampler) 54 | [next(si) for _ in range(10_000)] 55 | examples = next(si) 56 | 57 | # We should get 3 pairs of shard IDs 58 | sids = examples[0] 59 | values = examples[1] 60 | first_sid = sids[::2] 61 | second_sid = sids[1::2] 62 | 63 | self.assertLen(sids, 10) 64 | self.assertAllEqual(first_sid, second_sid) 65 | 66 | for sid, val in zip(sids, values): 67 | diff = val - sid * 1_000 68 | self.assertGreaterEqual(diff, 0) 69 | self.assertLess(diff, 1000) 70 | -------------------------------------------------------------------------------- /api/TFSimilarity/augmenters/SimCLRAugmenter.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.augmenters.SimCLRAugmenter 2 | 3 | 4 | 5 | 6 | 7 | SimCLR augmentation pipeline as defined in 8 | 9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | ```python 12 | TFSimilarity.augmenters.SimCLRAugmenter( 13 | height: int, 14 | width: int, 15 | is_training: bool = True, 16 | color_distort: bool = True, 17 | jitter_stength: float = 1.0, 18 | crop: bool = True, 19 | eval_crop_proportion: float = 0.875, 20 | flip: bool = True, 21 | version: str = v2, 22 | num_cpu: Optional[int] = os.cpu_count() 23 | ) 24 | ``` 25 | 26 | 27 | 28 | 29 | - [Big Self-Supervised Models are Strong Semi-Supervised Learners](https://arxiv.org/abs/2006.10029) 30 | 31 | code adapted from [orignal github](https://github.com/google-research/simclr/tree/master/tf2) 32 | 33 | ## Methods 34 | 35 |

augment

36 | 37 | View source 38 | 39 | ``python 40 | @tf.function`` 41 | 42 | ```python 43 | augment( 44 | x: TFSimilarity.callbacks.Tensor``` 45 | , 46 | y: TFSimilarity.callbacks.Tensor``` 47 | , 48 | num_views: int, 49 | is_warmup: bool 50 | ) -> List[TFSimilarity.callbacks.Tensor``` 51 | ] 52 | ``` 53 | 54 | 55 | 56 | 57 | 58 |

__call__

59 | 60 | View source 61 | 62 | ```python 63 | __call__( 64 | x: TFSimilarity.callbacks.Tensor``` 65 | , 66 | y: TFSimilarity.callbacks.Tensor``` 67 | , 68 | num_augmentations_per_example: int, 69 | is_warmup: bool 70 | ) -> List[TFSimilarity.callbacks.Tensor``` 71 | ] 72 | ``` 73 | 74 | 75 | Call self as a function. 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /tensorflow_similarity/visualization/vizualize_views.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import tensorflow as tf 6 | from matplotlib import pyplot as plt 7 | 8 | if TYPE_CHECKING: 9 | from tensorflow import Tensor 10 | 11 | 12 | def visualize_views( 13 | views: Tensor, 14 | labels: Tensor = None, 15 | predictions: Tensor = None, 16 | num_imgs: int | None = None, 17 | views_per_col: int = 4, 18 | fig_size: tuple[int, int] = (24, 4), 19 | max_pixel_value: float = 1.0, 20 | min_pixel_value: float = 0.0, 21 | ): 22 | """Display side by side different image views with labels, and predictions 23 | 24 | Args: 25 | views: Aray of views 26 | predictions: model output. 27 | labels: image labels 28 | num_imgs: number of images to use. 29 | views_per_col: Int, number of images in one row. Defaults to 3. 30 | max_pixel_value: Max expected value for a pixel. Used to scale the image 31 | between [0,1]. 32 | min_pixel_value: Min expected value for a pixel. Used to scale the image 33 | between [0,1]. 34 | 35 | Returns: 36 | None. 37 | """ 38 | num_views = len(views) 39 | num_imgs = num_imgs if num_imgs else len(views[0]) 40 | num_col = views_per_col 41 | num_row = num_imgs // num_col 42 | num_row = num_row + 1 if num_imgs % num_col else num_row 43 | 44 | # Plot the images 45 | fig, axes = plt.subplots(num_row, num_col, figsize=fig_size) 46 | for i in range(num_imgs): 47 | 48 | # If the number of rows is 1, the axes array is one-dimensional 49 | if num_row == 1: 50 | ax = axes[i % num_col] 51 | else: 52 | ax = axes[i // num_col, i % num_col] 53 | 54 | scale = abs(max_pixel_value - min_pixel_value) 55 | pair = [(views[j][i] - min_pixel_value) / scale for j in range(num_views)] 56 | ax.imshow(tf.concat(pair, axis=1)) 57 | ax.set_axis_off() 58 | 59 | label = labels[i] if labels else i 60 | 61 | if predictions: 62 | ax.set_title("Label: {} | Pred: {:.5f}".format(label, predictions[i][0])) 63 | elif labels: 64 | ax.set_title("Label: {}".format(label)) 65 | 66 | ax.set_xticks([]) 67 | ax.set_yticks([]) 68 | plt.show() 69 | -------------------------------------------------------------------------------- /api/TFSimilarity/indexer/Lookup.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.indexer.Lookup 2 | 3 | 4 | 5 | 6 | 7 | Metadata associated with a query match. 8 | 9 | 10 | ```python 11 | TFSimilarity.indexer.Lookup( 12 | rank: int, 13 | distance: float, 14 | label: Optional[int] = dataclasses.field(default=None), 15 | embedding: Optional[np.ndarray] = dataclasses.field(default=None), 16 | ) 17 | ``` 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 35 | 38 | 39 | 42 | 45 | 46 | 49 | 52 | 53 | 56 | 59 | 60 | 63 | 67 | 68 |
33 | rank 34 | 36 | Rank of the match with respect to the query distance. 37 |
40 | distance 41 | 43 | The distance from the match to the query. 44 |
47 | label 48 | 50 | The label associated with the match. Default None. 51 |
54 | embedding 55 | 57 | The embedded match vector. Default None. 58 |
61 | data 62 | 64 | The original Tensor representation of the match result. 65 | Default None. 66 |
69 | 70 | 71 | 72 | ## Methods 73 | 74 |

__eq__

75 | 76 | View source 77 | 78 | ```python 79 | __eq__( 80 | other 81 | ) -> bool 82 | ``` 83 | 84 | 85 | Return self==value. 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 101 | 104 | 105 | 108 | 111 | 112 | 115 | 118 | 119 |
99 | data 100 | 102 | None 103 |
106 | embedding 107 | 109 | None 110 |
113 | label 114 | 116 | None 117 |
120 | 121 | -------------------------------------------------------------------------------- /tensorflow_similarity/distances/manhattan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inner product similarity computation functions for embeddings.""" 15 | from __future__ import annotations 16 | 17 | from typing import TYPE_CHECKING 18 | 19 | import tensorflow as tf 20 | 21 | if TYPE_CHECKING: 22 | from ..types import FloatTensor 23 | 24 | from .distance import Distance 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Similarity") 28 | class ManhattanDistance(Distance): 29 | """Compute pairwise Manhattan distances between embeddings. 30 | 31 | The [Manhattan Distance](https://en.wikipedia.org/wiki/Euclidean_distance) 32 | is the sum of the lengths of the projections of the line segment between 33 | two embeddings onto the Cartesian axes. The larger the distance the more 34 | dissimilar the embeddings are. 35 | """ 36 | 37 | def __init__(self, name: str = "manhattan", **kwargs): 38 | "Init Manhattan distance" 39 | super().__init__(name=name, **kwargs) 40 | 41 | @tf.function 42 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor: 43 | """Compute pairwise distances for a given batch of embeddings. 44 | 45 | Args: 46 | query_embeddings: Embeddings to compute the pairwise one. 47 | key_embeddings: Embeddings to compute the pairwise one. 48 | 49 | Returns: 50 | FloatTensor: Pairwise distance tensor. 51 | """ 52 | q_rs = tf.reshape(query_embeddings, shape=[tf.shape(query_embeddings)[0], -1]) 53 | k_rs = tf.reshape(key_embeddings, shape=[tf.shape(key_embeddings)[0], -1]) 54 | deltas = tf.expand_dims(q_rs, axis=1) - tf.expand_dims(k_rs, axis=0) 55 | distances: FloatTensor = tf.norm(deltas, 1, axis=2) 56 | return distances 57 | -------------------------------------------------------------------------------- /tensorflow_similarity/distances/inner_product.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inner product similarity computation functions for embeddings.""" 15 | from __future__ import annotations 16 | 17 | from typing import TYPE_CHECKING 18 | 19 | import tensorflow as tf 20 | 21 | if TYPE_CHECKING: 22 | from ..types import FloatTensor 23 | 24 | from .distance import Distance 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Similarity") 28 | class InnerProductSimilarity(Distance): 29 | """Compute the pairwise inner product between embeddings. 30 | 31 | The [Inner product](https://en.wikipedia.org/wiki/Inner_product_space) is 32 | a measure of similarity where the more similar vectors have the largest 33 | values. 34 | 35 | NOTE! This is not a distance and is likely not what you want to use with 36 | the built in losses. At the very least this will flip the sign on the 37 | margin in many of the losses. This is likely meant to be used with custom 38 | loss functions that expect a similarity instead of a distance. 39 | """ 40 | 41 | def __init__(self, name: str = "inner_product", **kwargs): 42 | "Init Inner product similarity" 43 | super().__init__(name=name, **kwargs) 44 | 45 | @tf.function 46 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor: 47 | """Compute pairwise similarities for a given batch of embeddings. 48 | 49 | Args: 50 | query_embeddings: Embeddings to compute the pairwise one. 51 | key_embeddings: Embeddings to compute the pairwise one. 52 | 53 | Returns: 54 | FloatTensor: Pairwise distance tensor. 55 | """ 56 | sims: FloatTensor = tf.linalg.matmul(query_embeddings, key_embeddings, transpose_b=True) 57 | return sims 58 | -------------------------------------------------------------------------------- /tensorflow_similarity/matchers/match_nearest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | if TYPE_CHECKING: 19 | from ..types import FloatTensor, IntTensor 20 | 21 | from .classification_match import ClassificationMatch 22 | 23 | 24 | class MatchNearest(ClassificationMatch): 25 | """Match metrics for labels at k=1.""" 26 | 27 | def __init__(self, name: str = "nearest", **kwargs) -> None: 28 | if "canonical_name" not in kwargs: 29 | kwargs["canonical_name"] = "match_nearest" 30 | 31 | super().__init__(name=name, **kwargs) 32 | 33 | def derive_match(self, lookup_labels: IntTensor, lookup_distances: FloatTensor) -> tuple[IntTensor, FloatTensor]: 34 | """Derive a match label and distance from a set of K neighbors. 35 | 36 | For each query, derive a single match label and distance given the 37 | associated set of lookup labels and distances. 38 | 39 | Args: 40 | lookup_labels: A 2D array where the jth row is the labels 41 | associated with the set of k neighbors for the jth query. 42 | 43 | lookup_distances: A 2D array where the jth row is the distances 44 | between the jth query and the set of k neighbors. 45 | 46 | Returns: 47 | A Tuple of FloatTensors: 48 | derived_labels: A FloatTensor of shape 49 | [len(lookup_labels), 1] where the jth row contains the derived 50 | label for the jth query. 51 | 52 | derived_distances: A FloatTensor of shape 53 | [len(lookup_labels), 1] where the jth row contains the distance 54 | associated with the jth derived label. 55 | """ 56 | 57 | return lookup_labels[:, :1], lookup_distances[:, :1] 58 | -------------------------------------------------------------------------------- /tensorflow_similarity/classification_metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import Type 17 | 18 | from .binary_accuracy import BinaryAccuracy # noqa 19 | from .classification_metric import ClassificationMetric # noqa 20 | from .f1_score import F1Score # noqa 21 | from .false_positive_rate import FalsePositiveRate # noqa 22 | from .negative_predictive_value import NegativePredictiveValue # noqa 23 | from .precision import Precision # noqa 24 | from .recall import Recall # noqa 25 | 26 | 27 | def make_classification_metric(metric: str | ClassificationMetric, name: str = "") -> ClassificationMetric: 28 | """Convert classification metric from str name to object if needed. 29 | 30 | Args: 31 | metric: ClassificationMetric() or metric name. 32 | 33 | Raises: 34 | ValueError: Unknown metric name: {metric}, typo? 35 | 36 | Returns: 37 | ClassificationMetric: Instantiated metric if needed. 38 | """ 39 | # ! Metrics must be non-instantiated. 40 | METRICS_ALIASES: dict[str, Type[ClassificationMetric]] = { 41 | "recall": Recall, 42 | "precision": Precision, 43 | "f1": F1Score, 44 | "f1score": F1Score, 45 | "f1_score": F1Score, 46 | "binary_accuracy": BinaryAccuracy, 47 | "npv": NegativePredictiveValue, 48 | "negative_predicitve_value": NegativePredictiveValue, 49 | "fpr": FalsePositiveRate, 50 | "false_positive_rate": FalsePositiveRate, 51 | } 52 | 53 | if isinstance(metric, str): 54 | if metric.lower() in METRICS_ALIASES: 55 | metric = METRICS_ALIASES[metric.lower()](name=metric.lower()) 56 | else: 57 | raise ValueError(f"Unknown metric name: {metric}, typo?") 58 | 59 | if name: 60 | metric.name = name 61 | 62 | return metric 63 | -------------------------------------------------------------------------------- /tests/stores/test_memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tensorflow_similarity.stores.memory import MemoryStore 4 | 5 | 6 | def build_store(records): 7 | kv_store = MemoryStore() 8 | idxs = [] 9 | for r in records: 10 | idx = kv_store.add(r[0], r[1], r[2]) 11 | idxs.append(idx) 12 | return kv_store, idxs 13 | 14 | 15 | def test_in_memory_store_and_retrieve(): 16 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 17 | 18 | kv_store, idxs = build_store(records) 19 | 20 | # check index numbering 21 | for gt, idx in enumerate(idxs): 22 | assert isinstance(idx, int) 23 | assert gt == idx 24 | 25 | # check reference counting 26 | assert kv_store.size() == 2 27 | 28 | # get back three elements 29 | for idx in idxs: 30 | emb, lbl, dt = kv_store.get(idx) 31 | assert emb == records[idx][0] 32 | assert lbl == records[idx][1] 33 | assert dt == records[idx][2] 34 | 35 | 36 | def test_reset(): 37 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 38 | 39 | kv_store, idxs = build_store(records) 40 | 41 | # check reference counting 42 | assert kv_store.size() == 2 43 | 44 | kv_store.reset() 45 | assert kv_store.size() == 0 46 | 47 | kv_store.add(records[0][0], records[0][1], records[0][2]) 48 | assert kv_store.size() == 1 49 | 50 | 51 | def test_batch_add(): 52 | embs = np.array([[0.1, 0.2], [0.2, 0.3]]) 53 | lbls = np.array([1, 2]) 54 | data = np.array([[0, 0, 0], [1, 1, 1]]) 55 | 56 | kv_store = MemoryStore() 57 | idxs = kv_store.batch_add(embs, lbls, data) 58 | for idx in idxs: 59 | emb, lbl, dt = kv_store.get(idx) 60 | assert np.array_equal(emb, embs[idx]) 61 | assert np.array_equal(lbl, lbls[idx]) 62 | assert np.array_equal(dt, data[idx]) 63 | 64 | 65 | def test_save_and_reload(tmp_path): 66 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 67 | 68 | kv_store, idxs = build_store(records) 69 | kv_store.save(str(tmp_path)) 70 | 71 | # reload 72 | reloaded_store = MemoryStore() 73 | reloaded_store.load(tmp_path) 74 | 75 | assert reloaded_store.size() == 2 76 | 77 | # get back three elements 78 | for idx in idxs: 79 | emb, lbl, dt = reloaded_store.get(idx) 80 | assert np.array_equal(emb, records[idx][0]) 81 | assert np.array_equal(lbl, records[idx][1]) 82 | assert np.array_equal(dt, records[idx][2]) 83 | -------------------------------------------------------------------------------- /benchmark/supervised/create_datasets.py: -------------------------------------------------------------------------------- 1 | "Supervised loss benchmark" 2 | from __future__ import annotations 3 | 4 | import argparse 5 | import json 6 | import os 7 | import re 8 | from collections.abc import Mapping 9 | from typing import Any 10 | 11 | import tensorflow as tf 12 | import tensorflow.random 13 | from components import datasets, make_augmentations, utils 14 | from termcolor import cprint 15 | 16 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 17 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 18 | 19 | 20 | def run(cfg: Mapping[str, Any], filter_pattern: str) -> None: 21 | if cfg.get("tfds_data_dir", None): 22 | os.environ["TFDS_DATA_DIR"] = cfg["tfds_data_dir"] 23 | 24 | version = cfg["version"] 25 | random_seed = cfg["random_seed"] 26 | output_dir = cfg["dataset_dir"] 27 | preproc_fns = make_augmentations(cfg["preprocess"]) 28 | 29 | p = re.compile(filter_pattern) 30 | data_dir = os.path.join(output_dir, version) 31 | dataset_cfgs = [] 32 | for name, cfg in cfg["datasets"].items(): 33 | if p.match(name): 34 | dataset_cfgs.append(datasets.utils.make_dataset_config(name, cfg)) 35 | 36 | for d in dataset_cfgs: 37 | cprint(f"|-{d.name}", "blue") 38 | 39 | cprint(f"{len(dataset_cfgs)} Datasets\n", "blue") 40 | if input("Would you like to continue: [Y/n] ").lower() != "y": 41 | cprint("Exit", "red") 42 | return 43 | else: 44 | cprint("Building datasets", "green") 45 | 46 | for dcfg in dataset_cfgs: 47 | d = datasets.utils.make_dataset(dcfg, data_dir) 48 | utils.set_random_seeds(random_seed) 49 | 50 | # Make result path 51 | cprint(f"\n|-Clearing all files in {d.path}", "blue") 52 | utils.clean_dir(d.path) 53 | 54 | # Load the raw dataset 55 | cprint(f"\n|-Loading and preprocessing {d.name}\n", "blue") 56 | d.load_raw_data(preproc_fns) 57 | d.split_raw_data() 58 | d.save_serialized_data() 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser(description="Generate datasets") 63 | parser.add_argument("--config", "-c", help="config path") 64 | parser.add_argument("--filter", "-f", help="run only the config ids that match the regexp", default=".*") 65 | args = parser.parse_args() 66 | 67 | if not args.config: 68 | parser.print_usage() 69 | quit() 70 | 71 | config = json.loads(open(args.config).read()) 72 | run(config, filter_pattern=args.filter) 73 | -------------------------------------------------------------------------------- /api/TFSimilarity/visualization/projector.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.visualization.projector 2 | 3 | 4 | 5 | 6 | 7 | Visualize the embeddings in 2D or 3D using UMAP projection 8 | 9 | ```python 10 | TFSimilarity.visualization.projector( 11 | labels: Optional[Sequence[Any]] = None, 12 | class_mapping: Optional[Sequence[int]] = None, 13 | image_size: int = 64, 14 | tooltips_info: Optional[Mapping[str, Sequence[str]]] = None, 15 | pt_size: int = 3, 16 | colorize: bool = True, 17 | pastel_factor: float = 0.1, 18 | plot_size: int = 600, 19 | active_drag: str = box_zoom, 20 | densmap: bool = True 21 | ) 22 | ``` 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 38 | 42 | 43 | 46 | 50 | 51 | 54 | 58 | 59 | 62 | 65 | 66 | 69 | 73 | 74 | 77 | 81 | 82 | 85 | 88 | 89 | 92 | 95 | 96 | 99 | 102 | 103 | 106 | 110 | 111 |
36 | embeddings 37 | 39 | The embeddings outputed by the model that 40 | are to be visualized 41 |
44 | labels 45 | 47 | Labels associated with the embeddings. If not supplied treat 48 | each example as its own classes. 49 |
52 | class_mapping 53 | 55 | Dictionary or list that maps the class numerical ids 56 | to their name. 57 |
60 | images 61 | 63 | Images to display in tooltip on hover. Usually x_test tensor. 64 |
67 | image_size 68 | 70 | size of the images displayed in the tool tip. 71 | Defaults to 64. 72 |
75 | pt_size 76 | 78 | Size of the points displayed on the visualization. 79 | Defaults to 3. 80 |
83 | tooltips_info 84 | 86 | Dictionary of information to display in the tooltips. 87 |
90 | colorize 91 | 93 | Colorize the clusters. Defaults to true. 94 |
97 | pastel_factor 98 | 100 | Modify the color palette to be more pastel. 101 |
104 | densmap 105 | 107 | Use UMAP dense mapper which provides better density 108 | estimation but is a little slower. Defaults to True. 109 |
112 | 113 | -------------------------------------------------------------------------------- /benchmark/supervised/components/experiments.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import os 5 | from collections.abc import Mapping 6 | from itertools import product 7 | from typing import Any 8 | 9 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule 10 | 11 | from . import datasets, utils 12 | 13 | 14 | @dataclasses.dataclass(eq=True, frozen=True) 15 | class Component: 16 | cid: str 17 | name: str 18 | params: Mapping[str, Any] 19 | 20 | def __hash__(self): 21 | return hash((self.cid, self.name)) 22 | 23 | 24 | @dataclasses.dataclass 25 | class Experiment: 26 | run_grp: str 27 | path: str 28 | dataset: Component 29 | architecture: Component 30 | loss: Component 31 | opt: Component 32 | training: Component 33 | lr_schedule: LearningRateSchedule | None = None 34 | 35 | 36 | def make_experiments(cfg: Mapping[str, Any], output_dir: str) -> list[Experiment]: 37 | experiments = [] 38 | 39 | # Generate the cross product of all the experiment params. 40 | for (dn, dcfg), (an, acfg), (ln, lcfg), (on, ocfg), tcfg in product( 41 | cfg["datasets"].items(), 42 | cfg["architectures"].items(), 43 | cfg["losses"].items(), 44 | cfg["optimizer"].items(), 45 | cfg["training"], 46 | ): 47 | dataset = datasets.utils.make_dataset_config(name=dn, params=dcfg) 48 | loss = Component(cid=lcfg["component"], name=ln, params=lcfg) 49 | opt = Component(cid=ocfg["component"], name=on, params=ocfg) 50 | training = Component(cid="", name=tcfg["name"], params=tcfg) 51 | 52 | for embedding_size in acfg.get("embedding_sizes", [128]): 53 | acfg["embedding"] = embedding_size 54 | architecture = Component(cid=acfg["component"], name=an, params=acfg) 55 | 56 | run_grp = utils.make_run_grp( 57 | dataset.name, 58 | architecture.name, 59 | architecture.params["embedding"], 60 | loss.name, 61 | opt.name, 62 | ) 63 | experiments.append( 64 | Experiment( 65 | run_grp=run_grp, 66 | path=os.path.join(output_dir, run_grp), 67 | dataset=dataset, 68 | architecture=architecture, 69 | loss=loss, 70 | opt=opt, 71 | training=training, 72 | ) 73 | ) 74 | return experiments 75 | -------------------------------------------------------------------------------- /tensorflow_similarity/samplers/file_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING, TypeVar 17 | 18 | import tensorflow as tf 19 | 20 | from .memory_samplers import MultiShotMemorySampler 21 | 22 | if TYPE_CHECKING: 23 | from collections.abc import Callable, Sequence 24 | 25 | from ..types import FloatTensor, IntTensor 26 | from .samplers import Augmenter 27 | 28 | T = TypeVar("T", FloatTensor, IntTensor) 29 | 30 | 31 | def load_image(path: str, target_size: tuple[int, int] | None = None) -> T: 32 | image_string = tf.io.read_file(path) 33 | image: T = tf.image.decode_jpeg(image_string, channels=3) 34 | image = tf.image.convert_image_dtype(image, tf.float32) 35 | if target_size: 36 | image = tf.image.resize(image, target_size, method=tf.image.ResizeMethod.LANCZOS3) 37 | image = tf.clip_by_value(image, 0.0, 1.0) 38 | return image 39 | 40 | 41 | class MultiShotFileSampler(MultiShotMemorySampler): 42 | def __init__( 43 | self, 44 | x, 45 | y, 46 | load_example_fn: Callable = load_image, 47 | classes_per_batch: int = 2, 48 | examples_per_class_per_batch: int = 2, 49 | steps_per_epoch: int = 1000, 50 | class_list: Sequence[int] | None = None, 51 | total_examples_per_class: int | None = None, 52 | augmenter: Augmenter | None = None, 53 | warmup: int = -1, 54 | ): 55 | super().__init__( 56 | x, 57 | y, 58 | load_example_fn=load_example_fn, 59 | classes_per_batch=classes_per_batch, 60 | examples_per_class_per_batch=examples_per_class_per_batch, 61 | steps_per_epoch=steps_per_epoch, 62 | class_list=class_list, 63 | total_examples_per_class=total_examples_per_class, 64 | augmenter=augmenter, 65 | warmup=warmup, 66 | ) 67 | -------------------------------------------------------------------------------- /tensorflow_similarity/distances/snr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inner product similarity computation functions for embeddings.""" 15 | from __future__ import annotations 16 | 17 | from typing import TYPE_CHECKING 18 | 19 | import tensorflow as tf 20 | 21 | if TYPE_CHECKING: 22 | from ..types import FloatTensor 23 | 24 | from .distance import Distance 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Similarity") 28 | class SNRDistance(Distance): 29 | """ 30 | Computes pairwise SNR distances between embeddings. 31 | 32 | The [Signal-to-Noise Ratio distance](https://arxiv.org/abs/1904.02616) 33 | is the ratio of noise variance to the feature variance. 34 | """ 35 | 36 | def __init__(self, name: str = "snr", **kwargs): 37 | "Init SNR distance" 38 | super().__init__(name=name, **kwargs) 39 | 40 | @tf.function 41 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor: 42 | """Compute pairwise snr distances for a given batch of embeddings. 43 | SNR(i, j): anchor i and compared feature j 44 | SNR(i,j) may not be equal to SNR(j, i) 45 | 46 | Args: 47 | query_embeddings: Embeddings to compute the pairwise one. 48 | 49 | Returns: 50 | FloatTensor: Pairwise distance tensor. 51 | """ 52 | # Calculating feature variance for each example 53 | anchor_var = tf.math.reduce_variance(query_embeddings, axis=1) 54 | 55 | # Calculating pairwise noise variances 56 | q_rs = tf.reshape(query_embeddings, shape=[tf.shape(query_embeddings)[0], -1]) 57 | k_rs = tf.reshape(key_embeddings, shape=[tf.shape(key_embeddings)[0], -1]) 58 | delta = tf.expand_dims(q_rs, axis=1) - tf.expand_dims(k_rs, axis=0) 59 | noise_var = tf.math.reduce_variance(delta, axis=2) 60 | 61 | distances: FloatTensor = tf.divide(noise_var, tf.expand_dims(anchor_var, axis=1)) 62 | 63 | return distances 64 | -------------------------------------------------------------------------------- /api/TFSimilarity/architectures/ResNet18Sim.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.architectures.ResNet18Sim 2 | 3 | 4 | 5 | 6 | 7 | Build an ResNet18 Model backbone for similarity learning 8 | 9 | ```python 10 | TFSimilarity.architectures.ResNet18Sim( 11 | input_shape: Tuple[int, int, int], 12 | embedding_size: int = 128, 13 | l2_norm: bool = True, 14 | include_top: bool = True, 15 | pooling: str = gem, 16 | gem_p=3.0 17 | ``` 18 | 19 | 20 | 21 | 22 | 23 | Architecture from [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 38 | 39 | 42 | 46 | 47 | 50 | 56 | 57 | 60 | 64 | 65 | 68 | 80 | 81 | 84 | 90 | 91 |
32 | input_shape 33 | 35 | Expected to be betweeen 32 and 224 and in the (H, W, C) 36 | data_format. 37 |
40 | embedding_size 41 | 43 | Size of the output embedding. Usually between 64 44 | and 512. Defaults to 128. 45 |
48 | l2_norm 49 | 51 | If True and include_top is also True, then 52 | tfsim.layers.MetricEmbedding is used as the last layer, otherwise 53 | keras.layers.Dense is used. This should be true when using cosine 54 | distance. Defaults to True. 55 |
58 | include_top 59 | 61 | Whether to include a fully-connected layer of 62 | embedding_size at the top of the network. Defaults to True. 63 |
66 | pooling 67 | 69 | Optional pooling mode for feature extraction when 70 | include_top is False. Defaults to gem. 71 | - None means that the output of the model will be the 4D tensor 72 | output of the last convolutional layer. 73 | - avg means that global average pooling will be applied to the 74 | output of the last convolutional layer, and thus the output of the 75 | model will be a 2D tensor. 76 | - max means that global max pooling will be applied. 77 | - gem means that global GeneralizedMeanPooling2D will be applied. 78 | The gem_p param sets the contrast amount on the pooling. 79 |
82 | gem_p 83 | 85 | Sets the power in the GeneralizedMeanPooling2D layer. A value 86 | of 1.0 is equivelent to GlobalMeanPooling2D, while larger values 87 | will increase the contrast between activations within each feature 88 | map, and a value of math.inf will be equivelent to MaxPool2d. 89 |
92 | 93 | -------------------------------------------------------------------------------- /tensorflow_similarity/architectures/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import tensorflow as tf 17 | 18 | 19 | def convert_sync_batchnorm(model: tf.keras.Model) -> tf.keras.Model: 20 | """Replace BatchNormalization layers to SyncBatchNormalization in place. 21 | WARNINGS: 22 | * This function is tested only with efficientnet and resnet 23 | * The returned model has shared layers with the input one. One of them should be disposed. 24 | """ 25 | 26 | layer2newtensor = {} 27 | in_ = tf.keras.layers.Input(model.input.shape[1:]) 28 | layer2newtensor[model.input.name] = in_ 29 | for layer in model.layers[1:]: 30 | assert len(layer.inbound_nodes) == 1 31 | if isinstance(layer.inbound_nodes[0].inbound_layers, list): # mutliple inputs 32 | x = [layer2newtensor[in_l.name] for in_l in layer.inbound_nodes[0].inbound_layers] 33 | else: 34 | x = layer2newtensor[layer.inbound_nodes[0].inbound_layers.name] 35 | if isinstance(layer, tf.keras.layers.BatchNormalization): 36 | tf_version = [int(v) for v in tf.__version__.split(".")] 37 | if tf_version[0] == 2 and tf_version[1] < 12: 38 | layer = tf.keras.layers.experimental.SyncBatchNormalization(**layer.get_config()) 39 | else: 40 | layer = tf.keras.layers.BatchNormalization(**layer.get_config(), synchronized=True) 41 | 42 | if "truediv" in layer.name: 43 | # efficeientnet edge case 44 | # https://github.com/keras-team/keras/blob/v2.9.0/keras/applications/efficientnet.py#L334 45 | x = layer(x, layer.inbound_nodes[0]._flat_arguments[1]) 46 | else: 47 | x = layer(x) 48 | 49 | layer2newtensor[layer.name] = x 50 | out_ = layer2newtensor[model.layers[-1].name] 51 | new_model = tf.keras.Model(inputs=in_, outputs=out_, name=model.name) 52 | new_model.set_weights(model.get_weights()) 53 | return new_model 54 | -------------------------------------------------------------------------------- /.github/workflows/nightly-publish.yml: -------------------------------------------------------------------------------- 1 | name: NightlyPublish 2 | 3 | on: 4 | workflow_dispatch: # Allow manual triggers 5 | schedule: 6 | # Runs every day at 3:07am UTC. 7 | - cron: '7 3 * * *' 8 | 9 | jobs: 10 | check: 11 | 12 | name: Check files 13 | # Prevent Publish from running on forks. 14 | if: github.repository == 'tensorflow/similarity' 15 | outputs: 16 | run_job: ${{ steps.check_files.outputs.run_job }} 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | with: 22 | fetch-depth: 2 23 | ref: 'development' 24 | 25 | - name: check modified files 26 | id: check_files 27 | run: | 28 | echo "=============== list modified files ===============" 29 | git diff --name-only HEAD^ HEAD 30 | 31 | echo "========== check paths of modified files ==========" 32 | git diff --name-only HEAD^ HEAD > files.txt 33 | while IFS= read -r file 34 | do 35 | echo $file 36 | if [[ $file = tensorflow_similarity/* && $file != tensorflow_similarity/__init__.py ]]; then 37 | echo "This modified file is under the 'tensorflow_similarity' folder." 38 | echo "::set-output name=run_job::true" 39 | break 40 | else 41 | echo "::set-output name=run_job::false" 42 | fi 43 | done < files.txt 44 | 45 | publish: 46 | 47 | name: Publish nightly 48 | needs: check 49 | # Prevent Publish from running on forks. 50 | if: | 51 | github.repository == 'tensorflow/similarity' && 52 | needs.check.outputs.run_job == 'true' 53 | runs-on: ubuntu-latest 54 | 55 | steps: 56 | - uses: actions/checkout@v2 57 | with: 58 | ref: 'development' 59 | 60 | - name: Set up Python 61 | uses: actions/setup-python@v2 62 | with: 63 | python-version: '3.9' 64 | 65 | - name: Install dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | 69 | - name: Install package 70 | run: | 71 | pip install ".[tensorflow,dev]" 72 | 73 | - name: Increment dev version 74 | run: | 75 | # Increments the dev version and pushes the changes to development. 76 | python scripts/increment_version.py 77 | 78 | - name: Build package 79 | run: | 80 | python setup.py sdist bdist_wheel --project_name tfsim-nightly 81 | 82 | - name: Publish package 83 | run: | 84 | twine upload -u ${{ secrets.PYPI_NIGHTLY_USERNAME }} -p ${{ secrets.PYPI_NIGHTLY_TOKEN }} dist/* --verbose 85 | -------------------------------------------------------------------------------- /tensorflow_similarity/distances/distance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Vectorized embedding pairwise distances computation functions""" 15 | from __future__ import annotations 16 | 17 | from abc import ABC, abstractmethod 18 | from typing import TYPE_CHECKING, Any 19 | 20 | if TYPE_CHECKING: 21 | from ..types import FloatTensor 22 | 23 | 24 | class Distance(ABC): 25 | """Abstract class for distance computation.""" 26 | 27 | def __init__(self, name: str, **kwargs): 28 | self.name = name 29 | 30 | @abstractmethod 31 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor: 32 | """Compute pairwise distances for a given batch. 33 | 34 | Args: 35 | query_embeddings: Embeddings to compute the pairwise one. 36 | key_embeddings: Embeddings to compute the pairwise one. 37 | 38 | Returns: 39 | FloatTensor: Pairwise distance tensor. 40 | """ 41 | 42 | def __call__(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor): 43 | return self.call(query_embeddings, key_embeddings) 44 | 45 | def __str__(self) -> str: 46 | return self.name 47 | 48 | def get_config(self) -> dict[str, Any]: 49 | """Contains the distance configuration. 50 | 51 | Returns: 52 | A Python dict containing the configuration of the distance obj. 53 | """ 54 | config = {"name": self.name} 55 | 56 | return config 57 | 58 | @classmethod 59 | def from_config(cls, config: dict[str, Any]) -> Distance: 60 | """Build a distance from a config. 61 | 62 | Args: 63 | config: A Python dict containing the configuration of the distance. 64 | 65 | Returns: 66 | A distance instance. 67 | """ 68 | try: 69 | return cls(**config) 70 | except Exception as e: 71 | raise TypeError( 72 | f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}" 73 | ) 74 | -------------------------------------------------------------------------------- /tests/losses/test_softnn_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from absl.testing import parameterized 3 | from tensorflow.keras.losses import Reduction 4 | from tensorflow.python.framework import combinations 5 | 6 | from tensorflow_similarity import losses 7 | 8 | 9 | @tf.function 10 | def softnn_util(y_true, x, temperature=1): 11 | """ 12 | A simple loop based implementation of soft 13 | nearest neighbor loss to test the code. 14 | https://arxiv.org/pdf/1902.01889.pdf 15 | """ 16 | 17 | batch_size = tf.shape(y_true)[0] 18 | loss = 0.0 19 | for i in tf.range(batch_size, dtype=tf.int32): 20 | numerator = 0.0 21 | denominator = 0.0 22 | for j in tf.range(batch_size, dtype=tf.int32): 23 | if i == j: 24 | continue 25 | if y_true[i] == y_true[j]: 26 | numerator += tf.math.exp(-1 * tf.math.reduce_sum(tf.math.square(x[i] - x[j])) / temperature) 27 | denominator += tf.math.exp(-1 * tf.math.reduce_sum(tf.math.square(x[i] - x[j])) / temperature) 28 | if numerator == 0.0: 29 | continue 30 | loss += tf.math.log(numerator / denominator) 31 | return -loss / tf.cast(batch_size, tf.float32) 32 | 33 | 34 | @combinations.generate( 35 | combinations.combine( 36 | mode=["graph", "eager"], 37 | ) 38 | ) 39 | class SoftNNLossTest(tf.test.TestCase, parameterized.TestCase): 40 | def test_config(self): 41 | softnn_obj = losses.SoftNearestNeighborLoss( 42 | reduction=Reduction.SUM, 43 | name="softnn_loss", 44 | distance="cosine", 45 | ) 46 | self.assertEqual(softnn_obj.distance.name, "cosine") 47 | self.assertEqual(softnn_obj.name, "softnn_loss") 48 | self.assertEqual(softnn_obj.reduction, Reduction.SUM) 49 | 50 | @parameterized.parameters((0.1), (0.5), (1), (2), (5), (10), (50)) 51 | def test_all_correct(self, temperature): 52 | num_inputs = 10 53 | n_classes = 10 54 | 55 | softnn_obj = losses.SoftNearestNeighborLoss( 56 | reduction=Reduction.SUM_OVER_BATCH_SIZE, 57 | temperature=temperature, 58 | ) 59 | 60 | # y_true: labels 61 | y_true = tf.random.uniform((num_inputs, 1), 0, n_classes, dtype=tf.int32) 62 | # x: embeddings 63 | y_preds = tf.random.uniform((num_inputs, 20), 0, 1) 64 | 65 | loss = softnn_obj(y_true, y_preds) 66 | loss_check = softnn_util( 67 | y_true, 68 | y_preds, 69 | temperature, 70 | ) 71 | 72 | loss_diff = loss - loss_check 73 | self.assertLess(self.evaluate(loss_diff), 1e-3) 74 | -------------------------------------------------------------------------------- /tests/models/test_similarity_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | 16 | from tensorflow_similarity.losses import TripletLoss 17 | from tensorflow_similarity.models import SimilarityModel 18 | 19 | 20 | class SimilarityModelTest(tf.test.TestCase): 21 | # TODO(ovallis): Add tests for graph mode. 22 | def test_save_and_reload(self): 23 | out_dir = self.get_temp_dir() 24 | 25 | inputs = tf.keras.layers.Input(shape=(3,)) 26 | outputs = tf.keras.layers.Dense(2)(inputs) 27 | model = SimilarityModel(inputs, outputs) 28 | model.compile(optimizer="adam", loss=TripletLoss()) 29 | 30 | # index data 31 | x = tf.constant([[1, 1, 3], [3, 1, 2]], dtype="float32") 32 | y = tf.constant([1, 2]) 33 | model.index(x, y) 34 | 35 | # save 36 | model.save(out_dir) 37 | 38 | # reload 39 | loaded_model = tf.keras.models.load_model(out_dir) 40 | loaded_model.load_index(out_dir) 41 | self.assertEqual(loaded_model._index.size(), len(y)) 42 | 43 | def test_save_no_compile(self): 44 | out_dir = self.get_temp_dir() 45 | 46 | inputs = tf.keras.layers.Input(shape=(3,)) 47 | outputs = tf.keras.layers.Dense(2)(inputs) 48 | model = SimilarityModel(inputs, outputs) 49 | 50 | model.save(out_dir) 51 | model2 = tf.keras.models.load_model(out_dir) 52 | self.assertIsInstance(model2, type(model)) 53 | 54 | def test_index_single(self): 55 | """Unit Test for issues #161 & #162""" 56 | inputs = tf.keras.layers.Input(shape=(3,)) 57 | outputs = tf.keras.layers.Dense(2)(inputs) 58 | model = SimilarityModel(inputs, outputs) 59 | model.compile(optimizer="adam", loss=TripletLoss()) 60 | 61 | # index data 62 | x = tf.constant([1, 1, 3], dtype="float32") 63 | y = tf.constant([1]) 64 | 65 | # run individual sample & index 66 | model.index_single(x, y, data=x) 67 | self.assertEqual(model._index.size(), 1) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /tests/retrieval_metrics/test_map_at_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow_similarity.retrieval_metrics import MapAtK 5 | 6 | 7 | def test_concrete_instance(): 8 | rm = MapAtK(r={1: 4, 0: 4}) 9 | 10 | assert rm.name == "map@5" 11 | # Check the name once we have updated the threshold 12 | rm.distance_threshold = 0.1 13 | assert rm.name == "map@5 : distance_threshold@0.1" 14 | assert repr(rm) == "map@k : map@5 : distance_threshold@0.1" 15 | assert rm.canonical_name == "map@k" 16 | assert rm.k == 5 17 | assert rm.distance_threshold == 0.1 18 | assert rm.average == "micro" 19 | 20 | expected_config = { 21 | "r": {1: 4, 0: 4}, 22 | "name": "map@5 : distance_threshold@0.1", 23 | "canonical_name": "map@k", 24 | "clip_at_r": False, 25 | "k": 5, 26 | "distance_threshold": 0.1, 27 | } 28 | assert rm.get_config() == expected_config 29 | 30 | 31 | def test_compute(): 32 | query_labels = tf.constant([1, 1, 1, 0]) 33 | match_mask = tf.constant( 34 | [ 35 | [True, True, False], 36 | [True, True, False], 37 | [True, True, False], 38 | [False, False, True], 39 | ], 40 | dtype=bool, 41 | ) 42 | rm = MapAtK(r={0: 10, 1: 3}, k=3) 43 | 44 | mapk = rm.compute(query_labels=query_labels, match_mask=match_mask) 45 | 46 | # mapk should be sum(precision@k*Relevancy_Mask)/R 47 | # class 1 has 3 results sets that are all T,T,F: 48 | # (1.0*True+1.0*True+0.66*False)/3 = 0.66667 49 | # class 0 has 1 result set that is F,F,T 50 | # (0.0*False+0.0*False+0.33*True)/10 = 0.03332 51 | # mapk = (0.667*3 + 0.0332)/4 52 | expected = tf.constant(0.50833333332) 53 | 54 | np.testing.assert_allclose(mapk, expected) 55 | 56 | 57 | def test_clip_at_r(): 58 | query_labels = tf.constant([0, 1]) 59 | match_mask = tf.constant( 60 | [ 61 | [False, False, False, True], 62 | [False, False, False, True], 63 | ], 64 | dtype=bool, 65 | ) 66 | rm = MapAtK(r={0: 4, 1: 3}, k=4, clip_at_r=True) 67 | 68 | mapk = rm.compute(query_labels=query_labels, match_mask=match_mask) 69 | 70 | # mapk should be sum(precision@k*Relevancy_Mask)/R 71 | # but here we clip the result set at the r associate with the query label. 72 | # class 0 has 1 result set of F,F,F,T with R == 4 73 | # (0.0*False+0.0*False+0.0*False)/3 = 0.0 74 | # class 1 has 1 result set of F,F,F,T with R == 3 75 | # (0.0*False+0.0*False+0.0*False+0.25*True)/4 = 0.0625 76 | # mapk = (0.0 + 0.0625)/2 77 | expected = tf.constant(0.03125) 78 | 79 | np.testing.assert_allclose(mapk, expected) 80 | -------------------------------------------------------------------------------- /tests/stores/test_cached.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from tensorflow_similarity.stores.cached import CachedStore 6 | 7 | 8 | def build_store(records, path): 9 | kv_store = CachedStore(path=path) 10 | idxs = [] 11 | for r in records: 12 | idx = kv_store.add(r[0], r[1], r[2]) 13 | idxs.append(idx) 14 | return kv_store, idxs 15 | 16 | 17 | def test_cached_store_and_retrieve(tmp_path): 18 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 19 | 20 | kv_store, idxs = build_store(records, tmp_path) 21 | 22 | # check index numbering 23 | for gt, idx in enumerate(idxs): 24 | assert isinstance(idx, int) 25 | assert gt == idx 26 | 27 | # check reference counting 28 | assert kv_store.size() == 2 29 | 30 | # get back three elements 31 | for idx in idxs: 32 | emb, lbl, dt = kv_store.get(idx) 33 | assert emb == records[idx][0] 34 | assert lbl == records[idx][1] 35 | assert dt == records[idx][2] 36 | 37 | 38 | def test_reset(tmp_path): 39 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 40 | 41 | kv_store, idxs = build_store(records, tmp_path) 42 | 43 | # check reference counting 44 | assert kv_store.size() == 2 45 | 46 | kv_store.reset() 47 | assert kv_store.size() == 0 48 | 49 | kv_store.add(records[0][0], records[0][1], records[0][2]) 50 | assert kv_store.size() == 1 51 | 52 | 53 | def test_batch_add(tmp_path): 54 | embs = np.array([[0.1, 0.2], [0.2, 0.3]]) 55 | lbls = np.array([1, 2]) 56 | data = np.array([[0, 0, 0], [1, 1, 1]]) 57 | 58 | kv_store = CachedStore(path=tmp_path) 59 | idxs = kv_store.batch_add(embs, lbls, data) 60 | for idx in idxs: 61 | emb, lbl, dt = kv_store.get(idx) 62 | assert np.array_equal(emb, embs[idx]) 63 | assert np.array_equal(lbl, lbls[idx]) 64 | assert np.array_equal(dt, data[idx]) 65 | 66 | 67 | def test_save_and_reload(tmp_path): 68 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]] 69 | 70 | save_path = tmp_path / "save" 71 | os.mkdir(save_path) 72 | obj_path = tmp_path / "obj" 73 | os.mkdir(obj_path) 74 | 75 | kv_store, idxs = build_store(records, obj_path) 76 | kv_store.save(save_path) 77 | 78 | # reload 79 | reloaded_store = CachedStore() 80 | print(f"loading from {save_path}") 81 | reloaded_store.load(save_path) 82 | 83 | assert reloaded_store.size() == 2 84 | 85 | # get back three elements 86 | for idx in idxs: 87 | emb, lbl, dt = reloaded_store.get(idx) 88 | assert np.array_equal(emb, records[idx][0]) 89 | assert np.array_equal(lbl, records[idx][1]) 90 | assert np.array_equal(dt, records[idx][2]) 91 | -------------------------------------------------------------------------------- /api/TFSimilarity/distances/Distance.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.distances.Distance 2 | 3 | 4 | 5 | 6 | 7 | Note: don't forget to add your distance to the DISTANCES list 8 | 9 | Inherits From: [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | 12 | ```python 13 | TFSimilarity.distances.Distance( 14 | name: str, aliases: List[str] = [] 15 | ) 16 | ``` 17 | 18 | 19 | 20 | 21 | and add alias names in it. 22 | 23 | ## Methods 24 | 25 |

call

26 | 27 | View source 28 | 29 | ```python 30 | call( 31 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 32 | , 33 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 34 | 35 | ) -> TFSimilarity.callbacks.FloatTensor``` 36 | 37 | ``` 38 | 39 | 40 | Compute pairwise distances for a given batch. 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 52 | 55 | 56 | 59 | 62 | 63 |
Args
50 | query_embeddings 51 | 53 | Embeddings to compute the pairwise one. 54 |
57 | key_embeddings 58 | 60 | Embeddings to compute the pairwise one. 61 |
64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 76 | 79 | 80 |
Returns
74 | FloatTensor 75 | 77 | Pairwise distance tensor. 78 |
81 | 82 | 83 | 84 |

get_config

85 | 86 | View source 87 | 88 | ```python 89 | get_config() 90 | ``` 91 | 92 | 93 | 94 | 95 | 96 |

__call__

97 | 98 | View source 99 | 100 | ```python 101 | __call__( 102 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 103 | , 104 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 105 | 106 | ) 107 | ``` 108 | 109 | 110 | Call self as a function. 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /api/TFSimilarity/callbacks/SplitValidationLoss.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.callbacks.SplitValidationLoss 2 | 3 | 4 | 5 | 6 | 7 | Creates the validation callbacks. 8 | 9 | ```python 10 | TFSimilarity.callbacks.SplitValidationLoss( 11 | query_labels: Sequence[int], 12 | target_labels: Sequence[int], 13 | distance: str = cosine, 14 | metrics: Sequence[Union[str, ClassificationMetric]] = [binary_accuracy, f1score], 15 | tb_logdir: str = None, 16 | k: int = 1, 17 | ) 18 | ``` 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 37 | 38 | 41 | 44 | 45 | 48 | 51 | 52 | 55 | 58 | 59 | 62 | 65 | 66 | 69 | 73 | 74 | 77 | 83 | 84 | 87 | 90 | 91 | 94 | 99 | 100 | 103 | 110 | 111 | 114 | 119 | 120 |
32 | queries 33 | 35 | Test examples that will be tested against the built index. 36 |
39 | query_labels 40 | 42 | Queries nearest neighbors expected labels. 43 |
46 | targets 47 | 49 | Examples that are indexed. 50 |
53 | target_labels 54 | 56 | Target examples labels. 57 |
60 | known_classes 61 | 63 | The set of classes seen during training. 64 |
67 | distance 68 | 70 | Distance function used to compute pairwise distance 71 | between examples embeddings. 72 |
75 | metrics 76 | 78 | List of 79 | 'tf.similarity.classification_metrics.ClassificationMetric()` to 80 | compute during the evaluation. Defaults to ['binary_accuracy', 81 | 'f1score']. 82 |
85 | tb_logdir 86 | 88 | Where to write TensorBoard logs. Defaults to None. 89 |
92 | k 93 | 95 | The number of nearest neighbors to return for each query. 96 | The lookups are consumed by the Matching Strategy and used to 97 | derive the matching label and distance. 98 |
101 | matcher 102 | 104 | 'match_nearest', 'match_majority_vote' or 105 | ClassificationMatch object. Defines the classification matching, 106 | e.g., match_nearest will count a True Positive if the query_label 107 | is equal to the label of the nearest neighbor and the distance is 108 | less than or equal to the distance threshold. 109 |
112 | distance_thresholds 113 | 115 | A 1D tensor denoting the distances points at 116 | which we compute the metrics. If None, distance_thresholds is set 117 | to tf.constant([math.inf]) 118 |
121 | 122 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | tmp.py 3 | tensorflow_similarity.egg-info/ 4 | logs/ 5 | examples/models/ 6 | examples/logs/ 7 | .vscode/ 8 | site/ 9 | release.sh 10 | .DS_Store 11 | benchmark/supervised/datasets/ 12 | benchmark/supervised/models/ 13 | datasets/ 14 | multi_modal_datasets/ 15 | *.h5 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | 21 | *$py.class 22 | *.prof 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | .pytype 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # Pipfile 150 | Pipfile 151 | Pipfile.lock 152 | -------------------------------------------------------------------------------- /tensorflow_similarity/losses/metric_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Metric losses base class.""" 16 | from __future__ import annotations 17 | 18 | from typing import TYPE_CHECKING, Any 19 | 20 | import tensorflow as tf 21 | 22 | if TYPE_CHECKING: 23 | from collections.abc import Callable 24 | from ..types import FloatTensor 25 | 26 | from ..utils import is_tensor_or_variable 27 | 28 | 29 | class MetricLoss(tf.keras.losses.Loss): 30 | """Wraps a loss function in the `Loss` class.""" 31 | 32 | def __init__( 33 | self, fn: Callable, reduction: Callable = tf.keras.losses.Reduction.AUTO, name: str | None = None, **kwargs 34 | ): 35 | """Initializes `LossFunctionWrapper` class. 36 | Args: 37 | fn: The loss function to wrap, with signature `fn(y_true, y_pred, 38 | **kwargs)`. 39 | reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 40 | loss. Default value is `AUTO`. 41 | name: (Optional) name for the loss. 42 | **kwargs: The keyword arguments that are passed on to `fn`. 43 | """ 44 | super().__init__(reduction=reduction, name=name) 45 | self.fn = fn 46 | self._fn_kwargs = kwargs 47 | 48 | def call(self, y_true: FloatTensor, y_pred: FloatTensor) -> FloatTensor: 49 | """Invokes the `LossFunctionWrapper` instance. 50 | Args: 51 | y_true: Ground truth values. 52 | y_pred: The predicted values. 53 | Returns: 54 | Loss values per sample. 55 | """ 56 | loss: FloatTensor = self.fn(y_true, y_pred, y_true, y_pred, **self._fn_kwargs) 57 | return loss 58 | 59 | def get_config(self) -> dict[str, Any]: 60 | """Contains the loss configuration. 61 | 62 | Returns: 63 | A Python dict containing the configuration of the loss. 64 | """ 65 | config: dict[str, Any] = super().get_config() 66 | for k, v in iter(self._fn_kwargs.items()): 67 | if is_tensor_or_variable(v): 68 | config[k] = tf.keras.backend.eval(v) 69 | else: 70 | config[k] = v 71 | 72 | return config 73 | -------------------------------------------------------------------------------- /tensorflow_similarity/classification_metrics/recall.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..types import FloatTensor 22 | 23 | from .classification_metric import ClassificationMetric 24 | 25 | 26 | class Recall(ClassificationMetric): 27 | """Calculates the recall of the query classification. 28 | 29 | Computes the recall given the query classification counts. 30 | 31 | $$ 32 | Recall = \frac{\textrm{true_positives}}{\textrm{true_positives} + 33 | \textrm{false_negatives}} 34 | $$ 35 | 36 | args: 37 | name: Name associated with a specific metric object, e.g., 38 | recall@0.1 39 | 40 | Usage with `tf.similarity.models.SimilarityModel()`: 41 | 42 | ```python 43 | model.calibrate(x=query_examples, 44 | y=query_labels, 45 | calibration_metric='recall') 46 | ``` 47 | """ 48 | 49 | def __init__(self, name: str = "recall") -> None: 50 | super().__init__(name=name, canonical_name="recall") 51 | 52 | def compute(self, tp: FloatTensor, fp: FloatTensor, tn: FloatTensor, fn: FloatTensor, count: int) -> FloatTensor: 53 | """Compute the classification metric. 54 | 55 | The `compute()` method supports computing the metric for a set of 56 | values, where each value represents the counts at a specific distance 57 | threshold. 58 | 59 | Args: 60 | tp: A 1D FloatTensor containing the count of True Positives at each 61 | distance threshold. 62 | 63 | fp: A 1D FloatTensor containing the count of False Positives at 64 | each distance threshold. 65 | 66 | tn: A 1D FloatTensor containing the count of True Negatives at each 67 | distance threshold. 68 | 69 | fn: A 1D FloatTensor containing the count of False Negatives at 70 | each distance threshold. 71 | 72 | count: The total number of queries 73 | 74 | Returns: 75 | A 1D FloatTensor containing the metric at each distance threshold. 76 | """ 77 | result: FloatTensor = tf.math.divide_no_nan(tp, tp + fn) 78 | return result 79 | -------------------------------------------------------------------------------- /tensorflow_similarity/classification_metrics/negative_predictive_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..types import FloatTensor 22 | 23 | from .classification_metric import ClassificationMetric 24 | 25 | 26 | class NegativePredictiveValue(ClassificationMetric): 27 | """Calculates the negative predictive value of the query classification. 28 | 29 | Computes the NPV given the query classification counts. 30 | 31 | $$ 32 | FPR = \frac{\textrm{true_negatives}}{\textrm{false_negatives} + 33 | \textrm{true_negatives}} 34 | $$ 35 | 36 | args: 37 | name: Name associated with a specific metric object, e.g. 38 | npv@0.1 39 | 40 | Usage with `tf.similarity.models.SimilarityModel()`: 41 | 42 | ```python 43 | model.calibrate(x=query_examples, 44 | y=query_labels, 45 | calibration_metric='fpr') 46 | ``` 47 | """ 48 | 49 | def __init__(self, name: str = "npv") -> None: 50 | super().__init__(name=name, canonical_name="negative_predictive_value") 51 | 52 | def compute(self, tp: FloatTensor, fp: FloatTensor, tn: FloatTensor, fn: FloatTensor, count: int) -> FloatTensor: 53 | """Compute the classification metric. 54 | 55 | The `compute()` method supports computing the metric for a set of 56 | values, where each value represents the counts at a specific distance 57 | threshold. 58 | 59 | Args: 60 | tp: A 1D FloatTensor containing the count of True Positives at each 61 | distance threshold. 62 | 63 | fp: A 1D FloatTensor containing the count of False Positives at 64 | each distance threshold. 65 | 66 | tn: A 1D FloatTensor containing the count of True Negatives at each 67 | distance threshold. 68 | 69 | fn: A 1D FloatTensor containing the count of False Negatives at 70 | each distance threshold. 71 | 72 | count: The total number of queries 73 | 74 | Returns: 75 | A 1D FloatTensor containing the metric at each distance threshold. 76 | """ 77 | result: FloatTensor = tf.math.divide_no_nan(tn, tn + fn) 78 | return result 79 | -------------------------------------------------------------------------------- /tests/retrieval_metrics/test_retrieval_metric.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | import tensorflow as tf 5 | 6 | from tensorflow_similarity.retrieval_metrics import RetrievalMetric 7 | from tensorflow_similarity.types import BoolTensor, FloatTensor, IntTensor 8 | 9 | 10 | class ConcreteRetrievalMetric(RetrievalMetric): 11 | def compute( 12 | self, 13 | *, # keyword only arguments see PEP-570 14 | query_labels: IntTensor, 15 | lookup_labels: IntTensor, 16 | lookup_distances: FloatTensor, 17 | match_mask: BoolTensor, 18 | ) -> FloatTensor: 19 | return tf.constant(1.0) 20 | 21 | 22 | def test_concrete_instance(): 23 | rm = ConcreteRetrievalMetric( 24 | name="foo", 25 | canonical_name="bar", 26 | k=6, 27 | average="macro", 28 | ) 29 | 30 | assert rm.name == "foo@6" 31 | # Check the name once we have updated the threshold 32 | rm.distance_threshold = 0.1 33 | assert rm.name == "foo@6 : distance_threshold@0.1" 34 | assert repr(rm) == "bar : foo@6 : distance_threshold@0.1" 35 | assert rm.canonical_name == "bar" 36 | assert rm.k == 6 37 | assert rm.distance_threshold == 0.1 38 | assert rm.average == "macro" 39 | 40 | expected_config = { 41 | "name": "foo@6 : distance_threshold@0.1", 42 | "canonical_name": "bar", 43 | "k": 6, 44 | "distance_threshold": 0.1, 45 | } 46 | assert rm.get_config() == expected_config 47 | 48 | 49 | def test_k_greater_than_num_lookups(): 50 | query_labels = tf.constant([1, 1]) 51 | match_mask = tf.constant( 52 | [ 53 | [True, True, False, False], 54 | [True, False, False, True], 55 | ], 56 | dtype=bool, 57 | ) 58 | rm = ConcreteRetrievalMetric( 59 | name="foo", 60 | canonical_name="bar", 61 | k=5, 62 | average="macro", 63 | ) 64 | 65 | msg = "The number of neighbors must be >= K. Number of neighbors is 4 but " "K is 5." 66 | 67 | with pytest.raises(ValueError, match=re.escape(msg)): 68 | _ = rm._check_shape(query_labels=query_labels, match_mask=match_mask) 69 | 70 | 71 | def test_query_and_match_mask_different_dims(): 72 | query_labels = tf.constant([1, 2, 3, 4]) 73 | match_mask = tf.constant( 74 | [ 75 | [True, True, False, False], 76 | [True, False, False, True], 77 | ], 78 | dtype=bool, 79 | ) 80 | rm = ConcreteRetrievalMetric( 81 | name="foo", 82 | canonical_name="bar", 83 | k=4, 84 | average="macro", 85 | ) 86 | 87 | msg = ( 88 | "The number of lookup sets must equal the number of query labels. " 89 | "Number of lookup sets is 2 but the number of query labels is 4." 90 | ) 91 | 92 | with pytest.raises(ValueError, match=re.escape(msg)): 93 | _ = rm._check_shape(query_labels=query_labels, match_mask=match_mask) 94 | -------------------------------------------------------------------------------- /api/TFSimilarity/distances/SNRDistance.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.distances.SNRDistance 2 | 3 | 4 | 5 | 6 | 7 | Computes pairwise SNR distances between embeddings. 8 | 9 | Inherits From: [`Distance`](../../TFSimilarity/distances/Distance.md), [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | ```python 12 | TFSimilarity.distances.SNRDistance() 13 | ``` 14 | 15 | 16 | 17 | 18 | 19 | The [Signal-to-Noise Ratio distance](https://arxiv.org/abs/1904.02616) 20 | is the ratio of noise variance to the feature variance. 21 | 22 | ## Methods 23 | 24 |

call

25 | 26 | View source 27 | 28 | ``python 29 | @tf.function`` 30 | 31 | ```python 32 | call( 33 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 34 | , 35 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 36 | 37 | ) -> TFSimilarity.callbacks.FloatTensor``` 38 | 39 | ``` 40 | 41 | 42 | Compute pairwise snr distances for a given batch of embeddings. 43 | SNR(i, j): anchor i and compared feature j 44 | SNR(i,j) may not be equal to SNR(j, i) 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 55 | 58 | 59 |
Args
53 | query_embeddings 54 | 56 | Embeddings to compute the pairwise one. 57 |
60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 72 | 75 | 76 |
Returns
70 | FloatTensor 71 | 73 | Pairwise distance tensor. 74 |
77 | 78 | 79 | 80 |

get_config

81 | 82 | View source 83 | 84 | ```python 85 | get_config() 86 | ``` 87 | 88 | 89 | 90 | 91 | 92 |

__call__

93 | 94 | View source 95 | 96 | ```python 97 | __call__( 98 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 99 | , 100 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 101 | 102 | ) 103 | ``` 104 | 105 | 106 | Call self as a function. 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /tests/losses/test_xbm_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.losses import Reduction 3 | from tensorflow.python.framework import combinations 4 | 5 | from tensorflow_similarity import losses 6 | 7 | 8 | # TODO(ovallis): Refactor XBM loss to work with graph mode and update tests. 9 | @combinations.generate( 10 | combinations.combine( 11 | mode=["eager"], 12 | ) 13 | ) 14 | class XBMLossTest(tf.test.TestCase): 15 | def test_config(self): 16 | xbm_tpl_obj = losses.XBM( 17 | loss=losses.TripletLoss(distance="cosine"), 18 | memory_size=12, 19 | reduction=Reduction.SUM, 20 | name="xbm_triplet_loss", 21 | ) 22 | self.assertEqual(xbm_tpl_obj.distance.name, "cosine") 23 | self.assertEqual(xbm_tpl_obj.name, "xbm_triplet_loss") 24 | self.assertEqual(xbm_tpl_obj.memory_size, 12) 25 | self.assertEqual(xbm_tpl_obj.reduction, Reduction.SUM) 26 | 27 | def test_xbm_loss(self): 28 | batch_size = 6 29 | embed_dim = 16 30 | 31 | embeddings1 = tf.random.uniform(shape=[batch_size, embed_dim]) 32 | labels1 = tf.constant( 33 | [[1], [1], [2], [2], [3], [3]], 34 | dtype=tf.int32, 35 | ) 36 | 37 | embeddings2 = tf.random.uniform(shape=[batch_size, embed_dim]) 38 | labels2 = tf.constant( 39 | [[4], [4], [5], [5], [6], [6]], 40 | dtype=tf.int32, 41 | ) 42 | 43 | distance = "cosine" 44 | loss = losses.MultiSimilarityLoss(distance=distance) 45 | loss_nowarm = losses.XBM(loss, memory_size=12, warmup_steps=0) 46 | 47 | # test enqueue 48 | loss_nowarm(labels1, embeddings1) 49 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (batch_size, embed_dim)) 50 | self.assertAllEqual(loss_nowarm._y_true_memory, labels1) 51 | 52 | loss_nowarm(labels2, embeddings2) 53 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (2 * batch_size, embed_dim)) 54 | self.assertAllEqual(loss_nowarm._y_true_memory, tf.concat([labels2, labels1], axis=0)) 55 | 56 | # test dequeue 57 | loss_nowarm(labels2, embeddings2) 58 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (2 * batch_size, embed_dim)) 59 | self.assertAllEqual(loss_nowarm._y_true_memory, tf.concat([labels2, labels2], axis=0)) 60 | 61 | # test warmup 62 | loss_warm = losses.XBM(loss, memory_size=12, warmup_steps=1) 63 | 64 | loss_warm(labels1, embeddings1) 65 | self.assertAllEqual(loss_warm._y_pred_memory.numpy().shape, (0, embed_dim)) 66 | self.assertAllEqual(loss_warm._y_true_memory, tf.constant([[]], shape=(0, 1), dtype=tf.int32)) 67 | 68 | loss_warm(labels2, embeddings2) 69 | self.assertAllEqual(loss_warm._y_pred_memory.numpy().shape, (batch_size, embed_dim)) 70 | self.assertAllEqual(loss_warm._y_true_memory, labels2) 71 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | Thanks for considering contributing to TF similarity! 3 | 4 | Here is what you need to know to make a successful contribution. There are 5 | just a few small guidelines you need to follow. 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution; 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to the 13 | [Google contribution agreement server](https://cla.developers.google.com/) to see 14 | your current agreements on file or to sign a new one. 15 | 16 | You generally only need to submit a CLA once, so if you've already submitted one 17 | (even if it was for a different project), you probably don't need to do it 18 | again. 19 | 20 | ## Code reviews 21 | 22 | All submissions, including submissions by project members, require review. We 23 | use GitHub pull requests for this purpose. Consult 24 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 25 | information on using pull requests. 26 | 27 | ## Code submission guidelines 28 | 29 | Here are the few steps you need to follow to ensure a smooth review of your 30 | pull request: 31 | 32 | - Ideally one PR corespond to one feature or improvement to make it easier to 33 | review. So **try** to split your contribution in meaning logical units. 34 | - Your code **must** pass the unit-tests. We use `pytest` so simply run it at the root of the project. 35 | - Your code **must** passs static analyis. We use `mypy` so simply run `mypy tensorflow_similarity/` from the root of the project. 36 | - Your code **must** comes with unit-tests to ensure long term quality 37 | - Your functions **must** be documented except obvious ones using the Google style. 38 | - Your functions **must** be typed. 39 | - Your code must pass `flake8` to ensure you follow Python coding standards. 40 | - You **must** provide documentation (located in documentation/src) edits that document any new function/features you added. 41 | - You **should** provide reference to research papers that describe the feature added if any exists. 42 | - Please **try** to include a notebook in [examples/](examples/) that demonstrates the features/model if possible or extend an existing one. If you add a new notebook, make sure to provide an edit to the `README.md` that link to it. 43 | - You **must** update the documentation/src/changelog.md that document the changes triggered by the PR. 44 | - Your code must be formatted using `black -l 119`. For automated formatting during git commit, run `pre-commit install`. 45 | 46 | ### Configure git blame to ignore the commit that migrated the code style to black. 47 | ~~~bash 48 | git config blame.ignoreRevsFile .git-blame-ignore-revs 49 | ~~~ 50 | 51 | 52 | ## Community Guidelines 53 | 54 | This project follows 55 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 56 | -------------------------------------------------------------------------------- /tensorflow_similarity/classification_metrics/false_positive_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..types import FloatTensor 22 | 23 | from .classification_metric import ClassificationMetric 24 | 25 | 26 | class FalsePositiveRate(ClassificationMetric): 27 | """Calculates the false positive rate of the query classification. 28 | 29 | Computes the FPR given the query classification counts. 30 | 31 | $$ 32 | FPR = \frac{\textrm{false_positives}}{\textrm{false_positives} + 33 | \textrm{true_negatives}} 34 | $$ 35 | 36 | args: 37 | name: Name associated with a specific metric object, e.g., 38 | fpr@0.1 39 | 40 | Usage with `tf.similarity.models.SimilarityModel()`: 41 | 42 | ```python 43 | model.calibrate(x=query_examples, 44 | y=query_labels, 45 | calibration_metric='fpr') 46 | ``` 47 | """ 48 | 49 | def __init__(self, name: str = "fpr") -> None: 50 | super().__init__(name=name, canonical_name="false_positive_rate", maximize=False) 51 | 52 | def compute( 53 | self, 54 | tp: FloatTensor, 55 | fp: FloatTensor, 56 | tn: FloatTensor, 57 | fn: FloatTensor, 58 | count: int, 59 | ) -> FloatTensor: 60 | """Compute the classification metric. 61 | 62 | The `compute()` method supports computing the metric for a set of 63 | values, where each value represents the counts at a specific distance 64 | threshold. 65 | 66 | Args: 67 | tp: A 1D FloatTensor containing the count of True Positives at each 68 | distance threshold. 69 | 70 | fp: A 1D FloatTensor containing the count of False Positives at 71 | each distance threshold. 72 | 73 | tn: A 1D FloatTensor containing the count of True Negatives at each 74 | distance threshold. 75 | 76 | fn: A 1D FloatTensor containing the count of False Negatives at 77 | each distance threshold. 78 | 79 | count: The total number of queries 80 | 81 | Returns: 82 | A 1D FloatTensor containing the metric at each distance threshold. 83 | """ 84 | result: FloatTensor = tf.math.divide_no_nan(fp, fp + tn) 85 | return result 86 | -------------------------------------------------------------------------------- /tensorflow_similarity/matchers/match_majority_vote.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | import tensorflow as tf 19 | 20 | if TYPE_CHECKING: 21 | from ..types import FloatTensor, IntTensor 22 | 23 | from .classification_match import ClassificationMatch 24 | 25 | 26 | class MatchMajorityVote(ClassificationMatch): 27 | """Match metrics for the most common label in a result set.""" 28 | 29 | def __init__(self, name: str = "majority_vote", **kwargs) -> None: 30 | if "canonical_name" not in kwargs: 31 | kwargs["canonical_name"] = "match_majority_vote" 32 | 33 | super().__init__(name=name, **kwargs) 34 | 35 | def derive_match(self, lookup_labels: IntTensor, lookup_distances: FloatTensor) -> tuple[IntTensor, FloatTensor]: 36 | """Derive a match label and distance from a set of K neighbors. 37 | 38 | For each query, derive a single match label and distance given the 39 | associated set of lookup labels and distances. 40 | 41 | Args: 42 | lookup_labels: A 2D array where the jth row is the labels 43 | associated with the set of k neighbors for the jth query. 44 | 45 | lookup_distances: A 2D array where the jth row is the distances 46 | between the jth query and the set of k neighbors. 47 | 48 | Returns: 49 | A Tuple of FloatTensors: 50 | derived_labels: A FloatTensor of shape 51 | [len(lookup_labels), 1] where the jth row contains the derived 52 | label for the jth query. 53 | 54 | derived_distances: A FloatTensor of shape 55 | [len(lookup_labels), 1] where the jth row contains the distance 56 | associated with the jth derived label. 57 | """ 58 | 59 | # TODO(ovallis): Add parallel for callback or inline evaluation. 60 | pred_labels = tf.map_fn(self._majority_vote, lookup_labels) 61 | pred_labels = tf.expand_dims(pred_labels, axis=-1) 62 | 63 | agg_dist = tf.math.reduce_mean(lookup_distances, 1) 64 | agg_dist = tf.expand_dims(agg_dist, axis=-1) 65 | 66 | return pred_labels, agg_dist 67 | 68 | def _majority_vote(self, lookup_labels): 69 | labels, _, counts = tf.unique_with_counts(lookup_labels) 70 | majority = tf.argmax(counts) 71 | 72 | return tf.gather(labels, majority) 73 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | workflow_dispatch: # Allow manual triggers 8 | 9 | jobs: 10 | test: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | include: 17 | - python-version: '3.7' 18 | tf-version: '2.8' 19 | - python-version: '3.7' 20 | tf-version: '2.11' 21 | - python-version: '3.10' 22 | tf-version: '2.8' 23 | - python-version: '3.10' 24 | tf-version: '2.11' 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install coveralls 36 | 37 | - name: Install TF package 38 | run: | 39 | pip install tensorflow==${{ matrix.tf-version }} 40 | # Fix proto dep issue in protobuf 4 41 | pip install protobuf==3.20.* 42 | 43 | - name: Install dev packages 44 | run: | 45 | pip install ".[dev,visualization,redis,faiss,nmslib]" 46 | 47 | - name: Lint with flake8 48 | run: | 49 | # stop the build if there are Python syntax errors or undefined names 50 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 51 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 52 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 53 | 54 | - name: Static analysis 55 | run: | 56 | mypy tensorflow_similarity/ --raise-exceptions 57 | 58 | - name: Test with pytest 59 | run: | 60 | coverage run -m pytest tests/ 61 | 62 | - name: Upload coverage to Coveralls 63 | if: ${{ github.ref == 'refs/heads/master' }} 64 | run: | 65 | coveralls --service=github 66 | env: 67 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 68 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }} 69 | COVERALLS_PARALLEL: true 70 | 71 | - name: Check black formatting 72 | uses: psf/black@stable 73 | with: 74 | options: "--check --verbose" 75 | version: "22.6.0" 76 | - name: Check isort 77 | uses: isort/isort-action@master 78 | 79 | coveralls: 80 | 81 | name: Indicate completion to coveralls.io 82 | needs: test 83 | runs-on: ubuntu-latest 84 | 85 | steps: 86 | - name: Install dependencies 87 | if: ${{ github.ref == 'refs/heads/master' }} 88 | run: | 89 | python -m pip install --upgrade pip 90 | pip install coveralls 91 | 92 | - name: Finished 93 | if: ${{ github.ref == 'refs/heads/master' }} 94 | run: | 95 | coveralls --service=github --finish 96 | env: 97 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 98 | -------------------------------------------------------------------------------- /tests/search/test_nmslib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tensorflow_similarity.search.nmslib import NMSLibSearch 4 | 5 | 6 | def test_index_match(): 7 | target = np.array([1, 1, 2], dtype="float32") 8 | embs = np.array([[1, 1, 3], [3, 1, 2]], dtype="float32") 9 | 10 | search_index = NMSLibSearch("cosine", 3) 11 | search_index.add(embs[0], 0) 12 | search_index.add(embs[1], 1) 13 | 14 | idxs, embs = search_index.lookup(target) 15 | 16 | assert len(embs) == 2 17 | assert list(idxs) == [0, 1] 18 | 19 | 20 | def test_reset(): 21 | target = np.array([1, 2, 3], dtype="float32") 22 | embs = np.array([[4, 2, 1], [2, 3, 5]], dtype="float32") 23 | 24 | search_index = NMSLibSearch("cosine", 3) 25 | search_index.add(embs[0], 0) 26 | search_index.add(embs[1], 1) 27 | 28 | idxs, dists = search_index.lookup(target, k=2) 29 | 30 | assert len(dists) == 2 31 | print(dists) 32 | assert list(idxs) == [1, 0] 33 | 34 | search_index.reset() 35 | # switch order 36 | search_index.add(embs[1], 0) 37 | search_index.add(embs[0], 1) 38 | idxs, dists = search_index.lookup(target, k=2) 39 | 40 | assert len(dists) == 2 41 | assert list(idxs) == [0, 1] 42 | 43 | 44 | def test_index_save(tmp_path): 45 | target = np.array([1, 1, 2], dtype="float32") 46 | embs = np.array([[1, 1, 3], [3, 1, 2]], dtype="float32") 47 | 48 | search_index = NMSLibSearch("cosine", 3) 49 | search_index.add(embs[0], 0) 50 | search_index.add(embs[1], 1) 51 | 52 | idxs, embs = search_index.lookup(target) 53 | 54 | assert len(embs) == 2 55 | assert list(idxs) == [0, 1] 56 | 57 | search_index.save(tmp_path) 58 | 59 | search_index2 = NMSLibSearch("cosine", 3) 60 | search_index2.load(tmp_path) 61 | 62 | idxs2, embs2 = search_index.lookup(target) 63 | assert len(embs2) == 2 64 | assert list(idxs2) == [0, 1] 65 | 66 | # add more 67 | search_index2.add(np.array([3.0, 3.0, 3.0]), 3) 68 | idxs3, embs3 = search_index2.lookup(target) 69 | assert len(embs3) == 3 70 | assert list(idxs3) == [0, 3, 1] 71 | 72 | 73 | def test_batch_vs_single(tmp_path): 74 | num_targets = 10 75 | index_size = 100 76 | vect_dim = 16 77 | 78 | # gen 79 | idxs = list(range(index_size)) 80 | 81 | targets = np.random.random((num_targets, vect_dim)).astype("float32") 82 | embs = np.random.random((index_size, vect_dim)).astype("float32") 83 | 84 | # build search_index 85 | search_index = NMSLibSearch("cosine", vect_dim) 86 | search_index.batch_add(embs, idxs) 87 | 88 | # batch 89 | batch_idxs, _ = search_index.batch_lookup(targets) 90 | 91 | # single 92 | singles_idxs = [] 93 | for t in targets: 94 | idxs, embs = search_index.lookup(t) 95 | singles_idxs.append(idxs) 96 | 97 | for i in range(num_targets): 98 | # k neigboors are the same? 99 | for k in range(3): 100 | assert batch_idxs[i][k] == singles_idxs[i][k] 101 | -------------------------------------------------------------------------------- /api/TFSimilarity/distances/SquaredEuclideanDistance.md: -------------------------------------------------------------------------------- 1 | # TFSimilarity.distances.SquaredEuclideanDistance 2 | 3 | 4 | 5 | 6 | 7 | Compute pairwise squared Euclidean distance. 8 | 9 | Inherits From: [`Distance`](../../TFSimilarity/distances/Distance.md), [`ABC`](../../TFSimilarity/distances/ABC.md) 10 | 11 | ```python 12 | TFSimilarity.distances.SquaredEuclideanDistance() 13 | ``` 14 | 15 | 16 | 17 | 18 | 19 | The [Squared Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance) is 20 | a distance that varies from 0 (similar) to infinity (dissimilar). 21 | 22 | ## Methods 23 | 24 |

call

25 | 26 | View source 27 | 28 | ``python 29 | @tf.function`` 30 | 31 | ```python 32 | call( 33 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 34 | , 35 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 36 | 37 | ) -> TFSimilarity.callbacks.FloatTensor``` 38 | 39 | ``` 40 | 41 | 42 | Compute pairwise distances for a given batch of embeddings. 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 54 | 57 | 58 | 61 | 64 | 65 |
Args
52 | query_embeddings 53 | 55 | Embeddings to compute the pairwise one. 56 |
59 | key_embeddings 60 | 62 | Embeddings to compute the pairwise one. 63 |
66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 78 | 81 | 82 |
Returns
76 | FloatTensor 77 | 79 | Pairwise distance tensor. 80 |
83 | 84 | 85 | 86 |

get_config

87 | 88 | View source 89 | 90 | ```python 91 | get_config() 92 | ``` 93 | 94 | 95 | 96 | 97 | 98 |

__call__

99 | 100 | View source 101 | 102 | ```python 103 | __call__( 104 | query_embeddings: TFSimilarity.callbacks.FloatTensor``` 105 | , 106 | key_embeddings: TFSimilarity.callbacks.FloatTensor``` 107 | 108 | ) 109 | ``` 110 | 111 | 112 | Call self as a function. 113 | 114 | 115 | 116 | 117 | --------------------------------------------------------------------------------