├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── mamimo ├── __init__.py ├── analysis.py ├── carryover.py ├── datasets │ ├── __init__.py │ └── _load_fake_mmm.py ├── linear_model.py ├── saturation.py └── time_utils.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── test_analysis.py ├── test_carryover.py ├── test_linear_model.py └── test_saturation.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E203, W503 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | dist/ 4 | .mypy_cache/ 5 | .pytest_cache/ 6 | .idea/ 7 | .hypothesis/ 8 | docs/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | repos: 3 | - repo: local 4 | hooks: 5 | - id: black 6 | name: black 7 | entry: poetry run black 8 | language: system 9 | types: [python] 10 | - repo: local 11 | hooks: 12 | - id: flake8 13 | name: flake8 14 | entry: poetry run flake8 15 | language: system 16 | types: [python] 17 | - repo: local 18 | hooks: 19 | - id: isort 20 | name: isort 21 | entry: poetry run isort . 22 | language: system 23 | types: [python] 24 | - repo: local 25 | hooks: 26 | - id: mypy 27 | name: mypy 28 | entry: poetry run mypy . 29 | language: system 30 | types: [python] 31 | args: [--no-strict-optional, --ignore-missing-imports] 32 | pass_filenames: false 33 | - repo: local 34 | hooks: 35 | - id: pydocstyle 36 | name: pydocstyle 37 | entry: poetry run pydocstyle . 38 | language: system 39 | types: [python] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaMiMo 2 | This is a small library that helps you with your everyday **Ma**rketing **Mi**x **Mo**delling. It contains a few saturation functions, carryovers and some utilities for creating with time features. You can also read my article about it here: [>>>Click<<<](https://towardsdatascience.com/a-small-python-library-for-marketing-mix-modeling-mamimo-100f31666e18). 3 | 4 | Give it a try via `pip install mamimo`! 5 | 6 | # Small Example 7 | You can create a marketing mix model using different components from MaMiMo as well as [scikit-learn](https://scikit-learn.org/stable/). First, we can create a dataset via 8 | ```python 9 | from mamimo.datasets import load_fake_mmm 10 | 11 | data = load_fake_mmm() 12 | 13 | X = data.drop(columns=['Sales']) 14 | y = data['Sales'] 15 | ``` 16 | 17 | `X` contains media spends only now, but you can enrich it with more information. 18 | 19 | ## Feature Engineering 20 | 21 | MaMiMo lets you add time features, for example, via 22 | 23 | ```python 24 | from mamimo.time_utils import add_time_features, add_date_indicators 25 | 26 | 27 | X = (X 28 | .pipe(add_time_features, month=True) 29 | .pipe(add_date_indicators, special_date=["2020-01-05"]) 30 | .assign(trend=range(200)) 31 | ) 32 | ``` 33 | 34 | This adds 35 | 36 | - a month column (integers between 1 and 12), 37 | - a binary column named special_date that is 1 on the 5h of January 2020 and 0 everywhere else, and 38 | - a (so far linear) trend which is only counting up from 0 to 199. 39 | 40 | `X` looks like this now: 41 | 42 | ![1_iPkUH70amWOZijv6LVhM3A](https://user-images.githubusercontent.com/932327/169354994-624c5608-8dcf-49ae-94e2-5195f019d596.png) 43 | 44 | ## Building a Model 45 | 46 | We can now build a final model like this: 47 | ```python 48 | from mamimo.time_utils import PowerTrend 49 | from mamimo.carryover import ExponentialCarryover 50 | from mamimo.saturation import ExponentialSaturation 51 | from sklearn.linear_model import LinearRegression 52 | from sklearn.preprocessing import OneHotEncoder 53 | from sklearn.compose import ColumnTransformer 54 | from sklearn.pipeline import Pipeline 55 | 56 | cats = [list(range(1, 13))] # different months, known beforehand 57 | 58 | preprocess = ColumnTransformer( 59 | [ 60 | ('tv_pipe', Pipeline([ 61 | ('carryover', ExponentialCarryover()), 62 | ('saturation', ExponentialSaturation()) 63 | ]), ['TV']), 64 | ('radio_pipe', Pipeline([ 65 | ('carryover', ExponentialCarryover()), 66 | ('saturation', ExponentialSaturation()) 67 | ]), ['Radio']), 68 | ('banners_pipe', Pipeline([ 69 | ('carryover', ExponentialCarryover()), 70 | ('saturation', ExponentialSaturation()) 71 | ]), ['Banners']), 72 | ('month', OneHotEncoder(sparse=False, categories=cats), ['month']), 73 | ('trend', PowerTrend(), ['trend']), 74 | ('special_date', ExponentialCarryover(), ['special_date']) 75 | ] 76 | ) 77 | 78 | model = Pipeline([ 79 | ('preprocess', preprocess), 80 | ('regression', LinearRegression( 81 | positive=True, 82 | fit_intercept=False # no intercept because of the months 83 | ) 84 | ) 85 | ]) 86 | ``` 87 | 88 | This builds a model that does the following: 89 | - the media channels are preprocessed using the [adstock transformation](https://en.wikipedia.org/wiki/Advertising_adstock), i.e. a carryover effect and a saturation is added 90 | - the month is one-hot (dummy) encoded 91 | - the trend is changed from linear to something like t^a, with some exponent a to be optimized 92 | - the special_date 2020-01-05 gets a carryover effect as well, meaning that not only on this special week there was some special effect on the sales, but also the weeks after it 93 | 94 | ## Training The Model 95 | We can then hyperparameter tune the model via 96 | ```python 97 | from scipy.stats import randint, uniform 98 | from sklearn.model_selection import RandomizedSearchCV, TimeSeriesSplit 99 | 100 | tuned_model = RandomizedSearchCV( 101 | model, 102 | param_distributions={ 103 | 'preprocess__tv_pipe__carryover__window': randint(1, 10), 104 | 'preprocess__tv_pipe__carryover__strength': uniform(0, 1), 105 | 'preprocess__tv_pipe__saturation__exponent': uniform(0, 1), 106 | 'preprocess__radio_pipe__carryover__window': randint(1, 10), 107 | 'preprocess__radio_pipe__carryover__strength': uniform(0, 1), 108 | 'preprocess__radio_pipe__saturation__exponent': uniform(0, 1), 109 | 'preprocess__banners_pipe__carryover__window': randint(1, 10), 110 | 'preprocess__banners_pipe__carryover__strength': uniform(0, 1), 111 | 'preprocess__banners_pipe__saturation__exponent': uniform(0, 1), 112 | 'preprocess__trend__power': uniform(0, 2), 113 | 'preprocess__special_date__window': randint(1, 10), 114 | 'preprocess__special_date__strength': uniform(0, 1), 115 | }, 116 | cv=TimeSeriesSplit(), 117 | random_state=0, 118 | n_iter=1000, # can take some time, lower number for faster results 119 | ) 120 | 121 | tuned_model.fit(X, y) 122 | ``` 123 | 124 | You can also use `GridSearch`, Optuna, or other hyperparameter tune methods and packages here, as long as it is compatible to scikit-learn. 125 | 126 | ## Analyzing 127 | With `tuned_model.predict(X)` and some plotting, we get 128 | 129 | ![1_Bf4NKiUPNVVH87-7PNNZGw](https://user-images.githubusercontent.com/932327/169356818-158a322e-c18c-4404-a32f-ee69778c4d22.png) 130 | 131 | You can get the best found hyperparameters using `print(tuned_model.best_params_)`. 132 | 133 | ### Plotting 134 | You can compute the channel contributions via 135 | ```python 136 | from mamimo.analysis import breakdown 137 | 138 | contributions = breakdown(tuned_model.best_estimator_, X, y) 139 | ``` 140 | 141 | This returns a dataframe with the contributions of each channel fo each time step, summing to the historical values present in `y`. You can get a nice plot via 142 | ```python 143 | ax = contributions.plot.area( 144 | figsize=(16, 10), 145 | linewidth=1, 146 | title="Predicted Sales and Breakdown", 147 | ylabel="Sales", 148 | xlabel="Date", 149 | ) 150 | handles, labels = ax.get_legend_handles_labels() 151 | ax.legend( 152 | handles[::-1], 153 | labels[::-1], 154 | title="Channels", 155 | loc="center left", 156 | bbox_to_anchor=(1.01, 0.5), 157 | ) 158 | ``` 159 | 160 | ![1_SIlnsYXxRjhSZf-1jE4aDQ](https://user-images.githubusercontent.com/932327/169357525-c4f79fa0-a2fd-46b2-8331-47e534737d81.png) 161 | 162 | Wow, that's a lot of channels. Let us group some of them together. 163 | 164 | ```python 165 | group_channels = {'Baseline': [f'month__month_{i}' for i in range(1, 13)] + ['Base', 'trend__trend']} 166 | # read: 'Baseline consists of the months, base and trend.' 167 | # You can add more groups! 168 | 169 | contributions = breakdown( 170 | tuned_model.best_estimator_, 171 | X, 172 | y, 173 | group_channels 174 | ) 175 | ``` 176 | 177 | If we plot again, we get 178 | 179 | ![1_xHzrUMMTKGxo7dvKpebjNg](https://user-images.githubusercontent.com/932327/169357648-13ae9097-d45b-4690-b3dd-63139da020b7.png) 180 | 181 | Yay! 182 | 183 | ----------------- 184 | [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/G2G7EBKVH) 185 | 186 | -------------------------------------------------------------------------------- /mamimo/__init__.py: -------------------------------------------------------------------------------- 1 | """A package to compute a marketing mix model.""" 2 | 3 | __version__ = "0.4.3" 4 | -------------------------------------------------------------------------------- /mamimo/analysis.py: -------------------------------------------------------------------------------- 1 | """Analyze trained marketing mix models.""" 2 | 3 | from typing import Dict, List, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.pipeline import Pipeline 8 | 9 | 10 | def breakdown( 11 | model: Pipeline, 12 | X: pd.DataFrame, 13 | y: np.ndarray, 14 | group_channels: Optional[Dict[str, List[str]]] = None, 15 | ): 16 | """ 17 | Compute the contributions for each channel. 18 | 19 | Parameters 20 | ---------- 21 | model : sklearn.pipeline.Pipeline 22 | The trained marketing mix model. Should be a pipeline consisting of two steps: 23 | 1. preprocessing (e.g. adstock transformations) 24 | 2. regression via a linear model. 25 | 26 | X : pd.Dataframe of shape (n_samples, n_features) 27 | The training data. 28 | 29 | y : np.ndarray, 1-dimensional 30 | The training labels. 31 | 32 | group_channels : Dict[str, List[str]], default=None 33 | Create new channels by grouping (i.e. summing) the channels in the input. 34 | 35 | Returns 36 | ------- 37 | pd.DataFrame 38 | A table consisting of the contributions of each channel in each timestep. 39 | The row-wise sum of this dataframe equals `y`. 40 | 41 | """ 42 | preprocessing = model.steps[-2][1] 43 | regression = model.steps[-1][1] 44 | channel_names = preprocessing.get_feature_names_out() 45 | 46 | after_preprocessing = pd.DataFrame( 47 | preprocessing.transform(X), columns=channel_names, index=X.index 48 | ) 49 | 50 | regression_weights = pd.Series(regression.coef_, index=channel_names) 51 | 52 | base = regression.intercept_ 53 | 54 | unadjusted_breakdown = after_preprocessing.mul(regression_weights).assign(Base=base) 55 | adjusted_breakdown = unadjusted_breakdown.div( 56 | unadjusted_breakdown.sum(axis=1), axis=0 57 | ).mul(y, axis=0) 58 | 59 | if group_channels is not None: 60 | for new_channel, old_channels in group_channels.items(): 61 | adjusted_breakdown[new_channel] = sum( 62 | adjusted_breakdown.pop(old_channel) 63 | for old_channel in old_channels 64 | ) 65 | 66 | return adjusted_breakdown 67 | -------------------------------------------------------------------------------- /mamimo/carryover.py: -------------------------------------------------------------------------------- 1 | """Smooth time series data.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import List, Optional 7 | 8 | import numpy as np 9 | from scipy.signal import convolve2d 10 | from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin 11 | from sklearn.utils import check_array 12 | from sklearn.utils.validation import ( 13 | FLOAT_DTYPES, 14 | _check_feature_names_in, 15 | check_is_fitted, 16 | ) 17 | 18 | 19 | class Carryover(OneToOneFeatureMixin, BaseEstimator, TransformerMixin, ABC): 20 | """ 21 | Smooth the columns of an array by applying a convolution. 22 | 23 | Parameters 24 | ---------- 25 | window : int 26 | Size of the sliding window. The effect of a holiday will reach from 27 | approximately date - `window/2 * frequency` to date + `window/2 * frequency`, 28 | i.e. it is centered around the dates in `dates`. 29 | 30 | mode : str 31 | Which convolution mode to use. Can be one of 32 | - "full": The output is the full discrete linear convolution of the inputs. 33 | - "valid": The output consists only of those elements that do not rely on 34 | the zero-padding. 35 | - "same": The output is the same size as the first input, centered with 36 | respect to the 'full' output. 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | window: int, 43 | mode: str, 44 | ) -> None: 45 | """Initialize.""" 46 | self.window = window 47 | self.mode = mode 48 | 49 | @abstractmethod 50 | def _get_sliding_window(self) -> np.ndarray: 51 | """ 52 | Calculate the sliding window. 53 | 54 | Returns 55 | ------- 56 | sliding_window : np.array 57 | The sliding window. 58 | 59 | """ 60 | 61 | def fit(self, X: np.ndarray, y: None = None) -> Carryover: 62 | """ 63 | Fit the estimator. 64 | 65 | The frequency is computed and the sliding window is created. 66 | 67 | Parameters 68 | ---------- 69 | X : np.ndarray 70 | Used for inferring the frequency, if not provided during initialization. 71 | 72 | y : Ignored 73 | Not used, present here for API consistency by convention. 74 | 75 | Returns 76 | ------- 77 | Carryover 78 | Fitted transformer. 79 | 80 | """ 81 | _ = self._validate_data(X, dtype=FLOAT_DTYPES) 82 | 83 | self.sliding_window_ = self._get_sliding_window() 84 | self.sliding_window_ = ( 85 | self.sliding_window_.reshape(-1, 1) / self.sliding_window_.sum() 86 | ) 87 | 88 | return self 89 | 90 | def transform(self, X: np.ndarray) -> np.ndarray: 91 | """ 92 | Add the new date feature to the dataframe. 93 | 94 | Parameters 95 | ---------- 96 | X : np.ndarray 97 | A pandas dataframe with a DatetimeIndex. 98 | 99 | Returns 100 | ------- 101 | np.ndarray 102 | The input dataframe with an additional column for special dates. 103 | 104 | """ 105 | check_is_fitted(self) 106 | X = check_array(X) 107 | self._check_n_features(X, reset=False) 108 | 109 | convolution = convolve2d(X, self.sliding_window_, mode=self.mode) 110 | 111 | if self.mode == "full" and self.window > 1: 112 | convolution = convolution[: -self.window + 1] 113 | 114 | return convolution 115 | 116 | def get_feature_names_out(self, input_features: Optional[List] = None): 117 | """ 118 | Get the output feature names. 119 | 120 | Parameters 121 | ---------- 122 | input_features : list (optional), default0None 123 | Input feature names. 124 | 125 | Returns 126 | ------- 127 | np.ndarray 128 | Output feature names. 129 | 130 | """ 131 | input_features = _check_feature_names_in(self, input_features) 132 | 133 | return np.array(input_features, dtype=object) 134 | 135 | 136 | class GeneralGaussianCarryover(Carryover): 137 | """ 138 | Smoothes time series data with a Gaussian window. 139 | 140 | Smooth the columns of an array by applying a convolution with a generalized 141 | Gaussian curve. 142 | 143 | Parameters 144 | ---------- 145 | window : int, default=1 146 | Size of the sliding window. The effect of a holiday will reach from 147 | approximately date - `window/2 * frequency` to date + `window/2 * frequency`, 148 | i.e. it is centered around the dates in `dates`. 149 | 150 | p : float, default=1 151 | Parameter for the shape of the curve. p=1 yields a typical Gaussian curve 152 | while p=0.5 yields a Laplace curve, for example. 153 | 154 | sig : float, default=1 155 | Parameter for the standard deviation of the bell-shaped curve. 156 | 157 | tails : str, default="both" 158 | Which tails to use. Can be one of 159 | - "left" 160 | - "right" 161 | - "both" 162 | 163 | Examples 164 | -------- 165 | >>> import numpy as np 166 | >>> X = np.array([0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1) 167 | >>> GeneralGaussianCarryover().fit_transform(X) 168 | array([[0.], 169 | [0.], 170 | [0.], 171 | [1.], 172 | [0.], 173 | [0.], 174 | [0.]]) 175 | 176 | >>> GeneralGaussianCarryover(window=5, p=1, sig=1).fit_transform(X) 177 | array([[0. ], 178 | [0.05448868], 179 | [0.24420134], 180 | [0.40261995], 181 | [0.24420134], 182 | [0.05448868], 183 | [0. ]]) 184 | 185 | >>> GeneralGaussianCarryover(window=7, tails="right").fit_transform(X) 186 | array([[0. ], 187 | [0. ], 188 | [0. ], 189 | [0.57045881], 190 | [0.34600076], 191 | [0.0772032 ], 192 | [0.00633722]]) 193 | 194 | """ 195 | 196 | def __init__( 197 | self, 198 | window: int = 1, 199 | p: float = 1, 200 | sig: float = 1, 201 | tails: str = "both", 202 | ) -> None: 203 | """Initialize.""" 204 | super().__init__(window, mode="same") 205 | self.p = p 206 | self.sig = sig 207 | self.tails = tails 208 | 209 | def _get_sliding_window(self) -> np.ndarray: 210 | """ 211 | Calculate the sliding window. 212 | 213 | Returns 214 | ------- 215 | sliding_window : np.array 216 | The sliding window. 217 | 218 | Raises 219 | ------ 220 | ValueError 221 | If the provided value for `tails` is not "left", "right" or "both". 222 | 223 | """ 224 | sliding_window = np.exp( 225 | -0.5 226 | * np.abs(np.arange(-self.window // 2 + 1, self.window // 2 + 1) / self.sig) 227 | ** (2 * self.p) 228 | ) 229 | if self.tails == "left": 230 | sliding_window[self.window // 2 + 1 :] = 0 231 | elif self.tails == "right": 232 | sliding_window[: self.window // 2] = 0 233 | elif self.tails != "both": 234 | raise ValueError( 235 | "tails keyword has to be one of 'both', 'left' or 'right'." 236 | ) 237 | return sliding_window 238 | 239 | 240 | class ExponentialCarryover(Carryover): 241 | """ 242 | Smoothes time series data with an exponential window. 243 | 244 | Smooth the columns of an array by applying a convolution with an exponentially 245 | decaying curve. This class can be used for modelling carry over effects in 246 | marketing mix models. 247 | 248 | Parameters 249 | ---------- 250 | window : int, default=1 251 | Size of the sliding window. The effect of a holiday will reach from 252 | approximately date - `window/2 * frequency` to date + `window/2 * frequency`, 253 | i.e. it is centered around the dates in `dates`. 254 | 255 | strength : float, default=0.0 256 | Fraction of the spending effect that is carried over. 257 | 258 | peak : float, default=0.0 259 | Where the carryover effect peaks. 260 | 261 | exponent : float, default=1.0 262 | To further widen or narrow the carryover curve. A value of 1.0 yields a normal 263 | exponential decay. With values larger than 1.0, a super exponential decay can 264 | be achieved. 265 | 266 | Examples 267 | -------- 268 | >>> import numpy as np 269 | >>> X = np.array([0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1) 270 | >>> ExponentialCarryover().fit_transform(X) 271 | array([[0.], 272 | [0.], 273 | [0.], 274 | [1.], 275 | [0.], 276 | [0.], 277 | [0.]]) 278 | 279 | >>> ExponentialCarryover(window=3, strength=0.5).fit_transform(X) 280 | array([[0. ], 281 | [0. ], 282 | [0. ], 283 | [0.57142857], 284 | [0.28571429], 285 | [0.14285714], 286 | [0. ]]) 287 | 288 | >>> ExponentialCarryover(window=3, strength=0.5, peak=1).fit_transform(X) 289 | array([[0. ], 290 | [0. ], 291 | [0. ], 292 | [0.25], 293 | [0.5 ], 294 | [0.25], 295 | [0. ]]) 296 | 297 | """ 298 | 299 | def __init__( 300 | self, 301 | window: int = 1, 302 | strength: float = 0.0, 303 | peak: float = 0.0, 304 | exponent: float = 1.0, 305 | ) -> None: 306 | """Initialize.""" 307 | super().__init__(window, mode="full") 308 | self.strength = strength 309 | self.peak = peak 310 | self.exponent = exponent 311 | 312 | def _get_sliding_window(self) -> np.ndarray: 313 | """ 314 | Calculate the sliding window. 315 | 316 | Returns 317 | ------- 318 | sliding_window : np.array 319 | The sliding window. 320 | 321 | """ 322 | return self.strength ** ( 323 | np.abs(np.arange(self.window) - self.peak) ** self.exponent 324 | ) 325 | -------------------------------------------------------------------------------- /mamimo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Location for marketing mix modeling datasets.""" 2 | 3 | from ._load_fake_mmm import load_fake_mmm 4 | 5 | __all__ = ["load_fake_mmm"] 6 | -------------------------------------------------------------------------------- /mamimo/datasets/_load_fake_mmm.py: -------------------------------------------------------------------------------- 1 | """Just some fake data for marketing mix modeling.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.pipeline import make_pipeline 6 | 7 | from mamimo.carryover import ExponentialCarryover 8 | from mamimo.saturation import ExponentialSaturation 9 | from mamimo.time_utils import add_date_indicators 10 | 11 | 12 | def load_fake_mmm(): 13 | """Load the data.""" 14 | np.random.seed(0) 15 | 16 | data = pd.DataFrame( 17 | { 18 | "TV": np.random.normal(loc=10000, scale=2000, size=200) 19 | * np.random.binomial(n=1, p=0.3, size=200), 20 | "Radio": np.random.normal(loc=5000, scale=1000, size=200) 21 | * np.random.binomial(n=1, p=0.5, size=200), 22 | "Banners": np.random.normal(loc=2000, scale=200, size=200) 23 | * np.random.binomial(n=1, p=0.8, size=200), 24 | }, 25 | index=pd.date_range(start="2018-01-01", periods=200, freq="w"), 26 | ).clip(0, np.inf) 27 | 28 | tv_pipe = make_pipeline( 29 | ExponentialCarryover(window=4, strength=0.5), 30 | ExponentialSaturation(exponent=0.0001), 31 | ) 32 | radio_pipe = make_pipeline( 33 | ExponentialCarryover(window=2, strength=0.2), 34 | ExponentialSaturation(exponent=0.0001), 35 | ) 36 | banners_pipe = make_pipeline(ExponentialSaturation(exponent=0.0001)) 37 | date_carryover = ExponentialCarryover(window=10, strength=0.6) 38 | 39 | adstock_data = data.copy().pipe( 40 | add_date_indicators, some_special_date=["2020-01-05"] 41 | ) 42 | adstock_data["TV"] = tv_pipe.fit_transform(adstock_data[["TV"]]) 43 | adstock_data["Radio"] = radio_pipe.fit_transform(adstock_data[["Radio"]]) 44 | adstock_data["Banners"] = banners_pipe.fit_transform(adstock_data[["Banners"]]) 45 | adstock_data["some_special_date"] = date_carryover.fit_transform( 46 | adstock_data[["some_special_date"]] 47 | ) 48 | 49 | sales = ( 50 | 10000 * adstock_data["TV"] 51 | + 8000 * adstock_data["Radio"] 52 | + 14000 * adstock_data["Banners"] 53 | + 1000 * np.sin(np.arange(200) * 2 * np.pi / 52) 54 | + 40 * np.arange(200) ** 1.2 55 | + 80000 * adstock_data["some_special_date"] 56 | + 500 * np.random.randn(200) 57 | ) 58 | 59 | data["Sales"] = sales 60 | 61 | return data.rename_axis(index="Date").round(2).clip(0, np.inf) 62 | -------------------------------------------------------------------------------- /mamimo/linear_model.py: -------------------------------------------------------------------------------- 1 | """Perform the actual regression using linear models.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Callable, List, Optional, Tuple 7 | 8 | import numpy as np 9 | from scipy.optimize import minimize 10 | from sklearn.base import BaseEstimator, RegressorMixin 11 | from sklearn.utils.validation import ( 12 | _check_sample_weight, 13 | check_array, 14 | check_is_fitted, 15 | check_X_y, 16 | ) 17 | 18 | 19 | class BaseScipyMinimizeRegressor(BaseEstimator, RegressorMixin, ABC): 20 | """ 21 | Base class for regressors relying on scipy's minimize method. 22 | 23 | Derive a class from this one and give it the function to be minimized. 24 | 25 | Parameters 26 | ---------- 27 | alpha : float, default=0.0 28 | Constant that multiplies the penalty terms. Defaults to 1.0. 29 | 30 | l1_ratio : float, default=0.0 31 | The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For 32 | ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it 33 | is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a 34 | combination of L1 and L2. 35 | 36 | fit_intercept : bool, default=True 37 | Whether to calculate the intercept for this model. If set 38 | to False, no intercept will be used in calculations 39 | (i.e. data is expected to be centered). 40 | 41 | copy_X : bool, default=True 42 | If True, X will be copied; else, it may be overwritten. 43 | 44 | positive : bool, default=False 45 | When set to True, forces the coefficients to be positive. 46 | 47 | monotone_constraints : list (optional), default=None 48 | A list containing as many numbers as there are features. The numbers should be 49 | - -1 to indicate that the coefficient in this position should be negative 50 | - 0 if the coefficient is unrestricted, and 51 | - 1 if the coefficient should be positive. 52 | 53 | Attributes 54 | ---------- 55 | coef_ : np.ndarray of shape (n_features,) 56 | Estimated coefficients of the model. 57 | 58 | intercept_ : float 59 | Independent term in the linear model. Set to 0.0 if fit_intercept = False. 60 | 61 | Notes 62 | ----- 63 | This implementation uses scipy.optimize.minimize, see 64 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html. 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | alpha: float = 0.0, 71 | l1_ratio: float = 0.0, 72 | fit_intercept: bool = True, 73 | copy_X: bool = True, 74 | positive: bool = False, 75 | monotone_constraints: Optional[List[int]] = None, 76 | ) -> None: 77 | """Initialize.""" 78 | self.alpha = alpha 79 | self.l1_ratio = l1_ratio 80 | self.fit_intercept = fit_intercept 81 | self.copy_X = copy_X 82 | self.positive = positive 83 | self.monotone_constraints = monotone_constraints 84 | 85 | @abstractmethod 86 | def _get_objective( 87 | self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray 88 | ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: 89 | """ 90 | Produce the loss function to be minimized. 91 | 92 | Also outputs its gradient to speed up computations. 93 | 94 | Parameters 95 | ---------- 96 | X : np.ndarray of shape (n_samples, n_features) 97 | The training data. 98 | 99 | y : np.ndarray, 1-dimensional 100 | The target values. 101 | 102 | sample_weight : Optional[np.ndarray], default=None 103 | Individual weights for each sample. 104 | 105 | Returns 106 | ------- 107 | loss : Callable[[np.ndarray], float] 108 | The loss function to be minimized. 109 | 110 | grad_loss : Callable[[np.ndarray], np.ndarray] 111 | The gradient of the loss function. Speeds up finding the minimum. 112 | 113 | """ 114 | 115 | def _loss_regularize(self, loss): 116 | def regularized_loss(params): 117 | return ( 118 | loss(params) 119 | + self.alpha * self.l1_ratio * np.sum(np.abs(params)) 120 | + 0.5 * self.alpha * (1 - self.l1_ratio) * np.sum(params**2) 121 | ) 122 | 123 | return regularized_loss 124 | 125 | def _grad_loss_regularize(self, grad_loss): 126 | def regularized_grad_loss(params): 127 | return ( 128 | grad_loss(params) 129 | + self.alpha * self.l1_ratio * np.sign(params) 130 | + self.alpha * (1 - self.l1_ratio) * params 131 | ) 132 | 133 | return regularized_grad_loss 134 | 135 | def fit( 136 | self, 137 | X: np.ndarray, 138 | y: np.ndarray, 139 | sample_weight: Optional[np.ndarray] = None, 140 | ) -> BaseScipyMinimizeRegressor: 141 | """ 142 | Fit the model using the SLSQP algorithm. 143 | 144 | Parameters 145 | ---------- 146 | X : np.ndarray of shape (n_samples, n_features) 147 | The training data. 148 | 149 | y : np.ndarray, 1-dimensional 150 | The target values. 151 | 152 | sample_weight : Optional[np.ndarray], default=None 153 | Individual weights for each sample. 154 | 155 | Returns 156 | ------- 157 | Fitted regressor. 158 | 159 | """ 160 | X_, grad_loss, loss = self._prepare_inputs(X, sample_weight, y) 161 | 162 | d = X_.shape[1] - self.n_features_in_ # This is either zero or one. 163 | 164 | if self.monotone_constraints is not None: 165 | monotone_constraints = self.monotone_constraints[:] 166 | elif self.positive: 167 | monotone_constraints = self.n_features_in_ * [1] 168 | else: 169 | monotone_constraints = self.n_features_in_ * [0] 170 | bounds = [ 171 | (0, np.inf) if c == 1 else (-np.inf, 0) if c == -1 else (-np.inf, np.inf) 172 | for c in monotone_constraints 173 | ] + d * [(-np.inf, np.inf)] 174 | 175 | minimize_result = minimize( 176 | loss, 177 | x0=np.zeros(self.n_features_in_ + d), 178 | bounds=bounds, 179 | method="SLSQP", 180 | jac=grad_loss, 181 | tol=1e-20, 182 | ) 183 | self.convergence_status_ = minimize_result.message 184 | 185 | if self.fit_intercept: 186 | *self.coef_, self.intercept_ = minimize_result.x 187 | else: 188 | self.coef_ = minimize_result.x 189 | self.intercept_ = 0.0 190 | 191 | self.coef_ = np.array(self.coef_) 192 | 193 | return self 194 | 195 | def _prepare_inputs(self, X, sample_weight, y): 196 | X, y = check_X_y(X, y) 197 | sample_weight = _check_sample_weight(sample_weight, X) 198 | self._check_n_features(X, reset=True) 199 | 200 | n = X.shape[0] 201 | 202 | X_ = X.copy() if self.copy_X else X 203 | if self.fit_intercept: 204 | X_ = np.hstack([X_, np.ones(shape=(n, 1))]) 205 | 206 | loss, grad_loss = self._get_objective(X_, y, sample_weight) 207 | 208 | return X_, grad_loss, loss 209 | 210 | def predict(self, X: np.ndarray) -> np.ndarray: 211 | """ 212 | Predict using the linear model. 213 | 214 | Parameters 215 | ---------- 216 | X : np.ndarray, shape (n_samples, n_features) 217 | Samples to get predictions of. 218 | 219 | Returns 220 | ------- 221 | y : np.ndarray, shape (n_samples,) 222 | The predicted values. 223 | 224 | """ 225 | check_is_fitted(self) 226 | X = check_array(X) 227 | self._check_n_features(X, reset=False) 228 | 229 | return X @ self.coef_ + self.intercept_ 230 | 231 | 232 | class LADRegression(BaseScipyMinimizeRegressor): 233 | """ 234 | Least absolute deviation Regression. 235 | 236 | `LADRegression` fits a linear model to minimize the residual sum of absolute 237 | deviations between the observed targets in the dataset, and the targets 238 | predicted by the linear approximation, i.e. 239 | 240 | 1 / (2 * n_samples) * ||y - Xw||_1 241 | + alpha * l1_ratio * ||w||_1 242 | + 0.5 * alpha * (1 - l1_ratio) * ||w||_2 ** 2 243 | 244 | Compared to linear regression, this approach is robust to outliers. You can even 245 | optimize for the lowest MAPE (Mean Average Percentage Error), if you pass in 246 | np.abs(1/y_train) for the `sample_weight` keyword when fitting the regressor. 247 | 248 | Parameters 249 | ---------- 250 | alpha : float, default=0.0 251 | Constant that multiplies the penalty terms. 252 | 253 | l1_ratio : float, default=0.0 254 | The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For 255 | ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it 256 | is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a 257 | combination of L1 and L2. 258 | 259 | fit_intercept : bool, default=True 260 | Whether to calculate the intercept for this model. If set 261 | to False, no intercept will be used in calculations 262 | (i.e. data is expected to be centered). 263 | 264 | copy_X : bool, default=True 265 | If True, X will be copied; else, it may be overwritten. 266 | 267 | positive : bool, default=False 268 | When set to True, forces the coefficients to be positive. 269 | 270 | monotone_constraints : list (optional), default=None 271 | A list containing as many numbers as there are features. The numbers should be 272 | - -1 to indicate that the coefficient in this position should be negative 273 | - 0 if the coefficient is unrestricted, and 274 | - 1 if the coefficient should be positive. 275 | 276 | Attributes 277 | ---------- 278 | coef_ : np.ndarray of shape (n_features,) 279 | Estimated coefficients of the model. 280 | 281 | intercept_ : float 282 | Independent term in the linear model. Set to 0.0 if fit_intercept = False. 283 | 284 | Notes 285 | ----- 286 | This implementation uses scipy.optimize.minimize, see 287 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html. 288 | 289 | Examples 290 | -------- 291 | >>> import numpy as np 292 | >>> np.random.seed(0) 293 | >>> X = np.random.randn(100, 4) 294 | >>> y = X @ np.array([1, 2, 3, 4]) 295 | >>> l = LADRegression().fit(X, y) 296 | >>> l.coef_ 297 | array([1., 2., 3., 4.]) 298 | 299 | >>> import numpy as np 300 | >>> np.random.seed(0) 301 | >>> X = np.random.randn(100, 4) 302 | >>> y = X @ np.array([-1, 2, -3, 4]) 303 | >>> l = LADRegression(positive=True).fit(X, y) 304 | >>> l.coef_ 305 | array([8.44480086e-17, 1.42423304e+00, 1.97135192e-16, 4.29789588e+00]) 306 | 307 | """ 308 | 309 | def _get_objective( 310 | self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray 311 | ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: 312 | @self._loss_regularize 313 | def mae_loss(params): 314 | return np.mean(sample_weight * np.abs(y - X @ params)) 315 | 316 | @self._grad_loss_regularize 317 | def grad_mae_loss(params): 318 | return -(sample_weight * np.sign(y - X @ params)) @ X / X.shape[0] 319 | 320 | return mae_loss, grad_mae_loss 321 | 322 | 323 | class QuantileRegression(BaseScipyMinimizeRegressor): 324 | """ 325 | Compute Quantile Regression. 326 | 327 | This can be used for computing confidence intervals of linear regressions. 328 | `QuantileRegression` fits a linear model to minimize a weighted residual sum of 329 | absolute deviations between the observed targets in the dataset and the targets 330 | predicted by the linear approximation, i.e. 331 | 332 | 1 / (2 * n_samples) * switch * ||y - Xw||_1 333 | + alpha * l1_ratio * ||w||_1 334 | + 0.5 * alpha * (1 - l1_ratio) * ||w||_2 ** 2 335 | 336 | where switch is a vector with value `quantile` if y - Xw < 0, else `1 - quantile`. 337 | The regressor defaults to `LADRegression` for its default value of `quantile=0.5`. 338 | 339 | Compared to linear regression, this approach is robust to outliers. 340 | 341 | Parameters 342 | ---------- 343 | alpha : float, default=0.0 344 | Constant that multiplies the penalty terms. 345 | 346 | l1_ratio : float, default=0.0 347 | The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For 348 | ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it 349 | is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a 350 | combination of L1 and L2. 351 | 352 | fit_intercept : bool, default=True 353 | Whether to calculate the intercept for this model. If set 354 | to False, no intercept will be used in calculations 355 | (i.e. data is expected to be centered). 356 | 357 | copy_X : bool, default=True 358 | If True, X will be copied; else, it may be overwritten. 359 | 360 | positive : bool, default=False 361 | When set to True, forces the coefficients to be positive. 362 | 363 | monotone_constraints : list (optional), default=None 364 | A list containing as many numbers as there are features. The numbers should be 365 | - -1 to indicate that the coefficient in this position should be negative 366 | - 0 if the coefficient is unrestricted, and 367 | - 1 if the coefficient should be positive. 368 | 369 | quantile : float, between 0 and 1, default=0.5 370 | The line output by the model will have a share of approximately `quantile` 371 | data points under it. A value of `quantile=1` outputs a line that is above 372 | each data point, for example. `quantile=0.5` corresponds to LADRegression. 373 | 374 | Attributes 375 | ---------- 376 | coef_ : np.ndarray of shape (n_features,) 377 | Estimated coefficients of the model. 378 | 379 | intercept_ : float 380 | Independent term in the linear model. Set to 0.0 if fit_intercept = False. 381 | 382 | Notes 383 | ----- 384 | This implementation uses scipy.optimize.minimize, see 385 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html. 386 | 387 | Examples 388 | -------- 389 | >>> import numpy as np 390 | >>> np.random.seed(0) 391 | >>> X = np.random.randn(100, 4) 392 | >>> y = X @ np.array([1, 2, 3, 4]) 393 | >>> l = QuantileRegression().fit(X, y) 394 | >>> l.coef_ 395 | array([1., 2., 3., 4.]) 396 | 397 | >>> import numpy as np 398 | >>> np.random.seed(0) 399 | >>> X = np.random.randn(100, 4) 400 | >>> y = X @ np.array([-1, 2, -3, 4]) 401 | >>> l = QuantileRegression(quantile=0.8).fit(X, y) 402 | >>> l.coef_ 403 | array([-1., 2., -3., 4.]) 404 | 405 | """ 406 | 407 | def __init__( 408 | self, 409 | alpha: float = 0.0, 410 | l1_ratio: float = 0.0, 411 | fit_intercept: bool = True, 412 | copy_X: bool = True, 413 | positive: bool = False, 414 | monotone_constraints: Optional[List[int]] = None, 415 | quantile: float = 0.5, 416 | ) -> None: 417 | """Initialize.""" 418 | super().__init__( 419 | alpha, l1_ratio, fit_intercept, copy_X, positive, monotone_constraints 420 | ) 421 | self.quantile = quantile 422 | 423 | def _get_objective( 424 | self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray 425 | ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: 426 | @self._loss_regularize 427 | def imbalanced_loss(params): 428 | return np.mean( 429 | sample_weight 430 | * np.where(X @ params < y, self.quantile, 1 - self.quantile) 431 | * np.abs(y - X @ params) 432 | ) 433 | 434 | @self._grad_loss_regularize 435 | def grad_imbalanced_loss(params): 436 | return ( 437 | -( 438 | sample_weight 439 | * np.where(X @ params < y, self.quantile, 1 - self.quantile) 440 | * np.sign(y - X @ params) 441 | ) 442 | @ X 443 | / X.shape[0] 444 | ) 445 | 446 | return imbalanced_loss, grad_imbalanced_loss 447 | 448 | def fit( 449 | self, 450 | X: np.ndarray, 451 | y: np.ndarray, 452 | sample_weight: Optional[np.ndarray] = None, 453 | ) -> "QuantileRegression": 454 | """ 455 | Fit the model using the SLSQP algorithm. 456 | 457 | Parameters 458 | ---------- 459 | X : np.ndarray of shape (n_samples, n_features) 460 | The training data. 461 | 462 | y : np.ndarray, 1-dimensional 463 | The target values. 464 | 465 | sample_weight : Optional[np.ndarray], default=None 466 | Individual weights for each sample. 467 | 468 | Returns 469 | ------- 470 | Fitted regressor. 471 | 472 | """ 473 | if 0 <= self.quantile <= 1: 474 | super().fit(X, y, sample_weight) 475 | else: 476 | raise ValueError("Parameter quantile should be between zero and one.") 477 | 478 | return self 479 | 480 | 481 | class ImbalancedLinearRegression(BaseScipyMinimizeRegressor): 482 | """ 483 | Linear regression where over and underestimating are treated differently. 484 | 485 | A value of `overestimation_punishment_factor=5` implies that overestimations by the 486 | model are penalized with a factor of 5 while underestimations have a default factor 487 | of 1. The formula optimized for is 488 | 489 | 1 / (2 * n_samples) * switch * ||y - Xw||_2 ** 2 490 | + alpha * l1_ratio * ||w||_1 491 | + 0.5 * alpha * (1 - l1_ratio) * ||w||_2 ** 2 492 | 493 | where switch is a vector with value `overestimation_punishment_factor` 494 | if y - Xw < 0, else 1. 495 | 496 | ImbalancedLinearRegression fits a linear model to minimize the residual sum of 497 | squares between the observed targets in the dataset, and the targets predicted 498 | by the linear approximation. Compared to normal linear regression, this approach 499 | allows for a different treatment of over or under estimations. 500 | 501 | Parameters 502 | ---------- 503 | alpha : float, default=0.0 504 | Constant that multiplies the penalty terms. 505 | 506 | l1_ratio : float, default=0.0 507 | The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For 508 | ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it 509 | is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a 510 | combination of L1 and L2. 511 | 512 | fit_intercept : bool, default=True 513 | Whether to calculate the intercept for this model. If set 514 | to False, no intercept will be used in calculations 515 | (i.e. data is expected to be centered). 516 | 517 | copy_X : bool, default=True 518 | If True, X will be copied; else, it may be overwritten. 519 | 520 | positive : bool, default=False 521 | When set to True, forces the coefficients to be positive. 522 | 523 | monotone_constraints : list (optional), default=None 524 | A list containing as many numbers as there are features. The numbers should be 525 | - -1 to indicate that the coefficient in this position should be negative 526 | - 0 if the coefficient is unrestricted, and 527 | - 1 if the coefficient should be positive. 528 | 529 | overestimation_punishment_factor : float, default=1 530 | Factor to punish overestimations more (if the value is larger than 1) or less 531 | (if the value is between 0 and 1). 532 | 533 | Attributes 534 | ---------- 535 | coef_ : np.ndarray of shape (n_features,) 536 | Estimated coefficients of the model. 537 | 538 | intercept_ : float 539 | Independent term in the linear model. Set to 0.0 if fit_intercept = False. 540 | 541 | Notes 542 | ----- 543 | This implementation uses scipy.optimize.minimize, see 544 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html. 545 | 546 | Examples 547 | -------- 548 | >>> import numpy as np 549 | >>> np.random.seed(0) 550 | >>> X = np.random.randn(100, 4) 551 | >>> y = X @ np.array([1, 2, 3, 4]) + 2*np.random.randn(100) 552 | >>> over_bad = ImbalancedLinearRegression(overestimation_punishment_factor=50) 553 | >>> over_bad.fit(X, y) 554 | ImbalancedLinearRegression(overestimation_punishment_factor=50) 555 | >>> over_bad.coef_ 556 | array([0.36267036, 1.39526844, 3.4247146 , 3.93679175]) 557 | 558 | >>> under_bad = ImbalancedLinearRegression(overestimation_punishment_factor=0.01) 559 | >>> under_bad.fit(X, y) 560 | ImbalancedLinearRegression(overestimation_punishment_factor=0.01) 561 | >>> under_bad.coef_ 562 | array([0.73519586, 1.28698197, 2.61362614, 4.35989806]) 563 | 564 | """ 565 | 566 | def __init__( 567 | self, 568 | alpha: float = 0.0, 569 | l1_ratio: float = 0.0, 570 | fit_intercept: bool = True, 571 | copy_X: bool = True, 572 | positive: bool = False, 573 | monotone_constraints: Optional[List[int]] = None, 574 | overestimation_punishment_factor: float = 1.0, 575 | ) -> None: 576 | """Initialize.""" 577 | super().__init__( 578 | alpha, l1_ratio, fit_intercept, copy_X, positive, monotone_constraints 579 | ) 580 | self.overestimation_punishment_factor = overestimation_punishment_factor 581 | 582 | def _get_objective( 583 | self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray 584 | ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: 585 | @self._loss_regularize 586 | def imbalanced_loss(params): 587 | return 0.5 * np.mean( 588 | sample_weight 589 | * np.where(X @ params > y, self.overestimation_punishment_factor, 1) 590 | * np.square(y - X @ params) 591 | ) 592 | 593 | @self._grad_loss_regularize 594 | def grad_imbalanced_loss(params): 595 | return ( 596 | -( 597 | sample_weight 598 | * np.where(X @ params > y, self.overestimation_punishment_factor, 1) 599 | * (y - X @ params) 600 | ) 601 | @ X 602 | / X.shape[0] 603 | ) 604 | 605 | return imbalanced_loss, grad_imbalanced_loss 606 | 607 | 608 | class LinearRegression(BaseScipyMinimizeRegressor): 609 | """ 610 | Just plain and simple linear regression. 611 | 612 | The formula optimized for is 613 | 614 | 1 / (2 * n_samples) * ||y - Xw||_2 ** 2 615 | + alpha * l1_ratio * ||w||_1 616 | + 0.5 * alpha * (1 - l1_ratio) * ||w||_2 ** 2 617 | 618 | Parameters 619 | ---------- 620 | alpha : float, default=0.0 621 | Constant that multiplies the penalty terms. 622 | 623 | l1_ratio : float, default=0.0 624 | The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For 625 | ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it 626 | is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a 627 | combination of L1 and L2. 628 | 629 | fit_intercept : bool, default=True 630 | Whether to calculate the intercept for this model. If set 631 | to False, no intercept will be used in calculations 632 | (i.e. data is expected to be centered). 633 | 634 | copy_X : bool, default=True 635 | If True, X will be copied; else, it may be overwritten. 636 | 637 | positive : bool, default=False 638 | When set to True, forces the coefficients to be positive. 639 | 640 | monotone_constraints : list (optional), default=None 641 | A list containing as many numbers as there are features. The numbers should be 642 | - -1 to indicate that the coefficient in this position should be negative 643 | - 0 if the coefficient is unrestricted, and 644 | - 1 if the coefficient should be positive. 645 | 646 | Attributes 647 | ---------- 648 | coef_ : np.ndarray of shape (n_features,) 649 | Estimated coefficients of the model. 650 | 651 | intercept_ : float 652 | Independent term in the linear model. Set to 0.0 if fit_intercept = False. 653 | 654 | Notes 655 | ----- 656 | This implementation uses scipy.optimize.minimize, see 657 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html. 658 | 659 | Examples 660 | -------- 661 | >>> import numpy as np 662 | >>> np.random.seed(0) 663 | >>> X = np.random.randn(100, 4) 664 | >>> y = X @ np.array([1, 2, 3, 4]) + 2*np.random.randn(100) 665 | >>> lr = LinearRegression().fit(X, y) 666 | >>> lr.coef_ 667 | array([0.73202377, 1.75186186, 2.92983272, 3.96578532]) 668 | 669 | >>> lr = lr = LinearRegression(monotone_constraints=[1, 0, -1, -1]).fit(X, y) 670 | >>> lr.coef_ 671 | array([1.37824805, 2.57277907, 0. , 0. ]) 672 | 673 | """ 674 | 675 | def _get_objective( 676 | self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray 677 | ) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: 678 | @self._loss_regularize 679 | def ols_loss(params): 680 | return 0.5 * np.mean(sample_weight * np.square(y - X @ params)) 681 | 682 | @self._grad_loss_regularize 683 | def grad_ols_loss(params): 684 | return -(sample_weight * (y - X @ params)) @ X / X.shape[0] 685 | 686 | return ols_loss, grad_ols_loss 687 | -------------------------------------------------------------------------------- /mamimo/saturation.py: -------------------------------------------------------------------------------- 1 | """Saturation classes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import List, Optional 7 | 8 | import numpy as np 9 | from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin 10 | from sklearn.utils.validation import ( 11 | FLOAT_DTYPES, 12 | _check_feature_names_in, 13 | check_array, 14 | check_is_fitted, 15 | ) 16 | 17 | 18 | class Saturation(OneToOneFeatureMixin, BaseEstimator, TransformerMixin, ABC): 19 | """Base class for all saturations, such as Box-Cox, Adbudg, ...""" 20 | 21 | def fit(self, X: np.ndarray, y: None = None) -> Saturation: 22 | """ 23 | Fit the transformer. 24 | 25 | In this special case, nothing is done. 26 | 27 | Parameters 28 | ---------- 29 | X : Ignored 30 | Not used, present here for API consistency by convention. 31 | 32 | y : Ignored 33 | Not used, present here for API consistency by convention. 34 | 35 | Returns 36 | ------- 37 | Saturation 38 | Fitted transformer. 39 | 40 | """ 41 | _ = self._validate_data(X, dtype=FLOAT_DTYPES) 42 | 43 | return self 44 | 45 | def transform(self, X: np.ndarray) -> np.ndarray: 46 | """ 47 | Apply the saturation effect. 48 | 49 | Parameters 50 | ---------- 51 | X : np.ndarray 52 | Data to be transformed. 53 | 54 | Returns 55 | ------- 56 | np.ndarray 57 | Data with saturation effect applied. 58 | 59 | """ 60 | check_is_fitted(self) 61 | X = check_array(X) 62 | self._check_n_features(X, reset=False) 63 | 64 | return self._transformation(X) 65 | 66 | @abstractmethod 67 | def _transformation(self, X: np.ndarray) -> np.ndarray: 68 | """Generate the transformation formula.""" 69 | 70 | def get_feature_names_out(self, input_features: Optional[List] = None): 71 | """ 72 | Get the output feature names. 73 | 74 | Parameters 75 | ---------- 76 | input_features : list (optional), default0None 77 | Input feature names. 78 | 79 | Returns 80 | ------- 81 | np.ndarray 82 | Output feature names. 83 | 84 | """ 85 | input_features = _check_feature_names_in(self, input_features) 86 | 87 | return np.array(input_features, dtype=object) 88 | 89 | 90 | class BoxCoxSaturation(Saturation): 91 | """ 92 | Apply the Box-Cox saturation. 93 | 94 | The formula is ((x + shift) ** exponent-1) / exponent if exponent!=0, 95 | else ln(x+shift). 96 | 97 | Parameters 98 | ---------- 99 | exponent: float, default=1.0 100 | The exponent. 101 | 102 | shift : float, default=1.0 103 | The shift. 104 | 105 | Examples 106 | -------- 107 | >>> import numpy as np 108 | >>> X = np.array([[1, 1000], [2, 1000], [3, 1000]]) 109 | >>> BoxCoxSaturation(exponent=0.5).fit_transform(X) 110 | array([[ 0.82842712, 61.27716808], 111 | [ 1.46410162, 61.27716808], 112 | [ 2. , 61.27716808]]) 113 | 114 | """ 115 | 116 | def __init__(self, exponent: float = 1.0, shift: float = 1.0) -> None: 117 | """Initialize.""" 118 | self.exponent = exponent 119 | self.shift = shift 120 | 121 | def _transformation(self, X: np.ndarray) -> np.ndarray: 122 | """Generate the transformation formula.""" 123 | if self.exponent != 0: 124 | return ((X + self.shift) ** self.exponent - 1) / self.exponent 125 | else: 126 | return np.log(X + self.shift) 127 | 128 | 129 | class AdbudgSaturation(Saturation): 130 | """ 131 | Apply the Adbudg saturation. 132 | 133 | The formula is x ** exponent / (denominator_shift + x ** exponent). 134 | 135 | Parameters 136 | ---------- 137 | exponent : float, default=1.0 138 | The exponent. 139 | 140 | denominator_shift : float, default=1.0 141 | The shift in the denominator. 142 | 143 | Examples 144 | -------- 145 | >>> import numpy as np 146 | >>> X = np.array([[1, 1000], [2, 1000], [3, 1000]]) 147 | >>> AdbudgSaturation().fit_transform(X) 148 | array([[0.5 , 0.999001 ], 149 | [0.66666667, 0.999001 ], 150 | [0.75 , 0.999001 ]]) 151 | 152 | """ 153 | 154 | def __init__(self, exponent: float = 1.0, denominator_shift: float = 1.0) -> None: 155 | """Initialize.""" 156 | self.exponent = exponent 157 | self.denominator_shift = denominator_shift 158 | 159 | def _transformation(self, X: np.ndarray) -> np.ndarray: 160 | """Generate the transformation formula.""" 161 | return X**self.exponent / (self.denominator_shift + X**self.exponent) 162 | 163 | 164 | class HillSaturation(Saturation): 165 | """ 166 | Apply the Hill saturation. 167 | 168 | The formula is 1 / (1 + (half_saturation / x) ** exponent). 169 | 170 | Parameters 171 | ---------- 172 | exponent : float, default=1.0 173 | The exponent. 174 | 175 | half_saturation : float, default=1.0 176 | The point of half saturation, i.e. Hill(half_saturation) = 0.5. 177 | 178 | Examples 179 | -------- 180 | >>> import numpy as np 181 | >>> X = np.array([[1, 1000], [2, 1000], [3, 1000]]) 182 | >>> HillSaturation().fit_transform(X) 183 | array([[0.5 , 0.999001 ], 184 | [0.66666667, 0.999001 ], 185 | [0.75 , 0.999001 ]]) 186 | 187 | """ 188 | 189 | def __init__(self, exponent: float = 1.0, half_saturation: float = 1.0) -> None: 190 | """Initialize.""" 191 | self.half_saturation = half_saturation 192 | self.exponent = exponent 193 | 194 | def _transformation(self, X: np.ndarray) -> np.ndarray: 195 | """Generate the transformation formula.""" 196 | eps = np.finfo(np.float64).eps 197 | return 1 / (1 + (self.half_saturation / (X + eps)) ** self.exponent) 198 | 199 | 200 | class ExponentialSaturation(Saturation): 201 | """ 202 | Apply exponential saturation. 203 | 204 | The formula is 1 - exp(-exponent * x). 205 | 206 | Parameters 207 | ---------- 208 | exponent : float, default=1.0 209 | The exponent. 210 | 211 | Examples 212 | -------- 213 | >>> import numpy as np 214 | >>> X = np.array([[1, 1000], [2, 1000], [3, 1000]]) 215 | >>> ExponentialSaturation().fit_transform(X) 216 | array([[0.63212056, 1. ], 217 | [0.86466472, 1. ], 218 | [0.95021293, 1. ]]) 219 | 220 | """ 221 | 222 | def __init__(self, exponent: float = 1.0) -> None: 223 | """Initialize.""" 224 | self.exponent = exponent 225 | 226 | def _transformation(self, X: np.ndarray) -> np.ndarray: 227 | """Generate the transformation formula.""" 228 | return 1 - np.exp(-self.exponent * X) 229 | -------------------------------------------------------------------------------- /mamimo/time_utils.py: -------------------------------------------------------------------------------- 1 | """Deal with time features in dataframes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List, Optional 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.base import BaseEstimator, TransformerMixin 10 | from sklearn.utils.validation import ( 11 | FLOAT_DTYPES, 12 | _check_feature_names_in, 13 | check_array, 14 | check_is_fitted, 15 | ) 16 | 17 | 18 | def add_date_indicators(df: pd.DataFrame, **kwargs) -> pd.DataFrame: 19 | """ 20 | Enrich a pandas dataframes with a new column indicating if there is a special date. 21 | 22 | This new column will contain a one for each date specified in the `dates` keyword, 23 | zero otherwise. 24 | 25 | Parameters 26 | ---------- 27 | df : pd.DataFrame 28 | Input dataframe with a DateTime index. 29 | 30 | kwargs : List[str]* 31 | As many inputs as you want of the form date_name=[date_1, date_2, ...], i.e. 32 | christmas=['2020-12-24']. See the example below for more information. 33 | 34 | Returns 35 | ------- 36 | pd.DataFrame 37 | A dataframe with date indicator columns. 38 | 39 | Examples 40 | -------- 41 | >>> import pandas as pd 42 | >>> df = pd.DataFrame( 43 | ... {"A": range(7)}, 44 | ... index=pd.date_range(start="2019-12-29", periods=7) 45 | ... ) 46 | >>> add_date_indicators( 47 | ... df, 48 | ... around_new_year_2020=["2019-12-31", "2020-01-01", "2020-01-02"], 49 | ... other_date_1=["2019-12-29"], 50 | ... other_date_2=["2018-01-01"] 51 | ... ) 52 | A around_new_year_2020 other_date_1 other_date_2 53 | 2019-12-29 0 0 1 0 54 | 2019-12-30 1 0 0 0 55 | 2019-12-31 2 1 0 0 56 | 2020-01-01 3 1 0 0 57 | 2020-01-02 4 1 0 0 58 | 2020-01-03 5 0 0 0 59 | 2020-01-04 6 0 0 0 60 | 61 | """ 62 | return df.assign( 63 | **{name: df.index.isin(dates).astype(int) for name, dates in kwargs.items()} 64 | ) 65 | 66 | 67 | def add_time_features( 68 | df: pd.DataFrame, 69 | second: bool = False, 70 | minute: bool = False, 71 | hour: bool = False, 72 | day_of_week: bool = False, 73 | day_of_month: bool = False, 74 | day_of_year: bool = False, 75 | week_of_month: bool = False, 76 | week_of_year: bool = False, 77 | month: bool = False, 78 | year: bool = False, 79 | ) -> pd.DataFrame: 80 | """ 81 | Enrich pandas dataframes with new time feaure columns. 82 | 83 | These features are easy derivations from the dataframe's 84 | DatetimeIndex, such as the day of week or the month. 85 | 86 | Parameters 87 | ---------- 88 | df: pd.DataFrame 89 | Input dataframe with a DateTime index. 90 | 91 | second : bool, default=False 92 | Whether to extract the day of week from the index and add it as a new column. 93 | 94 | minute : bool, default=False 95 | Whether to extract the day of week from the index and add it as a new column. 96 | 97 | hour : bool, default=False 98 | Whether to extract the day of week from the index and add it as a new column. 99 | 100 | day_of_week : bool, default=False 101 | Whether to extract the day of week from the index and add it as a new column. 102 | 103 | day_of_month : bool, default=False 104 | Whether to extract the day of month from the index and add it as a new column. 105 | 106 | day_of_year : bool, default=False 107 | Whether to extract the day of year from the index and add it as a new column. 108 | 109 | week_of_month : bool, default=False 110 | Whether to extract the week of month from the index and add it as a new column. 111 | 112 | week_of_year : bool, default=False 113 | Whether to extract the week of year from the index and add it as a new column. 114 | 115 | month : bool, default=False 116 | Whether to extract the month from the index and add it as a new column. 117 | 118 | year : bool, default=False 119 | Whether to extract the year from the index and add it as a new column. 120 | 121 | Examples 122 | -------- 123 | >>> import pandas as pd 124 | >>> df = pd.DataFrame( 125 | ... {"A": ["a", "b", "c"]}, 126 | ... index=[ 127 | ... pd.Timestamp("1988-08-08"), 128 | ... pd.Timestamp("2000-01-01"), 129 | ... pd.Timestamp("1950-12-31"), 130 | ... ]) 131 | >>> add_time_features(df, day_of_month=True, month=True, year=True) 132 | A day_of_month month year 133 | 1988-08-08 a 8 8 1988 134 | 2000-01-01 b 1 1 2000 135 | 1950-12-31 c 31 12 1950 136 | 137 | """ 138 | 139 | def _add_second(df: pd.DataFrame) -> pd.DataFrame: 140 | return df.assign(second=df.index.second) if second else df 141 | 142 | def _add_minute(df: pd.DataFrame) -> pd.DataFrame: 143 | return df.assign(minute=df.index.minute) if minute else df 144 | 145 | def _add_hour(df: pd.DataFrame) -> pd.DataFrame: 146 | return df.assign(hour=df.index.hour) if hour else df 147 | 148 | def _add_day_of_week(df: pd.DataFrame) -> pd.DataFrame: 149 | return df.assign(day_of_week=df.index.weekday + 1) if day_of_week else df 150 | 151 | def _add_day_of_month(df: pd.DataFrame) -> pd.DataFrame: 152 | return df.assign(day_of_month=df.index.day) if day_of_month else df 153 | 154 | def _add_day_of_year(df: pd.DataFrame) -> pd.DataFrame: 155 | return df.assign(day_of_year=df.index.dayofyear) if day_of_year else df 156 | 157 | def _add_week_of_month(df: pd.DataFrame) -> pd.DataFrame: 158 | return ( 159 | df.assign(week_of_month=np.ceil(df.index.day / 7).astype(int)) 160 | if week_of_month 161 | else df 162 | ) 163 | 164 | def _add_week_of_year(df: pd.DataFrame) -> pd.DataFrame: 165 | return ( 166 | df.assign(week_of_year=df.index.isocalendar().week) if week_of_year else df 167 | ) 168 | 169 | def _add_month(df: pd.DataFrame) -> pd.DataFrame: 170 | return df.assign(month=df.index.month) if month else df 171 | 172 | def _add_year(df: pd.DataFrame) -> pd.DataFrame: 173 | return df.assign(year=df.index.year) if year else df 174 | 175 | return ( 176 | df.pipe(_add_second) 177 | .pipe(_add_minute) 178 | .pipe(_add_hour) 179 | .pipe(_add_day_of_week) 180 | .pipe(_add_day_of_month) 181 | .pipe(_add_day_of_year) 182 | .pipe(_add_week_of_month) 183 | .pipe(_add_week_of_year) 184 | .pipe(_add_month) 185 | .pipe(_add_year) 186 | ) 187 | 188 | 189 | class PowerTrend(BaseEstimator, TransformerMixin): 190 | """ 191 | Apply a power function to a trend. 192 | 193 | This takes an x and computes x ^ power from it. 194 | 195 | Parameters 196 | ---------- 197 | power : float, default=1.0 198 | The power. 199 | 200 | Examples 201 | -------- 202 | >>> import numpy as np 203 | >>> X = np.array([[1], [2], [3]]) 204 | >>> PowerTrend(power=1.5).fit_transform(X) 205 | array([[1. ], 206 | [2.82842712], 207 | [5.19615242]]) 208 | 209 | """ 210 | 211 | def __init__(self, power: float = 1.0) -> None: 212 | """Initialize.""" 213 | self.power = power 214 | 215 | def fit(self, X: np.ndarray, y: None = None) -> PowerTrend: 216 | """ 217 | Fit the transformer. 218 | 219 | This takes data and just raises it to `power`. 220 | 221 | Parameters 222 | ---------- 223 | X : np.ndarray 224 | Data to be transformed. This is usually just an integer range from a to b. 225 | 226 | y : Ignored 227 | Not used, present here for API consistency by convention. 228 | 229 | Returns 230 | ------- 231 | PowerTrend 232 | Fitted transformer. 233 | 234 | """ 235 | _ = self._validate_data(X, dtype=FLOAT_DTYPES) 236 | 237 | return self 238 | 239 | def transform(self, X: np.ndarray) -> np.ndarray: 240 | """ 241 | Apply the power function. 242 | 243 | Parameters 244 | ---------- 245 | X : np.ndarray of shape (n_samples, n_features) 246 | Data to be transformed. This is usually just an integer range from a to b. 247 | 248 | Returns 249 | ------- 250 | np.ndarray 251 | Data with power trend applied. 252 | 253 | """ 254 | check_is_fitted(self) 255 | X = check_array(X) 256 | self._check_n_features(X, reset=False) 257 | 258 | return X**self.power 259 | 260 | def get_feature_names_out(self, input_features: Optional[List] = None): 261 | """ 262 | Get the output feature names. 263 | 264 | Parameters 265 | ---------- 266 | input_features : list (optional), default0None 267 | Input feature names. 268 | 269 | Returns 270 | ------- 271 | np.ndarray 272 | Output feature names. 273 | 274 | """ 275 | input_features = _check_feature_names_in(self, input_features) 276 | 277 | return np.array(input_features, dtype=object) 278 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "alabaster" 3 | version = "0.7.13" 4 | description = "A configurable sidebar-enabled Sphinx theme" 5 | category = "dev" 6 | optional = false 7 | python-versions = ">=3.6" 8 | 9 | [[package]] 10 | name = "attrs" 11 | version = "23.1.0" 12 | description = "Classes Without Boilerplate" 13 | category = "dev" 14 | optional = false 15 | python-versions = ">=3.7" 16 | 17 | [package.extras] 18 | cov = ["attrs", "coverage[toml] (>=5.3)"] 19 | dev = ["attrs", "pre-commit"] 20 | docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] 21 | tests = ["attrs", "zope-interface"] 22 | tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest-mypy-plugins", "pytest-xdist", "pytest (>=4.3.0)"] 23 | 24 | [[package]] 25 | name = "babel" 26 | version = "2.12.1" 27 | description = "Internationalization utilities" 28 | category = "dev" 29 | optional = false 30 | python-versions = ">=3.7" 31 | 32 | [[package]] 33 | name = "black" 34 | version = "22.12.0" 35 | description = "The uncompromising code formatter." 36 | category = "dev" 37 | optional = false 38 | python-versions = ">=3.7" 39 | 40 | [package.dependencies] 41 | click = ">=8.0.0" 42 | mypy-extensions = ">=0.4.3" 43 | pathspec = ">=0.9.0" 44 | platformdirs = ">=2" 45 | tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} 46 | typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} 47 | 48 | [package.extras] 49 | colorama = ["colorama (>=0.4.3)"] 50 | d = ["aiohttp (>=3.7.4)"] 51 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] 52 | uvloop = ["uvloop (>=0.15.2)"] 53 | 54 | [[package]] 55 | name = "certifi" 56 | version = "2023.5.7" 57 | description = "Python package for providing Mozilla's CA Bundle." 58 | category = "dev" 59 | optional = false 60 | python-versions = ">=3.6" 61 | 62 | [[package]] 63 | name = "cfgv" 64 | version = "3.3.1" 65 | description = "Validate configuration and produce human readable error messages." 66 | category = "dev" 67 | optional = false 68 | python-versions = ">=3.6.1" 69 | 70 | [[package]] 71 | name = "charset-normalizer" 72 | version = "3.1.0" 73 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 74 | category = "dev" 75 | optional = false 76 | python-versions = ">=3.7.0" 77 | 78 | [[package]] 79 | name = "click" 80 | version = "8.1.3" 81 | description = "Composable command line interface toolkit" 82 | category = "dev" 83 | optional = false 84 | python-versions = ">=3.7" 85 | 86 | [package.dependencies] 87 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 88 | 89 | [[package]] 90 | name = "colorama" 91 | version = "0.4.6" 92 | description = "Cross-platform colored terminal text." 93 | category = "dev" 94 | optional = false 95 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" 96 | 97 | [[package]] 98 | name = "distlib" 99 | version = "0.3.6" 100 | description = "Distribution utilities" 101 | category = "dev" 102 | optional = false 103 | python-versions = "*" 104 | 105 | [[package]] 106 | name = "docutils" 107 | version = "0.19" 108 | description = "Docutils -- Python Documentation Utilities" 109 | category = "dev" 110 | optional = false 111 | python-versions = ">=3.7" 112 | 113 | [[package]] 114 | name = "exceptiongroup" 115 | version = "1.1.1" 116 | description = "Backport of PEP 654 (exception groups)" 117 | category = "dev" 118 | optional = false 119 | python-versions = ">=3.7" 120 | 121 | [package.extras] 122 | test = ["pytest (>=6)"] 123 | 124 | [[package]] 125 | name = "filelock" 126 | version = "3.12.0" 127 | description = "A platform independent file lock." 128 | category = "dev" 129 | optional = false 130 | python-versions = ">=3.7" 131 | 132 | [package.extras] 133 | docs = ["furo (>=2023.3.27)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)", "sphinx (>=6.1.3)"] 134 | testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)", "pytest (>=7.3.1)"] 135 | 136 | [[package]] 137 | name = "flake8" 138 | version = "4.0.1" 139 | description = "the modular source code checker: pep8 pyflakes and co" 140 | category = "dev" 141 | optional = false 142 | python-versions = ">=3.6" 143 | 144 | [package.dependencies] 145 | mccabe = ">=0.6.0,<0.7.0" 146 | pycodestyle = ">=2.8.0,<2.9.0" 147 | pyflakes = ">=2.4.0,<2.5.0" 148 | 149 | [[package]] 150 | name = "hypothesis" 151 | version = "6.75.9" 152 | description = "A library for property-based testing" 153 | category = "dev" 154 | optional = false 155 | python-versions = ">=3.7" 156 | 157 | [package.dependencies] 158 | attrs = ">=19.2.0" 159 | exceptiongroup = {version = ">=1.0.0", markers = "python_version < \"3.11\""} 160 | sortedcontainers = ">=2.1.0,<3.0.0" 161 | 162 | [package.extras] 163 | all = ["black (>=19.10b0)", "click (>=7.0)", "django (>=3.2)", "dpcontracts (>=0.4)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.16.0)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "importlib-metadata (>=3.6)", "backports.zoneinfo (>=0.2.1)", "tzdata (>=2023.3)"] 164 | cli = ["click (>=7.0)", "black (>=19.10b0)", "rich (>=9.0.0)"] 165 | codemods = ["libcst (>=0.3.16)"] 166 | dateutil = ["python-dateutil (>=1.4)"] 167 | django = ["django (>=3.2)"] 168 | dpcontracts = ["dpcontracts (>=0.4)"] 169 | ghostwriter = ["black (>=19.10b0)"] 170 | lark = ["lark (>=0.10.1)"] 171 | numpy = ["numpy (>=1.16.0)"] 172 | pandas = ["pandas (>=1.1)"] 173 | pytest = ["pytest (>=4.6)"] 174 | pytz = ["pytz (>=2014.1)"] 175 | redis = ["redis (>=3.0.0)"] 176 | zoneinfo = ["backports.zoneinfo (>=0.2.1)", "tzdata (>=2023.3)"] 177 | 178 | [[package]] 179 | name = "identify" 180 | version = "2.5.24" 181 | description = "File identification library for Python" 182 | category = "dev" 183 | optional = false 184 | python-versions = ">=3.7" 185 | 186 | [package.extras] 187 | license = ["ukkonen"] 188 | 189 | [[package]] 190 | name = "idna" 191 | version = "3.4" 192 | description = "Internationalized Domain Names in Applications (IDNA)" 193 | category = "dev" 194 | optional = false 195 | python-versions = ">=3.5" 196 | 197 | [[package]] 198 | name = "imagesize" 199 | version = "1.4.1" 200 | description = "Getting image size from png/jpeg/jpeg2000/gif file" 201 | category = "dev" 202 | optional = false 203 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 204 | 205 | [[package]] 206 | name = "importlib-metadata" 207 | version = "6.6.0" 208 | description = "Read metadata from Python packages" 209 | category = "dev" 210 | optional = false 211 | python-versions = ">=3.7" 212 | 213 | [package.dependencies] 214 | zipp = ">=0.5" 215 | 216 | [package.extras] 217 | docs = ["sphinx (>=3.5)", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "furo", "sphinx-lint", "jaraco.tidelift (>=1.4)"] 218 | perf = ["ipython"] 219 | testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "pytest-flake8", "importlib-resources (>=1.3)"] 220 | 221 | [[package]] 222 | name = "iniconfig" 223 | version = "2.0.0" 224 | description = "brain-dead simple config-ini parsing" 225 | category = "dev" 226 | optional = false 227 | python-versions = ">=3.7" 228 | 229 | [[package]] 230 | name = "isort" 231 | version = "5.12.0" 232 | description = "A Python utility / library to sort Python imports." 233 | category = "dev" 234 | optional = false 235 | python-versions = ">=3.8.0" 236 | 237 | [package.extras] 238 | colors = ["colorama (>=0.4.3)"] 239 | requirements-deprecated-finder = ["pip-api", "pipreqs"] 240 | pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] 241 | plugins = ["setuptools"] 242 | 243 | [[package]] 244 | name = "jinja2" 245 | version = "3.1.2" 246 | description = "A very fast and expressive template engine." 247 | category = "dev" 248 | optional = false 249 | python-versions = ">=3.7" 250 | 251 | [package.dependencies] 252 | MarkupSafe = ">=2.0" 253 | 254 | [package.extras] 255 | i18n = ["Babel (>=2.7)"] 256 | 257 | [[package]] 258 | name = "joblib" 259 | version = "1.2.0" 260 | description = "Lightweight pipelining with Python functions" 261 | category = "main" 262 | optional = false 263 | python-versions = ">=3.7" 264 | 265 | [[package]] 266 | name = "markupsafe" 267 | version = "2.1.3" 268 | description = "Safely add untrusted strings to HTML/XML markup." 269 | category = "dev" 270 | optional = false 271 | python-versions = ">=3.7" 272 | 273 | [[package]] 274 | name = "mccabe" 275 | version = "0.6.1" 276 | description = "McCabe checker, plugin for flake8" 277 | category = "dev" 278 | optional = false 279 | python-versions = "*" 280 | 281 | [[package]] 282 | name = "mypy" 283 | version = "0.950" 284 | description = "Optional static typing for Python" 285 | category = "dev" 286 | optional = false 287 | python-versions = ">=3.6" 288 | 289 | [package.dependencies] 290 | mypy-extensions = ">=0.4.3" 291 | tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} 292 | typing-extensions = ">=3.10" 293 | 294 | [package.extras] 295 | dmypy = ["psutil (>=4.0)"] 296 | python2 = ["typed-ast (>=1.4.0,<2)"] 297 | reports = ["lxml"] 298 | 299 | [[package]] 300 | name = "mypy-extensions" 301 | version = "1.0.0" 302 | description = "Type system extensions for programs checked with the mypy type checker." 303 | category = "dev" 304 | optional = false 305 | python-versions = ">=3.5" 306 | 307 | [[package]] 308 | name = "nodeenv" 309 | version = "1.8.0" 310 | description = "Node.js virtual environment builder" 311 | category = "dev" 312 | optional = false 313 | python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" 314 | 315 | [[package]] 316 | name = "numpy" 317 | version = "1.24.3" 318 | description = "Fundamental package for array computing in Python" 319 | category = "main" 320 | optional = false 321 | python-versions = ">=3.8" 322 | 323 | [[package]] 324 | name = "packaging" 325 | version = "23.1" 326 | description = "Core utilities for Python packages" 327 | category = "dev" 328 | optional = false 329 | python-versions = ">=3.7" 330 | 331 | [[package]] 332 | name = "pandas" 333 | version = "1.5.3" 334 | description = "Powerful data structures for data analysis, time series, and statistics" 335 | category = "main" 336 | optional = false 337 | python-versions = ">=3.8" 338 | 339 | [package.dependencies] 340 | numpy = [ 341 | {version = ">=1.20.3", markers = "python_version < \"3.10\""}, 342 | {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, 343 | {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, 344 | ] 345 | python-dateutil = ">=2.8.1" 346 | pytz = ">=2020.1" 347 | 348 | [package.extras] 349 | test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] 350 | 351 | [[package]] 352 | name = "pathspec" 353 | version = "0.11.1" 354 | description = "Utility library for gitignore style pattern matching of file paths." 355 | category = "dev" 356 | optional = false 357 | python-versions = ">=3.7" 358 | 359 | [[package]] 360 | name = "platformdirs" 361 | version = "3.5.1" 362 | description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." 363 | category = "dev" 364 | optional = false 365 | python-versions = ">=3.7" 366 | 367 | [package.extras] 368 | docs = ["furo (>=2023.3.27)", "proselint (>=0.13)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)", "sphinx (>=6.2.1)"] 369 | test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest (>=7.3.1)"] 370 | 371 | [[package]] 372 | name = "pluggy" 373 | version = "1.0.0" 374 | description = "plugin and hook calling mechanisms for python" 375 | category = "dev" 376 | optional = false 377 | python-versions = ">=3.6" 378 | 379 | [package.extras] 380 | dev = ["pre-commit", "tox"] 381 | testing = ["pytest", "pytest-benchmark"] 382 | 383 | [[package]] 384 | name = "pre-commit" 385 | version = "2.21.0" 386 | description = "A framework for managing and maintaining multi-language pre-commit hooks." 387 | category = "dev" 388 | optional = false 389 | python-versions = ">=3.7" 390 | 391 | [package.dependencies] 392 | cfgv = ">=2.0.0" 393 | identify = ">=1.0.0" 394 | nodeenv = ">=0.11.1" 395 | pyyaml = ">=5.1" 396 | virtualenv = ">=20.10.0" 397 | 398 | [[package]] 399 | name = "pycodestyle" 400 | version = "2.8.0" 401 | description = "Python style guide checker" 402 | category = "dev" 403 | optional = false 404 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 405 | 406 | [[package]] 407 | name = "pydocstyle" 408 | version = "6.3.0" 409 | description = "Python docstring style checker" 410 | category = "dev" 411 | optional = false 412 | python-versions = ">=3.6" 413 | 414 | [package.dependencies] 415 | snowballstemmer = ">=2.2.0" 416 | 417 | [package.extras] 418 | toml = ["tomli (>=1.2.3)"] 419 | 420 | [[package]] 421 | name = "pyflakes" 422 | version = "2.4.0" 423 | description = "passive checker of Python programs" 424 | category = "dev" 425 | optional = false 426 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 427 | 428 | [[package]] 429 | name = "pygments" 430 | version = "2.15.1" 431 | description = "Pygments is a syntax highlighting package written in Python." 432 | category = "dev" 433 | optional = false 434 | python-versions = ">=3.7" 435 | 436 | [package.extras] 437 | plugins = ["importlib-metadata"] 438 | 439 | [[package]] 440 | name = "pytest" 441 | version = "7.3.1" 442 | description = "pytest: simple powerful testing with Python" 443 | category = "dev" 444 | optional = false 445 | python-versions = ">=3.7" 446 | 447 | [package.dependencies] 448 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 449 | exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} 450 | iniconfig = "*" 451 | packaging = "*" 452 | pluggy = ">=0.12,<2.0" 453 | tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} 454 | 455 | [package.extras] 456 | testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] 457 | 458 | [[package]] 459 | name = "python-dateutil" 460 | version = "2.8.2" 461 | description = "Extensions to the standard Python datetime module" 462 | category = "main" 463 | optional = false 464 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" 465 | 466 | [package.dependencies] 467 | six = ">=1.5" 468 | 469 | [[package]] 470 | name = "pytz" 471 | version = "2023.3" 472 | description = "World timezone definitions, modern and historical" 473 | category = "main" 474 | optional = false 475 | python-versions = "*" 476 | 477 | [[package]] 478 | name = "pyyaml" 479 | version = "6.0" 480 | description = "YAML parser and emitter for Python" 481 | category = "dev" 482 | optional = false 483 | python-versions = ">=3.6" 484 | 485 | [[package]] 486 | name = "requests" 487 | version = "2.31.0" 488 | description = "Python HTTP for Humans." 489 | category = "dev" 490 | optional = false 491 | python-versions = ">=3.7" 492 | 493 | [package.dependencies] 494 | certifi = ">=2017.4.17" 495 | charset-normalizer = ">=2,<4" 496 | idna = ">=2.5,<4" 497 | urllib3 = ">=1.21.1,<3" 498 | 499 | [package.extras] 500 | socks = ["PySocks (>=1.5.6,!=1.5.7)"] 501 | use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] 502 | 503 | [[package]] 504 | name = "scikit-learn" 505 | version = "1.2.2" 506 | description = "A set of python modules for machine learning and data mining" 507 | category = "main" 508 | optional = false 509 | python-versions = ">=3.8" 510 | 511 | [package.dependencies] 512 | joblib = ">=1.1.1" 513 | numpy = ">=1.17.3" 514 | scipy = ">=1.3.2" 515 | threadpoolctl = ">=2.0.0" 516 | 517 | [package.extras] 518 | benchmark = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "memory-profiler (>=0.57.0)"] 519 | docs = ["matplotlib (>=3.1.3)", "scikit-image (>=0.16.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.2.0)", "Pillow (>=7.1.2)", "pooch (>=1.6.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)", "plotly (>=5.10.0)"] 520 | examples = ["matplotlib (>=3.1.3)", "scikit-image (>=0.16.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "pooch (>=1.6.0)", "plotly (>=5.10.0)"] 521 | tests = ["matplotlib (>=3.1.3)", "scikit-image (>=0.16.2)", "pandas (>=1.0.5)", "pytest (>=5.3.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=22.3.0)", "mypy (>=0.961)", "pyamg (>=4.0.0)", "numpydoc (>=1.2.0)", "pooch (>=1.6.0)"] 522 | 523 | [[package]] 524 | name = "scipy" 525 | version = "1.9.3" 526 | description = "Fundamental algorithms for scientific computing in Python" 527 | category = "main" 528 | optional = false 529 | python-versions = ">=3.8" 530 | 531 | [package.dependencies] 532 | numpy = ">=1.18.5,<1.26.0" 533 | 534 | [package.extras] 535 | test = ["pytest", "pytest-cov", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack"] 536 | doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-panels (>=0.5.2)", "matplotlib (>2)", "numpydoc", "sphinx-tabs"] 537 | dev = ["mypy", "typing-extensions", "pycodestyle", "flake8"] 538 | 539 | [[package]] 540 | name = "six" 541 | version = "1.16.0" 542 | description = "Python 2 and 3 compatibility utilities" 543 | category = "main" 544 | optional = false 545 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 546 | 547 | [[package]] 548 | name = "snowballstemmer" 549 | version = "2.2.0" 550 | description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." 551 | category = "dev" 552 | optional = false 553 | python-versions = "*" 554 | 555 | [[package]] 556 | name = "sortedcontainers" 557 | version = "2.4.0" 558 | description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" 559 | category = "dev" 560 | optional = false 561 | python-versions = "*" 562 | 563 | [[package]] 564 | name = "sphinx" 565 | version = "5.3.0" 566 | description = "Python documentation generator" 567 | category = "dev" 568 | optional = false 569 | python-versions = ">=3.6" 570 | 571 | [package.dependencies] 572 | alabaster = ">=0.7,<0.8" 573 | babel = ">=2.9" 574 | colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} 575 | docutils = ">=0.14,<0.20" 576 | imagesize = ">=1.3" 577 | importlib-metadata = {version = ">=4.8", markers = "python_version < \"3.10\""} 578 | Jinja2 = ">=3.0" 579 | packaging = ">=21.0" 580 | Pygments = ">=2.12" 581 | requests = ">=2.5.0" 582 | snowballstemmer = ">=2.0" 583 | sphinxcontrib-applehelp = "*" 584 | sphinxcontrib-devhelp = "*" 585 | sphinxcontrib-htmlhelp = ">=2.0.0" 586 | sphinxcontrib-jsmath = "*" 587 | sphinxcontrib-qthelp = "*" 588 | sphinxcontrib-serializinghtml = ">=1.1.5" 589 | 590 | [package.extras] 591 | docs = ["sphinxcontrib-websupport"] 592 | lint = ["flake8 (>=3.5.0)", "flake8-comprehensions", "flake8-bugbear", "flake8-simplify", "isort", "mypy (>=0.981)", "sphinx-lint", "docutils-stubs", "types-typed-ast", "types-requests"] 593 | test = ["pytest (>=4.6)", "html5lib", "typed-ast", "cython"] 594 | 595 | [[package]] 596 | name = "sphinxcontrib-applehelp" 597 | version = "1.0.4" 598 | description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" 599 | category = "dev" 600 | optional = false 601 | python-versions = ">=3.8" 602 | 603 | [package.extras] 604 | lint = ["flake8", "mypy", "docutils-stubs"] 605 | test = ["pytest"] 606 | 607 | [[package]] 608 | name = "sphinxcontrib-devhelp" 609 | version = "1.0.2" 610 | description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." 611 | category = "dev" 612 | optional = false 613 | python-versions = ">=3.5" 614 | 615 | [package.extras] 616 | test = ["pytest"] 617 | lint = ["docutils-stubs", "mypy", "flake8"] 618 | 619 | [[package]] 620 | name = "sphinxcontrib-htmlhelp" 621 | version = "2.0.1" 622 | description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" 623 | category = "dev" 624 | optional = false 625 | python-versions = ">=3.8" 626 | 627 | [package.extras] 628 | lint = ["flake8", "mypy", "docutils-stubs"] 629 | test = ["pytest", "html5lib"] 630 | 631 | [[package]] 632 | name = "sphinxcontrib-jsmath" 633 | version = "1.0.1" 634 | description = "A sphinx extension which renders display math in HTML via JavaScript" 635 | category = "dev" 636 | optional = false 637 | python-versions = ">=3.5" 638 | 639 | [package.extras] 640 | test = ["mypy", "flake8", "pytest"] 641 | 642 | [[package]] 643 | name = "sphinxcontrib-qthelp" 644 | version = "1.0.3" 645 | description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." 646 | category = "dev" 647 | optional = false 648 | python-versions = ">=3.5" 649 | 650 | [package.extras] 651 | test = ["pytest"] 652 | lint = ["docutils-stubs", "mypy", "flake8"] 653 | 654 | [[package]] 655 | name = "sphinxcontrib-serializinghtml" 656 | version = "1.1.5" 657 | description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." 658 | category = "dev" 659 | optional = false 660 | python-versions = ">=3.5" 661 | 662 | [package.extras] 663 | lint = ["flake8", "mypy", "docutils-stubs"] 664 | test = ["pytest"] 665 | 666 | [[package]] 667 | name = "threadpoolctl" 668 | version = "3.1.0" 669 | description = "threadpoolctl" 670 | category = "main" 671 | optional = false 672 | python-versions = ">=3.6" 673 | 674 | [[package]] 675 | name = "tomli" 676 | version = "2.0.1" 677 | description = "A lil' TOML parser" 678 | category = "dev" 679 | optional = false 680 | python-versions = ">=3.7" 681 | 682 | [[package]] 683 | name = "typing-extensions" 684 | version = "4.6.3" 685 | description = "Backported and Experimental Type Hints for Python 3.7+" 686 | category = "dev" 687 | optional = false 688 | python-versions = ">=3.7" 689 | 690 | [[package]] 691 | name = "urllib3" 692 | version = "2.0.2" 693 | description = "HTTP library with thread-safe connection pooling, file post, and more." 694 | category = "dev" 695 | optional = false 696 | python-versions = ">=3.7" 697 | 698 | [package.extras] 699 | brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] 700 | secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] 701 | socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] 702 | zstd = ["zstandard (>=0.18.0)"] 703 | 704 | [[package]] 705 | name = "virtualenv" 706 | version = "20.23.0" 707 | description = "Virtual Python Environment builder" 708 | category = "dev" 709 | optional = false 710 | python-versions = ">=3.7" 711 | 712 | [package.dependencies] 713 | distlib = ">=0.3.6,<1" 714 | filelock = ">=3.11,<4" 715 | platformdirs = ">=3.2,<4" 716 | 717 | [package.extras] 718 | docs = ["furo (>=2023.3.27)", "proselint (>=0.13)", "sphinx-argparse (>=0.4)", "sphinx (>=6.1.3)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=22.12)"] 719 | test = ["covdefaults (>=2.3)", "coverage-enable-subprocess (>=1)", "coverage (>=7.2.3)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest-env (>=0.8.1)", "pytest-freezegun (>=0.4.2)", "pytest-mock (>=3.10)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "pytest (>=7.3.1)", "setuptools (>=67.7.1)", "time-machine (>=2.9)"] 720 | 721 | [[package]] 722 | name = "zipp" 723 | version = "3.15.0" 724 | description = "Backport of pathlib-compatible object wrapper for zip files" 725 | category = "dev" 726 | optional = false 727 | python-versions = ">=3.7" 728 | 729 | [package.extras] 730 | docs = ["sphinx (>=3.5)", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "furo", "sphinx-lint", "jaraco.tidelift (>=1.4)"] 731 | testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "jaraco.itertools", "jaraco.functools", "more-itertools", "big-o", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "pytest-flake8"] 732 | 733 | [metadata] 734 | lock-version = "1.1" 735 | python-versions = "^3.9" 736 | content-hash = "75b3581ea890f422559ba08feda7dec9d229ec23291cebf501b8ff55b8330450" 737 | 738 | [metadata.files] 739 | alabaster = [] 740 | attrs = [] 741 | babel = [] 742 | black = [] 743 | certifi = [] 744 | cfgv = [ 745 | {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"}, 746 | {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"}, 747 | ] 748 | charset-normalizer = [] 749 | click = [ 750 | {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, 751 | {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, 752 | ] 753 | colorama = [] 754 | distlib = [] 755 | docutils = [] 756 | exceptiongroup = [] 757 | filelock = [] 758 | flake8 = [ 759 | {file = "flake8-4.0.1-py2.py3-none-any.whl", hash = "sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d"}, 760 | {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"}, 761 | ] 762 | hypothesis = [] 763 | identify = [] 764 | idna = [] 765 | imagesize = [] 766 | importlib-metadata = [] 767 | iniconfig = [] 768 | isort = [] 769 | jinja2 = [] 770 | joblib = [] 771 | markupsafe = [] 772 | mccabe = [ 773 | {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, 774 | {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, 775 | ] 776 | mypy = [ 777 | {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, 778 | {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, 779 | {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, 780 | {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, 781 | {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, 782 | {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, 783 | {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, 784 | {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, 785 | {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, 786 | {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, 787 | {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, 788 | {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, 789 | {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, 790 | {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, 791 | {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, 792 | {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, 793 | {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, 794 | {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, 795 | {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, 796 | {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, 797 | {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, 798 | {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, 799 | {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, 800 | ] 801 | mypy-extensions = [] 802 | nodeenv = [] 803 | numpy = [] 804 | packaging = [] 805 | pandas = [] 806 | pathspec = [] 807 | platformdirs = [] 808 | pluggy = [ 809 | {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, 810 | {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, 811 | ] 812 | pre-commit = [] 813 | pycodestyle = [ 814 | {file = "pycodestyle-2.8.0-py2.py3-none-any.whl", hash = "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20"}, 815 | {file = "pycodestyle-2.8.0.tar.gz", hash = "sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f"}, 816 | ] 817 | pydocstyle = [] 818 | pyflakes = [ 819 | {file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"}, 820 | {file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"}, 821 | ] 822 | pygments = [] 823 | pytest = [] 824 | python-dateutil = [ 825 | {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, 826 | {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, 827 | ] 828 | pytz = [] 829 | pyyaml = [ 830 | {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, 831 | {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, 832 | {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, 833 | {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, 834 | {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, 835 | {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, 836 | {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, 837 | {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, 838 | {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, 839 | {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, 840 | {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, 841 | {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, 842 | {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, 843 | {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, 844 | {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, 845 | {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, 846 | {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, 847 | {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, 848 | {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, 849 | {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, 850 | {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, 851 | {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, 852 | {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, 853 | {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, 854 | {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, 855 | {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, 856 | {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, 857 | {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, 858 | {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, 859 | {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, 860 | {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, 861 | {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, 862 | {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, 863 | ] 864 | requests = [] 865 | scikit-learn = [] 866 | scipy = [] 867 | six = [ 868 | {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, 869 | {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, 870 | ] 871 | snowballstemmer = [ 872 | {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, 873 | {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, 874 | ] 875 | sortedcontainers = [] 876 | sphinx = [] 877 | sphinxcontrib-applehelp = [] 878 | sphinxcontrib-devhelp = [] 879 | sphinxcontrib-htmlhelp = [] 880 | sphinxcontrib-jsmath = [] 881 | sphinxcontrib-qthelp = [] 882 | sphinxcontrib-serializinghtml = [] 883 | threadpoolctl = [ 884 | {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, 885 | {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, 886 | ] 887 | tomli = [ 888 | {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, 889 | {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, 890 | ] 891 | typing-extensions = [] 892 | urllib3 = [] 893 | virtualenv = [] 894 | zipp = [] 895 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mamimo" 3 | version = "0.4.3" 4 | description = "A package to create marketing mix models." 5 | authors = ["Robert Kübler "] 6 | repository = "https://github.com/Garve/mamimo" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | scikit-learn = "^1.0.2" 12 | numpy = "^1.22.3" 13 | pandas = "^1.4.2" 14 | 15 | [tool.poetry.dev-dependencies] 16 | black = "^22.3.0" 17 | flake8 = "^4.0.1" 18 | isort = "^5.10.1" 19 | mypy = "^0.950" 20 | pre-commit = "^2.18.1" 21 | pydocstyle = "^6.1.1" 22 | pytest = "^7.1.2" 23 | Sphinx = "^5.2.2" 24 | hypothesis = "^6.59.0" 25 | 26 | [build-system] 27 | requires = ["poetry-core>=1.0.0"] 28 | build-backend = "poetry.core.masonry.api" 29 | 30 | [tool.isort] 31 | profile = "black" -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Testing.""" 2 | -------------------------------------------------------------------------------- /tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | """Test analysis.""" 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.compose import make_column_transformer 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.preprocessing import OneHotEncoder 8 | 9 | from mamimo.analysis import breakdown 10 | from mamimo.carryover import ExponentialCarryover 11 | from mamimo.datasets import load_fake_mmm 12 | from mamimo.linear_model import LinearRegression 13 | from mamimo.saturation import ExponentialSaturation 14 | from mamimo.time_utils import PowerTrend, add_date_indicators, add_time_features 15 | 16 | 17 | @pytest.fixture() 18 | def create_model(): 19 | """Create a model to test the breakdown function.""" 20 | data = load_fake_mmm() 21 | 22 | X = data.drop(columns=["Sales"]) 23 | y = data.Sales 24 | 25 | X = X.pipe(add_time_features, month=True).pipe( 26 | add_date_indicators, special_date=["2020-01-05"] 27 | ) 28 | X["Trend"] = range(200) 29 | 30 | preprocessing = make_column_transformer( 31 | ( 32 | make_pipeline( 33 | ExponentialCarryover(window=4, strength=0.5), 34 | ExponentialSaturation(exponent=0.0001), 35 | ), 36 | ["TV"], 37 | ), 38 | ( 39 | make_pipeline( 40 | ExponentialCarryover(window=2, strength=0.2), 41 | ExponentialSaturation(exponent=0.0001), 42 | ), 43 | ["Radio"], 44 | ), 45 | ( 46 | make_pipeline( 47 | ExponentialCarryover(), ExponentialSaturation(exponent=0.0001) 48 | ), 49 | ["Banners"], 50 | ), 51 | (OneHotEncoder(sparse_output=False), ["month"]), 52 | (PowerTrend(power=1.2), ["Trend"]), 53 | (ExponentialCarryover(window=10, strength=0.6), ["special_date"]), 54 | ) 55 | 56 | model = make_pipeline( 57 | preprocessing, LinearRegression(positive=True, fit_intercept=False) 58 | ) 59 | 60 | return model.fit(X, y), X, y 61 | 62 | 63 | def test_breakdown(create_model): 64 | """Tests if the sum of channel contribution equals the observed targets.""" 65 | model, X, y = create_model 66 | breakdowns = breakdown(model, X, y) 67 | 68 | np.testing.assert_array_almost_equal(breakdowns.sum(axis=1), y) 69 | 70 | 71 | def test_group(create_model): 72 | """Checks if grouping together channels works.""" 73 | model, X, y = create_model 74 | 75 | breakdowns = breakdown( 76 | model, 77 | X, 78 | y, 79 | group_channels={ 80 | "Base": [f"onehotencoder__month_{i}" for i in range(1, 13)] 81 | + ["powertrend__Trend"], 82 | "Media": ["pipeline-1__TV", "pipeline-2__Radio", "pipeline-3__Banners"], 83 | }, 84 | ) 85 | 86 | assert breakdowns.columns.tolist() == [ 87 | "exponentialcarryover__special_date", 88 | "Base", 89 | "Media", 90 | ] 91 | -------------------------------------------------------------------------------- /tests/test_carryover.py: -------------------------------------------------------------------------------- 1 | """Test carryover.""" 2 | 3 | import numpy as np 4 | import pytest 5 | from hypothesis import given 6 | from hypothesis.strategies import floats, lists 7 | from sklearn.utils.estimator_checks import check_estimator 8 | 9 | from mamimo.carryover import ExponentialCarryover, GeneralGaussianCarryover 10 | 11 | numpy_arrays = lists(floats(min_value=0, max_value=1e30), min_size=1).map( 12 | lambda x: np.array(x).reshape(-1, 1) 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "estimator", 18 | [ 19 | ExponentialCarryover(), 20 | GeneralGaussianCarryover(), 21 | ], 22 | ) 23 | def test_check_estimator(estimator): 24 | """Test if check_estimator passes.""" 25 | check_estimator(estimator) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "estimator", 30 | [ 31 | ExponentialCarryover(), 32 | GeneralGaussianCarryover(), 33 | ], 34 | ) 35 | @given(inputs=numpy_arrays) 36 | def test_output_is_greater_than_input(inputs, estimator): 37 | """Test if carryover increases the values of the original input array.""" 38 | outputs = estimator.fit_transform(inputs) 39 | assert (outputs >= inputs).all() 40 | 41 | 42 | @pytest.mark.parametrize( 43 | "estimator", 44 | [ 45 | ExponentialCarryover(), 46 | GeneralGaussianCarryover(), 47 | ], 48 | ) 49 | @given(inputs=numpy_arrays) 50 | def test_output_is_equal_to_input_in_the_first_component(inputs, estimator): 51 | """Test if carryover is equal to the input array in the first component.""" 52 | outputs = estimator.fit_transform(inputs) 53 | assert outputs[0] == inputs[0] 54 | -------------------------------------------------------------------------------- /tests/test_linear_model.py: -------------------------------------------------------------------------------- 1 | """Test the linear models.""" 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.utils.estimator_checks import check_estimator 6 | 7 | from mamimo.linear_model import ( 8 | ImbalancedLinearRegression, 9 | LADRegression, 10 | LinearRegression, 11 | QuantileRegression, 12 | ) 13 | 14 | test_batch = [ 15 | (np.array([0, 0, 3, 0, 6]), 3), 16 | (np.array([1, 0, -2, 0, 4, 0, -5, 0, 6]), 2), 17 | (np.array([4, -4]), 0), 18 | (np.array([0.1]), 1000), 19 | ] 20 | 21 | 22 | def _create_dataset(coefs, intercept, noise=0.0): 23 | np.random.seed(0) 24 | X = np.random.randn(1000, coefs.shape[0]) 25 | y = X @ coefs + intercept + noise * np.random.randn(1000) 26 | 27 | return X, y 28 | 29 | 30 | @pytest.mark.parametrize("coefs, intercept", test_batch) 31 | @pytest.mark.parametrize( 32 | "model", 33 | [LADRegression, QuantileRegression, ImbalancedLinearRegression, LinearRegression], 34 | ) 35 | def test_coefs_and_intercept__no_noise(coefs, intercept, model): 36 | """Regression problems without noise.""" 37 | X, y = _create_dataset(coefs, intercept) 38 | regressor = model() 39 | regressor.fit(X, y) 40 | 41 | assert regressor.score(X, y) > 0.99 42 | 43 | 44 | @pytest.mark.parametrize("coefs, intercept", test_batch) 45 | @pytest.mark.parametrize( 46 | "model", 47 | [LADRegression, QuantileRegression, ImbalancedLinearRegression, LinearRegression], 48 | ) 49 | def test_score(coefs, intercept, model): 50 | """Tests with noise on an easy problem. Parameter reconstruction should be easy.""" 51 | X, y = _create_dataset(coefs, intercept, noise=1.0) 52 | regressor = model() 53 | regressor.fit(X, y) 54 | 55 | np.testing.assert_almost_equal(regressor.coef_, coefs, decimal=1) 56 | np.testing.assert_almost_equal(regressor.intercept_, intercept, decimal=1) 57 | 58 | 59 | @pytest.mark.parametrize("coefs, intercept", test_batch) 60 | @pytest.mark.parametrize( 61 | "model", 62 | [LADRegression, QuantileRegression, ImbalancedLinearRegression, LinearRegression], 63 | ) 64 | def test_coefs_and_intercept__no_noise_positive(coefs, intercept, model): 65 | """Test with only positive coefficients.""" 66 | X, y = _create_dataset(coefs, intercept, noise=0.0) 67 | regressor = model(positive=True) 68 | regressor.fit(X, y) 69 | 70 | assert all(regressor.coef_ >= 0) 71 | assert regressor.score(X, y) > 0.3 72 | 73 | 74 | @pytest.mark.parametrize("coefs, intercept", test_batch) 75 | @pytest.mark.parametrize( 76 | "model", 77 | [LADRegression, QuantileRegression, ImbalancedLinearRegression, LinearRegression], 78 | ) 79 | def test_fit_intercept_and_copy(coefs, intercept, model): 80 | """Test if fit_intercept and copy_X work.""" 81 | X, y = _create_dataset(coefs, intercept, noise=2.0) 82 | regressor = model(fit_intercept=False, copy_X=False) 83 | regressor.fit(X, y) 84 | 85 | assert regressor.intercept_ == 0.0 86 | 87 | 88 | @pytest.mark.parametrize( 89 | "model", 90 | [LADRegression, QuantileRegression, ImbalancedLinearRegression, LinearRegression], 91 | ) 92 | def test_check_estimator(model): 93 | """Conduct all scikit-learn estimator tests.""" 94 | regressor = model() 95 | 96 | check_estimator(regressor) 97 | -------------------------------------------------------------------------------- /tests/test_saturation.py: -------------------------------------------------------------------------------- 1 | """Test saturation.""" 2 | 3 | import pytest 4 | from sklearn.utils.estimator_checks import check_estimator 5 | 6 | from mamimo.saturation import ( 7 | AdbudgSaturation, 8 | BoxCoxSaturation, 9 | ExponentialSaturation, 10 | HillSaturation, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "estimator", 16 | [ 17 | BoxCoxSaturation(), 18 | AdbudgSaturation(), 19 | HillSaturation(), 20 | ExponentialSaturation(), 21 | ], 22 | ) 23 | def test_check_estimator(estimator): 24 | """Test if check_estimator passes.""" 25 | check_estimator(estimator) 26 | --------------------------------------------------------------------------------