├── .github
└── workflows
│ └── tests.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── ml_metrics
├── __init__.py
├── _src
│ ├── __init__.py
│ ├── aggregates
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── base_test.py
│ │ ├── classification.py
│ │ ├── classification_test.py
│ │ ├── keras_metric_wrapper.py
│ │ ├── keras_metric_wrapper_test.py
│ │ ├── retrieval.py
│ │ ├── retrieval_test.py
│ │ ├── rolling_stats.py
│ │ ├── rolling_stats_test.py
│ │ ├── test_utils.py
│ │ ├── text.py
│ │ ├── text_test.py
│ │ ├── types.py
│ │ └── utils.py
│ ├── chainables
│ │ ├── __init__.py
│ │ ├── io.py
│ │ ├── io_test.py
│ │ ├── lazy_fns.py
│ │ ├── lazy_fns_test.py
│ │ ├── transform.py
│ │ ├── transform_test.py
│ │ ├── tree.py
│ │ ├── tree_fns.py
│ │ ├── tree_fns_test.py
│ │ └── tree_test.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── classification.py
│ │ ├── classification_test.py
│ │ ├── retrieval.py
│ │ ├── retrieval_test.py
│ │ ├── rolling_stats.py
│ │ ├── rolling_stats_test.py
│ │ ├── text.py
│ │ ├── text_test.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── signals
│ │ ├── cg_score.py
│ │ ├── cg_score_test.py
│ │ ├── cross_entropy.py
│ │ ├── cross_entropy_test.py
│ │ ├── flip_masks.py
│ │ ├── flip_masks_test.py
│ │ ├── image.py
│ │ ├── image_test.py
│ │ ├── text.py
│ │ ├── text_test.py
│ │ ├── topk_accuracy.py
│ │ └── topk_accuracy_test.py
│ ├── tools
│ │ └── telemetry
│ │ │ └── telemetry.py
│ ├── types.py
│ ├── types_test.py
│ └── utils
│ │ ├── df_utils.py
│ │ ├── df_utils_test.py
│ │ ├── func_utils.py
│ │ ├── func_utils_test.py
│ │ ├── iter_utils.py
│ │ ├── iter_utils_test.py
│ │ ├── math_utils.py
│ │ ├── math_utils_test.py
│ │ ├── proto_utils.py
│ │ ├── proto_utils_test.py
│ │ ├── test_utils.py
│ │ └── test_utils_test.py
├── aggregates.py
├── chainable.py
├── data_sinks.py
├── metrics
│ ├── __init__.py
│ ├── classification.py
│ ├── retrieval.py
│ ├── rolling_stats.py
│ ├── text.py
│ └── utils.py
├── oss
│ ├── README.md
│ ├── build.Dockerfile
│ ├── build_whl.sh
│ └── runner.sh
├── signals
│ ├── __init__.py
│ ├── classification.py
│ ├── image.py
│ └── text.py
└── utils.py
├── pyproject.toml
└── setup.py
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Build & Test
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | build-and-test:
11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12 | runs-on: "${{ matrix.os }}"
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.11", "3.12", "3.13"]
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: "actions/checkout@v4"
21 | - name: Create directory
22 | run: |
23 | mkdir -p /tmp/ml_metrics
24 | cp -r . /tmp/ml_metrics
25 | - name: Build package
26 | run: |
27 | set -xe
28 | export PYTHON_VERSION=${{ matrix.python-version }}
29 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1)
30 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2)
31 | export AUDITWHEEL_PLATFORM="manylinux2014_x86_64"
32 | export ML_METRICS_RUN_TESTS=true
33 | cd /tmp/ml_metrics
34 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \
35 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \
36 | --build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
37 | --build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
38 | -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile
39 | docker run --rm -a stdin -a stdout -a stderr \
40 | --env PYTHON_VERSION=${PYTHON_VERSION} \
41 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
42 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
43 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \
44 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
45 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \
46 | -v /tmp/ml_metrics:/tmp/ml_metrics \
47 | --name ml_metrics ml_metrics:${PYTHON_VERSION} \
48 | bash ml_metrics/oss/build_whl.sh
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ML-metrics
2 |
3 | Ml-metrics provides performant and distributed friendly ML metrics implementations.
4 |
5 | This is not an officially supported Google product.
--------------------------------------------------------------------------------
/ml_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/ml_metrics/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base AggregateFn for all the aggregates."""
15 | from __future__ import annotations
16 |
17 | import abc
18 | from collections.abc import Callable, Iterable
19 | import dataclasses
20 | from typing import Any, Protocol, Self, TypeGuard, TypeVar, runtime_checkable
21 |
22 | from ml_metrics._src import types
23 | from ml_metrics._src.chainables import lazy_fns
24 | from ml_metrics._src.chainables import tree
25 |
26 | _T = TypeVar('_T')
27 | _ResolvableOrMakeable = types.Resolvable[_T] | types.Makeable[_T]
28 |
29 |
30 | @runtime_checkable
31 | class Metric(Protocol):
32 | """MergibleMetric can be used as simple Map, and also MapReduce."""
33 |
34 | @abc.abstractmethod
35 | def add(self, *inputs, **named_inputs) -> Any:
36 | """Computes the state from a batch while outputting batch output."""
37 |
38 | @abc.abstractmethod
39 | def result(self) -> Any:
40 | """Returns the result of the metric."""
41 |
42 | _MetricT = TypeVar('_MetricT', bound=Metric)
43 |
44 |
45 | def as_agg_fn(
46 | cls: Callable[..., _MetricT],
47 | *args,
48 | nested: bool = False,
49 | agg_preprocess_fn: Callable[..., Any] | None = None,
50 | **kwargs,
51 | ) -> AggregateFn:
52 | """Creates an AggregateFn from a metric class."""
53 | deferred_metric: types.Resolvable[_MetricT] = lazy_fns.trace(cls)(
54 | *args, **kwargs
55 | )
56 | # Try resolve the target at construction at calltime to detect errors.
57 | _ = lazy_fns.maybe_make(deferred_metric)
58 | agg_fn = MergeableMetricAggFn(deferred_metric)
59 | if nested:
60 | agg_fn = AggFnNested(agg_fn, preprocess_fn=agg_preprocess_fn)
61 | return agg_fn
62 |
63 |
64 | @runtime_checkable
65 | class MergeableMetric(Metric, Protocol):
66 | """MergibleMetric can be used as simple Map, and also MapReduce."""
67 |
68 | @abc.abstractmethod
69 | def merge(self, other: Self):
70 | """Merges the metric with another metric of the same type."""
71 |
72 |
73 | class CallableMetric(MergeableMetric, Callable[..., Any]):
74 | """A metric that is also callable.
75 |
76 | The CallableMetric is the recommended interface to implement a metric that
77 | supports both calculating batch result (`process`) and merging batch results
78 | (`merge`). A default `add` method is provided, but should be overwritten if
79 | the `merge` method is not applicable for `add`.
80 | """
81 |
82 | @abc.abstractmethod
83 | def new(self, *args, **kwargs) -> Self:
84 | """Calculate the sufficient statistics, should be idempotent."""
85 |
86 | def add(self, *args, **kwargs):
87 | """Updates the sufficient statistics with a batch of inputs."""
88 | batch_result = self.new(*args, **kwargs)
89 | self.merge(batch_result)
90 | return batch_result
91 |
92 | def __call__(self, *args, **kwargs):
93 | """Calculates the result from the sufficient statistics."""
94 | return self.new(*args, **kwargs).result()
95 |
96 |
97 | class HasAsAggFn(Protocol):
98 | """An object that can be used as an aggregate function."""
99 |
100 | def as_agg_fn(self) -> Aggregatable:
101 | """Returns the aggregate function."""
102 |
103 |
104 | def has_as_agg_fn(obj: Any) -> TypeGuard[HasAsAggFn]:
105 | """Returns True if the object has an as_agg_fn method."""
106 | return types.obj_has_method(obj, 'as_agg_fn')
107 |
108 |
109 | @runtime_checkable
110 | class Aggregatable(Protocol):
111 | """An aggregation interface, similar to apche_beam.CombineFn."""
112 |
113 | def create_state(self) -> Any:
114 | """Creates the initial states for the aggregation."""
115 | return None
116 |
117 | @abc.abstractmethod
118 | def update_state(self, state, *inputs, **named_inputs):
119 | """Update the state from a batch of inputs.
120 |
121 | Args:
122 | state: the current state.
123 | *inputs: elements to add.
124 | **named_inputs: elements to add.
125 | """
126 |
127 | def merge_states(self, states):
128 | """Merging multiple states into a one state value.
129 |
130 | This is only required for distributed implementations such as Beam. Only the
131 | first state may be modified and returned for efficiency.
132 |
133 | Args:
134 | states: the states to be merged.
135 | """
136 | raise NotImplementedError()
137 |
138 | def get_result(self, state):
139 | """Computes and returns the result from state.
140 |
141 | Args:
142 | state: the final state value computed by this CombineFn.
143 |
144 | Returns:
145 | state.
146 | """
147 | return state
148 |
149 |
150 | class AggregateFn(Aggregatable):
151 | """An aggregation interface, similar to apche_beam.CombineFn."""
152 |
153 | def __call__(self, *inputs, **named_inputs):
154 | """Directly apply aggregate on inputs."""
155 | return self.get_result(
156 | self.update_state(self.create_state(), *inputs, **named_inputs)
157 | )
158 |
159 |
160 | class MergeableMetricAggFn(AggregateFn):
161 | """A aggregation wrapper for MergeableMetric."""
162 |
163 | metric_maker: _ResolvableOrMakeable[MergeableMetric]
164 |
165 | def __init__(self, metric_maker: _ResolvableOrMakeable[MergeableMetric]):
166 | super().__init__()
167 | if not (
168 | types.is_resolvable(metric_maker) or types.is_makeable(metric_maker)
169 | ):
170 | raise TypeError(
171 | 'metric_maker must be an instance of Makeable or Resolvable. got'
172 | f' {type(metric_maker)}'
173 | )
174 | self.metric_maker = metric_maker
175 |
176 | def __eq__(self, other, /):
177 | return (
178 | isinstance(other, MergeableMetricAggFn)
179 | and self.metric_maker == other.metric_maker
180 | )
181 |
182 | def create_state(self) -> MergeableMetric:
183 | metric = self.metric_maker
184 | if types.is_makeable(metric):
185 | return metric.make()
186 | elif types.is_resolvable(metric):
187 | return metric.result_()
188 | else:
189 | raise TypeError(f'{type(metric)} is not a Makeable or Resolvable.')
190 |
191 | def update_state(
192 | self, state: MergeableMetric, *args, **kwargs
193 | ) -> MergeableMetric:
194 | state.add(*args, **kwargs)
195 | return state
196 |
197 | def merge_states(self, states: Iterable[MergeableMetric]) -> MergeableMetric:
198 | iter_states = iter(states)
199 | result = next(iter_states)
200 | for state in iter_states:
201 | result.merge(state)
202 | return result
203 |
204 | def get_result(self, state: MergeableMetric) -> Any:
205 | return state.result()
206 |
207 |
208 | class AggFnNested(AggregateFn):
209 | """AggregateFn that traverses and aggregates each leaf of a PyTree."""
210 |
211 | fn: Aggregatable
212 | preprocess_fn: Callable[..., tree.TreeLike[Any]] | None
213 |
214 | def __init__(
215 | self,
216 | fn: Aggregatable,
217 | preprocess_fn: Callable[..., tree.TreeLike[Any]] | None = None,
218 | ):
219 | if preprocess_fn is not None and not callable(preprocess_fn):
220 | raise ValueError(f'preporcess_fn must be a callable. got {preprocess_fn}')
221 | if not isinstance(fn, Aggregatable):
222 | raise ValueError(f'fn must be an instance of Aggregatable. got {fn}')
223 | super().__init__()
224 | self.fn = fn
225 | self.preprocess_fn = preprocess_fn
226 |
227 | def create_state(self):
228 | """Creates the initial states for the aggregation."""
229 | return None
230 |
231 | def update_state(self, state: tree.TreeMapView, inputs: tree.TreeLike[Any]):
232 | """Update the state from a batch of inputs."""
233 | if self.preprocess_fn:
234 | inputs = self.preprocess_fn(inputs)
235 | if state is None:
236 | state = tree.TreeMapView(
237 | inputs, map_fn=lambda x: self.fn.create_state()
238 | ).apply()
239 | state = tree.TreeMapView(state)
240 | inputs = tree.TreeMapView.as_view(inputs)
241 | return state.copy_and_update(
242 | (k, self.fn.update_state(state[k], v)) for k, v in inputs.items()
243 | )
244 |
245 | # TODO: b/311207032 - Implement this.
246 | def merge_states(self, states):
247 | """Merging multiple states into a one state value."""
248 | raise NotImplementedError()
249 |
250 | def get_result(self, state):
251 | """Computes and returns the result from state."""
252 | return tree.TreeMapView.as_view(state, map_fn=self.fn.get_result).apply()
253 |
254 |
255 | @dataclasses.dataclass(frozen=True)
256 | class UserAggregateFn(AggregateFn):
257 | """An aggregation interface, similar to apche_beam.CombineFn."""
258 |
259 | fn: Aggregatable
260 |
261 | def __post_init__(self):
262 | if not isinstance(self.fn, Aggregatable):
263 | raise ValueError(
264 | f'UserAggregateFn must be an instance of Aggregatable. got {self.fn}'
265 | )
266 |
267 | def create_state(self) -> Any:
268 | """Creates the initial states for the aggregation."""
269 | return self.fn.create_state()
270 |
271 | def update_state(self, state, *inputs, **named_inputs):
272 | """Update the state from a batch of inputs."""
273 | return self.fn.update_state(state, *inputs, **named_inputs)
274 |
275 | def merge_states(self, states):
276 | """Merging multiple states into a one state value."""
277 | return self.fn.merge_states(states)
278 |
279 | def get_result(self, state):
280 | """Computes and returns the result from state."""
281 | return self.fn.get_result(state)
282 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/base_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for core lib."""
15 |
16 | from absl.testing import absltest
17 | from ml_metrics._src.aggregates import base
18 | from ml_metrics._src.aggregates import test_utils
19 | from ml_metrics._src.chainables import lazy_fns
20 |
21 |
22 | class AggregatesTest(absltest.TestCase):
23 |
24 | def test_callable_combinefn_in_process(self):
25 | sum_fn = base.UserAggregateFn(test_utils._SumMetric().as_agg_fn())
26 | self.assertEqual(sum_fn(list(range(4))), 6)
27 |
28 | def test_mergeable_aggregate_fn_in_process(self):
29 | sum_fn = base.as_agg_fn(test_utils._SumMetric)
30 | self.assertEqual(6, sum_fn([1, 2, 3]))
31 |
32 | def test_mergeable_aggregate_fn_from_resolvable(self):
33 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)()
34 | sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum)
35 | self.assertEqual(6, sum_fn([1, 2, 3]))
36 |
37 | def test_mergeable_aggregate_fn_from_makeable(self):
38 |
39 | class MakeableSum(test_utils._SumMetric):
40 |
41 | def make(self):
42 | return self
43 |
44 | sum_fn = base.MergeableMetricAggFn(MakeableSum())
45 | self.assertEqual(6, sum_fn([1, 2, 3]))
46 |
47 | def test_mergeable_aggregate_fn_eq(self):
48 | sum_fn = base.MergeableMetricAggFn(lazy_fns.trace(test_utils._SumMetric)())
49 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)()
50 | other_sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum)
51 | self.assertEqual(sum_fn, other_sum_fn)
52 | makeable_deferred_sum = lazy_fns.trace(test_utils._SumMetric)(len)
53 | other_sum_fn = base.MergeableMetricAggFn(makeable_deferred_sum)
54 | self.assertNotEqual(sum_fn, other_sum_fn)
55 |
56 | def test_mergeable_aggregate_fn_unsupported_type(self):
57 | with self.assertRaisesRegex(TypeError, 'must be an instance of.+ got'):
58 | # disable pytype check for the runtime error to surface.
59 | _ = base.MergeableMetricAggFn(test_utils._SumMetric()) # pytype: disable=wrong-arg-types
60 |
61 | def test_metric_callable(self):
62 | sum_fn = test_utils._SumMetric()
63 | self.assertEqual(6, sum_fn([1, 2, 3]))
64 |
65 | def test_has_as_agg_fn(self):
66 | sum_aggfn = test_utils._SumMetric().as_agg_fn()
67 | self.assertFalse(base.has_as_agg_fn(sum_aggfn))
68 | self.assertIsInstance(sum_aggfn, base.Aggregatable)
69 | sum_metric = test_utils._SumMetric()
70 | self.assertTrue(base.has_as_agg_fn(sum_metric))
71 |
72 |
73 | if __name__ == '__main__':
74 | absltest.main()
75 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/keras_metric_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Keras metric wrapper."""
15 | from __future__ import annotations
16 |
17 | from collections.abc import Iterable
18 | import dataclasses
19 | from typing import Any, Protocol
20 |
21 | from ml_metrics._src.aggregates import base as agg
22 |
23 |
24 | class KerasMetric(Protocol):
25 | """Base interface for Keras metrics."""
26 |
27 | def update_state(self, *inputs, **named_inputs) -> None:
28 | """Updates the state from a batch of inputs."""
29 |
30 | def reset_state(self) -> None:
31 | """Resets the state."""
32 |
33 | def merge_state(self, other: Iterable[KerasMetric]) -> None:
34 | """Merges the state with another state of the same type."""
35 |
36 | def result(self) -> Any:
37 | """Returns the result of the metric."""
38 |
39 |
40 | def is_keras_metric(metric: Any) -> bool:
41 | """Duck type check for Keras metric."""
42 | return (
43 | hasattr(metric, "update_state")
44 | and hasattr(metric, "reset_state")
45 | and hasattr(metric, "merge_state")
46 | and hasattr(metric, "result")
47 | )
48 |
49 |
50 | @dataclasses.dataclass
51 | class KerasAggregateFn(agg.AggregateFn):
52 | """AggregateFn for Keras metrics."""
53 |
54 | metric: KerasMetric
55 |
56 | def __post_init__(self):
57 | if is_keras_metric(self.metric):
58 | self.metric.reset_state()
59 | self._metric = self.metric
60 | else:
61 | try:
62 | assert hasattr(self.metric, "__call__")
63 | self._metric = self.metric()
64 | if not is_keras_metric(self._metric):
65 | raise TypeError("metric must implement Keras metric base interface.")
66 | except Exception as e:
67 | raise TypeError(
68 | f"Cannot construct a Keras metric from {self.metric}."
69 | ) from e
70 |
71 | def create_state(self) -> KerasMetric:
72 | assert hasattr(self._metric, "reset_state")
73 | self._metric.reset_state()
74 | return self._metric
75 |
76 | def update_state(
77 | self, state: KerasMetric, *inputs: Any, **named_inputs: Any
78 | ) -> KerasMetric:
79 | state.update_state(*inputs, **named_inputs)
80 | return state
81 |
82 | def merge_states(self, states: Iterable[KerasMetric]) -> KerasMetric:
83 | # This in-place merges all the states into the first state and returns it.
84 | iter_states = iter(states)
85 | result = next(iter_states)
86 | result.merge_state(list(iter_states))
87 | return result
88 |
89 | def get_result(self, state: KerasMetric) -> Any:
90 | result = state.result()
91 | if hasattr(result, "numpy"):
92 | return result.numpy()
93 | return result
94 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/keras_metric_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Wrapper for Keras metrics."""
15 |
16 | import importlib
17 |
18 | from ml_metrics._src.aggregates import keras_metric_wrapper
19 |
20 | from absl.testing import absltest
21 |
22 |
23 | class MockKerasMetric:
24 |
25 | def __init__(self):
26 | self.reset_state()
27 |
28 | def reset_state(self):
29 | self._state = 0
30 |
31 | def update_state(self, inputs):
32 | self._state += sum(inputs)
33 |
34 | def merge_state(self, states):
35 | for state in states:
36 | self._state += state
37 |
38 | def result(self):
39 | return self._state
40 |
41 |
42 | class KerasTest(absltest.TestCase):
43 |
44 | def test_mock_keras_metric(self):
45 | metric = keras_metric_wrapper.KerasAggregateFn(MockKerasMetric())
46 | self.assertEqual(6, metric([1, 2, 3]))
47 |
48 | def test_keras_metric_wrapper_merge(self):
49 | try:
50 | tf = importlib.import_module("tensorflow")
51 | except ImportError:
52 | # Ignores the import error if tensorflow is not installed.
53 | return
54 | metric1 = keras_metric_wrapper.KerasAggregateFn(
55 | tf.keras.metrics.Mean(name="mean")
56 | )
57 | metric2 = keras_metric_wrapper.KerasAggregateFn(
58 | tf.keras.metrics.Mean(name="mean")
59 | )
60 | state1, state2 = metric1.create_state(), metric2.create_state()
61 | merged_state = metric1.merge_states([
62 | metric1.update_state(state1, [1, 2, 3]),
63 | metric2.update_state(state2, [4, 5, 6]),
64 | ])
65 | self.assertEqual(3.5, metric1.get_result(merged_state))
66 |
67 | def test_keras_metric_wrapper(self):
68 | try:
69 | tf = importlib.import_module("tensorflow")
70 | except ImportError:
71 | # Ignores the import error if tensorflow is not installed.
72 | return
73 | metric = keras_metric_wrapper.KerasAggregateFn(
74 | tf.keras.metrics.Mean(name="mean")
75 | )
76 | self.assertEqual(2, metric([1, 2, 3]))
77 |
78 |
79 | if __name__ == "__main__":
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """test_utils."""
15 | from ml_metrics._src.aggregates import base
16 |
17 |
18 | class _SumMetric(base.CallableMetric, base.HasAsAggFn):
19 | """Mock Metric for test."""
20 |
21 | def __init__(self, state=0):
22 | self._state = state
23 |
24 | def as_agg_fn(self):
25 | return base.as_agg_fn(self.__class__)
26 |
27 | @property
28 | def state(self):
29 | return self._state
30 |
31 | def new(self, x):
32 | return _SumMetric(state=sum(x))
33 |
34 | def merge(self, other):
35 | self._state += other.state
36 |
37 | def result(self):
38 | return self._state
39 |
40 |
41 | class _SumAggFn:
42 | """Mock CombineFn for test."""
43 |
44 | def create_state(self):
45 | return 0
46 |
47 | def update_state(self, state, x):
48 | return state + sum(x)
49 |
50 | def merge_states(self, states):
51 | return sum(states)
52 |
53 | def get_result(self, state):
54 | return state
55 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Individual text stats metrics."""
15 |
16 | import collections
17 | from collections.abc import Sequence
18 | import dataclasses
19 | import re
20 |
21 | from ml_metrics._src.aggregates import base
22 | from ml_metrics._src.aggregates import utils
23 | from ml_metrics._src.tools.telemetry import telemetry
24 |
25 | FrequencyState = utils.FrequencyState
26 |
27 |
28 | @dataclasses.dataclass(kw_only=True)
29 | class TopKWordNGrams(base.MergeableMetric, base.HasAsAggFn):
30 | """Top k word n-grams metrics.
31 |
32 | Identify the top `k` frequent occurring word n-grams with a case-insensitive
33 | approach. The text will first be cleaned by removing non-alphabetic characters
34 | and spaces, and then converted to lowercase before computing the top k
35 | n-grams. When multiple n-grams share the same frequency, alphabetical order
36 | will be used as a tie-breaker. The result is a list of tuples containing the
37 | n-gram pattern and its corresponding frequency. The list includes either `k`
38 | or the number of distinct n-grams tuples, whichever is less.
39 |
40 |
41 | Attributes:
42 | k:
43 | Number of most frequent word n-grams.
44 | n:
45 | Number of grams.
46 | use_first_ngram_only:
47 | If `True`, only the first n words of each text will be used to form the
48 | n-grams and `count_duplicate` will be ignored. Otherwise, all words
49 | present in each text will be considered for generating the n-grams.
50 | Default to `False`.
51 | count_duplicate:
52 | If `True`, duplicate n-grams within the text are included in the total
53 | count. Otherwise, the count of a unique N-gram will only consider its
54 | first occurrence. Default to `True`.
55 | """
56 |
57 | k: int
58 | n: int
59 | use_first_ngram_only: bool = False
60 | count_duplicate: bool = True
61 | _state: FrequencyState = dataclasses.field(
62 | default_factory=FrequencyState, init=False
63 | )
64 |
65 | def __post_init__(self):
66 | telemetry.increment_counter(
67 | api='ml_metrics', category='metric', reference=self.__class__.__name__
68 | )
69 | if self.k <= 0 or self.n <= 0:
70 | raise ValueError(
71 | f'k and n must be positive integers but k={self.k} and n={self.n} was'
72 | ' passed.'
73 | )
74 |
75 | def as_agg_fn(self) -> base.AggregateFn:
76 | return base.as_agg_fn(
77 | self.__class__,
78 | k=self.k,
79 | n=self.n,
80 | use_first_ngram_only=self.use_first_ngram_only,
81 | count_duplicate=self.count_duplicate,
82 | )
83 |
84 | @property
85 | def state(self) -> FrequencyState:
86 | return self._state
87 |
88 | def add(self, texts: Sequence[str]) -> list[tuple[str, float]]:
89 | ngrams_counter = collections.Counter()
90 | for text in texts:
91 | # Remove non-alphabetical and non-space characters
92 | words = re.sub(r'[^a-zA-Z ]+', '', text).lower().split()
93 | if self.n <= len(words):
94 | ngrams = []
95 | if self.use_first_ngram_only:
96 | ngrams.append(' '.join(words[:self.n]))
97 | else:
98 | for idx in range(len(words) - self.n + 1):
99 | ngrams.append(' '.join(words[idx : idx + self.n]))
100 | if not self.count_duplicate:
101 | ngrams = set(ngrams)
102 | ngrams_counter.update(ngrams)
103 | batch_reault = FrequencyState(counter=ngrams_counter, count=len(texts))
104 | self._state.merge(batch_reault)
105 |
106 | result = batch_reault.result()[:self.k]
107 | return result
108 |
109 | def merge(self, other: 'TopKWordNGrams'):
110 | # TODO(b/331796958): Optimize storage consumption
111 | self._state.merge(other.state)
112 |
113 | def result(self) -> list[tuple[str, float]]:
114 | return self._state.result()[:self.k]
115 |
116 |
117 | @dataclasses.dataclass(kw_only=True)
118 | class PatternFrequency(base.MergeableMetric, base.HasAsAggFn):
119 | """Pattern frequency metric.
120 |
121 | Identify the frequency of occurrence for each pattern found within the given
122 | texts.
123 |
124 |
125 | Attributes:
126 | patterns:
127 | Sequence of text patterns.
128 | count_duplicate:
129 | If `True`, duplicate pattern within the text are included in the total
130 | count. Otherwise, the count of a pattern will only consider its first
131 | occurrence. Default to `False`.
132 | """
133 |
134 | patterns: Sequence[str]
135 | count_duplicate: bool = True
136 | _state: FrequencyState = dataclasses.field(
137 | default_factory=FrequencyState, init=False
138 | )
139 |
140 | def __post_init__(self):
141 | telemetry.increment_counter(
142 | api='ml_metrics', category='metric', reference=self.__class__.__name__
143 | )
144 | if not self.patterns:
145 | raise ValueError('Patterns must not be empty.')
146 |
147 | if len(set(self.patterns)) != len(self.patterns):
148 | raise ValueError(f'Patterns must be unique: {self.patterns}')
149 |
150 | def as_agg_fn(self) -> base.AggregateFn:
151 | return base.as_agg_fn(
152 | self.__class__,
153 | patterns=self.patterns,
154 | count_duplicate=self.count_duplicate,
155 | )
156 |
157 | @property
158 | def state(self) -> FrequencyState:
159 | return self._state
160 |
161 | def add(self, texts: Sequence[str]) -> list[tuple[str, float]]:
162 | batch_frquency_state = FrequencyState()
163 | for pattern in self.patterns:
164 | for text in texts:
165 | num_matches = 0
166 | if self.count_duplicate:
167 | matches = list(
168 | re.finditer(r'(?=({}))'.format(re.escape(pattern)), text)
169 | )
170 | num_matches = len(matches)
171 | elif text.find(pattern) >= 0:
172 | num_matches = 1
173 | batch_frquency_state.counter[pattern] += num_matches
174 | batch_frquency_state.count = len(texts)
175 | self._state.merge(batch_frquency_state)
176 | return batch_frquency_state.result()
177 |
178 | def merge(self, other: 'PatternFrequency'):
179 | self._state.merge(other.state)
180 |
181 | def result(self) -> list[tuple[str, float]]:
182 | return self._state.result()
183 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/text_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for text."""
15 |
16 | from absl.testing import parameterized
17 | from ml_metrics._src.aggregates import text
18 | from absl.testing import absltest
19 |
20 |
21 | class TopKWordNGramsTest(parameterized.TestCase):
22 | """Tests for TopKWordNGrams."""
23 |
24 | @parameterized.named_parameters([
25 | dict(
26 | testcase_name='distinct_ngrams',
27 | k=2,
28 | n=2,
29 | use_first_ngram_only=False,
30 | count_duplicate=False,
31 | expected_result=[
32 | ('a a', 1 / 3),
33 | ('b b', 1 / 3),
34 | # This demonstrated that the alphabetical order is used as a
35 | # tie-breaker in the frequency state.
36 | # ('c c', 1 / 3), # Included when k >= 3
37 | # ('d a', 1 / 3), # Included when k >= 4
38 | ],
39 | ),
40 | dict(
41 | testcase_name='count_duplicate',
42 | k=2,
43 | n=2,
44 | use_first_ngram_only=False,
45 | count_duplicate=True,
46 | expected_result=[
47 | ('b b', 2 / 3),
48 | ('a a', 1 / 3),
49 | # This demonstrated that the alphabetical order is used as a
50 | # tie-breaker in the frequency state.
51 | # ('c c', 1 / 3), # Included when k >= 3
52 | # ('d a', 1 / 3), # Included when k >= 4
53 | ],
54 | ),
55 | dict(
56 | testcase_name='use_first_ngram_only',
57 | k=2,
58 | n=2,
59 | use_first_ngram_only=True,
60 | # count_duplicate will be ignored. See TopKWordNGrams description.
61 | count_duplicate=True,
62 | expected_result=[
63 | ('b b', 1 / 3),
64 | ('c c', 1 / 3),
65 | # This demonstrated that the alphabetical order is used as a
66 | # tie-breaker in the frequency state.
67 | # ('d a', 1 / 3), # Included when k >= 3
68 | ],
69 | ),
70 | dict(
71 | testcase_name='large_k',
72 | k=10,
73 | n=2,
74 | use_first_ngram_only=False,
75 | count_duplicate=True,
76 | expected_result=[
77 | ('b b', 2 / 3),
78 | ('a a', 1 / 3),
79 | ('c c', 1 / 3),
80 | ('d a', 1 / 3),
81 | ],
82 | ),
83 | dict(
84 | testcase_name='large_n',
85 | k=4,
86 | n=10,
87 | use_first_ngram_only=False,
88 | count_duplicate=True,
89 | expected_result=[],
90 | ),
91 | dict(
92 | testcase_name='2k_1n',
93 | k=2,
94 | n=1,
95 | use_first_ngram_only=False,
96 | count_duplicate=True,
97 | expected_result=[
98 | ('b', 3 / 3),
99 | ('a', 2 / 3),
100 | # This demonstrated that the alphabetical order is used as a
101 | # tie-breaker in the frequency state.
102 | # ('c', 2 / 3), # Included when k >= 3
103 | # ('d', 1 / 3), # Included when k >= 4
104 | ],
105 | ),
106 | dict(
107 | testcase_name='2k_3n',
108 | k=2,
109 | n=3,
110 | use_first_ngram_only=False,
111 | count_duplicate=True,
112 | expected_result=[
113 | ('b b b', 1 / 3),
114 | ('d a a', 1 / 3),
115 | ],
116 | ),
117 | ])
118 | def test_compute_topkwordngrams(
119 | self, k, n, use_first_ngram_only, count_duplicate, expected_result
120 | ):
121 | batch = [
122 | 'c c',
123 | 'b B b', # Case-insensitive
124 | 'd a a',
125 | ]
126 | metric = text.TopKWordNGrams(
127 | k=k,
128 | n=n,
129 | use_first_ngram_only=use_first_ngram_only,
130 | count_duplicate=count_duplicate,
131 | )
132 | batch_result = metric.add(batch)
133 | self.assertSequenceAlmostEqual(expected_result, batch_result)
134 | self.assertSequenceAlmostEqual(expected_result, metric.result())
135 |
136 | def test_compute_topkwordngrams_empty(self):
137 | metric = text.TopKWordNGrams(k=2, n=2)
138 | batch_result = metric.add([])
139 | self.assertSequenceEqual([], batch_result)
140 |
141 | def test_topkwordngrams_add(self):
142 | metric = text.TopKWordNGrams(k=2, n=2)
143 |
144 | batch_0 = [
145 | 'a b',
146 | 'a b',
147 | 'c d',
148 | ]
149 | batch_1 = [
150 | 'a b',
151 | 'b c',
152 | 'b c',
153 | ]
154 |
155 | batch_0_result = metric.add(batch_0)
156 | expected_batch_0_result = [
157 | ('a b', 2 / 3),
158 | ('c d', 1 / 3),
159 | ]
160 | self.assertSequenceAlmostEqual(expected_batch_0_result, batch_0_result)
161 |
162 | batch_1_result = metric.add(batch_1)
163 | expected_batch_1_result = [
164 | ('b c', 2 / 3),
165 | ('a b', 1 / 3),
166 | ]
167 | self.assertSequenceAlmostEqual(expected_batch_1_result, batch_1_result)
168 |
169 | expected_metric_result = [
170 | ('a b', 3 / 6),
171 | ('b c', 2 / 6),
172 | ]
173 | self.assertSequenceAlmostEqual(expected_metric_result, metric.result())
174 |
175 | def test_topkwordngrams_merge(self):
176 | metric_0 = text.TopKWordNGrams(k=2, n=2)
177 | metric_1 = text.TopKWordNGrams(k=2, n=2)
178 | batch_0 = [
179 | 'a b',
180 | 'c d',
181 | ]
182 | batch_1 = [
183 | 'b c',
184 | 'b c',
185 | ]
186 | metric_0.add(batch_0)
187 | metric_1.add(batch_1)
188 |
189 | metric_0.merge(metric_1)
190 |
191 | expected_result = [
192 | ('b c', 2 / 4),
193 | ('a b', 1 / 4),
194 | ]
195 | merged_result = metric_0.result()
196 | self.assertSequenceAlmostEqual(expected_result, merged_result)
197 |
198 | def test_topkwordngrams_invalid_kn(self):
199 | with self.assertRaisesRegex(
200 | ValueError, 'k and n must be positive integers.'
201 | ):
202 | text.TopKWordNGrams(k=1, n=0)
203 |
204 | with self.assertRaisesRegex(
205 | ValueError, 'k and n must be positive integers.'
206 | ):
207 | text.TopKWordNGrams(k=0, n=1)
208 |
209 |
210 | class PatternFrequencyTest(parameterized.TestCase):
211 | """Tests for PatternFrequency."""
212 |
213 | @parameterized.named_parameters([
214 | dict(
215 | testcase_name='not_count_duplicate',
216 | count_duplicate=False,
217 | batch=['ab ab xyx', 'xyxyx'],
218 | expected_result=[
219 | ('ab', 1 / 2),
220 | ('mmm', 0),
221 | ('xyx', 2 / 2),
222 | ],
223 | ),
224 | dict(
225 | testcase_name='count_duplicate',
226 | count_duplicate=True,
227 | batch=['ab ab xyx', 'xyxyx'],
228 | expected_result=[
229 | ('ab', 2 / 2),
230 | ('mmm', 0),
231 | ('xyx', 3 / 2),
232 | ],
233 | ),
234 | ])
235 | def test_compute_pattern_frequency(
236 | self, batch, count_duplicate, expected_result
237 | ):
238 | metric = text.PatternFrequency(
239 | patterns=['ab', 'xyx', 'mmm'], count_duplicate=count_duplicate
240 | )
241 | batch_result = sorted(metric.add(batch))
242 | self.assertSequenceAlmostEqual(expected_result, batch_result)
243 | self.assertSequenceAlmostEqual(expected_result, sorted(metric.result()))
244 |
245 | def test_compute_pattern_frequency_empty(self):
246 | metric = text.PatternFrequency(
247 | patterns=['ab', 'xyx'], count_duplicate=False
248 | )
249 | batch_result = metric.add([])
250 | self.assertSequenceEqual([], batch_result)
251 |
252 | def test_pattern_frequency_add(self):
253 | metric = text.PatternFrequency(
254 | patterns=['ab', 'xyx', 'mmm'], count_duplicate=False
255 | )
256 |
257 | batch_0 = ['ab ab xyx', 'xyx']
258 | batch_1 = ['mmm']
259 |
260 | batch_0_result = sorted(metric.add(batch_0))
261 | expected_batch_0_result = [('ab', 1 / 2), ('mmm', 0), ('xyx', 2 / 2)]
262 | self.assertSequenceAlmostEqual(expected_batch_0_result, batch_0_result)
263 |
264 | batch_1_result = sorted(metric.add(batch_1))
265 | expected_batch_1_result = [('ab', 0), ('mmm', 1), ('xyx', 0)]
266 | self.assertSequenceAlmostEqual(expected_batch_1_result, batch_1_result)
267 |
268 | expected_metric_result = [('ab', 1 / 3), ('mmm', 1 / 3), ('xyx', 2 / 3)]
269 | self.assertSequenceAlmostEqual(
270 | expected_metric_result, sorted(metric.result())
271 | )
272 |
273 | def test_pattern_frequency_merge(self):
274 | metric_0 = text.PatternFrequency(patterns=['ab'], count_duplicate=False)
275 | metric_1 = text.PatternFrequency(patterns=['ab'], count_duplicate=False)
276 | batch_0 = ['ab ab xyx']
277 | batch_1 = ['ab']
278 | metric_0.add(batch_0)
279 | metric_1.add(batch_1)
280 |
281 | metric_0.merge(metric_1)
282 |
283 | expected_result = [('ab', 2 / 2)]
284 | merged_result = sorted(metric_0.result())
285 | self.assertSequenceAlmostEqual(expected_result, merged_result)
286 |
287 | @parameterized.named_parameters([
288 | dict(
289 | testcase_name='empty_patterns',
290 | patterns=[],
291 | exception_regex='Patterns must not be empty.',
292 | ),
293 | dict(
294 | testcase_name='invalid_patterns',
295 | patterns=['a', 'a', 'b'],
296 | exception_regex='Patterns must be unique',
297 | ),
298 | ])
299 | def test_pattern_frequency_invalid_patterns(self, patterns, exception_regex):
300 | with self.assertRaisesRegex(ValueError, exception_regex):
301 | text.PatternFrequency(patterns=patterns)
302 |
303 |
304 | if __name__ == '__main__':
305 | absltest.main()
306 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common types for the aggregates."""
15 |
16 | import enum
17 | from numpy import typing as npt
18 |
19 | NumbersT = npt.ArrayLike
20 | DefaultDType = float
21 |
22 |
23 | class InputType(enum.StrEnum): # pylint: disable=invalid-enum-extension
24 | """Label prediction encoding types."""
25 |
26 | # 1D array per batch, e.g., [0,1,0,1,0], [-1, 1, -1], or ['Y', 'N']
27 | BINARY = 'binary'
28 | # 1D array of floats typically is the probability for the binary
29 | # classification problem, e.g., [0.2, 0.3, 0.9]
30 | CONTINUOUS = 'continuous'
31 | # 2D array of the floats for the multilabel/multiclass classification problem.
32 | # Dimension: BatchSize x # Class
33 | # e.g., [[0.2, 0.8, 0.9], [0.1, 0.2, 0.7]].
34 | CONTINUOUS_MULTIOUTPUT = 'continuous-multioutput'
35 | # 1D array of class identifiers, e.g, ['a', 'b'] or [1, 29, 12].
36 | MULTICLASS = 'multiclass'
37 | # 2D lists of multiclass encodings of the classes, e.g., [[1,2,0], [3,2,0]]
38 | # The list can be ragged, e.g, [ ['a', 'b'], ['c'] ]
39 | MULTICLASS_MULTIOUTPUT = 'multiclass-multioutput'
40 | # 2D array of one-hot encoding of the classes, e.g., [[0,1,0], [0,0,1]]
41 | # This is a special case for "multilabel-indicator" except that only one
42 | # class is set to positive per example.
43 | MULTICLASS_INDICATOR = 'multiclass-indicator'
44 |
45 |
46 | class AverageType(enum.StrEnum): # pylint: disable=invalid-enum-extension
47 | """Average type of the confusion matrix."""
48 |
49 | # Treats each class as one example and calculates the metrics on the total
50 | # aggregates of the result.
51 | MICRO = 'micro'
52 | # Macro calculates metrics for each class first, then average them across
53 | # classes.
54 | MACRO = 'macro'
55 | # Macro average with explicit weights per class.
56 | WEIGHTED = 'weighted'
57 | # Samples average calculates the metrics per example, and average them across
58 | # all examples.
59 | SAMPLES = 'samples'
60 | # Average for the positive label only.
61 | BINARY = 'binary'
62 |
--------------------------------------------------------------------------------
/ml_metrics/_src/aggregates/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for aggregates."""
15 | from __future__ import annotations
16 |
17 | import collections
18 | import dataclasses
19 |
20 | from ml_metrics._src.aggregates import base
21 | from ml_metrics._src.aggregates import types
22 | from ml_metrics._src.utils import math_utils
23 |
24 |
25 | @dataclasses.dataclass
26 | class MeanState(base.CallableMetric):
27 | """Mergeable states for batch update in an aggregate function."""
28 |
29 | total: types.NumbersT = 0.0
30 | count: types.NumbersT = 0
31 |
32 | def new(self, inputs: types.NumbersT) -> types.NumbersT:
33 | return MeanState(total=sum(inputs), count=len(inputs))
34 |
35 | def merge(self, other: MeanState):
36 | self.total += other.total
37 | self.count += other.count
38 |
39 | def result(self):
40 | return math_utils.safe_divide(self.total, self.count)
41 |
42 |
43 | @dataclasses.dataclass
44 | class TupleMeanState(base.CallableMetric):
45 | """MeanState for a tuple of inputs."""
46 |
47 | states: tuple[MeanState, ...] = ()
48 |
49 | def new(self, *inputs: tuple[types.NumbersT, ...]) -> TupleMeanState:
50 | return TupleMeanState(tuple(MeanState().new(x) for x in inputs))
51 |
52 | def merge(self, other: TupleMeanState):
53 | if not self.states:
54 | self.states = tuple(MeanState() for _ in other.states)
55 | for state, state_other in zip(self.states, other.states, strict=True):
56 | state.merge(state_other)
57 |
58 | def result(self):
59 | return tuple(state.result() for state in self.states)
60 |
61 |
62 | @dataclasses.dataclass
63 | class FrequencyState:
64 | """Mergeable frequency states for batch update in an aggregate function."""
65 |
66 | # TODO(b/331796958): Optimize storage consumption
67 | counter: collections.Counter[str] = dataclasses.field(
68 | default_factory=collections.Counter
69 | )
70 | count: int = 0
71 |
72 | def merge(self, other: 'FrequencyState'):
73 | self.counter.update(other.counter)
74 | self.count += other.count
75 |
76 | def result(self) -> list[tuple[str, float]]:
77 | result = [
78 | (key, math_utils.safe_divide(value, self.count))
79 | for key, value in self.counter.items()
80 | ]
81 | result = sorted(result, key=lambda x: (-x[1], x[0]))
82 | return result
83 |
--------------------------------------------------------------------------------
/ml_metrics/_src/chainables/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/ml_metrics/_src/chainables/io.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """I/O utilities for chainables."""
15 | from __future__ import annotations
16 |
17 | from collections.abc import Iterable, Iterator
18 | import dataclasses as dc
19 | from typing import Any, Self, TypeVar
20 |
21 | from ml_metrics._src import types
22 | from ml_metrics._src.utils import iter_utils
23 |
24 | _T = TypeVar('_T')
25 |
26 |
27 | @dc.dataclass(frozen=True, slots=True)
28 | class ShardConfig:
29 | shard_index: int = 0
30 | num_shards: int = 1
31 | start_index: int = 0
32 | parent: ShardConfig | None = dc.field(default=None, kw_only=True)
33 |
34 |
35 | @dc.dataclass(frozen=True)
36 | class SequenceDataSource(types.Recoverable, Iterable[_T]):
37 | """A shardable sequence data source."""
38 | data: types.RandomAccessible[_T]
39 | ignore_error: bool = dc.field(kw_only=True, default=False)
40 | batch_size: dc.InitVar[int] = 0
41 | _shard_state: ShardConfig = dc.field(default_factory=ShardConfig)
42 | _start: int = 0
43 | _end: int | None = None
44 |
45 | def __post_init__(self, batch_size: int):
46 | data = self.data
47 | if not hasattr(data, '__getitem__') or not hasattr(data, '__len__'):
48 | raise TypeError(f'data is not indexable, got {type(data)=}')
49 | # Use MergedSequences even for a single sequence to enforce iterating by
50 | # random access so that the iterator is continuable after exception.
51 | sequences = [self.data]
52 | if isinstance(self.data, iter_utils.MergedSequences):
53 | sequences = self.data.sequences
54 | data = iter_utils.MergedSequences(sequences, batch_size)
55 | object.__setattr__(self, 'data', data)
56 |
57 | @classmethod
58 | def from_sequences(
59 | cls,
60 | sequences: Iterable[types.RandomAccessible[_T]],
61 | batch_size: int = 0,
62 | ignore_error: bool = False,
63 | ) -> Self:
64 | return cls(
65 | iter_utils.MergedSequences(sequences),
66 | ignore_error=ignore_error,
67 | batch_size=batch_size,
68 | )
69 |
70 | def shard(self, shard_index: int, num_shards: int, offset: int = 0) -> Self:
71 | if num_shards < 1:
72 | raise ValueError(f'num_shards must be positive, got {num_shards=}')
73 | interval, remainder = divmod(self.end - self.start, num_shards)
74 | start, adjusted_interval = self.start, 0
75 | for i in range(shard_index + 1):
76 | adjusted_interval = interval + 1 if i < remainder else interval
77 | start += adjusted_interval if i < shard_index else 0
78 | shard_state = ShardConfig(
79 | shard_index, num_shards, offset, parent=self._shard_state
80 | )
81 | return dc.replace(
82 | self,
83 | _shard_state=shard_state,
84 | _start=start + offset,
85 | _end=start + adjusted_interval,
86 | )
87 |
88 | @property
89 | def start(self) -> int:
90 | return self._start
91 |
92 | @property
93 | def end(self) -> int:
94 | return len(self.data) if self._end is None else self._end
95 |
96 | def __len__(self) -> int:
97 | return self.end - self.start
98 |
99 | @property
100 | def state(self) -> ShardConfig:
101 | return self._shard_state
102 |
103 | def from_state(self, shard_state: ShardConfig) -> Self:
104 | """Iterates the data source given a shard index."""
105 | if shard_state.parent is not None:
106 | result = self.from_state(shard_state.parent)
107 | else:
108 | result = SequenceDataSource(self.data, ignore_error=self.ignore_error)
109 | return result.shard(
110 | shard_state.shard_index, shard_state.num_shards, shard_state.start_index
111 | )
112 |
113 | def iterate(self) -> SequenceIterator[_T]:
114 | return SequenceIterator(self)
115 |
116 | def __iter__(self) -> Iterator[_T]:
117 | return self.iterate()
118 |
119 | def __getitem__(self, index: int | Any) -> _T:
120 | """Iterates the data source given a shard index."""
121 | if isinstance(index, slice):
122 | start, stop, step = index.start, index.stop, index.step
123 | start = start or 0 # Convert None to 0.
124 | start += self.start
125 | start = min(start, self.end)
126 | if stop is None:
127 | stop = self.end
128 | else:
129 | stop = self.end + stop if stop < 0 else self.start + stop
130 | stop = min(stop, self.end)
131 | return list(self.data[slice(start, stop, step)])
132 |
133 | index = self.end + index if index < 0 else self.start + index
134 | return self.data[index]
135 |
136 |
137 | class SequenceIterator(types.Recoverable, Iterator[_T]):
138 | """A sharded data source for chainables."""
139 |
140 | config: SequenceDataSource
141 | _index: int
142 |
143 | def __init__(self, config: SequenceDataSource):
144 | self._index = config.start
145 | iter_ = iter_utils.iter_ignore_error if config.ignore_error else iter
146 | self._it = iter_(config.data[config.start : config.end])
147 | self.config = config
148 |
149 | def from_state(self, shard_state: ShardConfig) -> Self:
150 | return self.__class__(self.config.from_state(shard_state))
151 |
152 | @property
153 | def state(self) -> ShardConfig:
154 | start_index = self._index - self.config.start
155 | return dc.replace(self.config.state, start_index=start_index)
156 |
157 | def __next__(self) -> _T:
158 | """Iterates the data source given a shard index."""
159 | result = next(self._it)
160 | self._index += 1
161 | return result
162 |
163 | def __iter__(self) -> Self:
164 | """Iterates the data source given a shard index."""
165 | return self
166 |
167 |
168 | @dc.dataclass(frozen=True)
169 | class ShardedIterable(types.Recoverable, Iterable[_T]):
170 | """A sharded data source for any iterable."""
171 | data: Iterable[_T]
172 | _shard_state: ShardConfig = dc.field(default_factory=ShardConfig)
173 |
174 | def __post_init__(self):
175 | data = self.data
176 | if not (isinstance(data, Iterable) and not isinstance(data, Iterator)):
177 | raise TypeError(
178 | f'input has to be an iterable but not an iterator, got {type(data)=}'
179 | )
180 | if self._shard_state.num_shards < 1:
181 | raise ValueError(f'num_shards must be positive, got {self._shard_state=}')
182 |
183 | def shard(self, shard_index: int, num_shards: int) -> Self:
184 | return dc.replace(self, _shard_state=ShardConfig(shard_index, num_shards))
185 |
186 | @property
187 | def state(self) -> ShardConfig:
188 | return self._shard_state
189 |
190 | def from_state(self, shard_state: ShardConfig) -> Self:
191 | return dc.replace(self, _shard_state=shard_state)
192 |
193 | def iterate(self) -> DataIterator[_T]:
194 | return DataIterator(self)
195 |
196 | def __iter__(self) -> Iterator[_T]:
197 | return self.iterate()
198 |
199 |
200 | class DataIterator(types.Recoverable, Iterator[_T]):
201 | """An sharded iterator for an iterable."""
202 |
203 | def __init__(self, config: ShardedIterable):
204 | self.config = config
205 | self._index = 0
206 | self._it = iter(config.data)
207 |
208 | def from_state(self, shard_state: ShardConfig) -> Self:
209 | return self.__class__(self.config.from_state(shard_state))
210 |
211 | @property
212 | def state(self) -> ShardConfig:
213 | return dc.replace(self.config.state, start_index=self._index)
214 |
215 | def __next__(self) -> _T:
216 | """Iterates the data source given a shard index."""
217 | while self._index < self.config.state.start_index:
218 | _ = next(self._it)
219 | self._index += 1
220 | shard_index = self.config.state.shard_index
221 | num_shards = self.config.state.num_shards
222 | while self._index % num_shards != shard_index:
223 | _ = next(self._it)
224 | self._index += 1
225 | result = next(self._it)
226 | self._index += 1
227 | return result
228 |
229 | def __iter__(self) -> Self:
230 | return self
231 |
--------------------------------------------------------------------------------
/ml_metrics/_src/chainables/io_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | from ml_metrics._src.chainables import io
18 | from ml_metrics._src.utils import iter_utils
19 | from ml_metrics._src.utils import test_utils
20 |
21 |
22 | class SequenceDataSourceTest(parameterized.TestCase):
23 |
24 | @parameterized.named_parameters([
25 | dict(
26 | testcase_name='with_default_num_shards',
27 | num_shards=1,
28 | expected=[[0, 1, 2]],
29 | ),
30 | dict(
31 | testcase_name='with_two_shards',
32 | num_shards=2,
33 | expected=[[0, 1], [2]],
34 | ),
35 | dict(
36 | testcase_name='with_three_shards',
37 | num_shards=3,
38 | expected=[[0], [1], [2]],
39 | ),
40 | dict(
41 | testcase_name='with_four_shards',
42 | num_shards=4,
43 | expected=[[0], [1], [2], []],
44 | ),
45 | ])
46 | def test_sharded_sequence(self, num_shards, expected):
47 | ds = io.SequenceDataSource(range(3))
48 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)]
49 | self.assertEqual(expected, actual)
50 |
51 | def test_sharded_sequence_len(self):
52 | ds = io.SequenceDataSource(range(3))
53 | self.assertLen(ds, 3)
54 | self.assertLen(ds.shard(0, 1, offset=1), 2)
55 | self.assertLen(ds.shard(0, 2), 2)
56 | self.assertLen(ds.shard(0, 2, offset=1), 1)
57 | self.assertLen(ds.shard(1, 2), 1)
58 |
59 | def test_sharded_sequence_repeated_shard_len(self):
60 | ds = io.SequenceDataSource(range(10))
61 | self.assertLen(ds, 10)
62 | self.assertLen(ds.shard(0, 2).shard(0, 2), 3)
63 | self.assertLen(ds.shard(0, 2).shard(1, 2), 2)
64 | self.assertLen(ds.shard(1, 2).shard(0, 2), 3)
65 | self.assertLen(ds.shard(1, 2).shard(1, 2), 2)
66 |
67 | def test_sharded_sequence_repeated_shard(self):
68 | ds = io.SequenceDataSource(range(10))
69 | self.assertEqual([0, 1, 2], list(ds.shard(0, 2).shard(0, 2)))
70 | self.assertEqual([3, 4], list(ds.shard(0, 2).shard(1, 2)))
71 | self.assertEqual([5, 6, 7], list(ds.shard(1, 2).shard(0, 2)))
72 | self.assertEqual([8, 9], list(ds.shard(1, 2).shard(1, 2)))
73 |
74 | def test_sharded_sequence_serialization(self):
75 | ds = io.SequenceDataSource(range(3))
76 | it = ds.iterate()
77 | self.assertEqual(0, next(it))
78 | ds = ds.from_state(it.state)
79 | self.assertEqual([1, 2], list(it))
80 | self.assertEqual([1, 2], list(ds))
81 |
82 | def test_sharded_sequence_serialization_after_shard(self):
83 | ds = io.SequenceDataSource(range(4))
84 | it = ds.shard(1, num_shards=2).iterate()
85 | self.assertEqual(2, next(it))
86 | ds = ds.iterate().from_state(it.state)
87 | self.assertEqual([3], list(it))
88 | self.assertEqual([3], list(ds))
89 |
90 | def test_sharded_sequence_serialization_after_shard_twice(self):
91 | ds = io.SequenceDataSource(range(8))
92 | it = ds.shard(0, 2).shard(1, 2).iterate()
93 | self.assertEqual(2, next(it))
94 | ds = ds.iterate().from_state(it.state)
95 | self.assertEqual([3], list(it))
96 | self.assertEqual([3], list(ds))
97 |
98 | def test_merged_sequences(self):
99 | ds = io.SequenceDataSource.from_sequences([range(2), range(2, 6)])
100 | self.assertEqual([0, 1, 2, 3, 4, 5], list(ds.shard(0, 1)))
101 | self.assertEqual([0, 1, 2], list(ds.shard(0, 2)))
102 | self.assertEqual([3, 4, 5], list(ds.shard(1, 2)))
103 |
104 | def test_sharded_sequence_with_non_indexable_data(self):
105 | with self.assertRaisesRegex(TypeError, 'data is not indexable'):
106 | io.SequenceDataSource(0) # pytype: disable=wrong-arg-types
107 |
108 | def test_sharded_sequence_with_invalid_num_shards_raises_error(self):
109 | with self.assertRaisesRegex(ValueError, 'num_shards must be positive'):
110 | _ = io.SequenceDataSource(range(3)).shard(0, 0)
111 |
112 | def test_sequence_ignore_error_single_sequence(self):
113 | ds = io.SequenceDataSource(
114 | test_utils.SequenceWithExc(5, 2), ignore_error=True
115 | )
116 | self.assertEqual([0, 1, 3, 4], list(ds))
117 |
118 | def test_sequence_ignore_error_multiple_sequences(self):
119 | ds = io.SequenceDataSource.from_sequences(
120 | [
121 | test_utils.SequenceWithExc(5, 2),
122 | test_utils.SequenceWithExc(5, 3),
123 | ],
124 | ignore_error=True,
125 | )
126 | self.assertEqual([0, 1, 3, 4, 0, 1, 2, 4], list(ds))
127 |
128 | def test_sequence_batch_size(self):
129 | batch_size = 2
130 | ds = io.SequenceDataSource.from_sequences(
131 | [range(4), range(4, 10)], batch_size=batch_size
132 | )
133 | sliced = ds.data[:3]
134 | self.assertIsInstance(sliced, iter_utils._RangeIterator)
135 | self.assertEqual(sliced._batch_size, batch_size)
136 | self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], list(ds))
137 |
138 | @parameterized.named_parameters([
139 | dict(testcase_name='[0]', at=0, expected=0),
140 | dict(testcase_name='[1]', at=1, expected=1),
141 | dict(testcase_name='[-1]', at=8, expected=8),
142 | dict(testcase_name='[:9]', at=slice(9), expected=list(range(9))),
143 | dict(testcase_name='[:-1]', at=slice(-1), expected=list(range(8))),
144 | dict(testcase_name='[2:-1]', at=slice(2, -1), expected=list(range(2, 8))),
145 | dict(testcase_name='[3:7]', at=slice(3, 7), expected=list(range(3, 7))),
146 | dict(testcase_name='[:]', at=slice(None), expected=list(range(9))),
147 | ])
148 | def test_sequence_getitem_single_shard(self, at, expected):
149 | ds = io.SequenceDataSource.from_sequences([range(4), range(4, 9)])
150 | self.assertEqual(expected, ds[at])
151 |
152 | @parameterized.named_parameters([
153 | dict(testcase_name='shard0[0]', shard=0, at=0, expected=0),
154 | dict(testcase_name='shard0[1]', shard=0, at=1, expected=1),
155 | dict(testcase_name='shard0[-1]', shard=0, at=-1, expected=4),
156 | dict(testcase_name='shard1[0]', shard=1, at=0, expected=5),
157 | dict(testcase_name='shard1[1]', shard=1, at=1, expected=6),
158 | dict(testcase_name='shard1[-1]', shard=1, at=-1, expected=8),
159 | dict(
160 | testcase_name='shard0[:]',
161 | shard=0,
162 | at=slice(None),
163 | expected=list(range(5)),
164 | ),
165 | dict(
166 | testcase_name='shard1[:]',
167 | shard=1,
168 | at=slice(None),
169 | expected=list(range(5, 9)),
170 | ),
171 | dict(
172 | testcase_name='shard0[1:-1]',
173 | shard=0,
174 | at=slice(1, -1),
175 | expected=list(range(1, 4)),
176 | ),
177 | dict(
178 | testcase_name='shard1[1:-1]',
179 | shard=1,
180 | at=slice(1, -1),
181 | expected=list(range(6, 8)),
182 | ),
183 | ])
184 | def test_sequence_getitem_two_shards(self, shard, at, expected):
185 | ds = io.SequenceDataSource.from_sequences([range(4), range(4, 9)])
186 | self.assertEqual(expected, ds.shard(shard, 2)[at])
187 |
188 | def test_sequence_getitem_raise(self):
189 | with self.assertRaisesRegex(IndexError, 'Index 4 is out of range'):
190 | _ = io.SequenceDataSource(range(3))[4]
191 |
192 |
193 | class IterableDataSourceTest(parameterized.TestCase):
194 |
195 | def test_sharded_iterable(self):
196 | ds = io.ShardedIterable(range(3))
197 | self.assertEqual([0, 1, 2], list(ds))
198 |
199 | def test_sharded_iterable_shard(self):
200 | ds = io.ShardedIterable(range(3))
201 | num_shards = 2
202 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)]
203 | expected = [[0, 2], [1]]
204 | self.assertEqual(expected, actual)
205 |
206 | def test_sharded_iterable_num_shards_more_than_data(self):
207 | ds = io.ShardedIterable(range(2))
208 | num_shards = 3
209 | actual = [list(ds.shard(i, num_shards)) for i in range(num_shards)]
210 | expected = [[0], [1], []]
211 | self.assertEqual(expected, actual)
212 |
213 | def test_sharded_iterable_serialization(self):
214 | ds = io.ShardedIterable(range(3))
215 | it = ds.iterate()
216 | self.assertEqual(0, next(it))
217 | ds = ds.iterate().from_state(it.state)
218 | self.assertEqual([1, 2], list(it))
219 | self.assertEqual([1, 2], list(ds))
220 |
221 | def test_sharded_iterable_shard_serialization(self):
222 | ds = io.ShardedIterable(range(6))
223 | it = ds.shard(1, num_shards=2).iterate()
224 | self.assertEqual(1, next(it))
225 | ds = ds.from_state(it.state)
226 | self.assertEqual([3, 5], list(it))
227 | self.assertEqual([3, 5], list(ds))
228 |
229 | def test_sharded_iterator_with_iterator_raises_error(self):
230 | with self.assertRaisesRegex(
231 | TypeError, 'input has to be an iterable but not an iterator'
232 | ):
233 | _ = list(io.ShardedIterable(iter(range(3))))
234 |
235 | def test_sharded_iterator_with_non_iteratable_raises_error(self):
236 | with self.assertRaisesRegex(
237 | TypeError, 'input has to be an iterable but not an iterator'
238 | ):
239 | _ = io.ShardedIterable(3) # pytype: disable=wrong-arg-types
240 |
241 | def test_sharded_iterator_with_invalid_num_shards_raises_error(self):
242 | with self.assertRaisesRegex(ValueError, 'num_shards must be positive'):
243 | _ = list(io.ShardedIterable(range(3), io.ShardConfig(num_shards=0)))
244 |
245 |
246 | if __name__ == '__main__':
247 | absltest.main()
248 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Individual Retrieval based metrics."""
15 |
16 | from collections.abc import Sequence
17 | from ml_metrics._src.aggregates import retrieval
18 | from ml_metrics._src.aggregates import types
19 |
20 |
21 | # TODO: b/368688941 - Remove this alias once all users are migrated to the new
22 | # module structure.
23 | TopKRetrievalAggFn = retrieval.TopKRetrievalAggFn
24 |
25 | _METRIC_PYDOC_POSTFIX = """
26 |
27 | Args:
28 | y_true: array of sample's true labels
29 | y_pred: array of sample's label predictions
30 | k_list: k_list is only applicable for average_type != Samples and
31 | multiclass/multioutput input types. It is a list of topk each of which
32 | slices y_pred by y_pred[:topk] assuming the predictions are sorted in
33 | descending order. Default 'None' means consider all outputs in the
34 | prediction.
35 | input_type: one input type from types.InputType
36 |
37 | Returns:
38 | Tuple with metric value(s)
39 | """
40 |
41 |
42 | def topk_retrieval_metrics(
43 | metrics: Sequence[retrieval.RetrievalMetric],
44 | *,
45 | y_true,
46 | y_pred,
47 | k_list: list[int] | None = None,
48 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
49 | ) -> tuple[tuple[float, ...], ...]:
50 | """Compute multiple metrics together for better efficiency.
51 |
52 | Args:
53 | metrics: List of CFM metrics
54 | y_true: array of sample's true labels
55 | y_pred: array of sample's label predictions
56 | k_list: k_list is only applicable for average_type != Samples and
57 | multiclass/multioutput input types. It is a list of topk each of which
58 | slices y_pred by y_pred[:topk] assuming the predictions are sorted in
59 | descending order. Default 'None' means consider all outputs in the
60 | prediction.
61 | input_type: one input type from types.InputType
62 |
63 | Returns:
64 | Tuple containing the evaluation metric values. in the corresponding order of
65 | given metric names in metrics list.
66 | """
67 | return retrieval.TopKRetrievalAggFn(
68 | metrics=metrics, k_list=k_list, input_type=input_type
69 | )(y_true, y_pred)
70 |
71 |
72 | def precision(
73 | y_true,
74 | y_pred,
75 | k_list: list[int] | None = None,
76 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
77 | ) -> tuple[float, ...]:
78 | """Compute Precision Retrieval metric."""
79 | return retrieval.TopKRetrieval(
80 | metrics=retrieval.RetrievalMetric.PRECISION,
81 | k_list=k_list,
82 | input_type=input_type,
83 | ).as_agg_fn()(y_true, y_pred)
84 |
85 |
86 | precision.__doc__ += _METRIC_PYDOC_POSTFIX
87 |
88 |
89 | def ppv(
90 | y_true,
91 | y_pred,
92 | k_list: list[int] | None = None,
93 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
94 | ) -> tuple[float, ...]:
95 | """Compute PPV Retrieval metric."""
96 | return retrieval.TopKRetrievalAggFn(
97 | metrics=retrieval.RetrievalMetric.PPV,
98 | k_list=k_list,
99 | input_type=input_type,
100 | )(y_true, y_pred)
101 |
102 |
103 | ppv.__doc__ += _METRIC_PYDOC_POSTFIX
104 |
105 |
106 | def recall(
107 | y_true,
108 | y_pred,
109 | k_list: list[int] | None = None,
110 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
111 | ) -> tuple[float, ...]:
112 | """Compute Recall Retrieval metric."""
113 | return retrieval.TopKRetrievalAggFn(
114 | metrics=retrieval.RetrievalMetric.RECALL,
115 | k_list=k_list,
116 | input_type=input_type,
117 | )(y_true, y_pred)
118 |
119 |
120 | recall.__doc__ += _METRIC_PYDOC_POSTFIX
121 |
122 |
123 | def sensitivity(
124 | y_true,
125 | y_pred,
126 | k_list: list[int] | None = None,
127 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
128 | ) -> tuple[float, ...]:
129 | """Compute Sensitivity Retrieval metric."""
130 | return retrieval.TopKRetrievalAggFn(
131 | metrics=retrieval.RetrievalMetric.SENSITIVITY,
132 | k_list=k_list,
133 | input_type=input_type,
134 | )(y_true, y_pred)
135 | sensitivity.__doc__ += _METRIC_PYDOC_POSTFIX
136 |
137 |
138 | def tpr(
139 | y_true,
140 | y_pred,
141 | k_list: list[int] | None = None,
142 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
143 | ) -> tuple[float, ...]:
144 | """Compute TPR Retrieval metric."""
145 | return retrieval.TopKRetrievalAggFn(
146 | metrics=retrieval.RetrievalMetric.TPR,
147 | k_list=k_list,
148 | input_type=input_type,
149 | )(y_true, y_pred)
150 |
151 |
152 | tpr.__doc__ += _METRIC_PYDOC_POSTFIX
153 |
154 |
155 | def intersection_over_union(
156 | y_true,
157 | y_pred,
158 | k_list: list[int] | None = None,
159 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
160 | ) -> tuple[float, ...]:
161 | """Compute Intersection Over Union Retrieval metric."""
162 | return retrieval.TopKRetrievalAggFn(
163 | metrics=retrieval.RetrievalMetric.INTERSECTION_OVER_UNION,
164 | k_list=k_list,
165 | input_type=input_type,
166 | )(y_true, y_pred)
167 | intersection_over_union.__doc__ += _METRIC_PYDOC_POSTFIX
168 |
169 |
170 | def positive_predictive_value(
171 | y_true,
172 | y_pred,
173 | k_list: list[int] | None = None,
174 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
175 | ) -> tuple[float, ...]:
176 | """Compute Positive Predictive Value Retrieval metric."""
177 | return retrieval.TopKRetrievalAggFn(
178 | metrics=retrieval.RetrievalMetric.POSITIVE_PREDICTIVE_VALUE,
179 | k_list=k_list,
180 | input_type=input_type,
181 | )(y_true, y_pred)
182 | positive_predictive_value.__doc__ += _METRIC_PYDOC_POSTFIX
183 |
184 |
185 | def f1_score(
186 | y_true,
187 | y_pred,
188 | k_list: list[int] | None = None,
189 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
190 | ) -> tuple[float, ...]:
191 | """Compute F1 Score Retrieval metric."""
192 | return retrieval.TopKRetrievalAggFn(
193 | metrics=retrieval.RetrievalMetric.F1_SCORE,
194 | k_list=k_list,
195 | input_type=input_type,
196 | )(y_true, y_pred)
197 |
198 |
199 | f1_score.__doc__ += _METRIC_PYDOC_POSTFIX
200 |
201 |
202 | def miss_rate(
203 | y_true,
204 | y_pred,
205 | k_list: list[int] | None = None,
206 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
207 | ) -> tuple[float, ...]:
208 | """Compute Miss Rate Retrieval metric."""
209 | return retrieval.TopKRetrievalAggFn(
210 | metrics=retrieval.RetrievalMetric.MISS_RATE,
211 | k_list=k_list,
212 | input_type=input_type,
213 | )(y_true, y_pred)
214 |
215 |
216 | miss_rate.__doc__ += _METRIC_PYDOC_POSTFIX
217 |
218 |
219 | def mean_average_precision(
220 | y_true,
221 | y_pred,
222 | k_list: list[int] | None = None,
223 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
224 | ) -> tuple[float, ...]:
225 | """Compute Mean Average Precision Retrieval metric."""
226 | return retrieval.TopKRetrievalAggFn(
227 | metrics=retrieval.RetrievalMetric.MEAN_AVERAGE_PRECISION,
228 | k_list=k_list,
229 | input_type=input_type,
230 | )(y_true, y_pred)
231 | mean_average_precision.__doc__ += _METRIC_PYDOC_POSTFIX
232 |
233 |
234 | def mean_reciprocal_rank(
235 | y_true,
236 | y_pred,
237 | k_list: list[int] | None = None,
238 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
239 | ) -> tuple[float, ...]:
240 | """Compute Mean Reciprocal Rank Retrieval metric."""
241 | return retrieval.TopKRetrievalAggFn(
242 | metrics=retrieval.RetrievalMetric.MEAN_RECIPROCAL_RANK,
243 | k_list=k_list,
244 | input_type=input_type,
245 | )(y_true, y_pred)
246 | mean_reciprocal_rank.__doc__ += _METRIC_PYDOC_POSTFIX
247 |
248 |
249 | def accuracy(
250 | y_true,
251 | y_pred,
252 | k_list: list[int] | None = None,
253 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
254 | ) -> tuple[float, ...]:
255 | """Compute Accuracy Retrieval metric."""
256 | return retrieval.TopKRetrievalAggFn(
257 | metrics=retrieval.RetrievalMetric.ACCURACY,
258 | k_list=k_list,
259 | input_type=input_type,
260 | )(y_true, y_pred)
261 | accuracy.__doc__ += _METRIC_PYDOC_POSTFIX
262 |
263 |
264 | def dcg_score(
265 | y_true,
266 | y_pred,
267 | k_list: list[int] | None = None,
268 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
269 | ) -> tuple[float, ...]:
270 | """Compute DCG Score Retrieval metric."""
271 | return retrieval.TopKRetrievalAggFn(
272 | metrics=retrieval.RetrievalMetric.DCG_SCORE,
273 | k_list=k_list,
274 | input_type=input_type,
275 | )(y_true, y_pred)
276 | dcg_score.__doc__ += _METRIC_PYDOC_POSTFIX
277 |
278 |
279 | def ndcg_score(
280 | y_true,
281 | y_pred,
282 | k_list: list[int] | None = None,
283 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
284 | ) -> tuple[float, ...]:
285 | """Compute NDCG Score Retrieval metric."""
286 | return retrieval.TopKRetrievalAggFn(
287 | metrics=retrieval.RetrievalMetric.NDCG_SCORE,
288 | k_list=k_list,
289 | input_type=input_type,
290 | )(y_true, y_pred)
291 | ndcg_score.__doc__ += _METRIC_PYDOC_POSTFIX
292 |
293 |
294 | def fowlkes_mallows_index(
295 | y_true,
296 | y_pred,
297 | k_list: list[int] | None = None,
298 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
299 | ) -> tuple[float, ...]:
300 | """Compute Fowlkes Mallows Index Retrieval metric."""
301 | return retrieval.TopKRetrievalAggFn(
302 | metrics=retrieval.RetrievalMetric.FOWLKES_MALLOWS_INDEX,
303 | k_list=k_list,
304 | input_type=input_type,
305 | )(y_true, y_pred)
306 | fowlkes_mallows_index.__doc__ += _METRIC_PYDOC_POSTFIX
307 |
308 |
309 | def false_discovery_rate(
310 | y_true,
311 | y_pred,
312 | k_list: list[int] | None = None,
313 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
314 | ) -> tuple[float, ...]:
315 | """Compute False Discovery Rate Retrieval metric."""
316 | return retrieval.TopKRetrievalAggFn(
317 | metrics=retrieval.RetrievalMetric.FALSE_DISCOVERY_RATE,
318 | k_list=k_list,
319 | input_type=input_type,
320 | )(y_true, y_pred)
321 | false_discovery_rate.__doc__ += _METRIC_PYDOC_POSTFIX
322 |
323 |
324 | def threat_score(
325 | y_true,
326 | y_pred,
327 | k_list: list[int] | None = None,
328 | input_type: types.InputType = types.InputType.MULTICLASS_MULTIOUTPUT,
329 | ) -> tuple[float, ...]:
330 | """Compute Threat Score Retrieval metric."""
331 | return retrieval.TopKRetrievalAggFn(
332 | metrics=retrieval.RetrievalMetric.THREAT_SCORE,
333 | k_list=k_list,
334 | input_type=input_type,
335 | )(y_true, y_pred)
336 | threat_score.__doc__ += _METRIC_PYDOC_POSTFIX
337 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/retrieval_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for retrieval."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | from ml_metrics._src.aggregates import types
19 | from ml_metrics._src.metrics import retrieval
20 | import numpy as np
21 |
22 |
23 | class RetrievalTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters([
26 | dict(
27 | testcase_name="Precision",
28 | metric_fn=retrieval.precision,
29 | # precision@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8
30 | # precision@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5/8
31 | expected=[5 / 8, 5.5 / 8],
32 | ),
33 | dict(
34 | testcase_name="PPV",
35 | metric_fn=retrieval.ppv,
36 | # ppv@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8
37 | # ppv@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5/8
38 | expected=[5 / 8, 5.5 / 8],
39 | ),
40 | dict(
41 | testcase_name="Recall",
42 | metric_fn=retrieval.recall,
43 | # recall@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8
44 | # recall@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8
45 | expected=[4.5 / 8, 5.5 / 8],
46 | ),
47 | dict(
48 | testcase_name="Sensitivity",
49 | metric_fn=retrieval.sensitivity,
50 | # sensitivity@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8
51 | # sensitivity@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8
52 | expected=[4.5 / 8, 5.5 / 8],
53 | ),
54 | dict(
55 | testcase_name="TPR",
56 | metric_fn=retrieval.tpr,
57 | # tpr@1 = mean([1, 0, 0, 1, 1/2, 1, 0, 1]) = 4.5/8
58 | # tpr@2 = mean([1, 1, 0, 1, 1/2, 1, 0, 1]) = 5.5/8
59 | expected=[4.5 / 8, 5.5 / 8],
60 | ),
61 | dict(
62 | testcase_name="POSITIVE_PREDICTIVE_VALUE",
63 | metric_fn=retrieval.positive_predictive_value,
64 | # positive_predictive_value@1 = mean([1, 0, 0, 1,
65 | # 1, 1, 0, 1]) = 5/8
66 | # positive_predictive_value@2 = mean([1, 1/2, 0, 1,
67 | # 1, 1, 0, 1]) = 5.5/8
68 | expected=[5 / 8, 5.5 / 8],
69 | ),
70 | dict(
71 | testcase_name="INTERSECTION_OVER_UNION",
72 | metric_fn=retrieval.intersection_over_union,
73 | # intersection_over_union@1 = mean([1, 0, 0, 1,
74 | # 1/2, 1, 0, 1]) = 4.5/8
75 | # intersection_over_union@2 = mean([1, 1/2, 0, 1,
76 | # 1/2, 1, 0, 1]) = 5/8
77 | expected=[4.5 / 8, 5 / 8],
78 | ),
79 | dict(
80 | testcase_name="Accuracy",
81 | metric_fn=retrieval.accuracy,
82 | # accuracy@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8
83 | # accuracy@2 = mean([1, 1, 0, 1, 1, 1, 0, 1]) = 6/8
84 | expected=[5 / 8, 6 / 8],
85 | ),
86 | dict(
87 | testcase_name="F1Score",
88 | metric_fn=retrieval.f1_score,
89 | # f1_score@1 = mean([1, 0, 0, 1, 1/1.5, 1, 0, 1]) = (4+2/3)/8
90 | # f1_score@2 = mean([1, 1/1.5, 0, 1, 1/1.5, 1, 0, 1]) = (4+4/3)/8
91 | expected=[(4 + 2 / 3) / 8, (4 + 4 / 3) / 8],
92 | ),
93 | dict(
94 | testcase_name="MissRate",
95 | metric_fn=retrieval.miss_rate,
96 | # miss_rate@1 = mean([0, 1, 1, 0, 1/2, 0, 1, 0]) = 3.5/8
97 | # miss_rate@2 = mean([0, 0, 1, 0, 1/2, 0, 1, 0]) = 2.5/8
98 | expected=[3.5 / 8, 2.5 / 8],
99 | ),
100 | dict(
101 | testcase_name="mAP",
102 | metric_fn=retrieval.mean_average_precision,
103 | # mAP@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5/8
104 | # mAP@2 = mean([1, 1/2, 0, 1, 1/2, 1, 0, 1]) = 5/8
105 | expected=[5 / 8, 5 / 8],
106 | ),
107 | dict(
108 | testcase_name="MRR",
109 | metric_fn=retrieval.mean_reciprocal_rank,
110 | # mRR@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8
111 | # mRR@2 = mean([1, 1/2, 0, 1, 1, 1, 0, 1]) = 5.5 / 8
112 | expected=[5 / 8, 5.5 / 8],
113 | ),
114 | dict(
115 | testcase_name="DCG",
116 | metric_fn=retrieval.dcg_score,
117 | # DCG@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8
118 | # DCG@2 = mean([1, 1/log2(3), 0, 1, 1, 1, 0, 1])
119 | # = (5 + 1 / log2(3)) / 8
120 | expected=[5 / 8, (1 / np.log2(3) + 5) / 8],
121 | ),
122 | dict(
123 | testcase_name="NDCG",
124 | metric_fn=retrieval.ndcg_score,
125 | # NDCG@1 = mean([1, 0, 0, 1, 1, 1, 0, 1]) = 5 / 8
126 | # NDCG@2 = mean([1, 1/log2(3), 0, 1, 1/(1+1/log2(3)), 1, 0, 1])
127 | # = (4 + 1 / log2(3) + 1/(1+1/log2(3))) / 8
128 | expected=[5 / 8, (4 + 1 / np.log2(3) + 1 / (1 + 1 / np.log2(3))) / 8],
129 | ),
130 | dict(
131 | testcase_name="FMI",
132 | metric_fn=retrieval.fowlkes_mallows_index,
133 | # FMI@1 = mean(sqrt([1, 0, 0, 1, 0.5, 1, 0, 1]))
134 | # FMI@2 = mean(sqrt([1, 0.5, 0, 1, 0.5, 1, 0, 1]))
135 | expected=[
136 | np.sqrt([1, 0, 0, 1, 0.5, 1, 0, 1]).mean(),
137 | np.sqrt([1, 0.5, 0, 1, 0.5, 1, 0, 1]).mean(),
138 | ],
139 | ),
140 | dict(
141 | testcase_name="FDR",
142 | metric_fn=retrieval.false_discovery_rate,
143 | # FDR@1 = mean([0, 1, 1, 0, 0, 0, 1, 0]) = 3/8
144 | # FDR@2 = mean([0, 1/2, 1, 0, 0, 0, 1, 0]) = 2.5/8
145 | expected=[3 / 8, 2.5 / 8],
146 | ),
147 | dict(
148 | testcase_name="ThreatScore",
149 | metric_fn=retrieval.threat_score,
150 | # threat_score@1 = mean([1, 0, 0, 1, 0.5, 1, 0, 1]) = 4.5 / 8
151 | # threat_score@2 = mean([0.5, 0.5, 0, 0.5, 1/3, 0.5, 0, 0.5]) =
152 | # = (2.5 + 1/3) / 8
153 | expected=[4.5 / 8, (2.5 + 1 / 3) / 8],
154 | ),
155 | ])
156 | def test_individual_metric(self, metric_fn, expected):
157 | k_list = [1, 2]
158 | y_pred = [["y"], ["n", "y"], ["y"], ["n"], ["y"], ["n"], ["n"], ["u"]]
159 | y_true = [["y"], ["y"], ["n"], ["n"], ["y", "n"], ["n"], ["y"], ["u"]]
160 | metric_doc_details = "\n".join(
161 | metric_fn.__doc__.split("\n")[1:]
162 | ).strip() # ignore the description line for comparison
163 | self.assertEqual(
164 | metric_doc_details, retrieval._METRIC_PYDOC_POSTFIX.strip()
165 | )
166 | np.testing.assert_allclose(
167 | expected,
168 | metric_fn(
169 | y_true,
170 | y_pred,
171 | k_list=k_list,
172 | input_type=types.InputType.MULTICLASS_MULTIOUTPUT,
173 | ),
174 | )
175 |
176 |
177 | if __name__ == "__main__":
178 | absltest.main()
179 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/rolling_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Individual statistics metrics."""
15 |
16 | from ml_metrics._src.aggregates import rolling_stats
17 | from ml_metrics._src.aggregates import types
18 |
19 |
20 | _METRIC_PYDOC_POSTFIX = """
21 |
22 | The metric is computed based on non-nan values within the batch.
23 |
24 | Args:
25 | batch: A batch of numbers.
26 |
27 | Returns:
28 | Metric value.
29 | """
30 |
31 |
32 | def var(batch: types.NumbersT) -> float:
33 | """Computes the variance in a batch."""
34 | return rolling_stats.MeanAndVariance().add(batch).var
35 |
36 |
37 | var.__doc__ += _METRIC_PYDOC_POSTFIX
38 |
39 |
40 | def stddev(batch: types.NumbersT) -> float:
41 | """Computes the standard deviation in a batch."""
42 | return rolling_stats.MeanAndVariance().add(batch).stddev
43 |
44 |
45 | stddev.__doc__ += _METRIC_PYDOC_POSTFIX
46 |
47 |
48 | def mean(batch: types.NumbersT) -> float:
49 | """Computes the mean in a batch."""
50 | return rolling_stats.MeanAndVariance().add(batch).mean
51 |
52 |
53 | mean.__doc__ += _METRIC_PYDOC_POSTFIX
54 |
55 |
56 | def count(batch: types.NumbersT) -> int:
57 | """Computes the number of elements in a batch."""
58 | return rolling_stats.MeanAndVariance().add(batch).count
59 |
60 |
61 | count.__doc__ += _METRIC_PYDOC_POSTFIX
62 |
63 |
64 | def total(batch: types.NumbersT) -> float:
65 | """Computes the total sum of a batch."""
66 | return rolling_stats.MeanAndVariance().add(batch).total
67 |
68 |
69 | total.__doc__ += _METRIC_PYDOC_POSTFIX
70 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/rolling_stats_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for stats."""
15 |
16 | from absl.testing import parameterized
17 |
18 | from ml_metrics._src.metrics import rolling_stats
19 | import numpy as np
20 |
21 | from absl.testing import absltest
22 |
23 | _BATCH_WITH_NAN = np.concatenate(
24 | [np.random.randn(1000) + 1e7, [np.nan] * 100], axis=0
25 | )
26 |
27 |
28 | class StatsTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters([
31 | dict(
32 | testcase_name='variance',
33 | metric_fn=rolling_stats.var,
34 | expected=np.nanvar(_BATCH_WITH_NAN),
35 | ),
36 | dict(
37 | testcase_name='stddev',
38 | metric_fn=rolling_stats.stddev,
39 | expected=np.nanstd(_BATCH_WITH_NAN),
40 | ),
41 | dict(
42 | testcase_name='mean',
43 | metric_fn=rolling_stats.mean,
44 | expected=np.nanmean(_BATCH_WITH_NAN),
45 | ),
46 | dict(
47 | testcase_name='count',
48 | metric_fn=rolling_stats.count,
49 | expected=np.nansum(~np.isnan(_BATCH_WITH_NAN)),
50 | ),
51 | dict(
52 | testcase_name='total',
53 | metric_fn=rolling_stats.total,
54 | expected=np.nansum(_BATCH_WITH_NAN),
55 | ),
56 | ])
57 | def test_rolling_stats_individual_metrics(self, metric_fn, expected):
58 | got = metric_fn(_BATCH_WITH_NAN)
59 | np.testing.assert_almost_equal(got, expected)
60 |
61 | @parameterized.named_parameters([
62 | dict(
63 | testcase_name='variance',
64 | metric_fn=rolling_stats.var,
65 | expected=np.nan,
66 | ),
67 | dict(
68 | testcase_name='stddev',
69 | metric_fn=rolling_stats.stddev,
70 | expected=np.nan,
71 | ),
72 | dict(
73 | testcase_name='mean',
74 | metric_fn=rolling_stats.mean,
75 | expected=np.nan,
76 | ),
77 | dict(
78 | testcase_name='count',
79 | metric_fn=rolling_stats.count,
80 | expected=0,
81 | ),
82 | dict(
83 | testcase_name='total',
84 | metric_fn=rolling_stats.total,
85 | expected=0.0,
86 | ),
87 | ])
88 | def test_rolling_stats_individual_metrics_empty_batch(
89 | self, metric_fn, expected
90 | ):
91 | got = metric_fn([])
92 | np.testing.assert_almost_equal(got, expected)
93 |
94 |
95 | if __name__ == '__main__':
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Individual text stats metrics."""
15 |
16 | from collections.abc import Sequence
17 |
18 | from ml_metrics._src.aggregates import rolling_stats
19 | from ml_metrics._src.aggregates import text
20 | from ml_metrics._src.signals import text as text_scores
21 | from ml_metrics._src.utils import iter_utils
22 | from ml_metrics._src.tools.telemetry import telemetry
23 |
24 |
25 | @telemetry.WithTelemetry(
26 | api='ml_metrics', category='metric', reference='topk_word_ngrams'
27 | )
28 | def topk_word_ngrams(
29 | texts: Sequence[str],
30 | k: int,
31 | n: int,
32 | use_first_ngram_only: bool = False,
33 | count_duplicate: bool = True,
34 | ) -> list[tuple[str, float]]:
35 | """Top k word n-grams metrics.
36 |
37 | Identify the top `k` frequent occurring word n-grams with a case-insensitive
38 | approach. The text will first be cleaned by removing non-alphabetic characters
39 | and spaces, and then converted to lowercase before computing the top k
40 | n-grams. When multiple n-grams share the same frequency, alphabetical order
41 | will be used as a tie-breaker. The result is a list of tuples containing the
42 | n-gram pattern and its corresponding frequency. The list includes either `k`
43 | or the number of distinct n-grams tuples, whichever is less.
44 |
45 | Args:
46 | texts:
47 | Sequence of texts.
48 | k:
49 | Number of most frequent word n-grams.
50 | n:
51 | Number of grams.
52 | use_first_ngram_only:
53 | If `True`, only the first n words of each text will be used to form the
54 | n-grams and `count_duplicate` will be ignored. Otherwise, all words
55 | present in each text will be considered for generating the n-grams.
56 | Default to `False`.
57 | count_duplicate:
58 | If `True`, duplicate n-grams within the text are included in the total
59 | count. Otherwise, the count of a unique N-gram will only consider its
60 | first occurrence.
61 |
62 | Returns:
63 | List of tuples of ngram and its frequency of appearance as a pair.
64 | """
65 |
66 | if k <= 0 or n <= 0:
67 | raise ValueError(
68 | f'k and n must be positive integers but k={k} and n={n} was passed.'
69 | )
70 |
71 | return text.TopKWordNGrams(
72 | k=k,
73 | n=n,
74 | use_first_ngram_only=use_first_ngram_only,
75 | count_duplicate=count_duplicate,
76 | ).as_agg_fn()(texts)
77 |
78 |
79 | @telemetry.WithTelemetry(
80 | api='ml_metrics', category='metric', reference='pattern_frequency'
81 | )
82 | def pattern_frequency(
83 | texts: Sequence[str], patterns: Sequence[str], count_duplicate: bool = True
84 | ) -> list[tuple[str, float]]:
85 | """Pattern frequency metric.
86 |
87 | Identify the frequency of occurrence for each pattern found within the given
88 | texts.
89 |
90 |
91 | Args:
92 | texts:
93 | Sequence of texts.
94 | patterns:
95 | Sequence of text patterns.
96 | count_duplicate:
97 | If `True`, duplicate pattern within the text are included in the total
98 | count. Otherwise, the count of a pattern will only consider its first
99 | occurrence. Default to `False`.
100 |
101 | Returns:
102 | List of tuples of pattern and its frequency of appearance as a pair.
103 | """
104 |
105 | if not patterns:
106 | raise ValueError('Patterns must not be empty.')
107 |
108 | if len(set(patterns)) != len(patterns):
109 | raise ValueError(f'Patterns must be unique: {patterns}')
110 |
111 | return text.PatternFrequency(
112 | patterns=patterns, count_duplicate=count_duplicate
113 | ).as_agg_fn()(texts)
114 |
115 |
116 | @telemetry.WithTelemetry(
117 | api='ml_metrics', category='metric', reference='avg_alphabetical_char_count'
118 | )
119 | def avg_alphabetical_char_count(
120 | texts: Sequence[str],
121 | ) -> rolling_stats.MeanAndVariance:
122 | """Average alphabetical character count metric."""
123 |
124 | if not list(texts):
125 | raise ValueError('`texts` must not be empty.')
126 |
127 | batch_scorer_fn = iter_utils.iterate_fn(text_scores.alphabetical_char_count)
128 | return rolling_stats.MeanAndVariance(
129 | batch_score_fn=batch_scorer_fn
130 | ).as_agg_fn()(texts)
131 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/text_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for text."""
15 |
16 | from absl.testing import parameterized
17 | from ml_metrics._src.aggregates import rolling_stats
18 | from ml_metrics._src.metrics import text
19 | import numpy as np
20 |
21 | from absl.testing import absltest
22 |
23 |
24 | class TopKWordNgramsTest(parameterized.TestCase):
25 |
26 | @parameterized.named_parameters([
27 | dict(
28 | testcase_name='distinct_ngrams',
29 | k=2,
30 | n=2,
31 | use_first_ngram_only=False,
32 | count_duplicate=False,
33 | expected_result=[
34 | ('a a', 1 / 3),
35 | ('b b', 1 / 3),
36 | # ('c c', 1 / 3),
37 | # ('d a', 1 / 3),
38 | ],
39 | ),
40 | dict(
41 | testcase_name='count_duplicate',
42 | k=2,
43 | n=2,
44 | use_first_ngram_only=False,
45 | count_duplicate=True,
46 | expected_result=[
47 | ('b b', 2 / 3),
48 | ('a a', 1 / 3),
49 | # ('c c', 1 / 3),
50 | # ('d a', 1 / 3),
51 | ],
52 | ),
53 | dict(
54 | testcase_name='use_first_ngram_only',
55 | k=2,
56 | n=2,
57 | use_first_ngram_only=True,
58 | # count_duplicate will be ignored. See TopKWordNGrams description.
59 | count_duplicate=True,
60 | expected_result=[
61 | ('b b', 1 / 3),
62 | ('c c', 1 / 3),
63 | # ('d a', 1 / 3),
64 | ],
65 | ),
66 | dict(
67 | testcase_name='large_k',
68 | k=10,
69 | n=2,
70 | use_first_ngram_only=False,
71 | count_duplicate=True,
72 | expected_result=[
73 | ('b b', 2 / 3),
74 | ('a a', 1 / 3),
75 | ('c c', 1 / 3),
76 | ('d a', 1 / 3),
77 | ],
78 | ),
79 | dict(
80 | testcase_name='large_n',
81 | k=4,
82 | n=10,
83 | use_first_ngram_only=False,
84 | count_duplicate=True,
85 | expected_result=[],
86 | ),
87 | dict(
88 | testcase_name='2k_1n',
89 | k=2,
90 | n=1,
91 | use_first_ngram_only=False,
92 | count_duplicate=True,
93 | expected_result=[
94 | ('b', 3 / 3),
95 | ('a', 2 / 3),
96 | # ('c', 2 / 3),
97 | # ('d', 1 / 3),
98 | ],
99 | ),
100 | dict(
101 | testcase_name='2k_3n',
102 | k=2,
103 | n=3,
104 | use_first_ngram_only=False,
105 | count_duplicate=True,
106 | expected_result=[
107 | ('b b b', 1 / 3),
108 | ('d a a', 1 / 3),
109 | ],
110 | ),
111 | ])
112 | def test_topk_word_ngrams(
113 | self, k, n, use_first_ngram_only, count_duplicate, expected_result
114 | ):
115 | texts = [
116 | 'c c',
117 | 'b B b', # Case-insensitive
118 | 'd a a',
119 | ]
120 | result = text.topk_word_ngrams(
121 | texts=texts,
122 | k=k,
123 | n=n,
124 | use_first_ngram_only=use_first_ngram_only,
125 | count_duplicate=count_duplicate
126 | )
127 | self.assertSequenceAlmostEqual(expected_result, result)
128 |
129 | def test_topk_word_ngrams_empty(self):
130 | # Randomly choose k and n.
131 | result = text.topk_word_ngrams(texts=[], k=3, n=2)
132 | self.assertSequenceAlmostEqual([], result)
133 |
134 | def test_topk_word_ngrams_invalid_kn(self):
135 | kn_pairs = [(0, 1), (1, 0)]
136 | for k, n in kn_pairs:
137 | with self.assertRaisesRegex(
138 | ValueError, 'k and n must be positive integers.'
139 | ):
140 | text.topk_word_ngrams(texts=['a'], k=k, n=n)
141 |
142 |
143 | class PatternFrequencyTest(parameterized.TestCase):
144 |
145 | @parameterized.named_parameters([
146 | dict(
147 | testcase_name='not_count_duplicate',
148 | count_duplicate=False,
149 | expected_result=[
150 | ('ab', 1 / 2),
151 | ('mmm', 0),
152 | ('xyx', 2 / 2),
153 | ],
154 | ),
155 | dict(
156 | testcase_name='count_duplicate',
157 | count_duplicate=True,
158 | expected_result=[
159 | ('ab', 2 / 2),
160 | ('mmm', 0),
161 | ('xyx', 3 / 2),
162 | ],
163 | ),
164 | ])
165 | def test_pattern_frequency(self, count_duplicate, expected_result):
166 | batch = ['ab ab xyx', 'xyxyx']
167 | result = text.pattern_frequency(
168 | texts=batch,
169 | patterns=['ab', 'xyx', 'mmm'],
170 | count_duplicate=count_duplicate,
171 | )
172 | self.assertSequenceAlmostEqual(expected_result, sorted(result))
173 |
174 | def test_pattern_frequency_empty(self):
175 | result = text.pattern_frequency(
176 | texts=[], patterns=['ab', 'xyx'], count_duplicate=False
177 | )
178 | self.assertSequenceEqual([], result)
179 |
180 | @parameterized.named_parameters([
181 | dict(
182 | testcase_name='empty_pattern',
183 | patterns=[],
184 | expected_regex='Patterns must not be empty.',
185 | ),
186 | dict(
187 | testcase_name='duplicate_patterns',
188 | patterns=['a', 'a', 'b'],
189 | expected_regex='Patterns must be unique:',
190 | ),
191 | ])
192 | def test_pattern_frequency_invalid_patterns(self, patterns, expected_regex):
193 | with self.assertRaisesRegex(ValueError, expected_regex):
194 | # `texts` is randomly assigned.
195 | text.pattern_frequency(texts=['a'], patterns=patterns)
196 |
197 |
198 | def get_expected_mean_and_variance_state_result(batch):
199 | # result_dict key must match MeanAndVariance properties name.
200 | result_dict = {
201 | 'mean': np.mean(batch),
202 | 'var': np.var(batch),
203 | 'stddev': np.std(batch),
204 | 'count': len(batch),
205 | 'total': np.sum(batch),
206 | }
207 | return result_dict
208 |
209 |
210 | class AvgAlphabeticalCharCountTest(absltest.TestCase):
211 |
212 | def test_avg_alphabetical_char_count(self):
213 | got_state = text.avg_alphabetical_char_count(
214 | texts=['ab', 'a b', '', 'ok?']
215 | )
216 | self.assertIsInstance(got_state, rolling_stats.MeanAndVariance)
217 |
218 | expected_state = get_expected_mean_and_variance_state_result(
219 | [2, 2, 0, 2]
220 | )
221 | for metric_name in expected_state:
222 | np.testing.assert_allclose(
223 | expected_state[metric_name],
224 | getattr(got_state, metric_name),
225 | )
226 |
227 | def test_avg_alphabetical_char_count_empty(self):
228 | with self.assertRaisesRegex(ValueError, '`texts` must not be empty.'):
229 | text.avg_alphabetical_char_count(texts=[])
230 |
231 |
232 | if __name__ == '__main__':
233 | absltest.main()
234 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/utils.py:
--------------------------------------------------------------------------------
1 | """Utility for ML metrics APIs."""
2 |
3 | import itertools
4 | from typing import Any
5 |
6 | from ml_metrics._src.aggregates import classification
7 | from ml_metrics._src.aggregates import types
8 | from ml_metrics._src.tools.telemetry import telemetry
9 |
10 |
11 | @telemetry.WithTelemetry(
12 | api='ml_metrics', category='metric', reference='verify_input'
13 | )
14 | def verify_input(y_true, y_pred, average, input_type, vocab, pos_label):
15 | if (
16 | average == types.AverageType.BINARY
17 | and input_type == types.InputType.BINARY
18 | ):
19 | _validate_pos_label(y_true, y_pred, pos_label, vocab)
20 |
21 |
22 | def _validate_pos_label(
23 | y_true, y_pred, pos_label: Any, vocab: dict[str, int] | None
24 | ):
25 | vocab = vocab or classification.get_vocab(
26 | itertools.chain(y_true, y_pred), False
27 | )
28 | labels = list(vocab.keys())
29 | labels.sort()
30 | if pos_label not in labels:
31 | raise ValueError(
32 | f'Pos label {pos_label} not in labels. Please set a valid pos_label'
33 | f' from: {labels}'
34 | )
35 |
--------------------------------------------------------------------------------
/ml_metrics/_src/metrics/utils_test.py:
--------------------------------------------------------------------------------
1 | """Tests for ML metrics API function utils."""
2 |
3 | from ml_metrics._src.aggregates import types
4 | from ml_metrics._src.metrics import utils
5 |
6 | from absl.testing import absltest
7 |
8 |
9 | InputType = types.InputType
10 | AverageType = types.AverageType
11 |
12 |
13 | class UtilsTest(absltest.TestCase):
14 |
15 | def test_validate_inputs(self):
16 | with self.assertRaisesRegex(
17 | ValueError,
18 | r"Pos label 1 not in labels. Please set a valid pos_label from: \['n',"
19 | r" 'y'\]",
20 | ):
21 | _ = utils.verify_input(
22 | y_true=["y", "n", "n"],
23 | y_pred=["y", "n", "y"],
24 | average=AverageType.BINARY,
25 | input_type=InputType.BINARY,
26 | pos_label=1,
27 | vocab=None,
28 | )
29 |
30 |
31 | if __name__ == "__main__":
32 | absltest.main()
33 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/cg_score.py:
--------------------------------------------------------------------------------
1 | """Code to compute CG Score.
2 |
3 | The Complexity Gap Score (CG Score) quantifies the "influence of individual
4 | instances." CG scores identify observations with large/small influence on
5 | downstream classification performance, potentially flagging label noise issues.
6 | The current implementation works for binary labels only.
7 | More information about the stochastic method implemented here can be found at
8 | https://arxiv.org/pdf/2301.00930.pdf (Section A.2).
9 | """
10 |
11 | from ml_metrics._src.aggregates import types
12 | from ml_metrics._src.tools.telemetry import telemetry
13 | import numpy as np
14 |
15 |
16 | @telemetry.WithTelemetry(
17 | api='ml_metrics', category='signal', reference='complexity_gap_score'
18 | )
19 | def complexity_gap_score(
20 | labels: types.NumbersT,
21 | embeddings: types.NumbersT,
22 | *,
23 | num_repetitions: int = 1,
24 | class_balance_ratio: float = 1.0,
25 | random_seed: int = 0,
26 | ) -> types.NumbersT:
27 | """Calculates the Complexity Gap (CG) score for identifying influential instances.
28 |
29 | Args:
30 | labels: Labels in binary vector representations.
31 | embeddings: Embeddings in vector representations.
32 | num_repetitions: Number of times to repeat the CG score calculation.
33 | class_balance_ratio: Ratio for balancing classes during calculation (e.g.,
34 | 1.0 for perfect balance).
35 | random_seed: Random seed for reproducibility.
36 |
37 | Returns:
38 | A NumPy array containing the CG score for each data point.
39 | """
40 | if (l := len(np.unique(labels))) > 2:
41 | raise ValueError(f'CG score only works for binary labels, got {l} labels.')
42 |
43 | cg_scores = np.zeros(len(labels))
44 |
45 | if l == 1:
46 | return cg_scores
47 |
48 | data_by_label = _group_data_by_label(embeddings, labels)
49 | if random_seed:
50 | np.random.seed(random_seed)
51 |
52 | for _ in range(num_repetitions):
53 | for label, data in data_by_label.items():
54 | data = np.array(data['data'])
55 | other_label = _get_other_label(data_by_label, label)
56 | other_data = np.array(data_by_label[other_label]['data'])
57 | balanced_data = _balance_dataset(data, other_data, class_balance_ratio)
58 | vi_scores = _calculate_influence_scores(balanced_data, data.shape[0])
59 | cg_scores[data_by_label[label]['indices']] += vi_scores
60 |
61 | return cg_scores / np.maximum(num_repetitions, 1)
62 |
63 |
64 | def _group_data_by_label(embeddings: np.ndarray, labels: np.ndarray):
65 | data_by_label = {}
66 | for i, (embedding, label) in enumerate(zip(embeddings, labels, strict=True)):
67 | data_by_label.setdefault(label, {'data': [], 'indices': []})
68 | data_float = embedding.astype(np.float64)
69 | data_by_label[label]['data'].append(
70 | data_float / np.linalg.norm(data_float, axis=-1, keepdims=True)
71 | )
72 | data_by_label[label]['indices'].append(i)
73 | return data_by_label
74 |
75 |
76 | def _get_other_label(
77 | data_by_label: dict[int, object], current_label: int
78 | ) -> int:
79 | return [label for label in data_by_label if label != current_label][0]
80 |
81 |
82 | def _balance_dataset(
83 | data: np.ndarray, other_data: np.ndarray, ratio: float
84 | ) -> np.ndarray:
85 | max_size = min(int(len(data) * ratio), len(other_data))
86 | selected_indices = np.random.choice(len(other_data), max_size, replace=False)
87 | balanced_other_data = other_data[selected_indices]
88 | return np.concatenate((data, balanced_other_data))
89 |
90 |
91 | def _calculate_influence_scores(
92 | data: np.ndarray, data_size: int
93 | ) -> types.NumbersT:
94 | """Computes Complexity Gap "influence" scores for each data point."""
95 |
96 | reformatted_data = data
97 | y = np.concatenate([np.ones(data_size), -np.ones(data.shape[0] - data_size)])
98 | hermitian_inner = reformatted_data @ np.transpose(reformatted_data)
99 | hermitian = (hermitian_inner * (np.pi - np.arccos(hermitian_inner))) / (
100 | 2 * np.pi
101 | )
102 | np.fill_diagonal(hermitian, 0.5)
103 | hermitian[np.isnan(hermitian)] = np.nextafter(np.float64(0), np.float64(1))
104 |
105 | inv_hermitian = np.linalg.pinv(hermitian, hermitian=True)
106 | original_error = y @ (inv_hermitian @ y)
107 |
108 | influence_scores = np.zeros(data_size)
109 | for k in range(data_size):
110 | without_col = np.delete(inv_hermitian, k, axis=1)
111 | without_k = np.delete(without_col, k, axis=0)
112 | row_expanded = np.expand_dims(without_col[k, :], 0)
113 | with_k = inv_hermitian[k, k]
114 |
115 | inv_hermitian_except_k = (
116 | without_k - (row_expanded.transpose() @ row_expanded) / with_k
117 | )
118 | y_except_k = np.delete(y, k, axis=0)
119 | influence_scores[k] = original_error - y_except_k @ (
120 | inv_hermitian_except_k @ y_except_k
121 | )
122 |
123 | return influence_scores
124 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/cg_score_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from ml_metrics._src.signals import cg_score
3 | import numpy as np
4 |
5 |
6 | class CgScoreTest(absltest.TestCase):
7 |
8 | def test_validate_nonbinary_input(self):
9 | labels = ['A', 'B', 'C']
10 | with self.assertRaisesRegex(ValueError, 'only works for binary label'):
11 | cg_score.complexity_gap_score(labels, np.array([[1, 2], [2, 3], [5, 4]]))
12 |
13 | def test_group_data_by_label(self):
14 | input_embeddings = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
15 | labels = np.array(['a', 'b', 'b'])
16 | result = cg_score._group_data_by_label(input_embeddings, labels)
17 |
18 | self.assertIn('a', result)
19 | self.assertSequenceAlmostEqual([0], result['a']['indices'])
20 | self.assertLen(result['a']['data'], 1)
21 | self.assertAlmostEqual(1, np.linalg.norm(result['a']['data']))
22 |
23 | self.assertIn('b', result)
24 | self.assertSequenceAlmostEqual([1, 2], result['b']['indices'])
25 | self.assertLen(result['b']['data'], 2)
26 | self.assertAlmostEqual(1, np.linalg.norm(result['b']['data'][0]))
27 | self.assertAlmostEqual(1, np.linalg.norm(result['b']['data'][1]))
28 |
29 | def test_get_other_label(self):
30 | example_dict = {
31 | 'a': {
32 | 'data': np.array([1, 2, 3]),
33 | 'indices': [0, 1],
34 | },
35 | 'b': {
36 | 'data': np.array([4, 5, 6]),
37 | 'indices': [0, 1],
38 | },
39 | }
40 | self.assertEqual(cg_score._get_other_label(example_dict, 'a'), 'b')
41 | self.assertEqual(cg_score._get_other_label(example_dict, 'b'), 'a')
42 |
43 | def test_balance_dataset(self):
44 | data_a = np.array([[1, 2, 3]])
45 | data_b = np.array(np.array([[4, 5, 6], [7, 8, 9]]))
46 | self.assertLen(cg_score._balance_dataset(data_a, data_b, 1), 2)
47 | self.assertLen(cg_score._balance_dataset(data_a, data_b, 2), 3)
48 | self.assertLen(cg_score._balance_dataset(data_b, data_a, 1), 3)
49 |
50 | def test_calculate_complexity_gap_score_result_sanity(self):
51 | embeddings = np.array([[1, 2], [2, 3], [5, 4], [5, 4]])
52 | labels = np.array([1, 1, 1, 0])
53 | cg_scores = cg_score.complexity_gap_score(
54 | labels, embeddings, num_repetitions=1, random_seed=0
55 | )
56 | self.assertLen(cg_scores, 4)
57 | self.assertLess(cg_scores[0], cg_scores[3])
58 |
59 | def test_calculate_complexity_gap_score_simple(self):
60 | embeddings = np.array([[0, 1], [1, 0], [0, 1], [1, 0]])
61 | labels = np.array([1, 1, 0, 0])
62 | cg_scores = cg_score.complexity_gap_score(
63 | labels, embeddings, num_repetitions=1
64 | )
65 | self.assertSequenceAlmostEqual(cg_scores, [0, 0, 0, 0])
66 |
67 | def test_num_repetitions(self):
68 | embeddings = np.array([[1, 2], [1, 2], [5, 4], [5, 4]])
69 | labels = np.array([1, 1, 0, 0])
70 | cg_scores_one = cg_score.complexity_gap_score(
71 | labels, embeddings, num_repetitions=1
72 | )
73 | cg_scores_three = cg_score.complexity_gap_score(
74 | labels, embeddings, num_repetitions=3
75 | )
76 | self.assertLen(cg_scores_one, 4)
77 | self.assertLen(cg_scores_three, 4)
78 | np.testing.assert_allclose(cg_scores_one, cg_scores_three)
79 |
80 | def test_one_single_label(self):
81 | embeddings = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
82 | labels = np.array([1, 1, 1])
83 | cg_scores = cg_score.complexity_gap_score(
84 | labels, embeddings, num_repetitions=1
85 | )
86 | self.assertSequenceAlmostEqual(cg_scores, [0, 0, 0])
87 |
88 |
89 | if __name__ == '__main__':
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module for cross entropy loss functions."""
15 |
16 | from ml_metrics._src.aggregates import types
17 | from ml_metrics._src.tools.telemetry import telemetry
18 | import numpy as np
19 |
20 |
21 | def _check_y_true_contains_only_0_and_1(y_true: types.NumbersT) -> None:
22 | if not all(y == 0 or y == 1 for y in y_true):
23 | raise ValueError(
24 | 'y_true must contain only 0s and 1s, but recieved: {}'.format(y_true)
25 | )
26 |
27 |
28 | @telemetry.WithTelemetry(
29 | api='ml_metrics', category='signal', reference='binary_cross_entropy'
30 | )
31 | def binary_cross_entropy(
32 | y_true: types.NumbersT,
33 | y_pred: types.NumbersT,
34 | ) -> float:
35 | """Calculates binary cross entropy loss for two lists of labels.
36 |
37 | Args:
38 | y_true: Truth label. This is either 0 or 1.
39 | y_pred: Predicted value. This is the model's prediction, i.e, a single
40 | floating-point value which represents a probability (i.e., value in (0.,
41 | 1.)).
42 |
43 | Returns:
44 | The binary cross-entropy loss between true labels and predicted labels.
45 | """
46 | _check_y_true_contains_only_0_and_1(y_true)
47 |
48 | return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
49 |
50 |
51 | @telemetry.WithTelemetry(
52 | api='ml_metrics', category='signal', reference='categorical_cross_entropy'
53 | )
54 | def categorical_cross_entropy(
55 | y_true: types.NumbersT,
56 | y_pred: types.NumbersT,
57 | ):
58 | """Calculates categorical cross entropy loss for two lists of labels.
59 |
60 | Args:
61 | y_true: Truth label. This is either 0 or 1.
62 | y_pred: Predicted value. This is the model's prediction, i.e, a single
63 | floating-point value which represents a probability (i.e., value in [0.,
64 | 1.]).
65 |
66 | Returns:
67 | The categorical cross-entropy loss between true labels and predicted labels.
68 | """
69 | _check_y_true_contains_only_0_and_1(y_true)
70 |
71 | return -np.sum(y_true * np.log(y_pred / np.sum(y_pred)))
72 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/cross_entropy_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Cross Entropy."""
15 |
16 | import math
17 |
18 | from ml_metrics._src.signals import cross_entropy
19 | import numpy as np
20 |
21 | from absl.testing import absltest
22 |
23 |
24 | _COMPARED_DECIMAL_PLACES = 6 # Used to tune tests accuracy.
25 |
26 |
27 | class CrossEntropyTest(absltest.TestCase):
28 |
29 | def test_binary_cross_entropy(self):
30 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 1))
31 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9))
32 |
33 | # How to reproduce expected_result:
34 | # bce = tensorflow.keras.losses.BinaryCrossentropy()
35 | # expected_result = tensorflow.get_static_value(
36 | # bce(y_true=y_true, y_pred=y_pred)
37 | # )
38 | expected_result = 0.9587651091286978
39 |
40 | self.assertAlmostEqual(
41 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred),
42 | expected_result,
43 | places=_COMPARED_DECIMAL_PLACES,
44 | )
45 |
46 | def test_binary_cross_entropy_y_pred_includes_0_and_1(self):
47 | y_true = np.array((0, 1))
48 | y_pred = np.array((0, 1))
49 |
50 | self.assertTrue(
51 | math.isnan(
52 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred)
53 | ),
54 | )
55 |
56 | def test_binary_cross_entropy_invalid_y_true_raises_error(self):
57 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 2))
58 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9))
59 |
60 | with self.assertRaisesRegex(
61 | ValueError, 'y_true must contain only 0s and 1s, but recieved: '
62 | ):
63 | cross_entropy.binary_cross_entropy(y_true=y_true, y_pred=y_pred)
64 |
65 | def test_categorical_cross_entropy(self):
66 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 1))
67 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9))
68 |
69 | # How to reproduce expected_result:
70 | # cce = tensorflow.keras.losses.CategoricalCrossentropy()
71 | # expected_result = tensorflow.get_static_value(
72 | # cce(y_true=y_true, y_pred=y_pred)
73 | # )
74 | expected_result = 9.38023940877158
75 |
76 | self.assertAlmostEqual(
77 | cross_entropy.categorical_cross_entropy(y_true=y_true, y_pred=y_pred),
78 | expected_result,
79 | places=_COMPARED_DECIMAL_PLACES,
80 | )
81 |
82 | def test_categorical_cross_entropy_invalid_y_true_raises_error(self):
83 | y_true = np.array((0, 1, 0, 1, 0, 1, 0, 2))
84 | y_pred = np.array((0.1, 0.1, 0.4, 0.4, 0.6, 0.6, 0.9, 0.9))
85 |
86 | with self.assertRaisesRegex(
87 | ValueError, 'y_true must contain only 0s and 1s, but recieved: '
88 | ):
89 | cross_entropy.categorical_cross_entropy(y_true=y_true, y_pred=y_pred)
90 |
91 |
92 | if __name__ == '__main__':
93 | absltest.main()
94 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/flip_masks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Flip Masks."""
15 |
16 | from ml_metrics._src.aggregates import types
17 | from ml_metrics._src.tools.telemetry import telemetry
18 | import numpy as np
19 |
20 |
21 | @telemetry.WithTelemetry(
22 | api='ml_metrics', category='signal', reference='binary_flip_mask'
23 | )
24 | def binary_flip_mask(
25 | base_prediction: types.NumbersT,
26 | model_prediction: types.NumbersT,
27 | threshold: types.NumbersT | None = None,
28 | ) -> types.NumbersT:
29 | """AKA symmetric flip mask. Returns a 1 if the predictions don't match."""
30 | if threshold is not None:
31 | base_prediction = base_prediction > threshold
32 | model_prediction = model_prediction > threshold
33 |
34 | return np.logical_xor(base_prediction, model_prediction).astype(int)
35 |
36 |
37 | @telemetry.WithTelemetry(
38 | api='ml_metrics', category='signal', reference='neg_to_pos_flip_mask'
39 | )
40 | def neg_to_pos_flip_mask(
41 | base_prediction: types.NumbersT,
42 | model_prediction: types.NumbersT,
43 | threshold: types.NumbersT | None = None,
44 | ) -> types.NumbersT:
45 | """Returns a 1 if base_prediction <= threshold < model_prediction."""
46 | if threshold is None:
47 | return not base_prediction and model_prediction
48 |
49 | base_under_threshold = base_prediction <= threshold
50 | model_over_threshold = model_prediction > threshold
51 |
52 | return np.logical_and(base_under_threshold, model_over_threshold).astype(int)
53 |
54 |
55 | @telemetry.WithTelemetry(
56 | api='ml_metrics', category='signal', reference='pos_to_neg_flip_mask'
57 | )
58 | def pos_to_neg_flip_mask(
59 | base_prediction: types.NumbersT,
60 | model_prediction: types.NumbersT,
61 | threshold: types.NumbersT | None = None,
62 | ) -> types.NumbersT:
63 | """Returns a 1 if base_prediction > threshold >= model_prediction."""
64 | if threshold is None:
65 | return base_prediction and not model_prediction
66 |
67 | base_over_threshold = base_prediction > threshold
68 | model_under_threshold = model_prediction <= threshold
69 |
70 | return np.logical_and(base_over_threshold, model_under_threshold).astype(int)
71 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/flip_masks_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Flip Masks."""
15 |
16 | from ml_metrics._src.signals import flip_masks
17 | import numpy as np
18 | from numpy import testing
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 |
23 |
24 | class FlipMasksTest(parameterized.TestCase):
25 |
26 | @parameterized.named_parameters(
27 | dict(
28 | testcase_name='binary',
29 | flip_mask_fn=flip_masks.binary_flip_mask,
30 | expected_mask=(False, True, True, False),
31 | ),
32 | dict(
33 | testcase_name='neg_to_pos',
34 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
35 | expected_mask=(False, True, False, False),
36 | ),
37 | dict(
38 | testcase_name='pos_to_neg',
39 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
40 | expected_mask=(False, False, True, False),
41 | ),
42 | )
43 | def test_flip_masks_bool_predictions(self, flip_mask_fn, expected_mask):
44 | base_predictions = np.array((False, False, True, True))
45 | model_predictions = np.array((False, True, False, True))
46 |
47 | for base_prediction, model_prediction, expected_flip in zip(
48 | base_predictions, model_predictions, expected_mask
49 | ):
50 | self.assertEqual(
51 | flip_mask_fn(base_prediction, model_prediction),
52 | expected_flip,
53 | )
54 |
55 | @parameterized.named_parameters(
56 | dict(
57 | testcase_name='binary',
58 | flip_mask_fn=flip_masks.binary_flip_mask,
59 | expected_mask=(0, 1, 1, 0),
60 | ),
61 | dict(
62 | testcase_name='neg_to_pos',
63 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
64 | expected_mask=(0, 1, 0, 0),
65 | ),
66 | dict(
67 | testcase_name='pos_to_neg',
68 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
69 | expected_mask=(0, 0, 1, 0),
70 | ),
71 | )
72 | def test_flip_masks_int_predictions(self, flip_mask_fn, expected_mask):
73 | base_predictions = np.array((0, 0, 1, 1))
74 | model_predictions = np.array((0, 1, 0, 1))
75 |
76 | for base_prediction, model_prediction, expected_flip in zip(
77 | base_predictions, model_predictions, expected_mask
78 | ):
79 | self.assertEqual(
80 | flip_mask_fn(base_prediction, model_prediction),
81 | expected_flip,
82 | )
83 |
84 | @parameterized.named_parameters(
85 | dict(
86 | testcase_name='binary',
87 | flip_mask_fn=flip_masks.binary_flip_mask,
88 | expected_mask=(0, 1, 1, 0),
89 | ),
90 | dict(
91 | testcase_name='neg_to_pos',
92 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
93 | expected_mask=(0, 1, 0, 0),
94 | ),
95 | dict(
96 | testcase_name='pos_to_neg',
97 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
98 | expected_mask=(0, 0, 1, 0),
99 | ),
100 | )
101 | def test_flip_masks_threshold(self, flip_mask_fn, expected_mask):
102 | base_predictions = np.array((0.1, 0.1, 0.9, 0.9))
103 | model_predictions = np.array((0.2, 0.9, 0.1, 0.8))
104 |
105 | for base_prediction, model_prediction, expected_flip in zip(
106 | base_predictions, model_predictions, expected_mask
107 | ):
108 | self.assertEqual(
109 | flip_mask_fn(base_prediction, model_prediction, threshold=0.5),
110 | expected_flip,
111 | )
112 |
113 | @parameterized.named_parameters(
114 | dict(
115 | testcase_name='binary',
116 | flip_mask_fn=flip_masks.binary_flip_mask,
117 | expected_mask=(False, True, True, False),
118 | ),
119 | dict(
120 | testcase_name='neg_to_pos',
121 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
122 | expected_mask=(False, True, False, False),
123 | ),
124 | dict(
125 | testcase_name='pos_to_neg',
126 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
127 | expected_mask=(False, False, True, False),
128 | ),
129 | )
130 | def test_flip_masks_bool_predictions_batched(
131 | self, flip_mask_fn, expected_mask
132 | ):
133 | base_predictions = np.array((False, False, True, True))
134 | model_predictions = np.array((False, True, False, True))
135 |
136 | testing.assert_array_equal(
137 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5),
138 | expected_mask,
139 | )
140 |
141 | @parameterized.named_parameters(
142 | dict(
143 | testcase_name='binary',
144 | flip_mask_fn=flip_masks.binary_flip_mask,
145 | expected_mask=(0, 1, 1, 0),
146 | ),
147 | dict(
148 | testcase_name='neg_to_pos',
149 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
150 | expected_mask=(0, 1, 0, 0),
151 | ),
152 | dict(
153 | testcase_name='pos_to_neg',
154 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
155 | expected_mask=(0, 0, 1, 0),
156 | ),
157 | )
158 | def test_flip_masks_int_predictions_batched(
159 | self, flip_mask_fn, expected_mask
160 | ):
161 | base_predictions = np.array((0, 0, 1, 1))
162 | model_predictions = np.array((0, 1, 0, 1))
163 |
164 | testing.assert_array_equal(
165 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5),
166 | expected_mask,
167 | )
168 |
169 | @parameterized.named_parameters(
170 | dict(
171 | testcase_name='binary',
172 | flip_mask_fn=flip_masks.binary_flip_mask,
173 | expected_mask=(0, 1, 1, 0),
174 | ),
175 | dict(
176 | testcase_name='neg_to_pos',
177 | flip_mask_fn=flip_masks.neg_to_pos_flip_mask,
178 | expected_mask=(0, 1, 0, 0),
179 | ),
180 | dict(
181 | testcase_name='pos_to_neg',
182 | flip_mask_fn=flip_masks.pos_to_neg_flip_mask,
183 | expected_mask=(0, 0, 1, 0),
184 | ),
185 | )
186 | def test_flip_masks_threshold_batched(self, flip_mask_fn, expected_mask):
187 | base_predictions = np.array((0.1, 0.1, 0.9, 0.9))
188 | model_predictions = np.array((0.2, 0.9, 0.1, 0.8))
189 |
190 | testing.assert_array_equal(
191 | flip_mask_fn(base_predictions, model_predictions, threshold=0.5),
192 | expected_mask,
193 | )
194 |
195 |
196 | if __name__ == '__main__':
197 | absltest.main()
198 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Samplewise scoring metrics for image."""
15 |
16 | import io
17 |
18 | from ml_metrics._src.tools.telemetry import telemetry
19 | from PIL import Image
20 |
21 |
22 | MEGAPIXELS = 1024 * 1024
23 |
24 |
25 | @telemetry.WithTelemetry(
26 | api='ml_metrics', category='signal', reference='content_metadata'
27 | )
28 | def content_metadata(image_bytes: bytes) -> dict[str, int | float | str]:
29 | """Extracts the content metadata of an image."""
30 |
31 | img = Image.open(io.BytesIO(image_bytes))
32 | width, height = img.size
33 | return { # pytype: disable=bad-return-type # pillow-102-upgrade
34 | 'mode': img.mode,
35 | 'format': img.format,
36 | 'pixel_width': width,
37 | 'pixel_height': height,
38 | 'megapixel_resolution': width * height / MEGAPIXELS,
39 | 'aspect_ratio': width / height if height > 0 else 0.0,
40 | }
41 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/image_test.py:
--------------------------------------------------------------------------------
1 | """Tests for image."""
2 |
3 | import io
4 |
5 | from ml_metrics._src.signals import image
6 | from PIL import Image
7 |
8 | from absl.testing import absltest
9 |
10 |
11 | class ImageTest(absltest.TestCase):
12 |
13 | def test_content_metadata(self):
14 | single_black_pixel_image = Image.new('RGB', (1, 1))
15 | with io.BytesIO() as output:
16 | single_black_pixel_image.save(output, format='PNG')
17 | single_black_pixel_image_bytes = output.getvalue()
18 |
19 | expected_metadata = {
20 | 'mode': 'RGB',
21 | 'format': 'PNG',
22 | 'pixel_width': 1,
23 | 'pixel_height': 1,
24 | 'megapixel_resolution': 1 / image.MEGAPIXELS,
25 | 'aspect_ratio': 1.0,
26 | }
27 |
28 | metadata = image.content_metadata(single_black_pixel_image_bytes)
29 | self.assertEqual(metadata, expected_metadata)
30 |
31 |
32 | if __name__ == '__main__':
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Samplewise scoring metrics for text."""
15 |
16 | from collections.abc import Callable, Sequence
17 | import re
18 | from typing import Any
19 |
20 | from ml_metrics._src.tools.telemetry import telemetry
21 |
22 |
23 | def _maybe_tuple(
24 | reference: str | Sequence[str],
25 | ) -> Sequence[str]:
26 | """Converts reference to a sequence if it is a single string."""
27 | if isinstance(reference, str):
28 | return (reference,)
29 | return reference
30 |
31 |
32 | @telemetry.WithTelemetry(
33 | api='ml_metrics', category='signal', reference='alphabetical_char_count'
34 | )
35 | def alphabetical_char_count(text: str) -> int:
36 | """Computes the number of alphabetical characters."""
37 | return len(re.sub(r'[^a-zA-Z]', '', text))
38 |
39 |
40 | @telemetry.WithTelemetry(
41 | api='ml_metrics', category='signal', reference='word_count'
42 | )
43 | def word_count(text: str) -> int:
44 | """Computes the number of words.
45 |
46 | Computes the number of words within the text. Characters that are not letters
47 | or spaces are taken out of the text, leaving only spaces between words.
48 | However, each contraction is counted as only one word. For example, "I'm" is
49 | treated as one word, "Im".
50 |
51 | Args:
52 | text: Input text.
53 |
54 | Returns:
55 | Number of words.
56 | """
57 | return len(_get_words(text))
58 |
59 |
60 | @telemetry.WithTelemetry(
61 | api='ml_metrics', category='signal', reference='token_count'
62 | )
63 | def token_count(text: str, tokenizer: Callable[[str], Sequence[Any]]) -> int:
64 | """Computes the number of tokens."""
65 | return len(tokenizer(text))
66 |
67 |
68 | @telemetry.WithTelemetry(
69 | api='ml_metrics', category='signal', reference='token_match_rate'
70 | )
71 | def token_match_rate(
72 | sample: str, reference: str, tokenizer: Callable[[str], Sequence[Any]]
73 | ) -> float:
74 | """Computes the token match rate between sample and reference."""
75 | sample_tokens = tokenizer(sample)
76 | reference_tokens = tokenizer(reference)
77 | matched = 0
78 | for t1, t2 in zip(sample_tokens, reference_tokens):
79 | if t1 == t2:
80 | matched += 1
81 | length = max(len(sample_tokens), len(reference_tokens))
82 | if length == 0:
83 | return 0
84 | return matched / length
85 |
86 |
87 | @telemetry.WithTelemetry(
88 | api='ml_metrics', category='signal', reference='exact_match'
89 | )
90 | def exact_match(sample: str, reference: str | Sequence[str]) -> bool:
91 | """Computes the exact match between sample and reference."""
92 | references = _maybe_tuple(reference)
93 | return any(sample == ref for ref in references)
94 |
95 |
96 | @telemetry.WithTelemetry(
97 | api='ml_metrics',
98 | category='signal',
99 | reference='sample_startswith_reference_match',
100 | )
101 | def sample_startswith_reference_match(
102 | sample: str, reference: str | Sequence[str]
103 | ) -> bool:
104 | """True when the sample starts with reference."""
105 | references = _maybe_tuple(reference)
106 | return any(sample.startswith(ref) for ref in references)
107 |
108 |
109 | @telemetry.WithTelemetry(
110 | api='ml_metrics',
111 | category='signal',
112 | reference='reference_startswith_sample_match',
113 | )
114 | def reference_startswith_sample_match(
115 | sample: str, reference: str | Sequence[str]
116 | ) -> bool:
117 | """True when the reference starts with sample."""
118 | references = _maybe_tuple(reference)
119 | return any(ref.startswith(sample) for ref in references)
120 |
121 |
122 | @telemetry.WithTelemetry(
123 | api='ml_metrics', category='signal', reference='reference_in_sample_match'
124 | )
125 | def reference_in_sample_match(
126 | sample: str, reference: str | Sequence[str]
127 | ) -> bool:
128 | """True when the reference in sample match."""
129 | references = _maybe_tuple(reference)
130 | return any(ref in sample for ref in references)
131 |
132 |
133 | @telemetry.WithTelemetry(
134 | api='ml_metrics', category='signal', reference='sample_in_reference_match'
135 | )
136 | def sample_in_reference_match(
137 | sample: str, reference: str | Sequence[str]
138 | ) -> bool:
139 | """True when the sample in reference match."""
140 | references = _maybe_tuple(reference)
141 | return any(sample in ref for ref in references)
142 |
143 |
144 | @telemetry.WithTelemetry(
145 | api='ml_metrics', category='signal', reference='non_ascii_char_count'
146 | )
147 | def non_ascii_char_count(text: str) -> int:
148 | """Computes the number of non-ascii characters."""
149 | return len(re.sub(r'[^\x00-\x7F]+', '', text))
150 |
151 |
152 | @telemetry.WithTelemetry(
153 | api='ml_metrics', category='signal', reference='is_all_whitespace'
154 | )
155 | def is_all_whitespace(text: str) -> bool:
156 | r"""Checks if the text is all whitespace.
157 |
158 | Check if string is empty-ish e.g. consisting of whitespace, \n, \t.
159 |
160 | Args:
161 | text: Input text.
162 |
163 | Returns:
164 | True if the text is all whitespace.
165 | """
166 | return not text.strip()
167 |
168 |
169 | @telemetry.WithTelemetry(
170 | api='ml_metrics', category='signal', reference='average_word_length'
171 | )
172 | def average_word_length(text: str) -> float:
173 | """Computes the average word length."""
174 | words = _get_words(text)
175 | if not words:
176 | return 0.0
177 | return sum(len(word) for word in words) / len(words)
178 |
179 |
180 | def _get_words(text: str) -> list[str]:
181 | """Returns the words in the text."""
182 | return re.sub(r'[^a-zA-Z ]', '', text).split()
183 |
184 |
185 | @telemetry.WithTelemetry(
186 | api='ml_metrics', category='signal', reference='unique_word_count'
187 | )
188 | def unique_word_count(text: str) -> int:
189 | """Computes the number of unique words."""
190 | return len(set(_get_words(text)))
191 |
192 |
193 | @telemetry.WithTelemetry(
194 | api='ml_metrics', category='signal', reference='number_of_characters'
195 | )
196 | def number_of_characters(text: str) -> int:
197 | """Computes the number of characters."""
198 | return len(text)
199 |
200 |
201 | @telemetry.WithTelemetry(
202 | api='ml_metrics', category='signal', reference='percentage_all_caps'
203 | )
204 | def percentage_all_caps(text: str) -> float:
205 | """Computes the percentage of all caps."""
206 | words = _get_words(text)
207 | if not words:
208 | return 0
209 | return len([word for word in words if word.isupper()]) / len(words)
210 |
211 |
212 | @telemetry.WithTelemetry(
213 | api='ml_metrics',
214 | category='signal',
215 | reference='percentage_non_ascii_characters',
216 | )
217 | def percentage_non_ascii_characters(text: str) -> float:
218 | """Computes the percentage of non-ascii characters."""
219 | if not number_of_characters(text):
220 | return 0
221 | return 1 - (non_ascii_char_count(text) / number_of_characters(text))
222 |
223 |
224 | @telemetry.WithTelemetry(
225 | api='ml_metrics', category='signal', reference='type_token_ratio'
226 | )
227 | def type_token_ratio(text: str) -> float:
228 | """Computes the type token ratio.
229 |
230 | Words with the same letters but different lowercase letters are considered
231 | different.
232 |
233 | Args:
234 | text: Input text.
235 |
236 | Returns:
237 | The ratio of unique words to total words.
238 | """
239 | words = _get_words(text)
240 | if not words:
241 | return 0
242 |
243 | return unique_word_count(text) / len(words)
244 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/topk_accuracy.py:
--------------------------------------------------------------------------------
1 | """Topk accuracy metric."""
2 |
3 | from ml_metrics._src.aggregates import types
4 | from ml_metrics._src.tools.telemetry import telemetry
5 | import numpy as np
6 |
7 |
8 | @telemetry.WithTelemetry(
9 | api='ml_metrics', category='signal', reference='topk_accuracy'
10 | )
11 | def topk_accurate(
12 | y_pred: types.NumbersT,
13 | label: int,
14 | weights: types.NumbersT = 1.0,
15 | k: int = 1,
16 | ) -> bool:
17 | """Calculate topk accuracy.
18 |
19 | Args:
20 | y_pred: Prediction scores with shape [num_classes].
21 | label: The ground truth label with values in [0, num_classes).
22 | weights: Weight applied to the prediction scores for computing the top-k
23 | accurate metric. Default is 1.0.
24 | k: The top-k predictions to consider.
25 |
26 | Returns:
27 | True if the label is in the top-k predictions, False otherwise.
28 | """
29 | weighted_pred = np.asarray(y_pred) * np.asarray(weights)
30 | # Get indices of top-k predictions
31 | topk_predictions = np.argsort(weighted_pred)[-k:]
32 | return label in topk_predictions
33 |
--------------------------------------------------------------------------------
/ml_metrics/_src/signals/topk_accuracy_test.py:
--------------------------------------------------------------------------------
1 | """Unit tests for topk_accuracy.py."""
2 |
3 | from ml_metrics._src.signals import topk_accuracy
4 | from absl.testing import absltest
5 | from absl.testing import parameterized
6 |
7 |
8 | class TopkAccuracyTest(parameterized.TestCase):
9 |
10 | @parameterized.named_parameters(
11 | dict(
12 | testcase_name="top_k_accuracy_label1_k1",
13 | y_pred=[0.2, 0.7, 0.1],
14 | label=1,
15 | k=1,
16 | expected=1.0,
17 | ),
18 | dict(
19 | testcase_name="top_k_accuracy_label0_k1",
20 | y_pred=[0.2, 0.7, 0.1],
21 | label=0,
22 | k=1,
23 | expected=0.0,
24 | ),
25 | dict(
26 | testcase_name="top_k_accuracy_label0_k2",
27 | y_pred=[0.2, 0.7, 0.1],
28 | label=0,
29 | k=2,
30 | expected=1.0,
31 | ),
32 | )
33 | def test_topk_accuracy(self, y_pred, label, k, expected):
34 | self.assertEqual(topk_accuracy.topk_accurate(y_pred, label, k=k), expected)
35 |
36 | @parameterized.named_parameters(
37 | dict(
38 | testcase_name="top_k_accuracy_weight_label1_k1_case0",
39 | y_pred=[0.2, 0.7, 0.1],
40 | label=1,
41 | k=1,
42 | weight=[1.0, 1.0/3.49, 1.0],
43 | expected=1.0,
44 | ),
45 | dict(
46 | testcase_name="top_k_accuracy_weight_label1_k1_case1",
47 | y_pred=[0.2, 0.7, 0.1],
48 | label=1,
49 | k=1,
50 | weight=[1.0, 1.0/3.51, 1.0],
51 | expected=0.0,
52 | ),
53 | dict(
54 | testcase_name="top_k_accuracy_weight_label0_k2",
55 | y_pred=[0.2, 0.7, 0.1],
56 | label=0,
57 | k=2,
58 | weight=[1.0, 1.0/3.51, 1.0],
59 | expected=1.0,
60 | ),
61 | )
62 | def test_topk_accuracy_with_weights(
63 | self, y_pred, label, k, weight, expected
64 | ):
65 | self.assertEqual(
66 | topk_accuracy.topk_accurate(y_pred, label, weight, k), expected
67 | )
68 |
69 |
70 | if __name__ == "__main__":
71 | absltest.main()
72 |
--------------------------------------------------------------------------------
/ml_metrics/_src/tools/telemetry/telemetry.py:
--------------------------------------------------------------------------------
1 | """Telemetry for MLAT."""
2 |
3 | from absl import logging
4 |
5 |
6 | def increment_counter(
7 | api: str, category: str, reference: str, execution_succeed: bool
8 | ):
9 | logging.info(
10 | 'Logging counter: api=%s, category=%s, reference=%s,'
11 | ' execution_succeed=%s',
12 | api,
13 | category,
14 | reference,
15 | execution_succeed,
16 | )
17 |
18 |
19 | class WithTelemetry:
20 | """Decorator to log usage."""
21 |
22 | def __init__(
23 | self,
24 | api: str,
25 | category: str,
26 | reference: str,
27 | *,
28 | target_methods: list[str] | str | None = None,
29 | ):
30 | self.api = api
31 | self.category = category
32 | self.reference = reference
33 | self.target_methods = target_methods
34 |
35 | def __call__(self, class_or_func_ref):
36 | increment_counter(self.api, self.category, self.reference, True)
37 | return class_or_func_ref
38 |
--------------------------------------------------------------------------------
/ml_metrics/_src/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base types used throughout the library."""
15 |
16 | import abc
17 | from typing import Any, Protocol, TypeGuard, TypeVar, runtime_checkable
18 | from numpy import typing as npt
19 |
20 | _T = TypeVar('_T')
21 |
22 |
23 | @runtime_checkable
24 | class SinkT(Protocol[_T]):
25 | """A sink that can be used to sink the data."""
26 |
27 | def write(self, data: _T) -> None:
28 | """Writes the data to the sink."""
29 |
30 | def close(self) -> None:
31 | """Closes the sink."""
32 |
33 |
34 | @runtime_checkable
35 | class Makeable(Protocol[_T]):
36 | """A config class that can make a Metric class."""
37 |
38 | @abc.abstractmethod
39 | def make(self) -> _T:
40 | """Makes a new Metric."""
41 |
42 |
43 | @runtime_checkable
44 | class Resolvable(Protocol[_T]):
45 | """All Resolvlables implements a `result_` to resolve the underlying value."""
46 |
47 | @abc.abstractmethod
48 | def result_(self) -> _T:
49 | """Interface to get the result of the underlying value."""
50 |
51 |
52 | class Shardable(Protocol):
53 | """A sharded data source for chainables."""
54 |
55 | @abc.abstractmethod
56 | def shard(self, *args, **kwargs):
57 | """Iterates the data source given a shard index and number of shards."""
58 |
59 |
60 | class Serializable(Protocol):
61 | """An object that can be both serialized and deserialized."""
62 |
63 | @abc.abstractmethod
64 | def get_config(self):
65 | """Gets the state of the object that can be used to recover the object."""
66 |
67 | @abc.abstractmethod
68 | def from_config(self, *args, **kwargs):
69 | """Iterates the data source given a shard index and number of shards."""
70 |
71 |
72 | class Recoverable(Protocol):
73 | """An object that can be both serialized and deserialized."""
74 |
75 | state: Any
76 |
77 | @abc.abstractmethod
78 | def from_state(self, *args, **kwargs):
79 | """Recover from the state."""
80 |
81 |
82 | @runtime_checkable
83 | class Stoppable(Protocol):
84 | """An object that can be stopped."""
85 |
86 | def maybe_stop(self):
87 | """Stops the concurrent threads if possible."""
88 |
89 |
90 | MaybeResolvable = Resolvable[_T] | _T
91 |
92 |
93 | class RandomAccessible(Protocol[_T]):
94 |
95 | def __getitem__(self, idx: int | slice) -> _T:
96 | """Same as Sequence.__getitem__."""
97 |
98 | def __len__(self) -> int:
99 | """Same as Sequence.__len__."""
100 |
101 |
102 | def obj_has_method(obj: Any, method_name: str) -> bool:
103 | """Checks if the object has a method."""
104 | method = getattr(obj, method_name, False)
105 | return method and getattr(method, '__self__', None) is obj
106 |
107 |
108 | def is_resolvable(obj: Resolvable[_T] | Any) -> TypeGuard[Resolvable[_T]]:
109 | """Checks if the object is a Resolvable."""
110 | return obj_has_method(obj, 'result_')
111 |
112 |
113 | def is_makeable(obj: Makeable[_T] | Any) -> TypeGuard[Makeable[_T]]:
114 | """Checks if the object is a Makeable."""
115 | return obj_has_method(obj, 'make')
116 |
117 |
118 | def is_shardable(obj: Shardable | Any) -> TypeGuard[Shardable]:
119 | """Checks if the object is a Shardable."""
120 | return obj_has_method(obj, 'shard')
121 |
122 |
123 | def is_serializable(obj: Serializable | Any) -> TypeGuard[Serializable]:
124 | """Checks if the object is a Shardable."""
125 | return obj_has_method(obj, 'get_config') and obj_has_method(
126 | obj, 'from_config'
127 | )
128 |
129 |
130 | def is_recoverable(obj: Recoverable | Any) -> TypeGuard[Recoverable]:
131 | """Checks if the object is a Shardable."""
132 | return obj_has_method(obj, 'from_state') and hasattr(obj, 'state')
133 |
134 |
135 | def is_array_like(obj: list[Any] | tuple[Any, ...] | npt.ArrayLike) -> bool:
136 | """Checks if the object is an array-like object."""
137 | return isinstance(obj, (list, tuple)) or (
138 | hasattr(obj, '__array__') and obj.ndim > 0
139 | )
140 |
--------------------------------------------------------------------------------
/ml_metrics/_src/types_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from absl.testing import absltest
15 | from ml_metrics._src import types
16 |
17 |
18 | class BaseTypesTest(absltest.TestCase):
19 |
20 | def test_is_resolvable(self):
21 |
22 | class Foo:
23 |
24 | @classmethod
25 | def result_(cls):
26 | pass
27 |
28 | self.assertIsInstance(Foo(), types.Resolvable)
29 | self.assertFalse(types.is_resolvable(Foo()))
30 |
31 | def test_is_makeable(self):
32 |
33 | class Foo:
34 |
35 | @classmethod
36 | def make(cls):
37 | pass
38 |
39 | self.assertIsInstance(Foo(), types.Makeable)
40 | self.assertFalse(types.is_makeable(Foo()))
41 |
42 | def test_is_shardable(self):
43 |
44 | class Foo:
45 |
46 | @classmethod
47 | def shard(cls, shard_index):
48 | pass
49 |
50 | self.assertFalse(types.is_shardable(Foo()))
51 |
52 | def test_is_serializable(self):
53 |
54 | class Foo:
55 |
56 | def get_config(self):
57 | pass
58 |
59 | def from_config(self, config):
60 | pass
61 |
62 | self.assertTrue(types.is_serializable(Foo()))
63 |
64 | def test_is_recoverable(self):
65 |
66 | class Foo:
67 |
68 | @property
69 | def state(self):
70 | pass
71 |
72 | def from_state(self, state):
73 | pass
74 |
75 | self.assertTrue(types.is_recoverable(Foo()))
76 |
77 | def test_is_not_recoverable(self):
78 |
79 | class Foo1:
80 |
81 | @property
82 | def state(self):
83 | pass
84 |
85 | self.assertFalse(types.is_recoverable(Foo1()))
86 |
87 | class Foo2:
88 |
89 | def from_state(self, state):
90 | pass
91 |
92 | self.assertFalse(types.is_recoverable(Foo2()))
93 |
94 |
95 | if __name__ == "__main__":
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/df_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for generating DataFrames."""
15 |
16 | from __future__ import annotations
17 |
18 | import collections
19 | from typing import Any
20 |
21 | from ml_metrics._src.chainables import transform
22 | import pandas as pd
23 |
24 |
25 | _METRIC_NAME = 'metric_name'
26 | _SLICE = 'slice'
27 | _VALUE = 'value'
28 |
29 |
30 | def _first_or_tuple(x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
31 | if isinstance(x, tuple) and len(x) == 1:
32 | return x[0]
33 | return x
34 |
35 |
36 | _StrOrMetricKey = transform.MetricKey | str
37 |
38 |
39 | def metrics_to_df(metrics: dict[_StrOrMetricKey, Any]) -> pd.DataFrame:
40 | """Converts the aggregation result to a DataFrame.
41 |
42 | This always converts the dict aggregation result to a DataFrame with
43 | the following columns:
44 |
45 | - metric_name: the name of the metric.
46 | - slice: the slice of the metric, if a slice is not specified, it will be
47 | 'overall'.
48 | - value: the value of the metric.
49 |
50 | Args:
51 | metrics: the aggregation result.
52 |
53 | Returns:
54 | A DataFrame with the above columns.
55 | """
56 | sliced_results = collections.defaultdict(list)
57 | for k, v in metrics.items():
58 | if isinstance(k, str):
59 | sliced_results[_METRIC_NAME].append(k)
60 | sliced_results[_SLICE].append('overall')
61 | sliced_results[_VALUE].append(v)
62 | elif isinstance(k, transform.MetricKey):
63 | sliced_results[_METRIC_NAME].append(k.metrics)
64 | slice_name = _first_or_tuple(k.slice.features)
65 | slice_value = _first_or_tuple(k.slice.values)
66 | sliced_results[_SLICE].append(f'{slice_name}={slice_value}')
67 | sliced_results[_VALUE].append(v)
68 | return pd.DataFrame(sliced_results)
69 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/df_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from ml_metrics._src.chainables import transform
17 | from ml_metrics._src.chainables import tree_fns
18 | from ml_metrics._src.utils import df_utils
19 | import pandas as pd
20 |
21 |
22 | class DfUtilsTest(absltest.TestCase):
23 |
24 | def test_as_dataframe_default(self):
25 | key = transform.MetricKey(
26 | metrics='m2', slice=tree_fns.SliceKey(('f1',), ('a',))
27 | )
28 | agg_result = {'m1': 1, key: 2}
29 | df = df_utils.metrics_to_df(agg_result)
30 | pd.testing.assert_frame_equal(
31 | df,
32 | pd.DataFrame({
33 | 'metric_name': ['m1', 'm2'],
34 | 'slice': ['overall', 'f1=a'],
35 | 'value': [1, 2],
36 | }),
37 | )
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/func_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Internal function utils, not meant to be used by users."""
15 |
16 | import collections
17 | from collections.abc import Iterable, Iterator, Mapping
18 | from concurrent import futures
19 | import copy
20 | import dataclasses as dc
21 | import functools
22 | import itertools as itt
23 | from typing import TypeVar
24 | import weakref
25 |
26 | from absl import logging
27 | import more_itertools as mit
28 |
29 |
30 | _KeyT = TypeVar('_KeyT')
31 | _ValueT = TypeVar('_ValueT')
32 | _T = TypeVar('_T')
33 |
34 |
35 | @dc.dataclass(slots=True, frozen=True)
36 | class _CacheInfo:
37 | hits: int
38 | misses: int
39 | currsize: int
40 | maxsize: int = 0
41 |
42 |
43 | class LruCache(Mapping[_KeyT, _ValueT]):
44 | """A mapping like object for caching with limited size."""
45 |
46 | def __init__(self, maxsize=128):
47 | self.maxsize = maxsize
48 | self.currsize = 0
49 | self.hits = 0
50 | self.misses = 0
51 | self.data = collections.OrderedDict()
52 |
53 | def __getitem__(self, key):
54 | if key not in self.data:
55 | self.misses += 1
56 | raise KeyError()
57 | self.hits += 1
58 | value = self.data[key]
59 | self.data.move_to_end(key)
60 | return value
61 |
62 | def __setitem__(self, key, value):
63 | key_is_new = key not in self.data
64 | self.data[key] = value
65 | if key_is_new:
66 | self.currsize += 1
67 | self.data.move_to_end(key)
68 | if self.currsize > self.maxsize:
69 | oldest = next(iter(self.data))
70 | del self.data[oldest]
71 | self.currsize -= 1
72 |
73 | def cache_insert(self, key, value):
74 | self.__setitem__(key, value)
75 |
76 | def __contains__(self, key):
77 | return key in self.data
78 |
79 | def __iter__(self) -> Iterator[_KeyT]:
80 | return iter(self.data)
81 |
82 | def __len__(self) -> int:
83 | return self.currsize
84 |
85 | def cache_clear(self):
86 | self.data.clear()
87 | self.currsize = 0
88 | self.hits = 0
89 | self.misses = 0
90 |
91 | def cache_info(self) -> _CacheInfo:
92 | return _CacheInfo(
93 | hits=self.hits,
94 | misses=self.misses,
95 | maxsize=self.maxsize,
96 | currsize=self.currsize,
97 | )
98 |
99 |
100 | def lru_cache(
101 | fn=None,
102 | *,
103 | settable_kwargs: Iterable[str] = (),
104 | maxsize: int = 128,
105 | ):
106 | """Cache by the positional and specified keyword arguments."""
107 |
108 | settable_kwargs = set(settable_kwargs)
109 |
110 | def decorator(fn):
111 | cache_ = LruCache(maxsize=maxsize)
112 |
113 | @functools.wraps(fn)
114 | def wrapped(*args, cache_insert_: bool = False, **kwargs):
115 | is_settable = lambda x: x[0] in settable_kwargs
116 | hashed, settables = mit.partition(is_settable, kwargs.items())
117 | key = hash(tuple(itt.chain(args, hashed)))
118 | if not cache_insert_ and key in cache_:
119 | result = cache_[key]
120 | else:
121 | logging.debug('chainable: %s', f'cache miss {fn}: {args}, {kwargs}')
122 | result = fn(*args, **kwargs)
123 | cache_[key] = result
124 | result_new = None
125 | for k, v in settables:
126 | if v != getattr(result, k):
127 | if result_new is None:
128 | result_new = copy.copy(result)
129 | setattr(result_new, k, v)
130 | return result_new if result_new is not None else result
131 |
132 | wrapped.cache_info = cache_.cache_info
133 | wrapped.cache_clear = cache_.cache_clear
134 | return wrapped
135 |
136 | return decorator if fn is None else decorator(fn)
137 |
138 |
139 | class SingletonMeta(type):
140 | """A metaclass that makes a class a singleton of any "equivalent" instance.
141 |
142 | The actual class instance has to be hashable to test equivalence. This also
143 | works with inherited classes.
144 | Example:
145 | ```
146 | @dc.dataclass(frozen=True)
147 | class Foo(metaclass=SingletonMeta):
148 | a: int
149 | b: str = 'b'
150 |
151 | # The following should be true.
152 | assert Foo(1) is Foo(1)
153 | assert Foo(1, 'b') is Foo(1)
154 | assert Foo(1, b='b') is not Foo(1)
155 | assert Foo(2) is not Foo(1)
156 | assert Foo(1, 'a') is not Foo(1)
157 | ```
158 | """
159 |
160 | _instances = weakref.WeakKeyDictionary()
161 |
162 | def __call__(cls: type[_T], *args, **kwargs) -> _T:
163 | obj = super(SingletonMeta, cls).__call__(*args, **kwargs)
164 | if (ref := cls._instances.get(obj, None)) and (result := ref()) is not None:
165 | return result
166 | logging.info('chainable: %s', f'singleton {cls.__name__}, {obj}')
167 | cls._instances[obj] = weakref.ref(obj)
168 | return obj
169 |
170 | @property
171 | def all_instances(cls):
172 | return [obj for ref in cls._instances.values() if (obj := ref())]
173 |
174 |
175 | class SingletonThreadPool(futures.ThreadPoolExecutor, metaclass=SingletonMeta):
176 | """A singleton thread pool executor."""
177 |
178 | def __hash__(self):
179 | return hash((self._max_workers, self._thread_name_prefix))
180 |
181 | def __eq__(self, other):
182 | return (
183 | isinstance(other, SingletonThreadPool)
184 | and self._max_workers == other._max_workers
185 | and self._thread_name_prefix == other._thread_name_prefix
186 | )
187 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/func_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import dataclasses as dc
15 | from absl.testing import absltest
16 | from ml_metrics._src.utils import func_utils
17 |
18 |
19 | @dc.dataclass()
20 | class Foo:
21 | a: int
22 | _b: int = 0
23 | c: int = 0
24 |
25 | @property
26 | def b(self):
27 | return self._b
28 |
29 | @b.setter
30 | def b(self, value):
31 | self._b = value
32 |
33 | def __eq__(self, other):
34 | return self.a == other.a and self.b == other.b and self.c == other.c
35 |
36 |
37 | class CacheByKwargsTest(absltest.TestCase):
38 |
39 | def test_without_kwargs(self):
40 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo)
41 | self.assertEqual(Foo(1, 0, 0), foo_cached(1))
42 | self.assertEqual(Foo(1, 100, 100), foo_cached(1, b=100, c=100))
43 | self.assertEqual(foo_cached.cache_info().hits, 1)
44 |
45 | def test_ignore_kwargs(self):
46 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', '_b', 'c'])(Foo)
47 | self.assertEqual(Foo(1, 10, 0), foo_cached(1, _b=10))
48 | self.assertEqual(Foo(1, 100, 100), foo_cached(1, b=100, c=100))
49 | self.assertEqual(Foo(1, 10, 0), foo_cached(1, b=10))
50 |
51 | def test_cache_partial_kwargs(self):
52 | foo_cached = func_utils.lru_cache(settable_kwargs=['b'])(Foo)
53 | self.assertEqual(Foo(1, 0, 10), foo_cached(1, c=10))
54 | self.assertEqual(Foo(1, 100, 10), foo_cached(1, b=100, c=10))
55 | self.assertEqual(Foo(1, 10, 10), foo_cached(1, b=10, c=10))
56 |
57 | def test_cache_info(self):
58 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo)
59 | foo_cached(1, c=10)
60 | foo_cached(1, b=100, c=100)
61 | foo_cached(1, b=10, c=100)
62 | self.assertEqual(foo_cached.cache_info().hits, 2)
63 | self.assertEqual(foo_cached.cache_info().misses, 0)
64 | self.assertEqual(foo_cached.cache_info().currsize, 1)
65 |
66 | def test_cache_insert(self):
67 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo)
68 | foo_cached(1)
69 | self.assertEqual(Foo(1, 0, 1), foo_cached(1, c=1))
70 | foo_cached(1, c=100, cache_insert_=True)
71 | self.assertEqual(Foo(1, 0, 100), foo_cached(1))
72 |
73 | def test_attribute_error_raises(self):
74 |
75 | def foo(a, b=1):
76 | return (a, b)
77 |
78 | foo_cached = func_utils.lru_cache(settable_kwargs=['b'])(foo)
79 | with self.assertRaises(AttributeError):
80 | # b is not an attr in the result of foo, thus, cannot uses this kind of
81 | # caching mechanism by setting the uncached attr afterwards.
82 | foo_cached(1, b=10)
83 |
84 | def test_cache_clear(self):
85 | foo_cached = func_utils.lru_cache(settable_kwargs=['b', 'c'])(Foo)
86 | foo_cached(1)
87 | foo_cached(1, b=10, c=100)
88 | foo_cached.cache_clear()
89 | self.assertEqual(foo_cached.cache_info().hits, 0)
90 | self.assertEqual(foo_cached.cache_info().misses, 0)
91 | self.assertEqual(foo_cached.cache_info().currsize, 0)
92 |
93 |
94 | class LruCacheTest(absltest.TestCase):
95 |
96 | def test_insert_and_get_and_clear(self):
97 | cache = func_utils.LruCache(maxsize=128)
98 | cache['a'] = 1
99 | self.assertEqual(cache['a'], 1)
100 | self.assertLen(cache, 1)
101 | self.assertEqual(1, cache.cache_info().currsize)
102 | self.assertEqual(1, cache.cache_info().hits)
103 | self.assertEqual(128, cache.cache_info().maxsize)
104 | cache.cache_clear()
105 | self.assertEmpty(cache)
106 |
107 | def test_cache_evict(self):
108 | cache = func_utils.LruCache(maxsize=2)
109 | cache['a'] = 1
110 | cache['b'] = 2
111 | # This pushes 'a' to the latest.
112 | self.assertEqual(1, cache['a'])
113 | # This causes 'b' to be evicted.
114 | cache['c'] = 3
115 | self.assertEqual(['a', 'c'], list(cache))
116 |
117 | def test_missing_key(self):
118 | cache = func_utils.LruCache(maxsize=128)
119 | with self.assertRaises(KeyError):
120 | _ = cache['a']
121 |
122 | def test_cache_insert(self):
123 | cache = func_utils.LruCache(maxsize=128)
124 | cache.cache_insert('a', 1)
125 | cache.cache_insert('a', 2)
126 | self.assertEqual(cache['a'], 2)
127 | self.assertLen(cache, 1)
128 |
129 | def test_cache_iter(self):
130 | cache = func_utils.LruCache(maxsize=128)
131 | cache.cache_insert('a', 1)
132 | cache.cache_insert('b', 2)
133 | cache.cache_insert('a', 3)
134 | # The order of the reference here will push 'a' to later than b
135 | self.assertEqual(cache['b'], 2)
136 | self.assertEqual(cache['a'], 3)
137 | self.assertEqual(['b', 'a'], list(cache))
138 |
139 |
140 | @dc.dataclass(frozen=True, eq=True)
141 | class SingletonA(metaclass=func_utils.SingletonMeta):
142 | a: int
143 |
144 | def __eq__(self, other):
145 | return isinstance(other, SingletonA) and self.a == other.a
146 |
147 |
148 | @dc.dataclass(frozen=True, eq=True)
149 | class SingletonB(SingletonA):
150 | # Essentially the same class that will be hashed the same.
151 | b: str = 'b'
152 |
153 | def __eq__(self, other):
154 | return (
155 | super().__eq__(other)
156 | and isinstance(other, SingletonB)
157 | and self.b == other.b
158 | )
159 |
160 |
161 | class SingletonMetaTest(absltest.TestCase):
162 |
163 | def test_singleton(self):
164 | self.assertIs(SingletonA(1), SingletonA(1))
165 | self.assertIsNot(SingletonA(2), SingletonA(1))
166 | self.assertIsNot(SingletonA(1), SingletonB(1))
167 | self.assertIs(SingletonB(1, 'b'), SingletonB(1))
168 | self.assertIs(SingletonB(1, b='b'), SingletonB(1))
169 | self.assertIsNot(SingletonB(1, b='b1'), SingletonB(1))
170 |
171 | def test_singleton_all_instances(self):
172 | a = SingletonA(1)
173 | b = SingletonB(1)
174 | self.assertSameElements(SingletonB.all_instances, SingletonA.all_instances)
175 | instances = SingletonB.all_instances
176 | self.assertLen(instances, 2)
177 | self.assertIn(a, instances)
178 | self.assertIn(b, instances)
179 |
180 |
181 | if __name__ == '__main__':
182 | absltest.main()
183 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/math_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Math Utils."""
15 |
16 | from ml_metrics._src import types
17 | from ml_metrics._src.aggregates import types as agg_types
18 | import numpy as np
19 |
20 |
21 | def pos_sqrt(value) -> agg_types.NumbersT:
22 | """Returns sqrt of value or raises ValueError if negative."""
23 | if np.any(value < 0):
24 | raise ValueError('Attempt to take sqrt of negative value: {}'.format(value))
25 | return np.sqrt(value)
26 |
27 |
28 | def safe_divide(a, b) -> agg_types.NumbersT:
29 | """Divide arguments element-wise (a / b), but returns zero(s) if b is 0."""
30 | result = np.divide(
31 | a, b, out=np.zeros_like(a, dtype=agg_types.DefaultDType), where=(b != 0)
32 | )
33 |
34 | return result.item() if result.ndim == 0 else result
35 |
36 |
37 | def safe_to_scalar(arr: agg_types.NumbersT) -> agg_types.NumbersT:
38 | """Returns tuple, list, or np.ndarray as a scalar. Returns 0.0 if empty.
39 |
40 | Originally from tensorflow_model_analysis/metrics/metric_util.py
41 |
42 | Args:
43 | arr: A one element tuple, list, or numpy.ndarray to be converted to a
44 | scalar.
45 |
46 | Returns:
47 | The Python scalar.
48 | """
49 | if isinstance(arr, np.ndarray):
50 | if arr.size == 0:
51 | # 0 elements.
52 | return 0.0
53 | if arr.size == 1:
54 | # 1 element.
55 | return arr.item()
56 | else:
57 | # arr is tuple or list.
58 | if not arr:
59 | # 0 elements.
60 | return 0.0
61 | if len(arr) == 1:
62 | # 1 element.
63 | return arr[0]
64 |
65 | # >1 element.
66 | raise ValueError('Array should have exactly 1 value to a Python scalar')
67 |
68 |
69 | def where(
70 | condition: agg_types.NumbersT, x: agg_types.NumbersT, y: agg_types.NumbersT
71 | ):
72 | """Like np.where, but directly returns itself if condition is scalar."""
73 | if types.is_array_like(condition):
74 | return np.where(condition, x, y)
75 | # If x or y is scalar, return it directly to keep its type (vs. np.where
76 | # that wraps it in a one-element array).
77 | return x if condition else y
78 |
79 |
80 | def nanadd(a: agg_types.NumbersT, b: agg_types.NumbersT) -> agg_types.NumbersT:
81 | """Returns element-wise a + b, but ignores NaN as 0 unless both are NaN."""
82 | a_nan, b_nan = np.isnan(a), np.isnan(b)
83 | result = where(a_nan, 0, a) + where(b_nan, 0, b)
84 | # Revovers the NaN values if both operands are NaNs.
85 | return where(a_nan & b_nan, np.nan, result)
86 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/math_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Math Utils."""
15 |
16 | from collections.abc import Sequence
17 |
18 | from ml_metrics._src.utils import math_utils
19 | import numpy as np
20 |
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 |
24 |
25 | class MathUtilsTest(parameterized.TestCase):
26 |
27 | def test_pos_sqrt(self):
28 | self.assertEqual(math_utils.pos_sqrt(4), 2.0)
29 |
30 | def test_pos_sqrt_raises_value_error(self):
31 | with self.assertRaisesRegex(
32 | ValueError, 'Attempt to take sqrt of negative value: -1'
33 | ):
34 | math_utils.pos_sqrt(-1)
35 |
36 | @parameterized.named_parameters(
37 | ('both_zero', 0, 0, 0.0),
38 | ('zero_num', 0, 10, 0.0),
39 | ('zero_denom', 10, 0, 0.0),
40 | ('float_num', 10.5, 3, 3.5),
41 | ('float_denom', 14, 3.5, 4.0),
42 | ('array_num', [2, 4, 6, 8], 2, [1.0, 2.0, 3.0, 4.0]),
43 | ('array_num_denom', [2, 4, 6, 8], [4, 8, 12, 16], [0.5, 0.5, 0.5, 0.5]),
44 | )
45 | def test_safe_divide(self, a, b, expected_result):
46 | result = math_utils.safe_divide(a, b)
47 | if isinstance(result, (Sequence, np.ndarray)):
48 | self.assertSequenceAlmostEqual(result, expected_result)
49 | else:
50 | self.assertAlmostEqual(result, expected_result)
51 |
52 | # Original Tests safe_to_scalar tests from:
53 | # tensorflow_model_analysis/metrics/metric_util_test.py
54 | @parameterized.named_parameters(
55 | dict(testcase_name='tuple_int', arr=(1,), expected_scalar=1),
56 | dict(testcase_name='tuple_float', arr=(1.0,), expected_scalar=1.0),
57 | dict(
58 | testcase_name='tuple_string',
59 | arr=('string',),
60 | expected_scalar='string',
61 | ),
62 | dict(testcase_name='list_int', arr=[1], expected_scalar=1),
63 | dict(testcase_name='list_float', arr=[1.0], expected_scalar=1.0),
64 | dict(
65 | testcase_name='list_string', arr=['string'], expected_scalar='string'
66 | ),
67 | dict(testcase_name='np_array_int', arr=np.array(1), expected_scalar=1),
68 | dict(
69 | testcase_name='np_array_float', arr=np.array(1.0), expected_scalar=1.0
70 | ),
71 | dict(
72 | testcase_name='np_array_string',
73 | arr=np.array('string'),
74 | expected_scalar='string',
75 | ),
76 | )
77 | def test_safe_to_scalar_unpack_1d_iterable(self, arr, expected_scalar):
78 | self.assertEqual(math_utils.safe_to_scalar(arr), expected_scalar)
79 |
80 | @parameterized.named_parameters(
81 | dict(testcase_name='empty_tuple', arr=()),
82 | dict(testcase_name='empty_list', arr=[]),
83 | dict(testcase_name='empty_np_array', arr=np.array([])),
84 | )
85 | def test_safe_to_scalar_empty_iterable(self, arr):
86 | self.assertEqual(math_utils.safe_to_scalar(arr), 0.0)
87 |
88 | @parameterized.named_parameters(
89 | dict(testcase_name='2_elem_tuple', arr=(1, 2)),
90 | dict(testcase_name='2_elem_list', arr=[1, 2]),
91 | dict(testcase_name='2_elem_np_array', arr=np.array((1, 2))),
92 | )
93 | def test_safe_to_scalar_multi_elem_iterable_raises_value_error(self, arr):
94 | with self.assertRaisesRegex(
95 | ValueError, 'Array should have exactly 1 value to a Python scalar'
96 | ):
97 | math_utils.safe_to_scalar(arr)
98 |
99 | @parameterized.named_parameters(
100 | dict(
101 | testcase_name='partial_nans',
102 | a=[1, 0],
103 | b=[1, np.nan],
104 | expected=[2, 0],
105 | ),
106 | dict(
107 | testcase_name='both_nans',
108 | a=[1, np.nan],
109 | b=[1, np.nan],
110 | expected=[2, np.nan],
111 | ),
112 | dict(
113 | testcase_name='constant_broadcast',
114 | a=[np.nan, 0],
115 | b=3,
116 | expected=[3, 3],
117 | ),
118 | dict(
119 | testcase_name='nan_broadcast',
120 | a=[1, 0],
121 | b=np.nan,
122 | expected=[1, 0],
123 | ),
124 | dict(
125 | testcase_name='scalar',
126 | a=1,
127 | b=np.nan,
128 | expected=1,
129 | ),
130 | dict(
131 | testcase_name='scalar_nans',
132 | a=np.nan,
133 | b=np.nan,
134 | expected=np.nan,
135 | ),
136 | )
137 | def test_nanadd(self, a, b, expected):
138 | actual = math_utils.nanadd(a, b)
139 | np.testing.assert_allclose(actual, expected)
140 |
141 | def test_where(self):
142 | self.assertEqual(1, math_utils.where(True, 1, 0))
143 | self.assertEqual([0, 1], math_utils.where(False, 1, [0, 1]))
144 | actual = math_utils.where([True, False], [1, 1], [0, 0])
145 | np.testing.assert_array_equal([1, 0], actual)
146 |
147 |
148 | if __name__ == '__main__':
149 | absltest.main()
150 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/proto_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Proto utils."""
15 | from collections.abc import Iterable
16 | from typing import Any
17 | from absl import logging
18 | import more_itertools as mit
19 | import numpy as np
20 | from tensorflow.core.example import example_pb2
21 |
22 | _ExampleOrBytes = bytes | example_pb2.Example
23 |
24 |
25 | def _maybe_deserialize(ex: _ExampleOrBytes) -> example_pb2.Example:
26 | if isinstance(ex, bytes):
27 | return example_pb2.Example.FromString(ex)
28 | if isinstance(ex, example_pb2.Example):
29 | return ex
30 | raise TypeError('Unsupported type: %s' % type(ex))
31 |
32 |
33 | def tf_examples_to_dict(examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes):
34 | """Parses a serialized tf.train.Example to a dict."""
35 | single_example = False
36 | if isinstance(examples, (bytes, example_pb2.Example)):
37 | single_example = True
38 | examples = [examples]
39 | examples = (_maybe_deserialize(ex) for ex in examples)
40 | examples = mit.peekable(examples)
41 | if (head := examples.peek(None)) is None:
42 | return {}
43 |
44 | result = {k: [] for k in head.features.feature}
45 | for ex in examples:
46 | missing = set(result)
47 | for key, feature in ex.features.feature.items():
48 | missing.remove(key)
49 | value = getattr(feature, feature.WhichOneof('kind')).value
50 | if value and isinstance(value[0], bytes):
51 | try:
52 | value = [v.decode() for v in value]
53 | except UnicodeDecodeError:
54 | logging.info(
55 | 'chainable: %s',
56 | f'Failed to decode for {key}, forward the raw bytes.',
57 | )
58 | result[key].extend(value)
59 | if missing:
60 | raise ValueError(
61 | f'Missing keys: {missing}, expecting {set(result)}, got {ex=}'
62 | )
63 | result = {k: v for k, v in result.items()}
64 | # Scalar value in a single example will be returned with the scalar directly.
65 | if single_example and all(len(v) == 1 for v in result.values()):
66 | result = {k: v[0] for k, v in result.items()}
67 | return result
68 |
69 |
70 | def dict_to_tf_example(data: dict[str, Any]) -> example_pb2.Example:
71 | """Creates a tf.Example from a dictionary."""
72 | example = example_pb2.Example()
73 | for key, value in data.items():
74 | if isinstance(value, (str, bytes, np.floating, float, int, np.integer)):
75 | value = [value]
76 | feature = example.features.feature
77 | if isinstance(value[0], str):
78 | for v in value:
79 | assert isinstance(v, str), f'bad str type: {value}'
80 | feature[key].bytes_list.value.append(v.encode())
81 | elif isinstance(value[0], bytes):
82 | feature[key].bytes_list.value.extend(value)
83 | elif isinstance(value[0], (int, np.integer)):
84 | feature[key].int64_list.value.extend(value)
85 | elif isinstance(value[0], (float, np.floating)):
86 | feature[key].float_list.value.extend(value)
87 | else:
88 | raise TypeError(f'Value for "{key}" is not a supported type.')
89 | return example
90 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/proto_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from ml_metrics._src.utils import proto_utils
15 | from ml_metrics._src.utils import test_utils
16 | import numpy as np
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from tensorflow.core.example import example_pb2
20 |
21 |
22 | def _get_tf_example(**kwargs):
23 | example = example_pb2.Example()
24 | for k, v in kwargs.items():
25 | example.features.feature[k].bytes_list.value.append(v)
26 | return example
27 |
28 |
29 | class TFExampleTest(parameterized.TestCase):
30 |
31 | def test_single_example(self):
32 | data = {
33 | 'bytes_key': b'\x80abc', # not utf-8 decodable
34 | 'str_key': 'str_test',
35 | 'init_key': 123,
36 | 'np_int': np.int32(123),
37 | 'float_key': 4.56,
38 | 'np_float': np.float32(123),
39 | }
40 | e = proto_utils.dict_to_tf_example(data).SerializeToString()
41 | actual = proto_utils.tf_examples_to_dict(e)
42 | self.assertDictAlmostEqual(data, actual, places=6)
43 |
44 | def test_batch_example(self):
45 | data = {
46 | 'bytes_key': [b'\x80abc', b'\x80def'], # not utf-8 decodable
47 | 'str_key': ['str_test', 'str_test2'],
48 | 'init_key': [123, 456],
49 | 'np_int': [np.int32(123), np.int32(456)],
50 | 'float_key': [4.56, 7.89],
51 | 'np_float': [np.float32(123), np.float32(456)],
52 | }
53 | e = proto_utils.dict_to_tf_example(data)
54 | actual = proto_utils.tf_examples_to_dict(e)
55 | test_utils.assert_nested_container_equal(self, data, actual, places=6)
56 |
57 | @parameterized.named_parameters(
58 | dict(
59 | testcase_name='with_single_example',
60 | num_elems=1,
61 | ),
62 | dict(
63 | testcase_name='multiple_examples',
64 | num_elems=3,
65 | ),
66 | )
67 | def test_multiple_examples_as_batch(self, num_elems):
68 | data = {
69 | 'bytes_key': b'\x80abc', # not utf-8 decodable
70 | 'str_key': 'str_test',
71 | 'init_key': 123,
72 | 'np_int': np.int32(123),
73 | 'float_key': 4.56,
74 | 'np_float': np.float32(123),
75 | }
76 | e = [proto_utils.dict_to_tf_example(data) for _ in range(num_elems)]
77 | actual = proto_utils.tf_examples_to_dict(e)
78 | expected = {k: [v] * num_elems for k, v in data.items()}
79 | test_utils.assert_nested_container_equal(self, expected, actual, places=6)
80 |
81 | def test_empty_example(self):
82 | self.assertEmpty(proto_utils.tf_examples_to_dict([]))
83 |
84 | def test_unsupported_type(self):
85 | with self.assertRaisesRegex(TypeError, 'Unsupported type'):
86 | proto_utils.tf_examples_to_dict('unsupported_type')
87 |
88 | def test_unsupported_value_type(self):
89 | with self.assertRaisesRegex(
90 | TypeError, 'Value for "a" is not a supported type'
91 | ):
92 | proto_utils.dict_to_tf_example({'a': [example_pb2.Example()]})
93 |
94 | def test_multiple_examples_missing_key(self):
95 | data = [{'a': 'a', 'b': 1}, {'b': 2}]
96 | examples = [proto_utils.dict_to_tf_example(d) for d in data]
97 | with self.assertRaisesRegex(ValueError, 'Missing keys'):
98 | _ = proto_utils.tf_examples_to_dict(examples)
99 |
100 |
101 | if __name__ == '__main__':
102 | absltest.main()
103 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for testing, internal use only."""
15 |
16 | from collections.abc import Iterable
17 | import time
18 | import unittest
19 |
20 | from ml_metrics._src.aggregates import rolling_stats
21 | from ml_metrics._src.chainables import transform
22 | import numpy as np
23 |
24 |
25 | class TestSink:
26 |
27 | def __init__(self):
28 | self.data = []
29 | self.closed = False
30 |
31 | def write(self, data):
32 | self.data.append(data)
33 | return data
34 |
35 | def close(self):
36 | self.closed = True
37 |
38 |
39 | class Unpickleable:
40 | def __call__(self, x):
41 | return x + 1
42 |
43 | def __getstate__(self):
44 | raise TypeError('Unpickleable')
45 |
46 |
47 | def range_with_exc(x, exc_i):
48 | for i in range(x):
49 | if i == exc_i:
50 | raise ValueError(f'range_with_exc at {i}')
51 | yield i
52 |
53 |
54 | def range_with_sleep(n, sleep: float = 0):
55 | for i in range(n):
56 | if sleep:
57 | time.sleep(sleep)
58 | yield i
59 |
60 |
61 | def range_with_return(n, return_value: int = 0):
62 | return_value = return_value or n
63 | for i in range(n):
64 | yield i
65 | return return_value
66 |
67 |
68 | def inf_range(n):
69 | yield from range(n)
70 | raise ValueError(
71 | 'The runner does not iterate and copy everything into memory, it should'
72 | ' never be exhausted.'
73 | )
74 |
75 |
76 | class NoLenIter(Iterable):
77 |
78 | def __init__(self, iterable):
79 | self._iteratable = iterable
80 |
81 | def __len__(self):
82 | raise ValueError('Cannot call len()')
83 |
84 | def __iter__(self):
85 | return iter(self._iteratable)
86 |
87 |
88 | class SequenceWithExc:
89 | """A range that raises an exception at a specific value."""
90 |
91 | def __init__(self, end: int, exc_i: int, error_type=ValueError):
92 | self.end = end
93 | self.exc_i = exc_i
94 | self.error_type = error_type
95 |
96 | def __len__(self):
97 | return self.end
98 |
99 | def __getitem__(self, i):
100 | if i == self.exc_i:
101 | raise self.error_type(f'SequenceWithExc at {i}')
102 | return i
103 |
104 |
105 | def assert_nested_container_equal(
106 | test: unittest.TestCase, a, b, *, strict=False, places=7
107 | ):
108 | """Asserts that two nested containers are equal."""
109 | try:
110 | if strict:
111 | test.assertEqual(type(a), type(b))
112 | if isinstance(a, dict) and isinstance(b, dict):
113 | for (k_a, v_a), (k_b, v_b) in zip(
114 | sorted(a.items()), sorted(b.items()), strict=True
115 | ):
116 | test.assertEqual(k_a, k_b)
117 | assert_nested_container_equal(
118 | test, v_a, v_b, strict=strict, places=places
119 | )
120 | elif isinstance(a, str) and isinstance(b, str):
121 | test.assertEqual(a, b)
122 | elif hasattr(a, '__array__') and hasattr(b, '__array__'):
123 | np.testing.assert_allclose(a, b)
124 | elif isinstance(a, Iterable) and isinstance(b, Iterable):
125 | for a_elem, b_elem in zip(a, b, strict=True):
126 | assert_nested_container_equal(
127 | test, a_elem, b_elem, strict=strict, places=places
128 | )
129 | else:
130 | test.assertAlmostEqual(a, b, places=places)
131 | except Exception as e: # pylint: disable=broad-except
132 | test.fail(f'Failed to compare {a} and {b}: {e}')
133 |
134 |
135 | def sharded_ones(
136 | total_numbers: int,
137 | batch_size: int,
138 | shard_index: int = 0,
139 | num_shards: int = 1,
140 | ):
141 | num_batches, remainder = divmod(total_numbers, batch_size)
142 | for i in range(num_batches):
143 | if i % num_shards == shard_index:
144 | yield batch_size
145 | if not shard_index and remainder:
146 | yield remainder
147 |
148 |
149 | def sharded_pipeline(
150 | total_numbers: int,
151 | batch_size: int,
152 | shard_index: int = 0,
153 | num_shards: int = 1,
154 | fuse_aggregate: bool = True,
155 | num_threads: int = 0,
156 | ):
157 | """A pipeline to calculate the stats of batches of random integers."""
158 | data_pipeline = transform.TreeTransform.new(name='datasource').data_source(
159 | sharded_ones(
160 | total_numbers,
161 | batch_size=batch_size,
162 | shard_index=shard_index,
163 | num_shards=num_shards,
164 | )
165 | )
166 | apply_pipeline = transform.TreeTransform.new(
167 | name='apply', num_threads=num_threads
168 | ).apply(
169 | fn=lambda batch_size: np.random.randint(1, 100, size=batch_size),
170 | )
171 |
172 | if fuse_aggregate:
173 | return data_pipeline.chain(
174 | apply_pipeline.aggregate(
175 | output_keys='stats',
176 | fn=rolling_stats.MeanAndVariance().as_agg_fn(),
177 | )
178 | )
179 | return data_pipeline.chain(apply_pipeline).chain(
180 | transform.TreeTransform.new(name='agg').aggregate(
181 | output_keys='stats',
182 | fn=rolling_stats.MeanAndVariance().as_agg_fn(),
183 | )
184 | )
185 |
--------------------------------------------------------------------------------
/ml_metrics/_src/utils/test_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from ml_metrics._src.utils import test_utils
15 | from absl.testing import absltest
16 |
17 |
18 | class TestUtilsTest(absltest.TestCase):
19 |
20 | def test_inf_range(self):
21 | with self.assertRaises(ValueError):
22 | _ = list(test_utils.inf_range(10))
23 |
24 | def test_nolen_iter(self):
25 | with self.assertRaises(ValueError):
26 | if test_utils.NoLenIter(range(10)):
27 | pass
28 |
29 |
30 | if __name__ == "__main__":
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/ml_metrics/aggregates.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Aggregation interfaces."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.aggregates import utils
19 | from ml_metrics._src.aggregates.base import AggFnNested
20 | from ml_metrics._src.aggregates.base import Aggregatable
21 | from ml_metrics._src.aggregates.base import AggregateFn
22 | from ml_metrics._src.aggregates.base import as_agg_fn
23 | from ml_metrics._src.aggregates.base import CallableMetric
24 | from ml_metrics._src.aggregates.base import MergeableMetric
25 | from ml_metrics._src.aggregates.base import MergeableMetricAggFn
26 | from ml_metrics._src.aggregates.base import Metric
27 | from ml_metrics._src.aggregates.base import UserAggregateFn
28 | from ml_metrics._src.aggregates.keras_metric_wrapper import is_keras_metric
29 | from ml_metrics._src.aggregates.keras_metric_wrapper import KerasAggregateFn
30 | from ml_metrics._src.aggregates.types import AverageType
31 | from ml_metrics._src.aggregates.types import InputType
32 |
--------------------------------------------------------------------------------
/ml_metrics/chainable.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Chainable interfaces."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.chainables.lazy_fns import cache_info
19 | from ml_metrics._src.chainables.lazy_fns import clear_cache
20 | from ml_metrics._src.chainables.lazy_fns import LazyFn
21 | from ml_metrics._src.chainables.lazy_fns import LazyObject
22 | from ml_metrics._src.chainables.lazy_fns import makeables
23 | from ml_metrics._src.chainables.lazy_fns import maybe_make
24 | from ml_metrics._src.chainables.lazy_fns import pickler
25 | from ml_metrics._src.chainables.lazy_fns import trace
26 | from ml_metrics._src.chainables.transform import AggregateResult
27 | from ml_metrics._src.chainables.transform import AggregateTransform
28 | from ml_metrics._src.chainables.transform import iterate_with_returned
29 | from ml_metrics._src.chainables.transform import MetricKey
30 | from ml_metrics._src.chainables.transform import RunnerMode
31 | from ml_metrics._src.chainables.transform import TreeTransform as Pipeline
32 | from ml_metrics._src.chainables.tree import Key
33 | from ml_metrics._src.chainables.tree import tree_shape
34 | from ml_metrics._src.chainables.tree import TreeMapView
35 | from ml_metrics._src.utils import test_utils
36 | from ml_metrics._src.utils.iter_utils import iterate_fn
37 |
--------------------------------------------------------------------------------
/ml_metrics/data_sinks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """All datasinks for ML Metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics.google.data_sinks.hydra import HydraPropertyWriter
19 |
--------------------------------------------------------------------------------
/ml_metrics/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/ml-metrics/510ebbc9ab70847d8aff3eca7cb4d8a3309e6319/ml_metrics/metrics/__init__.py
--------------------------------------------------------------------------------
/ml_metrics/metrics/classification.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Classification metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.aggregates.classification import ConfusionMatrixAggFn
19 | from ml_metrics._src.aggregates.classification import ConfusionMatrixMetric
20 | from ml_metrics._src.aggregates.classification import SamplewiseClassification
21 | from ml_metrics._src.aggregates.classification import TopKConfusionMatrixAggFn
22 | from ml_metrics._src.metrics.classification import accuracy
23 | from ml_metrics._src.metrics.classification import balanced_accuracy
24 | from ml_metrics._src.metrics.classification import binary_accuracy
25 | from ml_metrics._src.metrics.classification import CalibrationHistogram
26 | from ml_metrics._src.metrics.classification import classification_metrics
27 | from ml_metrics._src.metrics.classification import ClassificationAggFn
28 | from ml_metrics._src.metrics.classification import diagnostic_odds_ratio
29 | from ml_metrics._src.metrics.classification import f1_score
30 | from ml_metrics._src.metrics.classification import fall_out
31 | from ml_metrics._src.metrics.classification import false_discovery_rate
32 | from ml_metrics._src.metrics.classification import false_omission_rate
33 | from ml_metrics._src.metrics.classification import fnr
34 | from ml_metrics._src.metrics.classification import fpr
35 | from ml_metrics._src.metrics.classification import informedness
36 | from ml_metrics._src.metrics.classification import intersection_over_union
37 | from ml_metrics._src.metrics.classification import markedness
38 | from ml_metrics._src.metrics.classification import matthews_correlation_coefficient
39 | from ml_metrics._src.metrics.classification import miss_rate
40 | from ml_metrics._src.metrics.classification import negative_likelihood_ratio
41 | from ml_metrics._src.metrics.classification import negative_prediction_value
42 | from ml_metrics._src.metrics.classification import nvp
43 | from ml_metrics._src.metrics.classification import positive_likelihood_ratio
44 | from ml_metrics._src.metrics.classification import positive_predictive_value
45 | from ml_metrics._src.metrics.classification import ppv
46 | from ml_metrics._src.metrics.classification import precision
47 | from ml_metrics._src.metrics.classification import prevalence
48 | from ml_metrics._src.metrics.classification import prevalence_threshold
49 | from ml_metrics._src.metrics.classification import recall
50 | from ml_metrics._src.metrics.classification import sensitivity
51 | from ml_metrics._src.metrics.classification import specificity
52 | from ml_metrics._src.metrics.classification import threat_score
53 | from ml_metrics._src.metrics.classification import tnr
54 | from ml_metrics._src.metrics.classification import tpr
55 |
--------------------------------------------------------------------------------
/ml_metrics/metrics/retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Retrieval metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.aggregates.retrieval import RetrievalMetric
19 | from ml_metrics._src.aggregates.retrieval import TopKRetrieval
20 | from ml_metrics._src.aggregates.retrieval import TopKRetrievalAggFn
21 | from ml_metrics._src.metrics.retrieval import accuracy
22 | from ml_metrics._src.metrics.retrieval import dcg_score
23 | from ml_metrics._src.metrics.retrieval import f1_score
24 | from ml_metrics._src.metrics.retrieval import false_discovery_rate
25 | from ml_metrics._src.metrics.retrieval import fowlkes_mallows_index
26 | from ml_metrics._src.metrics.retrieval import intersection_over_union
27 | from ml_metrics._src.metrics.retrieval import mean_average_precision
28 | from ml_metrics._src.metrics.retrieval import mean_reciprocal_rank
29 | from ml_metrics._src.metrics.retrieval import miss_rate
30 | from ml_metrics._src.metrics.retrieval import ndcg_score
31 | from ml_metrics._src.metrics.retrieval import positive_predictive_value
32 | from ml_metrics._src.metrics.retrieval import ppv
33 | from ml_metrics._src.metrics.retrieval import precision
34 | from ml_metrics._src.metrics.retrieval import recall
35 | from ml_metrics._src.metrics.retrieval import sensitivity
36 | from ml_metrics._src.metrics.retrieval import threat_score
37 | from ml_metrics._src.metrics.retrieval import topk_retrieval_metrics
38 | from ml_metrics._src.metrics.retrieval import tpr
39 |
--------------------------------------------------------------------------------
/ml_metrics/metrics/rolling_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Statistics metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.aggregates.rolling_stats import Counter
19 | from ml_metrics._src.aggregates.rolling_stats import FixedSizeSample
20 | from ml_metrics._src.aggregates.rolling_stats import Histogram
21 | from ml_metrics._src.aggregates.rolling_stats import Mean
22 | from ml_metrics._src.aggregates.rolling_stats import MeanAndVariance
23 | from ml_metrics._src.aggregates.rolling_stats import MinMaxAndCount
24 | from ml_metrics._src.aggregates.rolling_stats import R2Tjur
25 | from ml_metrics._src.aggregates.rolling_stats import R2TjurRelative
26 | from ml_metrics._src.aggregates.rolling_stats import RRegression
27 | from ml_metrics._src.aggregates.rolling_stats import SymmetricPredictionDifference
28 | from ml_metrics._src.aggregates.rolling_stats import UnboundedSampler
29 | from ml_metrics._src.aggregates.rolling_stats import Var
30 | from ml_metrics._src.metrics.rolling_stats import count
31 | from ml_metrics._src.metrics.rolling_stats import mean
32 | from ml_metrics._src.metrics.rolling_stats import stddev
33 | from ml_metrics._src.metrics.rolling_stats import total
34 | from ml_metrics._src.metrics.rolling_stats import var
35 |
--------------------------------------------------------------------------------
/ml_metrics/metrics/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Text metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | # Eventually move things under /aggregates/text.py to metrics/text.py. Making
19 | # the classes callable and replace the functions in metrics/text.py
20 | from ml_metrics._src.aggregates.text import PatternFrequency
21 | from ml_metrics._src.aggregates.text import TopKWordNGrams
22 | from ml_metrics._src.metrics.text import avg_alphabetical_char_count
23 | from ml_metrics._src.metrics.text import pattern_frequency
24 | from ml_metrics._src.metrics.text import topk_word_ngrams
25 |
--------------------------------------------------------------------------------
/ml_metrics/metrics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Metrics utility."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.metrics.utils import verify_input
19 |
--------------------------------------------------------------------------------
/ml_metrics/oss/README.md:
--------------------------------------------------------------------------------
1 | # Steps to build a new py-ml-metrics pip package
2 |
3 | 1. Update the version number in project.toml
4 |
5 | 2. To build pypi wheel, run:
6 |
7 | ```
8 | cd
9 | sh ml_metrics/oss/runner.sh
10 | ```
11 |
12 | 3. Wheels are in `/tmp/ml_metrics/all_dist`.
13 |
14 | 4. Upload to PyPI:
15 |
16 | ```
17 | python3 -m pip install --upgrade twine
18 | python3 -m twine upload /tmp/ml_metrics/all_dist/*-any.whl
19 | ```
20 |
21 | Authenticate with Twine by following https://pypi.org/help/#apitoken and editing
22 | your `~/.pypirc`.
23 |
24 | 5. Draft the new release in github: https://github.com/google/ml-metrics/releases.
25 | Tag the release commit with the version number.
26 |
27 | # Optional
28 |
29 | * Update workflows in oss/workflows when Python is upgraded.
30 |
31 | * To build for a different python version, change the PYTHON_MINOR_VERSION
32 | and/or PYTHON_MAJOR_VERSION in `ml_metrics/oss/runner.sh`.
33 |
34 | * To use a different docker image, switch it out under
35 | `ml_metrics/oss/build.Dockerfile`.
36 |
37 | * All the dependencies have to be installed manually in
38 | `ml_metrics/oss/build.Dockerfile`.
39 |
40 | * Refer the required dependencies from the dependencies section of `pyproject.toml`.
41 |
42 | * When you do not need to update version number, you can manually adds or
43 | increment build number by renaming the whl files under
44 | `/tmp/ml_metrics/all_dist` following the format of
45 | `py_ml_metrics-{version}-[{build version}]-{py version}-non-any.whl`
46 | e.g.,:
47 |
48 | ```
49 | mv py_ml_metrics-0.0.1-py310-none-any.whl py_ml_metrics-0.0.1-1-py310-none-any.whl
50 | ```
51 |
--------------------------------------------------------------------------------
/ml_metrics/oss/build.Dockerfile:
--------------------------------------------------------------------------------
1 | # Constructs the environment within which we will build the ml-metrics pip wheels.
2 | #
3 | # From /tmp/ml_metrics,
4 | # ❯ DOCKER_BUILDKIT=1 docker build \
5 | # --build-arg PYTHON_VERSION=${PYTHON_VERSION} \
6 | # -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile
7 | # ❯ docker run --rm -it -v /tmp/ml_metrics:/tmp/ml_metrics \
8 | # ml_metrics:${PYTHON_VERSION} bash
9 |
10 | FROM quay.io/pypa/manylinux2014_x86_64
11 | LABEL maintainer="ml-metrics team "
12 |
13 | ARG PYTHON_MAJOR_VERSION
14 | ARG PYTHON_MINOR_VERSION
15 | ARG PYTHON_VERSION
16 |
17 | ENV DEBIAN_FRONTEND=noninteractive
18 |
19 | RUN ulimit -n 1024 && yum install -y rsync
20 |
21 | ENV PATH="/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin:${PATH}"
22 |
23 | # Install dependencies needed for ml-metrics
24 | RUN --mount=type=cache,target=/root/.cache \
25 | python${PYTHON_VERSION} -m pip install -U \
26 | absl-py \
27 | build \
28 | cloudpickle \
29 | more-itertools\
30 | numpy;
31 |
32 | # Install dependencies needed for ml-metrics tests
33 | RUN --mount=type=cache,target=/root/.cache \
34 | python${PYTHON_VERSION} -m pip install -U \
35 | auditwheel;
36 |
37 | WORKDIR "/tmp/ml_metrics"
--------------------------------------------------------------------------------
/ml_metrics/oss/build_whl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # build wheel for python version specified in $PYTHON
3 |
4 | set -e -x
5 |
6 | CP_VERSION="cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}"
7 | PYTHON_BIN_PATH="/opt/python/${CP_VERSION}-${CP_VERSION}/bin/python"
8 |
9 | function main() {
10 |
11 | DEST="/tmp/ml_metrics/all_dist"
12 | mkdir -p "${DEST}"
13 |
14 | echo "=== Destination directory: ${DEST}"
15 |
16 | if [ "$ML_METRICS_RUN_TESTS" = true ] ; then
17 | python3 -m unittest discover -s ml_metrics -p '*_test.py'
18 | fi
19 |
20 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
21 |
22 | echo $(date) : "=== Using tmpdir: ${TMPDIR}"
23 |
24 | echo "=== Copy ml_metrics files"
25 |
26 | cp ./setup.py "${TMPDIR}"
27 | cp ./pyproject.toml "${TMPDIR}"
28 | cp ./LICENSE "${TMPDIR}"
29 | rsync -avm -L --exclude="__pycache__/*" ./ml_metrics "${TMPDIR}"
30 | # rsync -avm -L --include="*.so" --include="*_pb2.py" \
31 | # --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \
32 | # bazel-bin/ml_metrics "${TMPDIR}"
33 |
34 | pushd ${TMPDIR}
35 | echo $(date) : "=== Building wheel"
36 |
37 | "python${PYTHON_VERSION}" setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION}
38 | cp dist/*.whl "${DEST}"
39 |
40 | echo $(date) : "=== Auditing wheel"
41 | auditwheel repair --plat ${AUDITWHEEL_PLATFORM} -w dist dist/*.whl
42 |
43 | echo $(date) : "=== Listing wheel"
44 | ls -lrt dist/*.whl
45 | cp dist/*.whl "${DEST}"
46 | popd
47 |
48 | echo $(date) : "=== Output wheel file is in: ${DEST}"
49 | }
50 |
51 | main "$@"
--------------------------------------------------------------------------------
/ml_metrics/oss/runner.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This script copies ml_metrics from internal repo, builds a docker, and
3 | # builds pip wheels for all Python versions.
4 |
5 | set -e -x
6 |
7 | export TMP_FOLDER="/tmp/ml_metrics"
8 |
9 | # Clean previous folders/images.
10 | [ -f $TMP_FOLDER ] && rm -rf $TMP_FOLDER
11 |
12 | PYTHON_MAJOR_VERSION="3"
13 | PYTHON_MINOR_VERSION="11"
14 | PYTHON_VERSION="${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}"
15 | ML_METRICS_RUN_TESTS=true
16 |
17 | AUDITWHEEL_PLATFORM="manylinux2014_x86_64"
18 |
19 |
20 | docker rmi -f ml_metrics:${PYTHON_VERSION}
21 | docker rm -f ml_metrics
22 |
23 | # Synchronize Copybara in $TMP_FOLDER.
24 | cp -r . $TMP_FOLDER
25 |
26 | cd $TMP_FOLDER
27 |
28 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \
29 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \
30 | --build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
31 | --build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
32 | -t ml_metrics:${PYTHON_VERSION} - < ml_metrics/oss/build.Dockerfile
33 |
34 | docker run --rm -a stdin -a stdout -a stderr \
35 | --env PYTHON_VERSION=${PYTHON_VERSION} \
36 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
37 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
38 | --env ML_METRICS_RUN_TESTS=${ML_METRICS_RUN_TESTS} \
39 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
40 | -v /tmp/ml_metrics:/tmp/ml_metrics \
41 | --name ml_metrics ml_metrics:${PYTHON_VERSION} \
42 | bash ml_metrics/oss/build_whl.sh
43 |
44 | ls $TMP_FOLDER/all_dist/*.whl
--------------------------------------------------------------------------------
/ml_metrics/signals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/ml-metrics/510ebbc9ab70847d8aff3eca7cb4d8a3309e6319/ml_metrics/signals/__init__.py
--------------------------------------------------------------------------------
/ml_metrics/signals/classification.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Classification signals."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.signals.cg_score import complexity_gap_score
19 | from ml_metrics._src.signals.cross_entropy import binary_cross_entropy
20 | from ml_metrics._src.signals.cross_entropy import categorical_cross_entropy
21 | from ml_metrics._src.signals.flip_masks import binary_flip_mask
22 | from ml_metrics._src.signals.flip_masks import neg_to_pos_flip_mask
23 | from ml_metrics._src.signals.flip_masks import pos_to_neg_flip_mask
24 | from ml_metrics._src.signals.topk_accuracy import topk_accurate
25 |
--------------------------------------------------------------------------------
/ml_metrics/signals/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Image signals."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.signals.image import content_metadata
19 |
--------------------------------------------------------------------------------
/ml_metrics/signals/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Text signals."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.signals.text import alphabetical_char_count
19 | from ml_metrics._src.signals.text import average_word_length
20 | from ml_metrics._src.signals.text import exact_match
21 | from ml_metrics._src.signals.text import is_all_whitespace
22 | from ml_metrics._src.signals.text import non_ascii_char_count
23 | from ml_metrics._src.signals.text import reference_in_sample_match
24 | from ml_metrics._src.signals.text import reference_startswith_sample_match
25 | from ml_metrics._src.signals.text import sample_in_reference_match
26 | from ml_metrics._src.signals.text import sample_startswith_reference_match
27 | from ml_metrics._src.signals.text import token_count
28 | from ml_metrics._src.signals.text import token_match_rate
29 | from ml_metrics._src.signals.text import word_count
30 |
--------------------------------------------------------------------------------
/ml_metrics/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utitilies."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=unused-import
18 | from ml_metrics._src.utils import func_utils
19 | from ml_metrics._src.utils import iter_utils
20 | from ml_metrics._src.utils import math_utils
21 | from ml_metrics._src.utils import test_utils
22 | from ml_metrics._src.utils.df_utils import metrics_to_df
23 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "py-ml-metrics"
7 | version = "0.0.3"
8 | description = "ML-metrics: A library that provides performant and distributed friendly ML metrics implementations."
9 | keywords = []
10 | authors = [
11 | {name = "ML metrics team", email = "ml-metrics-dev@google.com"},
12 | ]
13 | dependencies = [
14 | 'absl-py',
15 | 'cloudpickle',
16 | 'immutabledict',
17 | 'numpy',
18 | 'more-itertools',
19 | ]
20 | readme = "README.md"
21 | license = { file = "LICENSE" }
22 | requires-python = ">=3.11"
23 | classifiers = [
24 | "Programming Language :: Python :: 3",
25 | "License :: OSI Approved :: Apache Software License",
26 | "Operating System :: POSIX :: Linux",
27 | ]
28 |
29 | [project.optional-dependencies]
30 | signals = ['pillow']
31 | full = ['pillow']
32 |
33 | [project.urls]
34 | homepage = "https://github.com/google/ml-metrics"
35 |
36 | [tool.setuptools.packages.find]
37 | include = ["ml_metrics*"]
38 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup.py file for ml-metrics.
2 |
3 | Most project configs are in `pyproject.toml` -- prefer to modify
4 | `pyproject.toml` over this file if possible.
5 | """
6 |
7 | import setuptools
8 |
9 |
10 | setuptools.setup()
11 |
--------------------------------------------------------------------------------