``
14 | """
15 |
16 | from ._composite import *
17 |
--------------------------------------------------------------------------------
/src/pytools/expression/composite/base/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Supporting classes for composite expressions.
3 |
4 | Rarely used outside of module :mod:`pytools.expression.composite`.
5 | """
6 |
7 | from ._base import *
8 |
--------------------------------------------------------------------------------
/src/pytools/expression/composite/base/_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of :mod:`pytools.expression.composite.base`.
3 | """
4 |
5 | import logging
6 | from typing import Any
7 |
8 | from ....api import AllTracker, inheritdoc
9 | from ... import Expression
10 | from ...atomic import Epsilon, Id
11 | from ...base import SimplePrefixExpression
12 | from ...operator import BinaryOperator, UnaryOperator
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | #
18 | # Exported names
19 | #
20 |
21 | __all__ = [
22 | "KeywordArgument",
23 | "DictEntry",
24 | "LambdaDefinition",
25 | ]
26 |
27 |
28 | #
29 | # Ensure all symbols introduced below are included in __all__
30 | #
31 |
32 | __tracker = AllTracker(globals())
33 |
34 |
35 | #
36 | # Class definitions
37 | #
38 |
39 |
40 | @inheritdoc(match="[see superclass]")
41 | class KeywordArgument(SimplePrefixExpression):
42 | """
43 | A keyword argument, used by functions.
44 | """
45 |
46 | _PRECEDENCE = BinaryOperator.EQ.precedence
47 |
48 | def __init__(self, name: str, value: Any) -> None:
49 | """
50 | :param name: the name of the keyword
51 | :param value: the value for the keyword
52 | """
53 | super().__init__(prefix=Id(name), body=value)
54 | self._name = name
55 |
56 | @property
57 | def name_(self) -> str:
58 | """
59 | The name of this keyword argument.
60 | """
61 | return self._name
62 |
63 | @property
64 | def separator_(self) -> str:
65 | """[see superclass]"""
66 | return "="
67 |
68 | @property
69 | def value_(self) -> Expression:
70 | """
71 | The name of this keyword argument.
72 | """
73 | return self.body_
74 |
75 | @property
76 | def precedence_(self) -> int:
77 | """[see superclass]"""
78 | return self._PRECEDENCE
79 |
80 |
81 | @inheritdoc(match="[see superclass]")
82 | class DictEntry(SimplePrefixExpression):
83 | """
84 | Two expressions separated by a colon, used in dictionaries.
85 | """
86 |
87 | _PRECEDENCE = BinaryOperator.COLON.precedence
88 |
89 | def __init__(self, key: Any, value: Any) -> None:
90 | """
91 | :param key: the key of the dictionary entry
92 | :param value: the value of the dictionary entry
93 | """
94 | super().__init__(prefix=key, body=value)
95 |
96 | @property
97 | def key_(self) -> Expression:
98 | """
99 | The key of this dictionary entry; identical with the expression prefix.
100 | """
101 | return self.prefix_
102 |
103 | @property
104 | def separator_(self) -> str:
105 | """A ``:``, followed by a space."""
106 | return ": "
107 |
108 | @property
109 | def value_(self) -> Expression:
110 | """
111 | The value of this dictionary entry; identical with the expression body.
112 | """
113 | return self.body_
114 |
115 | @property
116 | def precedence_(self) -> int:
117 | """[see superclass]"""
118 | return DictEntry._PRECEDENCE
119 |
120 |
121 | @inheritdoc(match="[see superclass]")
122 | class LambdaDefinition(SimplePrefixExpression):
123 | """
124 | Function parameters and body separated by a colon, used inside lambda expressions.
125 | """
126 |
127 | _PRECEDENCE = UnaryOperator.LAMBDA.precedence
128 |
129 | def __init__(self, *params: Id, body: Any) -> None:
130 | """
131 | :param params: the parameters of the lambda expression
132 | :param body: the body of the lambda expression
133 | """
134 |
135 | params_expression: Expression
136 |
137 | if not params:
138 | params_expression = Epsilon()
139 | elif len(params) == 1:
140 | params_expression = params[0]
141 | else:
142 | from .. import BinaryOperation
143 |
144 | params_expression = BinaryOperation(BinaryOperator.COMMA, *params)
145 | super().__init__(prefix=params_expression, body=body)
146 |
147 | @property
148 | def params_(self) -> Expression:
149 | """
150 | The parameters of the lambda expression.
151 | """
152 | return self.prefix_
153 |
154 | @property
155 | def separator_(self) -> str:
156 | """A ``:``, followed by a space."""
157 | return ": "
158 |
159 | @property
160 | def precedence_(self) -> int:
161 | """[see superclass]"""
162 | return LambdaDefinition._PRECEDENCE
163 |
--------------------------------------------------------------------------------
/src/pytools/expression/formatter/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Expression formatters for converting :class:`.Expression` objects into text
3 | representations.
4 | """
5 |
6 | from ._python import *
7 |
--------------------------------------------------------------------------------
/src/pytools/expression/operator/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Operators used in expressions.
3 | """
4 |
5 | from ._operator import *
6 |
--------------------------------------------------------------------------------
/src/pytools/expression/repr/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Standard Python collection classes, enhanced for expression representations.
3 | """
4 |
5 | from ._repr import *
6 |
--------------------------------------------------------------------------------
/src/pytools/expression/repr/_repr.py:
--------------------------------------------------------------------------------
1 | """
2 | Native classes, enhanced with mixin class :class:`.HasExpressionRepr`.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | from typing import Generic, TypeVar
9 |
10 | from typing_extensions import TypeVarTuple, Unpack
11 |
12 | from pytools.api import inheritdoc
13 | from pytools.expression import Expression, HasExpressionRepr
14 | from pytools.expression.atomic import Id
15 | from pytools.expression.composite import (
16 | DictLiteral,
17 | ListLiteral,
18 | SetLiteral,
19 | TupleLiteral,
20 | )
21 |
22 | log = logging.getLogger(__name__)
23 |
24 | __all__ = [
25 | "DictWithExpressionRepr",
26 | "ListWithExpressionRepr",
27 | "SetWithExpressionRepr",
28 | "TupleWithExpressionRepr",
29 | ]
30 |
31 | #
32 | # Type variables
33 | #
34 |
35 | T = TypeVar("T")
36 | Ts = TypeVarTuple("Ts")
37 | KT = TypeVar("KT")
38 | VT = TypeVar("VT")
39 |
40 |
41 | #
42 | # Classes
43 | #
44 |
45 |
46 | @inheritdoc(match="[see superclass]")
47 | class ListWithExpressionRepr(HasExpressionRepr, list[T], Generic[T]):
48 | """
49 | A list that formats its string representation as an expression.
50 | """
51 |
52 | def to_expression(self) -> Expression:
53 | """[see superclass]"""
54 | return ListLiteral(*self)
55 |
56 |
57 | @inheritdoc(match="[see superclass]")
58 | class TupleWithExpressionRepr(
59 | HasExpressionRepr, tuple[Unpack[Ts]], Generic[Unpack[Ts]]
60 | ):
61 | """
62 | A tuple that formats its string representation as an expression.
63 | """
64 |
65 | def to_expression(self) -> Expression:
66 | """[see superclass]"""
67 | return TupleLiteral(*self)
68 |
69 |
70 | @inheritdoc(match="[see superclass]")
71 | class SetWithExpressionRepr(HasExpressionRepr, set[T], Generic[T]):
72 | """
73 | A set that formats its string representation as an expression.
74 | """
75 |
76 | def to_expression(self) -> Expression:
77 | """[see superclass]"""
78 | return SetLiteral(*self) if self else Id(set)()
79 |
80 |
81 | @inheritdoc(match="[see superclass]")
82 | class DictWithExpressionRepr(HasExpressionRepr, dict[KT, VT], Generic[KT, VT]):
83 | """
84 | A dictionary that formats its string representation as an expression.
85 | """
86 |
87 | def to_expression(self) -> Expression:
88 | """[see superclass]"""
89 | return DictLiteral(*self.items())
90 |
--------------------------------------------------------------------------------
/src/pytools/fit/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Support for fitting objects to data, and testing whether an object is fitted.
3 | """
4 |
5 | from ._fit import *
6 |
--------------------------------------------------------------------------------
/src/pytools/fit/_fit.py:
--------------------------------------------------------------------------------
1 | """
2 | Core implementation of :mod:`pytools.fit`.
3 | """
4 |
5 | import functools
6 | import logging
7 | from abc import ABCMeta, abstractmethod
8 | from collections.abc import Callable
9 | from typing import Any, Generic, TypeVar, cast, overload
10 |
11 | from ..api import AllTracker
12 |
13 | log = logging.getLogger(__name__)
14 |
15 | #
16 | # Exported names
17 | #
18 |
19 | __all__ = [
20 | "NotFittedError",
21 | "FittableMixin",
22 | "fitted_only",
23 | ]
24 |
25 |
26 | #
27 | # Type variables
28 | #
29 |
30 | T_Self = TypeVar("T_Self")
31 | T_Data = TypeVar("T_Data")
32 | T_Callable = TypeVar("T_Callable", bound=Callable[..., Any])
33 |
34 |
35 | #
36 | # Ensure all symbols introduced below are included in __all__
37 | #
38 |
39 | __tracker = AllTracker(globals())
40 |
41 |
42 | #
43 | # Classes
44 | #
45 |
46 |
47 | class NotFittedError(Exception):
48 | """
49 | Raised when a fittable object was expected to be fitted but was not fitted.
50 |
51 | See also :class:`FittableMixin` and the :obj:`fitted_only` decorator.
52 | """
53 |
54 |
55 | class FittableMixin(Generic[T_Data], metaclass=ABCMeta):
56 | # noinspection GrazieInspection
57 | """
58 | Mix-in class that supports fitting the object to data.
59 |
60 | See also the :obj:`fitted_only` decorator.
61 |
62 | Usage:
63 |
64 | .. code-block:: python
65 |
66 | class MyFittable(FittableMixin[MyData]):
67 | def fit(self, data: MyData) -> "MyFittable":
68 | # fit object to data
69 | ...
70 |
71 | return self
72 |
73 | def is_fitted(self) -> bool:
74 | # Return True if the object is fitted, False otherwise
75 | ...
76 |
77 | @fitted_only
78 | def some_method(self, ...) -> ...:
79 | # This method may only be called if the object is fitted
80 | ...
81 |
82 | .. note::
83 | This class is not meant to be instantiated directly. Instead, it is
84 | meant to be used as a mix-in class for other classes.
85 | """
86 |
87 | @abstractmethod
88 | def fit(self: T_Self, __x: T_Data, **fit_params: Any) -> T_Self:
89 | """
90 | Fit this object to the given data.
91 |
92 | :param __x: the data to fit this object to
93 | :param fit_params: optional fitting parameters
94 | :return: self
95 | """
96 | pass
97 |
98 | @property
99 | @abstractmethod
100 | def is_fitted(self) -> bool:
101 | """
102 | ``True`` if this object is fitted, ``False`` otherwise.
103 | """
104 | pass
105 |
106 |
107 | @overload
108 | def fitted_only(__method: T_Callable) -> T_Callable:
109 | """[overloaded]"""
110 |
111 |
112 | @overload
113 | def fitted_only(
114 | *,
115 | not_fitted_error: type[Exception] = NotFittedError,
116 | ) -> Callable[[T_Callable], T_Callable]:
117 | """[overloaded]"""
118 |
119 |
120 | def fitted_only(
121 | __method: T_Callable | None = None,
122 | *,
123 | not_fitted_error: type[Exception] = NotFittedError,
124 | ) -> T_Callable | Callable[[T_Callable], T_Callable]:
125 | # noinspection GrazieInspection
126 | """
127 | Decorator that ensures that the decorated method is only called if the object is
128 | fitted.
129 |
130 | The decorated method must be a method of a class that inherits from
131 | :class:`FittableMixin`, or implements a boolean property ``is_fitted``.
132 |
133 | Usage:
134 |
135 | .. code-block:: python
136 |
137 | class MyFittable(FittableMixin):
138 | def __init__(self) -> None:
139 | self._is_fitted = False
140 |
141 | @fitted_only
142 | def my_method(self) -> None:
143 | # this method may only be called if the object is fitted
144 | ...
145 |
146 | @property
147 | def is_fitted(self) -> bool:
148 | return self._is_fitted
149 |
150 | :param __method: the method to decorate
151 | :param not_fitted_error: the type of exception to raise if the object is not fitted;
152 | defaults to :class:`.NotFittedError`
153 | :return: the decorated method
154 | """
155 | if __method is None:
156 | return functools.partial(fitted_only, not_fitted_error=not_fitted_error)
157 | method: T_Callable = __method
158 |
159 | @functools.wraps(method)
160 | def _wrapper(self: FittableMixin[Any], *args: Any, **kwargs: Any) -> Any:
161 | if not self.is_fitted:
162 | raise not_fitted_error(f"{type(self).__name__} is not fitted")
163 | return method(self, *args, **kwargs)
164 |
165 | return cast(T_Callable, _wrapper)
166 |
167 |
168 | __tracker.validate()
169 |
--------------------------------------------------------------------------------
/src/pytools/http/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple HTTP client and server tools.
3 | """
4 |
5 | from ._http import *
6 |
--------------------------------------------------------------------------------
/src/pytools/http/_http.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ``fetch_url``.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import http.client
8 | import logging
9 | import urllib.parse
10 |
11 | log = logging.getLogger(__name__)
12 |
13 | __all__ = [
14 | "fetch_url",
15 | ]
16 |
17 |
18 | def fetch_url(url: str) -> bytes:
19 | """
20 | Fetch the contents of a URL.
21 |
22 | :param url: the URL to fetch
23 | :return: the contents of the URL
24 | :raises ValueError: if the request fails
25 | """
26 | # Parse the URL
27 | parsed_url = urllib.parse.urlparse(url)
28 |
29 | # Determine the connection type based on the URL scheme
30 | if parsed_url.scheme == "http":
31 | conn = http.client.HTTPConnection(parsed_url.netloc)
32 | elif parsed_url.scheme == "https":
33 | conn = http.client.HTTPSConnection(parsed_url.netloc)
34 | else:
35 | raise ValueError("Unsupported URL scheme")
36 |
37 | # Send a GET request
38 | path = parsed_url.path if parsed_url.path else "/"
39 | conn.request("GET", path)
40 |
41 | # Get the response
42 | response = conn.getresponse()
43 |
44 | # Read the response body
45 | data = response.read()
46 |
47 | # Get the response code
48 | status_code = response.status
49 |
50 | # If the response code is not OK, raise an error
51 | if status_code != http.client.OK:
52 | raise http.client.HTTPException(
53 | f"Request failed with status code {status_code}: {url}"
54 | )
55 |
56 | # Close the connection
57 | conn.close()
58 |
59 | return data
60 |
--------------------------------------------------------------------------------
/src/pytools/meta/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Useful meta-classes.
3 | """
4 |
5 | from ._meta import *
6 |
--------------------------------------------------------------------------------
/src/pytools/meta/_meta.py:
--------------------------------------------------------------------------------
1 | """
2 | Core implementation of :mod:`pytools.meta`.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | from abc import ABCMeta
9 | from typing import Any
10 | from weakref import ReferenceType
11 |
12 | from ..api import AllTracker
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | #
18 | # Exported names
19 | #
20 |
21 | __all__ = [
22 | "SingletonABCMeta",
23 | "SingletonMeta",
24 | ]
25 |
26 |
27 | #
28 | # Ensure all symbols introduced below are included in __all__
29 | #
30 |
31 | __tracker = AllTracker(globals())
32 |
33 |
34 | #
35 | # Classes
36 | #
37 |
38 |
39 | class SingletonMeta(type):
40 | """
41 | Metaclass for singleton classes.
42 |
43 | Subsequent instantiations of a singleton class return the identical object.
44 | Singleton classes must not accept any parameters upon instantiation.
45 | """
46 |
47 | __instance_ref: ReferenceType | None # type: ignore
48 |
49 | def __init__(cls: SingletonMeta, *args: Any, **kwargs: Any) -> None:
50 | """
51 | :param args: arguments to be passed on to the initializer of the base metaclass
52 | :param kwargs: keyword arguments to be passed on to the initializer of the
53 | base metaclass
54 | """
55 | super().__init__(*args, **kwargs)
56 | cls.__instance_ref = None
57 |
58 | def __call__(cls: SingletonMeta, *args: Any, **kwargs: Any) -> Any:
59 | """
60 | Return the existing singleton instance, or create a new one if none exists yet.
61 |
62 | Behind the scenes, uses a weak reference so the singleton instance can be
63 | garbage collected when no longer in use.
64 |
65 | Singletons must be instantiated without any parameters.
66 |
67 | :return: the singleton instance
68 | :raises ValueError: if called with parameters
69 | """
70 | if args or kwargs:
71 | raise ValueError("singleton classes may not take any arguments")
72 |
73 | if cls.__instance_ref:
74 | obj: Any | None = cls.__instance_ref()
75 | if obj is not None:
76 | return obj
77 |
78 | instance: Any = super().__call__()
79 | cls.__instance_ref = ReferenceType(instance)
80 | return instance
81 |
82 |
83 | class SingletonABCMeta(SingletonMeta, ABCMeta):
84 | """
85 | Convenience metaclass combining :class:`.SingletonMeta` and :class:`~abc.ABCMeta`.
86 | """
87 |
88 |
89 | __tracker.validate()
90 |
--------------------------------------------------------------------------------
/src/pytools/parallelization/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Parallelization support based on the :mod:`joblib` package.
3 | """
4 |
5 | from ._parallelization import *
6 |
--------------------------------------------------------------------------------
/src/pytools/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BCG-X-Official/pytools/9d6d37280b72724bd64f69fe7c98d687cbfa5317/src/pytools/py.typed
--------------------------------------------------------------------------------
/src/pytools/repr/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Alternative representations of objects.
3 |
4 | This module provides a set of utilities for creating alternative representations of
5 | objects. This can be useful for debugging, logging, and serialization.
6 | """
7 |
8 | from ._dict import *
9 |
--------------------------------------------------------------------------------
/src/pytools/repr/_dict.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ``DictRepresentation``.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | from abc import ABCMeta
9 | from collections.abc import Iterable, Mapping
10 | from importlib import import_module
11 | from typing import Any, TypeVar, final
12 |
13 | from ..api import get_init_params, inheritdoc
14 | from ..expression import Expression, HasExpressionRepr
15 | from ..expression.atomic import Id
16 |
17 | log = logging.getLogger(__name__)
18 |
19 | __all__ = [
20 | "HasDictRepr",
21 | ]
22 |
23 | #
24 | # Type variables
25 | #
26 |
27 |
28 | T_Class = TypeVar("T_Class", bound="HasDictRepr")
29 |
30 |
31 | #
32 | # Constants
33 | #
34 |
35 | KEY_CLASS = "cls"
36 | KEY_PARAMS = "params"
37 | KEY_DICT = "dict"
38 |
39 |
40 | @inheritdoc(match="""[see superclass]""")
41 | class HasDictRepr(HasExpressionRepr, metaclass=ABCMeta):
42 | """
43 | A class that can be represented as a JSON-serializable dictionary.
44 | """
45 |
46 | @final
47 | def to_dict(self) -> dict[str, Any]:
48 | """
49 | Convert this object to a dictionary that can be serialized to JSON.
50 |
51 | Calls private method :meth:`._get_params` to get the parameters that were
52 | used to initialize this object. By default, this uses class introspection to
53 | determine the class initializer parameters and access attributes of the same
54 | name. Subclasses can override this method to provide a custom implementation.
55 |
56 | Creates a dictionary with keys ``{KEY_CLASS}`` and ``{KEY_PARAMS}``,
57 | with the class name and a dictionary of parameters, respectively.
58 |
59 | Parameter values are recursively converted to JSON-serializable forms using the
60 | following rules:
61 |
62 | - If the value is a string, bytes, int, float, or bool, leave it as is
63 | - If the value is an instance of :class:`HasDictRepresentation`, call its
64 | :meth:`to_dict` method and use the result
65 | - If the value is a mapping, recursively convert its keys and values,
66 | collate the result into a dictionary and wrap it in a dictionary with a
67 | single key ``{KEY_DICT}`` to distinguish it from dictionary representations
68 | of classes
69 | - If the value is an iterable, recursively convert its elements and collate the
70 | result into a list
71 | - Otherwise, convert the value to a string using :func:`repr`.
72 |
73 | :return: the dictionary representation of the object
74 | """
75 |
76 | # iterate over args of __init__; these must correspond to matching fields
77 | # in the class
78 |
79 | return {
80 | KEY_CLASS: f"{self.__module__}.{type(self).__qualname__}",
81 | KEY_PARAMS: {
82 | # recursively convert all parameter values to JSON-serializable forms
83 | name: _to_json_like(value)
84 | for name, value in self._get_params().items()
85 | },
86 | }
87 |
88 | to_dict.__doc__ = str(to_dict.__doc__).format(
89 | KEY_CLASS=KEY_CLASS, KEY_PARAMS=KEY_PARAMS, KEY_DICT=KEY_DICT
90 | )
91 |
92 | @classmethod
93 | @final
94 | def from_dict(
95 | cls: type[T_Class], data: Mapping[str, Any], **kwargs: Any
96 | ) -> T_Class:
97 | """
98 | Create a new instance of this class from a dictionary.
99 |
100 | This method is the inverse of :meth:`to_dict`. It creates an instance of the
101 | class from a dictionary representation.
102 |
103 | :param data: the dictionary representation of the object
104 | :param kwargs: additional keyword arguments for pre-processing the parameters
105 | :return: the new object
106 | :raises TypeError: if the class of the object in the dictionary representation
107 | is not a subclass of this class
108 | """
109 |
110 | # the dict should only have the keys we expect
111 | unexpected_keys: set[str] = data.keys() - [KEY_CLASS, KEY_PARAMS]
112 | if unexpected_keys:
113 | raise ValueError(
114 | f"Unexpected keys in object representation: {unexpected_keys}"
115 | )
116 |
117 | # get the class name and parameters from the dictionary
118 | try:
119 | class_full_name: str = data[KEY_CLASS]
120 | params: dict[str, Any] = data[KEY_PARAMS]
121 | except KeyError as e:
122 | raise ValueError(
123 | f"Expected keys {KEY_CLASS!r} and {KEY_PARAMS!r} in object "
124 | f"representation but got: {data!r}"
125 | ) from e
126 |
127 | # get the class from the name
128 | module_name, _, class_name = class_full_name.rpartition(".")
129 | cls_target: type[T_Class] = getattr(import_module(module_name), class_name)
130 |
131 | if not issubclass(cls_target, cls):
132 | raise TypeError(f"Expected a subclass of {cls}, but got {cls_target}")
133 |
134 | # create an instance of the class with the parameters
135 | return cls_target(**cls_target._params_from_dict(params, **kwargs))
136 |
137 | def _get_params(self) -> dict[str, Any]:
138 | """
139 | Get the parameters that were used to initialize this object.
140 | """
141 | return get_init_params(self)
142 |
143 | @classmethod
144 | def _params_from_dict(
145 | cls, params: Mapping[str, Any], **kwargs: Any
146 | ) -> dict[str, Any]:
147 | """
148 | Process the parameters from a dict representation prior to creating a new
149 | object.
150 |
151 | This method is called by :meth:`to_dict`. By default, this method recursively
152 | de-serializes dictionary representations into instances of their original
153 | classes, and returns built-in data structures and data types as-is.
154 |
155 | Subclasses can override this method to provide a custom implementation.
156 |
157 | :param params: the parameters to process
158 | :return: the processed parameters
159 | """
160 | return {
161 | name: _from_json_like(value, **kwargs) for name, value in params.items()
162 | }
163 |
164 | def to_expression(self) -> Expression:
165 | """[see superclass]"""
166 | return Id(type(self))(**self._get_params())
167 |
168 |
169 | def _to_json_like(obj: Any) -> Any:
170 | # helper function to convert an object to a JSON-serializable form
171 | if obj is None or isinstance(obj, (str, bytes, int, float, bool)):
172 | return obj
173 | elif isinstance(obj, Mapping):
174 | return {KEY_DICT: {_to_json_like(k): _to_json_like(v) for k, v in obj.items()}}
175 | elif isinstance(obj, Iterable):
176 | return list(map(_to_json_like, iter(obj)))
177 | elif isinstance(obj, HasDictRepr):
178 | return obj.to_dict()
179 | else:
180 | raise ValueError(f"Object does not implement {HasDictRepr.__name__}: {obj!r}")
181 |
182 |
183 | def _from_json_like(obj: Any, **kwargs: Any) -> Any:
184 | # helper function to convert a JSON-serializable object representation to
185 | # the original object
186 | if obj is None or isinstance(obj, (str, bytes, int, float, bool)):
187 | return obj
188 | elif isinstance(obj, Mapping):
189 | if len(obj) == 1:
190 | try:
191 | dict_ = obj[KEY_DICT]
192 | except KeyError:
193 | raise ValueError(
194 | f"Expected key {KEY_DICT!r} in object representation but got: "
195 | f"{obj!r}"
196 | )
197 | return {_from_json_like(k): _from_json_like(v) for k, v in dict_.items()}
198 | elif len(obj) == 2:
199 | return HasDictRepr.from_dict(obj, **kwargs)
200 | else:
201 | raise ValueError(f"Invalid object representation: {obj}")
202 | elif isinstance(obj, Iterable):
203 | return list(map(_from_json_like, iter(obj)))
204 | else:
205 | raise ValueError(f"Invalid object representation: {obj}")
206 |
--------------------------------------------------------------------------------
/src/pytools/sphinx/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Supporting tools and autodoc enhancements for generating Sphinx documentation.
3 | """
4 |
5 | from ._sphinx import *
6 |
--------------------------------------------------------------------------------
/src/pytools/sphinx/util/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Supporting tools and autodoc enhancements for generating Sphinx documentation.
3 | """
4 |
5 | from ._util import *
6 |
--------------------------------------------------------------------------------
/src/pytools/text/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for rendering and manipulating text.
3 | """
4 |
5 | from ._strings import *
6 | from ._template import *
7 | from ._text import *
8 |
--------------------------------------------------------------------------------
/src/pytools/text/_strings.py:
--------------------------------------------------------------------------------
1 | """
2 | String manipulation functions.
3 | """
4 |
5 | import logging
6 | import re
7 |
8 | from pytools.api import AllTracker
9 |
10 | log = logging.getLogger(__name__)
11 |
12 |
13 | #
14 | # Exported names
15 | #
16 |
17 | __all__ = ["camel_case_to_snake_case"]
18 |
19 |
20 | #
21 | # Constants
22 | #
23 |
24 | RE_CAMEL_CASE_CHUNKS = re.compile(
25 | r"("
26 | r"(?:"
27 | r"[A-Z]+(?=[^a-z]|\b)"
28 | r"|[A-Za-z0-9][a-z0-9]*"
29 | r"|\W+"
30 | r")"
31 | r"(?:_(?=_))*"
32 | r")",
33 | re.ASCII,
34 | )
35 |
36 |
37 | #
38 | # Ensure all symbols introduced below are included in __all__
39 | #
40 |
41 | __tracker = AllTracker(globals())
42 |
43 |
44 | #
45 | # Functions
46 | #
47 |
48 |
49 | def camel_case_to_snake_case(camel: str) -> str:
50 | """
51 | Convert a string from ``CamelCase`` to ``snake_case``.
52 |
53 | :param camel: a string in camel case
54 | :return: the string converted to snake case
55 | """
56 | return "_".join(re.findall(RE_CAMEL_CASE_CHUNKS, camel)).lower()
57 |
58 |
59 | __tracker.validate()
60 |
--------------------------------------------------------------------------------
/src/pytools/text/_template.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of TextTemplate.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | import string
9 | from collections.abc import Iterable, Sized
10 | from typing import Any
11 |
12 | from pytools.api import as_set, inheritdoc
13 | from pytools.expression import (
14 | Expression,
15 | HasExpressionRepr,
16 | expression_from_init_params,
17 | )
18 |
19 | log = logging.getLogger(__name__)
20 |
21 | __all__ = [
22 | "TextTemplate",
23 | ]
24 |
25 |
26 | #
27 | # Class declarations
28 | #
29 |
30 |
31 | @inheritdoc(match="""[see superclass]""")
32 | class TextTemplate(HasExpressionRepr):
33 | """
34 | A template for generating text by substituting format keys in a format string with
35 | actual values.
36 |
37 | The format string must contain the required formatting keys, and no unexpected keys.
38 |
39 | Method :meth:`format_with_attributes` formats the text by substituting the format
40 | keys with the given attributes passed as keyword arguments.
41 |
42 | If the template is `strict`, an error is raised if not all attributes have a
43 | corresponding key in the format string. If the template is not strict, attributes
44 | must be provided to substitute all keys in the format string, but attributes not
45 | present in the formatting keys will be ignored.
46 | """
47 |
48 | #: A format string with formatting keys that will be substituted with values.
49 | format_string: str
50 |
51 | #: The formatting keys used in the format string.
52 | formatting_keys: set[str]
53 |
54 | #: If ``False``, an error is raised if the format string contains keys other than
55 | # the required keys; if ``True``, additional keys are allowed.
56 | allow_additional_keys: bool
57 |
58 | #: If ``False```, the template is strict and an error is raised if not all
59 | #: attributes have a corresponding key in the format string; if ``True``, attributes
60 | #: not present in the formatting keys will be ignored.
61 | ignore_unmatched_attributes: bool
62 |
63 | def __init__(
64 | self,
65 | *,
66 | format_string: str,
67 | required_keys: Iterable[str],
68 | allow_additional_keys: bool = False,
69 | ignore_unmatched_attributes: bool = False,
70 | ) -> None:
71 | """
72 | :param format_string: a format string with formatting keys that will be
73 | substituted with values
74 | :param required_keys: the names of the formatting keys required in the format
75 | string
76 | :param allow_additional_keys: if ``False``, an error is raised if the format
77 | string contains keys other than the required keys; if ``True``, additional
78 | keys are allowed (default: ``False``)
79 | :param ignore_unmatched_attributes: if ``False``, the template is strict and an
80 | error is raised if not all attributes have a corresponding key in the format
81 | string; if ``True``, attributes not present in the format string will be
82 | ignored (default: ``False``)
83 | """
84 | super().__init__()
85 | required_keys = as_set(
86 | required_keys,
87 | element_type=str,
88 | arg_name="formatting_keys",
89 | )
90 | self.formatting_keys = _validate_format_string(
91 | format_string,
92 | required_keys=required_keys,
93 | allow_additional_keys=allow_additional_keys,
94 | )
95 |
96 | self.format_string = format_string
97 | self.allow_additional_keys = allow_additional_keys
98 | self.ignore_unmatched_attributes = ignore_unmatched_attributes
99 |
100 | def format_with_attributes(self, **attributes: Any) -> str:
101 | """
102 | Formats the text using the format string and the given attributes passed as
103 | keyword arguments.
104 |
105 | :param attributes: the keyword arguments to use for formatting
106 | :return: the formatted text
107 | """
108 |
109 | if not self.ignore_unmatched_attributes:
110 | # We run in strict mode: ensure that the attribute keys match the formatting
111 | # keys
112 | if set(attributes) != self.formatting_keys:
113 | raise ValueError(
114 | f"Provided attributes must have the same keys as formatting keys "
115 | f"{self.formatting_keys!r}, but got {attributes!r}"
116 | )
117 |
118 | else:
119 | # We run in non-strict mode: ignore attributes not present in the formatting
120 | # keys but ensure that all formatting keys are present in the attributes
121 | missing_attributes = self.formatting_keys - set(attributes)
122 | if missing_attributes:
123 | raise ValueError(
124 | f"No values provided for formatting key"
125 | f"{_plural(missing_attributes)}: "
126 | + ", ".join(map(repr, sorted(missing_attributes)))
127 | )
128 |
129 | return self.format_string.format(**attributes)
130 |
131 | def to_expression(self) -> Expression:
132 | """[see superclass]"""
133 | return expression_from_init_params(self)
134 |
135 |
136 | def _validate_format_string(
137 | format_string: str, *, required_keys: Iterable[str], allow_additional_keys: bool
138 | ) -> set[str]:
139 | """
140 | Validate that the given format string contains the required keys, and that it does
141 | not contain any unexpected keys.
142 |
143 | :param format_string: the format string to validate
144 | :param required_keys: the required keys
145 | :param allow_additional_keys: if ``False``, an error is raised if the format string
146 | contains keys other than the required keys; if ``True``, additional keys are
147 | allowed
148 | :return: all formatting keys in the format string
149 | :raises TypeError: if the given format string is not a string
150 | """
151 |
152 | if not isinstance(format_string, str):
153 | raise TypeError(f"Format string must be a string, but got: {format_string!r}")
154 |
155 | # ensure arg expected_keys is a set
156 | required_keys = as_set(required_keys, element_type=str, arg_name="required_keys")
157 |
158 | # get all keys from the format string
159 | actual_keys = {
160 | field_name
161 | for _, field_name, _, _ in string.Formatter().parse(format_string)
162 | if field_name is not None
163 | }
164 |
165 | # check that the format string contains the required keys
166 | missing_keys = required_keys - actual_keys
167 | if missing_keys:
168 | raise ValueError(
169 | f"Format string is missing required key{_plural(missing_keys)}: "
170 | + ", ".join(map(repr, sorted(missing_keys)))
171 | )
172 |
173 | if not allow_additional_keys:
174 | # check that the format string does not contain unexpected keys
175 | unexpected_keys = actual_keys - required_keys
176 | if unexpected_keys:
177 | raise ValueError(
178 | f"Format string contains unexpected key{_plural(unexpected_keys)}: "
179 | + ", ".join(map(repr, sorted(unexpected_keys)))
180 | )
181 |
182 | return actual_keys
183 |
184 |
185 | def _plural(items: Sized) -> str:
186 | return "" if len(items) == 1 else "s"
187 |
--------------------------------------------------------------------------------
/src/pytools/text/_text.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for rendering text.
3 | """
4 |
5 | import logging
6 | from collections.abc import Iterable, Iterator, Sequence
7 | from typing import Any, Union
8 |
9 | import numpy.typing as npt
10 | import pandas as pd
11 |
12 | from pytools.api import AllTracker
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | #
18 | # Exported names
19 | #
20 |
21 | __all__ = ["CharacterMatrix", "format_table"]
22 |
23 |
24 | #
25 | # Ensure all symbols introduced below are included in __all__
26 | #
27 |
28 | __tracker = AllTracker(globals())
29 |
30 |
31 | #
32 | # Type definitions
33 | #
34 |
35 | _TextCoordinates = tuple[Union[int, slice], Union[int, slice]]
36 |
37 |
38 | #
39 | # Classes
40 | #
41 |
42 |
43 | class CharacterMatrix:
44 | """
45 | A matrix of characters, indexed by rows and columns.
46 |
47 | The matrix is initialised with space characters (`" "`).
48 |
49 | Characters can be "painted" in the matrix using 2D index assignments:
50 |
51 | - ``matrix[r, c] = chr`` assigns character `chr` at position `(r, c)`
52 | - ``matrix[r, c1:c2] = str`` writes string `str` at positions
53 | `(r, c1) … (r, c2 – 1)`; excess characters in `str` are clipped if ``len(s)`` is
54 | greater than `c2 – c1`
55 | - ``matrix[r, c1:c2] = chr`` (where `chr` is a single character) repeats `chr` at
56 | every position `(r, c1) … (r, c2 – 1)`
57 | - ``matrix[r1:r2, …] = …`` applies the same insertion at each of rows `r1 … r2 – 1`
58 | - full slice notation is supported so even slices of shape ``start:stop:step`` work
59 | as expected
60 | """
61 |
62 | def __init__(self, n_rows: int, n_columns: int) -> None:
63 | """
64 | :param n_rows: the matrix height
65 | :param n_columns: the matrix width
66 | """
67 | if n_columns <= 0:
68 | raise ValueError(f"arg width must be positive but is {n_columns}")
69 | if n_rows <= 0:
70 | raise ValueError(f"arg height must be positive but is {n_rows}")
71 | self._n_columns = n_columns
72 | self._matrix = [[" " for _ in range(n_columns)] for _ in range(n_rows)]
73 |
74 | @property
75 | def n_rows(self) -> int:
76 | """
77 | The height of this matrix.
78 |
79 | Same as ``len(self)``.
80 | """
81 | return len(self._matrix)
82 |
83 | @property
84 | def n_columns(self) -> int:
85 | """
86 | The width of this matrix.
87 | """
88 | return self._n_columns
89 |
90 | def lines(self, subset: Iterable[int] | None = None) -> Iterable[str]:
91 | """
92 | Get this character matrix as strings representing the matrix rows, stripping
93 | trailing whitespace.
94 |
95 | :param subset: indices of rows to return
96 | :return: the matrix rows as strings
97 | """
98 | matrix = self._matrix
99 | return (
100 | "".join(matrix[line]).rstrip()
101 | for line in (range(len(matrix)) if subset is None else subset)
102 | )
103 |
104 | @staticmethod
105 | def __key_as_slices(key: _TextCoordinates) -> tuple[slice, slice]:
106 | def _to_slice(index: int | slice) -> slice:
107 | if isinstance(index, int):
108 | return slice(index, index + 1)
109 | else:
110 | return index
111 |
112 | if not isinstance(key, tuple) or len(key) != 2:
113 | raise ValueError(f"expected (row, column) tuple but got {key}")
114 |
115 | row, column = key
116 | return _to_slice(row), _to_slice(column)
117 |
118 | def __str__(self) -> str:
119 | return "\n".join(self.lines())
120 |
121 | def __len__(self) -> int:
122 | return self.n_rows
123 |
124 | def __getitem__(self, key: _TextCoordinates) -> str:
125 | rows, columns = self.__key_as_slices(key)
126 | return "\n".join("".join(line[columns]) for line in self._matrix[rows])
127 |
128 | def __setitem__(self, key: _TextCoordinates, value: Any) -> None:
129 | rows, columns = self.__key_as_slices(key)
130 | value = str(value)
131 | single_char = len(value) == 1
132 | positions = range(*columns.indices(self.n_columns))
133 | for line in self._matrix[rows]:
134 | if single_char:
135 | for pos in positions:
136 | line[pos] = value
137 | else:
138 | for pos, char in zip(positions, value):
139 | line[pos] = char
140 |
141 |
142 | _ALIGNMENT_OPTIONS = ["<", "^", ">"]
143 |
144 |
145 | def format_table(
146 | headings: Sequence[str],
147 | data: pd.DataFrame | npt.NDArray[Any] | Sequence[Sequence[Any]],
148 | formats: Sequence[str | None] | None = None,
149 | alignment: Sequence[str | None] | None = None,
150 | ) -> str:
151 | """
152 | Print a formatted text table.
153 |
154 | :param headings: the table headings
155 | :param data: the table data, as an array-like with shape `[n_rows, n_columns]`
156 | :param formats: formatting strings for data in each row (optional);
157 | uses ``str()`` conversion for any formatting strings stated as ``None``
158 | :param alignment: text alignment for each column (optional); use ``"<"`` to align
159 | left, ``"="`` to center, ``">"`` to align right (defaults to left alignment)
160 | :return: the formatted table as a multi-line string
161 | """
162 | n_columns = len(headings)
163 |
164 | formats_seq: Sequence[str | None]
165 |
166 | if formats is None:
167 | formats_seq = [None] * n_columns
168 | elif len(formats) != n_columns:
169 | raise ValueError("arg formats must have the same length as arg headings")
170 | else:
171 | formats_seq = formats
172 |
173 | alignment_seq: Sequence[str | None]
174 |
175 | if alignment is None:
176 | alignment_seq = ["<"] * n_columns
177 | elif len(alignment) != n_columns:
178 | raise ValueError("arg alignment must have the same length as arg headings")
179 | elif not all(align in _ALIGNMENT_OPTIONS for align in alignment):
180 | raise ValueError(
181 | f"arg alignment must only contain alignment options "
182 | f'{", ".join(_ALIGNMENT_OPTIONS)}'
183 | )
184 | else:
185 | alignment_seq = alignment
186 |
187 | def _formatted(item: Any, format_string: str | None) -> str:
188 | if format_string is None:
189 | return str(item)
190 | else:
191 | return f"{item:{format_string}}"
192 |
193 | def _iterate_row_data() -> Iterable[Sequence[Any]] | Iterable[pd.Series]:
194 | if isinstance(data, pd.DataFrame):
195 | return (row for _, row in data.iterrows())
196 | else:
197 | return iter(data)
198 |
199 | def _make_row(items: Sequence[Any] | pd.Series) -> list[str]:
200 | if len(items) != n_columns:
201 | raise ValueError(
202 | "rows in data matrix must have the same length as arg headings"
203 | )
204 | return [
205 | _formatted(item, format_string)
206 | for item, format_string in zip(items, formats_seq)
207 | ]
208 |
209 | body_rows: list[list[str]] = [_make_row(items) for items in _iterate_row_data()]
210 |
211 | column_widths: list[int] = [
212 | max(column_lengths)
213 | for column_lengths in zip(
214 | *(
215 | (len(item) for item in row)
216 | for row in (
217 | headings,
218 | *[_make_row(items) for items in _iterate_row_data()],
219 | )
220 | )
221 | )
222 | ]
223 |
224 | dividers = ["=" * column_width for column_width in column_widths]
225 |
226 | def _format_rows(rows: Sequence[Sequence[str]], align: bool) -> Iterator[str]:
227 | return (
228 | " ".join(
229 | (
230 | f'{item:{align_char if align else ""}{column_width}s}'
231 | for item, align_char, column_width in zip(
232 | row, alignment_seq, column_widths
233 | )
234 | )
235 | )
236 | for row in rows
237 | )
238 |
239 | return "\n".join(
240 | (
241 | *(_format_rows(rows=[headings, dividers], align=False)),
242 | *(_format_rows(rows=body_rows, align=True)),
243 | "",
244 | )
245 | )
246 |
247 |
248 | __tracker.validate()
249 |
--------------------------------------------------------------------------------
/src/pytools/typing/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for run-time type inference.
3 | """
4 |
5 | from ._typing import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | A lean MVC framework for rendering basic visualizations in different styles, e.g.,
3 | as `matplotlib` charts, HTML, or as plain text.
4 | """
5 |
6 | from ._html import *
7 | from ._matplot import *
8 | from ._notebook import *
9 | from ._text import *
10 | from ._viz import *
11 |
--------------------------------------------------------------------------------
/src/pytools/viz/_html.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ``HTMLStyle``.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | import sys
9 | from abc import ABCMeta
10 | from io import StringIO
11 | from typing import Any, Generic, TextIO, TypeVar, cast
12 |
13 | from ..api import appenddoc, inheritdoc
14 | from ._notebook import is_running_in_notebook
15 | from ._viz import ColoredStyle
16 | from .color import ColorScheme, RgbColor
17 |
18 | log = logging.getLogger(__name__)
19 |
20 | __all__ = [
21 | "HTMLStyle",
22 | ]
23 | #
24 | # Type variables
25 | #
26 |
27 | T_ColorScheme = TypeVar("T_ColorScheme", bound=ColorScheme)
28 |
29 |
30 | #
31 | # Classes
32 | #
33 |
34 |
35 | @inheritdoc(match="[see superclass]")
36 | class HTMLStyle(ColoredStyle[T_ColorScheme], Generic[T_ColorScheme], metaclass=ABCMeta):
37 | """
38 | Abstract base class for styles for rendering output as HTML.
39 |
40 | Supports color schemes, and is able to display output in a notebook (if running in
41 | one), ``stdout``, or a given output stream.
42 | """
43 |
44 | #: The output stream this style instance writes to; or ``None`` if output should
45 | #: be displayed in a Jupyter notebook
46 | out: TextIO | None
47 |
48 | #: Whether the output should be displayed in a Jupyter notebook
49 | _send_to_notebook: bool = False
50 |
51 | @appenddoc(to=ColoredStyle.__init__)
52 | def __init__(
53 | self, *, colors: T_ColorScheme | None = None, out: TextIO | None = None
54 | ) -> None:
55 | """
56 | :param out: the output stream this style instance writes to; if ``None`` and
57 | running in a Jupyter notebook, output will be displayed in the notebook,
58 | otherwise it will be written to ``stdout``
59 | """
60 | super().__init__(colors=colors)
61 |
62 | if out is None: # pragma: no cover
63 | if is_running_in_notebook():
64 | self.out = StringIO()
65 | self._send_to_notebook = True
66 | else:
67 | self.out = sys.stdout
68 | self._send_to_notebook = False
69 | else:
70 | self.out = out
71 | self._send_to_notebook = False
72 |
73 | @classmethod
74 | def get_default_style_name(cls) -> str:
75 | """[see superclass]"""
76 | return "html"
77 |
78 | @staticmethod
79 | def rgb_to_css(rgb: RgbColor) -> str:
80 | """
81 | Convert an RGB color to its CSS representation in the form ``rgb(r,g,b)``,
82 | where ``r``, ``g``, and ``b`` are integers in the range 0-255.
83 |
84 | :param rgb: the RGB color
85 | :return: the CSS representation of the color
86 | """
87 | rgb_0_to_255 = ",".join(str(int(luminance * 255)) for luminance in rgb)
88 | return f"rgb({rgb_0_to_255})"
89 |
90 | def start_drawing(self, *, title: str, **kwargs: Any) -> None:
91 | """[see superclass]"""
92 | super().start_drawing(title=title, **kwargs)
93 |
94 | # we start a section, setting the colors
95 | print(
96 | ''
97 | f'',
102 | file=self.out,
103 | )
104 |
105 | # print the title
106 | print(self.render_title(title=title), file=self.out)
107 |
108 | def finalize_drawing(self, **kwargs: Any) -> None:
109 | """[see superclass]"""
110 | super().finalize_drawing()
111 | # we close the section
112 | print("
", file=self.out)
113 |
114 | # if we are in a notebook, display the HTML
115 | if self._send_to_notebook:
116 | from IPython.display import HTML, display
117 |
118 | display(HTML(cast(StringIO, self.out).getvalue()))
119 |
120 | # noinspection PyMethodMayBeStatic
121 | def render_title(self, title: str) -> str:
122 | """
123 | Render the title of the drawing as HTML.
124 |
125 | :param title: the title of the drawing
126 | :return: the HTML code of the title
127 | """
128 | return f"{title}
"
129 |
--------------------------------------------------------------------------------
/src/pytools/viz/_notebook.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ``is_running_in_notebook``.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 |
9 | log = logging.getLogger(__name__)
10 |
11 | __all__ = [
12 | "is_running_in_notebook",
13 | ]
14 |
15 |
16 | def is_running_in_notebook() -> bool:
17 | """
18 | Check if the code is running in a notebook like Jupyter or Colab.
19 |
20 | Useful to determine whether to display plots inline or not.
21 |
22 | :return: whether the code is running in a Jupyter notebook
23 | """
24 | # make sure we're in a proper notebook, not running from a shell
25 | try:
26 | # get the shell
27 | # noinspection PyUnresolvedReferences
28 | shell: str = get_ipython().__class__.__name__ # type: ignore[name-defined]
29 |
30 | # check if we're in a notebook
31 | return shell == "ZMQInteractiveShell"
32 | except NameError:
33 | # check for the presence of the "google.colab" module
34 | try:
35 | # noinspection PyUnresolvedReferences
36 | import google.colab # noqa: F401
37 |
38 | return True
39 | except ImportError:
40 | # we're not in a notebook
41 | return False
42 |
--------------------------------------------------------------------------------
/src/pytools/viz/_text.py:
--------------------------------------------------------------------------------
1 | """
2 | Text styles for the visualization library.
3 | """
4 |
5 | import logging
6 | import sys
7 | from abc import ABCMeta
8 | from typing import Any, TextIO
9 |
10 | from ..api import AllTracker, inheritdoc
11 | from ._viz import DrawingStyle
12 |
13 | log = logging.getLogger(__name__)
14 |
15 |
16 | #
17 | # Exported names
18 | #
19 |
20 | __all__ = ["TextStyle"]
21 |
22 |
23 | #
24 | # Ensure all symbols introduced below are included in __all__
25 | #
26 |
27 | __tracker = AllTracker(globals())
28 |
29 |
30 | #
31 | # Classes
32 | #
33 |
34 |
35 | @inheritdoc(match="[see superclass]")
36 | class TextStyle(DrawingStyle, metaclass=ABCMeta):
37 | """
38 | Base class of drawing styles producing plain text output.
39 | """
40 |
41 | #: The output stream this style instance writes to.
42 | out: TextIO
43 |
44 | #: The maximum width of the text to be produced.
45 | width: int
46 |
47 | def __init__(self, out: TextIO | None = None, width: int = 80) -> None:
48 | """
49 | :param out: the output stream this style instance writes to
50 | (defaults to :obj:`sys.stdout`)
51 | :param width: the maximum width available to render the text, defaults to 80
52 | """
53 |
54 | super().__init__()
55 |
56 | if width <= 0:
57 | raise ValueError(
58 | f"arg width expected to be positive integer but is {width}"
59 | )
60 | self.out = sys.stdout if out is None else out
61 | self.width = width
62 |
63 | @classmethod
64 | def get_default_style_name(cls) -> str:
65 | """[see superclass]"""
66 | return "text"
67 |
68 | def start_drawing(self, *, title: str, **kwargs: Any) -> None:
69 | """
70 | Write the title to :attr:`out`.
71 |
72 | :param title: the title of the drawing
73 | :param kwargs: additional drawer-specific arguments
74 | """
75 | super().start_drawing(title=title, **kwargs)
76 |
77 | print(f"{f' {title} ':=^{self.width}s}\n", file=self.out)
78 |
79 | def finalize_drawing(self, **kwargs: Any) -> None:
80 | """
81 | Add a blank line to the end of the text output.
82 |
83 | :param kwargs: additional drawer-specific arguments
84 | """
85 |
86 | try:
87 | print(file=self.out)
88 |
89 | finally:
90 | super().finalize_drawing(**kwargs)
91 |
92 |
93 | __tracker.validate()
94 |
--------------------------------------------------------------------------------
/src/pytools/viz/color/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Color definitions.
3 | """
4 |
5 | from ._color import *
6 | from ._rgb import *
7 |
--------------------------------------------------------------------------------
/src/pytools/viz/color/_rgb.py:
--------------------------------------------------------------------------------
1 | """
2 | Core implementation of :mod:`pytools.viz.color`
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | from typing import Any, TypeAlias, cast, overload
9 |
10 | from matplotlib.colors import to_rgb
11 |
12 | from pytools.api import AllTracker
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | #
18 | # Constants
19 | #
20 |
21 | ALPHA_DEFAULT = 1.0
22 |
23 |
24 | #
25 | # Exported names
26 | #
27 |
28 | __all__ = [
29 | "RgbColor",
30 | "RgbaColor",
31 | ]
32 |
33 |
34 | #
35 | # Type aliases
36 | #
37 | TupleRgb: TypeAlias = tuple[float, float, float]
38 |
39 | #
40 | # Ensure all symbols introduced below are included in __all__
41 | #
42 |
43 | __tracker = AllTracker(globals())
44 |
45 |
46 | #
47 | # Classes
48 | #
49 |
50 |
51 | class _RgbBase(tuple): # type: ignore
52 | @property
53 | def r(self: tuple[float, ...]) -> float:
54 | """
55 | The luminosity value for the *red* channel.
56 | """
57 | return self[0]
58 |
59 | @property
60 | def g(self: tuple[float, ...]) -> float:
61 | """
62 | The luminosity value for the *green* channel.
63 | """
64 | return self[1]
65 |
66 | @property
67 | def b(self: tuple[float, ...]) -> float:
68 | """
69 | The luminosity value for the *blue* channel.
70 | """
71 | return self[2]
72 |
73 | @classmethod
74 | def _check_arg_count(
75 | cls, args: tuple[Any, ...], kwargs: dict[str, Any], max_allowed: int
76 | ) -> None:
77 | if len(args) + len(kwargs) > max_allowed:
78 | args_list = ", ".join(
79 | (
80 | *map(repr, args),
81 | *(f"{name}={value!r}" for name, value in kwargs.items()),
82 | )
83 | )
84 | raise ValueError(
85 | f"{cls.__name__} expects at most {max_allowed} arguments but got: "
86 | f"{args_list}"
87 | )
88 |
89 | @property
90 | def hex(self) -> str:
91 | """
92 | The hexadecimal representation of this color.
93 | """
94 | # convert floats to bytes in the range 0-255
95 | return "#" + "".join(f"{int(round(c * 255)):02x}" for c in self)
96 |
97 |
98 | class RgbColor(_RgbBase):
99 | """
100 | RGB color type for use in color schemas and colored drawing styles.
101 | """
102 |
103 | @overload
104 | def __new__(cls, r: float, g: float, b: float) -> RgbColor:
105 | pass
106 |
107 | @overload
108 | def __new__(cls, c: str) -> RgbColor:
109 | pass
110 |
111 | def __new__(cls, *args: Any, **kwargs: Any) -> RgbColor:
112 | """
113 | :param r: the luminosity value for the *red* channel
114 | :param g: the luminosity value for the *green* channel
115 | :param b: the luminosity value for the *blue* channel
116 | :param c: a named color (see
117 | `matplotlib.colors `__)
118 | """
119 |
120 | cls._check_arg_count(args, kwargs, 3)
121 |
122 | rgb, alpha = _to_rgba(*args, **kwargs)
123 |
124 | if alpha is not None:
125 | raise ValueError(
126 | "alpha channel is not supported by RgbColor, use RgbaColor instead"
127 | )
128 |
129 | return cast(_RgbBase, super()).__new__(cls, rgb)
130 |
131 |
132 | class RgbaColor(_RgbBase):
133 | """
134 | RGB + Alpha color type for use in color schemas and colored drawing styles.
135 | """
136 |
137 | @overload
138 | def __new__(
139 | cls, r: float, g: float, b: float, alpha: float | None = None
140 | ) -> RgbaColor:
141 | pass
142 |
143 | @overload
144 | def __new__(cls, c: str, alpha: float | None = None) -> RgbaColor:
145 | pass
146 |
147 | def __new__(cls, *args: Any, **kwargs: Any) -> RgbaColor:
148 | """
149 | :param r: the luminosity value for the *red* channel
150 | :param g: the luminosity value for the *green* channel
151 | :param b: the luminosity value for the *blue* channel
152 | :param alpha: the opacity value for the *alpha* channel
153 | :param c: a named color (see
154 | `matplotlib.colors `__)
155 | """
156 | cls._check_arg_count(args, kwargs, 4)
157 |
158 | rgb, alpha = _to_rgba(*args, **kwargs)
159 |
160 | return cast(_RgbBase, super()).__new__(
161 | cls,
162 | (*rgb, ALPHA_DEFAULT if alpha is None else alpha),
163 | )
164 |
165 | @property
166 | def alpha(self: tuple[float, ...]) -> float:
167 | """
168 | The opacity value for the *alpha* channel.
169 | """
170 | return self[3]
171 |
172 |
173 | @overload
174 | def _to_rgba(
175 | r: float, g: float, b: float, alpha: float | None = None
176 | ) -> tuple[TupleRgb, float | None]:
177 | pass
178 |
179 |
180 | @overload
181 | def _to_rgba(c: str, alpha: float | None = None) -> tuple[TupleRgb, float | None]:
182 | pass
183 |
184 |
185 | def _to_rgba(
186 | *args: Any,
187 | r: float | None = None,
188 | g: float | None = None,
189 | b: float | None = None,
190 | alpha: float | None = None,
191 | c: str | None = None,
192 | ) -> tuple[TupleRgb, float | None]:
193 | n_rgb_kwargs = (r is not None) + (g is not None) + (b is not None)
194 | if n_rgb_kwargs in (1, 2):
195 | raise ValueError(
196 | "incomplete RGB keyword arguments: need to provide r, g, and b"
197 | )
198 |
199 | has_rgb_kwargs = n_rgb_kwargs > 0
200 | has_c = c is not None
201 |
202 | if args and (has_rgb_kwargs or has_c):
203 | raise ValueError(
204 | "mixed use of positional and keyword arguments for color arguments"
205 | )
206 | if has_rgb_kwargs and has_c:
207 | raise ValueError("mixed use of named color and color channels")
208 |
209 | # case 1: named color
210 |
211 | if args and isinstance(args[0], str):
212 | if len(args) == 1:
213 | c = args[0]
214 | elif len(args) == 2 and alpha is None:
215 | c, alpha = args
216 |
217 | if not (alpha is None or isinstance(alpha, (float, int))):
218 | raise ValueError(f"alpha must be numeric but is: {alpha!r}")
219 |
220 | if isinstance(c, str):
221 | return to_rgb(c), alpha
222 | elif c is not None:
223 | raise ValueError(f"single color argument must be a string but is: {c!r}")
224 |
225 | # case 2: color channels
226 |
227 | rgb: TupleRgb
228 |
229 | if has_rgb_kwargs:
230 | assert not (r is None or g is None or b is None)
231 | rgb = (r, g, b)
232 | else:
233 | if not all(isinstance(x, (float, int)) for x in args):
234 | raise ValueError(f"all color arguments must be numeric, but are: {args}")
235 |
236 | if len(args) == 3:
237 | rgb = cast(TupleRgb, args)
238 | elif len(args) == 4:
239 | rgb = cast(TupleRgb, args[:3])
240 | if alpha is None:
241 | alpha = args[3]
242 | else:
243 | raise ValueError(f"need 3 RGB values but got {args}")
244 | else:
245 | raise ValueError(f"need 3 RGB values or 4 RGBA values but got: {args}")
246 |
247 | if not all(isinstance(x, (float, int)) and 0.0 <= x <= 1.0 for x in rgb):
248 | raise ValueError(f"invalid RGB values: {rgb}")
249 |
250 | if not (alpha is None or 0.0 <= alpha <= 1.0):
251 | raise ValueError(f"invalid alpha value: {alpha}")
252 |
253 | assert len(rgb) == 3
254 |
255 | return rgb, alpha
256 |
257 |
258 | # check consistency of __all__
259 |
260 | __tracker.validate()
261 |
262 |
263 | #
264 | # helper methods
265 | #
266 |
--------------------------------------------------------------------------------
/src/pytools/viz/dendrogram/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Drawer and styles for dendrogram representations of linkage trees.
3 | """
4 |
5 | from ._draw import *
6 | from ._style import *
7 |
--------------------------------------------------------------------------------
/src/pytools/viz/dendrogram/_draw.py:
--------------------------------------------------------------------------------
1 | """
2 | Drawing dendrograms.
3 | """
4 |
5 | import logging
6 | from collections.abc import Iterable
7 | from typing import Any, NamedTuple
8 |
9 | from ...api import AllTracker, inheritdoc
10 | from ...data import LinkageTree
11 | from ...data.linkage import Node
12 | from .. import Drawer
13 | from ._style import DendrogramMatplotStyle, DendrogramReportStyle
14 | from .base import DendrogramStyle
15 |
16 | log = logging.getLogger(__name__)
17 |
18 |
19 | #
20 | # Exported names
21 | #
22 |
23 | __all__ = ["DendrogramDrawer"]
24 |
25 |
26 | #
27 | # Ensure all symbols introduced below are included in __all__
28 | #
29 |
30 | __tracker = AllTracker(globals())
31 |
32 |
33 | #
34 | # Classes
35 | #
36 |
37 |
38 | class _SubtreeInfo(NamedTuple):
39 | names: list[str]
40 | weights: list[float]
41 | weight_total: float
42 |
43 |
44 | @inheritdoc(match="[see superclass]")
45 | class DendrogramDrawer(Drawer[LinkageTree, DendrogramStyle]):
46 | """
47 | Draws dendrogram representations of :class:`.LinkageTree` objects.
48 | """
49 |
50 | # defined in superclass, repeated here for Sphinx
51 | style: DendrogramStyle
52 |
53 | def __init__(self, style: DendrogramStyle | str | None = None) -> None:
54 | """[see superclass]"""
55 | super().__init__(style=style)
56 |
57 | @classmethod
58 | def get_style_classes(cls) -> Iterable[type[DendrogramStyle]]:
59 | """[see superclass]"""
60 | return [
61 | DendrogramMatplotStyle,
62 | DendrogramReportStyle,
63 | ]
64 |
65 | @classmethod
66 | def get_default_style(cls) -> DendrogramStyle:
67 | """[see superclass]"""
68 | return DendrogramMatplotStyle()
69 |
70 | def get_style_kwargs(self, data: LinkageTree) -> dict[str, Any]:
71 | """[see superclass]"""
72 | return dict(
73 | leaf_label=data.leaf_label,
74 | distance_label=data.distance_label,
75 | weight_label=data.weight_label,
76 | max_distance=data.max_distance,
77 | leaf_names=tuple(leaf.name for leaf in data.iter_nodes(inner=False)),
78 | **super().get_style_kwargs(data=data),
79 | )
80 |
81 | def _draw(self, data: LinkageTree) -> None:
82 | # draw the linkage tree
83 |
84 | def _draw_node(
85 | node: Node, node_idx: int, weight_cumulative: float, width: float
86 | ) -> _SubtreeInfo:
87 | # Recursively draw the part of the dendrogram under a node.
88 | #
89 | # Arguments:
90 | # node: the node to be drawn
91 | # node_idx: an integer determining the position of the node with respect
92 | # to the leaves of the tree
93 | # weight_cumulative:
94 | # the cumulative weight of all nodes with a lower position
95 | # width: width difference in the tree covered by the node
96 | #
97 | # Returns:
98 | # _SubtreeInfo instance with labels and weights
99 |
100 | if node.is_leaf:
101 | self.style.draw_link_leg(
102 | bottom=0.0,
103 | top=width,
104 | leaf=node_idx,
105 | weight=node.weight,
106 | weight_cumulative=weight_cumulative,
107 | tree_height=data.max_distance,
108 | )
109 |
110 | return _SubtreeInfo(
111 | names=[node.name], weights=[node.weight], weight_total=node.weight
112 | )
113 |
114 | else:
115 | children = data.children(node=node)
116 | assert children is not None, "children of non-leaf node are defined"
117 | child_left, child_right = children
118 |
119 | subtree_left_info = _draw_node(
120 | node=child_left,
121 | node_idx=node_idx,
122 | weight_cumulative=weight_cumulative,
123 | width=node.children_distance,
124 | )
125 | subtree_right_info = _draw_node(
126 | node=child_right,
127 | node_idx=node_idx + len(subtree_left_info.names),
128 | weight_cumulative=weight_cumulative
129 | + subtree_left_info.weight_total,
130 | width=node.children_distance,
131 | )
132 |
133 | parent_info = _SubtreeInfo(
134 | names=subtree_left_info.names + subtree_right_info.names,
135 | weights=subtree_left_info.weights + subtree_right_info.weights,
136 | weight_total=(
137 | subtree_left_info.weight_total + subtree_right_info.weight_total
138 | ),
139 | )
140 |
141 | self.style.draw_link_connector(
142 | bottom=node.children_distance,
143 | top=width,
144 | first_leaf=node_idx,
145 | n_leaves_left=len(subtree_left_info.names),
146 | n_leaves_right=len(subtree_right_info.names),
147 | weight=parent_info.weight_total,
148 | weight_cumulative=weight_cumulative,
149 | tree_height=data.max_distance,
150 | )
151 |
152 | return parent_info
153 |
154 | tree_info = _draw_node(
155 | node=data.root, node_idx=0, weight_cumulative=0.0, width=data.max_distance
156 | )
157 | self.style.draw_leaf_labels(names=tree_info.names, weights=tree_info.weights)
158 |
159 |
160 | __tracker.validate()
161 |
--------------------------------------------------------------------------------
/src/pytools/viz/dendrogram/base/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for dendrogram representations.
3 | """
4 |
5 | from ._style import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/dendrogram/base/_style.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for dendrogram styles.
3 | """
4 |
5 | import logging
6 | from abc import ABCMeta, abstractmethod
7 | from collections.abc import Sequence
8 | from typing import Any
9 |
10 | from pytools.api import AllTracker
11 | from pytools.viz import DrawingStyle
12 |
13 | log = logging.getLogger(__name__)
14 |
15 | #
16 | # Exported names
17 | #
18 |
19 | __all__ = ["DendrogramStyle"]
20 |
21 |
22 | #
23 | # Ensure all symbols introduced below are included in __all__
24 | #
25 |
26 | __tracker = AllTracker(globals())
27 |
28 |
29 | #
30 | # Classes
31 | #
32 |
33 |
34 | class DendrogramStyle(DrawingStyle, metaclass=ABCMeta):
35 | """
36 | Base class for dendrogram drawing styles.
37 | """
38 |
39 | def start_drawing(
40 | self,
41 | *,
42 | title: str,
43 | leaf_label: str | None = None,
44 | distance_label: str | None = None,
45 | weight_label: str | None = None,
46 | max_distance: float | None = None,
47 | leaf_names: Sequence[str] | None = None,
48 | **kwargs: Any,
49 | ) -> None:
50 | """
51 | Prepare a new dendrogram for drawing, using the given title.
52 |
53 | :param title: the title of the chart
54 | :param leaf_label: the label for the leaf axis
55 | :param distance_label: the label for the distance axis
56 | :param weight_label: the label for the weight scale
57 | :param max_distance: the height (= maximum possible distance) of the dendrogram
58 | :param leaf_names: the names of the dendrogram leaf nodes
59 | :param kwargs: additional drawer-specific arguments
60 | """
61 |
62 | none_args: list[str] = [
63 | arg
64 | for arg, value in {
65 | "leaf_label": leaf_label,
66 | "distance_label": distance_label,
67 | "weight_label": weight_label,
68 | "max_distance": max_distance,
69 | "leaf_names": leaf_names,
70 | }.items()
71 | if value is None
72 | ]
73 | if none_args:
74 | raise ValueError(
75 | "keyword arguments must not be None: " + ", ".join(none_args)
76 | )
77 |
78 | super().start_drawing(title=title, **kwargs)
79 |
80 | def finalize_drawing(
81 | self,
82 | *,
83 | leaf_label: str | None = None,
84 | distance_label: str | None = None,
85 | weight_label: str | None = None,
86 | max_distance: float | None = None,
87 | leaf_names: Sequence[str] | None = None,
88 | **kwargs: Any,
89 | ) -> None:
90 | """
91 | Finalize the dendrogram, adding labels to the axes.
92 |
93 | :param leaf_label: the label for the leaf axis
94 | :param distance_label: the label for the distance axis
95 | :param weight_label: the label for the weight scale
96 | :param max_distance: the height (= maximum possible distance) of the dendrogram
97 | :param leaf_names: the names of the dendrogram leaf nodes
98 | :param kwargs: additional drawer-specific arguments
99 | """
100 | super().finalize_drawing(**kwargs)
101 |
102 | @abstractmethod
103 | def draw_leaf_labels(
104 | self, *, names: Sequence[str], weights: Sequence[float]
105 | ) -> None:
106 | """
107 | Render the labels for all leaves.
108 |
109 | :param names: the names of all leaves
110 | :param weights: the weights of all leaves
111 | """
112 | pass
113 |
114 | @abstractmethod
115 | def draw_link_leg(
116 | self,
117 | *,
118 | bottom: float,
119 | top: float,
120 | leaf: float,
121 | weight: float,
122 | weight_cumulative: float,
123 | tree_height: float,
124 | ) -> None:
125 | """
126 | Draw a "leg" connecting two levels of the linkage tree hierarchy.
127 |
128 | :param bottom: the height of the child node in the linkage tree
129 | :param top: the height of the parent node in the linkage tree
130 | :param leaf: the index of the leaf where the link leg should be drawn (may be
131 | a ``float``, indicating a position in between two leaves)
132 | :param weight: the weight of the child node
133 | :param weight_cumulative: the cumulative weight of all nodes with a lower
134 | position index than the current one
135 | :param tree_height: the total height of the linkage tree
136 | """
137 | pass
138 |
139 | @abstractmethod
140 | def draw_link_connector(
141 | self,
142 | *,
143 | bottom: float,
144 | top: float,
145 | first_leaf: int,
146 | n_leaves_left: int,
147 | n_leaves_right: int,
148 | weight: float,
149 | weight_cumulative: float,
150 | tree_height: float,
151 | ) -> None:
152 | """
153 | Draw a connector between two sub-trees and their parent node.
154 |
155 | :param bottom: the height (i.e. cluster distance) of the sub-trees
156 | :param top: the height of the parent node
157 | :param first_leaf: the index of the first leaf in the left sub-tree
158 | :param n_leaves_left: the number of leaves in the left sub-tree
159 | :param n_leaves_right: the number of leaves in the right sub-tree
160 | :param weight: the weight of the parent node
161 | :param weight_cumulative: the cumulative weight of all nodes with a lower
162 | position index than the current one
163 | :param tree_height: the total height of the linkage tree
164 | """
165 | pass
166 |
167 |
168 | __tracker.validate()
169 |
--------------------------------------------------------------------------------
/src/pytools/viz/distribution/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Plotting distributions for exploratory data visualization.
3 | """
4 |
5 | from ._distribution import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/distribution/base/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for distribution styles.
3 | """
4 |
5 | from ._base import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/distribution/base/_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for distribution styles.
3 | """
4 |
5 | import logging
6 | from abc import ABCMeta, abstractmethod
7 | from collections.abc import Sequence
8 |
9 | from pytools.api import AllTracker
10 | from pytools.viz import DrawingStyle
11 |
12 | log = logging.getLogger(__name__)
13 |
14 | #
15 | # Exported names
16 | #
17 |
18 | __all__ = ["XYSeries", "ECDF", "ECDFStyle"]
19 |
20 |
21 | #
22 | # Ensure all symbols introduced below are included in __all__
23 | #
24 |
25 | __tracker = AllTracker(globals())
26 |
27 |
28 | #
29 | # Classes
30 | #
31 |
32 |
33 | class XYSeries:
34 | """
35 | Series of `x` and `y` coordinates for plotting; `x` and `y` values are stored in two
36 | separate sequences of the same length.
37 | """
38 |
39 | def __init__(self, x: Sequence[float], y: Sequence[float]) -> None:
40 | """
41 | :param x: series of all `x` coordinate values
42 | :param y: series of all `y` coordinate values
43 | """
44 | assert len(x) == len(y), "x and y have the same length"
45 | self.x = x
46 | self.y = y
47 |
48 | #: series of all `x` coordinate values
49 | x: Sequence[float]
50 |
51 | #: series of all `y` coordinate values
52 | y: Sequence[float]
53 |
54 |
55 | class ECDF:
56 | """
57 | Three sets of coordinates for plotting an ECDF: inliers, outliers, and far
58 | outliers.
59 | """
60 |
61 | def __init__(
62 | self, inliers: XYSeries, outliers: XYSeries, far_outliers: XYSeries
63 | ) -> None:
64 | """
65 | :param inliers: coordinates for inliers in the ECDF
66 | :param outliers: coordinates for outliers in the ECDF
67 | :param far_outliers: coordinates for far outliers in the ECDF
68 | """
69 | self._inliers = inliers
70 | self._outliers = outliers
71 | self._far_outliers = far_outliers
72 |
73 | @property
74 | def inliers(self) -> XYSeries:
75 | """
76 | Coordinates for inliers in the ECDF.
77 | """
78 | return self._inliers
79 |
80 | @property
81 | def outliers(self) -> XYSeries:
82 | """
83 | Coordinates for outliers in the ECDF.
84 | """
85 | return self._outliers
86 |
87 | @property
88 | def far_outliers(self) -> XYSeries:
89 | """
90 | Coordinates for far outliers in the ECDF.
91 | """
92 | return self._far_outliers
93 |
94 |
95 | class ECDFStyle(DrawingStyle, metaclass=ABCMeta):
96 | """
97 | Base drawing style for ECDFs.
98 | """
99 |
100 | @abstractmethod
101 | def _draw_ecdf(
102 | self,
103 | ecdf: ECDF,
104 | x_label: str,
105 | iqr_multiple: float | None,
106 | iqr_multiple_far: float | None,
107 | ) -> None:
108 | pass
109 |
110 |
111 | __tracker.validate()
112 |
--------------------------------------------------------------------------------
/src/pytools/viz/matrix/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Plotting matrices for exploratory data visualization.
3 | """
4 |
5 | from ._matrix import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/matrix/base/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for matrix styles.
3 | """
4 |
5 | from ._base import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/matrix/base/_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base classes for matrix styles.
3 | """
4 |
5 | import logging
6 | from abc import ABCMeta, abstractmethod
7 | from typing import Any
8 |
9 | import numpy as np
10 | import numpy.typing as npt
11 |
12 | from pytools.api import AllTracker
13 | from pytools.viz import DrawingStyle
14 |
15 | log = logging.getLogger(__name__)
16 |
17 | #
18 | # Exported names
19 | #
20 |
21 | __all__ = ["MatrixStyle"]
22 |
23 |
24 | #
25 | # Ensure all symbols introduced below are included in __all__
26 | #
27 |
28 | __tracker = AllTracker(globals())
29 |
30 |
31 | #
32 | # Classes
33 | #
34 |
35 |
36 | class MatrixStyle(DrawingStyle, metaclass=ABCMeta):
37 | """
38 | Base class for matrix drawer styles.
39 | """
40 |
41 | def start_drawing(
42 | self,
43 | *,
44 | title: str,
45 | name_labels: tuple[str | None, str | None] = (None, None),
46 | weight_label: str | None = None,
47 | **kwargs: Any,
48 | ) -> None:
49 | """
50 | Initialize the matrix plot.
51 |
52 | :param title: the title of the matrix
53 | :param name_labels: the labels for the row and column axes
54 | :param weight_label: the label for the `weight` axis
55 | :param kwargs: additional drawer-specific arguments
56 | """
57 |
58 | super().start_drawing(title=title, **kwargs)
59 |
60 | @abstractmethod
61 | def draw_matrix(
62 | self,
63 | data: npt.NDArray[Any],
64 | *,
65 | names: tuple[
66 | npt.NDArray[Any] | None,
67 | npt.NDArray[Any] | None,
68 | ],
69 | weights: tuple[
70 | npt.NDArray[np.float64] | None,
71 | npt.NDArray[np.float64] | None,
72 | ],
73 | ) -> None:
74 | """
75 | Draw the matrix.
76 |
77 | :param data: the values of the matrix cells, as a 2d array
78 | :param names: the names of the rows and columns
79 | :param weights: the weights of the rows and columns
80 | """
81 | pass
82 |
83 | def finalize_drawing(
84 | self,
85 | name_labels: tuple[str | None, str | None] | None = None,
86 | weight_label: str | None = None,
87 | **kwargs: Any,
88 | ) -> None:
89 | """
90 | Finalize the matrix plot.
91 |
92 | :param name_labels: the labels for the row and column axes
93 | :param weight_label: the label for the `weight` axis
94 | :param kwargs: additional drawer-specific arguments
95 | """
96 |
97 | super().finalize_drawing(**kwargs)
98 |
99 |
100 | __tracker.validate()
101 |
--------------------------------------------------------------------------------
/src/pytools/viz/util/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Base and auxiliary classes for visualizations.
3 | """
4 |
5 | from ._matplot import *
6 |
--------------------------------------------------------------------------------
/src/pytools/viz/util/_matplot.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities related to matplotlib.
3 | """
4 |
5 | import logging
6 | from typing import Any
7 |
8 | from matplotlib.backend_bases import RendererBase
9 | from matplotlib.text import Text
10 | from matplotlib.ticker import Formatter
11 |
12 | from pytools.api import AllTracker
13 | from pytools.meta import SingletonMeta
14 |
15 | log = logging.getLogger(__name__)
16 |
17 |
18 | #
19 | # Exported names
20 | #
21 |
22 | __all__ = [
23 | "FittedText",
24 | "PercentageFormatter",
25 | ]
26 |
27 |
28 | #
29 | # Ensure all symbols introduced below are included in __all__
30 | #
31 |
32 | __tracker = AllTracker(globals())
33 |
34 |
35 | #
36 | # Classes
37 | #
38 |
39 |
40 | class PercentageFormatter(
41 | Formatter, # type: ignore
42 | metaclass=SingletonMeta,
43 | ):
44 | """
45 | Formats floats as a percentages with 3 digits precision, omitting trailing zeros.
46 |
47 | For percentages above 100%, formats percentages as the nearest whole number.
48 |
49 | Formatting examples:
50 |
51 | - ``0.00005`` is formatted as ``0.01%``
52 | - ``0.0005`` is formatted as ``0.05%``
53 | - ``0.0`` is formatted as ``0%``
54 | - ``0.1`` is formatted as ``10%``
55 | - ``1.0`` is formatted as ``100%``
56 | - ``0.01555`` is formatted as ``1.56%``
57 | - ``0.1555`` is formatted as ``15.6%``
58 | - ``1.555`` is formatted as ``156%``
59 | - ``15.55`` is formatted as ``1556%``
60 | - ``1555`` is formatted as ``1.6e+05%``
61 | """
62 |
63 | def __call__(self, x: float, pos: int | None = None) -> str:
64 | if x < 1.0:
65 | return f"{x * 100.0:.3g}%"
66 | else:
67 | return f"{round(x * 100.0):.5g}%"
68 |
69 |
70 | class FittedText(
71 | Text, # type: ignore
72 | ):
73 | """
74 | Handle storing and drawing of text in window or data coordinates;
75 | only render text that does not exceed the given width and height in data
76 | coordinates.
77 | """
78 |
79 | def __init__(
80 | self,
81 | *,
82 | x: int | float = 0,
83 | y: int | float = 0,
84 | width: int | float | None = None,
85 | height: int | float | None = None,
86 | text: str = "",
87 | **kwargs: Any,
88 | ) -> None:
89 | """
90 | :param x: the `x` coordinate of the text
91 | :param y: the `y` coordinate of the text
92 | :param width: the maximum allowed width for this text, in data coordinates;
93 | if ``None``, width is unrestricted
94 | :param height: the maximum allowed height for this text, in data coordinates;
95 | if ``None``, height is unrestricted
96 | :param text: the text to be rendered
97 | :param kwargs: additional keyword arguments of class
98 | :class:`matplotlib.text.Text`
99 | """
100 | super().__init__(x=x, y=y, text=text, **kwargs)
101 | self._width = width
102 | self._height = height
103 |
104 | def set_width(self, width: int | float | None) -> None:
105 | """
106 | Set the maximum allowed width for this text, in data coordinates.
107 |
108 | :param width: the maximum allowed width; ``None`` if width is unrestricted
109 | """
110 | self.stale = width != self._width
111 | self._width = width
112 |
113 | def get_width(self) -> int | float | None:
114 | """
115 | Get the maximum allowed width for this text, in data coordinates.
116 |
117 | :return: the maximum allowed width; ``None`` if width is unrestricted
118 | """
119 | return self._width
120 |
121 | def set_height(self, height: int | float | None) -> None:
122 | """
123 | Set the maximum allowed height for this text, in data coordinates.
124 |
125 | :param height: the maximum allowed height; ``None`` if height is unrestricted
126 | """
127 | self.stale = height != self._height
128 | self._height = height
129 |
130 | def get_height(self) -> int | float | None:
131 | """
132 | Get the maximum allowed height for this text, in data coordinates.
133 |
134 | :return: the maximum allowed height; ``None`` if height is unrestricted
135 | """
136 | return self._height
137 |
138 | def draw(self, renderer: RendererBase) -> None:
139 | """
140 | Draw the text if it is visible, and if it does not exceed the maximum
141 | width and height.
142 |
143 | See also :meth:`~matplotlib.artist.Artist.draw`.
144 |
145 | :param renderer: the renderer used for drawing
146 | """
147 | width = self.get_width()
148 | height = self.get_height()
149 |
150 | if width is None and height is None:
151 | super().draw(renderer)
152 | else:
153 | (x0, y0), (x1, y1) = self.axes.transData.inverted().transform(
154 | self.get_window_extent(renderer)
155 | )
156 |
157 | if (width is None or abs(x1 - x0) <= width) and (
158 | height is None or abs(y1 - y0) <= height
159 | ):
160 | super().draw(renderer)
161 |
162 | def set(self, **kwargs: Any) -> None:
163 | """
164 | Set multiple properties.
165 |
166 | :param kwargs: the properties to set
167 | """
168 |
169 | if "width" in kwargs:
170 | self.set_width(kwargs.pop("width"))
171 | if "height" in kwargs:
172 | self.set_height(kwargs.pop("height"))
173 | super().set(**kwargs)
174 |
175 |
176 | # check consistency of __all__
177 |
178 | __tracker.validate()
179 |
--------------------------------------------------------------------------------
/test/test/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Test initialisation
3 | """
4 |
--------------------------------------------------------------------------------
/test/test/conftest.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytest
4 |
5 | from pytools.parallelization import Job
6 |
7 | logging.basicConfig(level=logging.DEBUG)
8 | log = logging.getLogger(__name__)
9 |
10 |
11 | @pytest.fixture
12 | def jobs() -> list[Job[int]]:
13 | # generate jobs using a class
14 |
15 | class TestJob(Job[int]):
16 | def __init__(self, x: int) -> None:
17 | self.x = x
18 |
19 | def run(self) -> int:
20 | return self.x
21 |
22 | return [TestJob(i) for i in range(8)]
23 |
24 |
25 | @pytest.fixture
26 | def jobs_delayed() -> list[Job[int]]:
27 | # generate jobs using class function Job.delayed
28 | def plus_2(x: int) -> int:
29 | return x + 2
30 |
31 | return [Job.delayed(plus_2)(i) for i in range(4)]
32 |
--------------------------------------------------------------------------------
/test/test/paths.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | log = logging.getLogger(__name__)
4 |
5 | # directory paths
6 | DIR_DATA = "data"
7 | DIR_CONFIG = "config"
8 |
--------------------------------------------------------------------------------
/test/test/pytools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Test initialisation
3 | """
4 |
--------------------------------------------------------------------------------
/test/test/pytools/test_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic test cases for the `pytools.data` module
3 | """
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pytest
8 | from pandas.testing import assert_frame_equal
9 |
10 | from pytools.data import Matrix
11 |
12 | MSG_GOT_A_3_TUPLE = r"got a 3-tuple"
13 | MSG_GOT_A_3D_ARRAY = r"got a 3d array"
14 | MSG_GOT_A_STR = r"got a str"
15 |
16 |
17 | def test_matrix_validation() -> None:
18 | with pytest.raises(ValueError, match=MSG_GOT_A_3D_ARRAY):
19 | Matrix(np.arange(20).reshape((4, 5, 1)))
20 |
21 | with pytest.raises(TypeError, match=MSG_GOT_A_STR):
22 | # noinspection PyTypeChecker
23 | Matrix(np.arange(20).reshape((4, 5)), names="invalid") # type: ignore
24 |
25 | with pytest.raises(ValueError, match=MSG_GOT_A_3_TUPLE):
26 | # noinspection PyTypeChecker
27 | Matrix(
28 | np.arange(20).reshape((4, 5)),
29 | names=("invalid", "invalid", "invalid"), # type: ignore
30 | )
31 |
32 | with pytest.raises(
33 | ValueError,
34 | match=r"arg names\[0\] must be a 1d array, but has shape \(\)",
35 | ):
36 | Matrix(np.arange(20).reshape((4, 5)), names=("invalid", "invalid"))
37 |
38 | with pytest.raises(ValueError, match=MSG_GOT_A_3_TUPLE):
39 | # noinspection PyTypeChecker
40 | Matrix(np.arange(20).reshape((4, 5)), weights=(1, 2, 3)) # type: ignore
41 |
42 | with pytest.raises(
43 | ValueError, match=r"arg weights\[0\] must be a 1d array, but has shape \(\)"
44 | ):
45 | # noinspection PyTypeChecker
46 | Matrix(np.arange(20).reshape((4, 5)), weights=(1, [2, 4])) # type: ignore
47 |
48 | with pytest.raises(
49 | ValueError,
50 | match=(
51 | r"arg weights\[1\] must have same length as arg values.shape\[1\]=5, "
52 | r"but has length 2"
53 | ),
54 | ):
55 | Matrix(np.arange(20).reshape((4, 5)), weights=(None, [2, 4]))
56 |
57 | with pytest.raises(
58 | ValueError,
59 | match=(
60 | r"arg weights\[1\] should be all positive, but contains negative weights"
61 | ),
62 | ):
63 | Matrix(np.arange(20).reshape((4, 5)), weights=(None, [2, 4, -3, 2, 1]))
64 |
65 | with pytest.raises(
66 | ValueError,
67 | match=MSG_GOT_A_3_TUPLE,
68 | ):
69 | # noinspection PyTypeChecker
70 | Matrix(np.arange(20).reshape((4, 5)), name_labels=(1, 2, 3)) # type: ignore
71 |
72 | with pytest.raises(
73 | TypeError,
74 | match=(
75 | r"^arg value_label requires an instance of one of \{str, NoneType\} "
76 | r"but got: float$"
77 | ),
78 | ):
79 | # noinspection PyTypeChecker
80 | Matrix(np.arange(20).reshape((4, 5)), value_label=1.0) # type: ignore
81 |
82 | with pytest.raises(
83 | TypeError,
84 | match=(
85 | r"^arg weight_label requires an instance of one of \{str, NoneType\} "
86 | r"but got: int$"
87 | ),
88 | ):
89 | # noinspection PyTypeChecker
90 | Matrix(np.arange(20).reshape((4, 5)), weight_label=1) # type: ignore
91 |
92 |
93 | def test_matrix_from_frame() -> None:
94 | values = np.arange(20).reshape((4, 5))
95 | rows = list("ABCD")
96 | columns = list("abcde")
97 |
98 | frame = pd.DataFrame(values, index=rows, columns=columns)
99 | matrix_from_frame: Matrix[np.int_] = Matrix.from_frame(frame)
100 | matrix_expected = Matrix(values, names=(rows, columns))
101 | assert matrix_from_frame == matrix_expected
102 |
103 | assert_frame_equal(matrix_from_frame.to_frame(), frame)
104 |
105 |
106 | def test_matrix_resize() -> None:
107 | m: Matrix[np.int_] = Matrix(
108 | np.arange(20).reshape((4, 5)),
109 | names=(list("ABCD"), list("abcde")),
110 | weights=([2, 4, 2, 4], [1, 5, 4, 1, 5]),
111 | value_label="value",
112 | name_labels=("row", "column"),
113 | weight_label="weight",
114 | )
115 |
116 | assert m.resize(None) == m
117 |
118 | assert m.resize((1, None)) == Matrix(
119 | np.array([[5, 6, 7, 8, 9]]),
120 | names=(["B"], list("abcde")),
121 | weights=([4], [1, 5, 4, 1, 5]),
122 | value_label="value",
123 | name_labels=("row", "column"),
124 | weight_label="weight",
125 | )
126 |
127 | assert m.resize((None, 1)) == Matrix(
128 | np.array([[1], [6], [11], [16]]),
129 | names=(list("ABCD"), ["b"]),
130 | weights=([2, 4, 2, 4], [5]),
131 | value_label="value",
132 | name_labels=("row", "column"),
133 | weight_label="weight",
134 | )
135 |
136 | assert m.resize((1, 1)) == Matrix(
137 | np.array([[6]]),
138 | names=(["B"], ["b"]),
139 | weights=([4], [5]),
140 | value_label="value",
141 | name_labels=("row", "column"),
142 | weight_label="weight",
143 | )
144 |
145 | assert m.resize(1) == Matrix(
146 | np.array([[6]]),
147 | names=(["B"], ["b"]),
148 | weights=([4], [5]),
149 | value_label="value",
150 | name_labels=("row", "column"),
151 | weight_label="weight",
152 | )
153 |
154 | assert m.resize((3, 4)) == Matrix(
155 | np.array([[0, 1, 2, 4], [5, 6, 7, 9], [15, 16, 17, 19]]),
156 | names=(list("ABD"), list("abce")),
157 | weights=([2, 4, 4], [1, 5, 4, 5]),
158 | value_label="value",
159 | name_labels=("row", "column"),
160 | weight_label="weight",
161 | )
162 |
163 | assert m.resize((0.8, 0.0001)) == Matrix(
164 | values=np.array([[1], [6], [16]]),
165 | names=(list("ABD"), ["b"]),
166 | weights=([2, 4, 4], [5]),
167 | value_label="value",
168 | name_labels=("row", "column"),
169 | weight_label="weight",
170 | )
171 |
172 | assert m.resize((4, 5)) == m
173 |
174 | assert m.resize((3, 3)) != m
175 |
176 | with pytest.raises(
177 | ValueError,
178 | match=r"arg size=\(1, 2, 3\) must be a number or a pair of numbers",
179 | ):
180 | # noinspection PyTypeChecker
181 | m.resize((1, 2, 3)) # type: ignore
182 |
183 | with pytest.raises(
184 | ValueError,
185 | match=r"arg size=\(1, '5'\) must be a number or a pair of numbers",
186 | ):
187 | # noinspection PyTypeChecker
188 | m.resize((1, "5")) # type: ignore
189 |
190 | with pytest.raises(
191 | ValueError,
192 | match=r"arg size='5' must be a number or a pair of numbers",
193 | ):
194 | # noinspection PyTypeChecker
195 | m.resize("5") # type: ignore
196 |
197 | with pytest.raises(
198 | ValueError,
199 | match="row size must not be greater than the current number of rows, but is 5",
200 | ):
201 | m.resize(5)
202 |
203 | with pytest.raises(
204 | ValueError,
205 | match=(
206 | "column size must not be greater than the current number of rows, "
207 | "but is 6"
208 | ),
209 | ):
210 | m.resize((4, 6))
211 |
212 | with pytest.raises(ValueError, match="row size must not be negative, but is -4"):
213 | m.resize((-4, 5))
214 |
215 | with pytest.raises(
216 | ValueError, match="column size must not be greater than 1.0, but is 1.5"
217 | ):
218 | m.resize((None, 1.5))
219 |
--------------------------------------------------------------------------------
/test/test/pytools/test_dict_repr.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for dictionary representations of objects.
3 | """
4 |
5 | import logging
6 | from collections.abc import Iterable
7 |
8 | import pytest
9 | from typing_extensions import Self
10 |
11 | from pytools.data.taxonomy import Category
12 | from pytools.repr import HasDictRepr
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | class A(HasDictRepr):
18 | """
19 | A persona.
20 | """
21 |
22 | def __init__(self, x: str) -> None:
23 | self.x = x
24 |
25 | def __eq__(self, other: object) -> bool:
26 | return isinstance(other, A) and self.x == other.x
27 |
28 |
29 | class B(HasDictRepr):
30 | """
31 | A challenge generated from the perspective of a persona.
32 | """
33 |
34 | def __init__(self, y: str, *, a: A) -> None:
35 | self.y = y
36 | self.a = a
37 |
38 | def __eq__(self, other: object) -> bool:
39 | return isinstance(other, B) and self.y == other.y and self.a == other.a
40 |
41 |
42 | class TestCategory(Category):
43 | #: The name of the category.
44 | _name: str
45 |
46 | def __init__(
47 | self, name: str, *, children: Self | Iterable[Self] | None = None
48 | ) -> None:
49 | super().__init__(children=children)
50 | self._name = name
51 |
52 | @property
53 | def name(self) -> str:
54 | return self._name
55 |
56 |
57 | def test_invalid_dict_repr() -> None:
58 | """
59 | Test that invalid or missing keys in the dictionary representation raise a
60 | ValueError.
61 | """
62 |
63 | # missing key
64 | with pytest.raises(ValueError):
65 | HasDictRepr.from_dict({"cls": "unknown_module.UnknownClass"})
66 | with pytest.raises(ValueError):
67 | HasDictRepr.from_dict({"params": "value"})
68 |
69 | # invalid key
70 | with pytest.raises(ValueError):
71 | HasDictRepr.from_dict({"invalid": "value"})
72 |
73 | # invalid nested dict
74 | with pytest.raises(ValueError):
75 | HasDictRepr.from_dict(
76 | {
77 | "cls": "TestCategory",
78 | "params": dict(
79 | name={"a": 1},
80 | description="description",
81 | children=[],
82 | ),
83 | }
84 | )
85 |
86 | with pytest.raises(ValueError):
87 | HasDictRepr.from_dict(
88 | {
89 | "cls": "TestCategory",
90 | "params": dict(
91 | name={"a": 1, "b": 2},
92 | description="description",
93 | children=[],
94 | ),
95 | }
96 | )
97 |
98 | with pytest.raises(ValueError):
99 | HasDictRepr.from_dict(
100 | {
101 | "cls": "TestCategory",
102 | "params": dict(
103 | name={"a": 1, "b": 2, "c": 3},
104 | description="description",
105 | children=[],
106 | ),
107 | }
108 | )
109 |
110 | with pytest.raises(ValueError):
111 | HasDictRepr.from_dict(
112 | {
113 | "cls": "TestCategory",
114 | "params": dict(
115 | name=object(),
116 | description="description",
117 | children=[],
118 | ),
119 | }
120 | )
121 |
122 |
123 | def test_invalid_source_object() -> None:
124 | """c
125 | Test that attempts to convert objects that do not have a dictionary representation
126 | raise a ValueError.
127 | """
128 |
129 | class NoDictRepr:
130 | pass
131 |
132 | class WithDictRepr(HasDictRepr):
133 | def __init__(self, no_dict_repr: NoDictRepr) -> None:
134 | self.no_dict_repr = no_dict_repr
135 |
136 | with pytest.raises(ValueError):
137 | WithDictRepr(NoDictRepr()).to_dict()
138 |
--------------------------------------------------------------------------------
/test/test/pytools/test_http.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests of the flow.util module.
3 | """
4 |
5 | import http.client
6 | import logging
7 | from unittest.mock import Mock, patch
8 |
9 | import pytest
10 |
11 | from pytools.http import fetch_url
12 |
13 | log = logging.getLogger(__name__)
14 |
15 |
16 | # noinspection HttpUrlsUsage
17 | def test_fetch_url() -> None:
18 | # Example URL
19 | example_url = "http://www.example.com/my_test"
20 |
21 | with patch("http.client.HTTPConnection") as MockHTTPConnection:
22 | # Mock the response object
23 | mock_response = Mock()
24 | mock_response.read.return_value = b"...\n"
25 | mock_response.status = http.client.OK
26 |
27 | # Set up the mock connection object
28 | mock_conn = Mock()
29 | mock_conn.getresponse.return_value = mock_response
30 | MockHTTPConnection.return_value = mock_conn
31 |
32 | # Configure the mock connection to return the mock response based on the URL
33 | def mock_request() -> Mock | None:
34 | if mock_conn.request.call_args[0][1] == "/my_test":
35 | return mock_response
36 | else:
37 | return None
38 |
39 | mock_conn.getresponse.side_effect = mock_request
40 |
41 | # Call fetch_url and run assertion on returned data
42 | data = fetch_url(example_url)
43 | assert data is not None
44 | assert len(data) > 0
45 | assert data.startswith(b"")
46 | assert data.endswith(b"