├── .github └── workflows │ └── tests.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ml_metrics ├── __init__.py ├── _src │ ├── __init__.py │ ├── aggregates │ │ ├── __init__.py │ │ ├── base.py │ │ ├── base_test.py │ │ ├── classification.py │ │ ├── classification_test.py │ │ ├── keras_metric_wrapper.py │ │ ├── keras_metric_wrapper_test.py │ │ ├── retrieval.py │ │ ├── retrieval_test.py │ │ ├── rolling_stats.py │ │ ├── rolling_stats_test.py │ │ ├── test_utils.py │ │ ├── text.py │ │ ├── text_test.py │ │ ├── types.py │ │ └── utils.py │ ├── chainables │ │ ├── __init__.py │ │ ├── io.py │ │ ├── io_test.py │ │ ├── lazy_fns.py │ │ ├── lazy_fns_test.py │ │ ├── transform.py │ │ ├── transform_test.py │ │ ├── tree.py │ │ ├── tree_fns.py │ │ ├── tree_fns_test.py │ │ └── tree_test.py │ ├── metrics │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── classification_test.py │ │ ├── retrieval.py │ │ ├── retrieval_test.py │ │ ├── rolling_stats.py │ │ ├── rolling_stats_test.py │ │ ├── text.py │ │ ├── text_test.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── signals │ │ ├── cg_score.py │ │ ├── cg_score_test.py │ │ ├── cross_entropy.py │ │ ├── cross_entropy_test.py │ │ ├── flip_masks.py │ │ ├── flip_masks_test.py │ │ ├── image.py │ │ ├── image_test.py │ │ ├── text.py │ │ ├── text_test.py │ │ ├── topk_accuracy.py │ │ └── topk_accuracy_test.py │ ├── tools │ │ └── telemetry │ │ │ └── telemetry.py │ ├── types.py │ ├── types_test.py │ └── utils │ │ ├── df_utils.py │ │ ├── df_utils_test.py │ │ ├── func_utils.py │ │ ├── func_utils_test.py │ │ ├── iter_utils.py │ │ ├── iter_utils_test.py │ │ ├── math_utils.py │ │ ├── math_utils_test.py │ │ ├── proto_utils.py │ │ ├── proto_utils_test.py │ │ ├── test_utils.py │ │ └── test_utils_test.py ├── aggregates.py ├── chainable.py ├── data_sinks.py ├── metrics │ ├── __init__.py │ ├── classification.py │ ├── retrieval.py │ ├── rolling_stats.py │ ├── text.py │ └── utils.py ├── oss │ ├── README.md │ ├── build.Dockerfile │ ├── build_whl.sh │ └── runner.sh ├── signals │ ├── __init__.py │ ├── classification.py │ ├── image.py │ └── text.py └── utils.py ├── pyproject.toml └── setup.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Build & Test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.11", "3.12", "3.13"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v4" 21 | - name: Create directory 22 | run: | 23 | mkdir -p /tmp/ml_metrics 24 | cp -r . /tmp/ml_metrics 25 | - name: Build package 26 | run: | 27 | set -xe 28 | export PYTHON_VERSION=${{ matrix.python-version }} 29 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1) 30 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2) 31 | export AUDITWHEEL_PLATFORM="manylinux2014_x86_64" 32 | export ML_METRICS_RUN_TESTS=true 33 | cd /tmp/ml_metrics 34 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \ 35 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ 36 | --build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ 37 | --build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ 38 | -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile 39 | docker run --rm -a stdin -a stdout -a stderr \ 40 | --env PYTHON_VERSION=${PYTHON_VERSION} \ 41 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ 42 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ 43 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \ 44 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 45 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \ 46 | -v /tmp/ml_metrics:/tmp/ml_metrics \ 47 | --name ml_metrics ml_metrics:${PYTHON_VERSION} \ 48 | bash ml_metrics/oss/build_whl.sh -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML-metrics 2 | 3 | Ml-metrics provides performant and distributed friendly ML metrics implementations. 4 | 5 | This is not an officially supported Google product. -------------------------------------------------------------------------------- /ml_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ml_metrics/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base AggregateFn for all the aggregates.""" 15 | from __future__ import annotations 16 | 17 | import abc 18 | from collections.abc import Callable, Iterable 19 | import dataclasses 20 | from typing import Any, Protocol, Self, TypeGuard, TypeVar, runtime_checkable 21 | 22 | from ml_metrics._src import types 23 | from ml_metrics._src.chainables import lazy_fns 24 | from ml_metrics._src.chainables import tree 25 | 26 | _T = TypeVar('_T') 27 | _ResolvableOrMakeable = types.Resolvable[_T] | types.Makeable[_T] 28 | 29 | 30 | @runtime_checkable 31 | class Metric(Protocol): 32 | """MergibleMetric can be used as simple Map, and also MapReduce.""" 33 | 34 | @abc.abstractmethod 35 | def add(self, *inputs, **named_inputs) -> Any: 36 | """Computes the state from a batch while outputting batch output.""" 37 | 38 | @abc.abstractmethod 39 | def result(self) -> Any: 40 | """Returns the result of the metric.""" 41 | 42 | _MetricT = TypeVar('_MetricT', bound=Metric) 43 | 44 | 45 | def as_agg_fn( 46 | cls: Callable[..., _MetricT], 47 | *args, 48 | nested: bool = False, 49 | agg_preprocess_fn: Callable[..., Any] | None = None, 50 | **kwargs, 51 | ) -> AggregateFn: 52 | """Creates an AggregateFn from a metric class.""" 53 | deferred_metric: types.Resolvable[_MetricT] = lazy_fns.trace(cls)( 54 | *args, **kwargs 55 | ) 56 | # Try resolve the target at construction at calltime to detect errors. 57 | _ = lazy_fns.maybe_make(deferred_metric) 58 | agg_fn = MergeableMetricAggFn(deferred_metric) 59 | if nested: 60 | agg_fn = AggFnNested(agg_fn, preprocess_fn=agg_preprocess_fn) 61 | return agg_fn 62 | 63 | 64 | @runtime_checkable 65 | class MergeableMetric(Metric, Protocol): 66 | """MergibleMetric can be used as simple Map, and also MapReduce.""" 67 | 68 | @abc.abstractmethod 69 | def merge(self, other: Self): 70 | """Merges the metric with another metric of the same type.""" 71 | 72 | 73 | class CallableMetric(MergeableMetric, Callable[..., Any]): 74 | """A metric that is also callable. 75 | 76 | The CallableMetric is the recommended interface to implement a metric that 77 | supports both calculating batch result (`process`) and merging batch results 78 | (`merge`). A default `add` method is provided, but should be overwritten if 79 | the `merge` method is not applicable for `add`. 80 | """ 81 | 82 | @abc.abstractmethod 83 | def new(self, *args, **kwargs) -> Self: 84 | """Calculate the sufficient statistics, should be idempotent.""" 85 | 86 | def add(self, *args, **kwargs): 87 | """Updates the sufficient statistics with a batch of inputs.""" 88 | batch_result = self.new(*args, **kwargs) 89 | self.merge(batch_result) 90 | return batch_result 91 | 92 | def __call__(self, *args, **kwargs): 93 | """Calculates the result from the sufficient statistics.""" 94 | return self.new(*args, **kwargs).result() 95 | 96 | 97 | class HasAsAggFn(Protocol): 98 | """An object that can be used as an aggregate function.""" 99 | 100 | def as_agg_fn(self) -> Aggregatable: 101 | """Returns the aggregate function.""" 102 | 103 | 104 | def has_as_agg_fn(obj: Any) -> TypeGuard[HasAsAggFn]: 105 | """Returns True if the object has an as_agg_fn method.""" 106 | return types.obj_has_method(obj, 'as_agg_fn') 107 | 108 | 109 | @runtime_checkable 110 | class Aggregatable(Protocol): 111 | """An aggregation interface, similar to apche_beam.CombineFn.""" 112 | 113 | def create_state(self) -> Any: 114 | """Creates the initial states for the aggregation.""" 115 | return None 116 | 117 | @abc.abstractmethod 118 | def update_state(self, state, *inputs, **named_inputs): 119 | """Update the state from a batch of inputs. 120 | 121 | Args: 122 | state: the current state. 123 | *inputs: elements to add. 124 | **named_inputs: elements to add. 125 | """ 126 | 127 | def merge_states(self, states): 128 | """Merging multiple states into a one state value. 129 | 130 | This is only required for distributed implementations such as Beam. Only the 131 | first state may be modified and returned for efficiency. 132 | 133 | Args: 134 | states: the states to be merged. 135 | """ 136 | raise NotImplementedError() 137 | 138 | def get_result(self, state): 139 | """Computes and returns the result from state. 140 | 141 | Args: 142 | state: the final state value computed by this CombineFn. 143 | 144 | Returns: 145 | state. 146 | """ 147 | return state 148 | 149 | 150 | class AggregateFn(Aggregatable): 151 | """An aggregation interface, similar to apche_beam.CombineFn.""" 152 | 153 | def __call__(self, *inputs, **named_inputs): 154 | """Directly apply aggregate on inputs.""" 155 | return self.get_result( 156 | self.update_state(self.create_state(), *inputs, **named_inputs) 157 | ) 158 | 159 | 160 | class MergeableMetricAggFn(AggregateFn): 161 | """A aggregation wrapper for MergeableMetric.""" 162 | 163 | metric_maker: _ResolvableOrMakeable[MergeableMetric] 164 | 165 | def __init__(self, metric_maker: _ResolvableOrMakeable[MergeableMetric]): 166 | super().__init__() 167 | if not ( 168 | types.is_resolvable(metric_maker) or types.is_makeable(metric_maker) 169 | ): 170 | raise TypeError( 171 | 'metric_maker must be an instance of Makeable or Resolvable. got' 172 | f' {type(metric_maker)}' 173 | ) 174 | self.metric_maker = metric_maker 175 | 176 | def __eq__(self, other, /): 177 | return ( 178 | isinstance(other, MergeableMetricAggFn) 179 | and self.metric_maker == other.metric_maker 180 | ) 181 | 182 | def create_state(self) -> MergeableMetric: 183 | metric = self.metric_maker 184 | if types.is_makeable(metric): 185 | return metric.make() 186 | elif types.is_resolvable(metric): 187 | return metric.result_() 188 | else: 189 | raise TypeError(f'{type(metric)} is not a Makeable or Resolvable.') 190 | 191 | def update_state( 192 | self, state: MergeableMetric, *args, **kwargs 193 | ) -> MergeableMetric: 194 | state.add(*args, **kwargs) 195 | return state 196 | 197 | def merge_states(self, states: Iterable[MergeableMetric]) -> MergeableMetric: 198 | iter_states = iter(states) 199 | result = next(iter_states) 200 | for state in iter_states: 201 | result.merge(state) 202 | return result 203 | 204 | def get_result(self, state: MergeableMetric) -> Any: 205 | return state.result() 206 | 207 | 208 | class AggFnNested(AggregateFn): 209 | """AggregateFn that traverses and aggregates each leaf of a PyTree.""" 210 | 211 | fn: Aggregatable 212 | preprocess_fn: Callable[..., tree.TreeLike[Any]] | None 213 | 214 | def __init__( 215 | self, 216 | fn: Aggregatable, 217 | preprocess_fn: Callable[..., tree.TreeLike[Any]] | None = None, 218 | ): 219 | if preprocess_fn is not None and not callable(preprocess_fn): 220 | raise ValueError(f'preporcess_fn must be a callable. got {preprocess_fn}') 221 | if not isinstance(fn, Aggregatable): 222 | raise ValueError(f'fn must be an instance of Aggregatable. got {fn}') 223 | super().__init__() 224 | self.fn = fn 225 | self.preprocess_fn = preprocess_fn 226 | 227 | def create_state(self): 228 | """Creates the initial states for the aggregation.""" 229 | return None 230 | 231 | def update_state(self, state: tree.TreeMapView, inputs: tree.TreeLike[Any]): 232 | """Update the state from a batch of inputs.""" 233 | if self.preprocess_fn: 234 | inputs = self.preprocess_fn(inputs) 235 | if state is None: 236 | state = tree.TreeMapView( 237 | inputs, map_fn=lambda x: self.fn.create_state() 238 | ).apply() 239 | state = tree.TreeMapView(state) 240 | inputs = tree.TreeMapView.as_view(inputs) 241 | return state.copy_and_update( 242 | (k, self.fn.update_state(state[k], v)) for k, v in inputs.items() 243 | ) 244 | 245 | # TODO: b/311207032 - Implement this. 246 | def merge_states(self, states): 247 | """Merging multiple states into a one state value.""" 248 | raise NotImplementedError() 249 | 250 | def get_result(self, state): 251 | """Computes and returns the result from state.""" 252 | return tree.TreeMapView.as_view(state, map_fn=self.fn.get_result).apply() 253 | 254 | 255 | @dataclasses.dataclass(frozen=True) 256 | class UserAggregateFn(AggregateFn): 257 | """An aggregation interface, similar to apche_beam.CombineFn.""" 258 | 259 | fn: Aggregatable 260 | 261 | def __post_init__(self): 262 | if not isinstance(self.fn, Aggregatable): 263 | raise ValueError( 264 | f'UserAggregateFn must be an instance of Aggregatable. got {self.fn}' 265 | ) 266 | 267 | def create_state(self) -> Any: 268 | """Creates the initial states for the aggregation.""" 269 | return self.fn.create_state() 270 | 271 | def update_state(self, state, *inputs, **named_inputs): 272 | """Update the state from a batch of inputs.""" 273 | return self.fn.update_state(state, *inputs, **named_inputs) 274 | 275 | def merge_states(self, states): 276 | """Merging multiple states into a one state value.""" 277 | return self.fn.merge_states(states) 278 | 279 | def get_result(self, state): 280 | """Computes and returns the result from state.""" 281 | return self.fn.get_result(state) 282 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for core lib.""" 15 | 16 | from absl.testing import absltest 17 | from ml_metrics._src.aggregates import base 18 | from ml_metrics._src.aggregates import test_utils 19 | from ml_metrics._src.chainables import lazy_fns 20 | 21 | 22 | class AggregatesTest(absltest.TestCase): 23 | 24 | def test_callable_combinefn_in_process(self): 25 | sum_fn = base.UserAggregateFn(test_utils._SumMetric().as_agg_fn()) 26 | self.assertEqual(sum_fn(list(range(4))), 6) 27 | 28 | def test_mergeable_aggregate_fn_in_process(self): 29 | sum_fn = base.as_agg_fn(test_utils._SumMetric) 30 | self.assertEqual(6, sum_fn([1, 2, 3])) 31 | 32 | def test_mergeable_aggregate_fn_from_resolvable(self): 33 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)() 34 | sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum) 35 | self.assertEqual(6, sum_fn([1, 2, 3])) 36 | 37 | def test_mergeable_aggregate_fn_from_makeable(self): 38 | 39 | class MakeableSum(test_utils._SumMetric): 40 | 41 | def make(self): 42 | return self 43 | 44 | sum_fn = base.MergeableMetricAggFn(MakeableSum()) 45 | self.assertEqual(6, sum_fn([1, 2, 3])) 46 | 47 | def test_mergeable_aggregate_fn_eq(self): 48 | sum_fn = base.MergeableMetricAggFn(lazy_fns.trace(test_utils._SumMetric)()) 49 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)() 50 | other_sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum) 51 | self.assertEqual(sum_fn, other_sum_fn) 52 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)(len) 53 | other_sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum) 54 | self.assertNotEqual(sum_fn, other_sum_fn) 55 | 56 | def test_mergeable_aggregate_fn_unsupported_type(self): 57 | with self.assertRaisesRegex(TypeError, 'must be an instance of.+ got'): 58 | # disable pytype check for the runtime error to surface. 59 | _ = base.MergeableMetricAggFn(test_utils._SumMetric()) # pytype: disable=wrong-arg-types 60 | 61 | def test_metric_callable(self): 62 | sum_fn = test_utils._SumMetric() 63 | self.assertEqual(6, sum_fn([1, 2, 3])) 64 | 65 | def test_has_as_agg_fn(self): 66 | sum_aggfn = test_utils._SumMetric().as_agg_fn() 67 | self.assertFalse(base.has_as_agg_fn(sum_aggfn)) 68 | self.assertIsInstance(sum_aggfn, base.Aggregatable) 69 | sum_metric = test_utils._SumMetric() 70 | self.assertTrue(base.has_as_agg_fn(sum_metric)) 71 | 72 | 73 | if __name__ == '__main__': 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/keras_metric_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Keras metric wrapper.""" 15 | from __future__ import annotations 16 | 17 | from collections.abc import Iterable 18 | import dataclasses 19 | from typing import Any, Protocol 20 | 21 | from ml_metrics._src.aggregates import base as agg 22 | 23 | 24 | class KerasMetric(Protocol): 25 | """Base interface for Keras metrics.""" 26 | 27 | def update_state(self, *inputs, **named_inputs) -> None: 28 | """Updates the state from a batch of inputs.""" 29 | 30 | def reset_state(self) -> None: 31 | """Resets the state.""" 32 | 33 | def merge_state(self, other: Iterable[KerasMetric]) -> None: 34 | """Merges the state with another state of the same type.""" 35 | 36 | def result(self) -> Any: 37 | """Returns the result of the metric.""" 38 | 39 | 40 | def is_keras_metric(metric: Any) -> bool: 41 | """Duck type check for Keras metric.""" 42 | return ( 43 | hasattr(metric, "update_state") 44 | and hasattr(metric, "reset_state") 45 | and hasattr(metric, "merge_state") 46 | and hasattr(metric, "result") 47 | ) 48 | 49 | 50 | @dataclasses.dataclass 51 | class KerasAggregateFn(agg.AggregateFn): 52 | """AggregateFn for Keras metrics.""" 53 | 54 | metric: KerasMetric 55 | 56 | def __post_init__(self): 57 | if is_keras_metric(self.metric): 58 | self.metric.reset_state() 59 | self._metric = self.metric 60 | else: 61 | try: 62 | assert hasattr(self.metric, "__call__") 63 | self._metric = self.metric() 64 | if not is_keras_metric(self._metric): 65 | raise TypeError("metric must implement Keras metric base interface.") 66 | except Exception as e: 67 | raise TypeError( 68 | f"Cannot construct a Keras metric from {self.metric}." 69 | ) from e 70 | 71 | def create_state(self) -> KerasMetric: 72 | assert hasattr(self._metric, "reset_state") 73 | self._metric.reset_state() 74 | return self._metric 75 | 76 | def update_state( 77 | self, state: KerasMetric, *inputs: Any, **named_inputs: Any 78 | ) -> KerasMetric: 79 | state.update_state(*inputs, **named_inputs) 80 | return state 81 | 82 | def merge_states(self, states: Iterable[KerasMetric]) -> KerasMetric: 83 | # This in-place merges all the states into the first state and returns it. 84 | iter_states = iter(states) 85 | result = next(iter_states) 86 | result.merge_state(list(iter_states)) 87 | return result 88 | 89 | def get_result(self, state: KerasMetric) -> Any: 90 | result = state.result() 91 | if hasattr(result, "numpy"): 92 | return result.numpy() 93 | return result 94 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/keras_metric_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wrapper for Keras metrics.""" 15 | 16 | import importlib 17 | 18 | from ml_metrics._src.aggregates import keras_metric_wrapper 19 | 20 | from absl.testing import absltest 21 | 22 | 23 | class MockKerasMetric: 24 | 25 | def __init__(self): 26 | self.reset_state() 27 | 28 | def reset_state(self): 29 | self._state = 0 30 | 31 | def update_state(self, inputs): 32 | self._state += sum(inputs) 33 | 34 | def merge_state(self, states): 35 | for state in states: 36 | self._state += state 37 | 38 | def result(self): 39 | return self._state 40 | 41 | 42 | class KerasTest(absltest.TestCase): 43 | 44 | def test_mock_keras_metric(self): 45 | metric = keras_metric_wrapper.KerasAggregateFn(MockKerasMetric()) 46 | self.assertEqual(6, metric([1, 2, 3])) 47 | 48 | def test_keras_metric_wrapper_merge(self): 49 | try: 50 | tf = importlib.import_module("tensorflow") 51 | except ImportError: 52 | # Ignores the import error if tensorflow is not installed. 53 | return 54 | metric1 = keras_metric_wrapper.KerasAggregateFn( 55 | tf.keras.metrics.Mean(name="mean") 56 | ) 57 | metric2 = keras_metric_wrapper.KerasAggregateFn( 58 | tf.keras.metrics.Mean(name="mean") 59 | ) 60 | state1, state2 = metric1.create_state(), metric2.create_state() 61 | merged_state = metric1.merge_states([ 62 | metric1.update_state(state1, [1, 2, 3]), 63 | metric2.update_state(state2, [4, 5, 6]), 64 | ]) 65 | self.assertEqual(3.5, metric1.get_result(merged_state)) 66 | 67 | def test_keras_metric_wrapper(self): 68 | try: 69 | tf = importlib.import_module("tensorflow") 70 | except ImportError: 71 | # Ignores the import error if tensorflow is not installed. 72 | return 73 | metric = keras_metric_wrapper.KerasAggregateFn( 74 | tf.keras.metrics.Mean(name="mean") 75 | ) 76 | self.assertEqual(2, metric([1, 2, 3])) 77 | 78 | 79 | if __name__ == "__main__": 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """test_utils.""" 15 | from ml_metrics._src.aggregates import base 16 | 17 | 18 | class _SumMetric(base.CallableMetric, base.HasAsAggFn): 19 | """Mock Metric for test.""" 20 | 21 | def __init__(self, state=0): 22 | self._state = state 23 | 24 | def as_agg_fn(self): 25 | return base.as_agg_fn(self.__class__) 26 | 27 | @property 28 | def state(self): 29 | return self._state 30 | 31 | def new(self, x): 32 | return _SumMetric(state=sum(x)) 33 | 34 | def merge(self, other): 35 | self._state += other.state 36 | 37 | def result(self): 38 | return self._state 39 | 40 | 41 | class _SumAggFn: 42 | """Mock CombineFn for test.""" 43 | 44 | def create_state(self): 45 | return 0 46 | 47 | def update_state(self, state, x): 48 | return state + sum(x) 49 | 50 | def merge_states(self, states): 51 | return sum(states) 52 | 53 | def get_result(self, state): 54 | return state 55 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Individual text stats metrics.""" 15 | 16 | import collections 17 | from collections.abc import Sequence 18 | import dataclasses 19 | import re 20 | 21 | from ml_metrics._src.aggregates import base 22 | from ml_metrics._src.aggregates import utils 23 | from ml_metrics._src.tools.telemetry import telemetry 24 | 25 | FrequencyState = utils.FrequencyState 26 | 27 | 28 | @dataclasses.dataclass(kw_only=True) 29 | class TopKWordNGrams(base.MergeableMetric, base.HasAsAggFn): 30 | """Top k word n-grams metrics. 31 | 32 | Identify the top `k` frequent occurring word n-grams with a case-insensitive 33 | approach. The text will first be cleaned by removing non-alphabetic characters 34 | and spaces, and then converted to lowercase before computing the top k 35 | n-grams. When multiple n-grams share the same frequency, alphabetical order 36 | will be used as a tie-breaker. The result is a list of tuples containing the 37 | n-gram pattern and its corresponding frequency. The list includes either `k` 38 | or the number of distinct n-grams tuples, whichever is less. 39 | 40 | 41 | Attributes: 42 | k: 43 | Number of most frequent word n-grams. 44 | n: 45 | Number of grams. 46 | use_first_ngram_only: 47 | If `True`, only the first n words of each text will be used to form the 48 | n-grams and `count_duplicate` will be ignored. Otherwise, all words 49 | present in each text will be considered for generating the n-grams. 50 | Default to `False`. 51 | count_duplicate: 52 | If `True`, duplicate n-grams within the text are included in the total 53 | count. Otherwise, the count of a unique N-gram will only consider its 54 | first occurrence. Default to `True`. 55 | """ 56 | 57 | k: int 58 | n: int 59 | use_first_ngram_only: bool = False 60 | count_duplicate: bool = True 61 | _state: FrequencyState = dataclasses.field( 62 | default_factory=FrequencyState, init=False 63 | ) 64 | 65 | def __post_init__(self): 66 | telemetry.increment_counter( 67 | api='ml_metrics', category='metric', reference=self.__class__.__name__ 68 | ) 69 | if self.k <= 0 or self.n <= 0: 70 | raise ValueError( 71 | f'k and n must be positive integers but k={self.k} and n={self.n} was' 72 | ' passed.' 73 | ) 74 | 75 | def as_agg_fn(self) -> base.AggregateFn: 76 | return base.as_agg_fn( 77 | self.__class__, 78 | k=self.k, 79 | n=self.n, 80 | use_first_ngram_only=self.use_first_ngram_only, 81 | count_duplicate=self.count_duplicate, 82 | ) 83 | 84 | @property 85 | def state(self) -> FrequencyState: 86 | return self._state 87 | 88 | def add(self, texts: Sequence[str]) -> list[tuple[str, float]]: 89 | ngrams_counter = collections.Counter() 90 | for text in texts: 91 | # Remove non-alphabetical and non-space characters 92 | words = re.sub(r'[^a-zA-Z ]+', '', text).lower().split() 93 | if self.n <= len(words): 94 | ngrams = [] 95 | if self.use_first_ngram_only: 96 | ngrams.append(' '.join(words[:self.n])) 97 | else: 98 | for idx in range(len(words) - self.n + 1): 99 | ngrams.append(' '.join(words[idx : idx + self.n])) 100 | if not self.count_duplicate: 101 | ngrams = set(ngrams) 102 | ngrams_counter.update(ngrams) 103 | batch_reault = FrequencyState(counter=ngrams_counter, count=len(texts)) 104 | self._state.merge(batch_reault) 105 | 106 | result = batch_reault.result()[:self.k] 107 | return result 108 | 109 | def merge(self, other: 'TopKWordNGrams'): 110 | # TODO(b/331796958): Optimize storage consumption 111 | self._state.merge(other.state) 112 | 113 | def result(self) -> list[tuple[str, float]]: 114 | return self._state.result()[:self.k] 115 | 116 | 117 | @dataclasses.dataclass(kw_only=True) 118 | class PatternFrequency(base.MergeableMetric, base.HasAsAggFn): 119 | """Pattern frequency metric. 120 | 121 | Identify the frequency of occurrence for each pattern found within the given 122 | texts. 123 | 124 | 125 | Attributes: 126 | patterns: 127 | Sequence of text patterns. 128 | count_duplicate: 129 | If `True`, duplicate pattern within the text are included in the total 130 | count. Otherwise, the count of a pattern will only consider its first 131 | occurrence. Default to `False`. 132 | """ 133 | 134 | patterns: Sequence[str] 135 | count_duplicate: bool = True 136 | _state: FrequencyState = dataclasses.field( 137 | default_factory=FrequencyState, init=False 138 | ) 139 | 140 | def __post_init__(self): 141 | telemetry.increment_counter( 142 | api='ml_metrics', category='metric', reference=self.__class__.__name__ 143 | ) 144 | if not self.patterns: 145 | raise ValueError('Patterns must not be empty.') 146 | 147 | if len(set(self.patterns)) != len(self.patterns): 148 | raise ValueError(f'Patterns must be unique: {self.patterns}') 149 | 150 | def as_agg_fn(self) -> base.AggregateFn: 151 | return base.as_agg_fn( 152 | self.__class__, 153 | patterns=self.patterns, 154 | count_duplicate=self.count_duplicate, 155 | ) 156 | 157 | @property 158 | def state(self) -> FrequencyState: 159 | return self._state 160 | 161 | def add(self, texts: Sequence[str]) -> list[tuple[str, float]]: 162 | batch_frquency_state = FrequencyState() 163 | for pattern in self.patterns: 164 | for text in texts: 165 | num_matches = 0 166 | if self.count_duplicate: 167 | matches = list( 168 | re.finditer(r'(?=({}))'.format(re.escape(pattern)), text) 169 | ) 170 | num_matches = len(matches) 171 | elif text.find(pattern) >= 0: 172 | num_matches = 1 173 | batch_frquency_state.counter[pattern] += num_matches 174 | batch_frquency_state.count = len(texts) 175 | self._state.merge(batch_frquency_state) 176 | return batch_frquency_state.result() 177 | 178 | def merge(self, other: 'PatternFrequency'): 179 | self._state.merge(other.state) 180 | 181 | def result(self) -> list[tuple[str, float]]: 182 | return self._state.result() 183 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/text_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for text.""" 15 | 16 | from absl.testing import parameterized 17 | from ml_metrics._src.aggregates import text 18 | from absl.testing import absltest 19 | 20 | 21 | class TopKWordNGramsTest(parameterized.TestCase): 22 | """Tests for TopKWordNGrams.""" 23 | 24 | @parameterized.named_parameters([ 25 | dict( 26 | testcase_name='distinct_ngrams', 27 | k=2, 28 | n=2, 29 | use_first_ngram_only=False, 30 | count_duplicate=False, 31 | expected_result=[ 32 | ('a a', 1 / 3), 33 | ('b b', 1 / 3), 34 | # This demonstrated that the alphabetical order is used as a 35 | # tie-breaker in the frequency state. 36 | # ('c c', 1 / 3), # Included when k >= 3 37 | # ('d a', 1 / 3), # Included when k >= 4 38 | ], 39 | ), 40 | dict( 41 | testcase_name='count_duplicate', 42 | k=2, 43 | n=2, 44 | use_first_ngram_only=False, 45 | count_duplicate=True, 46 | expected_result=[ 47 | ('b b', 2 / 3), 48 | ('a a', 1 / 3), 49 | # This demonstrated that the alphabetical order is used as a 50 | # tie-breaker in the frequency state. 51 | # ('c c', 1 / 3), # Included when k >= 3 52 | # ('d a', 1 / 3), # Included when k >= 4 53 | ], 54 | ), 55 | dict( 56 | testcase_name='use_first_ngram_only', 57 | k=2, 58 | n=2, 59 | use_first_ngram_only=True, 60 | # count_duplicate will be ignored. See TopKWordNGrams description. 61 | count_duplicate=True, 62 | expected_result=[ 63 | ('b b', 1 / 3), 64 | ('c c', 1 / 3), 65 | # This demonstrated that the alphabetical order is used as a 66 | # tie-breaker in the frequency state. 67 | # ('d a', 1 / 3), # Included when k >= 3 68 | ], 69 | ), 70 | dict( 71 | testcase_name='large_k', 72 | k=10, 73 | n=2, 74 | use_first_ngram_only=False, 75 | count_duplicate=True, 76 | expected_result=[ 77 | ('b b', 2 / 3), 78 | ('a a', 1 / 3), 79 | ('c c', 1 / 3), 80 | ('d a', 1 / 3), 81 | ], 82 | ), 83 | dict( 84 | testcase_name='large_n', 85 | k=4, 86 | n=10, 87 | use_first_ngram_only=False, 88 | count_duplicate=True, 89 | expected_result=[], 90 | ), 91 | dict( 92 | testcase_name='2k_1n', 93 | k=2, 94 | n=1, 95 | use_first_ngram_only=False, 96 | count_duplicate=True, 97 | expected_result=[ 98 | ('b', 3 / 3), 99 | ('a', 2 / 3), 100 | # This demonstrated that the alphabetical order is used as a 101 | # tie-breaker in the frequency state. 102 | # ('c', 2 / 3), # Included when k >= 3 103 | # ('d', 1 / 3), # Included when k >= 4 104 | ], 105 | ), 106 | dict( 107 | testcase_name='2k_3n', 108 | k=2, 109 | n=3, 110 | use_first_ngram_only=False, 111 | count_duplicate=True, 112 | expected_result=[ 113 | ('b b b', 1 / 3), 114 | ('d a a', 1 / 3), 115 | ], 116 | ), 117 | ]) 118 | def test_compute_topkwordngrams( 119 | self, k, n, use_first_ngram_only, count_duplicate, expected_result 120 | ): 121 | batch = [ 122 | 'c c', 123 | 'b B b', # Case-insensitive 124 | 'd a a', 125 | ] 126 | metric = text.TopKWordNGrams( 127 | k=k, 128 | n=n, 129 | use_first_ngram_only=use_first_ngram_only, 130 | count_duplicate=count_duplicate, 131 | ) 132 | batch_result = metric.add(batch) 133 | self.assertSequenceAlmostEqual(expected_result, batch_result) 134 | self.assertSequenceAlmostEqual(expected_result, metric.result()) 135 | 136 | def test_compute_topkwordngrams_empty(self): 137 | metric = text.TopKWordNGrams(k=2, n=2) 138 | batch_result = metric.add([]) 139 | self.assertSequenceEqual([], batch_result) 140 | 141 | def test_topkwordngrams_add(self): 142 | metric = text.TopKWordNGrams(k=2, n=2) 143 | 144 | batch_0 = [ 145 | 'a b', 146 | 'a b', 147 | 'c d', 148 | ] 149 | batch_1 = [ 150 | 'a b', 151 | 'b c', 152 | 'b c', 153 | ] 154 | 155 | batch_0_result = metric.add(batch_0) 156 | expected_batch_0_result = [ 157 | ('a b', 2 / 3), 158 | ('c d', 1 / 3), 159 | ] 160 | self.assertSequenceAlmostEqual(expected_batch_0_result, batch_0_result) 161 | 162 | batch_1_result = metric.add(batch_1) 163 | expected_batch_1_result = [ 164 | ('b c', 2 / 3), 165 | ('a b', 1 / 3), 166 | ] 167 | self.assertSequenceAlmostEqual(expected_batch_1_result, batch_1_result) 168 | 169 | expected_metric_result = [ 170 | ('a b', 3 / 6), 171 | ('b c', 2 / 6), 172 | ] 173 | self.assertSequenceAlmostEqual(expected_metric_result, metric.result()) 174 | 175 | def test_topkwordngrams_merge(self): 176 | metric_0 = text.TopKWordNGrams(k=2, n=2) 177 | metric_1 = text.TopKWordNGrams(k=2, n=2) 178 | batch_0 = [ 179 | 'a b', 180 | 'c d', 181 | ] 182 | batch_1 = [ 183 | 'b c', 184 | 'b c', 185 | ] 186 | metric_0.add(batch_0) 187 | metric_1.add(batch_1) 188 | 189 | metric_0.merge(metric_1) 190 | 191 | expected_result = [ 192 | ('b c', 2 / 4), 193 | ('a b', 1 / 4), 194 | ] 195 | merged_result = metric_0.result() 196 | self.assertSequenceAlmostEqual(expected_result, merged_result) 197 | 198 | def test_topkwordngrams_invalid_kn(self): 199 | with self.assertRaisesRegex( 200 | ValueError, 'k and n must be positive integers.' 201 | ): 202 | text.TopKWordNGrams(k=1, n=0) 203 | 204 | with self.assertRaisesRegex( 205 | ValueError, 'k and n must be positive integers.' 206 | ): 207 | text.TopKWordNGrams(k=0, n=1) 208 | 209 | 210 | class PatternFrequencyTest(parameterized.TestCase): 211 | """Tests for PatternFrequency.""" 212 | 213 | @parameterized.named_parameters([ 214 | dict( 215 | testcase_name='not_count_duplicate', 216 | count_duplicate=False, 217 | batch=['ab ab xyx', 'xyxyx'], 218 | expected_result=[ 219 | ('ab', 1 / 2), 220 | ('mmm', 0), 221 | ('xyx', 2 / 2), 222 | ], 223 | ), 224 | dict( 225 | testcase_name='count_duplicate', 226 | count_duplicate=True, 227 | batch=['ab ab xyx', 'xyxyx'], 228 | expected_result=[ 229 | ('ab', 2 / 2), 230 | ('mmm', 0), 231 | ('xyx', 3 / 2), 232 | ], 233 | ), 234 | ]) 235 | def test_compute_pattern_frequency( 236 | self, batch, count_duplicate, expected_result 237 | ): 238 | metric = text.PatternFrequency( 239 | patterns=['ab', 'xyx', 'mmm'], count_duplicate=count_duplicate 240 | ) 241 | batch_result = sorted(metric.add(batch)) 242 | self.assertSequenceAlmostEqual(expected_result, batch_result) 243 | self.assertSequenceAlmostEqual(expected_result, sorted(metric.result())) 244 | 245 | def test_compute_pattern_frequency_empty(self): 246 | metric = text.PatternFrequency( 247 | patterns=['ab', 'xyx'], count_duplicate=False 248 | ) 249 | batch_result = metric.add([]) 250 | self.assertSequenceEqual([], batch_result) 251 | 252 | def test_pattern_frequency_add(self): 253 | metric = text.PatternFrequency( 254 | patterns=['ab', 'xyx', 'mmm'], count_duplicate=False 255 | ) 256 | 257 | batch_0 = ['ab ab xyx', 'xyx'] 258 | batch_1 = ['mmm'] 259 | 260 | batch_0_result = sorted(metric.add(batch_0)) 261 | expected_batch_0_result = [('ab', 1 / 2), ('mmm', 0), ('xyx', 2 / 2)] 262 | self.assertSequenceAlmostEqual(expected_batch_0_result, batch_0_result) 263 | 264 | batch_1_result = sorted(metric.add(batch_1)) 265 | expected_batch_1_result = [('ab', 0), ('mmm', 1), ('xyx', 0)] 266 | self.assertSequenceAlmostEqual(expected_batch_1_result, batch_1_result) 267 | 268 | expected_metric_result = [('ab', 1 / 3), ('mmm', 1 / 3), ('xyx', 2 / 3)] 269 | self.assertSequenceAlmostEqual( 270 | expected_metric_result, sorted(metric.result()) 271 | ) 272 | 273 | def test_pattern_frequency_merge(self): 274 | metric_0 = text.PatternFrequency(patterns=['ab'], count_duplicate=False) 275 | metric_1 = text.PatternFrequency(patterns=['ab'], count_duplicate=False) 276 | batch_0 = ['ab ab xyx'] 277 | batch_1 = ['ab'] 278 | metric_0.add(batch_0) 279 | metric_1.add(batch_1) 280 | 281 | metric_0.merge(metric_1) 282 | 283 | expected_result = [('ab', 2 / 2)] 284 | merged_result = sorted(metric_0.result()) 285 | self.assertSequenceAlmostEqual(expected_result, merged_result) 286 | 287 | @parameterized.named_parameters([ 288 | dict( 289 | testcase_name='empty_patterns', 290 | patterns=[], 291 | exception_regex='Patterns must not be empty.', 292 | ), 293 | dict( 294 | testcase_name='invalid_patterns', 295 | patterns=['a', 'a', 'b'], 296 | exception_regex='Patterns must be unique', 297 | ), 298 | ]) 299 | def test_pattern_frequency_invalid_patterns(self, patterns, exception_regex): 300 | with self.assertRaisesRegex(ValueError, exception_regex): 301 | text.PatternFrequency(patterns=patterns) 302 | 303 | 304 | if __name__ == '__main__': 305 | absltest.main() 306 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common types for the aggregates.""" 15 | 16 | import enum 17 | from numpy import typing as npt 18 | 19 | NumbersT = npt.ArrayLike 20 | DefaultDType = float 21 | 22 | 23 | class InputType(enum.StrEnum): # pylint: disable=invalid-enum-extension 24 | """Label prediction encoding types.""" 25 | 26 | # 1D array per batch, e.g., [0,1,0,1,0], [-1, 1, -1], or ['Y', 'N'] 27 | BINARY = 'binary' 28 | # 1D array of floats typically is the probability for the binary 29 | # classification problem, e.g., [0.2, 0.3, 0.9] 30 | CONTINUOUS = 'continuous' 31 | # 2D array of the floats for the multilabel/multiclass classification problem. 32 | # Dimension: BatchSize x # Class 33 | # e.g., [[0.2, 0.8, 0.9], [0.1, 0.2, 0.7]]. 34 | CONTINUOUS_MULTIOUTPUT = 'continuous-multioutput' 35 | # 1D array of class identifiers, e.g, ['a', 'b'] or [1, 29, 12]. 36 | MULTICLASS = 'multiclass' 37 | # 2D lists of multiclass encodings of the classes, e.g., [[1,2,0], [3,2,0]] 38 | # The list can be ragged, e.g, [ ['a', 'b'], ['c'] ] 39 | MULTICLASS_MULTIOUTPUT = 'multiclass-multioutput' 40 | # 2D array of one-hot encoding of the classes, e.g., [[0,1,0], [0,0,1]] 41 | # This is a special case for "multilabel-indicator" except that only one 42 | # class is set to positive per example. 43 | MULTICLASS_INDICATOR = 'multiclass-indicator' 44 | 45 | 46 | class AverageType(enum.StrEnum): # pylint: disable=invalid-enum-extension 47 | """Average type of the confusion matrix.""" 48 | 49 | # Treats each class as one example and calculates the metrics on the total 50 | # aggregates of the result. 51 | MICRO = 'micro' 52 | # Macro calculates metrics for each class first, then average them across 53 | # classes. 54 | MACRO = 'macro' 55 | # Macro average with explicit weights per class. 56 | WEIGHTED = 'weighted' 57 | # Samples average calculates the metrics per example, and average them across 58 | # all examples. 59 | SAMPLES = 'samples' 60 | # Average for the positive label only. 61 | BINARY = 'binary' 62 | -------------------------------------------------------------------------------- /ml_metrics/_src/aggregates/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for aggregates.""" 15 | from __future__ import annotations 16 | 17 | import collections 18 | import dataclasses 19 | 20 | from ml_metrics._src.aggregates import base 21 | from ml_metrics._src.aggregates import types 22 | from ml_metrics._src.utils import math_utils 23 | 24 | 25 | @dataclasses.dataclass 26 | class MeanState(base.CallableMetric): 27 | """Mergeable states for batch update in an aggregate function.""" 28 | 29 | total: types.NumbersT = 0.0 30 | count: types.NumbersT = 0 31 | 32 | def new(self, inputs: types.NumbersT) -> types.NumbersT: 33 | return MeanState(total=sum(inputs), count=len(inputs)) 34 | 35 | def merge(self, other: MeanState): 36 | self.total += other.total 37 | self.count += other.count 38 | 39 | def result(self): 40 | return math_utils.safe_divide(self.total, self.count) 41 | 42 | 43 | @dataclasses.dataclass 44 | class TupleMeanState(base.CallableMetric): 45 | """MeanState for a tuple of inputs.""" 46 | 47 | states: tuple[MeanState, ...] = () 48 | 49 | def new(self, *inputs: tuple[types.NumbersT, ...]) -> TupleMeanState: 50 | return TupleMeanState(tuple(MeanState().new(x) for x in inputs)) 51 | 52 | def merge(self, other: TupleMeanState): 53 | if not self.states: 54 | self.states = tuple(MeanState() for _ in other.states) 55 | for state, state_other in zip(self.states, other.states, strict=True): 56 | state.merge(state_other) 57 | 58 | def result(self): 59 | return tuple(state.result() for state in self.states) 60 | 61 | 62 | @dataclasses.dataclass 63 | class FrequencyState: 64 | """Mergeable frequency states for batch update in an aggregate function.""" 65 | 66 | # TODO(b/331796958): Optimize storage consumption 67 | counter: collections.Counter[str] = dataclasses.field( 68 | default_factory=collections.Counter 69 | ) 70 | count: int = 0 71 | 72 | def merge(self, other: 'FrequencyState'): 73 | self.counter.update(other.counter) 74 | self.count += other.count 75 | 76 | def result(self) -> list[tuple[str, float]]: 77 | result = [ 78 | (key, math_utils.safe_divide(value, self.count)) 79 | for key, value in self.counter.items() 80 | ] 81 | result = sorted(result, key=lambda x: (-x[1], x[0])) 82 | return result 83 | -------------------------------------------------------------------------------- /ml_metrics/_src/chainables/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ml_metrics/_src/chainables/io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """I/O utilities for chainables.""" 15 | from __future__ import annotations 16 | 17 | from collections.abc import Iterable, Iterator 18 | import dataclasses as dc 19 | from typing import Any, Self, TypeVar 20 | 21 | from ml_metrics._src import types 22 | from ml_metrics._src.utils import iter_utils 23 | 24 | _T = TypeVar('_T') 25 | 26 | 27 | @dc.dataclass(frozen=True, slots=True) 28 | class ShardConfig: 29 | shard_index: int = 0 30 | num_shards: int = 1 31 | start_index: int = 0 32 | parent: ShardConfig | None = dc.field(default=None, kw_only=True) 33 | 34 | 35 | @dc.dataclass(frozen=True) 36 | class SequenceDataSource(types.Recoverable, Iterable[_T]): 37 | """A shardable sequence data source.""" 38 | data: types.RandomAccessible[_T] 39 | ignore_error: bool = dc.field(kw_only=True, default=False) 40 | batch_size: dc.InitVar[int] = 0 41 | _shard_state: ShardConfig = dc.field(default_factory=ShardConfig) 42 | _start: int = 0 43 | _end: int | None = None 44 | 45 | def __post_init__(self, batch_size: int): 46 | data = self.data 47 | if not hasattr(data, '__getitem__') or not hasattr(data, '__len__'): 48 | raise TypeError(f'data is not indexable, got {type(data)=}') 49 | # Use MergedSequences even for a single sequence to enforce iterating by 50 | # random access so that the iterator is continuable after exception. 51 | sequences = [self.data] 52 | if isinstance(self.data, iter_utils.MergedSequences): 53 | sequences = self.data.sequences 54 | data = iter_utils.MergedSequences(sequences, batch_size) 55 | object.__setattr__(self, 'data', data) 56 | 57 | @classmethod 58 | def from_sequences( 59 | cls, 60 | sequences: Iterable[types.RandomAccessible[_T]], 61 | batch_size: int = 0, 62 | ignore_error: bool = False, 63 | ) -> Self: 64 | return cls( 65 | iter_utils.MergedSequences(sequences), 66 | ignore_error=ignore_error, 67 | batch_size=batch_size, 68 | ) 69 | 70 | def shard(self, shard_index: int, num_shards: int, offset: int = 0) -> Self: 71 | if num_shards < 1: 72 | raise ValueError(f'num_shards must be positive, got {num_shards=}') 73 | interval, remainder = divmod(self.end - self.start, num_shards) 74 | start, adjusted_interval = self.start, 0 75 | for i in range(shard_index + 1): 76 | adjusted_interval = interval + 1 if i < remainder else interval 77 | start += adjusted_interval if i < shard_index else 0 78 | shard_state = ShardConfig( 79 | shard_index, num_shards, offset, parent=self._shard_state 80 | ) 81 | return dc.replace( 82 | self, 83 | _shard_state=shard_state, 84 | _start=start + offset, 85 | _end=start + adjusted_interval, 86 | ) 87 | 88 | @property 89 | def start(self) -> int: 90 | return self._start 91 | 92 | @property 93 | def end(self) -> int: 94 | return len(self.data) if self._end is None else self._end 95 | 96 | def __len__(self) -> int: 97 | return self.end - self.start 98 | 99 | @property 100 | def state(self) -> ShardConfig: 101 | return self._shard_state 102 | 103 | def from_state(self, shard_state: ShardConfig) -> Self: 104 | """Iterates the data source given a shard index.""" 105 | if shard_state.parent is not None: 106 | result = self.from_state(shard_state.parent) 107 | else: 108 | result = SequenceDataSource(self.data, ignore_error=self.ignore_error) 109 | return result.shard( 110 | shard_state.shard_index, shard_state.num_shards, shard_state.start_index 111 | ) 112 | 113 | def iterate(self) -> SequenceIterator[_T]: 114 | return SequenceIterator(self) 115 | 116 | def __iter__(self) -> Iterator[_T]: 117 | return self.iterate() 118 | 119 | def __getitem__(self, index: int | Any) -> _T: 120 | """Iterates the data source given a shard index.""" 121 | if isinstance(index, slice): 122 | start, stop, step = index.start, index.stop, index.step 123 | start = start or 0 # Convert None to 0. 124 | start += self.start 125 | start = min(start, self.end) 126 | if stop is None: 127 | stop = self.end 128 | else: 129 | stop = self.end + stop if stop < 0 else self.start + stop 130 | stop = min(stop, self.end) 131 | return list(self.data[slice(start, stop, step)]) 132 | 133 | index = self.end + index if index < 0 else self.start + index 134 | return self.data[index] 135 | 136 | 137 | class SequenceIterator(types.Recoverable, Iterator[_T]): 138 | """A sharded data source for chainables.""" 139 | 140 | config: SequenceDataSource 141 | _index: int 142 | 143 | def __init__(self, config: SequenceDataSource): 144 | self._index = config.start 145 | iter_ = iter_utils.iter_ignore_error if config.ignore_error else iter 146 | self._it = iter_(config.data[config.start : config.end]) 147 | self.config = config 148 | 149 | def from_state(self, shard_state: ShardConfig) -> Self: 150 | return self.__class__(self.config.from_state(shard_state)) 151 | 152 | @property 153 | def state(self) -> ShardConfig: 154 | start_index = self._index - self.config.start 155 | return dc.replace(self.config.state, start_index=start_index) 156 | 157 | def __next__(self) -> _T: 158 | """Iterates the data source given a shard index.""" 159 | result = next(self._it) 160 | self._index += 1 161 | return result 162 | 163 | def __iter__(self) -> Self: 164 | """Iterates the data source given a shard index.""" 165 | return self 166 | 167 | 168 | @dc.dataclass(frozen=True) 169 | class ShardedIterable(types.Recoverable, Iterable[_T]): 170 | """A sharded data source for any iterable.""" 171 | data: Iterable[_T] 172 | _shard_state: ShardConfig = dc.field(default_factory=ShardConfig) 173 | 174 | def __post_init__(self): 175 | data = self.data 176 | if not (isinstance(data, Iterable) and not isinstance(data, Iterator)): 177 | raise TypeError( 178 | f'input has to be an iterable but not an iterator, got {type(data)=}' 179 | ) 180 | if self._shard_state.num_shards < 1: 181 | raise ValueError(f'num_shards must be positive, got {self._shard_state=}') 182 | 183 | def shard(self, shard_index: int, num_shards: int) -> Self: 184 | return dc.replace(self, _shard_state=ShardConfig(shard_index, num_shards)) 185 | 186 | @property 187 | def state(self) -> ShardConfig: 188 | return self._shard_state 189 | 190 | def from_state(self, shard_state: ShardConfig) -> Self: 191 | return dc.replace(self, _shard_state=shard_state) 192 | 193 | def iterate(self) -> DataIterator[_T]: 194 | return DataIterator(self) 195 | 196 | def __iter__(self) -> Iterator[_T]: 197 | return self.iterate() 198 | 199 | 200 | class DataIterator(types.Recoverable, Iterator[_T]): 201 | """An sharded iterator for an iterable.""" 202 | 203 | def __init__(self, config: ShardedIterable): 204 | self.config = config 205 | self._index = 0 206 | self._it = iter(config.data) 207 | 208 | def from_state(self, shard_state: ShardConfig) -> Self: 209 | return self.__class__(self.config.from_state(shard_state)) 210 | 211 | @property 212 | def state(self) -> ShardConfig: 213 | return dc.replace(self.config.state, start_index=self._index) 214 | 215 | def __next__(self) -> _T: 216 | """Iterates the data source given a shard index.""" 217 | while self._index < self.config.state.start_index: 218 | _ = next(self._it) 219 | self._index += 1 220 | shard_index = self.config.state.shard_index 221 | num_shards = self.config.state.num_shards 222 | while self._index % num_shards != shard_index: 223 | _ = next(self._it) 224 | self._index += 1 225 | result = next(self._it) 226 | self._index += 1 227 | return result 228 | 229 | def __iter__(self) -> Self: 230 | return self 231 | -------------------------------------------------------------------------------- /ml_metrics/_src/chainables/io_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from ml_metrics._src.chainables import io 18 | from ml_metrics._src.utils import iter_utils 19 | from ml_metrics._src.utils import test_utils 20 | 21 | 22 | class SequenceDataSourceTest(parameterized.TestCase): 23 | 24 | @parameterized.named_parameters([ 25 | dict( 26 | testcase_name='with_default_num_shards', 27 | num_shards=1, 28 | expected=[[0, 1, 2]], 29 | ), 30 | dict( 31 | testcase_name='with_two_shards', 32 | num_shards=2, 33 | expected=[[0, 1], [2]], 34 | ), 35 | dict( 36 | testcase_name='with_three_shards', 37 | num_shards=3, 38 | expected=[[0], [1], [2]], 39 | ), 40 | dict( 41 | testcase_name='with_four_shards', 42 | num_shards=4, 43 | expected=[[0], [1], [2], []], 44 | ), 45 | ]) 46 | def test_sharded_sequence(self, num_shards, expected): 47 | ds = io.SequenceDataSource(range(3)) 48 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)] 49 | self.assertEqual(expected, actual) 50 | 51 | def test_sharded_sequence_len(self): 52 | ds = io.SequenceDataSource(range(3)) 53 | self.assertLen(ds, 3) 54 | self.assertLen(ds.shard(0, 1, offset=1), 2) 55 | self.assertLen(ds.shard(0, 2), 2) 56 | self.assertLen(ds.shard(0, 2, offset=1), 1) 57 | self.assertLen(ds.shard(1, 2), 1) 58 | 59 | def test_sharded_sequence_repeated_shard_len(self): 60 | ds = io.SequenceDataSource(range(10)) 61 | self.assertLen(ds, 10) 62 | self.assertLen(ds.shard(0, 2).shard(0, 2), 3) 63 | self.assertLen(ds.shard(0, 2).shard(1, 2), 2) 64 | self.assertLen(ds.shard(1, 2).shard(0, 2), 3) 65 | self.assertLen(ds.shard(1, 2).shard(1, 2), 2) 66 | 67 | def test_sharded_sequence_repeated_shard(self): 68 | ds = io.SequenceDataSource(range(10)) 69 | self.assertEqual([0, 1, 2], list(ds.shard(0, 2).shard(0, 2))) 70 | self.assertEqual([3, 4], list(ds.shard(0, 2).shard(1, 2))) 71 | self.assertEqual([5, 6, 7], list(ds.shard(1, 2).shard(0, 2))) 72 | self.assertEqual([8, 9], list(ds.shard(1, 2).shard(1, 2))) 73 | 74 | def test_sharded_sequence_serialization(self): 75 | ds = io.SequenceDataSource(range(3)) 76 | it = ds.iterate() 77 | self.assertEqual(0, next(it)) 78 | ds = ds.from_state(it.state) 79 | self.assertEqual([1, 2], list(it)) 80 | self.assertEqual([1, 2], list(ds)) 81 | 82 | def test_sharded_sequence_serialization_after_shard(self): 83 | ds = io.SequenceDataSource(range(4)) 84 | it = ds.shard(1, num_shards=2).iterate() 85 | self.assertEqual(2, next(it)) 86 | ds = ds.iterate().from_state(it.state) 87 | self.assertEqual([3], list(it)) 88 | self.assertEqual([3], list(ds)) 89 | 90 | def test_sharded_sequence_serialization_after_shard_twice(self): 91 | ds = io.SequenceDataSource(range(8)) 92 | it = ds.shard(0, 2).shard(1, 2).iterate() 93 | self.assertEqual(2, next(it)) 94 | ds = ds.iterate().from_state(it.state) 95 | self.assertEqual([3], list(it)) 96 | self.assertEqual([3], list(ds)) 97 | 98 | def test_merged_sequences(self): 99 | ds = io.SequenceDataSource.from_sequences([range(2), range(2, 6)]) 100 | self.assertEqual([0, 1, 2, 3, 4, 5], list(ds.shard(0, 1))) 101 | self.assertEqual([0, 1, 2], list(ds.shard(0, 2))) 102 | self.assertEqual([3, 4, 5], list(ds.shard(1, 2))) 103 | 104 | def test_sharded_sequence_with_non_indexable_data(self): 105 | with self.assertRaisesRegex(TypeError, 'data is not indexable'): 106 | io.SequenceDataSource(0) # pytype: disable=wrong-arg-types 107 | 108 | def test_sharded_sequence_with_invalid_num_shards_raises_error(self): 109 | with self.assertRaisesRegex(ValueError, 'num_shards must be positive'): 110 | _ = io.SequenceDataSource(range(3)).shard(0, 0) 111 | 112 | def test_sequence_ignore_error_single_sequence(self): 113 | ds = io.SequenceDataSource( 114 | test_utils.SequenceWithExc(5, 2), ignore_error=True 115 | ) 116 | self.assertEqual([0, 1, 3, 4], list(ds)) 117 | 118 | def test_sequence_ignore_error_multiple_sequences(self): 119 | ds = io.SequenceDataSource.from_sequences( 120 | [ 121 | test_utils.SequenceWithExc(5, 2), 122 | test_utils.SequenceWithExc(5, 3), 123 | ], 124 | ignore_error=True, 125 | ) 126 | self.assertEqual([0, 1, 3, 4, 0, 1, 2, 4], list(ds)) 127 | 128 | def test_sequence_batch_size(self): 129 | batch_size = 2 130 | ds = io.SequenceDataSource.from_sequences( 131 | [range(4), range(4, 10)], batch_size=batch_size 132 | ) 133 | sliced = ds.data[:3] 134 | self.assertIsInstance(sliced, iter_utils._RangeIterator) 135 | self.assertEqual(sliced._batch_size, batch_size) 136 | self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], list(ds)) 137 | 138 | @parameterized.named_parameters([ 139 | dict(testcase_name='[0]', at=0, expected=0), 140 | dict(testcase_name='[1]', at=1, expected=1), 141 | dict(testcase_name='[-1]', at=8, expected=8), 142 | dict(testcase_name='[:9]', at=slice(9), expected=list(range(9))), 143 | dict(testcase_name='[:-1]', at=slice(-1), expected=list(range(8))), 144 | dict(testcase_name='[2:-1]', at=slice(2, -1), expected=list(range(2, 8))), 145 | dict(testcase_name='[3:7]', at=slice(3, 7), expected=list(range(3, 7))), 146 | dict(testcase_name='[:]', at=slice(None), expected=list(range(9))), 147 | ]) 148 | def test_sequence_getitem_single_shard(self, at, expected): 149 | ds = io.SequenceDataSource.from_sequences([range(4), range(4, 9)]) 150 | self.assertEqual(expected, ds[at]) 151 | 152 | @parameterized.named_parameters([ 153 | dict(testcase_name='shard0[0]', shard=0, at=0, expected=0), 154 | dict(testcase_name='shard0[1]', shard=0, at=1, expected=1), 155 | dict(testcase_name='shard0[-1]', shard=0, at=-1, expected=4), 156 | dict(testcase_name='shard1[0]', shard=1, at=0, expected=5), 157 | dict(testcase_name='shard1[1]', shard=1, at=1, expected=6), 158 | dict(testcase_name='shard1[-1]', shard=1, at=-1, expected=8), 159 | dict( 160 | testcase_name='shard0[:]', 161 | shard=0, 162 | at=slice(None), 163 | expected=list(range(5)), 164 | ), 165 | dict( 166 | testcase_name='shard1[:]', 167 | shard=1, 168 | at=slice(None), 169 | expected=list(range(5, 9)), 170 | ), 171 | dict( 172 | testcase_name='shard0[1:-1]', 173 | shard=0, 174 | at=slice(1, -1), 175 | expected=list(range(1, 4)), 176 | ), 177 | dict( 178 | testcase_name='shard1[1:-1]', 179 | shard=1, 180 | at=slice(1, -1), 181 | expected=list(range(6, 8)), 182 | ), 183 | ]) 184 | def test_sequence_getitem_two_shards(self, shard, at, expected): 185 | ds = io.SequenceDataSource.from_sequences([range(4), range(4, 9)]) 186 | self.assertEqual(expected, ds.shard(shard, 2)[at]) 187 | 188 | def test_sequence_getitem_raise(self): 189 | with self.assertRaisesRegex(IndexError, 'Index 4 is out of range'): 190 | _ = io.SequenceDataSource(range(3))[4] 191 | 192 | 193 | class IterableDataSourceTest(parameterized.TestCase): 194 | 195 | def test_sharded_iterable(self): 196 | ds = io.ShardedIterable(range(3)) 197 | self.assertEqual([0, 1, 2], list(ds)) 198 | 199 | def test_sharded_iterable_shard(self): 200 | ds = io.ShardedIterable(range(3)) 201 | num_shards = 2 202 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)] 203 | expected = [[0, 2], [1]] 204 | self.assertEqual(expected, actual) 205 | 206 | def test_sharded_iterable_num_shards_more_than_data(self): 207 | ds = io.ShardedIterable(range(2)) 208 | num_shards = 3 209 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)] 210 | expected = [[0], [1], []] 211 | self.assertEqual(expected, actual) 212 | 213 | def test_sharded_iterable_serialization(self): 214 | ds = io.ShardedIterable(range(3)) 215 | it = ds.iterate() 216 | self.assertEqual(0, next(it)) 217 | ds = ds.iterate().from_state(it.state) 218 | self.assertEqual([1, 2], list(it)) 219 | self.assertEqual([1, 2], list(ds)) 220 | 221 | def test_sharded_iterable_shard_serialization(self): 222 | ds = io.ShardedIterable(range(6)) 223 | it = ds.shard(1, num_shards=2).iterate() 224 | self.assertEqual(1, next(it)) 225 | ds = ds.from_state(it.state) 226 | self.assertEqual([3, 5], list(it)) 227 | self.assertEqual([3, 5], list(ds)) 228 | 229 | def test_sharded_iterator_with_iterator_raises_error(self): 230 | with self.assertRaisesRegex( 231 | TypeError, 'input has to be an iterable but not an iterator' 232 | ): 233 | _ = list(io.ShardedIterable(iter(range(3)))) 234 | 235 | def test_sharded_iterator_with_non_iteratable_raises_error(self): 236 | with self.assertRaisesRegex( 237 | TypeError, 'input has to be an iterable but not an iterator' 238 | ): 239 | _ = io.ShardedIterable(3) # pytype: disable=wrong-arg-types 240 | 241 | def test_sharded_iterator_with_invalid_num_shards_raises_error(self): 242 | with self.assertRaisesRegex(ValueError, 'num_shards must be positive'): 243 | _ = list(io.ShardedIterable(range(3), io.ShardConfig(num_shards=0))) 244 | 245 | 246 | if __name__ == '__main__': 247 | absltest.main() 248 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Individual Retrieval based metrics.""" 15 | 16 | from collections.abc import Sequence 17 | from ml_metrics._src.aggregates import retrieval 18 | from ml_metrics._src.aggregates import types 19 | 20 | 21 | # TODO: b/368688941 - Remove this alias once all users are migrated to the new 22 | # module structure. 23 | TopKRetrievalAggFn = retrieval.TopKRetrievalAggFn 24 | 25 | _METRIC_PYDOC_POSTFIX = """ 26 | 27 | Args: 28 | y_true: array of sample's true labels 29 | y_pred: array of sample's label predictions 30 | k_list: k_list is only applicable for average_type != Samples and 31 | multiclass/multioutput input types. It is a list of topk each of which 32 | slices y_pred by y_pred[:topk] assuming the predictions are sorted in 33 | descending order. Default 'None' means consider all outputs in the 34 | prediction. 35 | input_type: one input type from types.InputType 36 | 37 | Returns: 38 | Tuple with metric value(s) 39 | """ 40 | 41 | 42 | def topk_retrieval_metrics( 43 | metrics: Sequence[retrieval.RetrievalMetric], 44 | *, 45 | y_true, 46 | y_pred, 47 | k_list: list[int] | None = None, 48 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 49 | ) -> tuple[tuple[float, ...], ...]: 50 | """Compute multiple metrics together for better efficiency. 51 | 52 | Args: 53 | metrics: List of CFM metrics 54 | y_true: array of sample's true labels 55 | y_pred: array of sample's label predictions 56 | k_list: k_list is only applicable for average_type != Samples and 57 | multiclass/multioutput input types. It is a list of topk each of which 58 | slices y_pred by y_pred[:topk] assuming the predictions are sorted in 59 | descending order. Default 'None' means consider all outputs in the 60 | prediction. 61 | input_type: one input type from types.InputType 62 | 63 | Returns: 64 | Tuple containing the evaluation metric values. in the corresponding order of 65 | given metric names in metrics list. 66 | """ 67 | return retrieval.TopKRetrievalAggFn( 68 | metrics=metrics, k_list=k_list, input_type=input_type 69 | )(y_true, y_pred) 70 | 71 | 72 | def precision( 73 | y_true, 74 | y_pred, 75 | k_list: list[int] | None = None, 76 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 77 | ) -> tuple[float, ...]: 78 | """Compute Precision Retrieval metric.""" 79 | return retrieval.TopKRetrieval( 80 | metrics=retrieval.RetrievalMetric.PRECISION, 81 | k_list=k_list, 82 | input_type=input_type, 83 | ).as_agg_fn()(y_true, y_pred) 84 | 85 | 86 | precision.__doc__ += _METRIC_PYDOC_POSTFIX 87 | 88 | 89 | def ppv( 90 | y_true, 91 | y_pred, 92 | k_list: list[int] | None = None, 93 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 94 | ) -> tuple[float, ...]: 95 | """Compute PPV Retrieval metric.""" 96 | return retrieval.TopKRetrievalAggFn( 97 | metrics=retrieval.RetrievalMetric.PPV, 98 | k_list=k_list, 99 | input_type=input_type, 100 | )(y_true, y_pred) 101 | 102 | 103 | ppv.__doc__ += _METRIC_PYDOC_POSTFIX 104 | 105 | 106 | def recall( 107 | y_true, 108 | y_pred, 109 | k_list: list[int] | None = None, 110 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 111 | ) -> tuple[float, ...]: 112 | """Compute Recall Retrieval metric.""" 113 | return retrieval.TopKRetrievalAggFn( 114 | metrics=retrieval.RetrievalMetric.RECALL, 115 | k_list=k_list, 116 | input_type=input_type, 117 | )(y_true, y_pred) 118 | 119 | 120 | recall.__doc__ += _METRIC_PYDOC_POSTFIX 121 | 122 | 123 | def sensitivity( 124 | y_true, 125 | y_pred, 126 | k_list: list[int] | None = None, 127 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 128 | ) -> tuple[float, ...]: 129 | """Compute Sensitivity Retrieval metric.""" 130 | return retrieval.TopKRetrievalAggFn( 131 | metrics=retrieval.RetrievalMetric.SENSITIVITY, 132 | k_list=k_list, 133 | input_type=input_type, 134 | )(y_true, y_pred) 135 | sensitivity.__doc__ += _METRIC_PYDOC_POSTFIX 136 | 137 | 138 | def tpr( 139 | y_true, 140 | y_pred, 141 | k_list: list[int] | None = None, 142 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 143 | ) -> tuple[float, ...]: 144 | """Compute TPR Retrieval metric.""" 145 | return retrieval.TopKRetrievalAggFn( 146 | metrics=retrieval.RetrievalMetric.TPR, 147 | k_list=k_list, 148 | input_type=input_type, 149 | )(y_true, y_pred) 150 | 151 | 152 | tpr.__doc__ += _METRIC_PYDOC_POSTFIX 153 | 154 | 155 | def intersection_over_union( 156 | y_true, 157 | y_pred, 158 | k_list: list[int] | None = None, 159 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 160 | ) -> tuple[float, ...]: 161 | """Compute Intersection Over Union Retrieval metric.""" 162 | return retrieval.TopKRetrievalAggFn( 163 | metrics=retrieval.RetrievalMetric.INTERSECTION_OVER_UNION, 164 | k_list=k_list, 165 | input_type=input_type, 166 | )(y_true, y_pred) 167 | intersection_over_union.__doc__ += _METRIC_PYDOC_POSTFIX 168 | 169 | 170 | def positive_predictive_value( 171 | y_true, 172 | y_pred, 173 | k_list: list[int] | None = None, 174 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 175 | ) -> tuple[float, ...]: 176 | """Compute Positive Predictive Value Retrieval metric.""" 177 | return retrieval.TopKRetrievalAggFn( 178 | metrics=retrieval.RetrievalMetric.POSITIVE_PREDICTIVE_VALUE, 179 | k_list=k_list, 180 | input_type=input_type, 181 | )(y_true, y_pred) 182 | positive_predictive_value.__doc__ += _METRIC_PYDOC_POSTFIX 183 | 184 | 185 | def f1_score( 186 | y_true, 187 | y_pred, 188 | k_list: list[int] | None = None, 189 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 190 | ) -> tuple[float, ...]: 191 | """Compute F1 Score Retrieval metric.""" 192 | return retrieval.TopKRetrievalAggFn( 193 | metrics=retrieval.RetrievalMetric.F1_SCORE, 194 | k_list=k_list, 195 | input_type=input_type, 196 | )(y_true, y_pred) 197 | 198 | 199 | f1_score.__doc__ += _METRIC_PYDOC_POSTFIX 200 | 201 | 202 | def miss_rate( 203 | y_true, 204 | y_pred, 205 | k_list: list[int] | None = None, 206 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 207 | ) -> tuple[float, ...]: 208 | """Compute Miss Rate Retrieval metric.""" 209 | return retrieval.TopKRetrievalAggFn( 210 | metrics=retrieval.RetrievalMetric.MISS_RATE, 211 | k_list=k_list, 212 | input_type=input_type, 213 | )(y_true, y_pred) 214 | 215 | 216 | miss_rate.__doc__ += _METRIC_PYDOC_POSTFIX 217 | 218 | 219 | def mean_average_precision( 220 | y_true, 221 | y_pred, 222 | k_list: list[int] | None = None, 223 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 224 | ) -> tuple[float, ...]: 225 | """Compute Mean Average Precision Retrieval metric.""" 226 | return retrieval.TopKRetrievalAggFn( 227 | metrics=retrieval.RetrievalMetric.MEAN_AVERAGE_PRECISION, 228 | k_list=k_list, 229 | input_type=input_type, 230 | )(y_true, y_pred) 231 | mean_average_precision.__doc__ += _METRIC_PYDOC_POSTFIX 232 | 233 | 234 | def mean_reciprocal_rank( 235 | y_true, 236 | y_pred, 237 | k_list: list[int] | None = None, 238 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 239 | ) -> tuple[float, ...]: 240 | """Compute Mean Reciprocal Rank Retrieval metric.""" 241 | return retrieval.TopKRetrievalAggFn( 242 | metrics=retrieval.RetrievalMetric.MEAN_RECIPROCAL_RANK, 243 | k_list=k_list, 244 | input_type=input_type, 245 | )(y_true, y_pred) 246 | mean_reciprocal_rank.__doc__ += _METRIC_PYDOC_POSTFIX 247 | 248 | 249 | def accuracy( 250 | y_true, 251 | y_pred, 252 | k_list: list[int] | None = None, 253 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 254 | ) -> tuple[float, ...]: 255 | """Compute Accuracy Retrieval metric.""" 256 | return retrieval.TopKRetrievalAggFn( 257 | metrics=retrieval.RetrievalMetric.ACCURACY, 258 | k_list=k_list, 259 | input_type=input_type, 260 | )(y_true, y_pred) 261 | accuracy.__doc__ += _METRIC_PYDOC_POSTFIX 262 | 263 | 264 | def dcg_score( 265 | y_true, 266 | y_pred, 267 | k_list: list[int] | None = None, 268 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 269 | ) -> tuple[float, ...]: 270 | """Compute DCG Score Retrieval metric.""" 271 | return retrieval.TopKRetrievalAggFn( 272 | metrics=retrieval.RetrievalMetric.DCG_SCORE, 273 | k_list=k_list, 274 | input_type=input_type, 275 | )(y_true, y_pred) 276 | dcg_score.__doc__ += _METRIC_PYDOC_POSTFIX 277 | 278 | 279 | def ndcg_score( 280 | y_true, 281 | y_pred, 282 | k_list: list[int] | None = None, 283 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 284 | ) -> tuple[float, ...]: 285 | """Compute NDCG Score Retrieval metric.""" 286 | return retrieval.TopKRetrievalAggFn( 287 | metrics=retrieval.RetrievalMetric.NDCG_SCORE, 288 | k_list=k_list, 289 | input_type=input_type, 290 | )(y_true, y_pred) 291 | ndcg_score.__doc__ += _METRIC_PYDOC_POSTFIX 292 | 293 | 294 | def fowlkes_mallows_index( 295 | y_true, 296 | y_pred, 297 | k_list: list[int] | None = None, 298 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 299 | ) -> tuple[float, ...]: 300 | """Compute Fowlkes Mallows Index Retrieval metric.""" 301 | return retrieval.TopKRetrievalAggFn( 302 | metrics=retrieval.RetrievalMetric.FOWLKES_MALLOWS_INDEX, 303 | k_list=k_list, 304 | input_type=input_type, 305 | )(y_true, y_pred) 306 | fowlkes_mallows_index.__doc__ += _METRIC_PYDOC_POSTFIX 307 | 308 | 309 | def false_discovery_rate( 310 | y_true, 311 | y_pred, 312 | k_list: list[int] | None = None, 313 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 314 | ) -> tuple[float, ...]: 315 | """Compute False Discovery Rate Retrieval metric.""" 316 | return retrieval.TopKRetrievalAggFn( 317 | metrics=retrieval.RetrievalMetric.FALSE_DISCOVERY_RATE, 318 | k_list=k_list, 319 | input_type=input_type, 320 | )(y_true, y_pred) 321 | false_discovery_rate.__doc__ += _METRIC_PYDOC_POSTFIX 322 | 323 | 324 | def threat_score( 325 | y_true, 326 | y_pred, 327 | k_list: list[int] | None = None, 328 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT, 329 | ) -> tuple[float, ...]: 330 | """Compute Threat Score Retrieval metric.""" 331 | return retrieval.TopKRetrievalAggFn( 332 | metrics=retrieval.RetrievalMetric.THREAT_SCORE, 333 | k_list=k_list, 334 | input_type=input_type, 335 | )(y_true, y_pred) 336 | threat_score.__doc__ += _METRIC_PYDOC_POSTFIX 337 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/retrieval_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for retrieval.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from ml_metrics._src.aggregates import types 19 | from ml_metrics._src.metrics import retrieval 20 | import numpy as np 21 | 22 | 23 | class RetrievalTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters([ 26 | dict( 27 | testcase_name="Precision", 28 | metric_fn=retrieval.precision, 29 | # precision@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8 30 | # precision@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5/8 31 | expected=[5 / 8, 5.5 / 8], 32 | ), 33 | dict( 34 | testcase_name="PPV", 35 | metric_fn=retrieval.ppv, 36 | # ppv@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8 37 | # ppv@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5/8 38 | expected=[5 / 8, 5.5 / 8], 39 | ), 40 | dict( 41 | testcase_name="Recall", 42 | metric_fn=retrieval.recall, 43 | # recall@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8 44 | # recall@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8 45 | expected=[4.5 / 8, 5.5 / 8], 46 | ), 47 | dict( 48 | testcase_name="Sensitivity", 49 | metric_fn=retrieval.sensitivity, 50 | # sensitivity@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8 51 | # sensitivity@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8 52 | expected=[4.5 / 8, 5.5 / 8], 53 | ), 54 | dict( 55 | testcase_name="TPR", 56 | metric_fn=retrieval.tpr, 57 | # tpr@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8 58 | # tpr@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8 59 | expected=[4.5 / 8, 5.5 / 8], 60 | ), 61 | dict( 62 | testcase_name="POSITIVE_PREDICTIVE_VALUE", 63 | metric_fn=retrieval.positive_predictive_value, 64 | # positive_predictive_value@1 = mean([1, 0, 0, 1, 65 | # 1, 1, 0, 1]) = 5/8 66 | # positive_predictive_value@2 = mean([1, 1/2, 0, 1, 67 | # 1, 1, 0, 1]) = 5.5/8 68 | expected=[5 / 8, 5.5 / 8], 69 | ), 70 | dict( 71 | testcase_name="INTERSECTION_OVER_UNION", 72 | metric_fn=retrieval.intersection_over_union, 73 | # intersection_over_union@1 = mean([1, 0, 0, 1, 74 | # 1/2, 1, 0, 1]) = 4.5/8 75 | # intersection_over_union@2 = mean([1, 1/2, 0, 1, 76 | # 1/2, 1, 0, 1]) = 5/8 77 | expected=[4.5 / 8, 5 / 8], 78 | ), 79 | dict( 80 | testcase_name="Accuracy", 81 | metric_fn=retrieval.accuracy, 82 | # accuracy@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8 83 | # accuracy@2 = mean([1, 1, 0, 1, 1, 1, 0, 1]) = 6/8 84 | expected=[5 / 8, 6 / 8], 85 | ), 86 | dict( 87 | testcase_name="F1Score", 88 | metric_fn=retrieval.f1_score, 89 | # f1_score@1 = mean([1, 0, 0, 1, 1/1.5, 1, 0, 1]) = (4+2/3)/8 90 | # f1_score@2 = mean([1, 1/1.5, 0, 1, 1/1.5, 1, 0, 1]) = (4+4/3)/8 91 | expected=[(4 + 2 / 3) / 8, (4 + 4 / 3) / 8], 92 | ), 93 | dict( 94 | testcase_name="MissRate", 95 | metric_fn=retrieval.miss_rate, 96 | # miss_rate@1 = mean([0, 1, 1, 0, 1/2, 0, 1, 0]) = 3.5/8 97 | # miss_rate@2 = mean([0, 0, 1, 0, 1/2, 0, 1, 0]) = 2.5/8 98 | expected=[3.5 / 8, 2.5 / 8], 99 | ), 100 | dict( 101 | testcase_name="mAP", 102 | metric_fn=retrieval.mean_average_precision, 103 | # mAP@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8 104 | # mAP@2 = mean([1, 1/2, 0, 1, 1/2, 1, 0, 1]) = 5/8 105 | expected=[5 / 8, 5 / 8], 106 | ), 107 | dict( 108 | testcase_name="MRR", 109 | metric_fn=retrieval.mean_reciprocal_rank, 110 | # mRR@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8 111 | # mRR@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5 / 8 112 | expected=[5 / 8, 5.5 / 8], 113 | ), 114 | dict( 115 | testcase_name="DCG", 116 | metric_fn=retrieval.dcg_score, 117 | # DCG@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8 118 | # DCG@2 = mean([1, 1/log2(3), 0, 1, 1, 1, 0, 1]) 119 | # = (5 + 1 / log2(3)) / 8 120 | expected=[5 / 8, (1 / np.log2(3) + 5) / 8], 121 | ), 122 | dict( 123 | testcase_name="NDCG", 124 | metric_fn=retrieval.ndcg_score, 125 | # NDCG@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8 126 | # NDCG@2 = mean([1, 1/log2(3), 0, 1, 1/(1+1/log2(3)), 1, 0, 1]) 127 | # = (4 + 1 / log2(3) + 1/(1+1/log2(3))) / 8 128 | expected=[5 / 8, (4 + 1 / np.log2(3) + 1 / (1 + 1 / np.log2(3))) / 8], 129 | ), 130 | dict( 131 | testcase_name="FMI", 132 | metric_fn=retrieval.fowlkes_mallows_index, 133 | # FMI@1 = mean(sqrt([1, 0, 0, 1, 0.5, 1, 0, 1])) 134 | # FMI@2 = mean(sqrt([1, 0.5, 0, 1, 0.5, 1, 0, 1])) 135 | expected=[ 136 | np.sqrt([1, 0, 0, 1, 0.5, 1, 0, 1]).mean(), 137 | np.sqrt([1, 0.5, 0, 1, 0.5, 1, 0, 1]).mean(), 138 | ], 139 | ), 140 | dict( 141 | testcase_name="FDR", 142 | metric_fn=retrieval.false_discovery_rate, 143 | # FDR@1 = mean([0, 1, 1, 0, 0, 0, 1, 0]) = 3/8 144 | # FDR@2 = mean([0, 1/2, 1, 0, 0, 0, 1, 0]) = 2.5/8 145 | expected=[3 / 8, 2.5 / 8], 146 | ), 147 | dict( 148 | testcase_name="ThreatScore", 149 | metric_fn=retrieval.threat_score, 150 | # threat_score@1 = mean([1, 0, 0, 1, 0.5, 1, 0, 1]) = 4.5 / 8 151 | # threat_score@2 = mean([0.5, 0.5, 0, 0.5, 1/3, 0.5, 0, 0.5]) = 152 | # = (2.5 + 1/3) / 8 153 | expected=[4.5 / 8, (2.5 + 1 / 3) / 8], 154 | ), 155 | ]) 156 | def test_individual_metric(self, metric_fn, expected): 157 | k_list = [1, 2] 158 | y_pred = [["y"], ["n", "y"], ["y"], ["n"], ["y"], ["n"], ["n"], ["u"]] 159 | y_true = [["y"], ["y"], ["n"], ["n"], ["y", "n"], ["n"], ["y"], ["u"]] 160 | metric_doc_details = "\n".join( 161 | metric_fn.__doc__.split("\n")[1:] 162 | ).strip() # ignore the description line for comparison 163 | self.assertEqual( 164 | metric_doc_details, retrieval._METRIC_PYDOC_POSTFIX.strip() 165 | ) 166 | np.testing.assert_allclose( 167 | expected, 168 | metric_fn( 169 | y_true, 170 | y_pred, 171 | k_list=k_list, 172 | input_type=types.InputType.MULTICLASS_MULTIOUTPUT, 173 | ), 174 | ) 175 | 176 | 177 | if __name__ == "__main__": 178 | absltest.main() 179 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/rolling_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Individual statistics metrics.""" 15 | 16 | from ml_metrics._src.aggregates import rolling_stats 17 | from ml_metrics._src.aggregates import types 18 | 19 | 20 | _METRIC_PYDOC_POSTFIX = """ 21 | 22 | The metric is computed based on non-nan values within the batch. 23 | 24 | Args: 25 | batch: A batch of numbers. 26 | 27 | Returns: 28 | Metric value. 29 | """ 30 | 31 | 32 | def var(batch: types.NumbersT) -> float: 33 | """Computes the variance in a batch.""" 34 | return rolling_stats.MeanAndVariance().add(batch).var 35 | 36 | 37 | var.__doc__ += _METRIC_PYDOC_POSTFIX 38 | 39 | 40 | def stddev(batch: types.NumbersT) -> float: 41 | """Computes the standard deviation in a batch.""" 42 | return rolling_stats.MeanAndVariance().add(batch).stddev 43 | 44 | 45 | stddev.__doc__ += _METRIC_PYDOC_POSTFIX 46 | 47 | 48 | def mean(batch: types.NumbersT) -> float: 49 | """Computes the mean in a batch.""" 50 | return rolling_stats.MeanAndVariance().add(batch).mean 51 | 52 | 53 | mean.__doc__ += _METRIC_PYDOC_POSTFIX 54 | 55 | 56 | def count(batch: types.NumbersT) -> int: 57 | """Computes the number of elements in a batch.""" 58 | return rolling_stats.MeanAndVariance().add(batch).count 59 | 60 | 61 | count.__doc__ += _METRIC_PYDOC_POSTFIX 62 | 63 | 64 | def total(batch: types.NumbersT) -> float: 65 | """Computes the total sum of a batch.""" 66 | return rolling_stats.MeanAndVariance().add(batch).total 67 | 68 | 69 | total.__doc__ += _METRIC_PYDOC_POSTFIX 70 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/rolling_stats_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for stats.""" 15 | 16 | from absl.testing import parameterized 17 | 18 | from ml_metrics._src.metrics import rolling_stats 19 | import numpy as np 20 | 21 | from absl.testing import absltest 22 | 23 | _BATCH_WITH_NAN = np.concatenate( 24 | [np.random.randn(1000) + 1e7, [np.nan] * 100], axis=0 25 | ) 26 | 27 | 28 | class StatsTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters([ 31 | dict( 32 | testcase_name='variance', 33 | metric_fn=rolling_stats.var, 34 | expected=np.nanvar(_BATCH_WITH_NAN), 35 | ), 36 | dict( 37 | testcase_name='stddev', 38 | metric_fn=rolling_stats.stddev, 39 | expected=np.nanstd(_BATCH_WITH_NAN), 40 | ), 41 | dict( 42 | testcase_name='mean', 43 | metric_fn=rolling_stats.mean, 44 | expected=np.nanmean(_BATCH_WITH_NAN), 45 | ), 46 | dict( 47 | testcase_name='count', 48 | metric_fn=rolling_stats.count, 49 | expected=np.nansum(~np.isnan(_BATCH_WITH_NAN)), 50 | ), 51 | dict( 52 | testcase_name='total', 53 | metric_fn=rolling_stats.total, 54 | expected=np.nansum(_BATCH_WITH_NAN), 55 | ), 56 | ]) 57 | def test_rolling_stats_individual_metrics(self, metric_fn, expected): 58 | got = metric_fn(_BATCH_WITH_NAN) 59 | np.testing.assert_almost_equal(got, expected) 60 | 61 | @parameterized.named_parameters([ 62 | dict( 63 | testcase_name='variance', 64 | metric_fn=rolling_stats.var, 65 | expected=np.nan, 66 | ), 67 | dict( 68 | testcase_name='stddev', 69 | metric_fn=rolling_stats.stddev, 70 | expected=np.nan, 71 | ), 72 | dict( 73 | testcase_name='mean', 74 | metric_fn=rolling_stats.mean, 75 | expected=np.nan, 76 | ), 77 | dict( 78 | testcase_name='count', 79 | metric_fn=rolling_stats.count, 80 | expected=0, 81 | ), 82 | dict( 83 | testcase_name='total', 84 | metric_fn=rolling_stats.total, 85 | expected=0.0, 86 | ), 87 | ]) 88 | def test_rolling_stats_individual_metrics_empty_batch( 89 | self, metric_fn, expected 90 | ): 91 | got = metric_fn([]) 92 | np.testing.assert_almost_equal(got, expected) 93 | 94 | 95 | if __name__ == '__main__': 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Individual text stats metrics.""" 15 | 16 | from collections.abc import Sequence 17 | 18 | from ml_metrics._src.aggregates import rolling_stats 19 | from ml_metrics._src.aggregates import text 20 | from ml_metrics._src.signals import text as text_scores 21 | from ml_metrics._src.utils import iter_utils 22 | from ml_metrics._src.tools.telemetry import telemetry 23 | 24 | 25 | @telemetry.WithTelemetry( 26 | api='ml_metrics', category='metric', reference='topk_word_ngrams' 27 | ) 28 | def topk_word_ngrams( 29 | texts: Sequence[str], 30 | k: int, 31 | n: int, 32 | use_first_ngram_only: bool = False, 33 | count_duplicate: bool = True, 34 | ) -> list[tuple[str, float]]: 35 | """Top k word n-grams metrics. 36 | 37 | Identify the top `k` frequent occurring word n-grams with a case-insensitive 38 | approach. The text will first be cleaned by removing non-alphabetic characters 39 | and spaces, and then converted to lowercase before computing the top k 40 | n-grams. When multiple n-grams share the same frequency, alphabetical order 41 | will be used as a tie-breaker. The result is a list of tuples containing the 42 | n-gram pattern and its corresponding frequency. The list includes either `k` 43 | or the number of distinct n-grams tuples, whichever is less. 44 | 45 | Args: 46 | texts: 47 | Sequence of texts. 48 | k: 49 | Number of most frequent word n-grams. 50 | n: 51 | Number of grams. 52 | use_first_ngram_only: 53 | If `True`, only the first n words of each text will be used to form the 54 | n-grams and `count_duplicate` will be ignored. Otherwise, all words 55 | present in each text will be considered for generating the n-grams. 56 | Default to `False`. 57 | count_duplicate: 58 | If `True`, duplicate n-grams within the text are included in the total 59 | count. Otherwise, the count of a unique N-gram will only consider its 60 | first occurrence. 61 | 62 | Returns: 63 | List of tuples of ngram and its frequency of appearance as a pair. 64 | """ 65 | 66 | if k <= 0 or n <= 0: 67 | raise ValueError( 68 | f'k and n must be positive integers but k={k} and n={n} was passed.' 69 | ) 70 | 71 | return text.TopKWordNGrams( 72 | k=k, 73 | n=n, 74 | use_first_ngram_only=use_first_ngram_only, 75 | count_duplicate=count_duplicate, 76 | ).as_agg_fn()(texts) 77 | 78 | 79 | @telemetry.WithTelemetry( 80 | api='ml_metrics', category='metric', reference='pattern_frequency' 81 | ) 82 | def pattern_frequency( 83 | texts: Sequence[str], patterns: Sequence[str], count_duplicate: bool = True 84 | ) -> list[tuple[str, float]]: 85 | """Pattern frequency metric. 86 | 87 | Identify the frequency of occurrence for each pattern found within the given 88 | texts. 89 | 90 | 91 | Args: 92 | texts: 93 | Sequence of texts. 94 | patterns: 95 | Sequence of text patterns. 96 | count_duplicate: 97 | If `True`, duplicate pattern within the text are included in the total 98 | count. Otherwise, the count of a pattern will only consider its first 99 | occurrence. Default to `False`. 100 | 101 | Returns: 102 | List of tuples of pattern and its frequency of appearance as a pair. 103 | """ 104 | 105 | if not patterns: 106 | raise ValueError('Patterns must not be empty.') 107 | 108 | if len(set(patterns)) != len(patterns): 109 | raise ValueError(f'Patterns must be unique: {patterns}') 110 | 111 | return text.PatternFrequency( 112 | patterns=patterns, count_duplicate=count_duplicate 113 | ).as_agg_fn()(texts) 114 | 115 | 116 | @telemetry.WithTelemetry( 117 | api='ml_metrics', category='metric', reference='avg_alphabetical_char_count' 118 | ) 119 | def avg_alphabetical_char_count( 120 | texts: Sequence[str], 121 | ) -> rolling_stats.MeanAndVariance: 122 | """Average alphabetical character count metric.""" 123 | 124 | if not list(texts): 125 | raise ValueError('`texts` must not be empty.') 126 | 127 | batch_scorer_fn = iter_utils.iterate_fn(text_scores.alphabetical_char_count) 128 | return rolling_stats.MeanAndVariance( 129 | batch_score_fn=batch_scorer_fn 130 | ).as_agg_fn()(texts) 131 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/text_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for text.""" 15 | 16 | from absl.testing import parameterized 17 | from ml_metrics._src.aggregates import rolling_stats 18 | from ml_metrics._src.metrics import text 19 | import numpy as np 20 | 21 | from absl.testing import absltest 22 | 23 | 24 | class TopKWordNgramsTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters([ 27 | dict( 28 | testcase_name='distinct_ngrams', 29 | k=2, 30 | n=2, 31 | use_first_ngram_only=False, 32 | count_duplicate=False, 33 | expected_result=[ 34 | ('a a', 1 / 3), 35 | ('b b', 1 / 3), 36 | # ('c c', 1 / 3), 37 | # ('d a', 1 / 3), 38 | ], 39 | ), 40 | dict( 41 | testcase_name='count_duplicate', 42 | k=2, 43 | n=2, 44 | use_first_ngram_only=False, 45 | count_duplicate=True, 46 | expected_result=[ 47 | ('b b', 2 / 3), 48 | ('a a', 1 / 3), 49 | # ('c c', 1 / 3), 50 | # ('d a', 1 / 3), 51 | ], 52 | ), 53 | dict( 54 | testcase_name='use_first_ngram_only', 55 | k=2, 56 | n=2, 57 | use_first_ngram_only=True, 58 | # count_duplicate will be ignored. See TopKWordNGrams description. 59 | count_duplicate=True, 60 | expected_result=[ 61 | ('b b', 1 / 3), 62 | ('c c', 1 / 3), 63 | # ('d a', 1 / 3), 64 | ], 65 | ), 66 | dict( 67 | testcase_name='large_k', 68 | k=10, 69 | n=2, 70 | use_first_ngram_only=False, 71 | count_duplicate=True, 72 | expected_result=[ 73 | ('b b', 2 / 3), 74 | ('a a', 1 / 3), 75 | ('c c', 1 / 3), 76 | ('d a', 1 / 3), 77 | ], 78 | ), 79 | dict( 80 | testcase_name='large_n', 81 | k=4, 82 | n=10, 83 | use_first_ngram_only=False, 84 | count_duplicate=True, 85 | expected_result=[], 86 | ), 87 | dict( 88 | testcase_name='2k_1n', 89 | k=2, 90 | n=1, 91 | use_first_ngram_only=False, 92 | count_duplicate=True, 93 | expected_result=[ 94 | ('b', 3 / 3), 95 | ('a', 2 / 3), 96 | # ('c', 2 / 3), 97 | # ('d', 1 / 3), 98 | ], 99 | ), 100 | dict( 101 | testcase_name='2k_3n', 102 | k=2, 103 | n=3, 104 | use_first_ngram_only=False, 105 | count_duplicate=True, 106 | expected_result=[ 107 | ('b b b', 1 / 3), 108 | ('d a a', 1 / 3), 109 | ], 110 | ), 111 | ]) 112 | def test_topk_word_ngrams( 113 | self, k, n, use_first_ngram_only, count_duplicate, expected_result 114 | ): 115 | texts = [ 116 | 'c c', 117 | 'b B b', # Case-insensitive 118 | 'd a a', 119 | ] 120 | result = text.topk_word_ngrams( 121 | texts=texts, 122 | k=k, 123 | n=n, 124 | use_first_ngram_only=use_first_ngram_only, 125 | count_duplicate=count_duplicate 126 | ) 127 | self.assertSequenceAlmostEqual(expected_result, result) 128 | 129 | def test_topk_word_ngrams_empty(self): 130 | # Randomly choose k and n. 131 | result = text.topk_word_ngrams(texts=[], k=3, n=2) 132 | self.assertSequenceAlmostEqual([], result) 133 | 134 | def test_topk_word_ngrams_invalid_kn(self): 135 | kn_pairs = [(0, 1), (1, 0)] 136 | for k, n in kn_pairs: 137 | with self.assertRaisesRegex( 138 | ValueError, 'k and n must be positive integers.' 139 | ): 140 | text.topk_word_ngrams(texts=['a'], k=k, n=n) 141 | 142 | 143 | class PatternFrequencyTest(parameterized.TestCase): 144 | 145 | @parameterized.named_parameters([ 146 | dict( 147 | testcase_name='not_count_duplicate', 148 | count_duplicate=False, 149 | expected_result=[ 150 | ('ab', 1 / 2), 151 | ('mmm', 0), 152 | ('xyx', 2 / 2), 153 | ], 154 | ), 155 | dict( 156 | testcase_name='count_duplicate', 157 | count_duplicate=True, 158 | expected_result=[ 159 | ('ab', 2 / 2), 160 | ('mmm', 0), 161 | ('xyx', 3 / 2), 162 | ], 163 | ), 164 | ]) 165 | def test_pattern_frequency(self, count_duplicate, expected_result): 166 | batch = ['ab ab xyx', 'xyxyx'] 167 | result = text.pattern_frequency( 168 | texts=batch, 169 | patterns=['ab', 'xyx', 'mmm'], 170 | count_duplicate=count_duplicate, 171 | ) 172 | self.assertSequenceAlmostEqual(expected_result, sorted(result)) 173 | 174 | def test_pattern_frequency_empty(self): 175 | result = text.pattern_frequency( 176 | texts=[], patterns=['ab', 'xyx'], count_duplicate=False 177 | ) 178 | self.assertSequenceEqual([], result) 179 | 180 | @parameterized.named_parameters([ 181 | dict( 182 | testcase_name='empty_pattern', 183 | patterns=[], 184 | expected_regex='Patterns must not be empty.', 185 | ), 186 | dict( 187 | testcase_name='duplicate_patterns', 188 | patterns=['a', 'a', 'b'], 189 | expected_regex='Patterns must be unique:', 190 | ), 191 | ]) 192 | def test_pattern_frequency_invalid_patterns(self, patterns, expected_regex): 193 | with self.assertRaisesRegex(ValueError, expected_regex): 194 | # `texts` is randomly assigned. 195 | text.pattern_frequency(texts=['a'], patterns=patterns) 196 | 197 | 198 | def get_expected_mean_and_variance_state_result(batch): 199 | # result_dict key must match MeanAndVariance properties name. 200 | result_dict = { 201 | 'mean': np.mean(batch), 202 | 'var': np.var(batch), 203 | 'stddev': np.std(batch), 204 | 'count': len(batch), 205 | 'total': np.sum(batch), 206 | } 207 | return result_dict 208 | 209 | 210 | class AvgAlphabeticalCharCountTest(absltest.TestCase): 211 | 212 | def test_avg_alphabetical_char_count(self): 213 | got_state = text.avg_alphabetical_char_count( 214 | texts=['ab', 'a b', '', 'ok?'] 215 | ) 216 | self.assertIsInstance(got_state, rolling_stats.MeanAndVariance) 217 | 218 | expected_state = get_expected_mean_and_variance_state_result( 219 | [2, 2, 0, 2] 220 | ) 221 | for metric_name in expected_state: 222 | np.testing.assert_allclose( 223 | expected_state[metric_name], 224 | getattr(got_state, metric_name), 225 | ) 226 | 227 | def test_avg_alphabetical_char_count_empty(self): 228 | with self.assertRaisesRegex(ValueError, '`texts` must not be empty.'): 229 | text.avg_alphabetical_char_count(texts=[]) 230 | 231 | 232 | if __name__ == '__main__': 233 | absltest.main() 234 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/utils.py: -------------------------------------------------------------------------------- 1 | """Utility for ML metrics APIs.""" 2 | 3 | import itertools 4 | from typing import Any 5 | 6 | from ml_metrics._src.aggregates import classification 7 | from ml_metrics._src.aggregates import types 8 | from ml_metrics._src.tools.telemetry import telemetry 9 | 10 | 11 | @telemetry.WithTelemetry( 12 | api='ml_metrics', category='metric', reference='verify_input' 13 | ) 14 | def verify_input(y_true, y_pred, average, input_type, vocab, pos_label): 15 | if ( 16 | average == types.AverageType.BINARY 17 | and input_type == types.InputType.BINARY 18 | ): 19 | _validate_pos_label(y_true, y_pred, pos_label, vocab) 20 | 21 | 22 | def _validate_pos_label( 23 | y_true, y_pred, pos_label: Any, vocab: dict[str, int] | None 24 | ): 25 | vocab = vocab or classification.get_vocab( 26 | itertools.chain(y_true, y_pred), False 27 | ) 28 | labels = list(vocab.keys()) 29 | labels.sort() 30 | if pos_label not in labels: 31 | raise ValueError( 32 | f'Pos label {pos_label} not in labels. Please set a valid pos_label' 33 | f' from: {labels}' 34 | ) 35 | -------------------------------------------------------------------------------- /ml_metrics/_src/metrics/utils_test.py: -------------------------------------------------------------------------------- 1 | """Tests for ML metrics API function utils.""" 2 | 3 | from ml_metrics._src.aggregates import types 4 | from ml_metrics._src.metrics import utils 5 | 6 | from absl.testing import absltest 7 | 8 | 9 | InputType = types.InputType 10 | AverageType = types.AverageType 11 | 12 | 13 | class UtilsTest(absltest.TestCase): 14 | 15 | def test_validate_inputs(self): 16 | with self.assertRaisesRegex( 17 | ValueError, 18 | r"Pos label 1 not in labels. Please set a valid pos_label from: \['n'," 19 | r" 'y'\]", 20 | ): 21 | _ = utils.verify_input( 22 | y_true=["y", "n", "n"], 23 | y_pred=["y", "n", "y"], 24 | average=AverageType.BINARY, 25 | input_type=InputType.BINARY, 26 | pos_label=1, 27 | vocab=None, 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | absltest.main() 33 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/cg_score.py: -------------------------------------------------------------------------------- 1 | """Code to compute CG Score. 2 | 3 | The Complexity Gap Score (CG Score) quantifies the "influence of individual 4 | instances." CG scores identify observations with large/small influence on 5 | downstream classification performance, potentially flagging label noise issues. 6 | The current implementation works for binary labels only. 7 | More information about the stochastic method implemented here can be found at 8 | https://arxiv.org/pdf/2301.00930.pdf (Section A.2). 9 | """ 10 | 11 | from ml_metrics._src.aggregates import types 12 | from ml_metrics._src.tools.telemetry import telemetry 13 | import numpy as np 14 | 15 | 16 | @telemetry.WithTelemetry( 17 | api='ml_metrics', category='signal', reference='complexity_gap_score' 18 | ) 19 | def complexity_gap_score( 20 | labels: types.NumbersT, 21 | embeddings: types.NumbersT, 22 | *, 23 | num_repetitions: int = 1, 24 | class_balance_ratio: float = 1.0, 25 | random_seed: int = 0, 26 | ) -> types.NumbersT: 27 | """Calculates the Complexity Gap (CG) score for identifying influential instances. 28 | 29 | Args: 30 | labels: Labels in binary vector representations. 31 | embeddings: Embeddings in vector representations. 32 | num_repetitions: Number of times to repeat the CG score calculation. 33 | class_balance_ratio: Ratio for balancing classes during calculation (e.g., 34 | 1.0 for perfect balance). 35 | random_seed: Random seed for reproducibility. 36 | 37 | Returns: 38 | A NumPy array containing the CG score for each data point. 39 | """ 40 | if (l := len(np.unique(labels))) > 2: 41 | raise ValueError(f'CG score only works for binary labels, got {l} labels.') 42 | 43 | cg_scores = np.zeros(len(labels)) 44 | 45 | if l == 1: 46 | return cg_scores 47 | 48 | data_by_label = _group_data_by_label(embeddings, labels) 49 | if random_seed: 50 | np.random.seed(random_seed) 51 | 52 | for _ in range(num_repetitions): 53 | for label, data in data_by_label.items(): 54 | data = np.array(data['data']) 55 | other_label = _get_other_label(data_by_label, label) 56 | other_data = np.array(data_by_label[other_label]['data']) 57 | balanced_data = _balance_dataset(data, other_data, class_balance_ratio) 58 | vi_scores = _calculate_influence_scores(balanced_data, data.shape[0]) 59 | cg_scores[data_by_label[label]['indices']] += vi_scores 60 | 61 | return cg_scores / np.maximum(num_repetitions, 1) 62 | 63 | 64 | def _group_data_by_label(embeddings: np.ndarray, labels: np.ndarray): 65 | data_by_label = {} 66 | for i, (embedding, label) in enumerate(zip(embeddings, labels, strict=True)): 67 | data_by_label.setdefault(label, {'data': [], 'indices': []}) 68 | data_float = embedding.astype(np.float64) 69 | data_by_label[label]['data'].append( 70 | data_float / np.linalg.norm(data_float, axis=-1, keepdims=True) 71 | ) 72 | data_by_label[label]['indices'].append(i) 73 | return data_by_label 74 | 75 | 76 | def _get_other_label( 77 | data_by_label: dict[int, object], current_label: int 78 | ) -> int: 79 | return [label for label in data_by_label if label != current_label][0] 80 | 81 | 82 | def _balance_dataset( 83 | data: np.ndarray, other_data: np.ndarray, ratio: float 84 | ) -> np.ndarray: 85 | max_size = min(int(len(data) * ratio), len(other_data)) 86 | selected_indices = np.random.choice(len(other_data), max_size, replace=False) 87 | balanced_other_data = other_data[selected_indices] 88 | return np.concatenate((data, balanced_other_data)) 89 | 90 | 91 | def _calculate_influence_scores( 92 | data: np.ndarray, data_size: int 93 | ) -> types.NumbersT: 94 | """Computes Complexity Gap "influence" scores for each data point.""" 95 | 96 | reformatted_data = data 97 | y = np.concatenate([np.ones(data_size), -np.ones(data.shape[0] - data_size)]) 98 | hermitian_inner = reformatted_data @ np.transpose(reformatted_data) 99 | hermitian = (hermitian_inner * (np.pi - np.arccos(hermitian_inner))) / ( 100 | 2 * np.pi 101 | ) 102 | np.fill_diagonal(hermitian, 0.5) 103 | hermitian[np.isnan(hermitian)] = np.nextafter(np.float64(0), np.float64(1)) 104 | 105 | inv_hermitian = np.linalg.pinv(hermitian, hermitian=True) 106 | original_error = y @ (inv_hermitian @ y) 107 | 108 | influence_scores = np.zeros(data_size) 109 | for k in range(data_size): 110 | without_col = np.delete(inv_hermitian, k, axis=1) 111 | without_k = np.delete(without_col, k, axis=0) 112 | row_expanded = np.expand_dims(without_col[k, :], 0) 113 | with_k = inv_hermitian[k, k] 114 | 115 | inv_hermitian_except_k = ( 116 | without_k - (row_expanded.transpose() @ row_expanded) / with_k 117 | ) 118 | y_except_k = np.delete(y, k, axis=0) 119 | influence_scores[k] = original_error - y_except_k @ ( 120 | inv_hermitian_except_k @ y_except_k 121 | ) 122 | 123 | return influence_scores 124 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/cg_score_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from ml_metrics._src.signals import cg_score 3 | import numpy as np 4 | 5 | 6 | class CgScoreTest(absltest.TestCase): 7 | 8 | def test_validate_nonbinary_input(self): 9 | labels = ['A', 'B', 'C'] 10 | with self.assertRaisesRegex(ValueError, 'only works for binary label'): 11 | cg_score.complexity_gap_score(labels, np.array([[1, 2], [2, 3], [5, 4]])) 12 | 13 | def test_group_data_by_label(self): 14 | input_embeddings = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 15 | labels = np.array(['a', 'b', 'b']) 16 | result = cg_score._group_data_by_label(input_embeddings, labels) 17 | 18 | self.assertIn('a', result) 19 | self.assertSequenceAlmostEqual([0], result['a']['indices']) 20 | self.assertLen(result['a']['data'], 1) 21 | self.assertAlmostEqual(1, np.linalg.norm(result['a']['data'])) 22 | 23 | self.assertIn('b', result) 24 | self.assertSequenceAlmostEqual([1, 2], result['b']['indices']) 25 | self.assertLen(result['b']['data'], 2) 26 | self.assertAlmostEqual(1, np.linalg.norm(result['b']['data'][0])) 27 | self.assertAlmostEqual(1, np.linalg.norm(result['b']['data'][1])) 28 | 29 | def test_get_other_label(self): 30 | example_dict = { 31 | 'a': { 32 | 'data': np.array([1, 2, 3]), 33 | 'indices': [0, 1], 34 | }, 35 | 'b': { 36 | 'data': np.array([4, 5, 6]), 37 | 'indices': [0, 1], 38 | }, 39 | } 40 | self.assertEqual(cg_score._get_other_label(example_dict, 'a'), 'b') 41 | self.assertEqual(cg_score._get_other_label(example_dict, 'b'), 'a') 42 | 43 | def test_balance_dataset(self): 44 | data_a = np.array([[1, 2, 3]]) 45 | data_b = np.array(np.array([[4, 5, 6], [7, 8, 9]])) 46 | self.assertLen(cg_score._balance_dataset(data_a, data_b, 1), 2) 47 | self.assertLen(cg_score._balance_dataset(data_a, data_b, 2), 3) 48 | self.assertLen(cg_score._balance_dataset(data_b, data_a, 1), 3) 49 | 50 | def test_calculate_complexity_gap_score_result_sanity(self): 51 | embeddings = np.array([[1, 2], [2, 3], [5, 4], [5, 4]]) 52 | labels = np.array([1, 1, 1, 0]) 53 | cg_scores = cg_score.complexity_gap_score( 54 | labels, embeddings, num_repetitions=1, random_seed=0 55 | ) 56 | self.assertLen(cg_scores, 4) 57 | self.assertLess(cg_scores[0], cg_scores[3]) 58 | 59 | def test_calculate_complexity_gap_score_simple(self): 60 | embeddings = np.array([[0, 1], [1, 0], [0, 1], [1, 0]]) 61 | labels = np.array([1, 1, 0, 0]) 62 | cg_scores = cg_score.complexity_gap_score( 63 | labels, embeddings, num_repetitions=1 64 | ) 65 | self.assertSequenceAlmostEqual(cg_scores, [0, 0, 0, 0]) 66 | 67 | def test_num_repetitions(self): 68 | embeddings = np.array([[1, 2], [1, 2], [5, 4], [5, 4]]) 69 | labels = np.array([1, 1, 0, 0]) 70 | cg_scores_one = cg_score.complexity_gap_score( 71 | labels, embeddings, num_repetitions=1 72 | ) 73 | cg_scores_three = cg_score.complexity_gap_score( 74 | labels, embeddings, num_repetitions=3 75 | ) 76 | self.assertLen(cg_scores_one, 4) 77 | self.assertLen(cg_scores_three, 4) 78 | np.testing.assert_allclose(cg_scores_one, cg_scores_three) 79 | 80 | def test_one_single_label(self): 81 | embeddings = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 82 | labels = np.array([1, 1, 1]) 83 | cg_scores = cg_score.complexity_gap_score( 84 | labels, embeddings, num_repetitions=1 85 | ) 86 | self.assertSequenceAlmostEqual(cg_scores, [0, 0, 0]) 87 | 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module for cross entropy loss functions.""" 15 | 16 | from ml_metrics._src.aggregates import types 17 | from ml_metrics._src.tools.telemetry import telemetry 18 | import numpy as np 19 | 20 | 21 | def _check_y_true_contains_only_0_and_1(y_true: types.NumbersT) -> None: 22 | if not all(y == 0 or y == 1 for y in y_true): 23 | raise ValueError( 24 | 'y_true must contain only 0s and 1s, but recieved: {}'.format(y_true) 25 | ) 26 | 27 | 28 | @telemetry.WithTelemetry( 29 | api='ml_metrics', category='signal', reference='binary_cross_entropy' 30 | ) 31 | def binary_cross_entropy( 32 | y_true: types.NumbersT, 33 | y_pred: types.NumbersT, 34 | ) -> float: 35 | """Calculates binary cross entropy loss for two lists of labels. 36 | 37 | Args: 38 | y_true: Truth label. This is either 0 or 1. 39 | y_pred: Predicted value. This is the model's prediction, i.e, a single 40 | floating-point value which represents a probability (i.e., value in (0., 41 | 1.)). 42 | 43 | Returns: 44 | The binary cross-entropy loss between true labels and predicted labels. 45 | """ 46 | _check_y_true_contains_only_0_and_1(y_true) 47 | 48 | return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) 49 | 50 | 51 | @telemetry.WithTelemetry( 52 | api='ml_metrics', category='signal', reference='categorical_cross_entropy' 53 | ) 54 | def categorical_cross_entropy( 55 | y_true: types.NumbersT, 56 | y_pred: types.NumbersT, 57 | ): 58 | """Calculates categorical cross entropy loss for two lists of labels. 59 | 60 | Args: 61 | y_true: Truth label. This is either 0 or 1. 62 | y_pred: Predicted value. This is the model's prediction, i.e, a single 63 | floating-point value which represents a probability (i.e., value in [0., 64 | 1.]). 65 | 66 | Returns: 67 | The categorical cross-entropy loss between true labels and predicted labels. 68 | """ 69 | _check_y_true_contains_only_0_and_1(y_true) 70 | 71 | return -np.sum(y_true * np.log(y_pred / np.sum(y_pred))) 72 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/cross_entropy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Cross Entropy.""" 15 | 16 | import math 17 | 18 | from ml_metrics._src.signals import cross_entropy 19 | import numpy as np 20 | 21 | from absl.testing import absltest 22 | 23 | 24 | _COMPARED_DECIMAL_PLACES = 6 # Used to tune tests accuracy. 25 | 26 | 27 | class CrossEntropyTest(absltest.TestCase): 28 | 29 | def test_binary_cross_entropy(self): 30 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 1)) 31 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9)) 32 | 33 | # How to reproduce expected_result: 34 | # bce = tensorflow.keras.losses.BinaryCrossentropy() 35 | # expected_result = tensorflow.get_static_value( 36 | # bce(y_true=y_true, y_pred=y_pred) 37 | # ) 38 | expected_result = 0.9587651091286978 39 | 40 | self.assertAlmostEqual( 41 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred), 42 | expected_result, 43 | places=_COMPARED_DECIMAL_PLACES, 44 | ) 45 | 46 | def test_binary_cross_entropy_y_pred_includes_0_and_1(self): 47 | y_true = np.array((0, 1)) 48 | y_pred = np.array((0, 1)) 49 | 50 | self.assertTrue( 51 | math.isnan( 52 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred) 53 | ), 54 | ) 55 | 56 | def test_binary_cross_entropy_invalid_y_true_raises_error(self): 57 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 2)) 58 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9)) 59 | 60 | with self.assertRaisesRegex( 61 | ValueError, 'y_true must contain only 0s and 1s, but recieved: ' 62 | ): 63 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred) 64 | 65 | def test_categorical_cross_entropy(self): 66 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 1)) 67 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9)) 68 | 69 | # How to reproduce expected_result: 70 | # cce = tensorflow.keras.losses.CategoricalCrossentropy() 71 | # expected_result = tensorflow.get_static_value( 72 | # cce(y_true=y_true, y_pred=y_pred) 73 | # ) 74 | expected_result = 9.38023940877158 75 | 76 | self.assertAlmostEqual( 77 | cross_entropy.categorical_cross_entropy(y_true=y_true, y_pred=y_pred), 78 | expected_result, 79 | places=_COMPARED_DECIMAL_PLACES, 80 | ) 81 | 82 | def test_categorical_cross_entropy_invalid_y_true_raises_error(self): 83 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 2)) 84 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9)) 85 | 86 | with self.assertRaisesRegex( 87 | ValueError, 'y_true must contain only 0s and 1s, but recieved: ' 88 | ): 89 | cross_entropy.categorical_cross_entropy(y_true=y_true, y_pred=y_pred) 90 | 91 | 92 | if __name__ == '__main__': 93 | absltest.main() 94 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/flip_masks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Flip Masks.""" 15 | 16 | from ml_metrics._src.aggregates import types 17 | from ml_metrics._src.tools.telemetry import telemetry 18 | import numpy as np 19 | 20 | 21 | @telemetry.WithTelemetry( 22 | api='ml_metrics', category='signal', reference='binary_flip_mask' 23 | ) 24 | def binary_flip_mask( 25 | base_prediction: types.NumbersT, 26 | model_prediction: types.NumbersT, 27 | threshold: types.NumbersT | None = None, 28 | ) -> types.NumbersT: 29 | """AKA symmetric flip mask. Returns a 1 if the predictions don't match.""" 30 | if threshold is not None: 31 | base_prediction = base_prediction > threshold 32 | model_prediction = model_prediction > threshold 33 | 34 | return np.logical_xor(base_prediction, model_prediction).astype(int) 35 | 36 | 37 | @telemetry.WithTelemetry( 38 | api='ml_metrics', category='signal', reference='neg_to_pos_flip_mask' 39 | ) 40 | def neg_to_pos_flip_mask( 41 | base_prediction: types.NumbersT, 42 | model_prediction: types.NumbersT, 43 | threshold: types.NumbersT | None = None, 44 | ) -> types.NumbersT: 45 | """Returns a 1 if base_prediction <= threshold < model_prediction.""" 46 | if threshold is None: 47 | return not base_prediction and model_prediction 48 | 49 | base_under_threshold = base_prediction <= threshold 50 | model_over_threshold = model_prediction > threshold 51 | 52 | return np.logical_and(base_under_threshold, model_over_threshold).astype(int) 53 | 54 | 55 | @telemetry.WithTelemetry( 56 | api='ml_metrics', category='signal', reference='pos_to_neg_flip_mask' 57 | ) 58 | def pos_to_neg_flip_mask( 59 | base_prediction: types.NumbersT, 60 | model_prediction: types.NumbersT, 61 | threshold: types.NumbersT | None = None, 62 | ) -> types.NumbersT: 63 | """Returns a 1 if base_prediction > threshold >= model_prediction.""" 64 | if threshold is None: 65 | return base_prediction and not model_prediction 66 | 67 | base_over_threshold = base_prediction > threshold 68 | model_under_threshold = model_prediction <= threshold 69 | 70 | return np.logical_and(base_over_threshold, model_under_threshold).astype(int) 71 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/flip_masks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Flip Masks.""" 15 | 16 | from ml_metrics._src.signals import flip_masks 17 | import numpy as np 18 | from numpy import testing 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | 24 | class FlipMasksTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | dict( 28 | testcase_name='binary', 29 | flip_mask_fn=flip_masks.binary_flip_mask, 30 | expected_mask=(False, True, True, False), 31 | ), 32 | dict( 33 | testcase_name='neg_to_pos', 34 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 35 | expected_mask=(False, True, False, False), 36 | ), 37 | dict( 38 | testcase_name='pos_to_neg', 39 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 40 | expected_mask=(False, False, True, False), 41 | ), 42 | ) 43 | def test_flip_masks_bool_predictions(self, flip_mask_fn, expected_mask): 44 | base_predictions = np.array((False, False, True, True)) 45 | model_predictions = np.array((False, True, False, True)) 46 | 47 | for base_prediction, model_prediction, expected_flip in zip( 48 | base_predictions, model_predictions, expected_mask 49 | ): 50 | self.assertEqual( 51 | flip_mask_fn(base_prediction, model_prediction), 52 | expected_flip, 53 | ) 54 | 55 | @parameterized.named_parameters( 56 | dict( 57 | testcase_name='binary', 58 | flip_mask_fn=flip_masks.binary_flip_mask, 59 | expected_mask=(0, 1, 1, 0), 60 | ), 61 | dict( 62 | testcase_name='neg_to_pos', 63 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 64 | expected_mask=(0, 1, 0, 0), 65 | ), 66 | dict( 67 | testcase_name='pos_to_neg', 68 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 69 | expected_mask=(0, 0, 1, 0), 70 | ), 71 | ) 72 | def test_flip_masks_int_predictions(self, flip_mask_fn, expected_mask): 73 | base_predictions = np.array((0, 0, 1, 1)) 74 | model_predictions = np.array((0, 1, 0, 1)) 75 | 76 | for base_prediction, model_prediction, expected_flip in zip( 77 | base_predictions, model_predictions, expected_mask 78 | ): 79 | self.assertEqual( 80 | flip_mask_fn(base_prediction, model_prediction), 81 | expected_flip, 82 | ) 83 | 84 | @parameterized.named_parameters( 85 | dict( 86 | testcase_name='binary', 87 | flip_mask_fn=flip_masks.binary_flip_mask, 88 | expected_mask=(0, 1, 1, 0), 89 | ), 90 | dict( 91 | testcase_name='neg_to_pos', 92 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 93 | expected_mask=(0, 1, 0, 0), 94 | ), 95 | dict( 96 | testcase_name='pos_to_neg', 97 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 98 | expected_mask=(0, 0, 1, 0), 99 | ), 100 | ) 101 | def test_flip_masks_threshold(self, flip_mask_fn, expected_mask): 102 | base_predictions = np.array((0.1, 0.1, 0.9, 0.9)) 103 | model_predictions = np.array((0.2, 0.9, 0.1, 0.8)) 104 | 105 | for base_prediction, model_prediction, expected_flip in zip( 106 | base_predictions, model_predictions, expected_mask 107 | ): 108 | self.assertEqual( 109 | flip_mask_fn(base_prediction, model_prediction, threshold=0.5), 110 | expected_flip, 111 | ) 112 | 113 | @parameterized.named_parameters( 114 | dict( 115 | testcase_name='binary', 116 | flip_mask_fn=flip_masks.binary_flip_mask, 117 | expected_mask=(False, True, True, False), 118 | ), 119 | dict( 120 | testcase_name='neg_to_pos', 121 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 122 | expected_mask=(False, True, False, False), 123 | ), 124 | dict( 125 | testcase_name='pos_to_neg', 126 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 127 | expected_mask=(False, False, True, False), 128 | ), 129 | ) 130 | def test_flip_masks_bool_predictions_batched( 131 | self, flip_mask_fn, expected_mask 132 | ): 133 | base_predictions = np.array((False, False, True, True)) 134 | model_predictions = np.array((False, True, False, True)) 135 | 136 | testing.assert_array_equal( 137 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5), 138 | expected_mask, 139 | ) 140 | 141 | @parameterized.named_parameters( 142 | dict( 143 | testcase_name='binary', 144 | flip_mask_fn=flip_masks.binary_flip_mask, 145 | expected_mask=(0, 1, 1, 0), 146 | ), 147 | dict( 148 | testcase_name='neg_to_pos', 149 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 150 | expected_mask=(0, 1, 0, 0), 151 | ), 152 | dict( 153 | testcase_name='pos_to_neg', 154 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 155 | expected_mask=(0, 0, 1, 0), 156 | ), 157 | ) 158 | def test_flip_masks_int_predictions_batched( 159 | self, flip_mask_fn, expected_mask 160 | ): 161 | base_predictions = np.array((0, 0, 1, 1)) 162 | model_predictions = np.array((0, 1, 0, 1)) 163 | 164 | testing.assert_array_equal( 165 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5), 166 | expected_mask, 167 | ) 168 | 169 | @parameterized.named_parameters( 170 | dict( 171 | testcase_name='binary', 172 | flip_mask_fn=flip_masks.binary_flip_mask, 173 | expected_mask=(0, 1, 1, 0), 174 | ), 175 | dict( 176 | testcase_name='neg_to_pos', 177 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask, 178 | expected_mask=(0, 1, 0, 0), 179 | ), 180 | dict( 181 | testcase_name='pos_to_neg', 182 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask, 183 | expected_mask=(0, 0, 1, 0), 184 | ), 185 | ) 186 | def test_flip_masks_threshold_batched(self, flip_mask_fn, expected_mask): 187 | base_predictions = np.array((0.1, 0.1, 0.9, 0.9)) 188 | model_predictions = np.array((0.2, 0.9, 0.1, 0.8)) 189 | 190 | testing.assert_array_equal( 191 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5), 192 | expected_mask, 193 | ) 194 | 195 | 196 | if __name__ == '__main__': 197 | absltest.main() 198 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Samplewise scoring metrics for image.""" 15 | 16 | import io 17 | 18 | from ml_metrics._src.tools.telemetry import telemetry 19 | from PIL import Image 20 | 21 | 22 | MEGAPIXELS = 1024 * 1024 23 | 24 | 25 | @telemetry.WithTelemetry( 26 | api='ml_metrics', category='signal', reference='content_metadata' 27 | ) 28 | def content_metadata(image_bytes: bytes) -> dict[str, int | float | str]: 29 | """Extracts the content metadata of an image.""" 30 | 31 | img = Image.open(io.BytesIO(image_bytes)) 32 | width, height = img.size 33 | return { # pytype: disable=bad-return-type # pillow-102-upgrade 34 | 'mode': img.mode, 35 | 'format': img.format, 36 | 'pixel_width': width, 37 | 'pixel_height': height, 38 | 'megapixel_resolution': width * height / MEGAPIXELS, 39 | 'aspect_ratio': width / height if height > 0 else 0.0, 40 | } 41 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/image_test.py: -------------------------------------------------------------------------------- 1 | """Tests for image.""" 2 | 3 | import io 4 | 5 | from ml_metrics._src.signals import image 6 | from PIL import Image 7 | 8 | from absl.testing import absltest 9 | 10 | 11 | class ImageTest(absltest.TestCase): 12 | 13 | def test_content_metadata(self): 14 | single_black_pixel_image = Image.new('RGB', (1, 1)) 15 | with io.BytesIO() as output: 16 | single_black_pixel_image.save(output, format='PNG') 17 | single_black_pixel_image_bytes = output.getvalue() 18 | 19 | expected_metadata = { 20 | 'mode': 'RGB', 21 | 'format': 'PNG', 22 | 'pixel_width': 1, 23 | 'pixel_height': 1, 24 | 'megapixel_resolution': 1 / image.MEGAPIXELS, 25 | 'aspect_ratio': 1.0, 26 | } 27 | 28 | metadata = image.content_metadata(single_black_pixel_image_bytes) 29 | self.assertEqual(metadata, expected_metadata) 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Samplewise scoring metrics for text.""" 15 | 16 | from collections.abc import Callable, Sequence 17 | import re 18 | from typing import Any 19 | 20 | from ml_metrics._src.tools.telemetry import telemetry 21 | 22 | 23 | def _maybe_tuple( 24 | reference: str | Sequence[str], 25 | ) -> Sequence[str]: 26 | """Converts reference to a sequence if it is a single string.""" 27 | if isinstance(reference, str): 28 | return (reference,) 29 | return reference 30 | 31 | 32 | @telemetry.WithTelemetry( 33 | api='ml_metrics', category='signal', reference='alphabetical_char_count' 34 | ) 35 | def alphabetical_char_count(text: str) -> int: 36 | """Computes the number of alphabetical characters.""" 37 | return len(re.sub(r'[^a-zA-Z]', '', text)) 38 | 39 | 40 | @telemetry.WithTelemetry( 41 | api='ml_metrics', category='signal', reference='word_count' 42 | ) 43 | def word_count(text: str) -> int: 44 | """Computes the number of words. 45 | 46 | Computes the number of words within the text. Characters that are not letters 47 | or spaces are taken out of the text, leaving only spaces between words. 48 | However, each contraction is counted as only one word. For example, "I'm" is 49 | treated as one word, "Im". 50 | 51 | Args: 52 | text: Input text. 53 | 54 | Returns: 55 | Number of words. 56 | """ 57 | return len(_get_words(text)) 58 | 59 | 60 | @telemetry.WithTelemetry( 61 | api='ml_metrics', category='signal', reference='token_count' 62 | ) 63 | def token_count(text: str, tokenizer: Callable[[str], Sequence[Any]]) -> int: 64 | """Computes the number of tokens.""" 65 | return len(tokenizer(text)) 66 | 67 | 68 | @telemetry.WithTelemetry( 69 | api='ml_metrics', category='signal', reference='token_match_rate' 70 | ) 71 | def token_match_rate( 72 | sample: str, reference: str, tokenizer: Callable[[str], Sequence[Any]] 73 | ) -> float: 74 | """Computes the token match rate between sample and reference.""" 75 | sample_tokens = tokenizer(sample) 76 | reference_tokens = tokenizer(reference) 77 | matched = 0 78 | for t1, t2 in zip(sample_tokens, reference_tokens): 79 | if t1 == t2: 80 | matched += 1 81 | length = max(len(sample_tokens), len(reference_tokens)) 82 | if length == 0: 83 | return 0 84 | return matched / length 85 | 86 | 87 | @telemetry.WithTelemetry( 88 | api='ml_metrics', category='signal', reference='exact_match' 89 | ) 90 | def exact_match(sample: str, reference: str | Sequence[str]) -> bool: 91 | """Computes the exact match between sample and reference.""" 92 | references = _maybe_tuple(reference) 93 | return any(sample == ref for ref in references) 94 | 95 | 96 | @telemetry.WithTelemetry( 97 | api='ml_metrics', 98 | category='signal', 99 | reference='sample_startswith_reference_match', 100 | ) 101 | def sample_startswith_reference_match( 102 | sample: str, reference: str | Sequence[str] 103 | ) -> bool: 104 | """True when the sample starts with reference.""" 105 | references = _maybe_tuple(reference) 106 | return any(sample.startswith(ref) for ref in references) 107 | 108 | 109 | @telemetry.WithTelemetry( 110 | api='ml_metrics', 111 | category='signal', 112 | reference='reference_startswith_sample_match', 113 | ) 114 | def reference_startswith_sample_match( 115 | sample: str, reference: str | Sequence[str] 116 | ) -> bool: 117 | """True when the reference starts with sample.""" 118 | references = _maybe_tuple(reference) 119 | return any(ref.startswith(sample) for ref in references) 120 | 121 | 122 | @telemetry.WithTelemetry( 123 | api='ml_metrics', category='signal', reference='reference_in_sample_match' 124 | ) 125 | def reference_in_sample_match( 126 | sample: str, reference: str | Sequence[str] 127 | ) -> bool: 128 | """True when the reference in sample match.""" 129 | references = _maybe_tuple(reference) 130 | return any(ref in sample for ref in references) 131 | 132 | 133 | @telemetry.WithTelemetry( 134 | api='ml_metrics', category='signal', reference='sample_in_reference_match' 135 | ) 136 | def sample_in_reference_match( 137 | sample: str, reference: str | Sequence[str] 138 | ) -> bool: 139 | """True when the sample in reference match.""" 140 | references = _maybe_tuple(reference) 141 | return any(sample in ref for ref in references) 142 | 143 | 144 | @telemetry.WithTelemetry( 145 | api='ml_metrics', category='signal', reference='non_ascii_char_count' 146 | ) 147 | def non_ascii_char_count(text: str) -> int: 148 | """Computes the number of non-ascii characters.""" 149 | return len(re.sub(r'[^\x00-\x7F]+', '', text)) 150 | 151 | 152 | @telemetry.WithTelemetry( 153 | api='ml_metrics', category='signal', reference='is_all_whitespace' 154 | ) 155 | def is_all_whitespace(text: str) -> bool: 156 | r"""Checks if the text is all whitespace. 157 | 158 | Check if string is empty-ish e.g. consisting of whitespace, \n, \t. 159 | 160 | Args: 161 | text: Input text. 162 | 163 | Returns: 164 | True if the text is all whitespace. 165 | """ 166 | return not text.strip() 167 | 168 | 169 | @telemetry.WithTelemetry( 170 | api='ml_metrics', category='signal', reference='average_word_length' 171 | ) 172 | def average_word_length(text: str) -> float: 173 | """Computes the average word length.""" 174 | words = _get_words(text) 175 | if not words: 176 | return 0.0 177 | return sum(len(word) for word in words) / len(words) 178 | 179 | 180 | def _get_words(text: str) -> list[str]: 181 | """Returns the words in the text.""" 182 | return re.sub(r'[^a-zA-Z ]', '', text).split() 183 | 184 | 185 | @telemetry.WithTelemetry( 186 | api='ml_metrics', category='signal', reference='unique_word_count' 187 | ) 188 | def unique_word_count(text: str) -> int: 189 | """Computes the number of unique words.""" 190 | return len(set(_get_words(text))) 191 | 192 | 193 | @telemetry.WithTelemetry( 194 | api='ml_metrics', category='signal', reference='number_of_characters' 195 | ) 196 | def number_of_characters(text: str) -> int: 197 | """Computes the number of characters.""" 198 | return len(text) 199 | 200 | 201 | @telemetry.WithTelemetry( 202 | api='ml_metrics', category='signal', reference='percentage_all_caps' 203 | ) 204 | def percentage_all_caps(text: str) -> float: 205 | """Computes the percentage of all caps.""" 206 | words = _get_words(text) 207 | if not words: 208 | return 0 209 | return len([word for word in words if word.isupper()]) / len(words) 210 | 211 | 212 | @telemetry.WithTelemetry( 213 | api='ml_metrics', 214 | category='signal', 215 | reference='percentage_non_ascii_characters', 216 | ) 217 | def percentage_non_ascii_characters(text: str) -> float: 218 | """Computes the percentage of non-ascii characters.""" 219 | if not number_of_characters(text): 220 | return 0 221 | return 1 - (non_ascii_char_count(text) / number_of_characters(text)) 222 | 223 | 224 | @telemetry.WithTelemetry( 225 | api='ml_metrics', category='signal', reference='type_token_ratio' 226 | ) 227 | def type_token_ratio(text: str) -> float: 228 | """Computes the type token ratio. 229 | 230 | Words with the same letters but different lowercase letters are considered 231 | different. 232 | 233 | Args: 234 | text: Input text. 235 | 236 | Returns: 237 | The ratio of unique words to total words. 238 | """ 239 | words = _get_words(text) 240 | if not words: 241 | return 0 242 | 243 | return unique_word_count(text) / len(words) 244 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/topk_accuracy.py: -------------------------------------------------------------------------------- 1 | """Topk accuracy metric.""" 2 | 3 | from ml_metrics._src.aggregates import types 4 | from ml_metrics._src.tools.telemetry import telemetry 5 | import numpy as np 6 | 7 | 8 | @telemetry.WithTelemetry( 9 | api='ml_metrics', category='signal', reference='topk_accuracy' 10 | ) 11 | def topk_accurate( 12 | y_pred: types.NumbersT, 13 | label: int, 14 | weights: types.NumbersT = 1.0, 15 | k: int = 1, 16 | ) -> bool: 17 | """Calculate topk accuracy. 18 | 19 | Args: 20 | y_pred: Prediction scores with shape [num_classes]. 21 | label: The ground truth label with values in [0, num_classes). 22 | weights: Weight applied to the prediction scores for computing the top-k 23 | accurate metric. Default is 1.0. 24 | k: The top-k predictions to consider. 25 | 26 | Returns: 27 | True if the label is in the top-k predictions, False otherwise. 28 | """ 29 | weighted_pred = np.asarray(y_pred) * np.asarray(weights) 30 | # Get indices of top-k predictions 31 | topk_predictions = np.argsort(weighted_pred)[-k:] 32 | return label in topk_predictions 33 | -------------------------------------------------------------------------------- /ml_metrics/_src/signals/topk_accuracy_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for topk_accuracy.py.""" 2 | 3 | from ml_metrics._src.signals import topk_accuracy 4 | from absl.testing import absltest 5 | from absl.testing import parameterized 6 | 7 | 8 | class TopkAccuracyTest(parameterized.TestCase): 9 | 10 | @parameterized.named_parameters( 11 | dict( 12 | testcase_name="top_k_accuracy_label1_k1", 13 | y_pred=[0.2, 0.7, 0.1], 14 | label=1, 15 | k=1, 16 | expected=1.0, 17 | ), 18 | dict( 19 | testcase_name="top_k_accuracy_label0_k1", 20 | y_pred=[0.2, 0.7, 0.1], 21 | label=0, 22 | k=1, 23 | expected=0.0, 24 | ), 25 | dict( 26 | testcase_name="top_k_accuracy_label0_k2", 27 | y_pred=[0.2, 0.7, 0.1], 28 | label=0, 29 | k=2, 30 | expected=1.0, 31 | ), 32 | ) 33 | def test_topk_accuracy(self, y_pred, label, k, expected): 34 | self.assertEqual(topk_accuracy.topk_accurate(y_pred, label, k=k), expected) 35 | 36 | @parameterized.named_parameters( 37 | dict( 38 | testcase_name="top_k_accuracy_weight_label1_k1_case0", 39 | y_pred=[0.2, 0.7, 0.1], 40 | label=1, 41 | k=1, 42 | weight=[1.0, 1.0/3.49, 1.0], 43 | expected=1.0, 44 | ), 45 | dict( 46 | testcase_name="top_k_accuracy_weight_label1_k1_case1", 47 | y_pred=[0.2, 0.7, 0.1], 48 | label=1, 49 | k=1, 50 | weight=[1.0, 1.0/3.51, 1.0], 51 | expected=0.0, 52 | ), 53 | dict( 54 | testcase_name="top_k_accuracy_weight_label0_k2", 55 | y_pred=[0.2, 0.7, 0.1], 56 | label=0, 57 | k=2, 58 | weight=[1.0, 1.0/3.51, 1.0], 59 | expected=1.0, 60 | ), 61 | ) 62 | def test_topk_accuracy_with_weights( 63 | self, y_pred, label, k, weight, expected 64 | ): 65 | self.assertEqual( 66 | topk_accuracy.topk_accurate(y_pred, label, weight, k), expected 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /ml_metrics/_src/tools/telemetry/telemetry.py: -------------------------------------------------------------------------------- 1 | """Telemetry for MLAT.""" 2 | 3 | from absl import logging 4 | 5 | 6 | def increment_counter( 7 | api: str, category: str, reference: str, execution_succeed: bool 8 | ): 9 | logging.info( 10 | 'Logging counter: api=%s, category=%s, reference=%s,' 11 | ' execution_succeed=%s', 12 | api, 13 | category, 14 | reference, 15 | execution_succeed, 16 | ) 17 | 18 | 19 | class WithTelemetry: 20 | """Decorator to log usage.""" 21 | 22 | def __init__( 23 | self, 24 | api: str, 25 | category: str, 26 | reference: str, 27 | *, 28 | target_methods: list[str] | str | None = None, 29 | ): 30 | self.api = api 31 | self.category = category 32 | self.reference = reference 33 | self.target_methods = target_methods 34 | 35 | def __call__(self, class_or_func_ref): 36 | increment_counter(self.api, self.category, self.reference, True) 37 | return class_or_func_ref 38 | -------------------------------------------------------------------------------- /ml_metrics/_src/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base types used throughout the library.""" 15 | 16 | import abc 17 | from typing import Any, Protocol, TypeGuard, TypeVar, runtime_checkable 18 | from numpy import typing as npt 19 | 20 | _T = TypeVar('_T') 21 | 22 | 23 | @runtime_checkable 24 | class SinkT(Protocol[_T]): 25 | """A sink that can be used to sink the data.""" 26 | 27 | def write(self, data: _T) -> None: 28 | """Writes the data to the sink.""" 29 | 30 | def close(self) -> None: 31 | """Closes the sink.""" 32 | 33 | 34 | @runtime_checkable 35 | class Makeable(Protocol[_T]): 36 | """A config class that can make a Metric class.""" 37 | 38 | @abc.abstractmethod 39 | def make(self) -> _T: 40 | """Makes a new Metric.""" 41 | 42 | 43 | @runtime_checkable 44 | class Resolvable(Protocol[_T]): 45 | """All Resolvlables implements a `result_` to resolve the underlying value.""" 46 | 47 | @abc.abstractmethod 48 | def result_(self) -> _T: 49 | """Interface to get the result of the underlying value.""" 50 | 51 | 52 | class Shardable(Protocol): 53 | """A sharded data source for chainables.""" 54 | 55 | @abc.abstractmethod 56 | def shard(self, *args, **kwargs): 57 | """Iterates the data source given a shard index and number of shards.""" 58 | 59 | 60 | class Serializable(Protocol): 61 | """An object that can be both serialized and deserialized.""" 62 | 63 | @abc.abstractmethod 64 | def get_config(self): 65 | """Gets the state of the object that can be used to recover the object.""" 66 | 67 | @abc.abstractmethod 68 | def from_config(self, *args, **kwargs): 69 | """Iterates the data source given a shard index and number of shards.""" 70 | 71 | 72 | class Recoverable(Protocol): 73 | """An object that can be both serialized and deserialized.""" 74 | 75 | state: Any 76 | 77 | @abc.abstractmethod 78 | def from_state(self, *args, **kwargs): 79 | """Recover from the state.""" 80 | 81 | 82 | @runtime_checkable 83 | class Stoppable(Protocol): 84 | """An object that can be stopped.""" 85 | 86 | def maybe_stop(self): 87 | """Stops the concurrent threads if possible.""" 88 | 89 | 90 | MaybeResolvable = Resolvable[_T] | _T 91 | 92 | 93 | class RandomAccessible(Protocol[_T]): 94 | 95 | def __getitem__(self, idx: int | slice) -> _T: 96 | """Same as Sequence.__getitem__.""" 97 | 98 | def __len__(self) -> int: 99 | """Same as Sequence.__len__.""" 100 | 101 | 102 | def obj_has_method(obj: Any, method_name: str) -> bool: 103 | """Checks if the object has a method.""" 104 | method = getattr(obj, method_name, False) 105 | return method and getattr(method, '__self__', None) is obj 106 | 107 | 108 | def is_resolvable(obj: Resolvable[_T] | Any) -> TypeGuard[Resolvable[_T]]: 109 | """Checks if the object is a Resolvable.""" 110 | return obj_has_method(obj, 'result_') 111 | 112 | 113 | def is_makeable(obj: Makeable[_T] | Any) -> TypeGuard[Makeable[_T]]: 114 | """Checks if the object is a Makeable.""" 115 | return obj_has_method(obj, 'make') 116 | 117 | 118 | def is_shardable(obj: Shardable | Any) -> TypeGuard[Shardable]: 119 | """Checks if the object is a Shardable.""" 120 | return obj_has_method(obj, 'shard') 121 | 122 | 123 | def is_serializable(obj: Serializable | Any) -> TypeGuard[Serializable]: 124 | """Checks if the object is a Shardable.""" 125 | return obj_has_method(obj, 'get_config') and obj_has_method( 126 | obj, 'from_config' 127 | ) 128 | 129 | 130 | def is_recoverable(obj: Recoverable | Any) -> TypeGuard[Recoverable]: 131 | """Checks if the object is a Shardable.""" 132 | return obj_has_method(obj, 'from_state') and hasattr(obj, 'state') 133 | 134 | 135 | def is_array_like(obj: list[Any] | tuple[Any, ...] | npt.ArrayLike) -> bool: 136 | """Checks if the object is an array-like object.""" 137 | return isinstance(obj, (list, tuple)) or ( 138 | hasattr(obj, '__array__') and obj.ndim > 0 139 | ) 140 | -------------------------------------------------------------------------------- /ml_metrics/_src/types_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl.testing import absltest 15 | from ml_metrics._src import types 16 | 17 | 18 | class BaseTypesTest(absltest.TestCase): 19 | 20 | def test_is_resolvable(self): 21 | 22 | class Foo: 23 | 24 | @classmethod 25 | def result_(cls): 26 | pass 27 | 28 | self.assertIsInstance(Foo(), types.Resolvable) 29 | self.assertFalse(types.is_resolvable(Foo())) 30 | 31 | def test_is_makeable(self): 32 | 33 | class Foo: 34 | 35 | @classmethod 36 | def make(cls): 37 | pass 38 | 39 | self.assertIsInstance(Foo(), types.Makeable) 40 | self.assertFalse(types.is_makeable(Foo())) 41 | 42 | def test_is_shardable(self): 43 | 44 | class Foo: 45 | 46 | @classmethod 47 | def shard(cls, shard_index): 48 | pass 49 | 50 | self.assertFalse(types.is_shardable(Foo())) 51 | 52 | def test_is_serializable(self): 53 | 54 | class Foo: 55 | 56 | def get_config(self): 57 | pass 58 | 59 | def from_config(self, config): 60 | pass 61 | 62 | self.assertTrue(types.is_serializable(Foo())) 63 | 64 | def test_is_recoverable(self): 65 | 66 | class Foo: 67 | 68 | @property 69 | def state(self): 70 | pass 71 | 72 | def from_state(self, state): 73 | pass 74 | 75 | self.assertTrue(types.is_recoverable(Foo())) 76 | 77 | def test_is_not_recoverable(self): 78 | 79 | class Foo1: 80 | 81 | @property 82 | def state(self): 83 | pass 84 | 85 | self.assertFalse(types.is_recoverable(Foo1())) 86 | 87 | class Foo2: 88 | 89 | def from_state(self, state): 90 | pass 91 | 92 | self.assertFalse(types.is_recoverable(Foo2())) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/df_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for generating DataFrames.""" 15 | 16 | from __future__ import annotations 17 | 18 | import collections 19 | from typing import Any 20 | 21 | from ml_metrics._src.chainables import transform 22 | import pandas as pd 23 | 24 | 25 | _METRIC_NAME = 'metric_name' 26 | _SLICE = 'slice' 27 | _VALUE = 'value' 28 | 29 | 30 | def _first_or_tuple(x: tuple[Any, ...]) -> tuple[Any, ...] | Any: 31 | if isinstance(x, tuple) and len(x) == 1: 32 | return x[0] 33 | return x 34 | 35 | 36 | _StrOrMetricKey = transform.MetricKey | str 37 | 38 | 39 | def metrics_to_df(metrics: dict[_StrOrMetricKey, Any]) -> pd.DataFrame: 40 | """Converts the aggregation result to a DataFrame. 41 | 42 | This always converts the dict aggregation result to a DataFrame with 43 | the following columns: 44 | 45 | - metric_name: the name of the metric. 46 | - slice: the slice of the metric, if a slice is not specified, it will be 47 | 'overall'. 48 | - value: the value of the metric. 49 | 50 | Args: 51 | metrics: the aggregation result. 52 | 53 | Returns: 54 | A DataFrame with the above columns. 55 | """ 56 | sliced_results = collections.defaultdict(list) 57 | for k, v in metrics.items(): 58 | if isinstance(k, str): 59 | sliced_results[_METRIC_NAME].append(k) 60 | sliced_results[_SLICE].append('overall') 61 | sliced_results[_VALUE].append(v) 62 | elif isinstance(k, transform.MetricKey): 63 | sliced_results[_METRIC_NAME].append(k.metrics) 64 | slice_name = _first_or_tuple(k.slice.features) 65 | slice_value = _first_or_tuple(k.slice.values) 66 | sliced_results[_SLICE].append(f'{slice_name}={slice_value}') 67 | sliced_results[_VALUE].append(v) 68 | return pd.DataFrame(sliced_results) 69 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/df_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from ml_metrics._src.chainables import transform 17 | from ml_metrics._src.chainables import tree_fns 18 | from ml_metrics._src.utils import df_utils 19 | import pandas as pd 20 | 21 | 22 | class DfUtilsTest(absltest.TestCase): 23 | 24 | def test_as_dataframe_default(self): 25 | key = transform.MetricKey( 26 | metrics='m2', slice=tree_fns.SliceKey(('f1',), ('a',)) 27 | ) 28 | agg_result = {'m1': 1, key: 2} 29 | df = df_utils.metrics_to_df(agg_result) 30 | pd.testing.assert_frame_equal( 31 | df, 32 | pd.DataFrame({ 33 | 'metric_name': ['m1', 'm2'], 34 | 'slice': ['overall', 'f1=a'], 35 | 'value': [1, 2], 36 | }), 37 | ) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/func_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Internal function utils, not meant to be used by users.""" 15 | 16 | import collections 17 | from collections.abc import Iterable, Iterator, Mapping 18 | from concurrent import futures 19 | import copy 20 | import dataclasses as dc 21 | import functools 22 | import itertools as itt 23 | from typing import TypeVar 24 | import weakref 25 | 26 | from absl import logging 27 | import more_itertools as mit 28 | 29 | 30 | _KeyT = TypeVar('_KeyT') 31 | _ValueT = TypeVar('_ValueT') 32 | _T = TypeVar('_T') 33 | 34 | 35 | @dc.dataclass(slots=True, frozen=True) 36 | class _CacheInfo: 37 | hits: int 38 | misses: int 39 | currsize: int 40 | maxsize: int = 0 41 | 42 | 43 | class LruCache(Mapping[_KeyT, _ValueT]): 44 | """A mapping like object for caching with limited size.""" 45 | 46 | def __init__(self, maxsize=128): 47 | self.maxsize = maxsize 48 | self.currsize = 0 49 | self.hits = 0 50 | self.misses = 0 51 | self.data = collections.OrderedDict() 52 | 53 | def __getitem__(self, key): 54 | if key not in self.data: 55 | self.misses += 1 56 | raise KeyError() 57 | self.hits += 1 58 | value = self.data[key] 59 | self.data.move_to_end(key) 60 | return value 61 | 62 | def __setitem__(self, key, value): 63 | key_is_new = key not in self.data 64 | self.data[key] = value 65 | if key_is_new: 66 | self.currsize += 1 67 | self.data.move_to_end(key) 68 | if self.currsize > self.maxsize: 69 | oldest = next(iter(self.data)) 70 | del self.data[oldest] 71 | self.currsize -= 1 72 | 73 | def cache_insert(self, key, value): 74 | self.__setitem__(key, value) 75 | 76 | def __contains__(self, key): 77 | return key in self.data 78 | 79 | def __iter__(self) -> Iterator[_KeyT]: 80 | return iter(self.data) 81 | 82 | def __len__(self) -> int: 83 | return self.currsize 84 | 85 | def cache_clear(self): 86 | self.data.clear() 87 | self.currsize = 0 88 | self.hits = 0 89 | self.misses = 0 90 | 91 | def cache_info(self) -> _CacheInfo: 92 | return _CacheInfo( 93 | hits=self.hits, 94 | misses=self.misses, 95 | maxsize=self.maxsize, 96 | currsize=self.currsize, 97 | ) 98 | 99 | 100 | def lru_cache( 101 | fn=None, 102 | *, 103 | settable_kwargs: Iterable[str] = (), 104 | maxsize: int = 128, 105 | ): 106 | """Cache by the positional and specified keyword arguments.""" 107 | 108 | settable_kwargs = set(settable_kwargs) 109 | 110 | def decorator(fn): 111 | cache_ = LruCache(maxsize=maxsize) 112 | 113 | @functools.wraps(fn) 114 | def wrapped(*args, cache_insert_: bool = False, **kwargs): 115 | is_settable = lambda x: x[0] in settable_kwargs 116 | hashed, settables = mit.partition(is_settable, kwargs.items()) 117 | key = hash(tuple(itt.chain(args, hashed))) 118 | if not cache_insert_ and key in cache_: 119 | result = cache_[key] 120 | else: 121 | logging.debug('chainable: %s', f'cache miss {fn}: {args}, {kwargs}') 122 | result = fn(*args, **kwargs) 123 | cache_[key] = result 124 | result_new = None 125 | for k, v in settables: 126 | if v != getattr(result, k): 127 | if result_new is None: 128 | result_new = copy.copy(result) 129 | setattr(result_new, k, v) 130 | return result_new if result_new is not None else result 131 | 132 | wrapped.cache_info = cache_.cache_info 133 | wrapped.cache_clear = cache_.cache_clear 134 | return wrapped 135 | 136 | return decorator if fn is None else decorator(fn) 137 | 138 | 139 | class SingletonMeta(type): 140 | """A metaclass that makes a class a singleton of any "equivalent" instance. 141 | 142 | The actual class instance has to be hashable to test equivalence. This also 143 | works with inherited classes. 144 | Example: 145 | ``` 146 | @dc.dataclass(frozen=True) 147 | class Foo(metaclass=SingletonMeta): 148 | a: int 149 | b: str = 'b' 150 | 151 | # The following should be true. 152 | assert Foo(1) is Foo(1) 153 | assert Foo(1, 'b') is Foo(1) 154 | assert Foo(1, b='b') is not Foo(1) 155 | assert Foo(2) is not Foo(1) 156 | assert Foo(1, 'a') is not Foo(1) 157 | ``` 158 | """ 159 | 160 | _instances = weakref.WeakKeyDictionary() 161 | 162 | def __call__(cls: type[_T], *args, **kwargs) -> _T: 163 | obj = super(SingletonMeta, cls).__call__(*args, **kwargs) 164 | if (ref := cls._instances.get(obj, None)) and (result := ref()) is not None: 165 | return result 166 | logging.info('chainable: %s', f'singleton {cls.__name__}, {obj}') 167 | cls._instances[obj] = weakref.ref(obj) 168 | return obj 169 | 170 | @property 171 | def all_instances(cls): 172 | return [obj for ref in cls._instances.values() if (obj := ref())] 173 | 174 | 175 | class SingletonThreadPool(futures.ThreadPoolExecutor, metaclass=SingletonMeta): 176 | """A singleton thread pool executor.""" 177 | 178 | def __hash__(self): 179 | return hash((self._max_workers, self._thread_name_prefix)) 180 | 181 | def __eq__(self, other): 182 | return ( 183 | isinstance(other, SingletonThreadPool) 184 | and self._max_workers == other._max_workers 185 | and self._thread_name_prefix == other._thread_name_prefix 186 | ) 187 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/func_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import dataclasses as dc 15 | from absl.testing import absltest 16 | from ml_metrics._src.utils import func_utils 17 | 18 | 19 | @dc.dataclass() 20 | class Foo: 21 | a: int 22 | _b: int = 0 23 | c: int = 0 24 | 25 | @property 26 | def b(self): 27 | return self._b 28 | 29 | @b.setter 30 | def b(self, value): 31 | self._b = value 32 | 33 | def __eq__(self, other): 34 | return self.a == other.a and self.b == other.b and self.c == other.c 35 | 36 | 37 | class CacheByKwargsTest(absltest.TestCase): 38 | 39 | def test_without_kwargs(self): 40 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo) 41 | self.assertEqual(Foo(1, 0, 0), foo_cached(1)) 42 | self.assertEqual(Foo(1, 100, 100), foo_cached(1, b=100, c=100)) 43 | self.assertEqual(foo_cached.cache_info().hits, 1) 44 | 45 | def test_ignore_kwargs(self): 46 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', '_b', 'c'])(Foo) 47 | self.assertEqual(Foo(1, 10, 0), foo_cached(1, _b=10)) 48 | self.assertEqual(Foo(1, 100, 100), foo_cached(1, b=100, c=100)) 49 | self.assertEqual(Foo(1, 10, 0), foo_cached(1, b=10)) 50 | 51 | def test_cache_partial_kwargs(self): 52 | foo_cached = func_utils.lru_cache(settable_kwargs=['b'])(Foo) 53 | self.assertEqual(Foo(1, 0, 10), foo_cached(1, c=10)) 54 | self.assertEqual(Foo(1, 100, 10), foo_cached(1, b=100, c=10)) 55 | self.assertEqual(Foo(1, 10, 10), foo_cached(1, b=10, c=10)) 56 | 57 | def test_cache_info(self): 58 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo) 59 | foo_cached(1, c=10) 60 | foo_cached(1, b=100, c=100) 61 | foo_cached(1, b=10, c=100) 62 | self.assertEqual(foo_cached.cache_info().hits, 2) 63 | self.assertEqual(foo_cached.cache_info().misses, 0) 64 | self.assertEqual(foo_cached.cache_info().currsize, 1) 65 | 66 | def test_cache_insert(self): 67 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo) 68 | foo_cached(1) 69 | self.assertEqual(Foo(1, 0, 1), foo_cached(1, c=1)) 70 | foo_cached(1, c=100, cache_insert_=True) 71 | self.assertEqual(Foo(1, 0, 100), foo_cached(1)) 72 | 73 | def test_attribute_error_raises(self): 74 | 75 | def foo(a, b=1): 76 | return (a, b) 77 | 78 | foo_cached = func_utils.lru_cache(settable_kwargs=['b'])(foo) 79 | with self.assertRaises(AttributeError): 80 | # b is not an attr in the result of foo, thus, cannot uses this kind of 81 | # caching mechanism by setting the uncached attr afterwards. 82 | foo_cached(1, b=10) 83 | 84 | def test_cache_clear(self): 85 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo) 86 | foo_cached(1) 87 | foo_cached(1, b=10, c=100) 88 | foo_cached.cache_clear() 89 | self.assertEqual(foo_cached.cache_info().hits, 0) 90 | self.assertEqual(foo_cached.cache_info().misses, 0) 91 | self.assertEqual(foo_cached.cache_info().currsize, 0) 92 | 93 | 94 | class LruCacheTest(absltest.TestCase): 95 | 96 | def test_insert_and_get_and_clear(self): 97 | cache = func_utils.LruCache(maxsize=128) 98 | cache['a'] = 1 99 | self.assertEqual(cache['a'], 1) 100 | self.assertLen(cache, 1) 101 | self.assertEqual(1, cache.cache_info().currsize) 102 | self.assertEqual(1, cache.cache_info().hits) 103 | self.assertEqual(128, cache.cache_info().maxsize) 104 | cache.cache_clear() 105 | self.assertEmpty(cache) 106 | 107 | def test_cache_evict(self): 108 | cache = func_utils.LruCache(maxsize=2) 109 | cache['a'] = 1 110 | cache['b'] = 2 111 | # This pushes 'a' to the latest. 112 | self.assertEqual(1, cache['a']) 113 | # This causes 'b' to be evicted. 114 | cache['c'] = 3 115 | self.assertEqual(['a', 'c'], list(cache)) 116 | 117 | def test_missing_key(self): 118 | cache = func_utils.LruCache(maxsize=128) 119 | with self.assertRaises(KeyError): 120 | _ = cache['a'] 121 | 122 | def test_cache_insert(self): 123 | cache = func_utils.LruCache(maxsize=128) 124 | cache.cache_insert('a', 1) 125 | cache.cache_insert('a', 2) 126 | self.assertEqual(cache['a'], 2) 127 | self.assertLen(cache, 1) 128 | 129 | def test_cache_iter(self): 130 | cache = func_utils.LruCache(maxsize=128) 131 | cache.cache_insert('a', 1) 132 | cache.cache_insert('b', 2) 133 | cache.cache_insert('a', 3) 134 | # The order of the reference here will push 'a' to later than b 135 | self.assertEqual(cache['b'], 2) 136 | self.assertEqual(cache['a'], 3) 137 | self.assertEqual(['b', 'a'], list(cache)) 138 | 139 | 140 | @dc.dataclass(frozen=True, eq=True) 141 | class SingletonA(metaclass=func_utils.SingletonMeta): 142 | a: int 143 | 144 | def __eq__(self, other): 145 | return isinstance(other, SingletonA) and self.a == other.a 146 | 147 | 148 | @dc.dataclass(frozen=True, eq=True) 149 | class SingletonB(SingletonA): 150 | # Essentially the same class that will be hashed the same. 151 | b: str = 'b' 152 | 153 | def __eq__(self, other): 154 | return ( 155 | super().__eq__(other) 156 | and isinstance(other, SingletonB) 157 | and self.b == other.b 158 | ) 159 | 160 | 161 | class SingletonMetaTest(absltest.TestCase): 162 | 163 | def test_singleton(self): 164 | self.assertIs(SingletonA(1), SingletonA(1)) 165 | self.assertIsNot(SingletonA(2), SingletonA(1)) 166 | self.assertIsNot(SingletonA(1), SingletonB(1)) 167 | self.assertIs(SingletonB(1, 'b'), SingletonB(1)) 168 | self.assertIs(SingletonB(1, b='b'), SingletonB(1)) 169 | self.assertIsNot(SingletonB(1, b='b1'), SingletonB(1)) 170 | 171 | def test_singleton_all_instances(self): 172 | a = SingletonA(1) 173 | b = SingletonB(1) 174 | self.assertSameElements(SingletonB.all_instances, SingletonA.all_instances) 175 | instances = SingletonB.all_instances 176 | self.assertLen(instances, 2) 177 | self.assertIn(a, instances) 178 | self.assertIn(b, instances) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Math Utils.""" 15 | 16 | from ml_metrics._src import types 17 | from ml_metrics._src.aggregates import types as agg_types 18 | import numpy as np 19 | 20 | 21 | def pos_sqrt(value) -> agg_types.NumbersT: 22 | """Returns sqrt of value or raises ValueError if negative.""" 23 | if np.any(value < 0): 24 | raise ValueError('Attempt to take sqrt of negative value: {}'.format(value)) 25 | return np.sqrt(value) 26 | 27 | 28 | def safe_divide(a, b) -> agg_types.NumbersT: 29 | """Divide arguments element-wise (a / b), but returns zero(s) if b is 0.""" 30 | result = np.divide( 31 | a, b, out=np.zeros_like(a, dtype=agg_types.DefaultDType), where=(b != 0) 32 | ) 33 | 34 | return result.item() if result.ndim == 0 else result 35 | 36 | 37 | def safe_to_scalar(arr: agg_types.NumbersT) -> agg_types.NumbersT: 38 | """Returns tuple, list, or np.ndarray as a scalar. Returns 0.0 if empty. 39 | 40 | Originally from tensorflow_model_analysis/metrics/metric_util.py 41 | 42 | Args: 43 | arr: A one element tuple, list, or numpy.ndarray to be converted to a 44 | scalar. 45 | 46 | Returns: 47 | The Python scalar. 48 | """ 49 | if isinstance(arr, np.ndarray): 50 | if arr.size == 0: 51 | # 0 elements. 52 | return 0.0 53 | if arr.size == 1: 54 | # 1 element. 55 | return arr.item() 56 | else: 57 | # arr is tuple or list. 58 | if not arr: 59 | # 0 elements. 60 | return 0.0 61 | if len(arr) == 1: 62 | # 1 element. 63 | return arr[0] 64 | 65 | # >1 element. 66 | raise ValueError('Array should have exactly 1 value to a Python scalar') 67 | 68 | 69 | def where( 70 | condition: agg_types.NumbersT, x: agg_types.NumbersT, y: agg_types.NumbersT 71 | ): 72 | """Like np.where, but directly returns itself if condition is scalar.""" 73 | if types.is_array_like(condition): 74 | return np.where(condition, x, y) 75 | # If x or y is scalar, return it directly to keep its type (vs. np.where 76 | # that wraps it in a one-element array). 77 | return x if condition else y 78 | 79 | 80 | def nanadd(a: agg_types.NumbersT, b: agg_types.NumbersT) -> agg_types.NumbersT: 81 | """Returns element-wise a + b, but ignores NaN as 0 unless both are NaN.""" 82 | a_nan, b_nan = np.isnan(a), np.isnan(b) 83 | result = where(a_nan, 0, a) + where(b_nan, 0, b) 84 | # Revovers the NaN values if both operands are NaNs. 85 | return where(a_nan & b_nan, np.nan, result) 86 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/math_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Math Utils.""" 15 | 16 | from collections.abc import Sequence 17 | 18 | from ml_metrics._src.utils import math_utils 19 | import numpy as np 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | 25 | class MathUtilsTest(parameterized.TestCase): 26 | 27 | def test_pos_sqrt(self): 28 | self.assertEqual(math_utils.pos_sqrt(4), 2.0) 29 | 30 | def test_pos_sqrt_raises_value_error(self): 31 | with self.assertRaisesRegex( 32 | ValueError, 'Attempt to take sqrt of negative value: -1' 33 | ): 34 | math_utils.pos_sqrt(-1) 35 | 36 | @parameterized.named_parameters( 37 | ('both_zero', 0, 0, 0.0), 38 | ('zero_num', 0, 10, 0.0), 39 | ('zero_denom', 10, 0, 0.0), 40 | ('float_num', 10.5, 3, 3.5), 41 | ('float_denom', 14, 3.5, 4.0), 42 | ('array_num', [2, 4, 6, 8], 2, [1.0, 2.0, 3.0, 4.0]), 43 | ('array_num_denom', [2, 4, 6, 8], [4, 8, 12, 16], [0.5, 0.5, 0.5, 0.5]), 44 | ) 45 | def test_safe_divide(self, a, b, expected_result): 46 | result = math_utils.safe_divide(a, b) 47 | if isinstance(result, (Sequence, np.ndarray)): 48 | self.assertSequenceAlmostEqual(result, expected_result) 49 | else: 50 | self.assertAlmostEqual(result, expected_result) 51 | 52 | # Original Tests safe_to_scalar tests from: 53 | # tensorflow_model_analysis/metrics/metric_util_test.py 54 | @parameterized.named_parameters( 55 | dict(testcase_name='tuple_int', arr=(1,), expected_scalar=1), 56 | dict(testcase_name='tuple_float', arr=(1.0,), expected_scalar=1.0), 57 | dict( 58 | testcase_name='tuple_string', 59 | arr=('string',), 60 | expected_scalar='string', 61 | ), 62 | dict(testcase_name='list_int', arr=[1], expected_scalar=1), 63 | dict(testcase_name='list_float', arr=[1.0], expected_scalar=1.0), 64 | dict( 65 | testcase_name='list_string', arr=['string'], expected_scalar='string' 66 | ), 67 | dict(testcase_name='np_array_int', arr=np.array(1), expected_scalar=1), 68 | dict( 69 | testcase_name='np_array_float', arr=np.array(1.0), expected_scalar=1.0 70 | ), 71 | dict( 72 | testcase_name='np_array_string', 73 | arr=np.array('string'), 74 | expected_scalar='string', 75 | ), 76 | ) 77 | def test_safe_to_scalar_unpack_1d_iterable(self, arr, expected_scalar): 78 | self.assertEqual(math_utils.safe_to_scalar(arr), expected_scalar) 79 | 80 | @parameterized.named_parameters( 81 | dict(testcase_name='empty_tuple', arr=()), 82 | dict(testcase_name='empty_list', arr=[]), 83 | dict(testcase_name='empty_np_array', arr=np.array([])), 84 | ) 85 | def test_safe_to_scalar_empty_iterable(self, arr): 86 | self.assertEqual(math_utils.safe_to_scalar(arr), 0.0) 87 | 88 | @parameterized.named_parameters( 89 | dict(testcase_name='2_elem_tuple', arr=(1, 2)), 90 | dict(testcase_name='2_elem_list', arr=[1, 2]), 91 | dict(testcase_name='2_elem_np_array', arr=np.array((1, 2))), 92 | ) 93 | def test_safe_to_scalar_multi_elem_iterable_raises_value_error(self, arr): 94 | with self.assertRaisesRegex( 95 | ValueError, 'Array should have exactly 1 value to a Python scalar' 96 | ): 97 | math_utils.safe_to_scalar(arr) 98 | 99 | @parameterized.named_parameters( 100 | dict( 101 | testcase_name='partial_nans', 102 | a=[1, 0], 103 | b=[1, np.nan], 104 | expected=[2, 0], 105 | ), 106 | dict( 107 | testcase_name='both_nans', 108 | a=[1, np.nan], 109 | b=[1, np.nan], 110 | expected=[2, np.nan], 111 | ), 112 | dict( 113 | testcase_name='constant_broadcast', 114 | a=[np.nan, 0], 115 | b=3, 116 | expected=[3, 3], 117 | ), 118 | dict( 119 | testcase_name='nan_broadcast', 120 | a=[1, 0], 121 | b=np.nan, 122 | expected=[1, 0], 123 | ), 124 | dict( 125 | testcase_name='scalar', 126 | a=1, 127 | b=np.nan, 128 | expected=1, 129 | ), 130 | dict( 131 | testcase_name='scalar_nans', 132 | a=np.nan, 133 | b=np.nan, 134 | expected=np.nan, 135 | ), 136 | ) 137 | def test_nanadd(self, a, b, expected): 138 | actual = math_utils.nanadd(a, b) 139 | np.testing.assert_allclose(actual, expected) 140 | 141 | def test_where(self): 142 | self.assertEqual(1, math_utils.where(True, 1, 0)) 143 | self.assertEqual([0, 1], math_utils.where(False, 1, [0, 1])) 144 | actual = math_utils.where([True, False], [1, 1], [0, 0]) 145 | np.testing.assert_array_equal([1, 0], actual) 146 | 147 | 148 | if __name__ == '__main__': 149 | absltest.main() 150 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/proto_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Proto utils.""" 15 | from collections.abc import Iterable 16 | from typing import Any 17 | from absl import logging 18 | import more_itertools as mit 19 | import numpy as np 20 | from tensorflow.core.example import example_pb2 21 | 22 | _ExampleOrBytes = bytes | example_pb2.Example 23 | 24 | 25 | def _maybe_deserialize(ex: _ExampleOrBytes) -> example_pb2.Example: 26 | if isinstance(ex, bytes): 27 | return example_pb2.Example.FromString(ex) 28 | if isinstance(ex, example_pb2.Example): 29 | return ex 30 | raise TypeError('Unsupported type: %s' % type(ex)) 31 | 32 | 33 | def tf_examples_to_dict(examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes): 34 | """Parses a serialized tf.train.Example to a dict.""" 35 | single_example = False 36 | if isinstance(examples, (bytes, example_pb2.Example)): 37 | single_example = True 38 | examples = [examples] 39 | examples = (_maybe_deserialize(ex) for ex in examples) 40 | examples = mit.peekable(examples) 41 | if (head := examples.peek(None)) is None: 42 | return {} 43 | 44 | result = {k: [] for k in head.features.feature} 45 | for ex in examples: 46 | missing = set(result) 47 | for key, feature in ex.features.feature.items(): 48 | missing.remove(key) 49 | value = getattr(feature, feature.WhichOneof('kind')).value 50 | if value and isinstance(value[0], bytes): 51 | try: 52 | value = [v.decode() for v in value] 53 | except UnicodeDecodeError: 54 | logging.info( 55 | 'chainable: %s', 56 | f'Failed to decode for {key}, forward the raw bytes.', 57 | ) 58 | result[key].extend(value) 59 | if missing: 60 | raise ValueError( 61 | f'Missing keys: {missing}, expecting {set(result)}, got {ex=}' 62 | ) 63 | result = {k: v for k, v in result.items()} 64 | # Scalar value in a single example will be returned with the scalar directly. 65 | if single_example and all(len(v) == 1 for v in result.values()): 66 | result = {k: v[0] for k, v in result.items()} 67 | return result 68 | 69 | 70 | def dict_to_tf_example(data: dict[str, Any]) -> example_pb2.Example: 71 | """Creates a tf.Example from a dictionary.""" 72 | example = example_pb2.Example() 73 | for key, value in data.items(): 74 | if isinstance(value, (str, bytes, np.floating, float, int, np.integer)): 75 | value = [value] 76 | feature = example.features.feature 77 | if isinstance(value[0], str): 78 | for v in value: 79 | assert isinstance(v, str), f'bad str type: {value}' 80 | feature[key].bytes_list.value.append(v.encode()) 81 | elif isinstance(value[0], bytes): 82 | feature[key].bytes_list.value.extend(value) 83 | elif isinstance(value[0], (int, np.integer)): 84 | feature[key].int64_list.value.extend(value) 85 | elif isinstance(value[0], (float, np.floating)): 86 | feature[key].float_list.value.extend(value) 87 | else: 88 | raise TypeError(f'Value for "{key}" is not a supported type.') 89 | return example 90 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/proto_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from ml_metrics._src.utils import proto_utils 15 | from ml_metrics._src.utils import test_utils 16 | import numpy as np 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tensorflow.core.example import example_pb2 20 | 21 | 22 | def _get_tf_example(**kwargs): 23 | example = example_pb2.Example() 24 | for k, v in kwargs.items(): 25 | example.features.feature[k].bytes_list.value.append(v) 26 | return example 27 | 28 | 29 | class TFExampleTest(parameterized.TestCase): 30 | 31 | def test_single_example(self): 32 | data = { 33 | 'bytes_key': b'\x80abc', # not utf-8 decodable 34 | 'str_key': 'str_test', 35 | 'init_key': 123, 36 | 'np_int': np.int32(123), 37 | 'float_key': 4.56, 38 | 'np_float': np.float32(123), 39 | } 40 | e = proto_utils.dict_to_tf_example(data).SerializeToString() 41 | actual = proto_utils.tf_examples_to_dict(e) 42 | self.assertDictAlmostEqual(data, actual, places=6) 43 | 44 | def test_batch_example(self): 45 | data = { 46 | 'bytes_key': [b'\x80abc', b'\x80def'], # not utf-8 decodable 47 | 'str_key': ['str_test', 'str_test2'], 48 | 'init_key': [123, 456], 49 | 'np_int': [np.int32(123), np.int32(456)], 50 | 'float_key': [4.56, 7.89], 51 | 'np_float': [np.float32(123), np.float32(456)], 52 | } 53 | e = proto_utils.dict_to_tf_example(data) 54 | actual = proto_utils.tf_examples_to_dict(e) 55 | test_utils.assert_nested_container_equal(self, data, actual, places=6) 56 | 57 | @parameterized.named_parameters( 58 | dict( 59 | testcase_name='with_single_example', 60 | num_elems=1, 61 | ), 62 | dict( 63 | testcase_name='multiple_examples', 64 | num_elems=3, 65 | ), 66 | ) 67 | def test_multiple_examples_as_batch(self, num_elems): 68 | data = { 69 | 'bytes_key': b'\x80abc', # not utf-8 decodable 70 | 'str_key': 'str_test', 71 | 'init_key': 123, 72 | 'np_int': np.int32(123), 73 | 'float_key': 4.56, 74 | 'np_float': np.float32(123), 75 | } 76 | e = [proto_utils.dict_to_tf_example(data) for _ in range(num_elems)] 77 | actual = proto_utils.tf_examples_to_dict(e) 78 | expected = {k: [v] * num_elems for k, v in data.items()} 79 | test_utils.assert_nested_container_equal(self, expected, actual, places=6) 80 | 81 | def test_empty_example(self): 82 | self.assertEmpty(proto_utils.tf_examples_to_dict([])) 83 | 84 | def test_unsupported_type(self): 85 | with self.assertRaisesRegex(TypeError, 'Unsupported type'): 86 | proto_utils.tf_examples_to_dict('unsupported_type') 87 | 88 | def test_unsupported_value_type(self): 89 | with self.assertRaisesRegex( 90 | TypeError, 'Value for "a" is not a supported type' 91 | ): 92 | proto_utils.dict_to_tf_example({'a': [example_pb2.Example()]}) 93 | 94 | def test_multiple_examples_missing_key(self): 95 | data = [{'a': 'a', 'b': 1}, {'b': 2}] 96 | examples = [proto_utils.dict_to_tf_example(d) for d in data] 97 | with self.assertRaisesRegex(ValueError, 'Missing keys'): 98 | _ = proto_utils.tf_examples_to_dict(examples) 99 | 100 | 101 | if __name__ == '__main__': 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for testing, internal use only.""" 15 | 16 | from collections.abc import Iterable 17 | import time 18 | import unittest 19 | 20 | from ml_metrics._src.aggregates import rolling_stats 21 | from ml_metrics._src.chainables import transform 22 | import numpy as np 23 | 24 | 25 | class TestSink: 26 | 27 | def __init__(self): 28 | self.data = [] 29 | self.closed = False 30 | 31 | def write(self, data): 32 | self.data.append(data) 33 | return data 34 | 35 | def close(self): 36 | self.closed = True 37 | 38 | 39 | class Unpickleable: 40 | def __call__(self, x): 41 | return x + 1 42 | 43 | def __getstate__(self): 44 | raise TypeError('Unpickleable') 45 | 46 | 47 | def range_with_exc(x, exc_i): 48 | for i in range(x): 49 | if i == exc_i: 50 | raise ValueError(f'range_with_exc at {i}') 51 | yield i 52 | 53 | 54 | def range_with_sleep(n, sleep: float = 0): 55 | for i in range(n): 56 | if sleep: 57 | time.sleep(sleep) 58 | yield i 59 | 60 | 61 | def range_with_return(n, return_value: int = 0): 62 | return_value = return_value or n 63 | for i in range(n): 64 | yield i 65 | return return_value 66 | 67 | 68 | def inf_range(n): 69 | yield from range(n) 70 | raise ValueError( 71 | 'The runner does not iterate and copy everything into memory, it should' 72 | ' never be exhausted.' 73 | ) 74 | 75 | 76 | class NoLenIter(Iterable): 77 | 78 | def __init__(self, iterable): 79 | self._iteratable = iterable 80 | 81 | def __len__(self): 82 | raise ValueError('Cannot call len()') 83 | 84 | def __iter__(self): 85 | return iter(self._iteratable) 86 | 87 | 88 | class SequenceWithExc: 89 | """A range that raises an exception at a specific value.""" 90 | 91 | def __init__(self, end: int, exc_i: int, error_type=ValueError): 92 | self.end = end 93 | self.exc_i = exc_i 94 | self.error_type = error_type 95 | 96 | def __len__(self): 97 | return self.end 98 | 99 | def __getitem__(self, i): 100 | if i == self.exc_i: 101 | raise self.error_type(f'SequenceWithExc at {i}') 102 | return i 103 | 104 | 105 | def assert_nested_container_equal( 106 | test: unittest.TestCase, a, b, *, strict=False, places=7 107 | ): 108 | """Asserts that two nested containers are equal.""" 109 | try: 110 | if strict: 111 | test.assertEqual(type(a), type(b)) 112 | if isinstance(a, dict) and isinstance(b, dict): 113 | for (k_a, v_a), (k_b, v_b) in zip( 114 | sorted(a.items()), sorted(b.items()), strict=True 115 | ): 116 | test.assertEqual(k_a, k_b) 117 | assert_nested_container_equal( 118 | test, v_a, v_b, strict=strict, places=places 119 | ) 120 | elif isinstance(a, str) and isinstance(b, str): 121 | test.assertEqual(a, b) 122 | elif hasattr(a, '__array__') and hasattr(b, '__array__'): 123 | np.testing.assert_allclose(a, b) 124 | elif isinstance(a, Iterable) and isinstance(b, Iterable): 125 | for a_elem, b_elem in zip(a, b, strict=True): 126 | assert_nested_container_equal( 127 | test, a_elem, b_elem, strict=strict, places=places 128 | ) 129 | else: 130 | test.assertAlmostEqual(a, b, places=places) 131 | except Exception as e: # pylint: disable=broad-except 132 | test.fail(f'Failed to compare {a} and {b}: {e}') 133 | 134 | 135 | def sharded_ones( 136 | total_numbers: int, 137 | batch_size: int, 138 | shard_index: int = 0, 139 | num_shards: int = 1, 140 | ): 141 | num_batches, remainder = divmod(total_numbers, batch_size) 142 | for i in range(num_batches): 143 | if i % num_shards == shard_index: 144 | yield batch_size 145 | if not shard_index and remainder: 146 | yield remainder 147 | 148 | 149 | def sharded_pipeline( 150 | total_numbers: int, 151 | batch_size: int, 152 | shard_index: int = 0, 153 | num_shards: int = 1, 154 | fuse_aggregate: bool = True, 155 | num_threads: int = 0, 156 | ): 157 | """A pipeline to calculate the stats of batches of random integers.""" 158 | data_pipeline = transform.TreeTransform.new(name='datasource').data_source( 159 | sharded_ones( 160 | total_numbers, 161 | batch_size=batch_size, 162 | shard_index=shard_index, 163 | num_shards=num_shards, 164 | ) 165 | ) 166 | apply_pipeline = transform.TreeTransform.new( 167 | name='apply', num_threads=num_threads 168 | ).apply( 169 | fn=lambda batch_size: np.random.randint(1, 100, size=batch_size), 170 | ) 171 | 172 | if fuse_aggregate: 173 | return data_pipeline.chain( 174 | apply_pipeline.aggregate( 175 | output_keys='stats', 176 | fn=rolling_stats.MeanAndVariance().as_agg_fn(), 177 | ) 178 | ) 179 | return data_pipeline.chain(apply_pipeline).chain( 180 | transform.TreeTransform.new(name='agg').aggregate( 181 | output_keys='stats', 182 | fn=rolling_stats.MeanAndVariance().as_agg_fn(), 183 | ) 184 | ) 185 | -------------------------------------------------------------------------------- /ml_metrics/_src/utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from ml_metrics._src.utils import test_utils 15 | from absl.testing import absltest 16 | 17 | 18 | class TestUtilsTest(absltest.TestCase): 19 | 20 | def test_inf_range(self): 21 | with self.assertRaises(ValueError): 22 | _ = list(test_utils.inf_range(10)) 23 | 24 | def test_nolen_iter(self): 25 | with self.assertRaises(ValueError): 26 | if test_utils.NoLenIter(range(10)): 27 | pass 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /ml_metrics/aggregates.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Aggregation interfaces.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.aggregates import utils 19 | from ml_metrics._src.aggregates.base import AggFnNested 20 | from ml_metrics._src.aggregates.base import Aggregatable 21 | from ml_metrics._src.aggregates.base import AggregateFn 22 | from ml_metrics._src.aggregates.base import as_agg_fn 23 | from ml_metrics._src.aggregates.base import CallableMetric 24 | from ml_metrics._src.aggregates.base import MergeableMetric 25 | from ml_metrics._src.aggregates.base import MergeableMetricAggFn 26 | from ml_metrics._src.aggregates.base import Metric 27 | from ml_metrics._src.aggregates.base import UserAggregateFn 28 | from ml_metrics._src.aggregates.keras_metric_wrapper import is_keras_metric 29 | from ml_metrics._src.aggregates.keras_metric_wrapper import KerasAggregateFn 30 | from ml_metrics._src.aggregates.types import AverageType 31 | from ml_metrics._src.aggregates.types import InputType 32 | -------------------------------------------------------------------------------- /ml_metrics/chainable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Chainable interfaces.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.chainables.lazy_fns import cache_info 19 | from ml_metrics._src.chainables.lazy_fns import clear_cache 20 | from ml_metrics._src.chainables.lazy_fns import LazyFn 21 | from ml_metrics._src.chainables.lazy_fns import LazyObject 22 | from ml_metrics._src.chainables.lazy_fns import makeables 23 | from ml_metrics._src.chainables.lazy_fns import maybe_make 24 | from ml_metrics._src.chainables.lazy_fns import pickler 25 | from ml_metrics._src.chainables.lazy_fns import trace 26 | from ml_metrics._src.chainables.transform import AggregateResult 27 | from ml_metrics._src.chainables.transform import AggregateTransform 28 | from ml_metrics._src.chainables.transform import iterate_with_returned 29 | from ml_metrics._src.chainables.transform import MetricKey 30 | from ml_metrics._src.chainables.transform import RunnerMode 31 | from ml_metrics._src.chainables.transform import TreeTransform as Pipeline 32 | from ml_metrics._src.chainables.tree import Key 33 | from ml_metrics._src.chainables.tree import tree_shape 34 | from ml_metrics._src.chainables.tree import TreeMapView 35 | from ml_metrics._src.utils import test_utils 36 | from ml_metrics._src.utils.iter_utils import iterate_fn 37 | -------------------------------------------------------------------------------- /ml_metrics/data_sinks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """All datasinks for ML Metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics.google.data_sinks.hydra import HydraPropertyWriter 19 | -------------------------------------------------------------------------------- /ml_metrics/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/ml-metrics/510ebbc9ab70847d8aff3eca7cb4d8a3309e6319/ml_metrics/metrics/__init__.py -------------------------------------------------------------------------------- /ml_metrics/metrics/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classification metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.aggregates.classification import ConfusionMatrixAggFn 19 | from ml_metrics._src.aggregates.classification import ConfusionMatrixMetric 20 | from ml_metrics._src.aggregates.classification import SamplewiseClassification 21 | from ml_metrics._src.aggregates.classification import TopKConfusionMatrixAggFn 22 | from ml_metrics._src.metrics.classification import accuracy 23 | from ml_metrics._src.metrics.classification import balanced_accuracy 24 | from ml_metrics._src.metrics.classification import binary_accuracy 25 | from ml_metrics._src.metrics.classification import CalibrationHistogram 26 | from ml_metrics._src.metrics.classification import classification_metrics 27 | from ml_metrics._src.metrics.classification import ClassificationAggFn 28 | from ml_metrics._src.metrics.classification import diagnostic_odds_ratio 29 | from ml_metrics._src.metrics.classification import f1_score 30 | from ml_metrics._src.metrics.classification import fall_out 31 | from ml_metrics._src.metrics.classification import false_discovery_rate 32 | from ml_metrics._src.metrics.classification import false_omission_rate 33 | from ml_metrics._src.metrics.classification import fnr 34 | from ml_metrics._src.metrics.classification import fpr 35 | from ml_metrics._src.metrics.classification import informedness 36 | from ml_metrics._src.metrics.classification import intersection_over_union 37 | from ml_metrics._src.metrics.classification import markedness 38 | from ml_metrics._src.metrics.classification import matthews_correlation_coefficient 39 | from ml_metrics._src.metrics.classification import miss_rate 40 | from ml_metrics._src.metrics.classification import negative_likelihood_ratio 41 | from ml_metrics._src.metrics.classification import negative_prediction_value 42 | from ml_metrics._src.metrics.classification import nvp 43 | from ml_metrics._src.metrics.classification import positive_likelihood_ratio 44 | from ml_metrics._src.metrics.classification import positive_predictive_value 45 | from ml_metrics._src.metrics.classification import ppv 46 | from ml_metrics._src.metrics.classification import precision 47 | from ml_metrics._src.metrics.classification import prevalence 48 | from ml_metrics._src.metrics.classification import prevalence_threshold 49 | from ml_metrics._src.metrics.classification import recall 50 | from ml_metrics._src.metrics.classification import sensitivity 51 | from ml_metrics._src.metrics.classification import specificity 52 | from ml_metrics._src.metrics.classification import threat_score 53 | from ml_metrics._src.metrics.classification import tnr 54 | from ml_metrics._src.metrics.classification import tpr 55 | -------------------------------------------------------------------------------- /ml_metrics/metrics/retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Retrieval metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.aggregates.retrieval import RetrievalMetric 19 | from ml_metrics._src.aggregates.retrieval import TopKRetrieval 20 | from ml_metrics._src.aggregates.retrieval import TopKRetrievalAggFn 21 | from ml_metrics._src.metrics.retrieval import accuracy 22 | from ml_metrics._src.metrics.retrieval import dcg_score 23 | from ml_metrics._src.metrics.retrieval import f1_score 24 | from ml_metrics._src.metrics.retrieval import false_discovery_rate 25 | from ml_metrics._src.metrics.retrieval import fowlkes_mallows_index 26 | from ml_metrics._src.metrics.retrieval import intersection_over_union 27 | from ml_metrics._src.metrics.retrieval import mean_average_precision 28 | from ml_metrics._src.metrics.retrieval import mean_reciprocal_rank 29 | from ml_metrics._src.metrics.retrieval import miss_rate 30 | from ml_metrics._src.metrics.retrieval import ndcg_score 31 | from ml_metrics._src.metrics.retrieval import positive_predictive_value 32 | from ml_metrics._src.metrics.retrieval import ppv 33 | from ml_metrics._src.metrics.retrieval import precision 34 | from ml_metrics._src.metrics.retrieval import recall 35 | from ml_metrics._src.metrics.retrieval import sensitivity 36 | from ml_metrics._src.metrics.retrieval import threat_score 37 | from ml_metrics._src.metrics.retrieval import topk_retrieval_metrics 38 | from ml_metrics._src.metrics.retrieval import tpr 39 | -------------------------------------------------------------------------------- /ml_metrics/metrics/rolling_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Statistics metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.aggregates.rolling_stats import Counter 19 | from ml_metrics._src.aggregates.rolling_stats import FixedSizeSample 20 | from ml_metrics._src.aggregates.rolling_stats import Histogram 21 | from ml_metrics._src.aggregates.rolling_stats import Mean 22 | from ml_metrics._src.aggregates.rolling_stats import MeanAndVariance 23 | from ml_metrics._src.aggregates.rolling_stats import MinMaxAndCount 24 | from ml_metrics._src.aggregates.rolling_stats import R2Tjur 25 | from ml_metrics._src.aggregates.rolling_stats import R2TjurRelative 26 | from ml_metrics._src.aggregates.rolling_stats import RRegression 27 | from ml_metrics._src.aggregates.rolling_stats import SymmetricPredictionDifference 28 | from ml_metrics._src.aggregates.rolling_stats import UnboundedSampler 29 | from ml_metrics._src.aggregates.rolling_stats import Var 30 | from ml_metrics._src.metrics.rolling_stats import count 31 | from ml_metrics._src.metrics.rolling_stats import mean 32 | from ml_metrics._src.metrics.rolling_stats import stddev 33 | from ml_metrics._src.metrics.rolling_stats import total 34 | from ml_metrics._src.metrics.rolling_stats import var 35 | -------------------------------------------------------------------------------- /ml_metrics/metrics/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Text metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | # Eventually move things under /aggregates/text.py to metrics/text.py. Making 19 | # the classes callable and replace the functions in metrics/text.py 20 | from ml_metrics._src.aggregates.text import PatternFrequency 21 | from ml_metrics._src.aggregates.text import TopKWordNGrams 22 | from ml_metrics._src.metrics.text import avg_alphabetical_char_count 23 | from ml_metrics._src.metrics.text import pattern_frequency 24 | from ml_metrics._src.metrics.text import topk_word_ngrams 25 | -------------------------------------------------------------------------------- /ml_metrics/metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Metrics utility.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.metrics.utils import verify_input 19 | -------------------------------------------------------------------------------- /ml_metrics/oss/README.md: -------------------------------------------------------------------------------- 1 | # Steps to build a new py-ml-metrics pip package 2 | 3 | 1. Update the version number in project.toml 4 | 5 | 2. To build pypi wheel, run: 6 | 7 | ``` 8 | cd 9 | sh ml_metrics/oss/runner.sh 10 | ``` 11 | 12 | 3. Wheels are in `/tmp/ml_metrics/all_dist`. 13 | 14 | 4. Upload to PyPI: 15 | 16 | ``` 17 | python3 -m pip install --upgrade twine 18 | python3 -m twine upload /tmp/ml_metrics/all_dist/*-any.whl 19 | ``` 20 | 21 | Authenticate with Twine by following https://pypi.org/help/#apitoken and editing 22 | your `~/.pypirc`. 23 | 24 | 5. Draft the new release in github: https://github.com/google/ml-metrics/releases. 25 | Tag the release commit with the version number. 26 | 27 | # Optional 28 | 29 | * Update workflows in oss/workflows when Python is upgraded. 30 | 31 | * To build for a different python version, change the PYTHON_MINOR_VERSION 32 | and/or PYTHON_MAJOR_VERSION in `ml_metrics/oss/runner.sh`. 33 | 34 | * To use a different docker image, switch it out under 35 | `ml_metrics/oss/build.Dockerfile`. 36 | 37 | * All the dependencies have to be installed manually in 38 | `ml_metrics/oss/build.Dockerfile`. 39 | 40 | * Refer the required dependencies from the dependencies section of `pyproject.toml`. 41 | 42 | * When you do not need to update version number, you can manually adds or 43 | increment build number by renaming the whl files under 44 | `/tmp/ml_metrics/all_dist` following the format of 45 | `py_ml_metrics-{version}-[{build version}]-{py version}-non-any.whl` 46 | e.g.,: 47 | 48 | ``` 49 | mv py_ml_metrics-0.0.1-py310-none-any.whl py_ml_metrics-0.0.1-1-py310-none-any.whl 50 | ``` 51 | -------------------------------------------------------------------------------- /ml_metrics/oss/build.Dockerfile: -------------------------------------------------------------------------------- 1 | # Constructs the environment within which we will build the ml-metrics pip wheels. 2 | # 3 | # From /tmp/ml_metrics, 4 | # ❯ DOCKER_BUILDKIT=1 docker build \ 5 | # --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ 6 | # -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile 7 | # ❯ docker run --rm -it -v /tmp/ml_metrics:/tmp/ml_metrics \ 8 | # ml_metrics:${PYTHON_VERSION} bash 9 | 10 | FROM quay.io/pypa/manylinux2014_x86_64 11 | LABEL maintainer="ml-metrics team " 12 | 13 | ARG PYTHON_MAJOR_VERSION 14 | ARG PYTHON_MINOR_VERSION 15 | ARG PYTHON_VERSION 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN ulimit -n 1024 && yum install -y rsync 20 | 21 | ENV PATH="/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin:${PATH}" 22 | 23 | # Install dependencies needed for ml-metrics 24 | RUN --mount=type=cache,target=/root/.cache \ 25 | python${PYTHON_VERSION} -m pip install -U \ 26 | absl-py \ 27 | build \ 28 | cloudpickle \ 29 | more-itertools\ 30 | numpy; 31 | 32 | # Install dependencies needed for ml-metrics tests 33 | RUN --mount=type=cache,target=/root/.cache \ 34 | python${PYTHON_VERSION} -m pip install -U \ 35 | auditwheel; 36 | 37 | WORKDIR "/tmp/ml_metrics" -------------------------------------------------------------------------------- /ml_metrics/oss/build_whl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # build wheel for python version specified in $PYTHON 3 | 4 | set -e -x 5 | 6 | CP_VERSION="cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}" 7 | PYTHON_BIN_PATH="/opt/python/${CP_VERSION}-${CP_VERSION}/bin/python" 8 | 9 | function main() { 10 | 11 | DEST="/tmp/ml_metrics/all_dist" 12 | mkdir -p "${DEST}" 13 | 14 | echo "=== Destination directory: ${DEST}" 15 | 16 | if [ "$ML_METRICS_RUN_TESTS" = true ] ; then 17 | python3 -m unittest discover -s ml_metrics -p '*_test.py' 18 | fi 19 | 20 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) 21 | 22 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 23 | 24 | echo "=== Copy ml_metrics files" 25 | 26 | cp ./setup.py "${TMPDIR}" 27 | cp ./pyproject.toml "${TMPDIR}" 28 | cp ./LICENSE "${TMPDIR}" 29 | rsync -avm -L --exclude="__pycache__/*" ./ml_metrics "${TMPDIR}" 30 | # rsync -avm -L --include="*.so" --include="*_pb2.py" \ 31 | # --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \ 32 | # bazel-bin/ml_metrics "${TMPDIR}" 33 | 34 | pushd ${TMPDIR} 35 | echo $(date) : "=== Building wheel" 36 | 37 | "python${PYTHON_VERSION}" setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION} 38 | cp dist/*.whl "${DEST}" 39 | 40 | echo $(date) : "=== Auditing wheel" 41 | auditwheel repair --plat ${AUDITWHEEL_PLATFORM} -w dist dist/*.whl 42 | 43 | echo $(date) : "=== Listing wheel" 44 | ls -lrt dist/*.whl 45 | cp dist/*.whl "${DEST}" 46 | popd 47 | 48 | echo $(date) : "=== Output wheel file is in: ${DEST}" 49 | } 50 | 51 | main "$@" -------------------------------------------------------------------------------- /ml_metrics/oss/runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script copies ml_metrics from internal repo, builds a docker, and 3 | # builds pip wheels for all Python versions. 4 | 5 | set -e -x 6 | 7 | export TMP_FOLDER="/tmp/ml_metrics" 8 | 9 | # Clean previous folders/images. 10 | [ -f $TMP_FOLDER ] && rm -rf $TMP_FOLDER 11 | 12 | PYTHON_MAJOR_VERSION="3" 13 | PYTHON_MINOR_VERSION="11" 14 | PYTHON_VERSION="${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}" 15 | ML_METRICS_RUN_TESTS=true 16 | 17 | AUDITWHEEL_PLATFORM="manylinux2014_x86_64" 18 | 19 | 20 | docker rmi -f ml_metrics:${PYTHON_VERSION} 21 | docker rm -f ml_metrics 22 | 23 | # Synchronize Copybara in $TMP_FOLDER. 24 | cp -r . $TMP_FOLDER 25 | 26 | cd $TMP_FOLDER 27 | 28 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \ 29 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ 30 | --build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ 31 | --build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ 32 | -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile 33 | 34 | docker run --rm -a stdin -a stdout -a stderr \ 35 | --env PYTHON_VERSION=${PYTHON_VERSION} \ 36 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ 37 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ 38 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \ 39 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 40 | -v /tmp/ml_metrics:/tmp/ml_metrics \ 41 | --name ml_metrics ml_metrics:${PYTHON_VERSION} \ 42 | bash ml_metrics/oss/build_whl.sh 43 | 44 | ls $TMP_FOLDER/all_dist/*.whl -------------------------------------------------------------------------------- /ml_metrics/signals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/ml-metrics/510ebbc9ab70847d8aff3eca7cb4d8a3309e6319/ml_metrics/signals/__init__.py -------------------------------------------------------------------------------- /ml_metrics/signals/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classification signals.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.signals.cg_score import complexity_gap_score 19 | from ml_metrics._src.signals.cross_entropy import binary_cross_entropy 20 | from ml_metrics._src.signals.cross_entropy import categorical_cross_entropy 21 | from ml_metrics._src.signals.flip_masks import binary_flip_mask 22 | from ml_metrics._src.signals.flip_masks import neg_to_pos_flip_mask 23 | from ml_metrics._src.signals.flip_masks import pos_to_neg_flip_mask 24 | from ml_metrics._src.signals.topk_accuracy import topk_accurate 25 | -------------------------------------------------------------------------------- /ml_metrics/signals/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Image signals.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.signals.image import content_metadata 19 | -------------------------------------------------------------------------------- /ml_metrics/signals/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Text signals.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.signals.text import alphabetical_char_count 19 | from ml_metrics._src.signals.text import average_word_length 20 | from ml_metrics._src.signals.text import exact_match 21 | from ml_metrics._src.signals.text import is_all_whitespace 22 | from ml_metrics._src.signals.text import non_ascii_char_count 23 | from ml_metrics._src.signals.text import reference_in_sample_match 24 | from ml_metrics._src.signals.text import reference_startswith_sample_match 25 | from ml_metrics._src.signals.text import sample_in_reference_match 26 | from ml_metrics._src.signals.text import sample_startswith_reference_match 27 | from ml_metrics._src.signals.text import token_count 28 | from ml_metrics._src.signals.text import token_match_rate 29 | from ml_metrics._src.signals.text import word_count 30 | -------------------------------------------------------------------------------- /ml_metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utitilies.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=unused-import 18 | from ml_metrics._src.utils import func_utils 19 | from ml_metrics._src.utils import iter_utils 20 | from ml_metrics._src.utils import math_utils 21 | from ml_metrics._src.utils import test_utils 22 | from ml_metrics._src.utils.df_utils import metrics_to_df 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "py-ml-metrics" 7 | version = "0.0.3" 8 | description = "ML-metrics: A library that provides performant and distributed friendly ML metrics implementations." 9 | keywords = [] 10 | authors = [ 11 | {name = "ML metrics team", email = "ml-metrics-dev@google.com"}, 12 | ] 13 | dependencies = [ 14 | 'absl-py', 15 | 'cloudpickle', 16 | 'immutabledict', 17 | 'numpy', 18 | 'more-itertools', 19 | ] 20 | readme = "README.md" 21 | license = { file = "LICENSE" } 22 | requires-python = ">=3.11" 23 | classifiers = [ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Operating System :: POSIX :: Linux", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | signals = ['pillow'] 31 | full = ['pillow'] 32 | 33 | [project.urls] 34 | homepage = "https://github.com/google/ml-metrics" 35 | 36 | [tool.setuptools.packages.find] 37 | include = ["ml_metrics*"] 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup.py file for ml-metrics. 2 | 3 | Most project configs are in `pyproject.toml` -- prefer to modify 4 | `pyproject.toml` over this file if possible. 5 | """ 6 | 7 | import setuptools 8 | 9 | 10 | setuptools.setup() 11 | --------------------------------------------------------------------------------