├── tests
├── evaluators
│ └── __init__.py
├── losses
│ ├── __init__.py
│ ├── utils.py
│ ├── test_softnn_loss.py
│ └── test_xbm_loss.py
├── matchers
│ ├── __init__.py
│ ├── test_match_nearest.py
│ └── test_majority_vote.py
├── models
│ ├── __init__.py
│ └── test_similarity_model.py
├── samplers
│ ├── __init__.py
│ ├── test_tfdataset_samplers.py
│ └── test_tfrecord_samplers.py
├── search
│ ├── __init__.py
│ └── test_nmslib.py
├── stores
│ ├── __init__.py
│ ├── test_redis.py
│ ├── test_memory.py
│ └── test_cached.py
├── architectures
│ ├── __init__.py
│ └── test_resnet18.py
├── visualization
│ └── __init__.py
├── retrieval_metrics
│ ├── __init__.py
│ ├── test_recall_at_k.py
│ ├── test_precision_at_k.py
│ ├── test_map_at_k.py
│ └── test_retrieval_metric.py
├── training_metrics
│ └── __init__.py
├── classification_metrics
│ └── __init__.py
├── data
│ └── mnist_fashion_embeddings.npz
├── __init__.py
├── conftest.py
├── augmenters
│ ├── test_augmentation_utils
│ │ ├── test_solarize.py
│ │ ├── test_crop.py
│ │ ├── test_flip.py
│ │ ├── test_blur.py
│ │ └── test_color_jitter.py
│ └── test_augmenters.py
└── integration
│ └── test_readme_minimal.py
├── benchmark
└── supervised
│ ├── requirements.txt
│ ├── components
│ ├── datasets
│ │ ├── __init__.py
│ │ └── types.py
│ ├── __init__.py
│ ├── augmentations.py
│ ├── architectures.py
│ ├── optimizers.py
│ └── experiments.py
│ ├── README.md
│ └── create_datasets.py
├── tensorflow_similarity
├── augmenters
│ ├── augmentation_utils
│ │ ├── __init__.py
│ │ ├── random_apply.py
│ │ ├── solarize.py
│ │ └── flip.py
│ ├── __init__.py
│ ├── augmenter.py
│ └── contrastive.py
├── __init__.py
├── architectures
│ ├── __init__.py
│ └── utils.py
├── models
│ └── __init__.py
├── matchers
│ ├── __init__.py
│ ├── utils.py
│ ├── match_nearest.py
│ └── match_majority_vote.py
├── retrieval_metrics
│ ├── __init__.py
│ └── utils.py
├── classification_metrics
│ ├── __init__.py
│ ├── utils.py
│ ├── recall.py
│ ├── negative_predictive_value.py
│ └── false_positive_rate.py
├── losses
│ ├── __init__.py
│ └── metric_loss.py
├── training_metrics
│ ├── __init__.py
│ └── utils.py
├── visualization
│ ├── __init__.py
│ └── vizualize_views.py
├── samplers
│ ├── __init__.py
│ └── file_samplers.py
├── evaluators
│ └── __init__.py
└── distances
│ ├── cosine.py
│ ├── manhattan.py
│ ├── inner_product.py
│ ├── snr.py
│ └── distance.py
├── .flake8
├── pyproject.toml
├── api
└── TFSimilarity
│ ├── all_symbols.md
│ ├── api_report.pb
│ ├── distances
│ ├── ABC.md
│ ├── distance_canonicalizer.md
│ ├── Distance.md
│ ├── SNRDistance.md
│ └── SquaredEuclideanDistance.md
│ ├── indexer
│ ├── PandasDataFrame.md
│ ├── time.md
│ ├── CalibrationResults.md
│ ├── make_classification_matcher.md
│ └── Lookup.md
│ ├── training_metrics
│ ├── avg_neg.md
│ ├── avg_pos.md
│ ├── max_neg.md
│ ├── max_pos.md
│ ├── min_neg.md
│ ├── min_pos.md
│ ├── sum_neg.md
│ ├── sum_pos.md
│ ├── dist_gap.md
│ └── batch_class_ratio.md
│ ├── utils
│ ├── tf_cap_memory.md
│ └── is_tensor_or_variable.md
│ ├── callbacks
│ ├── unpack_lookup_labels.md
│ ├── unpack_lookup_distances.md
│ ├── unpack_results.md
│ ├── make_classification_metric.md
│ └── SplitValidationLoss.md
│ ├── layers
│ ├── ActivationStdLoggingLayer.md
│ └── MetricEmbedding.md
│ ├── models.md
│ ├── architectures.md
│ ├── augmenters.md
│ ├── visualization.md
│ ├── matchers.md
│ ├── retrieval_metrics.md
│ ├── stores.md
│ ├── layers.md
│ ├── utils.md
│ ├── losses.md
│ ├── training_metrics.md
│ ├── search.md
│ ├── classification_metrics.md
│ ├── evaluators.md
│ ├── distances.md
│ ├── visualization
│ ├── viz_neigbors_imgs.md
│ ├── visualize_views.md
│ ├── confusion_matrix.md
│ └── projector.md
│ ├── augmenters
│ ├── Augmenter.md
│ ├── ContrastiveAugmenter.md
│ ├── BarlowAugmenter.md
│ └── SimCLRAugmenter.md
│ ├── samplers.md
│ ├── samplers
│ └── select_examples.md
│ ├── callbacks.md
│ └── architectures
│ └── ResNet18Sim.md
├── mypy.ini
├── .gitattributes
├── assets
└── images
│ ├── tfsim_overview.png
│ ├── contrastive_batch.png
│ ├── similar-cats-and-dogs.jpg
│ ├── contrastive_model_terms.png
│ ├── self_supervised_overview.png
│ ├── contrastive_loss_functions.png
│ └── similarity-learning-overview.png
├── pytest.ini
├── .git-blame-ignore-revs
├── scripts
├── README.md
└── increment_version.py
├── .pre-commit-config.yaml
├── .coveragerc
├── examples
└── README.md
├── .github
└── workflows
│ ├── python-publish.yml
│ ├── nightly-publish.yml
│ └── test.yml
├── .gitignore
└── CONTRIBUTING.md
/tests/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/losses/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/matchers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/samplers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/search/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/stores/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/architectures/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/visualization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/retrieval_metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/training_metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/classification_metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/supervised/requirements.txt:
--------------------------------------------------------------------------------
1 | keras-tuner
2 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/augmentation_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import utils # noqa
2 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203
3 | exclude = tmp.py, tests/
4 | max-line-length = 120
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | profile = "black"
3 |
4 | [tool.black]
5 | line-length = 119
6 |
--------------------------------------------------------------------------------
/api/TFSimilarity/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow Similarity
2 |
3 |
4 | ## Primary symbols
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = True
3 | warn_return_any = True
4 | warn_unused_configs = True
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | * text=auto eol=lf
2 | *.{cmd,[cC][mM][dD]} text eol=crlf
3 | *.{bat,[bB][aA][tT]} text eol=crlf
--------------------------------------------------------------------------------
/api/TFSimilarity/api_report.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/api/TFSimilarity/api_report.pb
--------------------------------------------------------------------------------
/assets/images/tfsim_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/tfsim_overview.png
--------------------------------------------------------------------------------
/assets/images/contrastive_batch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_batch.png
--------------------------------------------------------------------------------
/assets/images/similar-cats-and-dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/similar-cats-and-dogs.jpg
--------------------------------------------------------------------------------
/tests/data/mnist_fashion_embeddings.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/tests/data/mnist_fashion_embeddings.npz
--------------------------------------------------------------------------------
/assets/images/contrastive_model_terms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_model_terms.png
--------------------------------------------------------------------------------
/assets/images/self_supervised_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/self_supervised_overview.png
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | CURRRENT_DIR = Path(__file__).parent.absolute()
4 | DATA_DIR = CURRRENT_DIR / "data/"
5 |
--------------------------------------------------------------------------------
/assets/images/contrastive_loss_functions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/contrastive_loss_functions.png
--------------------------------------------------------------------------------
/assets/images/similarity-learning-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/similarity/HEAD/assets/images/similarity-learning-overview.png
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | log_cli = 1
3 | log_level = INFO
4 | filterwarnings =
5 | ignore::DeprecationWarning
6 | ignore::PendingDeprecationWarning
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # Since git version 2.23, git-blame has a feature to ignore any commit in this file.
2 |
3 | # Migrate code to black + isort
4 | 2821d07e7248c412384734b1592b94140754897d
5 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from termcolor import cprint
3 |
4 |
5 | def pytest_configure(config):
6 | tf.config.set_visible_devices([], "GPU")
7 | cprint("Tensorflow set to CPU", "green")
8 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Utility script
2 |
3 |
4 | ## Generate API documentation
5 |
6 | From `scripts/` directory
7 |
8 | ```
9 | pip install git+https://github.com/tensorflow/docs
10 | python gen_api_doc.py
11 | ```
--------------------------------------------------------------------------------
/api/TFSimilarity/distances/ABC.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.distances.ABC
2 |
3 |
4 |
5 |
6 |
7 |
8 | Helper class that provides a standard way to create an ABC using
9 |
10 | inheritance.
11 |
12 |
--------------------------------------------------------------------------------
/api/TFSimilarity/indexer/PandasDataFrame.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.indexer.PandasDataFrame
2 |
3 |
4 |
5 |
6 |
7 | Symbolic pandas frame
8 |
9 | Pandas type are too loose you get an Any. We want a DataFrame
10 |
11 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/avg_neg.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.avg_neg
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.avg_neg(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/avg_pos.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.avg_pos
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.avg_pos(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/max_neg.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.max_neg
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.max_neg(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/max_pos.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.max_pos
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.max_pos(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/min_neg.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.min_neg
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.min_neg(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/min_pos.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.min_pos
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.min_pos(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/sum_neg.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.sum_neg
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.sum_neg(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/sum_pos.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.sum_pos
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.sum_pos(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/dist_gap.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.dist_gap
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ```python
10 | TFSimilarity.training_metrics.dist_gap(
11 | distance
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/utils/tf_cap_memory.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.utils.tf_cap_memory
2 |
3 |
4 |
5 |
6 |
7 | Avoid TF to hog memory before needing it
8 |
9 | ```python
10 | TFSimilarity.utils.tf_cap_memory()
11 | ```
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/api/TFSimilarity/utils/is_tensor_or_variable.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.utils.is_tensor_or_variable
2 |
3 |
4 |
5 |
6 |
7 | check if a variable is tf.Tensor or tf.Variable
8 |
9 | ```python
10 | TFSimilarity.utils.is_tensor_or_variable(
11 | x
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 22.6.0
4 | hooks:
5 | - id: black
6 | stages: ['commit']
7 |
8 | - repo: https://github.com/pycqa/isort
9 | rev: 5.12.0
10 | hooks:
11 | - id: isort
12 | name: isort (python)
13 | stages: ['commit']
14 |
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks/unpack_lookup_labels.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.callbacks.unpack_lookup_labels
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | ```python
11 | TFSimilarity.callbacks.unpack_lookup_labels(
12 | lookups: Sequence[Sequence[Lookup]], dtype: Union[str, tf.DType]
13 | ```
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/__init__.py:
--------------------------------------------------------------------------------
1 | from . import metrics # noqa
2 | from . import utils # noqa
3 | from .architectures import make_architecture # noqa
4 | from .augmentations import make_augmentations # noqa
5 | from .experiments import make_experiments # noqa
6 | from .losses import make_loss # noqa
7 | from .optimizers import make_optimizer # noqa
8 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = tensorflow_similarity
3 |
4 | omit =
5 | # omit anything in an api directory anywhere
6 | */api/*
7 | # omit all __init__.py files
8 | */__init__.py
9 |
10 | [report]
11 | omit =
12 | # omit anything in an api directory anywhere
13 | */api/*
14 | # omit all __init__.py files
15 | */__init__.py
16 |
--------------------------------------------------------------------------------
/api/TFSimilarity/indexer/time.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.indexer.time
2 |
3 |
4 |
5 |
6 |
7 |
8 | time() -> floating point number
9 |
10 | ```python
11 | TFSimilarity.indexer.time()
12 | ```
13 |
14 |
15 |
16 |
17 |
18 | Return the current time in seconds since the Epoch.
19 | Fractions of a second may be present if the system clock provides them.
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks/unpack_lookup_distances.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.callbacks.unpack_lookup_distances
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | ```python
11 | TFSimilarity.callbacks.unpack_lookup_distances(
12 | lookups: Sequence[Sequence[Lookup]],
13 | dtype: Union[str, tf.DType],
14 | distance_rounding: Optional[int] = None
15 | ```
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/api/TFSimilarity/layers/ActivationStdLoggingLayer.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.layers.ActivationStdLoggingLayer
2 |
3 |
4 |
5 |
6 |
7 | Computes the mean std of the activations of a layer.
8 |
9 | ```python
10 | TFSimilarity.layers.ActivationStdLoggingLayer(
11 | name, **kwargs
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 |
19 | x = reduce_std(l2_normalize(inputs, axis=0), axis=-1)
20 |
21 | And then aggregate the per-batch mean of x over each epoch.
22 |
23 |
--------------------------------------------------------------------------------
/benchmark/supervised/README.md:
--------------------------------------------------------------------------------
1 | # Similarity Metric Benchmarking System
2 |
3 | ## Current results
4 |
5 | ## Running the benchmark
6 |
7 | 1. Generate initial training data with preprocessing: `python create_datasets.py -c config/benchmark_prod.json -f "cars.*lamb.*"`
8 |
9 | 2. Train and Hypertune the model parameters: `python hyper_parameter_search.py -c config/benchmark_prod.json -f "cars.*lamb.*"`
10 |
11 | 2. Evaluate performance:
12 |
13 | 3. Analyze results using notebook:
14 |
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks/unpack_results.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.callbacks.unpack_results
2 |
3 |
4 |
5 |
6 |
7 | Updates logs, writes summary, and returns list of strings of
8 |
9 |
10 | ```python
11 | TFSimilarity.callbacks.unpack_results(
12 | results: Dict[str, np.ndarray],
13 | epoch: int,
14 | logs: dict,
15 | tb_writer: tf.summary.SummaryWriter,
16 | name_suffix: Optional[str] =
17 | ) -> List[str]
18 | ```
19 |
20 |
21 |
22 |
23 | evaluation metric
--------------------------------------------------------------------------------
/api/TFSimilarity/models.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.models
2 |
3 |
4 |
5 |
6 |
7 | Specialized `keras.model.Model` subclasses that offers additional
8 |
9 | functionalities to make training and serving similarity models quick and easy.
10 |
11 | ## Classes
12 |
13 | - [`class ContrastiveModel`](../TFSimilarity/models/ContrastiveModel.md): `Model` groups layers into an object with training and inference features.
14 |
15 | - [`class SimilarityModel`](../TFSimilarity/callbacks/SimilarityModel.md): Specialized Keras.Model which implement the core features needed for
16 |
17 |
--------------------------------------------------------------------------------
/api/TFSimilarity/architectures.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.architectures
2 |
3 |
4 |
5 |
6 |
7 | Deep-learning model architectures tweaked to work well for similarity
8 |
9 | training
10 |
11 | ## Functions
12 |
13 | - [`EfficientNetSim(...)`](../TFSimilarity/architectures/EfficientNetSim.md): Build an EfficientNet Model backbone for similarity learning
14 |
15 | - [`ResNet18Sim(...)`](../TFSimilarity/architectures/ResNet18Sim.md): Build an ResNet18 Model backbone for similarity learning
16 |
17 | - [`ResNet50Sim(...)`](../TFSimilarity/architectures/ResNet50Sim.md): Build an ResNet50 Model backbone for similarity learning
18 |
19 |
--------------------------------------------------------------------------------
/tests/matchers/test_match_nearest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.matchers import MatchNearest
5 |
6 |
7 | def test_predict():
8 | mn = MatchNearest()
9 |
10 | lookup_labels = tf.constant([[10, 12], [20, 13], [30, 14], [40, 15]])
11 | lookup_distances = tf.constant([[1.0, 1.5], [1.0, 1.7], [2.0, 2.1], [2.0, 2.2]])
12 |
13 | d_labels, d_dist = mn.derive_match(lookup_labels, lookup_distances)
14 |
15 | np.testing.assert_array_equal(d_labels.numpy(), np.array([[10], [20], [30], [40]]))
16 | np.testing.assert_array_equal(d_dist.numpy(), np.array([[1.0], [1.0], [2.0], [2.0]]))
17 |
--------------------------------------------------------------------------------
/api/TFSimilarity/augmenters.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.augmenters
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class Augmenter`](../TFSimilarity/augmenters/Augmenter.md): Helper class that provides a standard way to create an ABC using
14 |
15 | - [`class BarlowAugmenter`](../TFSimilarity/augmenters/BarlowAugmenter.md): Helper class that provides a standard way to create an ABC using
16 |
17 | - [`class ContrastiveAugmenter`](../TFSimilarity/augmenters/ContrastiveAugmenter.md): Helper class that provides a standard way to create an ABC using
18 |
19 | - [`class SimCLRAugmenter`](../TFSimilarity/augmenters/SimCLRAugmenter.md): SimCLR augmentation pipeline as defined in
20 |
21 |
--------------------------------------------------------------------------------
/api/TFSimilarity/layers/MetricEmbedding.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.layers.MetricEmbedding
2 |
3 |
4 |
5 |
6 |
7 | L2 Normalized `Dense` layer.
8 |
9 | ```python
10 | TFSimilarity.layers.MetricEmbedding(
11 | units,
12 | activation=None,
13 | use_bias=True,
14 | kernel_initializer=glorot_uniform,
15 | bias_initializer=zeros,
16 | kernel_regularizer=None,
17 | bias_regularizer=None,
18 | activity_regularizer=None,
19 | kernel_constraint=None,
20 | bias_constraint=None,
21 | **kwargs
22 | )
23 | ```
24 |
25 |
26 |
27 |
28 |
29 | This layer is usually used as output layer, especially when using cosine
30 | distance as the similarity metric.
31 |
32 |
--------------------------------------------------------------------------------
/api/TFSimilarity/visualization.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.visualization
2 |
3 |
4 |
5 |
6 |
7 | Collection of specialized notebook vizualization tools
8 |
9 |
10 |
11 | ## Functions
12 |
13 | - [`confusion_matrix(...)`](../TFSimilarity/visualization/confusion_matrix.md): Plot confusion matrix
14 |
15 | - [`projector(...)`](../TFSimilarity/visualization/projector.md): Visualize the embeddings in 2D or 3D using UMAP projection
16 |
17 | - [`visualize_views(...)`](../TFSimilarity/visualization/visualize_views.md): Display side by side different image views with labels, and predictions
18 |
19 | - [`viz_neigbors_imgs(...)`](../TFSimilarity/visualization/viz_neigbors_imgs.md): Display images nearest neighboors
20 |
21 |
--------------------------------------------------------------------------------
/tests/matchers/test_majority_vote.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.matchers import MatchMajorityVote
5 |
6 |
7 | def test_predict():
8 | mn = MatchMajorityVote()
9 |
10 | lookup_labels = tf.constant(
11 | [
12 | [10, 12, 10, 12, 10],
13 | # Ties should take the closer label.
14 | [20, 13, 13, 20, 30],
15 | ]
16 | )
17 | lookup_distances = tf.constant([[1.0, 1.1, 1.2, 1.3, 1.4], [2.0, 2.1, 2.2, 2.3, 2.4]])
18 |
19 | d_labels, d_dist = mn.derive_match(lookup_labels, lookup_distances)
20 |
21 | np.testing.assert_array_equal(d_labels.numpy(), np.array([[10], [20]]))
22 | np.testing.assert_allclose(d_dist.numpy(), np.array([[1.2], [2.2]]))
23 |
--------------------------------------------------------------------------------
/tensorflow_similarity/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | __version__ = "0.18.0.dev12"
15 |
16 |
17 | from . import models # noqa
18 | from . import samplers # noqa
19 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .augmenter import Augmenter # noqa
15 | from .barlow import BarlowAugmenter # noqa
16 | from .contrastive import ContrastiveAugmenter # noqa
17 | from .simclr import SimCLRAugmenter # noqa
18 |
--------------------------------------------------------------------------------
/api/TFSimilarity/matchers.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.matchers
2 |
3 |
4 |
5 |
6 |
7 | Matchers define the classification matching strategy when using
8 |
9 | similarity models to match query examples to the classes of indexed
10 | examples.
11 |
12 | ## Classes
13 |
14 | - [`class ClassificationMatch`](../TFSimilarity/callbacks/ClassificationMatch.md): Abstract base class for defining the classification matching strategy.
15 |
16 | - [`class MatchMajorityVote`](../TFSimilarity/matchers/MatchMajorityVote.md): Match metrics for the most common label in a result set.
17 |
18 | - [`class MatchNearest`](../TFSimilarity/matchers/MatchNearest.md): Match metrics for labels at k=1.
19 |
20 | ## Functions
21 |
22 | - [`make_classification_matcher(...)`](../TFSimilarity/indexer/make_classification_matcher.md): Convert classification matcher from str name to object if needed.
23 |
24 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics/batch_class_ratio.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.training_metrics.batch_class_ratio
2 |
3 |
4 |
5 |
6 |
7 | Computes the average number of examples per class within each batch.
8 |
9 | ```python
10 | TFSimilarity.training_metrics.batch_class_ratio(
11 | sampler: Sampler, num_batches: int = 100
12 | ) -> float
13 | ```
14 |
15 |
16 |
17 |
18 | Similarity learning requires at least 2 examples per class in each batch.
19 | This is needed in order to construct the triplets. This function
20 | provides the average number of examples per class within each batch and
21 | can be used to check that a sampler is working correctly.
22 | The ratio should be >= 2.
23 | Args:
24 | sampler: A tf.similarity sampler object.
25 | num_batches: The number of batches to sample.
26 | Returns:
27 | The average number of examples per class.
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | ### TensorFlow Similarity examples
2 |
3 |
4 | | Tutorial | Learning Type | Description |
5 | | ------ | :-----: | ---------- |
6 | | [Hello World](./supervised_hello_world.ipynb) | Supervised | Train and use an image similarity model to find similar looking MNIST digits |
7 | | [Self-Supervised Learning](./unsupervised_hello_world.ipynb) | Unsupervised | Train an image model using the SimSiam based self-supervised contrastive learning. |
8 | | [visualization](./visualization.ipynb) | Supervised | Train an image similarity model on the Stanford Dogs dataset using Evaluation Callbacks and the interactive visualizer |
9 | | [Sampler IO Cookbook](./sampler_io_cookbook.ipynb) | Utils | Examples demonstrating how to use the various in memory batch samplers. |
10 | | [CLIP finetuning](./multimodal_example.ipynb) | Supervised | Finetune CLIP on atric-dataset using multiple negatives ranking loss.
11 |
--------------------------------------------------------------------------------
/tests/samplers/test_tfdataset_samplers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
4 |
5 |
6 | class DatasetSamplersTest(tf.test.TestCase):
7 | def test_basic(self):
8 | dataset_name = "mnist"
9 | sampler = TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=10)
10 | batch = sampler.generate_batch(42)
11 | self.assertEqual(batch[0].shape, (20, 28, 28, 1))
12 |
13 | def test_wrong_key(self):
14 | dataset_name = "mnist"
15 |
16 | # X
17 | with self.assertRaises(ValueError):
18 | TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=4, x_key="error")
19 | # Y
20 | with self.assertRaises(ValueError):
21 | TFDatasetMultiShotMemorySampler(dataset_name=dataset_name, classes_per_batch=4, y_key="error")
22 |
--------------------------------------------------------------------------------
/tensorflow_similarity/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Deep-learning model architectures tweaked to work well for similarity
15 | training"""
16 | from .efficientnet import EfficientNetSim # noqa
17 | from .resnet18 import ResNet18Sim # noqa
18 | from .resnet50 import ResNet50Sim # noqa
19 |
--------------------------------------------------------------------------------
/tests/retrieval_metrics/test_recall_at_k.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.retrieval_metrics import RecallAtK
5 |
6 | testdata = [
7 | (
8 | "micro",
9 | tf.constant(0.75),
10 | ),
11 | (
12 | "macro",
13 | tf.constant(0.5),
14 | ),
15 | ]
16 |
17 |
18 | @pytest.mark.parametrize("avg, expected", testdata, ids=["micro", "macro"])
19 | def test_compute(avg, expected):
20 | query_labels = tf.constant([1, 1, 1, 0])
21 | match_mask = tf.constant(
22 | [
23 | [True, True, False],
24 | [True, True, False],
25 | [True, True, False],
26 | [False, False, True],
27 | ],
28 | dtype=bool,
29 | )
30 | rm = RecallAtK(k=2, average=avg)
31 |
32 | recall = rm.compute(query_labels=query_labels, match_mask=match_mask)
33 | assert recall == expected
34 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmentation_utils/test_solarize.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters.augmentation_utils import solarize
5 |
6 |
7 | def create_img(width=32, height=32, channels=3):
8 | return tf.random.uniform([width, height, channels], 0, 1)
9 |
10 |
11 | def test_solarization():
12 | img = create_img()
13 |
14 | random_solarize_always = solarize.random_solarize(img, p=1, thresh=0, pixel_min=0, pixel_max=1)
15 | random_solarize_never = solarize.random_solarize(img, p=0, thresh=0, pixel_min=0, pixel_max=1)
16 |
17 | # check shapes
18 | assert (tf.shape(random_solarize_always) == tf.shape(img)).numpy().all()
19 | assert (tf.shape(random_solarize_never) == tf.shape(img)).numpy().all()
20 |
21 | # check if flip works
22 | assert not (random_solarize_always == img).numpy().all()
23 | assert (random_solarize_never == img).numpy().all()
24 |
--------------------------------------------------------------------------------
/api/TFSimilarity/retrieval_metrics.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.retrieval_metrics
2 |
3 |
4 |
5 |
6 |
7 | Retrieval metrics measure the quality of the embedding space given a
8 |
9 | set query examples and a set of indexed examples. Informally it can be thought
10 | of as how well the space is clustered among other things.
11 |
12 | ## Classes
13 |
14 | - [`class BNDCG`](../TFSimilarity/retrieval_metrics/BNDCG.md): Binary normalized discounted cumulative gain.
15 |
16 | - [`class MapAtK`](../TFSimilarity/retrieval_metrics/MapAtK.md): Mean Average precision - mAP@K is computed as.
17 |
18 | - [`class PrecisionAtK`](../TFSimilarity/retrieval_metrics/PrecisionAtK.md): Precision@K is computed as.
19 |
20 | - [`class RecallAtK`](../TFSimilarity/retrieval_metrics/RecallAtK.md): The metric learning version of Recall@K.
21 |
22 | - [`class RetrievalMetric`](../TFSimilarity/indexer/RetrievalMetric.md): Abstract base class for computing retrieval metrics.
23 |
24 |
--------------------------------------------------------------------------------
/tests/retrieval_metrics/test_precision_at_k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import tensorflow as tf
4 |
5 | from tensorflow_similarity.retrieval_metrics import PrecisionAtK
6 |
7 | testdata = [
8 | (
9 | "micro",
10 | tf.constant(0.583333333),
11 | ),
12 | (
13 | "macro",
14 | tf.constant(0.5),
15 | ),
16 | ]
17 |
18 |
19 | @pytest.mark.parametrize("avg, expected", testdata, ids=["micro", "macro"])
20 | def test_compute(avg, expected):
21 | query_labels = tf.constant([1, 1, 1, 0])
22 | match_mask = tf.constant(
23 | [
24 | [True, True, False],
25 | [True, True, False],
26 | [True, True, False],
27 | [False, False, True],
28 | ],
29 | dtype=bool,
30 | )
31 | rm = PrecisionAtK(k=3, average=avg)
32 |
33 | precision = rm.compute(query_labels=query_labels, match_mask=match_mask)
34 | np.testing.assert_allclose(precision, expected, atol=1e-05)
35 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmenters.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters import BarlowAugmenter, SimCLRAugmenter
5 |
6 |
7 | def create_imgs(width=32, height=32, channels=3, num=5):
8 | return tf.random.uniform([num, width, height, channels], 0, 1)
9 |
10 |
11 | def test_barlow():
12 | imgs = create_imgs()
13 | WIDTH = 32
14 | HEIGHT = 32
15 | CHANNELS = 3
16 | NUM = 5
17 |
18 | aug = BarlowAugmenter(WIDTH, HEIGHT)
19 | augmented = aug.augment(imgs)
20 |
21 | assert (tf.shape(augmented) == tf.constant([2, NUM, WIDTH, HEIGHT, CHANNELS])).numpy().all()
22 |
23 |
24 | def test_simclr():
25 | imgs = create_imgs()
26 | WIDTH = 32
27 | HEIGHT = 32
28 | CHANNELS = 3
29 | NUM = 5
30 |
31 | aug = SimCLRAugmenter(HEIGHT, WIDTH)
32 | augmented = aug.augment(imgs, tf.constant([0]), 2, True)
33 |
34 | assert (tf.shape(augmented) == tf.constant([2, NUM, WIDTH, HEIGHT, CHANNELS])).numpy().all()
35 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | workflow_dispatch: # Allow manual triggers
5 |
6 | jobs:
7 | publish:
8 | # Prevent Publish from running on forks.
9 | if: |
10 | github.ref == 'refs/heads/master' &&
11 | github.repository == 'tensorflow/similarity'
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.9'
21 |
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 |
26 | - name: Install package
27 | run: |
28 | pip install ".[tensorflow,dev]"
29 |
30 | - name: Build package
31 | run: |
32 | python setup.py sdist bdist_wheel
33 |
34 | - name: Publish package
35 | run: |
36 | twine upload -u ${{ secrets.PYPI_USERNAME }} -p ${{ secrets.PYPY_TOKEN }} dist/* --verbose
37 |
--------------------------------------------------------------------------------
/api/TFSimilarity/stores.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.stores
2 |
3 |
4 |
5 |
6 |
7 | Key Values Stores store the data associated with the embeddings indexed by
8 |
9 | the `Indexer()`.
10 |
11 | Each key of the store represent a **record** that contains information
12 | about a given embedding.
13 |
14 | The main use-case for the store is to retrieve the records associated
15 | with the ids returned by a nearest neigboor search performed with the
16 | - [`Search()`](../search/).
17 |
18 | Additionally one might want to inspect the content of the index which is why
19 | `Store()` class may implement an export to
20 | a [Pandas Dataframe](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html)
21 | via the `to_pandas()` method.
22 |
23 | ## Classes
24 |
25 | - [`class MemoryStore`](../TFSimilarity/indexer/MemoryStore.md): Efficient in-memory dataset store
26 |
27 | - [`class Store`](../TFSimilarity/indexer/Store.md): Helper class that provides a standard way to create an ABC using
28 |
29 |
--------------------------------------------------------------------------------
/tensorflow_similarity/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Specialized `keras.model.Model` subclasses that offers additional
15 | functionalities to make training and serving similarity models quick and easy.
16 | """
17 |
18 | from .contrastive_model import ContrastiveModel # noqa
19 | from .contrastive_model import create_contrastive_model # noqa
20 | from .similarity_model import SimilarityModel # noqa
21 |
--------------------------------------------------------------------------------
/api/TFSimilarity/distances/distance_canonicalizer.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.distances.distance_canonicalizer
2 |
3 |
4 |
5 |
6 |
7 | Normalize user requested distance to its matching Distance object.
8 |
9 |
10 | ```python
11 | TFSimilarity.distances.distance_canonicalizer(
12 | ```
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | Args |
23 |
24 |
25 | |
26 | user_distance
27 | |
28 |
29 | Requested distance either by name or by object
30 | |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 | Returns |
40 |
41 |
42 | |
43 | Distance
44 | |
45 |
46 | Requested object name.
47 | |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/api/TFSimilarity/layers.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.layers
2 |
3 |
4 |
5 |
6 |
7 | Specialized Similarity `keras.layers`
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class ActivationStdLoggingLayer`](../TFSimilarity/layers/ActivationStdLoggingLayer.md): Computes the mean std of the activations of a layer.
14 |
15 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor
16 |
17 | - [`class GeneralizedMeanPooling`](../TFSimilarity/layers/GeneralizedMeanPooling.md): This is the class from which all layers inherit.
18 |
19 | - [`class GeneralizedMeanPooling1D`](../TFSimilarity/layers/GeneralizedMeanPooling1D.md): Computes the Generalized Mean of each channel in a tensor.
20 |
21 | - [`class GeneralizedMeanPooling2D`](../TFSimilarity/layers/GeneralizedMeanPooling2D.md): Computes the Generalized Mean of each channel in a tensor.
22 |
23 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor
24 |
25 | - [`class MetricEmbedding`](../TFSimilarity/layers/MetricEmbedding.md): L2 Normalized `Dense` layer.
26 |
27 |
--------------------------------------------------------------------------------
/tensorflow_similarity/matchers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Matchers define the classification matching strategy when using
15 | similarity models to match query examples to the classes of indexed
16 | examples."""
17 | # Match Types
18 | from .classification_match import ClassificationMatch # noqa
19 | from .match_majority_vote import MatchMajorityVote # noqa
20 | from .match_nearest import MatchNearest # noqa
21 | from .utils import make_classification_matcher # noqa
22 |
--------------------------------------------------------------------------------
/tests/losses/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def generate_perfect_test_batch(batch_size=32):
5 | """Generate a batch of embeddings and labels that will result in a perfect loss score."""
6 |
7 | # y_true: labels
8 | y_true = tf.range(0, batch_size, dtype=tf.int32)
9 | y_true = tf.concat([y_true, y_true], axis=0)
10 |
11 | # y_preds: embedding
12 | y_preds = tf.one_hot(y_true, depth=batch_size, dtype=tf.float32)
13 |
14 | y_true = tf.expand_dims(y_true, axis=1)
15 | return y_true, y_preds
16 |
17 |
18 | def generate_bad_test_batch(batch_size=32):
19 | """Generate a batch of embeddings and labels that will result in a mismatch for all classes."""
20 |
21 | # y_true: labels
22 | y_true = tf.range(0, batch_size, dtype=tf.int32)
23 |
24 | # y_preds: embedding
25 | y_preds = tf.concat([y_true, y_true[::-1]], axis=0)
26 | y_preds = tf.one_hot(y_preds, depth=batch_size, dtype=tf.float32)
27 |
28 | y_true = tf.concat([y_true, y_true], axis=0)
29 | y_true = tf.expand_dims(y_true, axis=1)
30 | return y_true, y_preds
31 |
--------------------------------------------------------------------------------
/tensorflow_similarity/retrieval_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Retrieval metrics measure the quality of the embedding space given a
15 | set query examples and a set of indexed examples. Informally it can be thought
16 | of as how well the space is clustered among other things.
17 | """
18 | from .bndcg import BNDCG # noqa
19 | from .map_at_k import MapAtK # noqa
20 | from .precision_at_k import PrecisionAtK # noqa
21 | from .recall_at_k import RecallAtK # noqa
22 | from .retrieval_metric import RetrievalMetric # noqa
23 |
--------------------------------------------------------------------------------
/api/TFSimilarity/utils.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.utils
2 |
3 |
4 |
5 |
6 |
7 | Various utilities functions for improved quality of life.
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class BoolTensor`](../TFSimilarity/utils/BoolTensor.md): Bool tensor
14 |
15 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor
16 |
17 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor
18 |
19 | - [`class Lookup`](../TFSimilarity/indexer/Lookup.md): Metadata associated with a query match.
20 |
21 | ## Functions
22 |
23 | - [`is_tensor_or_variable(...)`](../TFSimilarity/utils/is_tensor_or_variable.md): check if a variable is tf.Tensor or tf.Variable
24 |
25 | - [`tf_cap_memory(...)`](../TFSimilarity/utils/tf_cap_memory.md): Avoid TF to hog memory before needing it
26 |
27 | - [`unpack_lookup_distances(...)`](../TFSimilarity/callbacks/unpack_lookup_distances.md)
28 |
29 | - [`unpack_lookup_labels(...)`](../TFSimilarity/callbacks/unpack_lookup_labels.md)
30 |
31 | - [`unpack_results(...)`](../TFSimilarity/callbacks/unpack_results.md): Updates logs, writes summary, and returns list of strings of
32 |
33 |
--------------------------------------------------------------------------------
/tensorflow_similarity/classification_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Classification metrics measure matching classification quality between a
15 | set query examples and a set of indexed examples. """
16 | from .binary_accuracy import BinaryAccuracy # noqa
17 | from .classification_metric import ClassificationMetric # noqa
18 | from .f1_score import F1Score # noqa
19 | from .false_positive_rate import FalsePositiveRate # noqa
20 | from .negative_predictive_value import NegativePredictiveValue # noqa
21 | from .precision import Precision # noqa
22 | from .recall import Recall # noqa
23 | from .utils import make_classification_metric # noqa
24 |
--------------------------------------------------------------------------------
/api/TFSimilarity/losses.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.losses
2 |
3 |
4 |
5 |
6 |
7 | Contrastive learning specialized losses.
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class Barlow`](../TFSimilarity/losses/Barlow.md): Computes the Barlow Loss between two batches of embeddings.
14 |
15 | - [`class CircleLoss`](../TFSimilarity/losses/CircleLoss.md): Computes the CircleLoss
16 |
17 | - [`class MetricLoss`](../TFSimilarity/losses/MetricLoss.md): Wraps a loss function in the `Loss` class.
18 |
19 | - [`class MultiSimilarityLoss`](../TFSimilarity/losses/MultiSimilarityLoss.md): Computes the multi similarity loss in an online fashion.
20 |
21 | - [`class PNLoss`](../TFSimilarity/losses/PNLoss.md): Computes the PN loss in an online fashion.
22 |
23 | - [`class SimCLRLoss`](../TFSimilarity/losses/SimCLRLoss.md): SimCLR Loss
24 |
25 | - [`class SimSiamLoss`](../TFSimilarity/losses/SimSiamLoss.md): SimSiam Loss
26 |
27 | - [`class SoftNearestNeighborLoss`](../TFSimilarity/losses/SoftNearestNeighborLoss.md): Computes the soft nearest neighbors loss in an online fashion.
28 |
29 | - [`class TripletLoss`](../TFSimilarity/losses/TripletLoss.md): Computes the triplet loss in an online fashion.
30 |
31 | - [`class VicReg`](../TFSimilarity/losses/VicReg.md): VicReg Loss
32 |
33 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmentation_utils/test_crop.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters.augmentation_utils import cropping
5 |
6 |
7 | def create_img(width=32, height=32, channels=3):
8 | return tf.random.uniform([width, height, channels], 0, 1)
9 |
10 |
11 | def test_center_cropping():
12 | img = create_img()
13 | WIDTH = 32
14 | HEIGHT = 32
15 | CHANNELS = 3
16 |
17 | center_cropped = cropping.center_crop(img, HEIGHT, WIDTH, 0.5)
18 |
19 | assert (tf.shape(center_cropped) == tf.shape(img)).numpy().all()
20 |
21 |
22 | def test_random_cropping():
23 | # Random Crop
24 | img = create_img()
25 | WIDTH = 32
26 | HEIGHT = 32
27 | CHANNELS = 3
28 |
29 | random_cropping_always = cropping.random_crop_with_resize(img, HEIGHT, WIDTH, 1)
30 |
31 | random_cropping_never = cropping.random_crop_with_resize(img, HEIGHT, WIDTH, 0)
32 |
33 | # check shapes
34 | assert (tf.shape(random_cropping_always) == tf.shape(img)).numpy().all()
35 | assert (tf.shape(random_cropping_never) == tf.shape(img)).numpy().all()
36 |
37 | # check if crop works
38 | assert not (random_cropping_always == img).numpy().all()
39 | assert (random_cropping_never == img).numpy().all()
40 |
--------------------------------------------------------------------------------
/api/TFSimilarity/training_metrics.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.training_metrics
2 |
3 |
4 |
5 |
6 |
7 | Specialized `keras.metrics` that tracks how distances evolves during training.
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class DistanceGapMetric`](../TFSimilarity/training_metrics/DistanceGapMetric.md): Encapsulates metric logic and state.
14 |
15 | - [`class DistanceMetric`](../TFSimilarity/training_metrics/DistanceMetric.md): Encapsulates metric logic and state.
16 |
17 | ## Functions
18 |
19 | - [`avg_neg(...)`](../TFSimilarity/training_metrics/avg_neg.md)
20 |
21 | - [`avg_pos(...)`](../TFSimilarity/training_metrics/avg_pos.md)
22 |
23 | - [`batch_class_ratio(...)`](../TFSimilarity/training_metrics/batch_class_ratio.md): Computes the average number of examples per class within each batch.
24 |
25 | - [`dist_gap(...)`](../TFSimilarity/training_metrics/dist_gap.md)
26 |
27 | - [`max_neg(...)`](../TFSimilarity/training_metrics/max_neg.md)
28 |
29 | - [`max_pos(...)`](../TFSimilarity/training_metrics/max_pos.md)
30 |
31 | - [`min_neg(...)`](../TFSimilarity/training_metrics/min_neg.md)
32 |
33 | - [`min_pos(...)`](../TFSimilarity/training_metrics/min_pos.md)
34 |
35 | - [`sum_neg(...)`](../TFSimilarity/training_metrics/sum_neg.md)
36 |
37 | - [`sum_pos(...)`](../TFSimilarity/training_metrics/sum_pos.md)
38 |
39 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/augmenter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | import abc
17 |
18 | from ..types import Tensor
19 |
20 |
21 | class Augmenter(abc.ABC):
22 | @abc.abstractmethod
23 | def augment(
24 | self,
25 | x: Tensor,
26 | y: Tensor,
27 | num_augmentations_per_example: int,
28 | is_warmup: bool,
29 | ) -> list[Tensor]:
30 | pass
31 |
32 | def __call__(
33 | self,
34 | x: Tensor,
35 | y: Tensor,
36 | num_augmentations_per_example: int,
37 | is_warmup: bool,
38 | ) -> list[Tensor]:
39 | return self.augment(x, y, num_augmentations_per_example, is_warmup)
40 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/augmentations.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Mapping
4 | from functools import partial
5 | from typing import Any
6 |
7 | from tensorflow.image import random_flip_left_right, random_flip_up_down
8 |
9 | from tensorflow_similarity.augmenters.augmentation_utils.cropping import (
10 | center_crop,
11 | crop_and_resize,
12 | )
13 |
14 | AUGMENTATIONS = {}
15 | AUGMENTATIONS["random_resized_crop"] = lambda p: partial(
16 | crop_and_resize,
17 | height=p.get("height", 277),
18 | width=p.get("width", 277),
19 | area_range=p.get("area_range", (0.2, 1.0)),
20 | )
21 | AUGMENTATIONS["random_flip"] = (
22 | lambda p: random_flip_left_right if p.get("mode", "horizontal") == "horizontal" else random_flip_up_down
23 | )
24 | AUGMENTATIONS["center_crop"] = lambda p: partial(
25 | center_crop,
26 | height=p.get("height", 256),
27 | width=p.get("width", 256),
28 | crop_proportion=p.get("crop_proportion", 1.0),
29 | )
30 |
31 |
32 | # TODO(ovallis): Return type should be tuple[Callable[[FloatTensor], FloatTensor]], but
33 | # mypy doesn't recogonize the return types of the callabels.
34 | def make_augmentations(cfg: Mapping[str, Any]) -> tuple[Any, ...]:
35 | return tuple([AUGMENTATIONS[aug_id](params) for aug_id, params in cfg.items()])
36 |
--------------------------------------------------------------------------------
/tensorflow_similarity/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Contrastive learning specialized losses.
17 | """
18 | from .barlow import Barlow # noqa
19 | from .circle_loss import CircleLoss # noqa
20 | from .lifted_structure_loss import LiftedStructLoss # noqa
21 | from .metric_loss import MetricLoss # noqa
22 | from .multinegrank_loss import MultiNegativesRankLoss # noqa
23 | from .multisim_loss import MultiSimilarityLoss # noqa
24 | from .pn_loss import PNLoss # noqa
25 | from .simclr import SimCLRLoss # noqa
26 | from .simsiam import SimSiamLoss # noqa
27 | from .softnn_loss import SoftNearestNeighborLoss # noqa
28 | from .triplet_loss import TripletLoss # noqa
29 | from .vicreg import VicReg # noqa
30 | from .xbm_loss import XBM # noqa
31 |
--------------------------------------------------------------------------------
/tensorflow_similarity/training_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | "Specialized `keras.metrics` that tracks how distances evolves during training."
15 | from .distance_metrics import DistanceGapMetric # noqa
16 | from .distance_metrics import DistanceMetric # noqa
17 | from .distance_metrics import avg_neg # noqa
18 | from .distance_metrics import avg_pos # noqa
19 | from .distance_metrics import dist_gap # noqa
20 | from .distance_metrics import max_neg # noqa
21 | from .distance_metrics import max_pos # noqa
22 | from .distance_metrics import min_neg # noqa
23 | from .distance_metrics import min_pos # noqa
24 | from .distance_metrics import sum_neg # noqa
25 | from .distance_metrics import sum_pos # noqa
26 | from .utils import batch_class_ratio # noqa
27 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/augmentation_utils/random_apply.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from __future__ import annotations
16 |
17 | from collections.abc import Callable
18 |
19 | import tensorflow as tf
20 |
21 | from tensorflow_similarity.types import Tensor
22 |
23 |
24 | def random_apply(func: Callable[[Tensor], Tensor], p: float, x: Tensor) -> Tensor:
25 | """Randomly apply function func to x with probability p."""
26 | return tf.cond(
27 | tf.less(
28 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
29 | tf.cast(p, tf.float32),
30 | ),
31 | lambda: func(x),
32 | lambda: x,
33 | )
34 |
--------------------------------------------------------------------------------
/api/TFSimilarity/indexer/CalibrationResults.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.indexer.CalibrationResults
2 |
3 |
4 |
5 |
6 |
7 | Cutpoints and thresholds associated with a calibration.
8 |
9 | ```python
10 | TFSimilarity.indexer.CalibrationResults(
11 | cutpoints: Mapping[str, Mapping[str, Union[str, float, int]]],
12 | thresholds: Mapping[str, Sequence[Union[float, int]]]
13 | )
14 | ```
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | Attributes |
27 |
28 |
29 | |
30 | cutpoints
31 | |
32 |
33 | A Dict mapping the cutpoint name to a Dict containing the
34 | ClassificationMetric values associated with a particular distance
35 | threshold, e.g., 'optimal' : 'acc': 0.90, 'f1': 0.92.
36 | |
37 |
38 | |
39 | thresholds
40 | |
41 |
42 | A Dict mapping ClassificationMetric names to a list
43 | containing the metric's value computed at each of the distance
44 | thresholds, e.g., 'f1': [0.99, 0.80], 'distance': [0.0, 1.0].
45 | |
46 |
47 |
48 |
49 |
50 |
51 | ## Methods
52 |
53 | __eq__
54 |
55 | ```python
56 | __eq__(
57 | other
58 | )
59 | ```
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/augmentation_utils/solarize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from __future__ import annotations
16 |
17 | import tensorflow as tf
18 |
19 | from tensorflow_similarity.augmenters.augmentation_utils.random_apply import (
20 | random_apply,
21 | )
22 | from tensorflow_similarity.types import Tensor
23 |
24 |
25 | def random_solarize(
26 | image: Tensor,
27 | p: float = 0.2,
28 | pixel_min: int = 0,
29 | pixel_max: int = 255,
30 | thresh: int = 10,
31 | ) -> Tensor:
32 | def _transform(image: Tensor) -> Tensor:
33 | return tf.where(image < thresh, image, pixel_max - image)
34 |
35 | return random_apply(_transform, p=p, x=image)
36 |
--------------------------------------------------------------------------------
/tests/integration/test_readme_minimal.py:
--------------------------------------------------------------------------------
1 | "Test that the example in the readme work correctly"
2 | import os
3 |
4 | import pytest
5 | from tensorflow.keras import layers
6 |
7 | from tensorflow_similarity.layers import MetricEmbedding
8 | from tensorflow_similarity.losses import MultiSimilarityLoss
9 | from tensorflow_similarity.models import SimilarityModel
10 | from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
11 |
12 |
13 | @pytest.fixture
14 | def readme_path(request):
15 | """Helper to load README relative to the test file."""
16 | # README path needs to be relative to the test.
17 | test_path = os.path.dirname(os.path.realpath(request.module.__file__))
18 | return os.path.join(test_path, "..", "..", "README.md")
19 |
20 |
21 | def test_readme_text_directly(readme_path):
22 | """Quick and dirty test of the README.md code snippets."""
23 | code = []
24 | code_block = False
25 |
26 | with open(readme_path, "r") as f:
27 | for line in f:
28 | if line.endswith("```\n"):
29 | code_block = False
30 |
31 | # Add all code lines except for the viz function.
32 | if code_block and not line.startswith("viz_neighbors_imgs"):
33 | code.append(line)
34 |
35 | if line.startswith("```python"):
36 | code_block = True
37 |
38 | exec(("\n").join(code))
39 |
--------------------------------------------------------------------------------
/tensorflow_similarity/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | "Collection of specialized notebook vizualization tools"
15 | import importlib.util
16 |
17 | excs = []
18 | for mod in ["PIL", "umap", "bokeh", "distinctipy", "matplotlib"]:
19 | if importlib.util.find_spec(mod) is None:
20 | excs.append(
21 | ModuleNotFoundError(
22 | f"{mod} is not installed. Please install it with `pip install tensorflow_similarity[visualization]`"
23 | )
24 | )
25 |
26 | if excs:
27 | raise Exception(excs)
28 | else:
29 | from .confusion_matrix_viz import confusion_matrix # noqa
30 | from .neighbors_viz import viz_neigbors_imgs # noqa
31 | from .projector import projector # noqa
32 | from .vizualize_views import visualize_views # noqa
33 |
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks/make_classification_metric.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.callbacks.make_classification_metric
2 |
3 |
4 |
5 |
6 |
7 | Convert classification metric from str name to object if needed.
8 |
9 |
10 | ```python
11 | TFSimilarity.callbacks.make_classification_metric(
12 | name: str =
13 | ```
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | Args |
24 |
25 |
26 | |
27 | metric
28 | |
29 |
30 | ClassificationMetric() or metric name.
31 | |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | Raises |
41 |
42 |
43 | |
44 | ValueError
45 | |
46 |
47 | Unknown metric name: metric, typo?
48 | |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | Returns |
58 |
59 |
60 | |
61 | ClassificationMetric
62 | |
63 |
64 | Instantiated metric if needed.
65 | |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmentation_utils/test_flip.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters.augmentation_utils import flip
5 |
6 |
7 | def create_img(width=32, height=32, channels=3):
8 | return tf.random.uniform([width, height, channels], 0, 1)
9 |
10 |
11 | def test_flip_left_right():
12 | img = create_img()
13 |
14 | random_flip_always = flip.random_random_flip_left_right(img, 1)
15 | random_flip_never = flip.random_random_flip_left_right(img, 0)
16 |
17 | # check shapes
18 | assert (tf.shape(random_flip_always) == tf.shape(img)).numpy().all()
19 | assert (tf.shape(random_flip_never) == tf.shape(img)).numpy().all()
20 |
21 | # check if flip works
22 | # assert not (random_flip_always == img).numpy().all()
23 | assert (random_flip_never == img).numpy().all()
24 |
25 |
26 | def test_flip_top_bottom():
27 | img = create_img()
28 |
29 | random_flip_always = flip.random_random_flip_top_bottom(img, 1)
30 | random_flip_never = flip.random_random_flip_top_bottom(img, 0)
31 |
32 | # check shapes
33 | assert (tf.shape(random_flip_always) == tf.shape(img)).numpy().all()
34 | assert (tf.shape(random_flip_never) == tf.shape(img)).numpy().all()
35 |
36 | # check if flip works
37 | # assert not (random_flip_always == img).numpy().all() # removing because would fail if flip were symmetric
38 | assert (random_flip_never == img).numpy().all()
39 |
--------------------------------------------------------------------------------
/api/TFSimilarity/search.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.search
2 |
3 |
4 |
5 |
6 |
7 | Efficiently find nearest indexed embeddings
8 |
9 |
10 | The search is used to find the closest indexed example embeddings
11 | to a query example embebbeding.
12 | To do so it performs a sub-linear time
13 | - [ANN (Approximate Nearst Neigboors)](https://en.wikipedia.org/wiki/Nearest_neighbor_search)
14 | search on the indexed set of embedding examples.
15 |
16 | Different ANN librairies have [different performance profiles](https://github.com/erikbern/ann-benchmarks).
17 | Tensorflow Similarity by default use [NMSLIB](https://github.com/nmslib/nmslib)
18 | which has a strong performance profile and is very portable.
19 |
20 | Adding another backend is fairly straightforward: subclass the
21 | abstract class `Search()` and implement the abstract methods. Then to use it
22 | pass it to the `compile()` method of your [SimilarityModel].
23 |
24 | Adding your search backend as a built-in choice invlolves
25 | modifiying the [Indexer](../indexer.md) and sending a PR. In general, unless
26 | the backend is of general use, its better to not include it as
27 | a built-in option as it must be supported moving forward.
28 |
29 | ## Classes
30 |
31 | - [`class NMSLibSearch`](../TFSimilarity/indexer/NMSLibSearch.md): Efficiently find nearest embeddings by indexing known embeddings and make
32 |
33 | - [`class Search`](../TFSimilarity/indexer/Search.md): Helper class that provides a standard way to create an ABC using
34 |
35 |
--------------------------------------------------------------------------------
/api/TFSimilarity/classification_metrics.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.classification_metrics
2 |
3 |
4 |
5 |
6 |
7 | Classification metrics measure matching classification quality between a
8 |
9 | set query examples and a set of indexed examples.
10 |
11 | ## Classes
12 |
13 | - [`class BinaryAccuracy`](../TFSimilarity/classification_metrics/BinaryAccuracy.md): Calculates how often the query label matches the derived lookup label.
14 |
15 | - [`class ClassificationMetric`](../TFSimilarity/callbacks/ClassificationMetric.md): Abstract base class for computing classification metrics.
16 |
17 | - [`class F1Score`](../TFSimilarity/classification_metrics/F1Score.md): Calculates the harmonic mean of precision and recall.
18 |
19 | - [`class FalsePositiveRate`](../TFSimilarity/classification_metrics/FalsePositiveRate.md): Calculates the false positive rate of the query classification.
20 |
21 | - [`class NegativePredictiveValue`](../TFSimilarity/classification_metrics/NegativePredictiveValue.md): Calculates the negative predictive value of the query classification.
22 |
23 | - [`class Precision`](../TFSimilarity/classification_metrics/Precision.md): Calculates the precision of the query classification.
24 |
25 | - [`class Recall`](../TFSimilarity/classification_metrics/Recall.md): Calculates the recall of the query classification.
26 |
27 | ## Functions
28 |
29 | - [`make_classification_metric(...)`](../TFSimilarity/callbacks/make_classification_metric.md): Convert classification metric from str name to object if needed.
30 |
31 |
--------------------------------------------------------------------------------
/api/TFSimilarity/evaluators.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.evaluators
2 |
3 |
4 |
5 |
6 |
7 | Evaluates search index performance and calibrates it.
8 |
9 |
10 | ## Use
11 |
12 | Evaluators are used for two primary purposes:
13 |
14 | - Evaluate model performance on a reference index during training and
15 | evaluation phase via the `evaluate_classification()` and `evaluate_retrieval()`
16 | methods. Evaluation on a reference index is
17 | required to be able to assess model performance using
18 | - [Classification metrics](../classification_metrics/) and
19 | - [Retrieval metrics](../retrieval_metrics/).
20 | Those metrics can't be computed without indexing data and looking up
21 | nearest neighbors.
22 |
23 | - Calibrating the model requires evaluating various distance thresholds
24 | to find the maximal distance threshold. Those thresholds either meet,
25 | if possible, the user supplied `thresholds_targets` performance value or
26 | the optimal value with respect to the calibration `classification metric`.
27 | Calibration is required to perform matching
28 | because the optimal distance thresholds will change depending on
29 | the model, dataset and, training. Accordingly those thresholds
30 | need to be evaluated empirically for every use-case.
31 |
32 | ## Classes
33 |
34 | - [`class Evaluator`](../TFSimilarity/callbacks/Evaluator.md): Evaluates search index performance and calibrates it.
35 |
36 | - [`class MemoryEvaluator`](../TFSimilarity/callbacks/MemoryEvaluator.md): In memory index performance evaluation and classification.
37 |
38 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/augmentation_utils/flip.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from __future__ import annotations
16 |
17 | import tensorflow as tf
18 |
19 | from tensorflow_similarity.augmenters.augmentation_utils.random_apply import (
20 | random_apply,
21 | )
22 | from tensorflow_similarity.types import Tensor
23 |
24 |
25 | def random_random_flip_left_right(image: Tensor, p: float = 0.5) -> Tensor:
26 | def _transform(image: Tensor) -> Tensor:
27 | return tf.image.random_flip_left_right(image)
28 |
29 | return random_apply(_transform, p=p, x=image)
30 |
31 |
32 | def random_random_flip_top_bottom(image: Tensor, p: float = 0.5) -> Tensor:
33 | def _transform(image: Tensor) -> Tensor:
34 | return tf.image.random_flip_up_down(image)
35 |
36 | return random_apply(_transform, p=p, x=image)
37 |
--------------------------------------------------------------------------------
/tensorflow_similarity/augmenters/contrastive.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | import os
17 | from collections.abc import Callable
18 |
19 | import tensorflow as tf
20 |
21 | from ..augmenters import Augmenter
22 | from ..types import Tensor
23 |
24 |
25 | class ContrastiveAugmenter(Augmenter):
26 | def __init__(self, process: Callable, num_cpu: int | None = os.cpu_count()):
27 | self.process = process
28 | self.num_cpu = num_cpu
29 |
30 | def augment(self, x: Tensor, y: Tensor, num_views: int, is_warmup: bool) -> list[Tensor]:
31 | with tf.device("/cpu:0"):
32 | inputs = tf.stack(x)
33 |
34 | views = []
35 | for _ in range(num_views):
36 | # multi-cor augementations
37 | view = tf.map_fn(self.process, inputs, parallel_iterations=self.num_cpu)
38 | views.append(view)
39 | return views
40 |
--------------------------------------------------------------------------------
/api/TFSimilarity/distances.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.distances
2 |
3 |
4 |
5 |
6 |
7 | Vectorized embedding pairwise distances computation functions
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class ABC`](../TFSimilarity/distances/ABC.md): Helper class that provides a standard way to create an ABC using
14 |
15 | - [`class CosineDistance`](../TFSimilarity/distances/CosineDistance.md): Compute pairwise cosine distances between embeddings.
16 |
17 | - [`class Distance`](../TFSimilarity/distances/Distance.md): Note: don't forget to add your distance to the DISTANCES list
18 |
19 | - [`class EuclideanDistance`](../TFSimilarity/distances/EuclideanDistance.md): Compute pairwise euclidean distances between embeddings.
20 |
21 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor
22 |
23 | - [`class InnerProductSimilarity`](../TFSimilarity/distances/InnerProductSimilarity.md): Compute the pairwise inner product between embeddings.
24 |
25 | - [`class ManhattanDistance`](../TFSimilarity/distances/ManhattanDistance.md): Compute pairwise Manhattan distances between embeddings.
26 |
27 | - [`class SNRDistance`](../TFSimilarity/distances/SNRDistance.md): Computes pairwise SNR distances between embeddings.
28 |
29 | - [`class SquaredEuclideanDistance`](../TFSimilarity/distances/SquaredEuclideanDistance.md): Compute pairwise squared Euclidean distance.
30 |
31 | ## Functions
32 |
33 | - [`distance_canonicalizer(...)`](../TFSimilarity/distances/distance_canonicalizer.md): Normalize user requested distance to its matching Distance object.
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/tests/architectures/test_resnet18.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pytest
4 |
5 | from tensorflow_similarity.architectures import resnet18
6 |
7 |
8 | def test_include_top():
9 | input_shape = (32, 32, 3)
10 | resnet = resnet18.ResNet18Sim(input_shape, include_top=True)
11 |
12 | # The second to last layer should use gem pooling when include_top is True
13 | assert resnet.layers[-2].name == "gem_pool"
14 | assert resnet.layers[-2].p == 3.0
15 | # The default is l2_norm True, so we expect the last layer to be
16 | # MetricEmbedding.
17 | assert re.match("metric_embedding", resnet.layers[-1].name) is not None
18 |
19 |
20 | def test_l2_norm_false():
21 | input_shape = (32, 32, 3)
22 | resnet = resnet18.ResNet18Sim(input_shape, include_top=True, l2_norm=False)
23 |
24 | # The second to last layer should use gem pooling when include_top is True
25 | assert resnet.layers[-2].name == "gem_pool"
26 | assert resnet.layers[-2].p == 3.0
27 | # If l2_norm is False, we should return a dense layer as the last layer.
28 | assert re.match("dense", resnet.layers[-1].name) is not None
29 |
30 |
31 | @pytest.mark.parametrize(
32 | "pooling, name", zip(["gem", "avg", "max"], ["gem_pool", "avg_pool", "max_pool"]), ids=["gem", "avg", "max"]
33 | )
34 | def test_include_top_false(pooling, name):
35 | input_shape = (32, 32, 3)
36 | resnet = resnet18.ResNet18Sim(input_shape, include_top=False, pooling=pooling)
37 |
38 | # The second to last layer should use gem pooling when include_top is True
39 | assert resnet.layers[-1].name == name
40 |
--------------------------------------------------------------------------------
/api/TFSimilarity/indexer/make_classification_matcher.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.indexer.make_classification_matcher
2 |
3 |
4 |
5 |
6 |
7 | Convert classification matcher from str name to object if needed.
8 |
9 |
10 | ```python
11 | TFSimilarity.indexer.make_classification_matcher(
12 | ```
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | Args |
23 |
24 |
25 | |
26 | matcher
27 | |
28 |
29 | 'match_nearest', 'match_majority_vote' or
30 | ClassificationMatch object. Defines the classification matching,
31 | e.g., match_nearest will count a True Positive if the query_label
32 | is equal to the label of the nearest neighbor and the distance is
33 | less than or equal to the distance threshold.
34 | |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | Raises |
44 |
45 |
46 | |
47 | ValueError
48 | |
49 |
50 | matcher name is invalid.
51 | |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | Returns |
61 |
62 |
63 | |
64 | ClassificationMatch
65 | |
66 |
67 | Instantiated matcher if needed.
68 | |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/tensorflow_similarity/retrieval_metrics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..types import BoolTensor, IntTensor
22 |
23 |
24 | def compute_match_mask(query_labels: IntTensor, lookup_labels: IntTensor) -> BoolTensor:
25 | """Compute a boolean mask (indicator function) marking the TPs in the results.
26 |
27 | Args:
28 | query_labels: A 1D tensor of the labels associated with the queries.
29 |
30 | lookup_labels: A 2D tensor where the jth row is the labels associated
31 | with the set of k neighbors for the jth query.
32 |
33 | Returns:
34 | A 2D boolean tensor indicating which lookups match the label of their
35 | associated query.
36 | """
37 | if tf.rank(query_labels) == 1:
38 | query_labels = tf.expand_dims(query_labels, axis=-1)
39 |
40 | match_mask: BoolTensor = tf.math.equal(lookup_labels, query_labels)
41 |
42 | return match_mask
43 |
--------------------------------------------------------------------------------
/api/TFSimilarity/visualization/viz_neigbors_imgs.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.visualization.viz_neigbors_imgs
2 |
3 |
4 |
5 |
6 |
7 | Display images nearest neighboors
8 |
9 | ```python
10 | TFSimilarity.visualization.viz_neigbors_imgs(
11 | example_class: int,
12 | class_mapping: Optional[Mapping[int, str]] = None,
13 | fig_size: Tuple[int, int] = (24, 4),
14 | cmap: str = viridis,
15 | show: bool = True
16 | )
17 | ```
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Args |
28 |
29 |
30 | |
31 | example
32 | |
33 |
34 | The data used as query input.
35 | |
36 |
37 | |
38 | example_class
39 | |
40 |
41 | The class of the data used as query
42 | |
43 |
44 | |
45 | neighbors
46 | |
47 |
48 | The list of neighbors returned by the lookup()
49 | |
50 |
51 | |
52 | class_mapping
53 | |
54 |
55 | Mapping from class numerical ids to a class name. If not
56 | set, the plot will display the class numerical id instead.
57 | Defaults to None.
58 | |
59 |
60 | |
61 | fig_size
62 | |
63 |
64 | Size of the figure. Defaults to (24, 4).
65 | |
66 |
67 | |
68 | cmap
69 | |
70 |
71 | Default color scheme for black and white images e.g mnist.
72 | Defaults to 'viridis'.
73 | |
74 |
75 | |
76 | show
77 | |
78 |
79 | If the plot is going to be shown or not. Defaults to True.
80 | |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/api/TFSimilarity/augmenters/Augmenter.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.augmenters.Augmenter
2 |
3 |
4 |
5 |
6 |
7 | Helper class that provides a standard way to create an ABC using
8 |
9 | Inherits From: [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | inheritance.
12 |
13 | ## Methods
14 |
15 | augment
16 |
17 | View source
18 |
19 | ``python
20 | @abc.abstractmethod``
21 |
22 | ```python
23 | augment(
24 | x: TFSimilarity.callbacks.Tensor```
25 | ,
26 | y: TFSimilarity.callbacks.Tensor```
27 | ,
28 | num_augmentations_per_example: int,
29 | is_warmup: bool
30 | ) -> List[TFSimilarity.callbacks.Tensor```
31 | ]
32 | ```
33 |
34 |
35 |
36 |
37 |
38 | __call__
39 |
40 | View source
41 |
42 | ```python
43 | __call__(
44 | x: TFSimilarity.callbacks.Tensor```
45 | ,
46 | y: TFSimilarity.callbacks.Tensor```
47 | ,
48 | num_augmentations_per_example: int,
49 | is_warmup: bool
50 | ) -> List[TFSimilarity.callbacks.Tensor```
51 | ]
52 | ```
53 |
54 |
55 | Call self as a function.
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/api/TFSimilarity/samplers.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.samplers
2 |
3 |
4 |
5 |
6 |
7 | Data Samplers generate balanced batches for smooth training.
8 |
9 |
10 | *A well balanced batch is a batch that contains at least 2 examples for
11 | each class present in the batch*.
12 |
13 | Having well balanced batches is important for many types of similarity learning
14 | including contrastive learning because contrastive losses require at least
15 | two examples (and sometimes more) to be able to compute distances between
16 | the embeddings.
17 |
18 | To address this need, TensorFlow Similarity provides data samplers for
19 | various types of datasets that:
20 | - Ensure that batches contain at least N examples of each class present in
21 | the batch.
22 | - Support restricting the batches to a subset of the classes present in
23 | the dataset.
24 |
25 | ## Classes
26 |
27 | - [`class MultiShotMemorySampler`](../TFSimilarity/samplers/MultiShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset.
28 |
29 | - [`class SingleShotMemorySampler`](../TFSimilarity/samplers/SingleShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset.
30 |
31 | - [`class TFDatasetMultiShotMemorySampler`](../TFSimilarity/samplers/TFDatasetMultiShotMemorySampler.md): Base object for fitting to a sequence of data, such as a dataset.
32 |
33 | ## Functions
34 |
35 | - [`TFRecordDatasetSampler(...)`](../TFSimilarity/samplers/TFRecordDatasetSampler.md): Create a [TFRecordDataset](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset) based sampler.
36 |
37 | - [`select_examples(...)`](../TFSimilarity/samplers/select_examples.md): Randomly select at most N examples per class
38 |
39 |
--------------------------------------------------------------------------------
/api/TFSimilarity/samplers/select_examples.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.samplers.select_examples
2 |
3 |
4 |
5 |
6 |
7 | Randomly select at most N examples per class
8 |
9 | ```python
10 | TFSimilarity.samplers.select_examples(
11 | class_list: Sequence[int] = None,
12 | num_examples_per_class: int = None
13 | ) -> Tuple[np.ndarray, np.ndarray]
14 | ```
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | Args |
25 |
26 |
27 | |
28 | x
29 | |
30 |
31 | A 2-D Tensor containing the data.
32 | |
33 |
34 | |
35 | y
36 | |
37 |
38 | A 1-D Tensor containing the labels.
39 | |
40 |
41 | |
42 | class_list
43 | |
44 |
45 | Filter the list of examples to only keep thoses those who
46 | belong to the supplied class list. In no class is supplied, keep
47 | examples for all the classes. Default to None - keep all the examples.
48 | |
49 |
50 | |
51 | num_examples_per_class
52 | |
53 |
54 | Restrict the number of examples for EACH
55 | class to num_examples_per_class if set. If not set, all the available
56 | examples are selected. Defaults to None - no selection.
57 | |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | Returns |
67 |
68 | |
69 | A Tuple containing the subset of x and y.
70 | |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/architectures.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Mapping
4 | from typing import Any
5 |
6 | import tensorflow as tf
7 |
8 | from tensorflow_similarity.architectures import (
9 | EfficientNetSim,
10 | ResNet18Sim,
11 | ResNet50Sim,
12 | )
13 |
14 | ARCHITECTURES = {}
15 | ARCHITECTURES["effnet"] = lambda p: EfficientNetSim(
16 | input_shape=p["input_shape"],
17 | embedding_size=p.get("embedding", 128),
18 | variant=p.get("variant", "B0"),
19 | weights=p.get("weights", "imagenet"),
20 | trainable=p.get("trainable", "frozen"),
21 | l2_norm=p.get("l2_norm", True),
22 | include_top=p.get("include_top", True),
23 | pooling=p.get("pooling", "gem"),
24 | gem_p=p.get("gem_p", 3.0),
25 | )
26 | ARCHITECTURES["resnet50"] = lambda p: ResNet50Sim(
27 | input_shape=p["input_shape"],
28 | embedding_size=p.get("embedding", 128),
29 | weights=p.get("weights", "imagenet"),
30 | trainable=p.get("trainable", "frozen"),
31 | l2_norm=p.get("l2_norm", True),
32 | include_top=p.get("include_top", True),
33 | pooling=p.get("pooling", "gem"),
34 | gem_p=p.get("gem_p", 3.0),
35 | )
36 | ARCHITECTURES["resnet18"] = lambda p: ResNet18Sim(
37 | input_shape=p["input_shape"],
38 | embedding_size=p.get("embedding", 128),
39 | l2_norm=p.get("l2_norm", True),
40 | include_top=p.get("include_top", True),
41 | pooling=p.get("pooling", "gem"),
42 | gem_p=p.get("gem_p", 3.0),
43 | )
44 |
45 |
46 | def make_architecture(architecture_id: str, params: Mapping[str, Any]) -> tf.keras.Model:
47 | try:
48 | return ARCHITECTURES[architecture_id](params)
49 | except KeyError as exc:
50 | raise ValueError(f"Unknown architecture name: {architecture_id}") from exc
51 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmentation_utils/test_blur.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters.augmentation_utils import blur
5 |
6 |
7 | def create_img(width=32, height=32, channels=3):
8 | return tf.random.uniform([width, height, channels], 0, 1)
9 |
10 |
11 | def test_random_blur():
12 | # Random Blur
13 | img = create_img()
14 | WIDTH = 32
15 | HEIGHT = 32
16 | CHANNELS = 3
17 |
18 | random_blurred_always = blur.random_blur(img, 32, 32, 1.0)
19 |
20 | random_blurred_never = blur.random_blur(img, 32, 32, 0)
21 |
22 | # check shapes
23 | assert (tf.shape(random_blurred_always) == tf.shape(img)).numpy().all()
24 | assert (tf.shape(random_blurred_never) == tf.shape(img)).numpy().all()
25 |
26 | # check if blur works
27 | assert not (random_blurred_always == img).numpy().all()
28 | assert (random_blurred_never == img).numpy().all()
29 |
30 |
31 | def test_batch_random_blur():
32 | img = create_img()
33 | WIDTH = 32
34 | HEIGHT = 32
35 | CHANNELS = 3
36 |
37 | batched_img = [img]
38 |
39 | random_batched_blurred_always = blur.batch_random_blur(batched_img, 32, 32, 1.0)
40 |
41 | random_batched_blurred_never = blur.batch_random_blur(batched_img, 32, 32, 0)
42 |
43 | # check shapes
44 | assert (tf.shape(random_batched_blurred_always) == tf.shape(batched_img)).numpy().all()
45 | assert (tf.shape(random_batched_blurred_never) == tf.shape(batched_img)).numpy().all()
46 |
47 | # check if blur works
48 | equality_always = tf.reshape(tf.equal(random_batched_blurred_always, batched_img), [-1])
49 | equality_never = tf.reshape(tf.equal(random_batched_blurred_never, batched_img), [-1])
50 |
51 | assert not equality_always.numpy().all()
52 | assert equality_never.numpy().all()
53 |
--------------------------------------------------------------------------------
/tensorflow_similarity/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data Samplers generate balanced batches for smooth training.
16 |
17 | *A well balanced batch is a batch that contains at least 2 examples for
18 | each class present in the batch*.
19 |
20 | Having well balanced batches is important for many types of similarity learning
21 | including contrastive learning because contrastive losses require at least
22 | two examples (and sometimes more) to be able to compute distances between
23 | the embeddings.
24 |
25 | To address this need, TensorFlow Similarity provides data samplers for
26 | various types of datasets that:
27 | - Ensure that batches contain at least N examples of each class present in
28 | the batch.
29 | - Support restricting the batches to a subset of the classes present in
30 | the dataset.
31 | """
32 | from .file_samplers import MultiShotFileSampler # noqa
33 | from .memory_samplers import MultiShotMemorySampler # noqa
34 | from .memory_samplers import SingleShotMemorySampler # noqa
35 | from .tfdata_sampler import TFDataSampler # noqa
36 | from .tfdataset_samplers import TFDatasetMultiShotMemorySampler # noqa
37 | from .tfrecords_samplers import TFRecordDatasetSampler # noqa
38 | from .utils import select_examples # noqa
39 |
--------------------------------------------------------------------------------
/api/TFSimilarity/visualization/visualize_views.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.visualization.visualize_views
2 |
3 |
4 |
5 |
6 |
7 | Display side by side different image views with labels, and predictions
8 |
9 | ```python
10 | TFSimilarity.visualization.visualize_views(
11 | num_imgs: int = None,
12 | views_per_col: int = 4,
13 | fig_size: Tuple[int, int] = (24, 4),
14 | max_pixel_value: float = 1.0,
15 | min_pixel_value: float = 0.0
16 | )
17 | ```
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Args |
28 |
29 |
30 | |
31 | views
32 | |
33 |
34 | Aray of views
35 | |
36 |
37 | |
38 | predictions
39 | |
40 |
41 | model output.
42 | |
43 |
44 | |
45 | labels
46 | |
47 |
48 | image labels
49 | |
50 |
51 | |
52 | num_imgs
53 | |
54 |
55 | number of images to use.
56 | |
57 |
58 | |
59 | views_per_col
60 | |
61 |
62 | Int, number of images in one row. Defaults to 3.
63 | |
64 |
65 | |
66 | max_pixel_value
67 | |
68 |
69 | Max expected value for a pixel. Used to scale the image
70 | between [0,1].
71 | |
72 |
73 | |
74 | min_pixel_value
75 | |
76 |
77 | Min expected value for a pixel. Used to scale the image
78 | between [0,1].
79 | |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 | Returns |
89 |
90 | |
91 | None.
92 | |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/tests/augmenters/test_augmentation_utils/test_color_jitter.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.augmenters.augmentation_utils import color_jitter
5 |
6 |
7 | def create_img(width=32, height=32, channels=3):
8 | return tf.random.uniform([width, height, channels], 0, 1)
9 |
10 |
11 | def test_random_color_jitter_multiplicative():
12 | # Random Color Jitter
13 | img = create_img()
14 | WIDTH = 32
15 | HEIGHT = 32
16 | CHANNELS = 3
17 |
18 | random_jitter_always = color_jitter.random_color_jitter(img, 1, 1, 1, impl="multiplicative")
19 |
20 | random_jitter_never = color_jitter.random_color_jitter(img, 0, impl="multiplicative")
21 |
22 | # check shapes
23 | assert (tf.shape(random_jitter_always) == tf.shape(img)).numpy().all()
24 | assert (tf.shape(random_jitter_never) == tf.shape(img)).numpy().all()
25 |
26 | # check if blur works
27 | assert not (random_jitter_always == img).numpy().all()
28 | assert (random_jitter_never == img).numpy().all()
29 |
30 |
31 | def test_random_color_jitter_additive():
32 | # Random Color Jitter
33 | img = create_img()
34 | WIDTH = 32
35 | HEIGHT = 32
36 | CHANNELS = 3
37 |
38 | random_jitter_always = color_jitter.random_color_jitter(
39 | img, 1, 1, 1, impl="additive" # won't make a difference between barlow/v1
40 | )
41 |
42 | random_jitter_never = color_jitter.random_color_jitter(
43 | img, 0, impl="additive" # won't make a difference between barlow/v1
44 | )
45 |
46 | # check shapes
47 | assert (tf.shape(random_jitter_always) == tf.shape(img)).numpy().all()
48 | assert (tf.shape(random_jitter_never) == tf.shape(img)).numpy().all()
49 |
50 | # check if color jitter works
51 | assert not (random_jitter_always == img).numpy().all()
52 | assert (random_jitter_never == img).numpy().all()
53 |
--------------------------------------------------------------------------------
/api/TFSimilarity/augmenters/ContrastiveAugmenter.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.augmenters.ContrastiveAugmenter
2 |
3 |
4 |
5 |
6 |
7 | Helper class that provides a standard way to create an ABC using
8 |
9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | ```python
12 | TFSimilarity.augmenters.ContrastiveAugmenter(
13 | process: Callable, num_cpu: Optional[int] = os.cpu_count()
14 | )
15 | ```
16 |
17 |
18 |
19 |
20 | inheritance.
21 |
22 | ## Methods
23 |
24 | augment
25 |
26 | View source
27 |
28 | ```python
29 | augment(
30 | x: TFSimilarity.callbacks.Tensor```
31 | ,
32 | y: TFSimilarity.callbacks.Tensor```
33 | ,
34 | num_views: int,
35 | is_warmup: bool
36 | ) -> List[TFSimilarity.callbacks.Tensor```
37 | ]
38 | ```
39 |
40 |
41 |
42 |
43 |
44 | __call__
45 |
46 | View source
47 |
48 | ```python
49 | __call__(
50 | x: TFSimilarity.callbacks.Tensor```
51 | ,
52 | y: TFSimilarity.callbacks.Tensor```
53 | ,
54 | num_augmentations_per_example: int,
55 | is_warmup: bool
56 | ) -> List[TFSimilarity.callbacks.Tensor```
57 | ]
58 | ```
59 |
60 |
61 | Call self as a function.
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/api/TFSimilarity/visualization/confusion_matrix.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.visualization.confusion_matrix
2 |
3 |
4 |
5 |
6 |
7 | Plot confusion matrix
8 |
9 | ```python
10 | TFSimilarity.visualization.confusion_matrix(
11 | normalize: bool = True,
12 | title: str = Confusion matrix,
13 | cmap: str = Blues,
14 | show: bool = True
15 | ```
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | Args |
26 |
27 |
28 | |
29 | y_pred
30 | |
31 |
32 | Model prediction returned by model.match()
33 | |
34 |
35 | |
36 | y_true
37 | |
38 |
39 | Expected class_id.
40 | |
41 |
42 | |
43 | normalize
44 | |
45 |
46 | Normalizes matrix values between 0 and 1.
47 | Defaults to True.
48 | |
49 |
50 | |
51 | labels
52 | |
53 |
54 | List of class string label to display instead of the class
55 | numerical ids. Defaults to None.
56 | |
57 |
58 | |
59 | title
60 | |
61 |
62 | Title of the confusion matrix. Defaults to 'Confusion matrix'.
63 | |
64 |
65 | |
66 | cmap
67 | |
68 |
69 | Color schema as CMAP. Defaults to 'Blues'.
70 | |
71 |
72 | |
73 | show
74 | |
75 |
76 | If the plot is going to be shown or not. Defaults to True.
77 | |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 | Returns |
87 |
88 | |
89 | A Tuple containing the plot and confusion matrix.
90 | |
91 |
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/tensorflow_similarity/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Evaluates search index performance and calibrates it.
15 |
16 | ## Use
17 |
18 | Evaluators are used for two primary purposes:
19 |
20 | - Evaluate model performance on a reference index during training and
21 | evaluation phase via the `evaluate_classification()` and `evaluate_retrieval()`
22 | methods. Evaluation on a reference index is
23 | required to be able to assess model performance using
24 | [Classification metrics](../classification_metrics/) and
25 | [Retrieval metrics](../retrieval_metrics/).
26 | Those metrics can't be computed without indexing data and looking up
27 | nearest neighbors.
28 |
29 | - Calibrating the model requires evaluating various distance thresholds
30 | to find the maximal distance threshold. Those thresholds either meet,
31 | if possible, the user supplied `thresholds_targets` performance value or
32 | the optimal value with respect to the calibration `classification metric`.
33 | Calibration is required to perform matching
34 | because the optimal distance thresholds will change depending on
35 | the model, dataset and, training. Accordingly those thresholds
36 | need to be evaluated empirically for every use-case.
37 | """
38 | from .evaluator import Evaluator # noqa
39 | from .memory_evaluator import MemoryEvaluator # noqa
40 |
--------------------------------------------------------------------------------
/api/TFSimilarity/augmenters/BarlowAugmenter.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.augmenters.BarlowAugmenter
2 |
3 |
4 |
5 |
6 |
7 | Helper class that provides a standard way to create an ABC using
8 |
9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | ```python
12 | TFSimilarity.augmenters.BarlowAugmenter(
13 | width: int,
14 | height: int,
15 | flip_probability=0.5,
16 | brightness_multiplier=0.8,
17 | contrast_multiplier=0.6,
18 | saturation_multiplier=0.6,
19 | hue_multiplier=0.2,
20 | jitter_probability=0.8,
21 | greyscale_probability=0.2,
22 | blur_probability=0.2,
23 | blur_min_sigma=0,
24 | blur_max_sigma=1,
25 | solarize_probability=0.2,
26 | solarize_pixel_min=0,
27 | solarize_pixel_max=255,
28 | solarize_thresh=10,
29 | num_cpu: Optional[int] = os.cpu_count()
30 | )
31 | ```
32 |
33 |
34 |
35 |
36 | inheritance.
37 |
38 | ## Methods
39 |
40 | augment
41 |
42 | View source
43 |
44 | ``python
45 | @tf.function``
46 |
47 | ```python
48 | augment(
49 | x: Any,
50 | y: Any = tf.constant([0]),
51 | num_augmentations_per_example: int = 2,
52 | is_warmup: bool = True
53 | ) -> List[Any]
54 | ```
55 |
56 |
57 |
58 |
59 |
60 | __call__
61 |
62 | View source
63 |
64 | ```python
65 | __call__(
66 | x: Any,
67 | y: Any = tf.constant([0]),
68 | num_augmentations_per_example: int = 2,
69 | is_warmup: bool = True
70 | ) -> List[Any]
71 | ```
72 |
73 |
74 | Call self as a function.
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/tensorflow_similarity/training_metrics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..samplers.samplers import Sampler
22 |
23 |
24 | def batch_class_ratio(sampler: Sampler, num_batches: int = 100) -> float:
25 | """Computes the average number of examples per class within each batch.
26 | Similarity learning requires at least 2 examples per class in each batch.
27 | This is needed in order to construct the triplets. This function
28 | provides the average number of examples per class within each batch and
29 | can be used to check that a sampler is working correctly.
30 | The ratio should be >= 2.
31 | Args:
32 | sampler: A tf.similarity sampler object.
33 | num_batches: The number of batches to sample.
34 | Returns:
35 | The average number of examples per class.
36 | """
37 | ratio = 0
38 | for batch_count, (_, y) in enumerate(sampler):
39 | if batch_count < num_batches:
40 | batch_size = tf.shape(y)[0]
41 | num_classes = tf.shape(tf.unique(y)[0])[0]
42 | ratio += tf.math.divide(batch_size, num_classes)
43 | else:
44 | break
45 |
46 | return float(ratio / (batch_count + 1))
47 |
--------------------------------------------------------------------------------
/scripts/increment_version.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def abs_path(rel_path):
5 | here = os.path.abspath(os.path.dirname(__file__))
6 | return os.path.join(here, rel_path)
7 |
8 |
9 | def read(rel_path):
10 | with open(abs_path(rel_path), "rt") as fd:
11 | return fd.read()
12 |
13 |
14 | def write(rel_path, lines):
15 | with open(abs_path(rel_path), "wt") as fd:
16 | for line in lines:
17 | fd.write(line + "\n")
18 |
19 |
20 | def increment_dev_version(previous_version):
21 | delim = "dev"
22 | if delim not in previous_version:
23 | raise ValueError(f"The previous version {previous_version} does contain a dev suffix")
24 | # Split and increment dev version
25 | main_version, dev_version = previous_version.split(delim)
26 | dev_version = int(dev_version) + 1
27 | # Construct new version
28 | new_version = f"{main_version}{delim}{dev_version}"
29 |
30 | return new_version
31 |
32 |
33 | def update_version(rel_path):
34 | lines = []
35 | for line in read(rel_path).splitlines():
36 | if line.startswith("__version__"):
37 | delim = '"' if '"' in line else "'"
38 | previous_version = line.split(delim)[1]
39 | new_version = increment_dev_version(previous_version)
40 | line = line.replace(previous_version, new_version)
41 | lines.append(line)
42 |
43 | write(rel_path, lines)
44 |
45 | return new_version
46 |
47 |
48 | if __name__ == "__main__":
49 | version_path = "../tensorflow_similarity/__init__.py"
50 |
51 | os.system("git config --global user.email 'action@github.com'")
52 | os.system("git config --global user.name 'Github Actions Bot'")
53 |
54 | # Assumes we are in the scripts/ dir
55 | new_version = update_version(version_path)
56 |
57 | os.system("git add -u")
58 | os.system(f"git commit -m '[nightly] Increase version to {new_version}'")
59 | os.system("git push")
60 |
--------------------------------------------------------------------------------
/tests/stores/test_redis.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | import numpy as np
4 |
5 | from tensorflow_similarity.stores.redis import RedisStore
6 |
7 |
8 | class MockRedis:
9 | def __init__(self, *args, **kwargs):
10 | self.flushdb()
11 |
12 | def get(self, key):
13 | if key in self.cache:
14 | return self.cache[key]
15 | return None # return nil
16 |
17 | def set(self, key, value, *args, **kwargs):
18 | self.cache[key] = value
19 | return "OK"
20 |
21 | def flushdb(self):
22 | self.cache = {}
23 | return "OK"
24 |
25 | def incr(self, key):
26 | if key in self.cache:
27 | self.cache[key] += 1
28 | else:
29 | self.cache[key] = 1
30 | return self.cache[key]
31 |
32 |
33 | @patch("redis.Redis", MockRedis)
34 | def test_store_and_retrieve():
35 | recs = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
36 |
37 | kv_store = RedisStore()
38 | idxs = []
39 | for r in recs:
40 | idx = kv_store.add(r[0], r[1], r[2])
41 | idxs.append(idx)
42 |
43 | # check index numbering
44 | for gt, idx in enumerate(idxs):
45 | assert isinstance(idx, int)
46 | assert gt == idx
47 |
48 | # get back three elements
49 | for idx in idxs:
50 | emb, lbl, dt = kv_store.get(idx)
51 | assert emb == recs[idx][0]
52 | assert lbl == recs[idx][1]
53 | assert dt == recs[idx][2]
54 |
55 |
56 | @patch("redis.Redis", MockRedis)
57 | def test_batch_add():
58 | embs = np.array([[0.1, 0.2], [0.2, 0.3]])
59 | lbls = np.array([1, 2])
60 | data = np.array([[0, 0, 0], [1, 1, 1]])
61 |
62 | kv_store = RedisStore()
63 | idxs = kv_store.batch_add(embs, lbls, data)
64 | for idx in idxs:
65 | emb, lbl, dt = kv_store.get(idx)
66 | assert np.array_equal(emb, embs[idx])
67 | assert np.array_equal(lbl, lbls[idx])
68 | assert np.array_equal(dt, data[idx])
69 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/optimizers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Mapping
4 | from typing import Any
5 |
6 | import keras_tuner
7 | import tensorflow as tf
8 | from tensorflow.keras.optimizers import Adam, RMSprop
9 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule
10 | from tensorflow_addons.optimizers import LAMB
11 |
12 | from . import utils
13 |
14 | OPTIMIZERS = {}
15 | OPTIMIZERS["adam"] = lambda p, lrs, hp: Adam(
16 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp),
17 | beta_1=utils.get_param(p, "beta_1", 0.9, hp),
18 | beta_2=utils.get_param(p, "beta_2", 0.999, hp),
19 | epsilon=utils.get_param(p, "epsilon", 1e-07, hp),
20 | amsgrad=utils.get_param(p, "amsgrad", False, hp),
21 | )
22 | OPTIMIZERS["lamb"] = lambda p, lrs, hp: LAMB(
23 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp),
24 | beta_1=utils.get_param(p, "beta_1", 0.9, hp),
25 | beta_2=utils.get_param(p, "beta_2", 0.999, hp),
26 | epsilon=utils.get_param(p, "epsilon", 1e-06, hp),
27 | weight_decay=utils.get_param(p, "weight_decay", 0.0, hp),
28 | exclude_from_weight_decay=utils.get_param(p, "exclude_from_weight_decay", None, hp),
29 | exclude_from_layer_adaptation=utils.get_param(p, "exclude_from_layer_adaptation", None, hp),
30 | )
31 | OPTIMIZERS["rmsprop"] = lambda p, lrs, hp: RMSprop(
32 | learning_rate=lrs if lrs else utils.get_param(p, "lr", 0.001, hp),
33 | rho=utils.get_param(p, "rho", 0.9, hp),
34 | momentum=utils.get_param(p, "momentum", 0.0, hp),
35 | epsilon=utils.get_param(p, "epsilon", 1e-07, hp),
36 | centered=utils.get_param(p, "centered", False, hp),
37 | )
38 |
39 |
40 | def make_optimizer(
41 | opt_id: str,
42 | params: Mapping[str, Any],
43 | lr_schedule: LearningRateSchedule | None,
44 | hp: keras_tuner.HyperParameters | None = None,
45 | ) -> tf.keras.Optimizer:
46 | try:
47 | return OPTIMIZERS[opt_id](params, lr_schedule, hp)
48 | except KeyError as exc:
49 | raise ValueError(f"Unknown optimizer name: {opt_id}") from exc
50 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/datasets/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import inspect
5 | from collections.abc import Mapping, Sequence
6 | from typing import Any
7 |
8 |
9 | @dataclasses.dataclass
10 | class Fold:
11 | training_classes: list[int] = dataclasses.field(default_factory=list)
12 | val_classes: list[int] = dataclasses.field(default_factory=list)
13 | train_idxs: list[int] = dataclasses.field(default_factory=list)
14 | val_idxs: list[int] = dataclasses.field(default_factory=list)
15 |
16 |
17 | @dataclasses.dataclass
18 | class Splits:
19 | train_idxs: list[int] = dataclasses.field(default_factory=list)
20 | test_idxs: list[int] = dataclasses.field(default_factory=list)
21 | folds: Mapping[str, Fold] = dataclasses.field(default_factory=dict)
22 |
23 |
24 | @dataclasses.dataclass
25 | class TrainValSplit:
26 | num_splits: int
27 | val_class_pctg: float
28 |
29 |
30 | @dataclasses.dataclass
31 | class EvalCallback:
32 | max_num_queries: int
33 | max_num_targets: int
34 |
35 |
36 | @dataclasses.dataclass(eq=True, frozen=True)
37 | class DatasetConfig:
38 | name: str
39 | component: str
40 | dataset_id: str
41 | x_key: str
42 | y_key: str
43 | train_classes: Sequence[int]
44 | test_classes: Sequence[int]
45 | train_val_splits: TrainValSplit | None = None
46 | eval_callback: EvalCallback | None = None
47 |
48 | def __post_init__(self):
49 | if self.train_val_splits is None:
50 | self.train_val_splits = TrainValSplit(
51 | num_splits=0,
52 | val_class_pctg=0.0,
53 | )
54 |
55 | def __hash__(self):
56 | return hash((self.name, self.component, self.dataset_id, self.x_key, self.y_key))
57 |
58 | @classmethod
59 | def from_dict(cls, params: dict[str, Any]) -> DatasetConfig:
60 | """Create a DatasetConfig from a dictionary.
61 |
62 | This supports kwargs that are not part of the DatasetConfig class.
63 | """
64 | kwargs = {k: v for k, v in params.items() if k in inspect.signature(cls).parameters}
65 | return cls(**kwargs)
66 |
--------------------------------------------------------------------------------
/tensorflow_similarity/matchers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING, Type
17 |
18 | if TYPE_CHECKING:
19 | from .classification_match import ClassificationMatch
20 |
21 | from .match_majority_vote import MatchMajorityVote
22 | from .match_nearest import MatchNearest
23 |
24 |
25 | def make_classification_matcher(matcher: str | ClassificationMatch) -> ClassificationMatch:
26 | """Convert classification matcher from str name to object if needed.
27 |
28 | Args:
29 | matcher: {'match_nearest', 'match_majority_vote'} or
30 | ClassificationMatch object. Defines the classification matching,
31 | e.g., match_nearest will count a True Positive if the query_label
32 | is equal to the label of the nearest neighbor and the distance is
33 | less than or equal to the distance threshold.
34 |
35 | Raises:
36 | ValueError: matcher name is invalid.
37 |
38 | Returns:
39 | ClassificationMatch: Instantiated matcher if needed.
40 | """
41 | # ! Matcher must be non-instantiated.
42 | MATCHER_ALIASES: dict[str, Type[ClassificationMatch]] = {
43 | "match_nearest": MatchNearest,
44 | "match_majority_vote": MatchMajorityVote,
45 | }
46 |
47 | if isinstance(matcher, str):
48 | if matcher.lower() in MATCHER_ALIASES:
49 | matcher = MATCHER_ALIASES[matcher.lower()]()
50 | else:
51 | raise ValueError(f"Unknown matcher name: {matcher}, typo?")
52 |
53 | return matcher
54 |
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks.md:
--------------------------------------------------------------------------------
1 | # Module: TFSimilarity.callbacks
2 |
3 |
4 |
5 |
6 |
7 | Specialized callbacks that track similarity metrics during training
8 |
9 |
10 |
11 | ## Classes
12 |
13 | - [`class Callback`](../TFSimilarity/callbacks/Callback.md): Abstract base class used to build new callbacks.
14 |
15 | - [`class ClassificationMatch`](../TFSimilarity/callbacks/ClassificationMatch.md): Abstract base class for defining the classification matching strategy.
16 |
17 | - [`class ClassificationMetric`](../TFSimilarity/callbacks/ClassificationMetric.md): Abstract base class for computing classification metrics.
18 |
19 | - [`class EvalCallback`](../TFSimilarity/callbacks/EvalCallback.md): Epoch end evaluation callback that build a test index and evaluate
20 |
21 | - [`class Evaluator`](../TFSimilarity/callbacks/Evaluator.md): Evaluates search index performance and calibrates it.
22 |
23 | - [`class FloatTensor`](../TFSimilarity/callbacks/FloatTensor.md): Float tensor
24 |
25 | - [`class IntTensor`](../TFSimilarity/callbacks/IntTensor.md): Integer tensor
26 |
27 | - [`class MemoryEvaluator`](../TFSimilarity/callbacks/MemoryEvaluator.md): In memory index performance evaluation and classification.
28 |
29 | - [`class Path`](../TFSimilarity/callbacks/Path.md): PurePath subclass that can make system calls.
30 |
31 | - [`class SimilarityModel`](../TFSimilarity/callbacks/SimilarityModel.md): Specialized Keras.Model which implement the core features needed for
32 |
33 | - [`class Tensor`](../TFSimilarity/callbacks/Tensor.md): A `tf.Tensor` represents a multidimensional array of elements.
34 |
35 | ## Functions
36 |
37 | - [`SplitValidationLoss(...)`](../TFSimilarity/callbacks/SplitValidationLoss.md): Creates the validation callbacks.
38 |
39 | - [`make_classification_metric(...)`](../TFSimilarity/callbacks/make_classification_metric.md): Convert classification metric from str name to object if needed.
40 |
41 | - [`unpack_lookup_distances(...)`](../TFSimilarity/callbacks/unpack_lookup_distances.md)
42 |
43 | - [`unpack_lookup_labels(...)`](../TFSimilarity/callbacks/unpack_lookup_labels.md)
44 |
45 | - [`unpack_results(...)`](../TFSimilarity/callbacks/unpack_results.md): Updates logs, writes summary, and returns list of strings of
46 |
47 |
--------------------------------------------------------------------------------
/tensorflow_similarity/distances/cosine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inner product similarity computation functions for embeddings."""
15 | from __future__ import annotations
16 |
17 | from typing import TYPE_CHECKING
18 |
19 | import tensorflow as tf
20 |
21 | if TYPE_CHECKING:
22 | from ..types import FloatTensor
23 |
24 | from .distance import Distance
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Similarity")
28 | class CosineDistance(Distance):
29 | """Compute pairwise cosine distances between embeddings.
30 |
31 | The [Cosine Distance](https://en.wikipedia.org/wiki/Cosine_similarity) is
32 | an angular distance that varies from 0 (similar) to 1 (dissimilar).
33 | """
34 |
35 | def __init__(self, name: str = "cosine", **kwargs):
36 | "Init Cosine distance"
37 | super().__init__(name=name, **kwargs)
38 |
39 | @tf.function
40 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor:
41 | """Compute pairwise distances for a given batch of embeddings.
42 |
43 | Args:
44 | query_embeddings: Embeddings to compute the pairwise one. The embeddings
45 | are expected to be normalized.
46 | key_embeddings: Embeddings to compute the pairwise one. The embeddings
47 | are expected to be normalized.
48 |
49 | Returns:
50 | FloatTensor: Pairwise distance tensor.
51 | """
52 | distances = 1 - tf.linalg.matmul(query_embeddings, key_embeddings, transpose_b=True)
53 | min_clip_distances: FloatTensor = tf.math.maximum(distances, 0.0)
54 | return min_clip_distances
55 |
--------------------------------------------------------------------------------
/tests/samplers/test_tfrecord_samplers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 |
5 | from tensorflow_similarity.samplers import TFRecordDatasetSampler
6 |
7 |
8 | def _int64_feature(value):
9 | """Returns an int64_list from a bool / enum / int / uint."""
10 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
11 |
12 |
13 | def to_tfrecord(sid, value):
14 | feature = {
15 | "sid": _int64_feature(sid),
16 | "value": _int64_feature(value),
17 | }
18 |
19 | example = tf.train.Example(features=tf.train.Features(feature=feature))
20 | return example.SerializeToString()
21 |
22 |
23 | def deserialization_fn(serialized_example):
24 | fd = {
25 | "sid": tf.io.FixedLenFeature([], dtype=tf.int64),
26 | "value": tf.io.FixedLenFeature([], dtype=tf.int64),
27 | }
28 | sample = tf.io.parse_single_example(serialized_example, fd)
29 |
30 | return (sample["sid"], sample["value"])
31 |
32 |
33 | class TFRecordSamplerTest(tf.test.TestCase):
34 | def setUp(self):
35 | super().setUp()
36 |
37 | for sid in range(100):
38 | shard_path = os.path.join(self.get_temp_dir(), f"tfr_{sid}.tfrec")
39 |
40 | with tf.io.TFRecordWriter(str(shard_path)) as w:
41 | for value in range(1000):
42 | example = to_tfrecord(sid, sid * 1000 + value)
43 | w.write(example)
44 |
45 | def test_basic(self):
46 | sampler = TFRecordDatasetSampler(
47 | self.get_temp_dir(),
48 | deserialization_fn=deserialization_fn,
49 | batch_size=10,
50 | example_per_class=2,
51 | )
52 |
53 | si = iter(sampler)
54 | [next(si) for _ in range(10_000)]
55 | examples = next(si)
56 |
57 | # We should get 3 pairs of shard IDs
58 | sids = examples[0]
59 | values = examples[1]
60 | first_sid = sids[::2]
61 | second_sid = sids[1::2]
62 |
63 | self.assertLen(sids, 10)
64 | self.assertAllEqual(first_sid, second_sid)
65 |
66 | for sid, val in zip(sids, values):
67 | diff = val - sid * 1_000
68 | self.assertGreaterEqual(diff, 0)
69 | self.assertLess(diff, 1000)
70 |
--------------------------------------------------------------------------------
/api/TFSimilarity/augmenters/SimCLRAugmenter.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.augmenters.SimCLRAugmenter
2 |
3 |
4 |
5 |
6 |
7 | SimCLR augmentation pipeline as defined in
8 |
9 | Inherits From: [`Augmenter`](../../TFSimilarity/augmenters/Augmenter.md), [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | ```python
12 | TFSimilarity.augmenters.SimCLRAugmenter(
13 | height: int,
14 | width: int,
15 | is_training: bool = True,
16 | color_distort: bool = True,
17 | jitter_stength: float = 1.0,
18 | crop: bool = True,
19 | eval_crop_proportion: float = 0.875,
20 | flip: bool = True,
21 | version: str = v2,
22 | num_cpu: Optional[int] = os.cpu_count()
23 | )
24 | ```
25 |
26 |
27 |
28 |
29 | - [Big Self-Supervised Models are Strong Semi-Supervised Learners](https://arxiv.org/abs/2006.10029)
30 |
31 | code adapted from [orignal github](https://github.com/google-research/simclr/tree/master/tf2)
32 |
33 | ## Methods
34 |
35 | augment
36 |
37 | View source
38 |
39 | ``python
40 | @tf.function``
41 |
42 | ```python
43 | augment(
44 | x: TFSimilarity.callbacks.Tensor```
45 | ,
46 | y: TFSimilarity.callbacks.Tensor```
47 | ,
48 | num_views: int,
49 | is_warmup: bool
50 | ) -> List[TFSimilarity.callbacks.Tensor```
51 | ]
52 | ```
53 |
54 |
55 |
56 |
57 |
58 | __call__
59 |
60 | View source
61 |
62 | ```python
63 | __call__(
64 | x: TFSimilarity.callbacks.Tensor```
65 | ,
66 | y: TFSimilarity.callbacks.Tensor```
67 | ,
68 | num_augmentations_per_example: int,
69 | is_warmup: bool
70 | ) -> List[TFSimilarity.callbacks.Tensor```
71 | ]
72 | ```
73 |
74 |
75 | Call self as a function.
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/tensorflow_similarity/visualization/vizualize_views.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import tensorflow as tf
6 | from matplotlib import pyplot as plt
7 |
8 | if TYPE_CHECKING:
9 | from tensorflow import Tensor
10 |
11 |
12 | def visualize_views(
13 | views: Tensor,
14 | labels: Tensor = None,
15 | predictions: Tensor = None,
16 | num_imgs: int | None = None,
17 | views_per_col: int = 4,
18 | fig_size: tuple[int, int] = (24, 4),
19 | max_pixel_value: float = 1.0,
20 | min_pixel_value: float = 0.0,
21 | ):
22 | """Display side by side different image views with labels, and predictions
23 |
24 | Args:
25 | views: Aray of views
26 | predictions: model output.
27 | labels: image labels
28 | num_imgs: number of images to use.
29 | views_per_col: Int, number of images in one row. Defaults to 3.
30 | max_pixel_value: Max expected value for a pixel. Used to scale the image
31 | between [0,1].
32 | min_pixel_value: Min expected value for a pixel. Used to scale the image
33 | between [0,1].
34 |
35 | Returns:
36 | None.
37 | """
38 | num_views = len(views)
39 | num_imgs = num_imgs if num_imgs else len(views[0])
40 | num_col = views_per_col
41 | num_row = num_imgs // num_col
42 | num_row = num_row + 1 if num_imgs % num_col else num_row
43 |
44 | # Plot the images
45 | fig, axes = plt.subplots(num_row, num_col, figsize=fig_size)
46 | for i in range(num_imgs):
47 |
48 | # If the number of rows is 1, the axes array is one-dimensional
49 | if num_row == 1:
50 | ax = axes[i % num_col]
51 | else:
52 | ax = axes[i // num_col, i % num_col]
53 |
54 | scale = abs(max_pixel_value - min_pixel_value)
55 | pair = [(views[j][i] - min_pixel_value) / scale for j in range(num_views)]
56 | ax.imshow(tf.concat(pair, axis=1))
57 | ax.set_axis_off()
58 |
59 | label = labels[i] if labels else i
60 |
61 | if predictions:
62 | ax.set_title("Label: {} | Pred: {:.5f}".format(label, predictions[i][0]))
63 | elif labels:
64 | ax.set_title("Label: {}".format(label))
65 |
66 | ax.set_xticks([])
67 | ax.set_yticks([])
68 | plt.show()
69 |
--------------------------------------------------------------------------------
/api/TFSimilarity/indexer/Lookup.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.indexer.Lookup
2 |
3 |
4 |
5 |
6 |
7 | Metadata associated with a query match.
8 |
9 |
10 | ```python
11 | TFSimilarity.indexer.Lookup(
12 | rank: int,
13 | distance: float,
14 | label: Optional[int] = dataclasses.field(default=None),
15 | embedding: Optional[np.ndarray] = dataclasses.field(default=None),
16 | )
17 | ```
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | Attributes |
30 |
31 |
32 | |
33 | rank
34 | |
35 |
36 | Rank of the match with respect to the query distance.
37 | |
38 |
39 | |
40 | distance
41 | |
42 |
43 | The distance from the match to the query.
44 | |
45 |
46 | |
47 | label
48 | |
49 |
50 | The label associated with the match. Default None.
51 | |
52 |
53 | |
54 | embedding
55 | |
56 |
57 | The embedded match vector. Default None.
58 | |
59 |
60 | |
61 | data
62 | |
63 |
64 | The original Tensor representation of the match result.
65 | Default None.
66 | |
67 |
68 |
69 |
70 |
71 |
72 | ## Methods
73 |
74 | __eq__
75 |
76 | View source
77 |
78 | ```python
79 | __eq__(
80 | other
81 | ) -> bool
82 | ```
83 |
84 |
85 | Return self==value.
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 | Class Variables |
96 |
97 |
98 | |
99 | data
100 | |
101 |
102 | None
103 | |
104 |
105 | |
106 | embedding
107 | |
108 |
109 | None
110 | |
111 |
112 | |
113 | label
114 | |
115 |
116 | None
117 | |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/tensorflow_similarity/distances/manhattan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inner product similarity computation functions for embeddings."""
15 | from __future__ import annotations
16 |
17 | from typing import TYPE_CHECKING
18 |
19 | import tensorflow as tf
20 |
21 | if TYPE_CHECKING:
22 | from ..types import FloatTensor
23 |
24 | from .distance import Distance
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Similarity")
28 | class ManhattanDistance(Distance):
29 | """Compute pairwise Manhattan distances between embeddings.
30 |
31 | The [Manhattan Distance](https://en.wikipedia.org/wiki/Euclidean_distance)
32 | is the sum of the lengths of the projections of the line segment between
33 | two embeddings onto the Cartesian axes. The larger the distance the more
34 | dissimilar the embeddings are.
35 | """
36 |
37 | def __init__(self, name: str = "manhattan", **kwargs):
38 | "Init Manhattan distance"
39 | super().__init__(name=name, **kwargs)
40 |
41 | @tf.function
42 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor:
43 | """Compute pairwise distances for a given batch of embeddings.
44 |
45 | Args:
46 | query_embeddings: Embeddings to compute the pairwise one.
47 | key_embeddings: Embeddings to compute the pairwise one.
48 |
49 | Returns:
50 | FloatTensor: Pairwise distance tensor.
51 | """
52 | q_rs = tf.reshape(query_embeddings, shape=[tf.shape(query_embeddings)[0], -1])
53 | k_rs = tf.reshape(key_embeddings, shape=[tf.shape(key_embeddings)[0], -1])
54 | deltas = tf.expand_dims(q_rs, axis=1) - tf.expand_dims(k_rs, axis=0)
55 | distances: FloatTensor = tf.norm(deltas, 1, axis=2)
56 | return distances
57 |
--------------------------------------------------------------------------------
/tensorflow_similarity/distances/inner_product.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inner product similarity computation functions for embeddings."""
15 | from __future__ import annotations
16 |
17 | from typing import TYPE_CHECKING
18 |
19 | import tensorflow as tf
20 |
21 | if TYPE_CHECKING:
22 | from ..types import FloatTensor
23 |
24 | from .distance import Distance
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Similarity")
28 | class InnerProductSimilarity(Distance):
29 | """Compute the pairwise inner product between embeddings.
30 |
31 | The [Inner product](https://en.wikipedia.org/wiki/Inner_product_space) is
32 | a measure of similarity where the more similar vectors have the largest
33 | values.
34 |
35 | NOTE! This is not a distance and is likely not what you want to use with
36 | the built in losses. At the very least this will flip the sign on the
37 | margin in many of the losses. This is likely meant to be used with custom
38 | loss functions that expect a similarity instead of a distance.
39 | """
40 |
41 | def __init__(self, name: str = "inner_product", **kwargs):
42 | "Init Inner product similarity"
43 | super().__init__(name=name, **kwargs)
44 |
45 | @tf.function
46 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor:
47 | """Compute pairwise similarities for a given batch of embeddings.
48 |
49 | Args:
50 | query_embeddings: Embeddings to compute the pairwise one.
51 | key_embeddings: Embeddings to compute the pairwise one.
52 |
53 | Returns:
54 | FloatTensor: Pairwise distance tensor.
55 | """
56 | sims: FloatTensor = tf.linalg.matmul(query_embeddings, key_embeddings, transpose_b=True)
57 | return sims
58 |
--------------------------------------------------------------------------------
/tensorflow_similarity/matchers/match_nearest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | if TYPE_CHECKING:
19 | from ..types import FloatTensor, IntTensor
20 |
21 | from .classification_match import ClassificationMatch
22 |
23 |
24 | class MatchNearest(ClassificationMatch):
25 | """Match metrics for labels at k=1."""
26 |
27 | def __init__(self, name: str = "nearest", **kwargs) -> None:
28 | if "canonical_name" not in kwargs:
29 | kwargs["canonical_name"] = "match_nearest"
30 |
31 | super().__init__(name=name, **kwargs)
32 |
33 | def derive_match(self, lookup_labels: IntTensor, lookup_distances: FloatTensor) -> tuple[IntTensor, FloatTensor]:
34 | """Derive a match label and distance from a set of K neighbors.
35 |
36 | For each query, derive a single match label and distance given the
37 | associated set of lookup labels and distances.
38 |
39 | Args:
40 | lookup_labels: A 2D array where the jth row is the labels
41 | associated with the set of k neighbors for the jth query.
42 |
43 | lookup_distances: A 2D array where the jth row is the distances
44 | between the jth query and the set of k neighbors.
45 |
46 | Returns:
47 | A Tuple of FloatTensors:
48 | derived_labels: A FloatTensor of shape
49 | [len(lookup_labels), 1] where the jth row contains the derived
50 | label for the jth query.
51 |
52 | derived_distances: A FloatTensor of shape
53 | [len(lookup_labels), 1] where the jth row contains the distance
54 | associated with the jth derived label.
55 | """
56 |
57 | return lookup_labels[:, :1], lookup_distances[:, :1]
58 |
--------------------------------------------------------------------------------
/tensorflow_similarity/classification_metrics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import Type
17 |
18 | from .binary_accuracy import BinaryAccuracy # noqa
19 | from .classification_metric import ClassificationMetric # noqa
20 | from .f1_score import F1Score # noqa
21 | from .false_positive_rate import FalsePositiveRate # noqa
22 | from .negative_predictive_value import NegativePredictiveValue # noqa
23 | from .precision import Precision # noqa
24 | from .recall import Recall # noqa
25 |
26 |
27 | def make_classification_metric(metric: str | ClassificationMetric, name: str = "") -> ClassificationMetric:
28 | """Convert classification metric from str name to object if needed.
29 |
30 | Args:
31 | metric: ClassificationMetric() or metric name.
32 |
33 | Raises:
34 | ValueError: Unknown metric name: {metric}, typo?
35 |
36 | Returns:
37 | ClassificationMetric: Instantiated metric if needed.
38 | """
39 | # ! Metrics must be non-instantiated.
40 | METRICS_ALIASES: dict[str, Type[ClassificationMetric]] = {
41 | "recall": Recall,
42 | "precision": Precision,
43 | "f1": F1Score,
44 | "f1score": F1Score,
45 | "f1_score": F1Score,
46 | "binary_accuracy": BinaryAccuracy,
47 | "npv": NegativePredictiveValue,
48 | "negative_predicitve_value": NegativePredictiveValue,
49 | "fpr": FalsePositiveRate,
50 | "false_positive_rate": FalsePositiveRate,
51 | }
52 |
53 | if isinstance(metric, str):
54 | if metric.lower() in METRICS_ALIASES:
55 | metric = METRICS_ALIASES[metric.lower()](name=metric.lower())
56 | else:
57 | raise ValueError(f"Unknown metric name: {metric}, typo?")
58 |
59 | if name:
60 | metric.name = name
61 |
62 | return metric
63 |
--------------------------------------------------------------------------------
/tests/stores/test_memory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tensorflow_similarity.stores.memory import MemoryStore
4 |
5 |
6 | def build_store(records):
7 | kv_store = MemoryStore()
8 | idxs = []
9 | for r in records:
10 | idx = kv_store.add(r[0], r[1], r[2])
11 | idxs.append(idx)
12 | return kv_store, idxs
13 |
14 |
15 | def test_in_memory_store_and_retrieve():
16 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
17 |
18 | kv_store, idxs = build_store(records)
19 |
20 | # check index numbering
21 | for gt, idx in enumerate(idxs):
22 | assert isinstance(idx, int)
23 | assert gt == idx
24 |
25 | # check reference counting
26 | assert kv_store.size() == 2
27 |
28 | # get back three elements
29 | for idx in idxs:
30 | emb, lbl, dt = kv_store.get(idx)
31 | assert emb == records[idx][0]
32 | assert lbl == records[idx][1]
33 | assert dt == records[idx][2]
34 |
35 |
36 | def test_reset():
37 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
38 |
39 | kv_store, idxs = build_store(records)
40 |
41 | # check reference counting
42 | assert kv_store.size() == 2
43 |
44 | kv_store.reset()
45 | assert kv_store.size() == 0
46 |
47 | kv_store.add(records[0][0], records[0][1], records[0][2])
48 | assert kv_store.size() == 1
49 |
50 |
51 | def test_batch_add():
52 | embs = np.array([[0.1, 0.2], [0.2, 0.3]])
53 | lbls = np.array([1, 2])
54 | data = np.array([[0, 0, 0], [1, 1, 1]])
55 |
56 | kv_store = MemoryStore()
57 | idxs = kv_store.batch_add(embs, lbls, data)
58 | for idx in idxs:
59 | emb, lbl, dt = kv_store.get(idx)
60 | assert np.array_equal(emb, embs[idx])
61 | assert np.array_equal(lbl, lbls[idx])
62 | assert np.array_equal(dt, data[idx])
63 |
64 |
65 | def test_save_and_reload(tmp_path):
66 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
67 |
68 | kv_store, idxs = build_store(records)
69 | kv_store.save(str(tmp_path))
70 |
71 | # reload
72 | reloaded_store = MemoryStore()
73 | reloaded_store.load(tmp_path)
74 |
75 | assert reloaded_store.size() == 2
76 |
77 | # get back three elements
78 | for idx in idxs:
79 | emb, lbl, dt = reloaded_store.get(idx)
80 | assert np.array_equal(emb, records[idx][0])
81 | assert np.array_equal(lbl, records[idx][1])
82 | assert np.array_equal(dt, records[idx][2])
83 |
--------------------------------------------------------------------------------
/benchmark/supervised/create_datasets.py:
--------------------------------------------------------------------------------
1 | "Supervised loss benchmark"
2 | from __future__ import annotations
3 |
4 | import argparse
5 | import json
6 | import os
7 | import re
8 | from collections.abc import Mapping
9 | from typing import Any
10 |
11 | import tensorflow as tf
12 | import tensorflow.random
13 | from components import datasets, make_augmentations, utils
14 | from termcolor import cprint
15 |
16 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
17 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
18 |
19 |
20 | def run(cfg: Mapping[str, Any], filter_pattern: str) -> None:
21 | if cfg.get("tfds_data_dir", None):
22 | os.environ["TFDS_DATA_DIR"] = cfg["tfds_data_dir"]
23 |
24 | version = cfg["version"]
25 | random_seed = cfg["random_seed"]
26 | output_dir = cfg["dataset_dir"]
27 | preproc_fns = make_augmentations(cfg["preprocess"])
28 |
29 | p = re.compile(filter_pattern)
30 | data_dir = os.path.join(output_dir, version)
31 | dataset_cfgs = []
32 | for name, cfg in cfg["datasets"].items():
33 | if p.match(name):
34 | dataset_cfgs.append(datasets.utils.make_dataset_config(name, cfg))
35 |
36 | for d in dataset_cfgs:
37 | cprint(f"|-{d.name}", "blue")
38 |
39 | cprint(f"{len(dataset_cfgs)} Datasets\n", "blue")
40 | if input("Would you like to continue: [Y/n] ").lower() != "y":
41 | cprint("Exit", "red")
42 | return
43 | else:
44 | cprint("Building datasets", "green")
45 |
46 | for dcfg in dataset_cfgs:
47 | d = datasets.utils.make_dataset(dcfg, data_dir)
48 | utils.set_random_seeds(random_seed)
49 |
50 | # Make result path
51 | cprint(f"\n|-Clearing all files in {d.path}", "blue")
52 | utils.clean_dir(d.path)
53 |
54 | # Load the raw dataset
55 | cprint(f"\n|-Loading and preprocessing {d.name}\n", "blue")
56 | d.load_raw_data(preproc_fns)
57 | d.split_raw_data()
58 | d.save_serialized_data()
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser(description="Generate datasets")
63 | parser.add_argument("--config", "-c", help="config path")
64 | parser.add_argument("--filter", "-f", help="run only the config ids that match the regexp", default=".*")
65 | args = parser.parse_args()
66 |
67 | if not args.config:
68 | parser.print_usage()
69 | quit()
70 |
71 | config = json.loads(open(args.config).read())
72 | run(config, filter_pattern=args.filter)
73 |
--------------------------------------------------------------------------------
/api/TFSimilarity/visualization/projector.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.visualization.projector
2 |
3 |
4 |
5 |
6 |
7 | Visualize the embeddings in 2D or 3D using UMAP projection
8 |
9 | ```python
10 | TFSimilarity.visualization.projector(
11 | labels: Optional[Sequence[Any]] = None,
12 | class_mapping: Optional[Sequence[int]] = None,
13 | image_size: int = 64,
14 | tooltips_info: Optional[Mapping[str, Sequence[str]]] = None,
15 | pt_size: int = 3,
16 | colorize: bool = True,
17 | pastel_factor: float = 0.1,
18 | plot_size: int = 600,
19 | active_drag: str = box_zoom,
20 | densmap: bool = True
21 | )
22 | ```
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | Args |
33 |
34 |
35 | |
36 | embeddings
37 | |
38 |
39 | The embeddings outputed by the model that
40 | are to be visualized
41 | |
42 |
43 | |
44 | labels
45 | |
46 |
47 | Labels associated with the embeddings. If not supplied treat
48 | each example as its own classes.
49 | |
50 |
51 | |
52 | class_mapping
53 | |
54 |
55 | Dictionary or list that maps the class numerical ids
56 | to their name.
57 | |
58 |
59 | |
60 | images
61 | |
62 |
63 | Images to display in tooltip on hover. Usually x_test tensor.
64 | |
65 |
66 | |
67 | image_size
68 | |
69 |
70 | size of the images displayed in the tool tip.
71 | Defaults to 64.
72 | |
73 |
74 | |
75 | pt_size
76 | |
77 |
78 | Size of the points displayed on the visualization.
79 | Defaults to 3.
80 | |
81 |
82 | |
83 | tooltips_info
84 | |
85 |
86 | Dictionary of information to display in the tooltips.
87 | |
88 |
89 | |
90 | colorize
91 | |
92 |
93 | Colorize the clusters. Defaults to true.
94 | |
95 |
96 | |
97 | pastel_factor
98 | |
99 |
100 | Modify the color palette to be more pastel.
101 | |
102 |
103 | |
104 | densmap
105 | |
106 |
107 | Use UMAP dense mapper which provides better density
108 | estimation but is a little slower. Defaults to True.
109 | |
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/benchmark/supervised/components/experiments.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import os
5 | from collections.abc import Mapping
6 | from itertools import product
7 | from typing import Any
8 |
9 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule
10 |
11 | from . import datasets, utils
12 |
13 |
14 | @dataclasses.dataclass(eq=True, frozen=True)
15 | class Component:
16 | cid: str
17 | name: str
18 | params: Mapping[str, Any]
19 |
20 | def __hash__(self):
21 | return hash((self.cid, self.name))
22 |
23 |
24 | @dataclasses.dataclass
25 | class Experiment:
26 | run_grp: str
27 | path: str
28 | dataset: Component
29 | architecture: Component
30 | loss: Component
31 | opt: Component
32 | training: Component
33 | lr_schedule: LearningRateSchedule | None = None
34 |
35 |
36 | def make_experiments(cfg: Mapping[str, Any], output_dir: str) -> list[Experiment]:
37 | experiments = []
38 |
39 | # Generate the cross product of all the experiment params.
40 | for (dn, dcfg), (an, acfg), (ln, lcfg), (on, ocfg), tcfg in product(
41 | cfg["datasets"].items(),
42 | cfg["architectures"].items(),
43 | cfg["losses"].items(),
44 | cfg["optimizer"].items(),
45 | cfg["training"],
46 | ):
47 | dataset = datasets.utils.make_dataset_config(name=dn, params=dcfg)
48 | loss = Component(cid=lcfg["component"], name=ln, params=lcfg)
49 | opt = Component(cid=ocfg["component"], name=on, params=ocfg)
50 | training = Component(cid="", name=tcfg["name"], params=tcfg)
51 |
52 | for embedding_size in acfg.get("embedding_sizes", [128]):
53 | acfg["embedding"] = embedding_size
54 | architecture = Component(cid=acfg["component"], name=an, params=acfg)
55 |
56 | run_grp = utils.make_run_grp(
57 | dataset.name,
58 | architecture.name,
59 | architecture.params["embedding"],
60 | loss.name,
61 | opt.name,
62 | )
63 | experiments.append(
64 | Experiment(
65 | run_grp=run_grp,
66 | path=os.path.join(output_dir, run_grp),
67 | dataset=dataset,
68 | architecture=architecture,
69 | loss=loss,
70 | opt=opt,
71 | training=training,
72 | )
73 | )
74 | return experiments
75 |
--------------------------------------------------------------------------------
/tensorflow_similarity/samplers/file_samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING, TypeVar
17 |
18 | import tensorflow as tf
19 |
20 | from .memory_samplers import MultiShotMemorySampler
21 |
22 | if TYPE_CHECKING:
23 | from collections.abc import Callable, Sequence
24 |
25 | from ..types import FloatTensor, IntTensor
26 | from .samplers import Augmenter
27 |
28 | T = TypeVar("T", FloatTensor, IntTensor)
29 |
30 |
31 | def load_image(path: str, target_size: tuple[int, int] | None = None) -> T:
32 | image_string = tf.io.read_file(path)
33 | image: T = tf.image.decode_jpeg(image_string, channels=3)
34 | image = tf.image.convert_image_dtype(image, tf.float32)
35 | if target_size:
36 | image = tf.image.resize(image, target_size, method=tf.image.ResizeMethod.LANCZOS3)
37 | image = tf.clip_by_value(image, 0.0, 1.0)
38 | return image
39 |
40 |
41 | class MultiShotFileSampler(MultiShotMemorySampler):
42 | def __init__(
43 | self,
44 | x,
45 | y,
46 | load_example_fn: Callable = load_image,
47 | classes_per_batch: int = 2,
48 | examples_per_class_per_batch: int = 2,
49 | steps_per_epoch: int = 1000,
50 | class_list: Sequence[int] | None = None,
51 | total_examples_per_class: int | None = None,
52 | augmenter: Augmenter | None = None,
53 | warmup: int = -1,
54 | ):
55 | super().__init__(
56 | x,
57 | y,
58 | load_example_fn=load_example_fn,
59 | classes_per_batch=classes_per_batch,
60 | examples_per_class_per_batch=examples_per_class_per_batch,
61 | steps_per_epoch=steps_per_epoch,
62 | class_list=class_list,
63 | total_examples_per_class=total_examples_per_class,
64 | augmenter=augmenter,
65 | warmup=warmup,
66 | )
67 |
--------------------------------------------------------------------------------
/tensorflow_similarity/distances/snr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inner product similarity computation functions for embeddings."""
15 | from __future__ import annotations
16 |
17 | from typing import TYPE_CHECKING
18 |
19 | import tensorflow as tf
20 |
21 | if TYPE_CHECKING:
22 | from ..types import FloatTensor
23 |
24 | from .distance import Distance
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Similarity")
28 | class SNRDistance(Distance):
29 | """
30 | Computes pairwise SNR distances between embeddings.
31 |
32 | The [Signal-to-Noise Ratio distance](https://arxiv.org/abs/1904.02616)
33 | is the ratio of noise variance to the feature variance.
34 | """
35 |
36 | def __init__(self, name: str = "snr", **kwargs):
37 | "Init SNR distance"
38 | super().__init__(name=name, **kwargs)
39 |
40 | @tf.function
41 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor:
42 | """Compute pairwise snr distances for a given batch of embeddings.
43 | SNR(i, j): anchor i and compared feature j
44 | SNR(i,j) may not be equal to SNR(j, i)
45 |
46 | Args:
47 | query_embeddings: Embeddings to compute the pairwise one.
48 |
49 | Returns:
50 | FloatTensor: Pairwise distance tensor.
51 | """
52 | # Calculating feature variance for each example
53 | anchor_var = tf.math.reduce_variance(query_embeddings, axis=1)
54 |
55 | # Calculating pairwise noise variances
56 | q_rs = tf.reshape(query_embeddings, shape=[tf.shape(query_embeddings)[0], -1])
57 | k_rs = tf.reshape(key_embeddings, shape=[tf.shape(key_embeddings)[0], -1])
58 | delta = tf.expand_dims(q_rs, axis=1) - tf.expand_dims(k_rs, axis=0)
59 | noise_var = tf.math.reduce_variance(delta, axis=2)
60 |
61 | distances: FloatTensor = tf.divide(noise_var, tf.expand_dims(anchor_var, axis=1))
62 |
63 | return distances
64 |
--------------------------------------------------------------------------------
/api/TFSimilarity/architectures/ResNet18Sim.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.architectures.ResNet18Sim
2 |
3 |
4 |
5 |
6 |
7 | Build an ResNet18 Model backbone for similarity learning
8 |
9 | ```python
10 | TFSimilarity.architectures.ResNet18Sim(
11 | input_shape: Tuple[int, int, int],
12 | embedding_size: int = 128,
13 | l2_norm: bool = True,
14 | include_top: bool = True,
15 | pooling: str = gem,
16 | gem_p=3.0
17 | ```
18 |
19 |
20 |
21 |
22 |
23 | Architecture from [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
24 |
25 |
26 |
27 |
28 | Args |
29 |
30 |
31 | |
32 | input_shape
33 | |
34 |
35 | Expected to be betweeen 32 and 224 and in the (H, W, C)
36 | data_format.
37 | |
38 |
39 | |
40 | embedding_size
41 | |
42 |
43 | Size of the output embedding. Usually between 64
44 | and 512. Defaults to 128.
45 | |
46 |
47 | |
48 | l2_norm
49 | |
50 |
51 | If True and include_top is also True, then
52 | tfsim.layers.MetricEmbedding is used as the last layer, otherwise
53 | keras.layers.Dense is used. This should be true when using cosine
54 | distance. Defaults to True.
55 | |
56 |
57 | |
58 | include_top
59 | |
60 |
61 | Whether to include a fully-connected layer of
62 | embedding_size at the top of the network. Defaults to True.
63 | |
64 |
65 | |
66 | pooling
67 | |
68 |
69 | Optional pooling mode for feature extraction when
70 | include_top is False. Defaults to gem.
71 | - None means that the output of the model will be the 4D tensor
72 | output of the last convolutional layer.
73 | - avg means that global average pooling will be applied to the
74 | output of the last convolutional layer, and thus the output of the
75 | model will be a 2D tensor.
76 | - max means that global max pooling will be applied.
77 | - gem means that global GeneralizedMeanPooling2D will be applied.
78 | The gem_p param sets the contrast amount on the pooling.
79 | |
80 |
81 | |
82 | gem_p
83 | |
84 |
85 | Sets the power in the GeneralizedMeanPooling2D layer. A value
86 | of 1.0 is equivelent to GlobalMeanPooling2D, while larger values
87 | will increase the contrast between activations within each feature
88 | map, and a value of math.inf will be equivelent to MaxPool2d.
89 | |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/tensorflow_similarity/architectures/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | import tensorflow as tf
17 |
18 |
19 | def convert_sync_batchnorm(model: tf.keras.Model) -> tf.keras.Model:
20 | """Replace BatchNormalization layers to SyncBatchNormalization in place.
21 | WARNINGS:
22 | * This function is tested only with efficientnet and resnet
23 | * The returned model has shared layers with the input one. One of them should be disposed.
24 | """
25 |
26 | layer2newtensor = {}
27 | in_ = tf.keras.layers.Input(model.input.shape[1:])
28 | layer2newtensor[model.input.name] = in_
29 | for layer in model.layers[1:]:
30 | assert len(layer.inbound_nodes) == 1
31 | if isinstance(layer.inbound_nodes[0].inbound_layers, list): # mutliple inputs
32 | x = [layer2newtensor[in_l.name] for in_l in layer.inbound_nodes[0].inbound_layers]
33 | else:
34 | x = layer2newtensor[layer.inbound_nodes[0].inbound_layers.name]
35 | if isinstance(layer, tf.keras.layers.BatchNormalization):
36 | tf_version = [int(v) for v in tf.__version__.split(".")]
37 | if tf_version[0] == 2 and tf_version[1] < 12:
38 | layer = tf.keras.layers.experimental.SyncBatchNormalization(**layer.get_config())
39 | else:
40 | layer = tf.keras.layers.BatchNormalization(**layer.get_config(), synchronized=True)
41 |
42 | if "truediv" in layer.name:
43 | # efficeientnet edge case
44 | # https://github.com/keras-team/keras/blob/v2.9.0/keras/applications/efficientnet.py#L334
45 | x = layer(x, layer.inbound_nodes[0]._flat_arguments[1])
46 | else:
47 | x = layer(x)
48 |
49 | layer2newtensor[layer.name] = x
50 | out_ = layer2newtensor[model.layers[-1].name]
51 | new_model = tf.keras.Model(inputs=in_, outputs=out_, name=model.name)
52 | new_model.set_weights(model.get_weights())
53 | return new_model
54 |
--------------------------------------------------------------------------------
/.github/workflows/nightly-publish.yml:
--------------------------------------------------------------------------------
1 | name: NightlyPublish
2 |
3 | on:
4 | workflow_dispatch: # Allow manual triggers
5 | schedule:
6 | # Runs every day at 3:07am UTC.
7 | - cron: '7 3 * * *'
8 |
9 | jobs:
10 | check:
11 |
12 | name: Check files
13 | # Prevent Publish from running on forks.
14 | if: github.repository == 'tensorflow/similarity'
15 | outputs:
16 | run_job: ${{ steps.check_files.outputs.run_job }}
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | with:
22 | fetch-depth: 2
23 | ref: 'development'
24 |
25 | - name: check modified files
26 | id: check_files
27 | run: |
28 | echo "=============== list modified files ==============="
29 | git diff --name-only HEAD^ HEAD
30 |
31 | echo "========== check paths of modified files =========="
32 | git diff --name-only HEAD^ HEAD > files.txt
33 | while IFS= read -r file
34 | do
35 | echo $file
36 | if [[ $file = tensorflow_similarity/* && $file != tensorflow_similarity/__init__.py ]]; then
37 | echo "This modified file is under the 'tensorflow_similarity' folder."
38 | echo "::set-output name=run_job::true"
39 | break
40 | else
41 | echo "::set-output name=run_job::false"
42 | fi
43 | done < files.txt
44 |
45 | publish:
46 |
47 | name: Publish nightly
48 | needs: check
49 | # Prevent Publish from running on forks.
50 | if: |
51 | github.repository == 'tensorflow/similarity' &&
52 | needs.check.outputs.run_job == 'true'
53 | runs-on: ubuntu-latest
54 |
55 | steps:
56 | - uses: actions/checkout@v2
57 | with:
58 | ref: 'development'
59 |
60 | - name: Set up Python
61 | uses: actions/setup-python@v2
62 | with:
63 | python-version: '3.9'
64 |
65 | - name: Install dependencies
66 | run: |
67 | python -m pip install --upgrade pip
68 |
69 | - name: Install package
70 | run: |
71 | pip install ".[tensorflow,dev]"
72 |
73 | - name: Increment dev version
74 | run: |
75 | # Increments the dev version and pushes the changes to development.
76 | python scripts/increment_version.py
77 |
78 | - name: Build package
79 | run: |
80 | python setup.py sdist bdist_wheel --project_name tfsim-nightly
81 |
82 | - name: Publish package
83 | run: |
84 | twine upload -u ${{ secrets.PYPI_NIGHTLY_USERNAME }} -p ${{ secrets.PYPI_NIGHTLY_TOKEN }} dist/* --verbose
85 |
--------------------------------------------------------------------------------
/tensorflow_similarity/distances/distance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Vectorized embedding pairwise distances computation functions"""
15 | from __future__ import annotations
16 |
17 | from abc import ABC, abstractmethod
18 | from typing import TYPE_CHECKING, Any
19 |
20 | if TYPE_CHECKING:
21 | from ..types import FloatTensor
22 |
23 |
24 | class Distance(ABC):
25 | """Abstract class for distance computation."""
26 |
27 | def __init__(self, name: str, **kwargs):
28 | self.name = name
29 |
30 | @abstractmethod
31 | def call(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor) -> FloatTensor:
32 | """Compute pairwise distances for a given batch.
33 |
34 | Args:
35 | query_embeddings: Embeddings to compute the pairwise one.
36 | key_embeddings: Embeddings to compute the pairwise one.
37 |
38 | Returns:
39 | FloatTensor: Pairwise distance tensor.
40 | """
41 |
42 | def __call__(self, query_embeddings: FloatTensor, key_embeddings: FloatTensor):
43 | return self.call(query_embeddings, key_embeddings)
44 |
45 | def __str__(self) -> str:
46 | return self.name
47 |
48 | def get_config(self) -> dict[str, Any]:
49 | """Contains the distance configuration.
50 |
51 | Returns:
52 | A Python dict containing the configuration of the distance obj.
53 | """
54 | config = {"name": self.name}
55 |
56 | return config
57 |
58 | @classmethod
59 | def from_config(cls, config: dict[str, Any]) -> Distance:
60 | """Build a distance from a config.
61 |
62 | Args:
63 | config: A Python dict containing the configuration of the distance.
64 |
65 | Returns:
66 | A distance instance.
67 | """
68 | try:
69 | return cls(**config)
70 | except Exception as e:
71 | raise TypeError(
72 | f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}"
73 | )
74 |
--------------------------------------------------------------------------------
/tests/losses/test_softnn_loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from absl.testing import parameterized
3 | from tensorflow.keras.losses import Reduction
4 | from tensorflow.python.framework import combinations
5 |
6 | from tensorflow_similarity import losses
7 |
8 |
9 | @tf.function
10 | def softnn_util(y_true, x, temperature=1):
11 | """
12 | A simple loop based implementation of soft
13 | nearest neighbor loss to test the code.
14 | https://arxiv.org/pdf/1902.01889.pdf
15 | """
16 |
17 | batch_size = tf.shape(y_true)[0]
18 | loss = 0.0
19 | for i in tf.range(batch_size, dtype=tf.int32):
20 | numerator = 0.0
21 | denominator = 0.0
22 | for j in tf.range(batch_size, dtype=tf.int32):
23 | if i == j:
24 | continue
25 | if y_true[i] == y_true[j]:
26 | numerator += tf.math.exp(-1 * tf.math.reduce_sum(tf.math.square(x[i] - x[j])) / temperature)
27 | denominator += tf.math.exp(-1 * tf.math.reduce_sum(tf.math.square(x[i] - x[j])) / temperature)
28 | if numerator == 0.0:
29 | continue
30 | loss += tf.math.log(numerator / denominator)
31 | return -loss / tf.cast(batch_size, tf.float32)
32 |
33 |
34 | @combinations.generate(
35 | combinations.combine(
36 | mode=["graph", "eager"],
37 | )
38 | )
39 | class SoftNNLossTest(tf.test.TestCase, parameterized.TestCase):
40 | def test_config(self):
41 | softnn_obj = losses.SoftNearestNeighborLoss(
42 | reduction=Reduction.SUM,
43 | name="softnn_loss",
44 | distance="cosine",
45 | )
46 | self.assertEqual(softnn_obj.distance.name, "cosine")
47 | self.assertEqual(softnn_obj.name, "softnn_loss")
48 | self.assertEqual(softnn_obj.reduction, Reduction.SUM)
49 |
50 | @parameterized.parameters((0.1), (0.5), (1), (2), (5), (10), (50))
51 | def test_all_correct(self, temperature):
52 | num_inputs = 10
53 | n_classes = 10
54 |
55 | softnn_obj = losses.SoftNearestNeighborLoss(
56 | reduction=Reduction.SUM_OVER_BATCH_SIZE,
57 | temperature=temperature,
58 | )
59 |
60 | # y_true: labels
61 | y_true = tf.random.uniform((num_inputs, 1), 0, n_classes, dtype=tf.int32)
62 | # x: embeddings
63 | y_preds = tf.random.uniform((num_inputs, 20), 0, 1)
64 |
65 | loss = softnn_obj(y_true, y_preds)
66 | loss_check = softnn_util(
67 | y_true,
68 | y_preds,
69 | temperature,
70 | )
71 |
72 | loss_diff = loss - loss_check
73 | self.assertLess(self.evaluate(loss_diff), 1e-3)
74 |
--------------------------------------------------------------------------------
/tests/models/test_similarity_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import tensorflow as tf
15 |
16 | from tensorflow_similarity.losses import TripletLoss
17 | from tensorflow_similarity.models import SimilarityModel
18 |
19 |
20 | class SimilarityModelTest(tf.test.TestCase):
21 | # TODO(ovallis): Add tests for graph mode.
22 | def test_save_and_reload(self):
23 | out_dir = self.get_temp_dir()
24 |
25 | inputs = tf.keras.layers.Input(shape=(3,))
26 | outputs = tf.keras.layers.Dense(2)(inputs)
27 | model = SimilarityModel(inputs, outputs)
28 | model.compile(optimizer="adam", loss=TripletLoss())
29 |
30 | # index data
31 | x = tf.constant([[1, 1, 3], [3, 1, 2]], dtype="float32")
32 | y = tf.constant([1, 2])
33 | model.index(x, y)
34 |
35 | # save
36 | model.save(out_dir)
37 |
38 | # reload
39 | loaded_model = tf.keras.models.load_model(out_dir)
40 | loaded_model.load_index(out_dir)
41 | self.assertEqual(loaded_model._index.size(), len(y))
42 |
43 | def test_save_no_compile(self):
44 | out_dir = self.get_temp_dir()
45 |
46 | inputs = tf.keras.layers.Input(shape=(3,))
47 | outputs = tf.keras.layers.Dense(2)(inputs)
48 | model = SimilarityModel(inputs, outputs)
49 |
50 | model.save(out_dir)
51 | model2 = tf.keras.models.load_model(out_dir)
52 | self.assertIsInstance(model2, type(model))
53 |
54 | def test_index_single(self):
55 | """Unit Test for issues #161 & #162"""
56 | inputs = tf.keras.layers.Input(shape=(3,))
57 | outputs = tf.keras.layers.Dense(2)(inputs)
58 | model = SimilarityModel(inputs, outputs)
59 | model.compile(optimizer="adam", loss=TripletLoss())
60 |
61 | # index data
62 | x = tf.constant([1, 1, 3], dtype="float32")
63 | y = tf.constant([1])
64 |
65 | # run individual sample & index
66 | model.index_single(x, y, data=x)
67 | self.assertEqual(model._index.size(), 1)
68 |
69 |
70 | if __name__ == "__main__":
71 | tf.test.main()
72 |
--------------------------------------------------------------------------------
/tests/retrieval_metrics/test_map_at_k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from tensorflow_similarity.retrieval_metrics import MapAtK
5 |
6 |
7 | def test_concrete_instance():
8 | rm = MapAtK(r={1: 4, 0: 4})
9 |
10 | assert rm.name == "map@5"
11 | # Check the name once we have updated the threshold
12 | rm.distance_threshold = 0.1
13 | assert rm.name == "map@5 : distance_threshold@0.1"
14 | assert repr(rm) == "map@k : map@5 : distance_threshold@0.1"
15 | assert rm.canonical_name == "map@k"
16 | assert rm.k == 5
17 | assert rm.distance_threshold == 0.1
18 | assert rm.average == "micro"
19 |
20 | expected_config = {
21 | "r": {1: 4, 0: 4},
22 | "name": "map@5 : distance_threshold@0.1",
23 | "canonical_name": "map@k",
24 | "clip_at_r": False,
25 | "k": 5,
26 | "distance_threshold": 0.1,
27 | }
28 | assert rm.get_config() == expected_config
29 |
30 |
31 | def test_compute():
32 | query_labels = tf.constant([1, 1, 1, 0])
33 | match_mask = tf.constant(
34 | [
35 | [True, True, False],
36 | [True, True, False],
37 | [True, True, False],
38 | [False, False, True],
39 | ],
40 | dtype=bool,
41 | )
42 | rm = MapAtK(r={0: 10, 1: 3}, k=3)
43 |
44 | mapk = rm.compute(query_labels=query_labels, match_mask=match_mask)
45 |
46 | # mapk should be sum(precision@k*Relevancy_Mask)/R
47 | # class 1 has 3 results sets that are all T,T,F:
48 | # (1.0*True+1.0*True+0.66*False)/3 = 0.66667
49 | # class 0 has 1 result set that is F,F,T
50 | # (0.0*False+0.0*False+0.33*True)/10 = 0.03332
51 | # mapk = (0.667*3 + 0.0332)/4
52 | expected = tf.constant(0.50833333332)
53 |
54 | np.testing.assert_allclose(mapk, expected)
55 |
56 |
57 | def test_clip_at_r():
58 | query_labels = tf.constant([0, 1])
59 | match_mask = tf.constant(
60 | [
61 | [False, False, False, True],
62 | [False, False, False, True],
63 | ],
64 | dtype=bool,
65 | )
66 | rm = MapAtK(r={0: 4, 1: 3}, k=4, clip_at_r=True)
67 |
68 | mapk = rm.compute(query_labels=query_labels, match_mask=match_mask)
69 |
70 | # mapk should be sum(precision@k*Relevancy_Mask)/R
71 | # but here we clip the result set at the r associate with the query label.
72 | # class 0 has 1 result set of F,F,F,T with R == 4
73 | # (0.0*False+0.0*False+0.0*False)/3 = 0.0
74 | # class 1 has 1 result set of F,F,F,T with R == 3
75 | # (0.0*False+0.0*False+0.0*False+0.25*True)/4 = 0.0625
76 | # mapk = (0.0 + 0.0625)/2
77 | expected = tf.constant(0.03125)
78 |
79 | np.testing.assert_allclose(mapk, expected)
80 |
--------------------------------------------------------------------------------
/tests/stores/test_cached.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | from tensorflow_similarity.stores.cached import CachedStore
6 |
7 |
8 | def build_store(records, path):
9 | kv_store = CachedStore(path=path)
10 | idxs = []
11 | for r in records:
12 | idx = kv_store.add(r[0], r[1], r[2])
13 | idxs.append(idx)
14 | return kv_store, idxs
15 |
16 |
17 | def test_cached_store_and_retrieve(tmp_path):
18 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
19 |
20 | kv_store, idxs = build_store(records, tmp_path)
21 |
22 | # check index numbering
23 | for gt, idx in enumerate(idxs):
24 | assert isinstance(idx, int)
25 | assert gt == idx
26 |
27 | # check reference counting
28 | assert kv_store.size() == 2
29 |
30 | # get back three elements
31 | for idx in idxs:
32 | emb, lbl, dt = kv_store.get(idx)
33 | assert emb == records[idx][0]
34 | assert lbl == records[idx][1]
35 | assert dt == records[idx][2]
36 |
37 |
38 | def test_reset(tmp_path):
39 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
40 |
41 | kv_store, idxs = build_store(records, tmp_path)
42 |
43 | # check reference counting
44 | assert kv_store.size() == 2
45 |
46 | kv_store.reset()
47 | assert kv_store.size() == 0
48 |
49 | kv_store.add(records[0][0], records[0][1], records[0][2])
50 | assert kv_store.size() == 1
51 |
52 |
53 | def test_batch_add(tmp_path):
54 | embs = np.array([[0.1, 0.2], [0.2, 0.3]])
55 | lbls = np.array([1, 2])
56 | data = np.array([[0, 0, 0], [1, 1, 1]])
57 |
58 | kv_store = CachedStore(path=tmp_path)
59 | idxs = kv_store.batch_add(embs, lbls, data)
60 | for idx in idxs:
61 | emb, lbl, dt = kv_store.get(idx)
62 | assert np.array_equal(emb, embs[idx])
63 | assert np.array_equal(lbl, lbls[idx])
64 | assert np.array_equal(dt, data[idx])
65 |
66 |
67 | def test_save_and_reload(tmp_path):
68 | records = [[[0.1, 0.2], 1, [0, 0, 0]], [[0.2, 0.3], 2, [0, 0, 0]]]
69 |
70 | save_path = tmp_path / "save"
71 | os.mkdir(save_path)
72 | obj_path = tmp_path / "obj"
73 | os.mkdir(obj_path)
74 |
75 | kv_store, idxs = build_store(records, obj_path)
76 | kv_store.save(save_path)
77 |
78 | # reload
79 | reloaded_store = CachedStore()
80 | print(f"loading from {save_path}")
81 | reloaded_store.load(save_path)
82 |
83 | assert reloaded_store.size() == 2
84 |
85 | # get back three elements
86 | for idx in idxs:
87 | emb, lbl, dt = reloaded_store.get(idx)
88 | assert np.array_equal(emb, records[idx][0])
89 | assert np.array_equal(lbl, records[idx][1])
90 | assert np.array_equal(dt, records[idx][2])
91 |
--------------------------------------------------------------------------------
/api/TFSimilarity/distances/Distance.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.distances.Distance
2 |
3 |
4 |
5 |
6 |
7 | Note: don't forget to add your distance to the DISTANCES list
8 |
9 | Inherits From: [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 |
12 | ```python
13 | TFSimilarity.distances.Distance(
14 | name: str, aliases: List[str] = []
15 | )
16 | ```
17 |
18 |
19 |
20 |
21 | and add alias names in it.
22 |
23 | ## Methods
24 |
25 | call
26 |
27 | View source
28 |
29 | ```python
30 | call(
31 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
32 | ,
33 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
34 |
35 | ) -> TFSimilarity.callbacks.FloatTensor```
36 |
37 | ```
38 |
39 |
40 | Compute pairwise distances for a given batch.
41 |
42 |
43 |
44 |
45 |
46 | | Args |
47 |
48 |
49 | |
50 | query_embeddings
51 | |
52 |
53 | Embeddings to compute the pairwise one.
54 | |
55 |
56 | |
57 | key_embeddings
58 | |
59 |
60 | Embeddings to compute the pairwise one.
61 | |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 | | Returns |
71 |
72 |
73 | |
74 | FloatTensor
75 | |
76 |
77 | Pairwise distance tensor.
78 | |
79 |
80 |
81 |
82 |
83 |
84 | get_config
85 |
86 | View source
87 |
88 | ```python
89 | get_config()
90 | ```
91 |
92 |
93 |
94 |
95 |
96 | __call__
97 |
98 | View source
99 |
100 | ```python
101 | __call__(
102 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
103 | ,
104 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
105 |
106 | )
107 | ```
108 |
109 |
110 | Call self as a function.
111 |
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/api/TFSimilarity/callbacks/SplitValidationLoss.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.callbacks.SplitValidationLoss
2 |
3 |
4 |
5 |
6 |
7 | Creates the validation callbacks.
8 |
9 | ```python
10 | TFSimilarity.callbacks.SplitValidationLoss(
11 | query_labels: Sequence[int],
12 | target_labels: Sequence[int],
13 | distance: str = cosine,
14 | metrics: Sequence[Union[str, ClassificationMetric]] = [binary_accuracy, f1score],
15 | tb_logdir: str = None,
16 | k: int = 1,
17 | )
18 | ```
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | Args |
29 |
30 |
31 | |
32 | queries
33 | |
34 |
35 | Test examples that will be tested against the built index.
36 | |
37 |
38 | |
39 | query_labels
40 | |
41 |
42 | Queries nearest neighbors expected labels.
43 | |
44 |
45 | |
46 | targets
47 | |
48 |
49 | Examples that are indexed.
50 | |
51 |
52 | |
53 | target_labels
54 | |
55 |
56 | Target examples labels.
57 | |
58 |
59 | |
60 | known_classes
61 | |
62 |
63 | The set of classes seen during training.
64 | |
65 |
66 | |
67 | distance
68 | |
69 |
70 | Distance function used to compute pairwise distance
71 | between examples embeddings.
72 | |
73 |
74 | |
75 | metrics
76 | |
77 |
78 | List of
79 | 'tf.similarity.classification_metrics.ClassificationMetric()` to
80 | compute during the evaluation. Defaults to ['binary_accuracy',
81 | 'f1score'].
82 | |
83 |
84 | |
85 | tb_logdir
86 | |
87 |
88 | Where to write TensorBoard logs. Defaults to None.
89 | |
90 |
91 | |
92 | k
93 | |
94 |
95 | The number of nearest neighbors to return for each query.
96 | The lookups are consumed by the Matching Strategy and used to
97 | derive the matching label and distance.
98 | |
99 |
100 | |
101 | matcher
102 | |
103 |
104 | 'match_nearest', 'match_majority_vote' or
105 | ClassificationMatch object. Defines the classification matching,
106 | e.g., match_nearest will count a True Positive if the query_label
107 | is equal to the label of the nearest neighbor and the distance is
108 | less than or equal to the distance threshold.
109 | |
110 |
111 | |
112 | distance_thresholds
113 | |
114 |
115 | A 1D tensor denoting the distances points at
116 | which we compute the metrics. If None, distance_thresholds is set
117 | to tf.constant([math.inf])
118 | |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | tmp/
2 | tmp.py
3 | tensorflow_similarity.egg-info/
4 | logs/
5 | examples/models/
6 | examples/logs/
7 | .vscode/
8 | site/
9 | release.sh
10 | .DS_Store
11 | benchmark/supervised/datasets/
12 | benchmark/supervised/models/
13 | datasets/
14 | multi_modal_datasets/
15 | *.h5
16 |
17 | # Byte-compiled / optimized / DLL files
18 | __pycache__/
19 | *.py[cod]
20 |
21 | *$py.class
22 | *.prof
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | pip-wheel-metadata/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | *.py,cover
68 | .hypothesis/
69 | .pytest_cache/
70 | .pytype
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 | db.sqlite3-journal
81 |
82 | # Flask stuff:
83 | instance/
84 | .webassets-cache
85 |
86 | # Scrapy stuff:
87 | .scrapy
88 |
89 | # Sphinx documentation
90 | docs/_build/
91 |
92 | # PyBuilder
93 | target/
94 |
95 | # Jupyter Notebook
96 | .ipynb_checkpoints
97 |
98 | # IPython
99 | profile_default/
100 | ipython_config.py
101 |
102 | # pyenv
103 | .python-version
104 |
105 | # pipenv
106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
109 | # install all needed dependencies.
110 | #Pipfile.lock
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # Pipfile
150 | Pipfile
151 | Pipfile.lock
152 |
--------------------------------------------------------------------------------
/tensorflow_similarity/losses/metric_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Metric losses base class."""
16 | from __future__ import annotations
17 |
18 | from typing import TYPE_CHECKING, Any
19 |
20 | import tensorflow as tf
21 |
22 | if TYPE_CHECKING:
23 | from collections.abc import Callable
24 | from ..types import FloatTensor
25 |
26 | from ..utils import is_tensor_or_variable
27 |
28 |
29 | class MetricLoss(tf.keras.losses.Loss):
30 | """Wraps a loss function in the `Loss` class."""
31 |
32 | def __init__(
33 | self, fn: Callable, reduction: Callable = tf.keras.losses.Reduction.AUTO, name: str | None = None, **kwargs
34 | ):
35 | """Initializes `LossFunctionWrapper` class.
36 | Args:
37 | fn: The loss function to wrap, with signature `fn(y_true, y_pred,
38 | **kwargs)`.
39 | reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
40 | loss. Default value is `AUTO`.
41 | name: (Optional) name for the loss.
42 | **kwargs: The keyword arguments that are passed on to `fn`.
43 | """
44 | super().__init__(reduction=reduction, name=name)
45 | self.fn = fn
46 | self._fn_kwargs = kwargs
47 |
48 | def call(self, y_true: FloatTensor, y_pred: FloatTensor) -> FloatTensor:
49 | """Invokes the `LossFunctionWrapper` instance.
50 | Args:
51 | y_true: Ground truth values.
52 | y_pred: The predicted values.
53 | Returns:
54 | Loss values per sample.
55 | """
56 | loss: FloatTensor = self.fn(y_true, y_pred, y_true, y_pred, **self._fn_kwargs)
57 | return loss
58 |
59 | def get_config(self) -> dict[str, Any]:
60 | """Contains the loss configuration.
61 |
62 | Returns:
63 | A Python dict containing the configuration of the loss.
64 | """
65 | config: dict[str, Any] = super().get_config()
66 | for k, v in iter(self._fn_kwargs.items()):
67 | if is_tensor_or_variable(v):
68 | config[k] = tf.keras.backend.eval(v)
69 | else:
70 | config[k] = v
71 |
72 | return config
73 |
--------------------------------------------------------------------------------
/tensorflow_similarity/classification_metrics/recall.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..types import FloatTensor
22 |
23 | from .classification_metric import ClassificationMetric
24 |
25 |
26 | class Recall(ClassificationMetric):
27 | """Calculates the recall of the query classification.
28 |
29 | Computes the recall given the query classification counts.
30 |
31 | $$
32 | Recall = \frac{\textrm{true_positives}}{\textrm{true_positives} +
33 | \textrm{false_negatives}}
34 | $$
35 |
36 | args:
37 | name: Name associated with a specific metric object, e.g.,
38 | recall@0.1
39 |
40 | Usage with `tf.similarity.models.SimilarityModel()`:
41 |
42 | ```python
43 | model.calibrate(x=query_examples,
44 | y=query_labels,
45 | calibration_metric='recall')
46 | ```
47 | """
48 |
49 | def __init__(self, name: str = "recall") -> None:
50 | super().__init__(name=name, canonical_name="recall")
51 |
52 | def compute(self, tp: FloatTensor, fp: FloatTensor, tn: FloatTensor, fn: FloatTensor, count: int) -> FloatTensor:
53 | """Compute the classification metric.
54 |
55 | The `compute()` method supports computing the metric for a set of
56 | values, where each value represents the counts at a specific distance
57 | threshold.
58 |
59 | Args:
60 | tp: A 1D FloatTensor containing the count of True Positives at each
61 | distance threshold.
62 |
63 | fp: A 1D FloatTensor containing the count of False Positives at
64 | each distance threshold.
65 |
66 | tn: A 1D FloatTensor containing the count of True Negatives at each
67 | distance threshold.
68 |
69 | fn: A 1D FloatTensor containing the count of False Negatives at
70 | each distance threshold.
71 |
72 | count: The total number of queries
73 |
74 | Returns:
75 | A 1D FloatTensor containing the metric at each distance threshold.
76 | """
77 | result: FloatTensor = tf.math.divide_no_nan(tp, tp + fn)
78 | return result
79 |
--------------------------------------------------------------------------------
/tensorflow_similarity/classification_metrics/negative_predictive_value.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..types import FloatTensor
22 |
23 | from .classification_metric import ClassificationMetric
24 |
25 |
26 | class NegativePredictiveValue(ClassificationMetric):
27 | """Calculates the negative predictive value of the query classification.
28 |
29 | Computes the NPV given the query classification counts.
30 |
31 | $$
32 | FPR = \frac{\textrm{true_negatives}}{\textrm{false_negatives} +
33 | \textrm{true_negatives}}
34 | $$
35 |
36 | args:
37 | name: Name associated with a specific metric object, e.g.
38 | npv@0.1
39 |
40 | Usage with `tf.similarity.models.SimilarityModel()`:
41 |
42 | ```python
43 | model.calibrate(x=query_examples,
44 | y=query_labels,
45 | calibration_metric='fpr')
46 | ```
47 | """
48 |
49 | def __init__(self, name: str = "npv") -> None:
50 | super().__init__(name=name, canonical_name="negative_predictive_value")
51 |
52 | def compute(self, tp: FloatTensor, fp: FloatTensor, tn: FloatTensor, fn: FloatTensor, count: int) -> FloatTensor:
53 | """Compute the classification metric.
54 |
55 | The `compute()` method supports computing the metric for a set of
56 | values, where each value represents the counts at a specific distance
57 | threshold.
58 |
59 | Args:
60 | tp: A 1D FloatTensor containing the count of True Positives at each
61 | distance threshold.
62 |
63 | fp: A 1D FloatTensor containing the count of False Positives at
64 | each distance threshold.
65 |
66 | tn: A 1D FloatTensor containing the count of True Negatives at each
67 | distance threshold.
68 |
69 | fn: A 1D FloatTensor containing the count of False Negatives at
70 | each distance threshold.
71 |
72 | count: The total number of queries
73 |
74 | Returns:
75 | A 1D FloatTensor containing the metric at each distance threshold.
76 | """
77 | result: FloatTensor = tf.math.divide_no_nan(tn, tn + fn)
78 | return result
79 |
--------------------------------------------------------------------------------
/tests/retrieval_metrics/test_retrieval_metric.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pytest
4 | import tensorflow as tf
5 |
6 | from tensorflow_similarity.retrieval_metrics import RetrievalMetric
7 | from tensorflow_similarity.types import BoolTensor, FloatTensor, IntTensor
8 |
9 |
10 | class ConcreteRetrievalMetric(RetrievalMetric):
11 | def compute(
12 | self,
13 | *, # keyword only arguments see PEP-570
14 | query_labels: IntTensor,
15 | lookup_labels: IntTensor,
16 | lookup_distances: FloatTensor,
17 | match_mask: BoolTensor,
18 | ) -> FloatTensor:
19 | return tf.constant(1.0)
20 |
21 |
22 | def test_concrete_instance():
23 | rm = ConcreteRetrievalMetric(
24 | name="foo",
25 | canonical_name="bar",
26 | k=6,
27 | average="macro",
28 | )
29 |
30 | assert rm.name == "foo@6"
31 | # Check the name once we have updated the threshold
32 | rm.distance_threshold = 0.1
33 | assert rm.name == "foo@6 : distance_threshold@0.1"
34 | assert repr(rm) == "bar : foo@6 : distance_threshold@0.1"
35 | assert rm.canonical_name == "bar"
36 | assert rm.k == 6
37 | assert rm.distance_threshold == 0.1
38 | assert rm.average == "macro"
39 |
40 | expected_config = {
41 | "name": "foo@6 : distance_threshold@0.1",
42 | "canonical_name": "bar",
43 | "k": 6,
44 | "distance_threshold": 0.1,
45 | }
46 | assert rm.get_config() == expected_config
47 |
48 |
49 | def test_k_greater_than_num_lookups():
50 | query_labels = tf.constant([1, 1])
51 | match_mask = tf.constant(
52 | [
53 | [True, True, False, False],
54 | [True, False, False, True],
55 | ],
56 | dtype=bool,
57 | )
58 | rm = ConcreteRetrievalMetric(
59 | name="foo",
60 | canonical_name="bar",
61 | k=5,
62 | average="macro",
63 | )
64 |
65 | msg = "The number of neighbors must be >= K. Number of neighbors is 4 but " "K is 5."
66 |
67 | with pytest.raises(ValueError, match=re.escape(msg)):
68 | _ = rm._check_shape(query_labels=query_labels, match_mask=match_mask)
69 |
70 |
71 | def test_query_and_match_mask_different_dims():
72 | query_labels = tf.constant([1, 2, 3, 4])
73 | match_mask = tf.constant(
74 | [
75 | [True, True, False, False],
76 | [True, False, False, True],
77 | ],
78 | dtype=bool,
79 | )
80 | rm = ConcreteRetrievalMetric(
81 | name="foo",
82 | canonical_name="bar",
83 | k=4,
84 | average="macro",
85 | )
86 |
87 | msg = (
88 | "The number of lookup sets must equal the number of query labels. "
89 | "Number of lookup sets is 2 but the number of query labels is 4."
90 | )
91 |
92 | with pytest.raises(ValueError, match=re.escape(msg)):
93 | _ = rm._check_shape(query_labels=query_labels, match_mask=match_mask)
94 |
--------------------------------------------------------------------------------
/api/TFSimilarity/distances/SNRDistance.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.distances.SNRDistance
2 |
3 |
4 |
5 |
6 |
7 | Computes pairwise SNR distances between embeddings.
8 |
9 | Inherits From: [`Distance`](../../TFSimilarity/distances/Distance.md), [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | ```python
12 | TFSimilarity.distances.SNRDistance()
13 | ```
14 |
15 |
16 |
17 |
18 |
19 | The [Signal-to-Noise Ratio distance](https://arxiv.org/abs/1904.02616)
20 | is the ratio of noise variance to the feature variance.
21 |
22 | ## Methods
23 |
24 | call
25 |
26 | View source
27 |
28 | ``python
29 | @tf.function``
30 |
31 | ```python
32 | call(
33 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
34 | ,
35 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
36 |
37 | ) -> TFSimilarity.callbacks.FloatTensor```
38 |
39 | ```
40 |
41 |
42 | Compute pairwise snr distances for a given batch of embeddings.
43 | SNR(i, j): anchor i and compared feature j
44 | SNR(i,j) may not be equal to SNR(j, i)
45 |
46 |
47 |
48 |
49 | | Args |
50 |
51 |
52 | |
53 | query_embeddings
54 | |
55 |
56 | Embeddings to compute the pairwise one.
57 | |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | | Returns |
67 |
68 |
69 | |
70 | FloatTensor
71 | |
72 |
73 | Pairwise distance tensor.
74 | |
75 |
76 |
77 |
78 |
79 |
80 | get_config
81 |
82 | View source
83 |
84 | ```python
85 | get_config()
86 | ```
87 |
88 |
89 |
90 |
91 |
92 | __call__
93 |
94 | View source
95 |
96 | ```python
97 | __call__(
98 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
99 | ,
100 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
101 |
102 | )
103 | ```
104 |
105 |
106 | Call self as a function.
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/tests/losses/test_xbm_loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.losses import Reduction
3 | from tensorflow.python.framework import combinations
4 |
5 | from tensorflow_similarity import losses
6 |
7 |
8 | # TODO(ovallis): Refactor XBM loss to work with graph mode and update tests.
9 | @combinations.generate(
10 | combinations.combine(
11 | mode=["eager"],
12 | )
13 | )
14 | class XBMLossTest(tf.test.TestCase):
15 | def test_config(self):
16 | xbm_tpl_obj = losses.XBM(
17 | loss=losses.TripletLoss(distance="cosine"),
18 | memory_size=12,
19 | reduction=Reduction.SUM,
20 | name="xbm_triplet_loss",
21 | )
22 | self.assertEqual(xbm_tpl_obj.distance.name, "cosine")
23 | self.assertEqual(xbm_tpl_obj.name, "xbm_triplet_loss")
24 | self.assertEqual(xbm_tpl_obj.memory_size, 12)
25 | self.assertEqual(xbm_tpl_obj.reduction, Reduction.SUM)
26 |
27 | def test_xbm_loss(self):
28 | batch_size = 6
29 | embed_dim = 16
30 |
31 | embeddings1 = tf.random.uniform(shape=[batch_size, embed_dim])
32 | labels1 = tf.constant(
33 | [[1], [1], [2], [2], [3], [3]],
34 | dtype=tf.int32,
35 | )
36 |
37 | embeddings2 = tf.random.uniform(shape=[batch_size, embed_dim])
38 | labels2 = tf.constant(
39 | [[4], [4], [5], [5], [6], [6]],
40 | dtype=tf.int32,
41 | )
42 |
43 | distance = "cosine"
44 | loss = losses.MultiSimilarityLoss(distance=distance)
45 | loss_nowarm = losses.XBM(loss, memory_size=12, warmup_steps=0)
46 |
47 | # test enqueue
48 | loss_nowarm(labels1, embeddings1)
49 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (batch_size, embed_dim))
50 | self.assertAllEqual(loss_nowarm._y_true_memory, labels1)
51 |
52 | loss_nowarm(labels2, embeddings2)
53 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (2 * batch_size, embed_dim))
54 | self.assertAllEqual(loss_nowarm._y_true_memory, tf.concat([labels2, labels1], axis=0))
55 |
56 | # test dequeue
57 | loss_nowarm(labels2, embeddings2)
58 | self.assertAllEqual(loss_nowarm._y_pred_memory.numpy().shape, (2 * batch_size, embed_dim))
59 | self.assertAllEqual(loss_nowarm._y_true_memory, tf.concat([labels2, labels2], axis=0))
60 |
61 | # test warmup
62 | loss_warm = losses.XBM(loss, memory_size=12, warmup_steps=1)
63 |
64 | loss_warm(labels1, embeddings1)
65 | self.assertAllEqual(loss_warm._y_pred_memory.numpy().shape, (0, embed_dim))
66 | self.assertAllEqual(loss_warm._y_true_memory, tf.constant([[]], shape=(0, 1), dtype=tf.int32))
67 |
68 | loss_warm(labels2, embeddings2)
69 | self.assertAllEqual(loss_warm._y_pred_memory.numpy().shape, (batch_size, embed_dim))
70 | self.assertAllEqual(loss_warm._y_true_memory, labels2)
71 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 | Thanks for considering contributing to TF similarity!
3 |
4 | Here is what you need to know to make a successful contribution. There are
5 | just a few small guidelines you need to follow.
6 |
7 | ## Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a Contributor License
10 | Agreement. You (or your employer) retain the copyright to your contribution;
11 | this simply gives us permission to use and redistribute your contributions as
12 | part of the project. Head over to the
13 | [Google contribution agreement server](https://cla.developers.google.com/) to see
14 | your current agreements on file or to sign a new one.
15 |
16 | You generally only need to submit a CLA once, so if you've already submitted one
17 | (even if it was for a different project), you probably don't need to do it
18 | again.
19 |
20 | ## Code reviews
21 |
22 | All submissions, including submissions by project members, require review. We
23 | use GitHub pull requests for this purpose. Consult
24 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
25 | information on using pull requests.
26 |
27 | ## Code submission guidelines
28 |
29 | Here are the few steps you need to follow to ensure a smooth review of your
30 | pull request:
31 |
32 | - Ideally one PR corespond to one feature or improvement to make it easier to
33 | review. So **try** to split your contribution in meaning logical units.
34 | - Your code **must** pass the unit-tests. We use `pytest` so simply run it at the root of the project.
35 | - Your code **must** passs static analyis. We use `mypy` so simply run `mypy tensorflow_similarity/` from the root of the project.
36 | - Your code **must** comes with unit-tests to ensure long term quality
37 | - Your functions **must** be documented except obvious ones using the Google style.
38 | - Your functions **must** be typed.
39 | - Your code must pass `flake8` to ensure you follow Python coding standards.
40 | - You **must** provide documentation (located in documentation/src) edits that document any new function/features you added.
41 | - You **should** provide reference to research papers that describe the feature added if any exists.
42 | - Please **try** to include a notebook in [examples/](examples/) that demonstrates the features/model if possible or extend an existing one. If you add a new notebook, make sure to provide an edit to the `README.md` that link to it.
43 | - You **must** update the documentation/src/changelog.md that document the changes triggered by the PR.
44 | - Your code must be formatted using `black -l 119`. For automated formatting during git commit, run `pre-commit install`.
45 |
46 | ### Configure git blame to ignore the commit that migrated the code style to black.
47 | ~~~bash
48 | git config blame.ignoreRevsFile .git-blame-ignore-revs
49 | ~~~
50 |
51 |
52 | ## Community Guidelines
53 |
54 | This project follows
55 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
56 |
--------------------------------------------------------------------------------
/tensorflow_similarity/classification_metrics/false_positive_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..types import FloatTensor
22 |
23 | from .classification_metric import ClassificationMetric
24 |
25 |
26 | class FalsePositiveRate(ClassificationMetric):
27 | """Calculates the false positive rate of the query classification.
28 |
29 | Computes the FPR given the query classification counts.
30 |
31 | $$
32 | FPR = \frac{\textrm{false_positives}}{\textrm{false_positives} +
33 | \textrm{true_negatives}}
34 | $$
35 |
36 | args:
37 | name: Name associated with a specific metric object, e.g.,
38 | fpr@0.1
39 |
40 | Usage with `tf.similarity.models.SimilarityModel()`:
41 |
42 | ```python
43 | model.calibrate(x=query_examples,
44 | y=query_labels,
45 | calibration_metric='fpr')
46 | ```
47 | """
48 |
49 | def __init__(self, name: str = "fpr") -> None:
50 | super().__init__(name=name, canonical_name="false_positive_rate", maximize=False)
51 |
52 | def compute(
53 | self,
54 | tp: FloatTensor,
55 | fp: FloatTensor,
56 | tn: FloatTensor,
57 | fn: FloatTensor,
58 | count: int,
59 | ) -> FloatTensor:
60 | """Compute the classification metric.
61 |
62 | The `compute()` method supports computing the metric for a set of
63 | values, where each value represents the counts at a specific distance
64 | threshold.
65 |
66 | Args:
67 | tp: A 1D FloatTensor containing the count of True Positives at each
68 | distance threshold.
69 |
70 | fp: A 1D FloatTensor containing the count of False Positives at
71 | each distance threshold.
72 |
73 | tn: A 1D FloatTensor containing the count of True Negatives at each
74 | distance threshold.
75 |
76 | fn: A 1D FloatTensor containing the count of False Negatives at
77 | each distance threshold.
78 |
79 | count: The total number of queries
80 |
81 | Returns:
82 | A 1D FloatTensor containing the metric at each distance threshold.
83 | """
84 | result: FloatTensor = tf.math.divide_no_nan(fp, fp + tn)
85 | return result
86 |
--------------------------------------------------------------------------------
/tensorflow_similarity/matchers/match_majority_vote.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import annotations
15 |
16 | from typing import TYPE_CHECKING
17 |
18 | import tensorflow as tf
19 |
20 | if TYPE_CHECKING:
21 | from ..types import FloatTensor, IntTensor
22 |
23 | from .classification_match import ClassificationMatch
24 |
25 |
26 | class MatchMajorityVote(ClassificationMatch):
27 | """Match metrics for the most common label in a result set."""
28 |
29 | def __init__(self, name: str = "majority_vote", **kwargs) -> None:
30 | if "canonical_name" not in kwargs:
31 | kwargs["canonical_name"] = "match_majority_vote"
32 |
33 | super().__init__(name=name, **kwargs)
34 |
35 | def derive_match(self, lookup_labels: IntTensor, lookup_distances: FloatTensor) -> tuple[IntTensor, FloatTensor]:
36 | """Derive a match label and distance from a set of K neighbors.
37 |
38 | For each query, derive a single match label and distance given the
39 | associated set of lookup labels and distances.
40 |
41 | Args:
42 | lookup_labels: A 2D array where the jth row is the labels
43 | associated with the set of k neighbors for the jth query.
44 |
45 | lookup_distances: A 2D array where the jth row is the distances
46 | between the jth query and the set of k neighbors.
47 |
48 | Returns:
49 | A Tuple of FloatTensors:
50 | derived_labels: A FloatTensor of shape
51 | [len(lookup_labels), 1] where the jth row contains the derived
52 | label for the jth query.
53 |
54 | derived_distances: A FloatTensor of shape
55 | [len(lookup_labels), 1] where the jth row contains the distance
56 | associated with the jth derived label.
57 | """
58 |
59 | # TODO(ovallis): Add parallel for callback or inline evaluation.
60 | pred_labels = tf.map_fn(self._majority_vote, lookup_labels)
61 | pred_labels = tf.expand_dims(pred_labels, axis=-1)
62 |
63 | agg_dist = tf.math.reduce_mean(lookup_distances, 1)
64 | agg_dist = tf.expand_dims(agg_dist, axis=-1)
65 |
66 | return pred_labels, agg_dist
67 |
68 | def _majority_vote(self, lookup_labels):
69 | labels, _, counts = tf.unique_with_counts(lookup_labels)
70 | majority = tf.argmax(counts)
71 |
72 | return tf.gather(labels, majority)
73 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Testing
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | workflow_dispatch: # Allow manual triggers
8 |
9 | jobs:
10 | test:
11 |
12 | runs-on: ubuntu-latest
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | include:
17 | - python-version: '3.7'
18 | tf-version: '2.8'
19 | - python-version: '3.7'
20 | tf-version: '2.11'
21 | - python-version: '3.10'
22 | tf-version: '2.8'
23 | - python-version: '3.10'
24 | tf-version: '2.11'
25 |
26 | steps:
27 | - uses: actions/checkout@v2
28 | - name: Set up Python ${{ matrix.python-version }}
29 | uses: actions/setup-python@v2
30 | with:
31 | python-version: ${{ matrix.python-version }}
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install coveralls
36 |
37 | - name: Install TF package
38 | run: |
39 | pip install tensorflow==${{ matrix.tf-version }}
40 | # Fix proto dep issue in protobuf 4
41 | pip install protobuf==3.20.*
42 |
43 | - name: Install dev packages
44 | run: |
45 | pip install ".[dev,visualization,redis,faiss,nmslib]"
46 |
47 | - name: Lint with flake8
48 | run: |
49 | # stop the build if there are Python syntax errors or undefined names
50 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
51 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
52 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
53 |
54 | - name: Static analysis
55 | run: |
56 | mypy tensorflow_similarity/ --raise-exceptions
57 |
58 | - name: Test with pytest
59 | run: |
60 | coverage run -m pytest tests/
61 |
62 | - name: Upload coverage to Coveralls
63 | if: ${{ github.ref == 'refs/heads/master' }}
64 | run: |
65 | coveralls --service=github
66 | env:
67 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
68 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }}
69 | COVERALLS_PARALLEL: true
70 |
71 | - name: Check black formatting
72 | uses: psf/black@stable
73 | with:
74 | options: "--check --verbose"
75 | version: "22.6.0"
76 | - name: Check isort
77 | uses: isort/isort-action@master
78 |
79 | coveralls:
80 |
81 | name: Indicate completion to coveralls.io
82 | needs: test
83 | runs-on: ubuntu-latest
84 |
85 | steps:
86 | - name: Install dependencies
87 | if: ${{ github.ref == 'refs/heads/master' }}
88 | run: |
89 | python -m pip install --upgrade pip
90 | pip install coveralls
91 |
92 | - name: Finished
93 | if: ${{ github.ref == 'refs/heads/master' }}
94 | run: |
95 | coveralls --service=github --finish
96 | env:
97 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
98 |
--------------------------------------------------------------------------------
/tests/search/test_nmslib.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tensorflow_similarity.search.nmslib import NMSLibSearch
4 |
5 |
6 | def test_index_match():
7 | target = np.array([1, 1, 2], dtype="float32")
8 | embs = np.array([[1, 1, 3], [3, 1, 2]], dtype="float32")
9 |
10 | search_index = NMSLibSearch("cosine", 3)
11 | search_index.add(embs[0], 0)
12 | search_index.add(embs[1], 1)
13 |
14 | idxs, embs = search_index.lookup(target)
15 |
16 | assert len(embs) == 2
17 | assert list(idxs) == [0, 1]
18 |
19 |
20 | def test_reset():
21 | target = np.array([1, 2, 3], dtype="float32")
22 | embs = np.array([[4, 2, 1], [2, 3, 5]], dtype="float32")
23 |
24 | search_index = NMSLibSearch("cosine", 3)
25 | search_index.add(embs[0], 0)
26 | search_index.add(embs[1], 1)
27 |
28 | idxs, dists = search_index.lookup(target, k=2)
29 |
30 | assert len(dists) == 2
31 | print(dists)
32 | assert list(idxs) == [1, 0]
33 |
34 | search_index.reset()
35 | # switch order
36 | search_index.add(embs[1], 0)
37 | search_index.add(embs[0], 1)
38 | idxs, dists = search_index.lookup(target, k=2)
39 |
40 | assert len(dists) == 2
41 | assert list(idxs) == [0, 1]
42 |
43 |
44 | def test_index_save(tmp_path):
45 | target = np.array([1, 1, 2], dtype="float32")
46 | embs = np.array([[1, 1, 3], [3, 1, 2]], dtype="float32")
47 |
48 | search_index = NMSLibSearch("cosine", 3)
49 | search_index.add(embs[0], 0)
50 | search_index.add(embs[1], 1)
51 |
52 | idxs, embs = search_index.lookup(target)
53 |
54 | assert len(embs) == 2
55 | assert list(idxs) == [0, 1]
56 |
57 | search_index.save(tmp_path)
58 |
59 | search_index2 = NMSLibSearch("cosine", 3)
60 | search_index2.load(tmp_path)
61 |
62 | idxs2, embs2 = search_index.lookup(target)
63 | assert len(embs2) == 2
64 | assert list(idxs2) == [0, 1]
65 |
66 | # add more
67 | search_index2.add(np.array([3.0, 3.0, 3.0]), 3)
68 | idxs3, embs3 = search_index2.lookup(target)
69 | assert len(embs3) == 3
70 | assert list(idxs3) == [0, 3, 1]
71 |
72 |
73 | def test_batch_vs_single(tmp_path):
74 | num_targets = 10
75 | index_size = 100
76 | vect_dim = 16
77 |
78 | # gen
79 | idxs = list(range(index_size))
80 |
81 | targets = np.random.random((num_targets, vect_dim)).astype("float32")
82 | embs = np.random.random((index_size, vect_dim)).astype("float32")
83 |
84 | # build search_index
85 | search_index = NMSLibSearch("cosine", vect_dim)
86 | search_index.batch_add(embs, idxs)
87 |
88 | # batch
89 | batch_idxs, _ = search_index.batch_lookup(targets)
90 |
91 | # single
92 | singles_idxs = []
93 | for t in targets:
94 | idxs, embs = search_index.lookup(t)
95 | singles_idxs.append(idxs)
96 |
97 | for i in range(num_targets):
98 | # k neigboors are the same?
99 | for k in range(3):
100 | assert batch_idxs[i][k] == singles_idxs[i][k]
101 |
--------------------------------------------------------------------------------
/api/TFSimilarity/distances/SquaredEuclideanDistance.md:
--------------------------------------------------------------------------------
1 | # TFSimilarity.distances.SquaredEuclideanDistance
2 |
3 |
4 |
5 |
6 |
7 | Compute pairwise squared Euclidean distance.
8 |
9 | Inherits From: [`Distance`](../../TFSimilarity/distances/Distance.md), [`ABC`](../../TFSimilarity/distances/ABC.md)
10 |
11 | ```python
12 | TFSimilarity.distances.SquaredEuclideanDistance()
13 | ```
14 |
15 |
16 |
17 |
18 |
19 | The [Squared Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance) is
20 | a distance that varies from 0 (similar) to infinity (dissimilar).
21 |
22 | ## Methods
23 |
24 | call
25 |
26 | View source
27 |
28 | ``python
29 | @tf.function``
30 |
31 | ```python
32 | call(
33 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
34 | ,
35 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
36 |
37 | ) -> TFSimilarity.callbacks.FloatTensor```
38 |
39 | ```
40 |
41 |
42 | Compute pairwise distances for a given batch of embeddings.
43 |
44 |
45 |
46 |
47 |
48 | | Args |
49 |
50 |
51 | |
52 | query_embeddings
53 | |
54 |
55 | Embeddings to compute the pairwise one.
56 | |
57 |
58 | |
59 | key_embeddings
60 | |
61 |
62 | Embeddings to compute the pairwise one.
63 | |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | | Returns |
73 |
74 |
75 | |
76 | FloatTensor
77 | |
78 |
79 | Pairwise distance tensor.
80 | |
81 |
82 |
83 |
84 |
85 |
86 | get_config
87 |
88 | View source
89 |
90 | ```python
91 | get_config()
92 | ```
93 |
94 |
95 |
96 |
97 |
98 | __call__
99 |
100 | View source
101 |
102 | ```python
103 | __call__(
104 | query_embeddings: TFSimilarity.callbacks.FloatTensor```
105 | ,
106 | key_embeddings: TFSimilarity.callbacks.FloatTensor```
107 |
108 | )
109 | ```
110 |
111 |
112 | Call self as a function.
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------