├── .github
└── workflows
│ ├── deploy-docs.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .travis.yml
├── CONTRIBUTING.md
├── LICENSE.txt
├── MANIFEST.in
├── PSL_catalog.json
├── README.md
├── conda.recipe
├── bld.bat
├── build.sh
└── meta.yaml
├── docs
├── _config.yml
├── _toc.yml
├── api
│ ├── custom-adjust.ipynb
│ ├── custom-types.ipynb
│ ├── extend.ipynb
│ ├── guide.md
│ ├── indexing.ipynb
│ ├── reference.rst
│ └── viewing-data.ipynb
├── intro.ipynb
└── parameters.md
├── environment.yml
├── paramtools
├── __init__.py
├── contrib
│ ├── __init__.py
│ ├── fields.py
│ └── validate.py
├── examples
│ ├── baseball
│ │ └── defaults.json
│ ├── behresp
│ │ └── defaults.json
│ ├── taxparams-demo
│ │ └── defaults.json
│ └── taxparams
│ │ └── defaults.json
├── exceptions.py
├── parameters.py
├── schema.py
├── schema_factory.py
├── select.py
├── sorted_key_list.py
├── tests
│ ├── __init__.py
│ ├── defaults.json
│ ├── extend_ex.json
│ ├── test_examples
│ │ ├── __init__.py
│ │ ├── test_baseball.py
│ │ ├── test_behresp.py
│ │ └── test_tc_ex.py
│ ├── test_fields.py
│ ├── test_parameters.py
│ ├── test_schema.py
│ ├── test_select.py
│ ├── test_sorted_key_list.py
│ ├── test_utils.py
│ ├── test_validate.py
│ └── test_values.py
├── typing.py
├── utils.py
└── values.py
├── pyproject.toml
└── setup.py
/.github/workflows/deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: Build and Deploy Jupyter Book
2 | on:
3 | push:
4 | branches:
5 | - master
6 |
7 | jobs:
8 | build-and-deploy:
9 | if: github.repository == 'PSLmodels/ParamTools'
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v2 # If you're using actions/checkout@v2 you must set persist-credentials to false in most cases for the deployment to work correctly.
14 | with:
15 | persist-credentials: false
16 |
17 | - name: Setup Miniconda
18 | uses: conda-incubator/setup-miniconda@v3
19 | with:
20 | activate-environment: paramtools-dev
21 | environment-file: environment.yml
22 | python-version: 3.12
23 | auto-activate-base: false
24 |
25 | - name: Build # Build Jupyter Book
26 | shell: bash -l {0}
27 | run: |
28 | pip install -e .
29 | conda install pandas --yes
30 | jupyter-book build ./docs
31 |
32 | - name: Deploy
33 | uses: JamesIves/github-pages-deploy-action@releases/v3
34 | with:
35 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
36 | BRANCH: gh-pages # The branch the action should deploy to.
37 | FOLDER: docs/_build/html # The folder the action should deploy.
38 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 |
2 | name: Build Package and Test Source Code [Python 3.10, 3.11, 3.12]
3 |
4 | on:
5 | push:
6 | branches:
7 | - master
8 | pull_request: {}
9 |
10 | jobs:
11 | build:
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | python-version: ["3.10", "3.11", "3.12"]
16 |
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@master
20 | with:
21 | persist-credentials: false
22 |
23 | - name: Setup Miniconda using Python ${{ matrix.python-version }}
24 | uses: conda-incubator/setup-miniconda@v3
25 | with:
26 | activate-environment: paramtools-dev
27 | environment-file: environment.yml
28 | python-version: ${{ matrix.python-version }}
29 | auto-activate-base: false
30 |
31 | - name: Build
32 | shell: bash -l {0}
33 | run: |
34 | pip install -e .
35 | - name: Test
36 | shell: bash -l {0}
37 | working-directory: ./
38 | run: |
39 | pytest paramtools -v -s
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: 18.9b0
4 | hooks:
5 | - id: black
6 | language_version: python3
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v2.1.0
9 | hooks:
10 | - id: flake8
11 | args: ["--max-line-length=100", "--ignore=E203,W503,E712"]
12 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: xenial
2 |
3 | language: python
4 |
5 | python:
6 | - "3.6"
7 | - "3.7"
8 |
9 | # command to install dependencies
10 | install:
11 | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
12 | - bash miniconda.sh -b -p $HOME/miniconda
13 | - export PATH="$HOME/miniconda/bin:$PATH"
14 | - conda config --set always_yes yes
15 | - conda update -n base -c defaults conda
16 | - conda create -n paramtools-dev python=$TRAVIS_PYTHON_VERSION;
17 | - source activate paramtools-dev
18 | - conda env update -f environment.yml
19 |
20 | # command to run tests
21 | script:
22 | - pytest -v
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions
4 | ------------------------------
5 | Contributions are welcome! Open a [PR][2] with your changes (and tests to go along with them!). In this PR describe what your change does and link to any relevant issues.
6 |
7 | Feature Requests
8 | ----------------------------------
9 | Please open an [issue][1] describing the feature and its potential use cases.
10 |
11 | Bug Reports
12 | -----------------------
13 | Please open an [issue][1] describing the bug.
14 |
15 | Dev Setup
16 | ------------------------
17 |
18 | Fork the repo and clone it so that you have a copy of the source code. Next, run the following commands in the terminal:
19 |
20 | ```
21 | cd ParamTools
22 | conda env create
23 | conda activate paramtools-dev
24 | pip install -e .
25 | pre-commit install
26 | ```
27 |
28 | Testing
29 | -------------------
30 | ```
31 | py.test -v
32 | ```
33 |
34 |
35 | [1]: https://github.com/PSLmodels/ParamTools/issues
36 | [2]: https://github.com/PSLmodels/ParamTools/pulls
37 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Henry Doupe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include paramtools/examples/taxparams-demo/*.json
2 |
--------------------------------------------------------------------------------
/PSL_catalog.json:
--------------------------------------------------------------------------------
1 | {
2 | "project_one_line": {
3 | "start_header": null,
4 | "end_header": null,
5 | "source": "https://github.com/PSLmodels/ParamTools",
6 | "type": "html",
7 | "data": "
Library for parameter processing and validation with a focus on computational modeling projects
"
8 | },
9 | "key_features": {
10 | "start_header": null,
11 | "end_header": null,
12 | "source": null,
13 | "type": null,
14 | "data": null
15 | },
16 | "project_overview": {
17 | "start_header": "ParamTools",
18 | "end_header": "How to use ParamTools",
19 | "source": "README.md",
20 | "type": "github_file",
21 | "data": null
22 | },
23 | "citation": {
24 | "start_header": null,
25 | "end_header": null,
26 | "source": null,
27 | "data": null,
28 | "type": null
29 | },
30 | "license": {
31 | "start_header": null,
32 | "end_header": null,
33 | "source": "https://github.com/PSLmodels/ParamTools/blob/master/LICENSE.txt",
34 | "data": "MIT
",
35 | "type": "html"
36 | },
37 | "user_documentation": {
38 | "start_header": null,
39 | "end_header": null,
40 | "type": "html",
41 | "source": "https://paramtools.dev",
42 | "data": "https://paramtools.dev"
43 | },
44 | "user_changelog": {
45 | "start_header": null,
46 | "end_header": null,
47 | "source": "https://github.com/PSLmodels/ParamTools/releases",
48 | "type": "html",
49 | "data": "https://github.com/PSLmodels/ParamTools/releases"
50 | },
51 | "user_changelog_recent": {
52 | "start_header": null,
53 | "end_header": null,
54 | "source": "https://github.com/PSLmodels/ParamTools/releases/latest",
55 | "type": "html",
56 | "data": "https://github.com/PSLmodels/ParamTools/releases/latest"
57 | },
58 | "dev_changelog": {
59 | "start_header": null,
60 | "end_header": null,
61 | "source": "https://github.com/PSLmodels/ParamTools/releases",
62 | "type": "html",
63 | "data": "https://github.com/PSLmodels/ParamTools/releases"
64 | },
65 | "disclaimer": {
66 | "start_header": null,
67 | "end_header": null,
68 | "source": null,
69 | "type": null,
70 | "data": null
71 | },
72 | "user_case_studies": {
73 | "start_header": null,
74 | "end_header": null,
75 | "source": null,
76 | "type": null,
77 | "data": null
78 | },
79 | "project_roadmap": {
80 | "start_header": null,
81 | "end_header": null,
82 | "source": null,
83 | "type": null,
84 | "data": null
85 | },
86 | "contributor_overview": {
87 | "start_header": "Contributions",
88 | "end_header": "Dev Setup",
89 | "source": "CONTRIBUTING.md",
90 | "type": "github_file",
91 | "data": null
92 | },
93 | "contributor_guide": {
94 | "start_header": null,
95 | "end_header": null,
96 | "source": "https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md",
97 | "type": "html",
98 | "data": "https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md"
99 | },
100 | "governance_overview": {
101 | "start_header": null,
102 | "end_header": null,
103 | "source": null,
104 | "type": null,
105 | "data": null
106 | },
107 | "public_funding": {
108 | "start_header": null,
109 | "end_header": null,
110 | "source": null,
111 | "type": null,
112 | "data": null
113 | },
114 | "link_to_webapp": {
115 | "data": null,
116 | "source": null,
117 | "type": null,
118 | "start_header": null,
119 | "end_header": null
120 | },
121 | "public_issue_tracker": {
122 | "start_header": null,
123 | "end_header": null,
124 | "data": "https://github.com/PSLmodels/ParamTools/issues",
125 | "source": null,
126 | "type": "html"
127 | },
128 | "public_qanda": {
129 | "start_header": null,
130 | "end_header": null,
131 | "data": "https://github.com/PSLmodels/ParamTools/issues",
132 | "source": null,
133 | "type": "html"
134 | },
135 | "core_maintainers": {
136 | "start_header": null,
137 | "end_header": null,
138 | "data": "",
139 | "source": null,
140 | "type": "html"
141 | },
142 | "unit_test": {
143 | "start_header": null,
144 | "end_header": null,
145 | "data": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests",
146 | "source": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests",
147 | "type": "html"
148 | },
149 | "integration_test": {
150 | "start_header": null,
151 | "end_header": null,
152 | "data": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests",
153 | "source": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests",
154 | "type": "html"
155 | }
156 | }
157 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ParamTools
2 |
3 | **Define, update, and validate your model's parameters.**
4 |
5 | Install using pip:
6 |
7 | ```
8 | pip install paramtools
9 | ```
10 |
11 | Install using conda:
12 |
13 | ```
14 | conda install -c conda-forge paramtools
15 | ```
16 |
17 | ## Usage
18 |
19 | Subclass `paramtools.Parameters` and define your model's [parameters](https://paramtools.dev/parameters):
20 |
21 | ```python
22 | import paramtools
23 |
24 |
25 | class Params(paramtools.Parameters):
26 | defaults = {
27 | "schema": {
28 | "labels": {
29 | "date": {
30 | "type": "date",
31 | "validators": {
32 | "range": {
33 | "min": "2020-01-01",
34 | "max": "2021-01-01",
35 | "step": {"months": 1}
36 | }
37 | }
38 | }
39 | },
40 | },
41 | "a": {
42 | "title": "A",
43 | "type": "int",
44 | "value": [
45 | {"date": "2020-01-01", "value": 2},
46 | {"date": "2020-10-01", "value": 8},
47 | ],
48 | "validators": {
49 | "range" : {
50 | "min": 0, "max": "b"
51 | }
52 | }
53 | },
54 | "b": {
55 | "title": "B",
56 | "type": "float",
57 | "value": [{"date": "2020-01-01", "value": 10.5}]
58 | }
59 | }
60 | ```
61 |
62 | ### Access parameter values
63 |
64 | Access values using `.sel`:
65 |
66 | ```python
67 | params = Params()
68 |
69 | params.sel["a"]
70 | ```
71 |
72 | Values([
73 | {'date': datetime.date(2020, 1, 1), 'value': 2},
74 | {'date': datetime.date(2020, 10, 1), 'value': 8},
75 | ])
76 |
77 | Look up parameter values using a pandas-like api:
78 |
79 | ```python
80 | from datetime import date
81 |
82 | result = params.sel["a"]["date"] == date(2020, 1, 1)
83 | result
84 | ```
85 |
86 | QueryResult([
87 | {'date': datetime.date(2020, 1, 1), 'value': 2}
88 | ])
89 |
90 | ```python
91 | result.isel[0]["value"]
92 | ```
93 |
94 | 2
95 |
96 | ### Adjust and validate parameter values
97 |
98 | Add a new value:
99 |
100 | ```python
101 | params.adjust({"a": [{"date": "2020-11-01", "value": 22}]})
102 |
103 | params.sel["a"]
104 | ```
105 |
106 | Values([
107 | {'date': datetime.date(2020, 1, 1), 'value': 2},
108 | {'date': datetime.date(2020, 10, 1), 'value': 8},
109 | {'date': datetime.date(2020, 11, 1), 'value': 22},
110 | ])
111 |
112 | Update an existing value:
113 |
114 | ```python
115 | params.adjust({"a": [{"date": "2020-01-01", "value": 3}]})
116 |
117 | params.sel["a"]
118 | ```
119 |
120 | Values([
121 | {'date': datetime.date(2020, 1, 1), 'value': 3},
122 | {'date': datetime.date(2020, 10, 1), 'value': 8},
123 | {'date': datetime.date(2020, 11, 1), 'value': 22},
124 | ])
125 |
126 | Update all values:
127 |
128 | ```python
129 | params.adjust({"a": 7})
130 |
131 | params.sel["a"]
132 | ```
133 |
134 | Values([
135 | {'date': datetime.date(2020, 1, 1), 'value': 7},
136 | {'date': datetime.date(2020, 10, 1), 'value': 7},
137 | {'date': datetime.date(2020, 11, 1), 'value': 7},
138 | ])
139 |
140 | Errors on values that are out of range:
141 |
142 | ```python
143 | params.adjust({"a": -1})
144 | ```
145 |
146 | ---------------------------------------------------------------------------
147 |
148 | ValidationError Traceback (most recent call last)
149 |
150 | in
151 | ----> 1 params.adjust({"a": -1})
152 |
153 |
154 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
155 | 253 least one existing value item's corresponding label values.
156 | 254 """
157 | --> 255 return self._adjust(
158 | 256 params_or_path,
159 | 257 ignore_warnings=ignore_warnings,
160 |
161 |
162 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)
163 | 371 not ignore_warnings and has_warnings
164 | 372 ):
165 | --> 373 raise self.validation_error
166 | 374
167 | 375 # Update attrs for params that were adjusted.
168 |
169 |
170 | ValidationError: {
171 | "errors": {
172 | "a": [
173 | "a -1 < min 0 "
174 | ]
175 | }
176 | }
177 |
178 | ```python
179 | params = Params()
180 |
181 | params.adjust({"a": [{"date": "2020-01-01", "value": 11}]})
182 | ```
183 |
184 | ---------------------------------------------------------------------------
185 |
186 | ValidationError Traceback (most recent call last)
187 |
188 | in
189 | 1 params = Params()
190 | 2
191 | ----> 3 params.adjust({"a": [{"date": "2020-01-01", "value": 11}]})
192 |
193 |
194 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
195 | 253 least one existing value item's corresponding label values.
196 | 254 """
197 | --> 255 return self._adjust(
198 | 256 params_or_path,
199 | 257 ignore_warnings=ignore_warnings,
200 |
201 |
202 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)
203 | 371 not ignore_warnings and has_warnings
204 | 372 ):
205 | --> 373 raise self.validation_error
206 | 374
207 | 375 # Update attrs for params that were adjusted.
208 |
209 |
210 | ValidationError: {
211 | "errors": {
212 | "a": [
213 | "a[date=2020-01-01] 11 > max 10.5 b[date=2020-01-01]"
214 | ]
215 | }
216 | }
217 |
218 | Errors on invalid values:
219 |
220 | ```python
221 | params = Params()
222 |
223 | params.adjust({"b": "abc"})
224 | ```
225 |
226 | ---------------------------------------------------------------------------
227 |
228 | ValidationError Traceback (most recent call last)
229 |
230 | in
231 | 1 params = Params()
232 | 2
233 | ----> 3 params.adjust({"b": "abc"})
234 |
235 |
236 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)
237 | 253 least one existing value item's corresponding label values.
238 | 254 """
239 | --> 255 return self._adjust(
240 | 256 params_or_path,
241 | 257 ignore_warnings=ignore_warnings,
242 |
243 |
244 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)
245 | 371 not ignore_warnings and has_warnings
246 | 372 ):
247 | --> 373 raise self.validation_error
248 | 374
249 | 375 # Update attrs for params that were adjusted.
250 |
251 |
252 | ValidationError: {
253 | "errors": {
254 | "b": [
255 | "Not a valid number: abc."
256 | ]
257 | }
258 | }
259 |
260 | ### Extend parameter values using label definitions
261 |
262 | Extend values using `label_to_extend`:
263 |
264 | ```python
265 | params = Params(label_to_extend="date")
266 | ```
267 |
268 | ```python
269 | params.sel["a"]
270 | ```
271 |
272 | Values([
273 | {'date': datetime.date(2020, 1, 1), 'value': 2},
274 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True},
275 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True},
276 | {'date': datetime.date(2020, 4, 1), 'value': 2, '_auto': True},
277 | {'date': datetime.date(2020, 5, 1), 'value': 2, '_auto': True},
278 | {'date': datetime.date(2020, 6, 1), 'value': 2, '_auto': True},
279 | {'date': datetime.date(2020, 7, 1), 'value': 2, '_auto': True},
280 | {'date': datetime.date(2020, 8, 1), 'value': 2, '_auto': True},
281 | {'date': datetime.date(2020, 9, 1), 'value': 2, '_auto': True},
282 | {'date': datetime.date(2020, 10, 1), 'value': 8},
283 | {'date': datetime.date(2020, 11, 1), 'value': 8, '_auto': True},
284 | {'date': datetime.date(2020, 12, 1), 'value': 8, '_auto': True},
285 | {'date': datetime.date(2021, 1, 1), 'value': 8, '_auto': True},
286 | ])
287 |
288 | Updates to values are carried through to future dates:
289 |
290 | ```python
291 | params.adjust({"a": [{"date": "2020-4-01", "value": 9}]})
292 |
293 | params.sel["a"]
294 | ```
295 |
296 | Values([
297 | {'date': datetime.date(2020, 1, 1), 'value': 2},
298 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True},
299 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True},
300 | {'date': datetime.date(2020, 4, 1), 'value': 9},
301 | {'date': datetime.date(2020, 5, 1), 'value': 9, '_auto': True},
302 | {'date': datetime.date(2020, 6, 1), 'value': 9, '_auto': True},
303 | {'date': datetime.date(2020, 7, 1), 'value': 9, '_auto': True},
304 | {'date': datetime.date(2020, 8, 1), 'value': 9, '_auto': True},
305 | {'date': datetime.date(2020, 9, 1), 'value': 9, '_auto': True},
306 | {'date': datetime.date(2020, 10, 1), 'value': 9, '_auto': True},
307 | {'date': datetime.date(2020, 11, 1), 'value': 9, '_auto': True},
308 | {'date': datetime.date(2020, 12, 1), 'value': 9, '_auto': True},
309 | {'date': datetime.date(2021, 1, 1), 'value': 9, '_auto': True},
310 | ])
311 |
312 | Use `clobber` to only update values that were set automatically:
313 |
314 | ```python
315 | params = Params(label_to_extend="date")
316 | params.adjust(
317 | {"a": [{"date": "2020-4-01", "value": 9}]},
318 | clobber=False,
319 | )
320 |
321 | # Sort parameter values by date for nicer output
322 | params.sort_values()
323 | params.sel["a"]
324 | ```
325 |
326 | Values([
327 | {'date': datetime.date(2020, 1, 1), 'value': 2},
328 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True},
329 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True},
330 | {'date': datetime.date(2020, 4, 1), 'value': 9},
331 | {'date': datetime.date(2020, 5, 1), 'value': 9, '_auto': True},
332 | {'date': datetime.date(2020, 6, 1), 'value': 9, '_auto': True},
333 | {'date': datetime.date(2020, 7, 1), 'value': 9, '_auto': True},
334 | {'date': datetime.date(2020, 8, 1), 'value': 9, '_auto': True},
335 | {'date': datetime.date(2020, 9, 1), 'value': 9, '_auto': True},
336 | {'date': datetime.date(2020, 10, 1), 'value': 8},
337 | {'date': datetime.date(2020, 11, 1), 'value': 8, '_auto': True},
338 | {'date': datetime.date(2020, 12, 1), 'value': 8, '_auto': True},
339 | {'date': datetime.date(2021, 1, 1), 'value': 8, '_auto': True},
340 | ])
341 |
342 | ### NumPy integration
343 |
344 | Access values as NumPy arrays with `array_first`:
345 |
346 | ```python
347 | params = Params(label_to_extend="date", array_first=True)
348 |
349 | params.a
350 | ```
351 |
352 | array([2, 2, 2, 2, 2, 2, 2, 2, 2, 8, 8, 8, 8])
353 |
354 | ```python
355 | params.a * params.b
356 | ```
357 |
358 | array([21., 21., 21., 21., 21., 21., 21., 21., 21., 84., 84., 84., 84.])
359 |
360 | Only get the values that you want:
361 |
362 | ```python
363 | arr = params.to_array("a", date=["2020-01-01", "2020-11-01"])
364 | arr
365 | ```
366 |
367 | array([2, 8])
368 |
369 | Go back to a list of dictionaries:
370 |
371 | ```python
372 | params.from_array("a", arr, date=["2020-01-01", "2020-11-01"])
373 | ```
374 |
375 | [{'date': datetime.date(2020, 1, 1), 'value': 2},
376 | {'date': datetime.date(2020, 11, 1), 'value': 8}]
377 |
378 | ## Documentation
379 |
380 | Full documentation available at [paramtools.dev](https://paramtools.dev).
381 |
382 | ## Contributing
383 |
384 | Contributions are welcome! Checkout [CONTRIBUTING.md][3] to get started.
385 |
386 | ## Credits
387 |
388 | ParamTools is built on top of the excellent [marshmallow][1] JSON schema and validation framework. I encourage everyone to check out their repo and documentation. ParamTools was modeled off of [Tax-Calculator's][2] parameter processing and validation engine due to its maturity and sophisticated capabilities.
389 |
390 | [1]: https://github.com/marshmallow-code/marshmallow
391 | [2]: https://github.com/PSLmodels/Tax-Calculator
392 | [3]: https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md
393 |
--------------------------------------------------------------------------------
/conda.recipe/bld.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | SET BLD_DIR=%CD%
4 | cd /D "%RECIPE_DIR%\.."
5 | "%PYTHON%" setup.py install
6 |
--------------------------------------------------------------------------------
/conda.recipe/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | BLD_DIR=`pwd`
4 |
5 | # Recipe and source are stored together
6 | SRC_DIR=$RECIPE_DIR/..
7 | pushd $SRC_DIR
8 |
9 | $PYTHON setup.py install
10 | popd
11 |
--------------------------------------------------------------------------------
/conda.recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: paramtools
3 | version: 0.0.0
4 |
5 | requirements:
6 | build:
7 | - python
8 | - "marshmallow>=4.0.0"
9 | - "numpy>=1.13"
10 | - "python-dateutil>=2.8.0"
11 |
12 | run:
13 | - python
14 | - "marshmallow>=4.0.0"
15 | - "numpy>=1.13"
16 | - "python-dateutil>=2.8.0"
17 |
18 | test:
19 | imports:
20 | - paramtools
21 |
22 | about:
23 | home: https://github.com/PSLmodels/ParamTools
24 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | title: ParamTools
2 | author: Hank Doupe
3 | copyright: "2020"
4 | # logo: "logo.png"
5 |
6 | repository:
7 | url: https://github.com/PSLmodels/ParamTools
8 | path_to_book: "docs"
9 |
10 | sphinx:
11 | extra_extensions: ["sphinx.ext.autodoc", "sphinx.ext.viewcode"]
12 |
13 | execute:
14 | allow_errors: true
15 |
--------------------------------------------------------------------------------
/docs/_toc.yml:
--------------------------------------------------------------------------------
1 | format: jb-book
2 | root: intro
3 | parts:
4 | - caption: API
5 | chapters:
6 | - file: api/guide
7 | - file: api/viewing-data
8 | - file: api/extend
9 | - file: api/indexing
10 | - file: api/custom-adjust
11 | - file: api/custom-types
12 | - file: api/reference
13 | - caption: Parameters
14 | chapters:
15 | - file: parameters
--------------------------------------------------------------------------------
/docs/api/custom-adjust.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Custom Adjustments\n",
8 | "\n",
9 | "The ParamTools adjustment format and logic can be augmented significantly. This is helpful for projects that need to support a pre-existing data format or require custom adjustment logic. Projects should customize their adjustments by writing their own `adjust` method and then calling the default `adjust` method from there:\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import paramtools\n",
19 | "\n",
20 | "\n",
21 | "class Params(paramtools.Parameters):\n",
22 | " def adjust(self, params_or_path, **kwargs):\n",
23 | " params = self.read_params(params_or_path)\n",
24 | "\n",
25 | " # ... custom logic here\n",
26 | "\n",
27 | " # call default adjust method.\n",
28 | " return super().adjust(params, **kwargs)\n",
29 | "\n"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## Example\n",
37 | "\n",
38 | "Some projects may find it convenient to use CSVs for their adjustment format. That's no problem for ParamTools as long as the CSV is converted to a JSON file or Python dictionary that meets the ParamTools criteria.\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 2,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "import io\n",
48 | "import os\n",
49 | "\n",
50 | "import pandas as pd\n",
51 | "\n",
52 | "import paramtools\n",
53 | "\n",
54 | "\n",
55 | "class CSVParams(paramtools.Parameters):\n",
56 | " defaults = {\n",
57 | " \"schema\": {\n",
58 | " \"labels\": {\n",
59 | " \"year\": {\n",
60 | " \"type\": \"int\",\n",
61 | " \"validators\": {\"range\": {\"min\": 2000, \"max\": 2005}}\n",
62 | " }\n",
63 | " }\n",
64 | " },\n",
65 | " \"a\": {\n",
66 | " \"title\": \"A\",\n",
67 | " \"description\": \"a param\",\n",
68 | " \"type\": \"int\",\n",
69 | " \"value\": [\n",
70 | " {\"year\": 2000, \"value\": 1},\n",
71 | " {\"year\": 2001, \"value\": 2},\n",
72 | " ]\n",
73 | " },\n",
74 | " \"b\": {\n",
75 | " \"title\": \"B\",\n",
76 | " \"description\": \"b param\",\n",
77 | " \"type\": \"int\",\n",
78 | " \"value\": [\n",
79 | " {\"year\": 2000, \"value\": 3},\n",
80 | " {\"year\": 2001, \"value\": 4},\n",
81 | " ]\n",
82 | " }\n",
83 | " }\n",
84 | "\n",
85 | " def adjust(self, params_or_path, **kwargs):\n",
86 | " \"\"\"\n",
87 | " A custom adjust method that converts CSV files to\n",
88 | " ParamTools compliant Python dictionaries.\n",
89 | " \"\"\"\n",
90 | " if os.path.exists(params_or_path):\n",
91 | " paramsdf = pd.read_csv(params_or_path, index_col=\"year\")\n",
92 | " else:\n",
93 | " paramsdf = pd.read_csv(io.StringIO(params_or_path), index_col=\"year\")\n",
94 | "\n",
95 | " dfdict = paramsdf.to_dict()\n",
96 | " params = {\"a\": [], \"b\": []}\n",
97 | " for label in params:\n",
98 | " for year, value in dfdict[label].items():\n",
99 | " params[label] += [{\"year\": year, \"value\": value}]\n",
100 | "\n",
101 | " # call adjust method on paramtools.Parameters which will\n",
102 | " # call _adjust to actually do the update.\n",
103 | " return super().adjust(params, **kwargs)\n",
104 | "\n"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "Now we create an example CSV file. To keep the example self-contained, the CSV is just a string, but this example works with CSV files, too. The values of \"A\" are updated to 5 in 2000 and 6 in 2001, and the values of \"B\" are updated to 6 in 2000 and 7 in 2001.\n"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 3,
117 | "metadata": {},
118 | "outputs": [
119 | {
120 | "data": {
121 | "text/plain": [
122 | "OrderedDict([('a',\n",
123 | " [OrderedDict([('year', 2000), ('value', 5)]),\n",
124 | " OrderedDict([('year', 2001), ('value', 6)])]),\n",
125 | " ('b',\n",
126 | " [OrderedDict([('year', 2000), ('value', 6)]),\n",
127 | " OrderedDict([('year', 2001), ('value', 7)])])])"
128 | ]
129 | },
130 | "execution_count": 3,
131 | "metadata": {},
132 | "output_type": "execute_result"
133 | }
134 | ],
135 | "source": [
136 | "# this could also be a path to a CSV file.\n",
137 | "csv_string = \"\"\"\n",
138 | "year,a,b\n",
139 | "2000,5,6\\n\n",
140 | "2001,6,7\\n\n",
141 | "\"\"\"\n",
142 | "\n",
143 | "params = CSVParams()\n",
144 | "params.adjust(csv_string)\n"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 4,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "data": {
154 | "text/plain": [
155 | "[OrderedDict([('year', 2000), ('value', 5)]),\n",
156 | " OrderedDict([('year', 2001), ('value', 6)])]"
157 | ]
158 | },
159 | "execution_count": 4,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "params.a"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 5,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "[OrderedDict([('year', 2000), ('value', 6)]),\n",
177 | " OrderedDict([('year', 2001), ('value', 7)])]"
178 | ]
179 | },
180 | "execution_count": 5,
181 | "metadata": {},
182 | "output_type": "execute_result"
183 | }
184 | ],
185 | "source": [
186 | "params.b"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "Now, if we use `array_first` and [`label_to_extend`](/api/extend/), the params instance can be loaded into a Pandas\n",
194 | "DataFrame like this:\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 6,
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/html": [
205 | "\n",
206 | "\n",
219 | "
\n",
220 | " \n",
221 | " \n",
222 | " | \n",
223 | " a | \n",
224 | " b | \n",
225 | "
\n",
226 | " \n",
227 | " \n",
228 | " \n",
229 | " 0 | \n",
230 | " 5 | \n",
231 | " 6 | \n",
232 | "
\n",
233 | " \n",
234 | " 1 | \n",
235 | " 6 | \n",
236 | " 7 | \n",
237 | "
\n",
238 | " \n",
239 | " 2 | \n",
240 | " 6 | \n",
241 | " 7 | \n",
242 | "
\n",
243 | " \n",
244 | " 3 | \n",
245 | " 6 | \n",
246 | " 7 | \n",
247 | "
\n",
248 | " \n",
249 | " 4 | \n",
250 | " 6 | \n",
251 | " 7 | \n",
252 | "
\n",
253 | " \n",
254 | " 5 | \n",
255 | " 6 | \n",
256 | " 7 | \n",
257 | "
\n",
258 | " \n",
259 | "
\n",
260 | "
"
261 | ],
262 | "text/plain": [
263 | " a b\n",
264 | "0 5 6\n",
265 | "1 6 7\n",
266 | "2 6 7\n",
267 | "3 6 7\n",
268 | "4 6 7\n",
269 | "5 6 7"
270 | ]
271 | },
272 | "execution_count": 6,
273 | "metadata": {},
274 | "output_type": "execute_result"
275 | }
276 | ],
277 | "source": [
278 | "csv_string = \"\"\"\n",
279 | "year,a,b\n",
280 | "2000,5,6\\n\n",
281 | "2001,6,7\\n\n",
282 | "\"\"\"\n",
283 | "\n",
284 | "params = CSVParams(array_first=True, label_to_extend=\"year\")\n",
285 | "params.adjust(csv_string)\n",
286 | "\n",
287 | "params_df = pd.DataFrame.from_dict(params.to_dict())\n",
288 | "params_df\n"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 7,
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/html": [
299 | "\n",
300 | "\n",
313 | "
\n",
314 | " \n",
315 | " \n",
316 | " | \n",
317 | " a | \n",
318 | " b | \n",
319 | "
\n",
320 | " \n",
321 | " year | \n",
322 | " | \n",
323 | " | \n",
324 | "
\n",
325 | " \n",
326 | " \n",
327 | " \n",
328 | " 2000 | \n",
329 | " 5 | \n",
330 | " 6 | \n",
331 | "
\n",
332 | " \n",
333 | " 2001 | \n",
334 | " 6 | \n",
335 | " 7 | \n",
336 | "
\n",
337 | " \n",
338 | " 2002 | \n",
339 | " 6 | \n",
340 | " 7 | \n",
341 | "
\n",
342 | " \n",
343 | " 2003 | \n",
344 | " 6 | \n",
345 | " 7 | \n",
346 | "
\n",
347 | " \n",
348 | " 2004 | \n",
349 | " 6 | \n",
350 | " 7 | \n",
351 | "
\n",
352 | " \n",
353 | " 2005 | \n",
354 | " 6 | \n",
355 | " 7 | \n",
356 | "
\n",
357 | " \n",
358 | "
\n",
359 | "
"
360 | ],
361 | "text/plain": [
362 | " a b\n",
363 | "year \n",
364 | "2000 5 6\n",
365 | "2001 6 7\n",
366 | "2002 6 7\n",
367 | "2003 6 7\n",
368 | "2004 6 7\n",
369 | "2005 6 7"
370 | ]
371 | },
372 | "execution_count": 7,
373 | "metadata": {},
374 | "output_type": "execute_result"
375 | }
376 | ],
377 | "source": [
378 | "params_df[\"year\"] = params.label_grid[\"year\"]\n",
379 | "params_df.set_index(\"year\")"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "metadata": {},
386 | "outputs": [],
387 | "source": []
388 | }
389 | ],
390 | "metadata": {
391 | "kernelspec": {
392 | "display_name": "Python 3",
393 | "language": "python",
394 | "name": "python3"
395 | },
396 | "language_info": {
397 | "codemirror_mode": {
398 | "name": "ipython",
399 | "version": 3
400 | },
401 | "file_extension": ".py",
402 | "mimetype": "text/x-python",
403 | "name": "python",
404 | "nbconvert_exporter": "python",
405 | "pygments_lexer": "ipython3",
406 | "version": "3.8.5"
407 | }
408 | },
409 | "nbformat": 4,
410 | "nbformat_minor": 4
411 | }
412 |
--------------------------------------------------------------------------------
/docs/api/custom-types.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Custom Types\n",
8 | "\n",
9 | "Often, the behavior for a field needs to be customized to support a particular shape or validation method that ParamTools does not support out of the box. In this case, you may use the `register_custom_type` function to add your new `type` to the ParamTools type registry. Each `type` has a corresponding `field` that is used for serialization and deserialization. ParamTools will then use this `field` any time it is handling a `value`, `label`, or `member` that is of this `type`.\n",
10 | "\n",
11 | "ParamTools is built on top of [`marshmallow`](https://github.com/marshmallow-code/marshmallow), a general purpose validation library. This means that you must implement a custom `marshmallow` field to go along with your new type. Please refer to the `marshmallow` [docs](https://marshmallow.readthedocs.io/en/stable/) if you have questions about the use of `marshmallow` in the examples below.\n",
12 | "\n",
13 | "\n",
14 | "## 32 Bit Integer Example\n",
15 | "\n",
16 | "ParamTools's default integer field uses NumPy's `int64` type. This example shows you how to define an `int32` type and reference it in your `defaults`.\n",
17 | "\n",
18 | "First, let's define the Marshmallow class:\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import marshmallow as ma\n",
28 | "import numpy as np\n",
29 | "\n",
30 | "class Int32(ma.fields.Field):\n",
31 | " \"\"\"\n",
32 | " A custom type for np.int32.\n",
33 | " https://numpy.org/devdocs/reference/arrays.dtypes.html\n",
34 | " \"\"\"\n",
35 | " # minor detail that makes this play nice with array_first\n",
36 | " np_type = np.int32\n",
37 | "\n",
38 | " def _serialize(self, value, *args, **kwargs):\n",
39 | " \"\"\"Convert np.int32 to basic, serializable Python int.\"\"\"\n",
40 | " return value.tolist()\n",
41 | "\n",
42 | " def _deserialize(self, value, *args, **kwargs):\n",
43 | " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n",
44 | " converted = np.int32(value)\n",
45 | " return converted\n"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "Now, reference it in our defaults JSON/dict object:\n"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 2,
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stdout",
62 | "output_type": "stream",
63 | "text": [
64 | "value: 2, type: \n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "import paramtools as pt\n",
70 | "\n",
71 | "\n",
72 | "# add int32 type to the paramtools type registry\n",
73 | "pt.register_custom_type(\"int32\", Int32())\n",
74 | "\n",
75 | "\n",
76 | "class Params(pt.Parameters):\n",
77 | " defaults = {\n",
78 | " \"small_int\": {\n",
79 | " \"title\": \"Small integer\",\n",
80 | " \"description\": \"Demonstrate how to define a custom type\",\n",
81 | " \"type\": \"int32\",\n",
82 | " \"value\": 2\n",
83 | " }\n",
84 | " }\n",
85 | "\n",
86 | "\n",
87 | "params = Params(array_first=True)\n",
88 | "\n",
89 | "\n",
90 | "print(f\"value: {params.small_int}, type: {type(params.small_int)}\")\n",
91 | "\n"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "One problem with this is that we could run into some deserialization issues. Due to integer overflow, our deserialized result is not the number that we passed in--it's negative!\n"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 3,
104 | "metadata": {},
105 | "outputs": [
106 | {
107 | "data": {
108 | "text/plain": [
109 | "OrderedDict([('small_int', [OrderedDict([('value', -2147483648)])])])"
110 | ]
111 | },
112 | "execution_count": 3,
113 | "metadata": {},
114 | "output_type": "execute_result"
115 | }
116 | ],
117 | "source": [
118 | "params.adjust(dict(\n",
119 | " # this number wasn't chosen randomly.\n",
120 | " small_int=2147483647 + 1\n",
121 | "))\n"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "### Marshmallow Validator\n",
129 | "\n",
130 | "Fortunately, you can specify a custom validator with `marshmallow` or ParamTools. Making this works requires modifying the `_deserialize` method to check for overflow like this:\n"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 4,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "class Int32(ma.fields.Field):\n",
140 | " \"\"\"\n",
141 | " A custom type for np.int32.\n",
142 | " https://numpy.org/devdocs/reference/arrays.dtypes.html\n",
143 | " \"\"\"\n",
144 | " # minor detail that makes this play nice with array_first\n",
145 | " np_type = np.int32\n",
146 | "\n",
147 | " def _serialize(self, value, *args, **kwargs):\n",
148 | " \"\"\"Convert np.int32 to basic Python int.\"\"\"\n",
149 | " return value.tolist()\n",
150 | "\n",
151 | " def _deserialize(self, value, *args, **kwargs):\n",
152 | " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n",
153 | " converted = np.int32(value)\n",
154 | "\n",
155 | " # check for overflow and let range validator\n",
156 | " # display the error message.\n",
157 | " if converted != int(value):\n",
158 | " return int(value)\n",
159 | "\n",
160 | " return converted\n"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | " Now, let's see how to use `marshmallow` to fix this problem:\n"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 9,
173 | "metadata": {},
174 | "outputs": [
175 | {
176 | "ename": "ValidationError",
177 | "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}",
178 | "output_type": "error",
179 | "traceback": [
180 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
181 | "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
182 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint64\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_int32\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m ))\n",
183 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
184 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
185 | "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}"
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "import marshmallow as ma\n",
191 | "import paramtools as pt\n",
192 | "\n",
193 | "\n",
194 | "# get the minimum and maxium values for 32 bit integers.\n",
195 | "min_int32 = -2147483648 # = np.iinfo(np.int32).min\n",
196 | "max_int32 = 2147483647 # = np.iinfo(np.int32).max\n",
197 | "\n",
198 | "# add int32 type to the paramtools type registry\n",
199 | "pt.register_custom_type(\n",
200 | " \"int32\",\n",
201 | " Int32(validate=[\n",
202 | " ma.validate.Range(min=min_int32, max=max_int32)\n",
203 | " ])\n",
204 | ")\n",
205 | "\n",
206 | "\n",
207 | "class Params(pt.Parameters):\n",
208 | " defaults = {\n",
209 | " \"small_int\": {\n",
210 | " \"title\": \"Small integer\",\n",
211 | " \"description\": \"Demonstrate how to define a custom type\",\n",
212 | " \"type\": \"int32\",\n",
213 | " \"value\": 2\n",
214 | " }\n",
215 | " }\n",
216 | "\n",
217 | "\n",
218 | "params = Params(array_first=True)\n",
219 | "\n",
220 | "params.adjust(dict(\n",
221 | " small_int=np.int64(max_int32) + 1\n",
222 | "))\n"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {},
228 | "source": [
229 | "### ParamTools Validator\n",
230 | "\n",
231 | "Finally, we will use ParamTools to solve this problem. We need to modify how we create our custom `marshmallow` field so that it's wrapped by ParamTools's `PartialField`. This makes it clear that your field still needs to be initialized, and that your custom field is able to receive validation information from the `defaults` configuration:\n"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 7,
237 | "metadata": {},
238 | "outputs": [
239 | {
240 | "ename": "ValidationError",
241 | "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}",
242 | "output_type": "error",
243 | "traceback": [
244 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
245 | "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
246 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2147483647\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m ))\n",
247 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
248 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
249 | "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}"
250 | ]
251 | }
252 | ],
253 | "source": [
254 | "import paramtools as pt\n",
255 | "\n",
256 | "\n",
257 | "# add int32 type to the paramtools type registry\n",
258 | "pt.register_custom_type(\n",
259 | " \"int32\",\n",
260 | " pt.PartialField(Int32)\n",
261 | ")\n",
262 | "\n",
263 | "\n",
264 | "class Params(pt.Parameters):\n",
265 | " defaults = {\n",
266 | " \"small_int\": {\n",
267 | " \"title\": \"Small integer\",\n",
268 | " \"description\": \"Demonstrate how to define a custom type\",\n",
269 | " \"type\": \"int32\",\n",
270 | " \"value\": 2,\n",
271 | " \"validators\": {\n",
272 | " \"range\": {\"min\": -2147483648, \"max\": 2147483647}\n",
273 | " }\n",
274 | " }\n",
275 | " }\n",
276 | "\n",
277 | "\n",
278 | "params = Params(array_first=True)\n",
279 | "\n",
280 | "params.adjust(dict(\n",
281 | " small_int=2147483647 + 1\n",
282 | "))\n",
283 | "\n"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "metadata": {},
290 | "outputs": [],
291 | "source": []
292 | }
293 | ],
294 | "metadata": {
295 | "kernelspec": {
296 | "display_name": "Python 3",
297 | "language": "python",
298 | "name": "python3"
299 | },
300 | "language_info": {
301 | "codemirror_mode": {
302 | "name": "ipython",
303 | "version": 3
304 | },
305 | "file_extension": ".py",
306 | "mimetype": "text/x-python",
307 | "name": "python",
308 | "nbconvert_exporter": "python",
309 | "pygments_lexer": "ipython3",
310 | "version": "3.8.5"
311 | }
312 | },
313 | "nbformat": 4,
314 | "nbformat_minor": 4
315 | }
316 |
--------------------------------------------------------------------------------
/docs/api/extend.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Extend\n",
8 | "\n",
9 | "The values of a parameter can be extended along a specified label. This is helpful when a parameter's values are the same for different values of a label and there is some inherent order in that label. The extend feature allows you to simply write down the minimum amount of information needed to fill in a parameter's values and ParamTools will fill in the gaps.\n",
10 | "\n",
11 | "To use the extend feature, set the `label_to_extend` class attribute to the label that should be extended.\n",
12 | "\n",
13 | "## Example\n",
14 | "\n",
15 | "The standard deduction parameter's values only need to be specified when there is a change in the tax law. For the other years, it does not change (unless its indexed to inflation). It would be annoying to have to manually write out each of its values. Instead, we can more concisely write its values in 2017, its new values in 2018 after the TCJA tax reform was passed, and its values after provisions of the TCJA are phased out in 2026.\n"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "data": {
25 | "text/plain": [
26 | "array([[ 6350., 12700.],\n",
27 | " [ 6350., 12700.],\n",
28 | " [ 6350., 12700.],\n",
29 | " [ 6350., 12700.],\n",
30 | " [ 6350., 12700.],\n",
31 | " [12000., 24000.],\n",
32 | " [12000., 24000.],\n",
33 | " [12000., 24000.],\n",
34 | " [12000., 24000.],\n",
35 | " [12000., 24000.],\n",
36 | " [12000., 24000.],\n",
37 | " [12000., 24000.],\n",
38 | " [12000., 24000.],\n",
39 | " [ 7685., 15369.],\n",
40 | " [ 7685., 15369.]])"
41 | ]
42 | },
43 | "execution_count": 1,
44 | "metadata": {},
45 | "output_type": "execute_result"
46 | }
47 | ],
48 | "source": [
49 | "import paramtools\n",
50 | "\n",
51 | "\n",
52 | "class TaxParams(paramtools.Parameters):\n",
53 | " defaults = {\n",
54 | " \"schema\": {\n",
55 | " \"labels\": {\n",
56 | " \"year\": {\n",
57 | " \"type\": \"int\",\n",
58 | " \"validators\": {\"range\": {\"min\": 2013, \"max\": 2027}}\n",
59 | " },\n",
60 | " \"marital_status\": {\n",
61 | " \"type\": \"str\",\n",
62 | " \"validators\": {\"choice\": {\"choices\": [\"single\", \"joint\"]}}\n",
63 | " },\n",
64 | " }\n",
65 | " },\n",
66 | " \"standard_deduction\": {\n",
67 | " \"title\": \"Standard deduction amount\",\n",
68 | " \"description\": \"Amount filing unit can use as a standard deduction.\",\n",
69 | " \"type\": \"float\",\n",
70 | " \"value\": [\n",
71 | " {\"year\": 2017, \"marital_status\": \"single\", \"value\": 6350},\n",
72 | " {\"year\": 2017, \"marital_status\": \"joint\", \"value\": 12700},\n",
73 | " {\"year\": 2018, \"marital_status\": \"single\", \"value\": 12000},\n",
74 | " {\"year\": 2018, \"marital_status\": \"joint\", \"value\": 24000},\n",
75 | " {\"year\": 2026, \"marital_status\": \"single\", \"value\": 7685},\n",
76 | " {\"year\": 2026, \"marital_status\": \"joint\", \"value\": 15369}],\n",
77 | " \"validators\": {\n",
78 | " \"range\": {\n",
79 | " \"min\": 0,\n",
80 | " \"max\": 9e+99\n",
81 | " }\n",
82 | " }\n",
83 | " },\n",
84 | " }\n",
85 | "\n",
86 | " label_to_extend = \"year\"\n",
87 | " array_first = True\n",
88 | "\n",
89 | "params = TaxParams()\n",
90 | "params.standard_deduction\n"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "Adjustments are also extended along `label_to_extend`. In the example below, `standard_deduction` is set to 10,000 in 2017, increased to 15,000 for single tax units in 2020, and increased to 20,000 for joint tax units in 2021:\n"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 2,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "text/plain": [
108 | "array([[ 6350., 12700.],\n",
109 | " [ 6350., 12700.],\n",
110 | " [ 6350., 12700.],\n",
111 | " [ 6350., 12700.],\n",
112 | " [10000., 10000.],\n",
113 | " [10000., 10000.],\n",
114 | " [10000., 10000.],\n",
115 | " [15000., 10000.],\n",
116 | " [15000., 20000.],\n",
117 | " [15000., 20000.],\n",
118 | " [15000., 20000.],\n",
119 | " [15000., 20000.],\n",
120 | " [15000., 20000.],\n",
121 | " [15000., 20000.],\n",
122 | " [15000., 20000.]])"
123 | ]
124 | },
125 | "execution_count": 2,
126 | "metadata": {},
127 | "output_type": "execute_result"
128 | }
129 | ],
130 | "source": [
131 | "params.adjust(\n",
132 | " {\n",
133 | " \"standard_deduction\": [\n",
134 | " {\"year\": 2017, \"value\": 10000},\n",
135 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n",
136 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n",
137 | " ]\n",
138 | " }\n",
139 | ")\n",
140 | "\n",
141 | "params.standard_deduction\n"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "### Clobber\n",
149 | "\n",
150 | "In the previous example, the new values _clobber_ the existing values in years after they are specified. By setting `clobber` to `False`, only values that were added automatically will be replaced by the new ones. User defined values such as those in 2026 will not be over-written by the new values:\n"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 3,
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "data": {
160 | "text/plain": [
161 | "array([[ 6350., 12700.],\n",
162 | " [ 6350., 12700.],\n",
163 | " [ 6350., 12700.],\n",
164 | " [ 6350., 12700.],\n",
165 | " [10000., 10000.],\n",
166 | " [12000., 24000.],\n",
167 | " [12000., 24000.],\n",
168 | " [15000., 24000.],\n",
169 | " [15000., 20000.],\n",
170 | " [15000., 20000.],\n",
171 | " [15000., 20000.],\n",
172 | " [15000., 20000.],\n",
173 | " [15000., 20000.],\n",
174 | " [ 7685., 15369.],\n",
175 | " [ 7685., 15369.]])"
176 | ]
177 | },
178 | "execution_count": 3,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "params = TaxParams()\n",
185 | "params.adjust(\n",
186 | " {\n",
187 | " \"standard_deduction\": [\n",
188 | " {\"year\": 2017, \"value\": 10000},\n",
189 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n",
190 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n",
191 | " ]\n",
192 | " },\n",
193 | " clobber=False,\n",
194 | ")\n",
195 | "\n",
196 | "params.standard_deduction\n"
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "metadata": {},
202 | "source": [
203 | "## Extend behavior by validator\n",
204 | "\n",
205 | "ParamTools uses the validator associated with `label_to_extend` to determine how values should be extended by assuming that there is some order among the range of possible values for the label.\n",
206 | "\n",
207 | "Note: You can view the grid of values for any label by inspecting the `label_grid` attribute of a `paramtools.Parameters` derived instance.\n",
208 | "\n",
209 | "### Range\n",
210 | "\n",
211 | "**Type:** `int`\n",
212 | "\n",
213 | "```json\n",
214 | "{\n",
215 | " \"range\": { \"min\": 0, \"max\": 5 }\n",
216 | "}\n",
217 | "```\n",
218 | "\n",
219 | "_Extend values:_\n",
220 | "\n",
221 | "```python\n",
222 | "[0, 1, 2, 3, 4, 5]\n",
223 | "```\n",
224 | "\n",
225 | "**Type:** `float`\n",
226 | "\n",
227 | "```json\n",
228 | "{\n",
229 | " \"range\": { \"min\": 0, \"max\": 2, \"step\": 0.5 }\n",
230 | "}\n",
231 | "```\n",
232 | "\n",
233 | "_Extend values:_\n",
234 | "\n",
235 | "```python\n",
236 | "[0, 0.5, 1.0, 1.5, 2.0]\n",
237 | "```\n",
238 | "\n",
239 | "**Type:** `date`\n",
240 | "\n",
241 | "```json\n",
242 | "{\n",
243 | " \"range\": { \"min\": \"2019-01-01\", \"max\": \"2019-01-05\", \"step\": { \"days\": 2 } }\n",
244 | "}\n",
245 | "```\n",
246 | "\n",
247 | "_Extend values:_\n",
248 | "\n",
249 | "```python\n",
250 | "[datetime.date(2019, 1, 1),\n",
251 | " datetime.date(2019, 1, 3),\n",
252 | " datetime.date(2019, 1, 5)]\n",
253 | "```\n",
254 | "\n",
255 | "### Choice\n",
256 | "\n",
257 | "**Type:** `int`\n",
258 | "\n",
259 | "```json\n",
260 | "{\n",
261 | " \"choice\": { \"choices\": [-1, -2, -3] }\n",
262 | "}\n",
263 | "```\n",
264 | "\n",
265 | "_Extend values:_\n",
266 | "\n",
267 | "```python\n",
268 | "[-1, -2, -3]\n",
269 | "```\n",
270 | "\n",
271 | "**Type:** `str`\n",
272 | "\n",
273 | "```json\n",
274 | "{\n",
275 | " \"choice\": { \"choices\": [\"january\", \"february\", \"march\"] }\n",
276 | "}\n",
277 | "```\n",
278 | "\n",
279 | "_Extend values:_\n",
280 | "\n",
281 | "```python\n",
282 | "[\"january\", \"february\", \"march\"]\n",
283 | "```\n"
284 | ]
285 | }
286 | ],
287 | "metadata": {
288 | "kernelspec": {
289 | "display_name": "Python 3",
290 | "language": "python",
291 | "name": "python3"
292 | },
293 | "language_info": {
294 | "codemirror_mode": {
295 | "name": "ipython",
296 | "version": 3
297 | },
298 | "file_extension": ".py",
299 | "mimetype": "text/x-python",
300 | "name": "python",
301 | "nbconvert_exporter": "python",
302 | "pygments_lexer": "ipython3",
303 | "version": "3.8.5"
304 | }
305 | },
306 | "nbformat": 4,
307 | "nbformat_minor": 4
308 | }
309 |
--------------------------------------------------------------------------------
/docs/api/guide.md:
--------------------------------------------------------------------------------
1 | # Guide
2 |
3 | **Just getting started?**
4 | Check out the [home page](/intro/) for a quick start guide.
5 |
6 | **Want to learn more about the JSON spec?**
7 | Check out the [parameters spec](/parameters/) to learn more about what types of parameters and validators are supported by ParamTools.
8 |
9 | ## API
10 |
11 | [**Viewing data**](/api/viewing-data/) Working with parameter values.
12 |
13 | [**Extend:**](/api/extend/) Write more concise JSON schemas with the extend capability.
14 |
15 | [**Extend with Indexing:**](/api/indexing/) Index parameter values.
16 |
17 | [**Custom Adjustment Formats and Logic**](/api/custom-adjust/) Customize the adjustment format and logic to meet your project's needs.
18 |
19 | [**Custom types**](/api/custom-types/) Add custom types for your parameters' values, labels, and members.
20 |
--------------------------------------------------------------------------------
/docs/api/indexing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Extend with Indexing\n",
8 | "\n",
9 | "ParamTools provides out-of-the-box parameter indexing. This is helpful for projects that have parameters that change at some rate over time. For example, tax parameters like the standard deduction are often indexed to price inflation. So, the value of the standard deduction actually increases every year by 1 or 2% depending on that year's inflation rate.\n",
10 | "\n",
11 | "The [extend documentation](/api/extend/) may be useful for gaining a better understanding of how ParamTools extends parameter values along `label_to_extend`.\n",
12 | "\n",
13 | "To use the indexing feature:\n",
14 | "\n",
15 | "- Set the `label_to_extend` class attribute to the label that should be extended\n",
16 | "- Set the `indexing_rates` class attribute to a dictionary of inflation rates where the keys correspond to the value of `label_to_extend` and the values are the indexing rates.\n",
17 | "- Set the `uses_extend_func` class attribute to `True`.\n",
18 | "- In `defaults` or `defaults.json`, set `indexed` to `True` for each parameter that needs to be indexed.\n",
19 | "\n",
20 | "## Example\n",
21 | "\n",
22 | "This is a continuation of the tax parameters example from the [extend documentation](/api/extend/). The differences are `indexed` is set to `True` for the `standard_deducation` parameter, `uses_extend_func` is set to `True`, and `index_rates` is specified with inflation rates obtained from the open-source tax modeling package, [Tax-Calculator](https://github.com/PSLmodels/Tax-Calculator/), using version 2.5.0.\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "data": {
32 | "text/plain": [
33 | "array([[ 6074.92, 12149.84],\n",
34 | " [ 6164.83, 12329.66],\n",
35 | " [ 6262.85, 12525.7 ],\n",
36 | " [ 6270.37, 12540.73],\n",
37 | " [ 6350. , 12700. ],\n",
38 | " [12000. , 24000. ],\n",
39 | " [12268.8 , 24537.6 ],\n",
40 | " [12497. , 24994. ],\n",
41 | " [12788.18, 25576.36],\n",
42 | " [13081.03, 26162.06],\n",
43 | " [13379.28, 26758.55],\n",
44 | " [13674.96, 27349.91],\n",
45 | " [13963.5 , 27926.99],\n",
46 | " [ 7685. , 15369. ],\n",
47 | " [ 7847.15, 15693.29]])"
48 | ]
49 | },
50 | "execution_count": 1,
51 | "metadata": {},
52 | "output_type": "execute_result"
53 | }
54 | ],
55 | "source": [
56 | "import paramtools\n",
57 | "\n",
58 | "\n",
59 | "class TaxParams(paramtools.Parameters):\n",
60 | " defaults = {\n",
61 | " \"schema\": {\n",
62 | " \"labels\": {\n",
63 | " \"year\": {\n",
64 | " \"type\": \"int\",\n",
65 | " \"validators\": {\"range\": {\"min\": 2013, \"max\": 2027}}\n",
66 | " },\n",
67 | " \"marital_status\": {\n",
68 | " \"type\": \"str\",\n",
69 | " \"validators\": {\"choice\": {\"choices\": [\"single\", \"joint\"]}}\n",
70 | " },\n",
71 | " }\n",
72 | " },\n",
73 | " \"standard_deduction\": {\n",
74 | " \"title\": \"Standard deduction amount\",\n",
75 | " \"description\": \"Amount filing unit can use as a standard deduction.\",\n",
76 | " \"type\": \"float\",\n",
77 | "\n",
78 | " # Set indexed to True to extend standard_deduction with the built-in\n",
79 | " # extension logic.\n",
80 | " \"indexed\": True,\n",
81 | "\n",
82 | " \"value\": [\n",
83 | " {\"year\": 2017, \"marital_status\": \"single\", \"value\": 6350},\n",
84 | " {\"year\": 2017, \"marital_status\": \"joint\", \"value\": 12700},\n",
85 | " {\"year\": 2018, \"marital_status\": \"single\", \"value\": 12000},\n",
86 | " {\"year\": 2018, \"marital_status\": \"joint\", \"value\": 24000},\n",
87 | " {\"year\": 2026, \"marital_status\": \"single\", \"value\": 7685},\n",
88 | " {\"year\": 2026, \"marital_status\": \"joint\", \"value\": 15369}],\n",
89 | " \"validators\": {\n",
90 | " \"range\": {\n",
91 | " \"min\": 0,\n",
92 | " \"max\": 9e+99\n",
93 | " }\n",
94 | " }\n",
95 | " },\n",
96 | " }\n",
97 | " array_first = True\n",
98 | " label_to_extend = \"year\"\n",
99 | " # Activate use of extend_func method.\n",
100 | " uses_extend_func = True\n",
101 | " # inflation rates from Tax-Calculator v2.5.0\n",
102 | " index_rates = {\n",
103 | " 2013: 0.0148,\n",
104 | " 2014: 0.0159,\n",
105 | " 2015: 0.0012,\n",
106 | " 2016: 0.0127,\n",
107 | " 2017: 0.0187,\n",
108 | " 2018: 0.0224,\n",
109 | " 2019: 0.0186,\n",
110 | " 2020: 0.0233,\n",
111 | " 2021: 0.0229,\n",
112 | " 2022: 0.0228,\n",
113 | " 2023: 0.0221,\n",
114 | " 2024: 0.0211,\n",
115 | " 2025: 0.0209,\n",
116 | " 2026: 0.0211,\n",
117 | " 2027: 0.0208,\n",
118 | " 2028: 0.021,\n",
119 | " 2029: 0.021\n",
120 | " }\n",
121 | "\n",
122 | "\n",
123 | "params = TaxParams()\n",
124 | "params.standard_deduction\n"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "Adjustments are also indexed. In the example below, `standard_deduction` is set to 10,000 in 2017, increased to 15,000 for single tax units in 2020, and increased to 20,000 for joint tax units in 2021:\n"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 2,
137 | "metadata": {},
138 | "outputs": [
139 | {
140 | "data": {
141 | "text/plain": [
142 | "array([[ 6074.92, 12149.84],\n",
143 | " [ 6164.83, 12329.66],\n",
144 | " [ 6262.85, 12525.7 ],\n",
145 | " [ 6270.37, 12540.73],\n",
146 | " [10000. , 10000. ],\n",
147 | " [10187. , 10187. ],\n",
148 | " [10415.19, 10415.19],\n",
149 | " [15000. , 10608.91],\n",
150 | " [15349.5 , 20000. ],\n",
151 | " [15701. , 20458. ],\n",
152 | " [16058.98, 20924.44],\n",
153 | " [16413.88, 21386.87],\n",
154 | " [16760.21, 21838.13],\n",
155 | " [17110.5 , 22294.55],\n",
156 | " [17471.53, 22764.97]])"
157 | ]
158 | },
159 | "execution_count": 2,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "params.adjust(\n",
166 | " {\n",
167 | " \"standard_deduction\": [\n",
168 | " {\"year\": 2017, \"value\": 10000},\n",
169 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n",
170 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n",
171 | " ]\n",
172 | " }\n",
173 | ")\n",
174 | "\n",
175 | "params.standard_deduction\n"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "All values that are added automatically via the `extend` method are given an `_auto` attribute. You can select them like this:\n"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 3,
188 | "metadata": {},
189 | "outputs": [
190 | {
191 | "data": {
192 | "text/plain": [
193 | "[{'year': 2013, 'marital_status': 'single', 'value': 6074.92, '_auto': True},\n",
194 | " {'year': 2014, 'marital_status': 'single', 'value': 6164.83, '_auto': True},\n",
195 | " {'year': 2015, 'marital_status': 'single', 'value': 6262.85, '_auto': True},\n",
196 | " {'year': 2016, 'marital_status': 'single', 'value': 6270.37, '_auto': True},\n",
197 | " {'year': 2019, 'marital_status': 'single', 'value': 12268.8, '_auto': True},\n",
198 | " {'year': 2020, 'marital_status': 'single', 'value': 12497.0, '_auto': True},\n",
199 | " {'year': 2021, 'marital_status': 'single', 'value': 12788.18, '_auto': True},\n",
200 | " {'year': 2022, 'marital_status': 'single', 'value': 13081.03, '_auto': True},\n",
201 | " {'year': 2023, 'marital_status': 'single', 'value': 13379.28, '_auto': True},\n",
202 | " {'year': 2024, 'marital_status': 'single', 'value': 13674.96, '_auto': True},\n",
203 | " {'year': 2025, 'marital_status': 'single', 'value': 13963.5, '_auto': True},\n",
204 | " {'year': 2027, 'marital_status': 'single', 'value': 7847.15, '_auto': True},\n",
205 | " {'year': 2013, 'marital_status': 'joint', 'value': 12149.84, '_auto': True},\n",
206 | " {'year': 2014, 'marital_status': 'joint', 'value': 12329.66, '_auto': True},\n",
207 | " {'year': 2015, 'marital_status': 'joint', 'value': 12525.7, '_auto': True},\n",
208 | " {'year': 2016, 'marital_status': 'joint', 'value': 12540.73, '_auto': True},\n",
209 | " {'year': 2019, 'marital_status': 'joint', 'value': 24537.6, '_auto': True},\n",
210 | " {'year': 2020, 'marital_status': 'joint', 'value': 24994.0, '_auto': True},\n",
211 | " {'year': 2021, 'marital_status': 'joint', 'value': 25576.36, '_auto': True},\n",
212 | " {'year': 2022, 'marital_status': 'joint', 'value': 26162.06, '_auto': True},\n",
213 | " {'year': 2023, 'marital_status': 'joint', 'value': 26758.55, '_auto': True},\n",
214 | " {'year': 2024, 'marital_status': 'joint', 'value': 27349.91, '_auto': True},\n",
215 | " {'year': 2025, 'marital_status': 'joint', 'value': 27926.99, '_auto': True},\n",
216 | " {'year': 2027, 'marital_status': 'joint', 'value': 15693.29, '_auto': True}]"
217 | ]
218 | },
219 | "execution_count": 3,
220 | "metadata": {},
221 | "output_type": "execute_result"
222 | }
223 | ],
224 | "source": [
225 | "params = TaxParams()\n",
226 | "\n",
227 | "params.select_eq(\n",
228 | " \"standard_deduction\", strict=True, _auto=True\n",
229 | ")\n"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "metadata": {},
235 | "source": [
236 | "If you want to update the index rates and apply them to your existing values, then all you need to do is remove the values that were added automatically. ParamTools will fill in the missing values using the updated index rates:\n"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 4,
242 | "metadata": {
243 | "scrolled": false
244 | },
245 | "outputs": [
246 | {
247 | "data": {
248 | "text/plain": [
249 | "array([[ 6015.22, 12030.41],\n",
250 | " [ 6119.28, 12238.54],\n",
251 | " [ 6231.87, 12463.73],\n",
252 | " [ 6254.93, 12509.85],\n",
253 | " [ 6350. , 12700. ],\n",
254 | " [12000. , 24000. ],\n",
255 | " [12298.8 , 24597.6 ],\n",
256 | " [12558.3 , 25116.61],\n",
257 | " [12882.3 , 25764.62],\n",
258 | " [13209.51, 26419.04],\n",
259 | " [13543.71, 27087.44],\n",
260 | " [13876.89, 27753.79],\n",
261 | " [14204.38, 28408.78],\n",
262 | " [ 7685. , 15369. ],\n",
263 | " [ 7866.37, 15731.71]])"
264 | ]
265 | },
266 | "execution_count": 4,
267 | "metadata": {},
268 | "output_type": "execute_result"
269 | }
270 | ],
271 | "source": [
272 | "params = TaxParams()\n",
273 | "\n",
274 | "offset = 0.0025\n",
275 | "for year, rate in params.index_rates.items():\n",
276 | " params.index_rates[year] = rate + offset\n",
277 | "\n",
278 | "automatically_added = params.select_eq(\n",
279 | " \"standard_deduction\", strict=True, _auto=True\n",
280 | ")\n",
281 | "\n",
282 | "params.delete(\n",
283 | " {\n",
284 | " \"standard_deduction\": automatically_added\n",
285 | " }\n",
286 | ")\n",
287 | "\n",
288 | "params.standard_deduction\n"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "metadata": {},
294 | "source": [
295 | "### Code for getting Tax-Calculator index rates\n",
296 | "\n",
297 | "```python\n",
298 | "import taxcalc\n",
299 | "pol = taxcalc.Policy()\n",
300 | "index_rates = {\n",
301 | " year: value\n",
302 | " for year, value in zip(list(range(2013, 2029 + 1)), pol.inflation_rates())\n",
303 | "}\n",
304 | "```\n",
305 | "\n",
306 | "Note that there are some subtle details that are implemented in Tax-Calculator's indexing logic that are not implemented in this example. Tax-Calculator has a parameter called `CPI_offset` that adjusts inflation rates up or down by a fixed amount. The `indexed` property can also be turned on and off for each parameter. Implementing these nuanced features is left as the proverbial \"trivial exercise to the reader.\"\n"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": []
315 | }
316 | ],
317 | "metadata": {
318 | "kernelspec": {
319 | "display_name": "Python 3",
320 | "language": "python",
321 | "name": "python3"
322 | },
323 | "language_info": {
324 | "codemirror_mode": {
325 | "name": "ipython",
326 | "version": 3
327 | },
328 | "file_extension": ".py",
329 | "mimetype": "text/x-python",
330 | "name": "python",
331 | "nbconvert_exporter": "python",
332 | "pygments_lexer": "ipython3",
333 | "version": "3.8.5"
334 | }
335 | },
336 | "nbformat": 4,
337 | "nbformat_minor": 4
338 | }
339 |
--------------------------------------------------------------------------------
/docs/api/reference.rst:
--------------------------------------------------------------------------------
1 | .. _reference:
2 |
3 | API Reference
4 | =================================================
5 |
6 | **parameters**
7 |
8 | Parameters
9 | ------------------------------------------
10 |
11 | .. currentmodule:: paramtools.parameters
12 |
13 | .. autoclass:: Parameters
14 | :members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values, validate, transaction
15 |
16 | Values
17 | ------------------------------------------
18 |
19 | .. currentmodule:: paramtools.values
20 |
21 | .. autoclass:: Values
22 | :members: eq, ne, gt, lt, lte, isin, isel, __and__, __or__
23 |
24 |
--------------------------------------------------------------------------------
/docs/parameters.md:
--------------------------------------------------------------------------------
1 | # Parameters
2 |
3 | Define your default parameters and let ParamTools handle the rest.
4 |
5 | The ParamTools JSON file is split into two components: a component that defines the structure of your default inputs and a component that defines the variables that are used in your model. The first component is a top level member named `schema`. The second component consists of key-value pairs where the key is the parameter's name and the value is its data.
6 |
7 | ```json
8 | {
9 | "schema": {
10 | "labels": {
11 | "year": {
12 | "type": "int",
13 | "validators": {"range": {"min": 2013, "max": 2027}}
14 | },
15 | "marital_status": {
16 | "type": "str",
17 | "validators": {"choice": {"choices": ["single", "joint", "separate",
18 | "headhousehold", "widow"]}}
19 | },
20 | },
21 | "additional_members": {
22 | "cpi_inflatable": {"type": "bool"},
23 | "cpi_inflated": {"type": "bool"}
24 | },
25 | "operators": {
26 | "array_first": true,
27 | "label_to_extend": "year",
28 | "uses_extend_func": true
29 | }
30 | },
31 | "personal_exemption": {
32 | "title": "Personal Exemption",
33 | "description": "A simple version of the personal exemption.",
34 | "cpi_inflatable": true,
35 | "cpi_inflated": true,
36 | "type": "float",
37 | "value": 0,
38 | "validators": {
39 | "range": {
40 | "min": 0,
41 | }
42 | }
43 | },
44 | "standard_deduction": {
45 | "title": "Standard deduction amount",
46 | "description": "Amount filing unit can use as a standard deduction.",
47 | "cpi_inflatable": true,
48 | "cpi_inflated": true,
49 | "type": "float",
50 | "value": [
51 | {"year": 2024, "marital_status": "single", "value": 13673.68},
52 | {"year": 2024, "marital_status": "joint", "value": 27347.36},
53 | {"year": 2024, "marital_status": "separate", "value": 13673.68},
54 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52},
55 | {"year": 2024, "marital_status": "widow", "value": 27347.36},
56 | {"year": 2025, "marital_status": "single", "value": 13967.66},
57 | {"year": 2025, "marital_status": "joint", "value": 27935.33},
58 | {"year": 2025, "marital_status": "separate", "value": 13967.66},
59 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49},
60 | {"year": 2025, "marital_status": "widow", "value": 27935.33}],
61 | "validators": {
62 | "range": {
63 | "min": 0,
64 | "level": "warn",
65 | }
66 | }
67 | },
68 | }
69 | ```
70 |
71 |
72 |
73 |
74 |
75 | ## Parameters Schema
76 |
77 | ```json
78 | {
79 | "schema": {
80 | "labels": {
81 | "year": {
82 | "type": "int",
83 | "validators": {"range": {"min": 2013, "max": 2027}}
84 | }
85 | },
86 | "additional_members": {
87 | "cpi_inflatable": {"type": "bool"},
88 | "cpi_inflated": {"type": "bool"}
89 | }
90 | },
91 | "operators": {
92 | "array_first": true,
93 | "label_to_extend": true,
94 | "uses_extend_func": true
95 | }
96 | }
97 | ```
98 |
99 | - `labels`: Labels are used for defining, accessing, and updating a parameter's values.
100 |
101 | - `additional_members`: Additional Members are parameter level members that are specific to your model. For example, "title" is a parameter level member that is required by ParamTools, but "cpi_inflated" is not. Therefore, "cpi_inflated" needs to be defined in `additional_members`.
102 |
103 | - `operators`: Operators affect how the data is read into and handled by the `Parameters` class:
104 |
105 | - `array_first`: If value is `true`, parameters' values will be accessed as arrays by default.
106 |
107 | - `label_to_extend`: The name of the label along which the missing values of the parameters will be extended. For more information, check out the [extend docs](/api/extend/).
108 |
109 | - `uses_extend_func`: If value is `true`, special logic is applied to the values of the parameters as they are extended. For more information, check out the [indexing docs](/api/indexing/).
110 |
111 |
112 | ## Default Parameters
113 |
114 | ```json
115 | {
116 | "standard_deduction": {
117 | "title": "Standard deduction amount",
118 | "description": "Amount filing unit can use as a standard deduction.",
119 | "cpi_inflatable": true,
120 | "cpi_inflated": true,
121 | "type": "float",
122 | "number_dims": 0,
123 | "value": [
124 | {"year": 2024, "marital_status": "single", "value": 13673.68},
125 | {"year": 2024, "marital_status": "joint", "value": 27347.36},
126 | {"year": 2024, "marital_status": "separate", "value": 13673.68},
127 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52},
128 | {"year": 2024, "marital_status": "widow", "value": 27347.36},
129 | {"year": 2025, "marital_status": "single", "value": 13967.66},
130 | {"year": 2025, "marital_status": "joint", "value": 27935.33},
131 | {"year": 2025, "marital_status": "separate", "value": 13967.66},
132 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49},
133 | {"year": 2025, "marital_status": "widow", "value": 27935.33}],
134 | "validators": {
135 | "range": {
136 | "min": 0,
137 | "max": 9e+99
138 | }
139 | }
140 | }
141 | }
142 | ```
143 |
144 | ### Members:
145 |
146 | - `title`: A human readable name for the parameter.
147 |
148 | - `description`: Describe the parameter.
149 |
150 | - `notes`: (*optional*) Additional advice or information.
151 |
152 | - `type`: Data type of the parameter. Allowed types are `int`, `float`, `bool`, `str` and `date` (YYYY-MM-DD).
153 |
154 | - `number_dims`: (*optional, default is 0*) Number of dimensions for the value, as defined by [`np.ndim`][1].
155 |
156 | - `value`: Value of the parameter and optionally, the corresponding labels. It can be written in two ways:
157 |
158 | - if labels are used: `{"value": [{"value": "my value", **labels}]}`
159 |
160 | - if labels are not used: `{"value": "my value"}`
161 |
162 | - `validators`: Key-value pairs of the validator objects (*the ranges are inclusive*):
163 | - `level`: All validators take a `level` argument which is either "error" or "warn". By default it is set to "error".
164 | - `when`:
165 | - `is` is set to `equal_to` by default but can also be `greater_than` or `less_than`.
166 | - e.g: `"is": {"greater_than": 0}`
167 |
168 | - If the sub-validators refer to the value of another parameter, then the other parameter
169 | must have `number_dims` equal to 0, i.e. be a scalar value and not an array value.
170 |
171 | ```json
172 | {
173 | "validators": {
174 | "range": {"min": "min value", "max": "max value", "level": "warn"},
175 | "choice": {"choices": ["list", "of", "allowed", "values"]},
176 | "date_range": {"min": "2018-01-01", "max": "2018-06-01"},
177 | "when": {
178 | "param": "other parameter",
179 | "is": "equal_value",
180 | "then": {
181 | "range": {
182 | "min": "min value if other parameter is 'equal_value'"
183 | }
184 | },
185 | "otherwise": {
186 | "range": {
187 | "min": "min value if other parameter is not 'equal_value'"
188 | }
189 | }
190 | }
191 | }
192 | }
193 | ```
194 |
195 | [1]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ndim.html
196 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: paramtools-dev
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - "marshmallow>=4.0.0"
6 | - "numpy>=2.1.0"
7 | - "python-dateutil>=2.8.0"
8 | - "pytest>=6.0.0"
9 | - pandas
10 | - fsspec
11 | - pyopenssl
12 | - s3fs # required for tests but not most usage.
13 | - gcsfs # required for tests but not most usage.
14 | - requests # required for tests but not most usage.
15 | - aiohttp # required for tests but not most usage.
16 | - pip
17 | - pip:
18 | - pre-commit
19 | - black
20 | - flake8
21 | - jupyter-book
22 | - sphinx
23 | - autodoc
24 | - ghp-import
25 | - sortedcontainers
26 |
--------------------------------------------------------------------------------
/paramtools/__init__.py:
--------------------------------------------------------------------------------
1 | from paramtools.schema_factory import SchemaFactory
2 | from paramtools.exceptions import (
3 | ParamToolsError,
4 | ParameterUpdateException,
5 | SparseValueObjectsException,
6 | ValidationError,
7 | InconsistentLabelsException,
8 | collision_list,
9 | ParameterNameCollisionException,
10 | UnknownTypeException,
11 | )
12 | from paramtools.parameters import Parameters
13 | from paramtools.schema import (
14 | RangeSchema,
15 | ChoiceSchema,
16 | ValueValidatorSchema,
17 | BaseParamSchema,
18 | EmptySchema,
19 | BaseValidatorSchema,
20 | ALLOWED_TYPES,
21 | FIELD_MAP,
22 | VALIDATOR_MAP,
23 | get_type,
24 | get_param_schema,
25 | register_custom_type,
26 | PartialField,
27 | )
28 | from paramtools.select import (
29 | select,
30 | select_eq,
31 | select_ne,
32 | select_gt,
33 | select_gte,
34 | select_gt_ix,
35 | select_lt,
36 | select_lte,
37 | )
38 | from paramtools.sorted_key_list import SortedKeyList, SortedKeyListResult
39 | from paramtools.typing import ValueObject
40 | from paramtools.utils import (
41 | read_json,
42 | get_example_paths,
43 | LeafGetter,
44 | get_leaves,
45 | ravel,
46 | consistent_labels,
47 | ensure_value_object,
48 | hashable_value_object,
49 | filter_labels,
50 | make_label_str,
51 | )
52 | from paramtools.values import Values, Slice, QueryResult
53 |
54 |
55 | name = "paramtools"
56 | __version__ = "0.20.0"
57 |
58 | __all__ = [
59 | "SchemaFactory",
60 | "ParamToolsError",
61 | "ParameterUpdateException",
62 | "SparseValueObjectsException",
63 | "ValidationError",
64 | "InconsistentLabelsException",
65 | "collision_list",
66 | "ParameterNameCollisionException",
67 | "UnknownTypeException",
68 | "Parameters",
69 | "RangeSchema",
70 | "ChoiceSchema",
71 | "ValueValidatorSchema",
72 | "BaseParamSchema",
73 | "EmptySchema",
74 | "BaseValidatorSchema",
75 | "ALLOWED_TYPES",
76 | "FIELD_MAP",
77 | "VALIDATOR_MAP",
78 | "get_type",
79 | "get_param_schema",
80 | "register_custom_type",
81 | "PartialField",
82 | "select",
83 | "select_eq",
84 | "select_gt",
85 | "select_gte",
86 | "select_gt_ix",
87 | "select_lt",
88 | "select_lte",
89 | "select_ne",
90 | "read_json",
91 | "get_example_paths",
92 | "get_defaults",
93 | "LeafGetter",
94 | "get_leaves",
95 | "ravel",
96 | "consistent_labels",
97 | "ensure_value_object",
98 | "hashable_value_object",
99 | "filter_labels",
100 | "make_label_str",
101 | "SortedKeyList",
102 | "SortedKeyListResult",
103 | "ValueObject",
104 | "Values",
105 | "Slice",
106 | "QueryResult",
107 | ]
108 |
--------------------------------------------------------------------------------
/paramtools/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | from paramtools.contrib.validate import Range, DateRange, OneOf, When
2 | from paramtools.contrib.fields import (
3 | Float64,
4 | Int64,
5 | Bool_,
6 | MeshFieldMixin,
7 | Str,
8 | Integer,
9 | Float,
10 | Boolean,
11 | Date,
12 | )
13 |
14 |
15 | __all__ = [
16 | "Range",
17 | "DateRange",
18 | "OneOf",
19 | "When",
20 | "Float64",
21 | "Int64",
22 | "Bool_",
23 | "MeshFieldMixin",
24 | "Str",
25 | "Integer",
26 | "Float",
27 | "Boolean",
28 | "Date",
29 | ]
30 |
--------------------------------------------------------------------------------
/paramtools/contrib/fields.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import datetime
3 | import json
4 |
5 | import marshmallow as ma
6 |
7 |
8 | default_cmps = {
9 | "key": lambda x: x,
10 | "gt": lambda x, y: x > y,
11 | "gte": lambda x, y: x >= y,
12 | "lt": lambda x, y: x < y,
13 | "lte": lambda x, y: x <= y,
14 | "ne": lambda x, y: x != y,
15 | "eq": lambda x, y: x == y,
16 | }
17 |
18 |
19 | class NumPySerializeMixin:
20 | def _serialize(self, value, attr, obj, **kwargs):
21 | if hasattr(value, "tolist"):
22 | return value.tolist()
23 | else:
24 | return value
25 |
26 | def _validated(self, value):
27 | value = super()._validated(value)
28 | if value.shape != tuple():
29 | raise self.make_error("invalid", input=value)
30 | return value
31 |
32 | def cmp_funcs(self, **kwargs):
33 | if not self.validators:
34 | return default_cmps
35 | assert len(self.validators) == 1
36 | cmp_funcs = self.validators[0].cmp_funcs(**kwargs)
37 | if cmp_funcs is None:
38 | return default_cmps
39 | else:
40 | return cmp_funcs
41 |
42 |
43 | class Float64(NumPySerializeMixin, ma.fields.Number):
44 | """
45 | Implements "float" :ref:`spec:Type property` for parameter values.
46 | Defined as
47 | `numpy.float64 `__ type
48 | """
49 |
50 | num_type = np_type = np.float64
51 |
52 |
53 | class Int64(NumPySerializeMixin, ma.fields.Integer):
54 | """
55 | Implements "int" :ref:`spec:Type property` for parameter values.
56 | Defined as `numpy.int64 `__ type
57 | """
58 |
59 | num_type = np_type = np.int64
60 |
61 |
62 | class Bool_(NumPySerializeMixin, ma.fields.Boolean):
63 | """
64 | Implements "bool" :ref:`spec:Type property` for parameter values.
65 | Defined as `numpy.bool_ `__ type
66 | """
67 |
68 | num_type = np_type = np.bool_
69 |
70 | def _deserialize(self, value, attr, obj, **kwargs):
71 | return np.bool_(super()._deserialize(value, attr, obj, **kwargs))
72 |
73 |
74 | class MeshFieldMixin:
75 | """
76 | Provides method for accessing ``contrib.validate``
77 | validators' grid methods
78 | """
79 |
80 | def grid(self):
81 | if not self.validators:
82 | return []
83 | assert len(self.validators) == 1
84 | return self.validators[0].grid()
85 |
86 | def cmp_funcs(self, **kwargs):
87 | if not self.validators:
88 | return default_cmps
89 | assert len(self.validators) == 1
90 | cmp_funcs = self.validators[0].cmp_funcs(**kwargs)
91 | if cmp_funcs is None:
92 | return default_cmps
93 | else:
94 | return cmp_funcs
95 |
96 |
97 | class Str(MeshFieldMixin, ma.fields.Str):
98 | """
99 | Implements "str" :ref:`spec:Type property`.
100 | """
101 |
102 | np_type = object
103 |
104 |
105 | class Integer(MeshFieldMixin, ma.fields.Integer):
106 | """
107 | Implements "int" :ref:`spec:Type property` for properties
108 | except for parameter values.
109 | """
110 |
111 | np_type = int
112 |
113 |
114 | class Float(MeshFieldMixin, ma.fields.Float):
115 | """
116 | Implements "float" :ref:`spec:Type property` for properties
117 | except for parameter values.
118 | """
119 |
120 | np_type = float
121 |
122 |
123 | class Boolean(MeshFieldMixin, ma.fields.Boolean):
124 | """
125 | Implements "bool" :ref:`spec:Type property` for properties
126 | except for parameter values.
127 | """
128 |
129 | np_type = bool
130 |
131 |
132 | class Date(MeshFieldMixin, ma.fields.Date):
133 | """
134 | Implements "date" :ref:`spec:Type property`.
135 | """
136 |
137 | np_type = datetime.date
138 | default_error_messages = {
139 | "invalid": "Not a valid {obj_type}: {input}",
140 | "format": '"{input}" cannot be formatted as a {obj_type}.',
141 | }
142 |
143 | def _deserialize(self, value, attr=None, data=None, **kwargs):
144 | if isinstance(value, (datetime.datetime, datetime.date)):
145 | return value
146 | return super()._deserialize(value, attr, data, **kwargs)
147 |
148 |
149 | class Nested(ma.fields.Nested):
150 | def json_cmp_func(self, data):
151 | try:
152 | return json.dumps(self._serialize(data, None, None))
153 | except ma.ValidationError as ve:
154 | try:
155 | return json.dumps(data)
156 | except json.JSONDecodeError:
157 | raise ve
158 |
159 | def cmp_funcs(self, **kwargs):
160 | cmp_funcs = getattr(self.nested, "cmp_funcs", None)
161 | if cmp_funcs is None:
162 | # This is not a good comparison function but it's the
163 | # best we can do if the cmp_funcs method is not
164 | # defined.
165 | return {"key": self.json_cmp_func}
166 | else:
167 | return cmp_funcs(**kwargs)
168 |
--------------------------------------------------------------------------------
/paramtools/contrib/validate.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import datetime
3 | import itertools
4 |
5 | from dateutil.relativedelta import relativedelta
6 | import numpy as np
7 | import marshmallow as ma
8 |
9 | from paramtools.typing import ValueObject
10 | from paramtools import utils
11 |
12 |
13 | class ValidationError(ma.ValidationError):
14 | def __init__(self, *args, level=None, **kwargs):
15 | self.level = level or "error"
16 | super().__init__(*args, **kwargs)
17 |
18 |
19 | class When(ma.validate.Validator):
20 | then_message = "When value is {is_val}, the input is invalid: {submsg}"
21 | otherwise_message = (
22 | "When value is {is_val}, the input is invalid: {submsg}"
23 | )
24 | shape_mismatch = "Shape mismatch between parameters: {shape1} {shape2}"
25 |
26 | is_value_evaluators = {
27 | "equal_to": lambda when_value, is_val: when_value == is_val,
28 | "less_than": lambda when_value, is_val: when_value < is_val,
29 | "greater_than": lambda when_value, is_val: when_value > is_val,
30 | }
31 |
32 | def __init__(
33 | self,
34 | is_object,
35 | when_vos: List[ValueObject],
36 | then_validators: List[ma.validate.Validator],
37 | otherwise_validators: List[ma.validate.Validator],
38 | then_message: str = None,
39 | otherwise_message: str = None,
40 | level: str = "error",
41 | type: str = "int",
42 | number_dims: int = 0,
43 | ):
44 | self.is_operator = next(iter(is_object))
45 | self.is_val = is_object[self.is_operator]
46 | self.when_vos = when_vos
47 | self.then_validators = then_validators
48 | self.otherwise_validators = otherwise_validators
49 | self.then_message = then_message or self.then_message
50 | self.otherwise_message = otherwise_message or self.otherwise_message
51 | self.level = level
52 | self.type = type
53 | self.number_dims = number_dims
54 |
55 | def __call__(self, value, is_value_object=False):
56 | if value is None:
57 | return value
58 | if not is_value_object:
59 | value = {"value": value}
60 |
61 | msgs = []
62 | arr = np.array(value["value"])
63 | for when_vo in self.when_vos:
64 | if not isinstance(when_vo["value"], list):
65 | msgs += self.apply_validator(
66 | value, when_vo["value"], value, when_vo, ix=None
67 | )
68 | continue
69 |
70 | when_arr = np.array(when_vo["value"])
71 | if when_arr.shape != arr.shape:
72 | raise ValidationError(
73 | self.shape_mismatch.format(
74 | shape1=when_arr.shape, shape2=arr.shape
75 | )
76 | )
77 | for ix in itertools.product(*(map(range, arr.shape))):
78 | msgs += self.apply_validator(
79 | {"value": arr[ix]}, when_arr[ix], value, when_vo, ix
80 | )
81 |
82 | if msgs:
83 | raise ValidationError(
84 | msgs if len(msgs) > 1 else msgs[0], level=self.level
85 | )
86 |
87 | def evaluate_is_value(self, when_value):
88 | return self.is_value_evaluators[self.is_operator](
89 | when_value, self.is_val
90 | )
91 |
92 | def apply_validator(self, value, when_value, labels, when_labels, ix=None):
93 | def ix2string(ix):
94 | return (
95 | f"[index={', '.join(map(str, ix))}]" if ix is not None else ""
96 | )
97 |
98 | msgs = []
99 | is_val_cond = self.evaluate_is_value(when_value)
100 | if is_val_cond:
101 | for validator in self.then_validators:
102 | try:
103 | validator(value, is_value_object=True)
104 | except ValidationError as ve:
105 | msgs.append(
106 | self.then_message.format(
107 | is_val=f"{self.is_operator.replace('_', ' ')} {self.is_val}",
108 | submsg=str(ve),
109 | labels=utils.make_label_str(labels),
110 | when_labels=utils.make_label_str(when_labels),
111 | ix=ix2string(ix),
112 | )
113 | )
114 | else:
115 | for validator in self.otherwise_validators:
116 | try:
117 | validator(value, is_value_object=True)
118 | except ValidationError as ve:
119 | msgs.append(
120 | self.otherwise_message.format(
121 | is_val=f"{self.is_operator.replace('_', ' ')} {self.is_val}",
122 | submsg=str(ve),
123 | labels=utils.make_label_str(labels),
124 | when_labels=utils.make_label_str(when_labels),
125 | ix=ix2string(ix),
126 | )
127 | )
128 | return msgs
129 |
130 | def grid(self):
131 | """
132 | Just return grid of first validator. It's unlikely that
133 | there will be multiple.
134 | """
135 | return self.then_validators[0].grid()
136 |
137 | def cmp_funcs(self, **kwargs):
138 | return None
139 |
140 |
141 | class Range(ma.validate.Range):
142 | """
143 | Implements "range" :ref:`spec:Validator object`.
144 | """
145 |
146 | error = ""
147 | message_min = "Input {input} must be {min_op} {min}."
148 | message_max = "Input {input} must be {max_op} {max}."
149 |
150 | def __init__(
151 | self,
152 | min=None,
153 | max=None,
154 | min_vo=None,
155 | max_vo=None,
156 | error_min=None,
157 | error_max=None,
158 | step=None,
159 | level=None,
160 | ):
161 | if min is not None:
162 | self.min = [{"value": min}]
163 | else:
164 | self.min = min_vo
165 | if max is not None:
166 | self.max = [{"value": max}]
167 | else:
168 | self.max = max_vo
169 |
170 | self.error_min = error_min
171 | self.error_max = error_max
172 | self.step = step or 1 # default to 1
173 |
174 | self.min_inclusive = None
175 | self.max_inclusive = None
176 | self.level = level or "error"
177 |
178 | def __call__(self, value, is_value_object=False):
179 | """
180 | This is the method that marshmallow calls by default. is_value_object
181 | validation goes straight to validate_value_objects.
182 | """
183 | if value is None:
184 | return value
185 | if not is_value_object:
186 | value = {"value": value}
187 | return self.validate_value_objects(value)
188 |
189 | def validate_value_objects(self, value):
190 | if value["value"] is None:
191 | return None
192 | msgs = []
193 | if self.min is not None:
194 | for min_vo in self.min:
195 | if np.any(np.array(value["value"]) < min_vo["value"]):
196 | msgs.append(
197 | (self.error_min or self.message_min).format(
198 | input=value["value"],
199 | min=min_vo["value"],
200 | min_op="greater than",
201 | labels=utils.make_label_str(value),
202 | oth_labels=utils.make_label_str(min_vo),
203 | )
204 | )
205 | if self.max is not None:
206 | for max_vo in self.max:
207 | if np.any(np.array(value["value"]) > max_vo["value"]):
208 | msgs.append(
209 | (self.error_max or self.message_max).format(
210 | input=value["value"],
211 | max=max_vo["value"],
212 | max_op="less than",
213 | labels=utils.make_label_str(value),
214 | oth_labels=utils.make_label_str(max_vo),
215 | )
216 | )
217 | if msgs:
218 | raise ValidationError(
219 | msgs if len(msgs) > 1 else msgs[0], level=self.level
220 | )
221 | return value
222 |
223 | def grid(self):
224 | # make np.arange inclusive.
225 | max_ = self.max[0]["value"] + self.step
226 | arr = np.arange(self.min[0]["value"], max_, self.step)
227 | return arr[arr <= self.max[0]["value"]].tolist()
228 |
229 | def cmp_funcs(self, **kwargs):
230 | return None
231 |
232 |
233 | class DateRange(Range):
234 | """
235 | Implements "date_range" :ref:`spec:Validator object`.
236 | Behaves like ``Range``, except values are ensured to be
237 | ``datetime.date`` type and ``grid`` has special logic for dates.
238 | """
239 |
240 | # check against allowed args:
241 | # https://docs.python.org/3/library/datetime.html#datetime.timedelta
242 | timedelta_args = {
243 | "days",
244 | "months",
245 | "seconds",
246 | "microseconds",
247 | "milliseconds",
248 | "minutes",
249 | "hours",
250 | "weeks",
251 | }
252 |
253 | step_msg = (
254 | f"The step field must be a dictionary with only these keys: {', '.join(timedelta_args)}."
255 | f"\n\tFor more information, check out the timedelta docs: "
256 | f"\n\t\thttps://docs.python.org/3/library/datetime.html#datetime.timedelta"
257 | )
258 |
259 | def __init__(
260 | self,
261 | min=None,
262 | max=None,
263 | min_vo=None,
264 | max_vo=None,
265 | error_min=None,
266 | error_max=None,
267 | step=None,
268 | level=None,
269 | ):
270 | if min is not None:
271 | self.min = [{"value": self.safe_deserialize(min)}]
272 | elif min_vo is not None:
273 | self.min = [
274 | dict(vo, **{"value": self.safe_deserialize(vo["value"])})
275 | for vo in min_vo
276 | ]
277 | else:
278 | self.min = None
279 |
280 | if max is not None:
281 | self.max = [{"value": self.safe_deserialize(max)}]
282 | elif max_vo is not None:
283 | self.max = [
284 | dict(vo, **{"value": self.safe_deserialize(vo["value"])})
285 | for vo in max_vo
286 | ]
287 | else:
288 | self.max = None
289 |
290 | self.error_min = error_min
291 | self.error_max = error_max
292 |
293 | if step is None:
294 | # set to to default step.
295 | step = {"days": 1}
296 |
297 | if not isinstance(step, dict):
298 | raise ValidationError(self.step_msg)
299 |
300 | has_extra_keys = len(set(step.keys()) - self.timedelta_args)
301 | if has_extra_keys:
302 | raise ValidationError(self.step_msg)
303 |
304 | self.step = relativedelta(**step)
305 |
306 | self.level = level or "error"
307 |
308 | def safe_deserialize(self, date):
309 | if isinstance(date, datetime.date):
310 | return date
311 | else:
312 | return ma.fields.Date()._deserialize(date, None, None)
313 |
314 | def grid(self):
315 | # make np.arange inclusive.
316 | max_ = self.max[0]["value"] + self.step
317 |
318 | current = self.min[0]["value"]
319 | result = []
320 | while current < max_:
321 | result.append(current)
322 | current += self.step
323 |
324 | return result
325 |
326 |
327 | class OneOf(ma.validate.OneOf):
328 | """
329 | Implements "choice" :ref:`spec:Validator object`.
330 | """
331 |
332 | default_message = "Input {input} must be one of {choices}."
333 |
334 | def __init__(self, *args, level=None, **kwargs):
335 | self.level = level or "error"
336 | super().__init__(*args, **kwargs)
337 |
338 | def __call__(self, value, is_value_object=False):
339 | if value is None:
340 | return value
341 | if not is_value_object:
342 | vo = {"value": value}
343 | else:
344 | vo = value
345 | if vo["value"] is None:
346 | return None
347 | if not isinstance(vo["value"], list):
348 | vos = {"value": [vo["value"]]}
349 | else:
350 | vos = {"value": utils.ravel(vo["value"])}
351 | for vo in vos["value"]:
352 | try:
353 | if vo not in self.choices:
354 | raise ValidationError(
355 | self._format_error(vo), level=self.level
356 | )
357 | except TypeError:
358 | raise ValidationError(self._format_error(vo), level=self.level)
359 | return value
360 |
361 | def grid(self):
362 | return self.choices
363 |
364 | def cmp_funcs(self, choices=None, **kwargs):
365 | if choices is None:
366 | choices = self.choices
367 | return {
368 | "key": lambda x: choices.index(x),
369 | "gt": lambda x, y: choices.index(x) > choices.index(y),
370 | "gte": lambda x, y: choices.index(x) >= choices.index(y),
371 | "lt": lambda x, y: choices.index(x) < choices.index(y),
372 | "lte": lambda x, y: choices.index(x) <= choices.index(y),
373 | "ne": lambda x, y: x != y,
374 | "eq": lambda x, y: x == y,
375 | }
376 |
--------------------------------------------------------------------------------
/paramtools/examples/baseball/defaults.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": {
3 | "labels": {
4 | "use_2018": {"type": "bool", "validators": {}}
5 | },
6 | "additional_members": {
7 | "section_1": {"type": "str"},
8 | "section_2": {"type": "str"}
9 | }
10 | },
11 | "pitcher": {
12 | "title": "Pitcher Name",
13 | "description": "Name of pitcher to pull data on",
14 | "section_1": "Parameters",
15 | "section_2": "Pitcher",
16 | "notes": "Make sure the name of the pitcher is correct. A good place to reference this is baseball-reference.com",
17 | "type": "str",
18 | "value": [{"value": "Clayton Kershaw"}],
19 | "validators": {"choice": {"choices": ["Clayton Kershaw", "Julio Teheran", "Max Scherzer"]}}
20 | },
21 | "batter": {
22 | "title": "Batter Name",
23 | "description": "Name of batter for pitching matchup analysis",
24 | "section_1": "Parameters",
25 | "section_2": "Batter",
26 | "notes": "Make sure the name of the batter is correct. A good place to reference this is baseball-reference.com",
27 | "type": "str",
28 | "value": [{"value": "Freddie Freeman"}],
29 | "validators": {"choice": {"choices": ["Freddie Freeman", "Bryce Harper", "Mookie Betts"]}}
30 | },
31 | "start_date": {
32 | "title": "Start Date",
33 | "description": "Date to start pulling statcast information",
34 | "section_1": "Parameters",
35 | "section_2": "Date",
36 | "notes": "If using the 2018 dataset, only use dates in 2018.",
37 | "type": "date",
38 | "value": [{"value": "2018-01-01"}],
39 | "validators": {"date_range": {"min": "2008-01-01", "max": "end_date"}}
40 | },
41 | "end_date": {
42 | "title": "End Date",
43 | "description": "Date to quit pulling statcast information",
44 | "section_1": "Parameters",
45 | "section_2": "Date",
46 | "notes": "If using the 2018 dataset, only use dates in 2018.",
47 | "type": "date",
48 | "value": [{"value": "2018-11-10"}],
49 | "validators": {"date_range": {"min": "2008-01-01", "max": "2018-11-10"}}
50 | }
51 | }
--------------------------------------------------------------------------------
/paramtools/examples/behresp/defaults.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": {
3 | "labels": {},
4 | "additional_members": {}
5 | },
6 | "BE_sub": {
7 | "title": "Substitution elasticity of taxable income",
8 | "description": "Defined as proportional change in taxable income divided by proportional change in marginal net-of-tax rate (1-MTR) on taxpayer earnings caused by the reform. Must be zero or positive.",
9 | "notes": "",
10 | "type": "float",
11 | "value": [{"value": 0.0}],
12 | "validators": {
13 | "range": {"min": 0.0, "max": 9e99}
14 | }
15 | },
16 | "BE_inc": {
17 | "title": "Income elasticity of taxable income",
18 | "description": "Defined as dollar change in taxable income divided by dollar change in after-tax income caused by the reform. Must be zero or negative.",
19 | "notes": "",
20 | "type": "float",
21 | "value": [{"value": 0.0}],
22 | "validators": {
23 | "range": {"min": -9e99, "max": 0}
24 | }
25 | },
26 | "BE_cg": {
27 | "title": "Semi-elasticity of long-term capital gains",
28 | "description": "Defined as change in logarithm of long-term capital gains divided by change in marginal tax rate (MTR) on long-term capital gains caused by the reform. Must be zero or negative. Read response function documentation (see below) for discussion of appropriate values.",
29 | "notes": "",
30 | "type": "float",
31 | "value": [{"value": 0.0}],
32 | "validators": {
33 | "range": {"min": -9e99, "max": 0}
34 | }
35 | }
36 | }
--------------------------------------------------------------------------------
/paramtools/examples/taxparams-demo/defaults.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": {
3 | "labels": {
4 | "year": {
5 | "type": "int",
6 | "validators": {"range": {"min": 2013, "max": 2027}}
7 | },
8 | "marital_status": {
9 | "type": "str",
10 | "validators": {"choice": {"choices": ["single", "joint", "separate",
11 | "headhousehold", "widow"]}}
12 | },
13 | "idedtype": {
14 | "type": "str",
15 | "validators": {"choice": {"choices": ["medical", "statelocal",
16 | "realestate", "casualty",
17 | "misc", "interest", "charity"]}}
18 | },
19 | "EIC": {
20 | "type": "str",
21 | "validators": {"choice": {"choices": ["0kids", "1kid",
22 | "2kids", "3+kids"]}}
23 | }
24 | },
25 | "additional_members": {
26 | "section_1": {"type": "str", "number_dims": 0},
27 | "section_2": {"type": "str", "number_dims": 0},
28 | "section_3": {"type": "str", "number_dims": 0},
29 | "irs_ref": {"type": "str", "number_dims": 0},
30 | "start_year": {"type": "int", "number_dims": 0},
31 | "cpi_inflatable": {"type": "bool", "number_dims": 0},
32 | "cpi_inflated": {"type": "bool", "number_dims": 0}
33 | }
34 | },
35 | "standard_deduction": {
36 | "title": "Standard deduction amount",
37 | "description": "Amount filing unit can use as a standard deduction.",
38 | "section_1": "Standard Deduction",
39 | "section_2": "Standard Deduction Amount",
40 | "irs_ref": "Form 1040, line 8, instructions. ",
41 | "notes": "",
42 | "start_year": 2024,
43 | "cpi_inflatable": true,
44 | "cpi_inflated": true,
45 | "type": "float",
46 | "value": [
47 | {"year": 2024, "marital_status": "single", "value": 13673.68},
48 | {"year": 2024, "marital_status": "joint", "value": 27347.36},
49 | {"year": 2024, "marital_status": "separate", "value": 13673.68},
50 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52},
51 | {"year": 2024, "marital_status": "widow", "value": 27347.36},
52 | {"year": 2025, "marital_status": "single", "value": 13967.66},
53 | {"year": 2025, "marital_status": "joint", "value": 27935.33},
54 | {"year": 2025, "marital_status": "separate", "value": 13967.66},
55 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49},
56 | {"year": 2025, "marital_status": "widow", "value": 27935.33},
57 | {"year": 2026, "marital_status": "single", "value": 7690.0},
58 | {"year": 2026, "marital_status": "joint", "value": 15380.0},
59 | {"year": 2026, "marital_status": "separate", "value": 7690.0},
60 | {"year": 2026, "marital_status": "headhousehold", "value": 11323.0},
61 | {"year": 2026, "marital_status": "widow", "value": 15380.0}],
62 | "validators": {
63 | "range": {
64 | "min": 0,
65 | "max": 9e+99
66 | }
67 | }
68 | },
69 | "social_security_tax_rate": {
70 | "description": "Social Security FICA rate, including both employer and employee.",
71 | "section_1": "Payroll Taxes",
72 | "section_2": "Social Security FICA",
73 | "irs_ref": "",
74 | "notes": "",
75 | "start_year": 2026,
76 | "cpi_inflatable": false,
77 | "cpi_inflated": false,
78 | "value": [
79 | {"year": 2024, "value": 0.124},
80 | {"year": 2025, "value": 0.124},
81 | {"year": 2026, "value": 0.124}
82 | ],
83 | "title": "Social Security payroll tax rate",
84 | "type": "float",
85 | "validators": {
86 | "range": {
87 | "min": 0,
88 | "max": 1
89 | }
90 | }
91 | },
92 | "ii_bracket_1": {
93 | "title": "Personal income (regular/non-AMT/non-pass-through) tax bracket (upper threshold) 1",
94 | "description": "Taxable income below this threshold is taxed at tax rate 1.",
95 | "section_1": "Personal Income",
96 | "section_2": "Regular: Non-AMT, Non-Pass-Through",
97 | "irs_ref": "Form 1040, line 44, instruction (Schedule XYZ).",
98 | "notes": "",
99 | "start_year": 2013,
100 | "cpi_inflatable": true,
101 | "cpi_inflated": true,
102 | "type": "float",
103 | "value": [
104 | {"year": 2024, "marital_status": "single", "value": 10853.48},
105 | {"year": 2024, "marital_status": "joint", "value": 21706.97},
106 | {"year": 2024, "marital_status": "separate", "value": 10853.48},
107 | {"year": 2024, "marital_status": "headhousehold", "value": 15496.84},
108 | {"year": 2024, "marital_status": "widow", "value": 21706.97},
109 | {"year": 2025, "marital_status": "single", "value": 11086.83},
110 | {"year": 2025, "marital_status": "joint", "value": 22173.66},
111 | {"year": 2025, "marital_status": "separate", "value": 11086.83},
112 | {"year": 2025, "marital_status": "headhousehold", "value": 15830.02},
113 | {"year": 2025, "marital_status": "widow", "value": 22173.66},
114 | {"year": 2026, "marital_status": "single", "value": 11293.0},
115 | {"year": 2026, "marital_status": "joint", "value": 22585.0},
116 | {"year": 2026, "marital_status": "separate", "value": 11293.0},
117 | {"year": 2026, "marital_status": "headhousehold", "value": 16167.0},
118 | {"year": 2026, "marital_status": "widow", "value": 22585.0}],
119 | "validators": {
120 | "range": {
121 | "min": 0,
122 | "max": "ii_bracket_2"
123 | }
124 | }
125 | },
126 | "ii_bracket_2": {
127 | "title": "Personal income (regular/non-AMT/non-pass-through) tax bracket (upper threshold) 2",
128 | "description": "Income below this threshold and above tax bracket 1 is taxed at tax rate 2.",
129 | "section_1": "Personal Income",
130 | "section_2": "Regular: Non-AMT, Non-Pass-Through",
131 | "irs_ref": "Form 1040, line 11, instruction (Schedule XYZ).",
132 | "notes": "",
133 | "start_year": 2013,
134 | "cpi_inflatable": true,
135 | "cpi_inflated": true,
136 | "type": "float",
137 | "value": [
138 | {"year": 2024, "marital_status": "single", "value": 44097.61},
139 | {"year": 2024, "marital_status": "joint", "value": 88195.23},
140 | {"year": 2024, "marital_status": "separate", "value": 44097.61},
141 | {"year": 2024, "marital_status": "headhousehold", "value": 59024.71},
142 | {"year": 2024, "marital_status": "widow", "value": 88195.23},
143 | {"year": 2025, "marital_status": "single", "value": 45045.71},
144 | {"year": 2025, "marital_status": "joint", "value": 90091.43},
145 | {"year": 2025, "marital_status": "separate", "value": 45045.71},
146 | {"year": 2025, "marital_status": "headhousehold", "value": 60293.74},
147 | {"year": 2025, "marital_status": "widow", "value": 90091.43},
148 | {"year": 2026, "marital_status": "single", "value": 45957.0},
149 | {"year": 2026, "marital_status": "joint", "value": 91915.0},
150 | {"year": 2026, "marital_status": "separate", "value": 45957.0},
151 | {"year": 2026, "marital_status": "headhousehold", "value": 61519.0},
152 | {"year": 2026, "marital_status": "widow", "value": 91915.0}],
153 | "validators": {
154 | "range": {
155 | "min": "ii_bracket_1",
156 | "max": 9e+99
157 | }
158 | }
159 | }
160 | }
161 |
--------------------------------------------------------------------------------
/paramtools/exceptions.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | from paramtools import utils
5 |
6 |
7 | class ParamToolsError(Exception):
8 | pass
9 |
10 |
11 | class ParameterUpdateException(ParamToolsError):
12 | pass
13 |
14 |
15 | class SparseValueObjectsException(ParamToolsError):
16 | pass
17 |
18 |
19 | class UnknownTypeException(ParamToolsError):
20 | pass
21 |
22 |
23 | class ValidationError(ParamToolsError):
24 | def __init__(self, messages, labels):
25 | self.messages = messages
26 | self.labels = labels
27 | error_msg = defaultdict(dict)
28 | for error_type, msgs in self.messages.items():
29 | for param, msg in msgs.items():
30 | error_msg[error_type][param] = utils.ravel(msg)
31 | super().__init__(json.dumps(error_msg, indent=4))
32 |
33 |
34 | class InconsistentLabelsException(ParamToolsError):
35 | pass
36 |
37 |
38 | collision_list = [
39 | "_data",
40 | "_errors",
41 | "_warnings",
42 | "select_eq",
43 | "select_ne",
44 | "select_gt",
45 | "select_gte",
46 | "select_lt",
47 | "select_lte",
48 | "_adjust",
49 | "_delete",
50 | "_numpy_type",
51 | "_parse_errors",
52 | "_resolve_order",
53 | "_schema",
54 | "_set_state",
55 | "_state",
56 | "_stateless_label_grid",
57 | "_update_param",
58 | "_validator_schema",
59 | "_defaults_schema",
60 | "_select",
61 | "_defer_validation",
62 | "operators",
63 | "adjust",
64 | "delete",
65 | "validate",
66 | "transaction",
67 | "array_first",
68 | "clear_state",
69 | "defaults",
70 | "dump",
71 | "items",
72 | "keys",
73 | "label_grid",
74 | "label_validators",
75 | "keyfuncs",
76 | "errors",
77 | "warnings",
78 | "from_array",
79 | "parse_labels",
80 | "read_params",
81 | "schema",
82 | "set_state",
83 | "specification",
84 | "sort_values",
85 | "to_array",
86 | "validation_error",
87 | "view_state",
88 | "extend",
89 | "extend_func",
90 | "uses_extend_func",
91 | "label_to_extend",
92 | "get_index_rate",
93 | "index_rates",
94 | "to_dict",
95 | "_parse_validation_messages",
96 | "sel",
97 | "get_defaults",
98 | ]
99 |
100 |
101 | class ParameterNameCollisionException(ParamToolsError):
102 | pass
103 |
--------------------------------------------------------------------------------
/paramtools/schema_factory.py:
--------------------------------------------------------------------------------
1 | from marshmallow import fields, Schema
2 |
3 | from paramtools.schema import (
4 | BaseValidatorSchema,
5 | ValueObject,
6 | get_type,
7 | get_param_schema,
8 | ParamToolsSchema,
9 | )
10 |
11 |
12 | class SchemaFactory:
13 | """
14 | Uses data from:
15 | - a schema definition file
16 | - a baseline specification file
17 |
18 | to extend:
19 | - `schema.BaseParamSchema`
20 | - `schema.BaseValidatorSchema`
21 |
22 | Once this has been completed, the `load_params` method can be used to
23 | deserialize and validate parameter data.
24 | """
25 |
26 | def __init__(self, defaults):
27 | self.defaults = {k: v for k, v in defaults.items() if k != "schema"}
28 | self.schema = ParamToolsSchema().load(defaults.get("schema", {}))
29 | (self.BaseParamSchema, self.label_validators) = get_param_schema(
30 | self.schema
31 | )
32 |
33 | def schemas(self):
34 | """
35 | For each parameter defined in the baseline specification file:
36 | - define a parameter schema for that specific parameter
37 | - define a validation schema for that specific parameter
38 |
39 | Next, create a baseline specification schema class (`ParamSchema`) for
40 | all parameters listed in the baseline specification file and a
41 | validator schema class (`ValidatorSchema`) for all parameters in the
42 | baseline specification file.
43 |
44 | - `ParamSchema` reads and validates the baseline specification file
45 | - `ValidatorSchema` reads revisions to the baseline parameters and
46 | validates their type, structure, and whether they are within the
47 | specified range.
48 |
49 | `param_schema` is defined and used to read and validate the baseline
50 | specifications file. `validator_schema` is defined to read and validate
51 | the parameter revisions. The output from the baseline specification
52 | deserialization is saved in the `context` attribute on
53 | `validator_schema` and will be utilized when doing range validation.
54 | """
55 | param_dict = {}
56 | validator_dict = {}
57 | for k, v in self.defaults.items():
58 | fieldtype = get_type(v)
59 | classattrs = {
60 | "value": fieldtype,
61 | "_auto": fields.Boolean(required=False, load_only=True),
62 | **self.label_validators,
63 | }
64 |
65 | # TODO: what about case where number_dims > 0
66 | # if not isinstance(v["value"], list):
67 | # v["value"] = [{"value": v["value"]}]
68 |
69 | validator_dict[k] = type("ValidatorItem", (Schema,), classattrs)
70 |
71 | classattrs = {"value": ValueObject(validator_dict[k], many=True)}
72 | param_dict[k] = type(
73 | "IndividualParamSchema", (self.BaseParamSchema,), classattrs
74 | )
75 |
76 | classattrs = {k: fields.Nested(v) for k, v in param_dict.items()}
77 | DefaultsSchema = type("DefaultsSchema", (Schema,), classattrs)
78 | defaults_schema = DefaultsSchema()
79 |
80 | classattrs = {
81 | k: ValueObject(v, many=True) for k, v in validator_dict.items()
82 | }
83 | ValidatorSchema = type(
84 | "ValidatorSchema", (BaseValidatorSchema,), classattrs
85 | )
86 | validator_schema = ValidatorSchema()
87 |
88 | return (
89 | defaults_schema,
90 | validator_schema,
91 | self.schema,
92 | defaults_schema.load(self.defaults),
93 | )
94 |
--------------------------------------------------------------------------------
/paramtools/select.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterable, List, Callable
2 | import warnings
3 |
4 |
5 | from paramtools.typing import ValueObject, CmpFunc
6 | from paramtools import values
7 |
8 |
9 | def select(
10 | value_objects: List[ValueObject],
11 | strict: bool,
12 | cmp_func: CmpFunc,
13 | labels: dict,
14 | tree=None,
15 | op=None,
16 | ) -> List[ValueObject]:
17 | """
18 | Deprecated. Use Values instead.
19 | """
20 | warnings.warn("The select module is deprecated. Use Values instead.")
21 | assert op, "Op is required."
22 | values_ = values.Values(value_objects)
23 | res = []
24 | for label, value in labels.items():
25 | if isinstance(value, list):
26 | if op == "ne":
27 | agg_func = values.intersection
28 | else:
29 | agg_func = values.union
30 | res.append(
31 | agg_func(
32 | values_._cmp(op, strict, **{label: element})
33 | for element in value
34 | )
35 | )
36 | else:
37 | res.append(values_._cmp(op, strict, **{label: value}))
38 | return list(values.intersection(res))
39 |
40 |
41 | def eq_func(x: Any, y: Iterable) -> bool:
42 | return x in y
43 |
44 |
45 | def ne_func(x: Any, y: Iterable) -> bool:
46 | return x not in y
47 |
48 |
49 | def gt_func(x: Any, y: Iterable) -> bool:
50 | return all(x > item for item in y)
51 |
52 |
53 | def gte_func(x: Any, y: Iterable) -> bool:
54 | return all(x >= item for item in y)
55 |
56 |
57 | def gt_ix_func(cmp_list: list, x: Any, y: Iterable) -> bool:
58 | x_val = cmp_list.index(x)
59 | return all(x_val > cmp_list.index(item) for item in y)
60 |
61 |
62 | def lt_func(x, y) -> bool:
63 | return all(x < item for item in y)
64 |
65 |
66 | def lte_func(x, y) -> bool:
67 | return all(x <= item for item in y)
68 |
69 |
70 | def make_cmp_func(
71 | cmp: Callable[[Any, Iterable], bool],
72 | all_or_any: Callable[[Iterable], bool],
73 | ) -> CmpFunc:
74 | return lambda x, y: all_or_any(cmp(x, item) for item in y)
75 |
76 |
77 | def select_eq(
78 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
79 | ) -> List[ValueObject]:
80 | return select(value_objects, strict, eq_func, labels, tree, op="eq")
81 |
82 |
83 | def select_ne(
84 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
85 | ) -> List[ValueObject]:
86 | return select(value_objects, strict, ne_func, labels, tree, op="ne")
87 |
88 |
89 | def select_gt(
90 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
91 | ) -> List[ValueObject]:
92 | return select(value_objects, strict, gt_func, labels, tree, op="gt")
93 |
94 |
95 | def select_gte(
96 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
97 | ) -> List[ValueObject]:
98 | return select(value_objects, strict, gte_func, labels, tree, op="gte")
99 |
100 |
101 | def select_gt_ix(
102 | value_objects: List[ValueObject],
103 | strict: bool,
104 | labels: dict,
105 | cmp_list: List,
106 | tree=None,
107 | ) -> List[ValueObject]:
108 | raise Exception("select_gt_ix is deprecated. Use Values instead.")
109 |
110 |
111 | def select_lt(
112 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
113 | ) -> List[ValueObject]:
114 | return select(value_objects, strict, lt_func, labels, tree, op="lt")
115 |
116 |
117 | def select_lte(
118 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None
119 | ) -> List[ValueObject]:
120 | return select(value_objects, strict, lte_func, labels, tree, op="lte")
121 |
--------------------------------------------------------------------------------
/paramtools/sorted_key_list.py:
--------------------------------------------------------------------------------
1 | import sortedcontainers
2 |
3 |
4 | class SortedKeyListException(Exception):
5 | default = (
6 | "Unable to create SortedKeyList. It is likely that this label or "
7 | "value uses a custom data type. In this case, you should define a "
8 | "cmp_funcs method to make it orderable."
9 | )
10 |
11 | def __init__(self, *args, **kwargs):
12 | if not args:
13 | args = (self.default,)
14 |
15 | super().__init__(*args, **kwargs)
16 |
17 |
18 | class SortedKeyListResult:
19 | def __init__(self, key_list_values):
20 | self.key_list_values = key_list_values
21 |
22 | @property
23 | def values(self):
24 | return [item[0] for item in self.key_list_values]
25 |
26 | @property
27 | def index(self):
28 | return [item[1] for item in self.key_list_values]
29 |
30 | def __iter__(self):
31 | for value in self.values:
32 | yield value
33 |
34 |
35 | class SortedKeyList:
36 | """
37 | Sorted key list built on top of sortedcontainers. This adds some
38 | query methods like lt/gt/eq and keeps track of the original indices
39 | for the values.
40 | """
41 |
42 | def __init__(self, values, keyfunc, index=None):
43 | if index:
44 | assert len(values) == len(index)
45 | if index is None:
46 | index = range(len(values))
47 | sorted_key_list = [(val, ix) for ix, val in zip(index, values)]
48 | self.index = set(index)
49 | self.keyfunc = keyfunc
50 |
51 | try:
52 | self.sorted_key_list_2 = sortedcontainers.SortedKeyList(
53 | sorted_key_list, key=lambda t: keyfunc(t[0])
54 | )
55 | except TypeError as e:
56 | raise SortedKeyListException() from e
57 |
58 | def __repr__(self):
59 | return str(self.sorted_key_list)
60 |
61 | def eq(self, value):
62 | key = self.keyfunc(value)
63 | irange = list(
64 | self.sorted_key_list_2.irange_key(min_key=key, max_key=key)
65 | )
66 | if irange:
67 | return SortedKeyListResult(irange)
68 | return None
69 |
70 | def ne(self, value):
71 | result = []
72 | lt = self.lt(value)
73 | if lt is not None:
74 | result += lt.key_list_values
75 | gt = self.gt(value)
76 | if gt is not None:
77 | result += gt.key_list_values
78 |
79 | return SortedKeyListResult(result)
80 |
81 | def lt(self, value):
82 | key = self.keyfunc(value)
83 | irange = list(
84 | self.sorted_key_list_2.irange_key(
85 | max_key=key, inclusive=(True, False)
86 | )
87 | )
88 | if irange:
89 | return SortedKeyListResult(irange)
90 | return None
91 |
92 | def lte(self, value):
93 | key = self.keyfunc(value)
94 | irange = list(
95 | self.sorted_key_list_2.irange_key(
96 | max_key=key, inclusive=(True, True)
97 | )
98 | )
99 | if irange:
100 | return SortedKeyListResult(irange)
101 | return None
102 |
103 | def gt(self, value):
104 | key = self.keyfunc(value)
105 | irange = list(
106 | self.sorted_key_list_2.irange_key(
107 | min_key=key, inclusive=(False, True)
108 | )
109 | )
110 | if irange:
111 | return SortedKeyListResult(irange)
112 | return None
113 |
114 | def gte(self, value):
115 | key = self.keyfunc(value)
116 | irange = list(
117 | self.sorted_key_list_2.irange_key(
118 | min_key=key, inclusive=(True, True)
119 | )
120 | )
121 | if irange:
122 | return SortedKeyListResult(irange)
123 | return None
124 |
125 | def add(self, value, index=None):
126 | if index is None:
127 | index = max(self.index) + 1
128 | self.sorted_key_list_2.add((value, index))
129 | self.index.add(index)
130 |
--------------------------------------------------------------------------------
/paramtools/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PSLmodels/ParamTools/3a9094f50b45b5ec7324d8ffa9d6ef7070957452/paramtools/tests/__init__.py
--------------------------------------------------------------------------------
/paramtools/tests/defaults.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": {
3 | "labels": {
4 | "label0": {
5 | "type": "str",
6 | "validators": {
7 | "choice": {
8 | "choices": [
9 | "zero",
10 | "one"
11 | ]
12 | }
13 | }
14 | },
15 | "label1": {
16 | "type": "int",
17 | "validators": {
18 | "range": {
19 | "min": 0,
20 | "max": 5
21 | }
22 | }
23 | },
24 | "label2": {
25 | "type": "int",
26 | "validators": {
27 | "range": {
28 | "min": 0,
29 | "max": 2
30 | }
31 | }
32 | }
33 | },
34 | "additional_members": {
35 | "opt0": {
36 | "type": "str"
37 | }
38 | }
39 | },
40 | "when_param": {
41 | "title": "When validator reference param",
42 | "description": "Example for using 'when' validator",
43 | "type": "int",
44 | "value": 0,
45 | "validators": {
46 | "when": {
47 | "param": "str_choice_param",
48 | "is": "value0",
49 | "then": {
50 | "range": {
51 | "min": 0,
52 | "max": 0
53 | }
54 | },
55 | "otherwise": {
56 | "choice": {
57 | "choices": [
58 | 0,
59 | 2,
60 | 5,
61 | 7
62 | ]
63 | }
64 | }
65 | }
66 | }
67 | },
68 | "when_array_param": {
69 | "title": "When validator reference array param",
70 | "description": "Example for using 'when' validator with an array param",
71 | "type": "int",
72 | "number_dims": 1,
73 | "value": [
74 | 0,
75 | 1,
76 | 2,
77 | 5
78 | ],
79 | "validators": {
80 | "when": {
81 | "param": "simple_int_list_param",
82 | "is": 2,
83 | "then": {
84 | "choice": {
85 | "choices": [
86 | 1
87 | ]
88 | }
89 | },
90 | "otherwise": {
91 | "choice": {
92 | "choices": [
93 | 0,
94 | 2,
95 | 5
96 | ]
97 | }
98 | }
99 | }
100 | }
101 | },
102 | "float_param": {
103 | "title": "Float Reference Param",
104 | "description": "Example for a float param.",
105 | "opt0": "an option",
106 | "type": "float",
107 | "value": 2.0,
108 | "validators": {}
109 | },
110 | "bool_param": {
111 | "title": "Boolean Reference Param",
112 | "description": "Example for a bool param.",
113 | "opt0": "an option",
114 | "type": "bool",
115 | "value": true,
116 | "validators": {}
117 | },
118 | "min_int_param": {
119 | "title": "min integer parameter",
120 | "description": "Serves as minimum reference variable.",
121 | "notes": "See max_int_param",
122 | "opt0": "an option",
123 | "type": "int",
124 | "value": [
125 | {
126 | "label0": "zero",
127 | "label1": 1,
128 | "value": 1
129 | },
130 | {
131 | "label0": "one",
132 | "label1": 2,
133 | "value": 2
134 | }
135 | ],
136 | "validators": {
137 | "range": {
138 | "min": 0,
139 | "max": "max_int_param"
140 | }
141 | }
142 | },
143 | "max_int_param": {
144 | "title": "max integer parameter",
145 | "description": "Serves as maximum reference variable.",
146 | "notes": "See min_int_param",
147 | "opt0": "an option",
148 | "type": "int",
149 | "value": [
150 | {
151 | "label0": "zero",
152 | "label1": 1,
153 | "value": 3
154 | },
155 | {
156 | "label0": "one",
157 | "label1": 2,
158 | "value": 4
159 | }
160 | ],
161 | "validators": {
162 | "range": {
163 | "min": "min_int_param",
164 | "max": 10
165 | }
166 | }
167 | },
168 | "str_choice_param": {
169 | "title": "String Choice Param",
170 | "description": "Example for string type params using a choice validator",
171 | "opt0": "another option",
172 | "type": "str",
173 | "value": "value0",
174 | "validators": {
175 | "choice": {
176 | "choices": [
177 | "value0",
178 | "value1"
179 | ]
180 | }
181 | }
182 | },
183 | "date_param": {
184 | "title": "Date parameter",
185 | "description": "Example for a date parameter",
186 | "opt0": "another option",
187 | "type": "date",
188 | "value": [
189 | {
190 | "label0": "zero",
191 | "label1": 1,
192 | "value": "2018-01-15"
193 | }
194 | ],
195 | "validators": {
196 | "date_range": {
197 | "min": "2018-01-01",
198 | "max": "2018-12-31"
199 | }
200 | }
201 | },
202 | "date_min_param": {
203 | "title": "Date Min Param",
204 | "description": "Serves as minimum reference variable.",
205 | "notes": "See date_max_param.",
206 | "opt0": "an option",
207 | "type": "date",
208 | "value": [
209 | {
210 | "label0": "zero",
211 | "label1": 1,
212 | "value": "2018-01-15"
213 | }
214 | ],
215 | "validators": {
216 | "date_range": {
217 | "min": "2018-01-01",
218 | "max": "date_max_param"
219 | }
220 | }
221 | },
222 | "date_max_param": {
223 | "title": "Date max parameter",
224 | "description": "Serves as maximum reference variable.",
225 | "notes": "See date_min_param.",
226 | "opt0": "an option",
227 | "type": "date",
228 | "value": [
229 | {
230 | "label0": "zero",
231 | "label1": 1,
232 | "value": "2018-01-15"
233 | }
234 | ],
235 | "validators": {
236 | "date_range": {
237 | "min": "date_min_param",
238 | "max": "2018-12-31"
239 | }
240 | }
241 | },
242 | "float_list_when_param": {
243 | "title": "Float List When Param",
244 | "description": "Reference for a float list param, using a when validator.",
245 | "opt0": "an option",
246 | "type": "float",
247 | "number_dims": 1,
248 | "value": [
249 | {
250 | "label0": "zero",
251 | "value": [
252 | 0,
253 | 2.0,
254 | 3.0,
255 | 4.0
256 | ]
257 | }
258 | ],
259 | "validators": {
260 | "when": {
261 | "param": "float_list_param",
262 | "is": {
263 | "greater_than": 1
264 | },
265 | "then": {
266 | "range": {
267 | "min": 1
268 | }
269 | },
270 | "otherwise": {
271 | "range": {
272 | "min": 0
273 | }
274 | }
275 | }
276 | }
277 | },
278 | "float_list_param": {
279 | "title": "Float List Param",
280 | "description": "Example for a float, list param.",
281 | "opt0": "an option",
282 | "type": "float",
283 | "number_dims": 1,
284 | "value": [
285 | {
286 | "label0": "zero",
287 | "label1": 1,
288 | "value": [
289 | 1,
290 | 2.0,
291 | 3.5,
292 | 4.6
293 | ]
294 | }
295 | ],
296 | "validators": {
297 | "range": {
298 | "min": 0,
299 | "max": 10
300 | }
301 | }
302 | },
303 | "simple_int_list_param": {
304 | "title": "Simple Int List Param",
305 | "description": "Test case where param is simple and a list.",
306 | "opt0": "an option",
307 | "type": "int",
308 | "number_dims": 1,
309 | "value": [
310 | 1,
311 | 2,
312 | 3,
313 | 4
314 | ],
315 | "validators": {
316 | "range": {
317 | "min": 0,
318 | "max": 10
319 | }
320 | }
321 | },
322 | "int_default_param": {
323 | "title": "Integer Default Reference Param",
324 | "description": "Example for a int param using a default reference value",
325 | "opt0": "an option",
326 | "type": "int",
327 | "value": 2,
328 | "validators": {
329 | "range": {
330 | "min": "default",
331 | "max": 10
332 | }
333 | }
334 | },
335 | "int_dense_array_param": {
336 | "title": "Integer Dense Array Param",
337 | "description": "Example of using an int type param that supports to/from_array.",
338 | "opt0": "an option",
339 | "type": "int",
340 | "value": [
341 | {
342 | "label0": "zero",
343 | "label1": 0,
344 | "label2": 0,
345 | "value": 1
346 | },
347 | {
348 | "label0": "zero",
349 | "label1": 0,
350 | "label2": 1,
351 | "value": 2
352 | },
353 | {
354 | "label0": "zero",
355 | "label1": 0,
356 | "label2": 2,
357 | "value": 3
358 | },
359 | {
360 | "label0": "zero",
361 | "label1": 1,
362 | "label2": 0,
363 | "value": 4
364 | },
365 | {
366 | "label0": "zero",
367 | "label1": 1,
368 | "label2": 1,
369 | "value": 5
370 | },
371 | {
372 | "label0": "zero",
373 | "label1": 1,
374 | "label2": 2,
375 | "value": 6
376 | },
377 | {
378 | "label0": "zero",
379 | "label1": 2,
380 | "label2": 0,
381 | "value": 7
382 | },
383 | {
384 | "label0": "zero",
385 | "label1": 2,
386 | "label2": 1,
387 | "value": 8
388 | },
389 | {
390 | "label0": "zero",
391 | "label1": 2,
392 | "label2": 2,
393 | "value": 9
394 | },
395 | {
396 | "label0": "zero",
397 | "label1": 3,
398 | "label2": 0,
399 | "value": 10
400 | },
401 | {
402 | "label0": "zero",
403 | "label1": 3,
404 | "label2": 1,
405 | "value": 11
406 | },
407 | {
408 | "label0": "zero",
409 | "label1": 3,
410 | "label2": 2,
411 | "value": 12
412 | },
413 | {
414 | "label0": "zero",
415 | "label1": 4,
416 | "label2": 0,
417 | "value": 13
418 | },
419 | {
420 | "label0": "zero",
421 | "label1": 4,
422 | "label2": 1,
423 | "value": 14
424 | },
425 | {
426 | "label0": "zero",
427 | "label1": 4,
428 | "label2": 2,
429 | "value": 15
430 | },
431 | {
432 | "label0": "zero",
433 | "label1": 5,
434 | "label2": 0,
435 | "value": 16
436 | },
437 | {
438 | "label0": "zero",
439 | "label1": 5,
440 | "label2": 1,
441 | "value": 17
442 | },
443 | {
444 | "label0": "zero",
445 | "label1": 5,
446 | "label2": 2,
447 | "value": 18
448 | },
449 | {
450 | "label0": "one",
451 | "label1": 0,
452 | "label2": 0,
453 | "value": 19
454 | },
455 | {
456 | "label0": "one",
457 | "label1": 0,
458 | "label2": 1,
459 | "value": 20
460 | },
461 | {
462 | "label0": "one",
463 | "label1": 0,
464 | "label2": 2,
465 | "value": 21
466 | },
467 | {
468 | "label0": "one",
469 | "label1": 1,
470 | "label2": 0,
471 | "value": 22
472 | },
473 | {
474 | "label0": "one",
475 | "label1": 1,
476 | "label2": 1,
477 | "value": 23
478 | },
479 | {
480 | "label0": "one",
481 | "label1": 1,
482 | "label2": 2,
483 | "value": 24
484 | },
485 | {
486 | "label0": "one",
487 | "label1": 2,
488 | "label2": 0,
489 | "value": 25
490 | },
491 | {
492 | "label0": "one",
493 | "label1": 2,
494 | "label2": 1,
495 | "value": 26
496 | },
497 | {
498 | "label0": "one",
499 | "label1": 2,
500 | "label2": 2,
501 | "value": 27
502 | },
503 | {
504 | "label0": "one",
505 | "label1": 3,
506 | "label2": 0,
507 | "value": 28
508 | },
509 | {
510 | "label0": "one",
511 | "label1": 3,
512 | "label2": 1,
513 | "value": 29
514 | },
515 | {
516 | "label0": "one",
517 | "label1": 3,
518 | "label2": 2,
519 | "value": 30
520 | },
521 | {
522 | "label0": "one",
523 | "label1": 4,
524 | "label2": 0,
525 | "value": 31
526 | },
527 | {
528 | "label0": "one",
529 | "label1": 4,
530 | "label2": 1,
531 | "value": 32
532 | },
533 | {
534 | "label0": "one",
535 | "label1": 4,
536 | "label2": 2,
537 | "value": 33
538 | },
539 | {
540 | "label0": "one",
541 | "label1": 5,
542 | "label2": 0,
543 | "value": 34
544 | },
545 | {
546 | "label0": "one",
547 | "label1": 5,
548 | "label2": 1,
549 | "value": 35
550 | },
551 | {
552 | "label0": "one",
553 | "label1": 5,
554 | "label2": 2,
555 | "value": 36
556 | }
557 | ],
558 | "validators": {
559 | "range": {
560 | "min": 1,
561 | "max": 36
562 | }
563 | }
564 | },
565 | "str_choice_warn_param": {
566 | "title": "String Choice Warnings Param",
567 | "description": "Example for string type params using a choice validator with warnings",
568 | "opt0": "another option",
569 | "type": "str",
570 | "value": "value0",
571 | "validators": {
572 | "choice": {
573 | "choices": [
574 | "value0",
575 | "value1"
576 | ],
577 | "level": "warn"
578 | }
579 | }
580 | },
581 | "int_warn_param": {
582 | "title": "Integer Parameter that uses warnings",
583 | "description": "Example for a int param using warnings with a range validator",
584 | "opt0": "an option",
585 | "type": "int",
586 | "value": 2,
587 | "validators": {
588 | "range": {
589 | "min": 0,
590 | "max": 10,
591 | "level": "warn"
592 | }
593 | }
594 | }
595 | }
--------------------------------------------------------------------------------
/paramtools/tests/extend_ex.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": {
3 | "labels": {
4 | "d0": {
5 | "type": "int",
6 | "validators": {"range": {"min": 0, "max": 10}}
7 | },
8 | "d1": {
9 | "type": "str",
10 | "validators": {
11 | "choice": {"choices": ["c1", "c2"]}
12 | }
13 | }
14 | }
15 | },
16 | "extend_param": {
17 | "title": "extend param",
18 | "description": ".",
19 | "type": "int",
20 | "value": [
21 | {"d0": 2, "d1": "c1", "value": 1},
22 | {"d0": 2, "d1": "c2", "value": 2},
23 | {"d0": 3, "d1": "c1", "value": 3},
24 | {"d0": 3, "d1": "c2", "value": 4},
25 | {"d0": 5, "d1": "c1", "value": 5},
26 | {"d0": 5, "d1": "c2", "value": 6},
27 | {"d0": 7, "d1": "c1", "value": 7},
28 | {"d0": 7, "d1": "c2", "value": 8}
29 | ],
30 | "validators": {
31 | "range": {
32 | "min": -100, "max": "related_param"
33 | }
34 | }
35 | },
36 | "indexed_param": {
37 | "title": "indexed param",
38 | "description": ".",
39 | "type": "float",
40 | "indexed": true,
41 | "value": [
42 | {"d0": 2, "d1": "c1", "value": 1},
43 | {"d0": 2, "d1": "c2", "value": 2},
44 | {"d0": 3, "d1": "c1", "value": 3},
45 | {"d0": 3, "d1": "c2", "value": 4},
46 | {"d0": 5, "d1": "c1", "value": 5},
47 | {"d0": 5, "d1": "c2", "value": 6},
48 | {"d0": 7, "d1": "c1", "value": 7},
49 | {"d0": 7, "d1": "c2", "value": 8}
50 | ],
51 | "validators": {
52 | "range": {
53 | "min": -100, "max": "related_param"
54 | }
55 | }
56 | },
57 | "related_param": {
58 | "title": "related param",
59 | "description": "Test error on adjustment extension.",
60 | "type": "int",
61 | "value": [
62 | {"d0": 0, "d1": "c1", "value": 100},
63 | {"d0": 0, "d1": "c2", "value": 101},
64 | {"d0": 7, "d1": "c1", "value": 50},
65 | {"d0": 7, "d1": "c2", "value": 51}
66 | ]
67 | },
68 | "nonextend_param": {
69 | "title": "nonextend param",
70 | "description": "Test error on adjustment extension.",
71 | "type": "int",
72 | "value": 2
73 | }
74 | }
--------------------------------------------------------------------------------
/paramtools/tests/test_examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PSLmodels/ParamTools/3a9094f50b45b5ec7324d8ffa9d6ef7070957452/paramtools/tests/test_examples/__init__.py
--------------------------------------------------------------------------------
/paramtools/tests/test_examples/test_baseball.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from paramtools import parameters
6 |
7 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
8 |
9 |
10 | @pytest.fixture
11 | def field_map():
12 | # nothing here for now
13 | return {}
14 |
15 |
16 | @pytest.fixture
17 | def defaults_spec_path():
18 | return os.path.join(CURRENT_PATH, "../../examples/baseball/defaults.json")
19 |
20 |
21 | @pytest.fixture
22 | def BaseballParams(defaults_spec_path):
23 | class _BaseballParams(parameters.Parameters):
24 | defaults = defaults_spec_path
25 |
26 | return _BaseballParams
27 |
28 |
29 | def test_load_schema(BaseballParams):
30 | params = BaseballParams()
31 | assert params
32 |
--------------------------------------------------------------------------------
/paramtools/tests/test_examples/test_behresp.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from paramtools import Parameters
6 |
7 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
8 |
9 |
10 | @pytest.fixture
11 | def field_map():
12 | # nothing here for now
13 | return {}
14 |
15 |
16 | @pytest.fixture
17 | def defaults_spec_path():
18 | return os.path.join(CURRENT_PATH, "../../examples/behresp/defaults.json")
19 |
20 |
21 | @pytest.fixture
22 | def BehrespParams(defaults_spec_path):
23 | class _BehrespParams(Parameters):
24 | defaults = defaults_spec_path
25 |
26 | return _BehrespParams
27 |
28 |
29 | def test_load_schema(BehrespParams):
30 | params = BehrespParams()
31 | assert params
32 |
--------------------------------------------------------------------------------
/paramtools/tests/test_examples/test_tc_ex.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from marshmallow import fields, Schema
6 |
7 | from paramtools import Parameters, register_custom_type
8 |
9 |
10 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
11 |
12 |
13 | class CompatibleDataSchema(Schema):
14 | """
15 | Schema for Compatible data object
16 | {
17 | "compatible_data": {"data1": bool, "data2": bool, ...}
18 | }
19 | """
20 |
21 | puf = fields.Boolean()
22 | cps = fields.Boolean()
23 |
24 |
25 | @pytest.fixture
26 | def register_compatible_data():
27 | register_custom_type(
28 | "compatible_data", fields.Nested(CompatibleDataSchema())
29 | )
30 |
31 |
32 | @pytest.fixture
33 | def defaults_spec_path():
34 | return os.path.join(CURRENT_PATH, "../../examples/taxparams/defaults.json")
35 |
36 |
37 | @pytest.fixture
38 | def TaxcalcParams(defaults_spec_path, register_compatible_data):
39 | class _TaxcalcParams(Parameters):
40 | defaults = defaults_spec_path
41 |
42 | return _TaxcalcParams
43 |
44 |
45 | def test_load_schema(TaxcalcParams):
46 | params = TaxcalcParams()
47 | assert params
48 |
49 |
50 | @pytest.fixture
51 | def demo_defaults_spec_path():
52 | return os.path.join(
53 | CURRENT_PATH, "../../examples/taxparams-demo/defaults.json"
54 | )
55 |
56 |
57 | @pytest.fixture
58 | def TaxDemoParams(demo_defaults_spec_path):
59 | class _TaxDemoParams(Parameters):
60 | defaults = demo_defaults_spec_path
61 |
62 | return _TaxDemoParams
63 |
64 |
65 | def test_load_demo_schema(TaxDemoParams):
66 | params = TaxDemoParams()
67 | assert params
68 |
--------------------------------------------------------------------------------
/paramtools/tests/test_fields.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import numpy as np
4 |
5 | from paramtools.contrib import fields, validate
6 |
7 |
8 | def test_np_value_fields():
9 | float64 = fields.Float64()
10 | res = float64._deserialize("2", None, None)
11 | assert res == 2.0
12 | assert isinstance(res, np.float64)
13 | assert type(float64._serialize(res, None, None)) == float
14 |
15 | int64 = fields.Int64()
16 | res = int64._deserialize("2", None, None)
17 | assert res == 2
18 | assert isinstance(res, np.int64)
19 | assert type(int64._serialize(res, None, None)) == int
20 |
21 | bool_ = fields.Bool_()
22 | res = bool_._deserialize("true", None, None)
23 | assert res is np.bool_(True)
24 | assert isinstance(res, np.bool_)
25 | assert bool_._serialize(res, None, None) is True
26 |
27 |
28 | def test_contrib_fields():
29 | range_validator = validate.Range(0, 10)
30 | daterange_validator = validate.DateRange(
31 | "2019-01-01", "2019-01-05", step={"days": 2}
32 | )
33 | choice_validator = validate.OneOf(choices=["one", "two"])
34 |
35 | s = fields.Str(validate=[choice_validator])
36 | assert s.grid() == ["one", "two"]
37 | s = fields.Str()
38 | assert s.grid() == []
39 |
40 | s = fields.Integer(validate=[range_validator])
41 | assert s.grid() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
42 | s = fields.Str()
43 | assert s.grid() == []
44 |
45 | # date will need an interval argument.
46 | s = fields.Date(validate=[daterange_validator])
47 | assert s.grid() == [datetime.date(2019, 1, i) for i in range(1, 6, 2)]
48 |
49 | s = fields.Date()
50 | assert s._deserialize(datetime.date(2015, 1, 1), None, None)
51 |
52 |
53 | def test_cmp_funcs():
54 | range_validator = validate.Range(0, 10)
55 | daterange_validator = validate.DateRange(
56 | "2019-01-01", "2019-01-05", step={"days": 2}
57 | )
58 | choice_validator = validate.OneOf(choices=["one", "two"])
59 |
60 | cases = [
61 | ("one", "two", fields.Str(validate=[choice_validator])),
62 | (
63 | datetime.date(2019, 1, 2),
64 | datetime.date(2019, 1, 3),
65 | fields.Date(validate=[daterange_validator]),
66 | ),
67 | (2, 5, fields.Integer(validate=[range_validator])),
68 | ]
69 |
70 | for (min_, max_, field) in cases:
71 | cmp_funcs = field.cmp_funcs()
72 | assert cmp_funcs["gt"](min_, max_) is False
73 | assert cmp_funcs["lt"](min_, max_) is True
74 | assert cmp_funcs["eq"](min_, max_) is False
75 | assert cmp_funcs["eq"](max_, max_) is True
76 | assert cmp_funcs["lte"](max_, max_) is True
77 | assert cmp_funcs["lte"](min_, max_) is True
78 | assert cmp_funcs["gte"](max_, min_) is True
79 |
--------------------------------------------------------------------------------
/paramtools/tests/test_schema.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import pytest
4 | import marshmallow as ma
5 |
6 | from paramtools import (
7 | get_type,
8 | get_param_schema,
9 | register_custom_type,
10 | ALLOWED_TYPES,
11 | UnknownTypeException,
12 | PartialField,
13 | ParamToolsError,
14 | )
15 |
16 |
17 | def test_get_type_with_list():
18 | int_field = get_type({"type": "int"})
19 |
20 | list_int_field = get_type({"type": "int", "number_dims": 1})
21 | assert list_int_field.np_type == int_field.np_type
22 |
23 | list_int_field = get_type({"type": "int", "number_dims": 2})
24 | assert list_int_field.np_type == int_field.np_type
25 |
26 |
27 | def test_register_custom_type():
28 | """
29 | Test allowed to register marshmallow field and PartialField instances and
30 | test that uninitialized fields and random classes throw type errors.
31 | """
32 | custom_type = "custom"
33 | assert custom_type not in ALLOWED_TYPES
34 | register_custom_type(custom_type, ma.fields.String())
35 | assert custom_type in ALLOWED_TYPES
36 |
37 | register_custom_type("partial-test", PartialField(ma.fields.Str(), {}))
38 | assert "partial-test" in ALLOWED_TYPES
39 |
40 | with pytest.raises(TypeError):
41 | register_custom_type("custom", ma.fields.Str)
42 |
43 | class Whatever:
44 | pass
45 |
46 | with pytest.raises(TypeError):
47 | register_custom_type("whatever", Whatever())
48 |
49 |
50 | def test_get_type():
51 | custom_type = "custom"
52 | register_custom_type(custom_type, ma.fields.String())
53 |
54 | assert isinstance(get_type({"type": custom_type}), ma.fields.String)
55 |
56 | with pytest.raises(UnknownTypeException):
57 | get_type({"type": "unknown"})
58 |
59 |
60 | def test_make_schema():
61 | custom_type = "custom"
62 | register_custom_type(custom_type, ma.fields.String())
63 |
64 | schema = {
65 | "labels": {"lab": {"type": custom_type, "validators": {}}},
66 | "additional_members": {"custom": {"type": custom_type}},
67 | }
68 |
69 | assert get_param_schema(schema)
70 |
71 | bad_schema = copy.deepcopy(schema)
72 | bad_schema["labels"]["lab"]["type"] = "unknown"
73 | with pytest.raises(UnknownTypeException):
74 | get_param_schema(bad_schema)
75 |
76 | bad_schema = copy.deepcopy(schema)
77 | bad_schema["additional_members"]["custom"]["type"] = "unknown"
78 | with pytest.raises(UnknownTypeException):
79 | get_param_schema(bad_schema)
80 |
81 | schema = {
82 | "labels": {
83 | "lab": {
84 | "type": custom_type,
85 | "validators": {"choice": {"choices": ["hello"]}},
86 | }
87 | },
88 | "additional_members": {"custom": {"type": custom_type}},
89 | }
90 | with pytest.raises(ParamToolsError):
91 | get_param_schema(schema)
92 |
--------------------------------------------------------------------------------
/paramtools/tests/test_select.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from paramtools.select import (
4 | select_eq,
5 | select_ne,
6 | select_gt,
7 | select_gte,
8 | select_lt,
9 | select_lte,
10 | )
11 |
12 |
13 | @pytest.fixture
14 | def vos():
15 | return [
16 | {"d0": 1, "d1": "hello", "value": 1},
17 | {"d0": 1, "d1": "world", "value": 1},
18 | {"d0": 2, "d1": "hello", "value": 1},
19 | {"d0": 3, "d1": "world", "value": 1},
20 | ]
21 |
22 |
23 | def test_select_eq(vos):
24 | assert list(select_eq(vos, False, labels={"d0": 1, "d1": "hello"})) == [
25 | {"d0": 1, "d1": "hello", "value": 1}
26 | ]
27 |
28 | assert list(
29 | select_eq(vos, False, labels={"d0": [1, 2], "d1": "hello"})
30 | ) == [
31 | {"d0": 1, "d1": "hello", "value": 1},
32 | {"d0": 2, "d1": "hello", "value": 1},
33 | ]
34 |
35 |
36 | def test_select_eq_strict(vos):
37 | assert list(select_eq(vos, True, labels={"d0": 1, "d1": "hello"})) == [
38 | {"d0": 1, "d1": "hello", "value": 1}
39 | ]
40 |
41 | assert list(
42 | select_eq(vos, True, labels={"d0": [1, 2], "d1": "hello"})
43 | ) == [
44 | {"d0": 1, "d1": "hello", "value": 1},
45 | {"d0": 2, "d1": "hello", "value": 1},
46 | ]
47 |
48 | vos[2]["_auto"] = True
49 | vos[3]["_auto"] = True
50 | assert list(select_eq(vos, False, labels={"_auto": False})) == [
51 | {"d0": 1, "d1": "hello", "value": 1},
52 | {"d0": 1, "d1": "world", "value": 1},
53 | ]
54 |
55 |
56 | def test_select_ne(vos):
57 | assert list(select_ne(vos, False, labels={"d0": 1, "d1": "hello"})) == [
58 | {"d0": 3, "d1": "world", "value": 1}
59 | ]
60 |
61 | assert list(select_ne(vos, False, labels={"d0": [2, 3]})) == [
62 | {"d0": 1, "d1": "hello", "value": 1},
63 | {"d0": 1, "d1": "world", "value": 1},
64 | ]
65 |
66 |
67 | def test_select_gt(vos):
68 | assert list(select_gt(vos, False, labels={"d0": 1})) == [
69 | {"d0": 2, "d1": "hello", "value": 1},
70 | {"d0": 3, "d1": "world", "value": 1},
71 | ]
72 |
73 |
74 | def test_select_gte(vos):
75 | assert list(select_gte(vos, False, labels={"d0": 2})) == [
76 | {"d0": 2, "d1": "hello", "value": 1},
77 | {"d0": 3, "d1": "world", "value": 1},
78 | ]
79 |
80 |
81 | def test_select_lt(vos):
82 | assert list(select_lt(vos, False, labels={"d0": 3})) == [
83 | {"d0": 1, "d1": "hello", "value": 1},
84 | {"d0": 1, "d1": "world", "value": 1},
85 | {"d0": 2, "d1": "hello", "value": 1},
86 | ]
87 |
88 |
89 | def test_select_lte(vos):
90 | assert list(select_lte(vos, False, labels={"d0": 2})) == [
91 | {"d0": 1, "d1": "hello", "value": 1},
92 | {"d0": 1, "d1": "world", "value": 1},
93 | {"d0": 2, "d1": "hello", "value": 1},
94 | ]
95 |
--------------------------------------------------------------------------------
/paramtools/tests/test_sorted_key_list.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from paramtools.sorted_key_list import SortedKeyList, SortedKeyListException
3 |
4 |
5 | def test_sorted_key_list():
6 | values = {
7 | "red": 2,
8 | "blue": 3,
9 | "orange": 5,
10 | "white": 6,
11 | "yellow": 7,
12 | "green": 9,
13 | "black": 0,
14 | }
15 |
16 | to_add = ["red", "blue", "orange", "yellow"]
17 |
18 | skl = SortedKeyList(
19 | to_add, keyfunc=lambda x: values[x], index=list(range(len(to_add)))
20 | )
21 |
22 | assert skl.eq("black") is None
23 | assert skl.gte("black").values[0] == "red"
24 | assert skl.lte("black") is None
25 | skl.add("black")
26 | assert skl.gte("black").values[0] == "black"
27 | assert skl.lte("black").values[-1] == "black"
28 | assert skl.eq("black").values == ["black"]
29 |
30 | assert skl.gte("white").values[0] == "yellow"
31 | assert skl.lte("white").values[-1] == "orange"
32 | skl.add("white")
33 | assert skl.gte("white").values[0] == "white"
34 | assert skl.gt("white").values[0] == "yellow"
35 | assert skl.lte("white").values[-1] == "white"
36 | assert skl.lt("yellow").values[-1] == "white"
37 |
38 | assert skl.gte("green") is None
39 | assert skl.lte("green").values[-1] == "yellow"
40 | skl.add("green")
41 | assert skl.gte("green").values[0] == "green"
42 | assert skl.lte("green").values[-1] == "green"
43 |
44 | skl.add("green")
45 | assert skl.eq("green").values == ["green", "green"]
46 |
47 | assert set(skl.ne("green").values) == set(list(values.keys())) - {"green"}
48 | values["pokadot"] = -1
49 | assert set(skl.ne("pokadot").values) == set(list(values.keys())) - {
50 | "pokadot"
51 | }
52 |
53 |
54 | def test_exception():
55 | with pytest.raises(SortedKeyListException):
56 | SortedKeyList(
57 | [
58 | {"really": {"nested": {"field": True}}},
59 | {"really": {"nested": {"field": False}}},
60 | ],
61 | keyfunc=lambda x: x,
62 | )
63 |
--------------------------------------------------------------------------------
/paramtools/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | from paramtools import (
5 | get_leaves,
6 | ravel,
7 | consistent_labels,
8 | ensure_value_object,
9 | hashable_value_object,
10 | filter_labels,
11 | make_label_str,
12 | read_json,
13 | )
14 |
15 |
16 | class TestRead:
17 | @pytest.mark.network_bound
18 | def test_read_s3(self):
19 | res = read_json("s3://paramtools-test/defaults.json", {"anon": True})
20 | assert isinstance(res, dict)
21 |
22 | # @pytest.mark.network_bound
23 | # def test_read_gcp(self):
24 | # res = read_json("gs://paramtools-dev/defaults.json", {"token": "anon"})
25 | # assert isinstance(res, dict)
26 |
27 | @pytest.mark.network_bound
28 | def test_read_http(self):
29 | http_path = (
30 | "https://raw.githubusercontent.com/PSLmodels/ParamTools/master/"
31 | "paramtools/tests/defaults.json"
32 | )
33 | res = read_json(http_path)
34 | assert isinstance(res, dict)
35 |
36 | @pytest.mark.network_bound
37 | def test_read_github(self):
38 | gh_path = "github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json"
39 | res = read_json(gh_path)
40 | assert isinstance(res, dict)
41 |
42 | def test_read_file_path(self):
43 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
44 | defaults_path = os.path.join(CURRENT_PATH, "defaults.json")
45 | res = read_json(defaults_path)
46 | assert isinstance(res, dict)
47 |
48 | def test_read_string(self):
49 | res = read_json('{"hello": "world"}')
50 | assert isinstance(res, dict)
51 |
52 | def test_read_invalid(self):
53 | with pytest.raises(ValueError):
54 | read_json('{"hello": "world"')
55 |
56 | with pytest.raises(ValueError):
57 | read_json(f":{['a'] * 200}")
58 |
59 | with pytest.raises(TypeError):
60 | read_json(("hello", "world"))
61 |
62 | with pytest.raises(TypeError):
63 | read_json(None)
64 |
65 | def test_strip_comments_simple(self):
66 | """test strip comment"""
67 | params = """
68 | // my comment
69 | // another
70 | {
71 | "hello": "world"
72 | }
73 | """
74 | assert read_json(params) == {"hello": "world"}
75 |
76 | def test_strip_comments_multiline(self):
77 | """test strip comment"""
78 | params = """
79 | /* my comment
80 | another
81 | */
82 | {
83 | "hello": "world"
84 | }
85 | """
86 | assert read_json(params) == {"hello": "world"}
87 |
88 | def test_strip_comments_ignores_url(self):
89 | """test strips comment but doesn't affect http://..."""
90 | params = """
91 | // my comment
92 | {
93 | "hello": "http://world"
94 | }
95 | """
96 | assert read_json(params) == {"hello": "http://world"}
97 |
98 |
99 | def test_get_leaves():
100 | t = {
101 | 0: {"value": {0: ["leaf1", "leaf2"]}, 1: {"value": {0: ["leaf3"]}}},
102 | 1: {
103 | "value": {1: ["leaf4", "leaf5"]},
104 | 2: {"value": {0: ["leaf6", ["leaf7", "leaf8"]]}},
105 | },
106 | }
107 |
108 | leaves = get_leaves(t)
109 | assert leaves == [f"leaf{i}" for i in range(1, 9)]
110 |
111 | leaves = get_leaves([t])
112 | assert leaves == [f"leaf{i}" for i in range(1, 9)]
113 |
114 | leaves = get_leaves({})
115 | assert leaves == []
116 |
117 | leaves = get_leaves([])
118 | assert leaves == []
119 |
120 | leaves = get_leaves("leaf")
121 | assert leaves == ["leaf"]
122 |
123 |
124 | def test_ravel():
125 | a = 1
126 | assert ravel(a) == 1
127 |
128 | b = [1, 2, 3]
129 | assert ravel(b) == [1, 2, 3]
130 |
131 | c = [[1], 2, 3]
132 | assert ravel(c) == [1, 2, 3]
133 |
134 | d = [[1, 2, 3], [4, 5, 6]]
135 | assert ravel(d) == [1, 2, 3, 4, 5, 6]
136 |
137 | e = [0, [1, 2, 3], 4, [5, 6, 7], 8]
138 | assert ravel(e) == [0, 1, 2, 3, 4, 5, 6, 7, 8]
139 |
140 |
141 | def test_consistent_labels():
142 | v = [
143 | {"label0": 1, "label1": 2, "value": 3},
144 | {"label0": 4, "label1": 5, "value": 6},
145 | ]
146 | assert consistent_labels(v) == set(["label0", "label1"])
147 |
148 | v = [{"label0": 1, "value": 3}, {"label0": 4, "label1": 5, "value": 6}]
149 | assert consistent_labels(v) is None
150 |
151 | v = [{"label0": 1, "label1": 2, "value": 3}, {"label0": 4, "value": 6}]
152 | assert consistent_labels(v) is None
153 |
154 |
155 | def test_ensure_value_object():
156 | assert ensure_value_object("hello") == [{"value": "hello"}]
157 | assert ensure_value_object([{"value": "hello"}]) == [{"value": "hello"}]
158 | assert ensure_value_object([1, 2, 3]) == [{"value": [1, 2, 3]}]
159 | assert ensure_value_object([[1, 2, 3]]) == [{"value": [[1, 2, 3]]}]
160 | assert ensure_value_object({"hello": "world"}) == [
161 | {"value": {"hello": "world"}}
162 | ]
163 |
164 |
165 | def test_hashable_value_object():
166 | assert hash(hashable_value_object({"value": "hello", "world": "!"}))
167 |
168 |
169 | def test_filter_labels():
170 | assert filter_labels({"hello": "world"}, drop=["hello"]) == {}
171 | assert filter_labels({"hello": "world"}, keep=["hello"]) == {
172 | "hello": "world"
173 | }
174 | assert filter_labels({"hello": "world"}) == {"hello": "world"}
175 | assert filter_labels(
176 | {"hello": "world", "world": "hello"}, drop=["world"]
177 | ) == {"hello": "world"}
178 |
179 |
180 | def test_make_label_str():
181 | assert make_label_str({"hello": "world", "value": 0}) == "[hello=world]"
182 | assert make_label_str({"value": 0}) == ""
183 | assert make_label_str({}) == ""
184 | assert make_label_str({"b": 0, "c": 1, "a": 2}) == "[a=2, b=0, c=1]"
185 |
--------------------------------------------------------------------------------
/paramtools/tests/test_validate.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import pytest
4 | from marshmallow import ValidationError
5 |
6 | from paramtools.contrib import OneOf, Range, DateRange, When
7 |
8 |
9 | def test_OneOf():
10 | choices = ["allowed1", "allowed2"]
11 |
12 | oneof = OneOf(choices=choices)
13 | assert oneof("allowed1") == "allowed1"
14 | assert oneof(choices) == choices
15 | assert oneof([choices]) == [choices]
16 |
17 | with pytest.raises(ValidationError):
18 | oneof("notallowed")
19 |
20 | with pytest.raises(ValidationError):
21 | oneof(["notallowed", "allowed1"])
22 |
23 | # no support for 3-D arrays yet.
24 | with pytest.raises(ValidationError):
25 | oneof([[choices]])
26 |
27 | assert oneof("allowed1")
28 | assert oneof({"value": "allowed1"}, is_value_object=True)
29 |
30 | with pytest.raises(ValidationError):
31 | oneof("notallowed")
32 |
33 | with pytest.raises(ValidationError):
34 | oneof({"value": "notallowed"}, is_value_object=True)
35 |
36 |
37 | def test_Range_errors():
38 | range_ = Range(0, 10)
39 | with pytest.raises(ValidationError):
40 | range_(11)
41 |
42 | with pytest.raises(ValidationError):
43 | range_({"value": 11}, is_value_object=True)
44 |
45 | range_ = Range(min_vo=[{"value": 0}], max_vo=[{"value": 10}])
46 | with pytest.raises(ValidationError):
47 | range_(11)
48 |
49 | with pytest.raises(ValidationError):
50 | range_({"value": 11}, is_value_object=True)
51 |
52 | range_ = Range(
53 | min_vo=[{"lab0": 1, "value": 0}, {"lab0": 2, "value": 2}],
54 | max_vo=[{"lab0": 1, "value": 10}, {"lab0": 2, "value": 9}],
55 | error_min="param{labels} {input} < min {min} oth_param{oth_labels}",
56 | error_max="param{labels} {input} > max {max} max_oth_param{oth_labels}",
57 | )
58 | with pytest.raises(ValidationError) as excinfo:
59 | range_({"lab0": 1, "value": 11}, is_value_object=True)
60 | assert (
61 | excinfo.value.args[0][0]
62 | == "param[lab0=1] 11 > max 10 max_oth_param[lab0=1]"
63 | )
64 |
65 | with pytest.raises(ValidationError) as excinfo:
66 | range_({"value": 11}, is_value_object=True)
67 | assert excinfo.value.args[0] == [
68 | "param 11 > max 10 max_oth_param[lab0=1]",
69 | "param 11 > max 9 max_oth_param[lab0=2]",
70 | ]
71 |
72 |
73 | def test_Range_grid():
74 | range_ = Range(0, 10)
75 | assert range_.grid() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
76 |
77 | range_ = Range(0, 10, step=3)
78 | assert range_.grid() == [0, 3, 6, 9]
79 |
80 |
81 | def test_DateRange():
82 | drange = DateRange("2019-01-01", "2019-01-10", step={"days": 1})
83 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1)]
84 | assert drange.grid() == exp
85 |
86 | drange = DateRange("2019-01-01", "2019-01-10")
87 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1)]
88 | assert drange.grid() == exp
89 |
90 | drange = DateRange("2019-01-01", "2019-01-10", step={"days": 3})
91 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1, 3)]
92 | assert drange.grid() == exp
93 |
94 | dranges = [
95 | DateRange("2019-01-01", "2019-01-10", step={"days": 1}),
96 | DateRange(
97 | min_vo=[{"value": "2019-01-01"}],
98 | max_vo=[{"value": "2019-01-10"}],
99 | step={"days": 1},
100 | ),
101 | ]
102 | for drange in dranges:
103 | assert drange(datetime.date(2019, 1, 2))
104 | assert drange(
105 | {"value": datetime.date(2019, 1, 2)}, is_value_object=True
106 | )
107 |
108 | with pytest.raises(ValidationError):
109 | drange(datetime.date(2020, 1, 2))
110 |
111 | with pytest.raises(ValidationError):
112 | drange({"value": datetime.date(2020, 1, 2)}, is_value_object=True)
113 |
114 |
115 | def test_When():
116 | range_ = Range(0, 10)
117 | choices = [12, 15]
118 | oneof = OneOf(choices=choices)
119 | when = When(
120 | {"equal_to": "world"},
121 | when_vos=[{"value": "hello"}],
122 | then_validators=[range_],
123 | otherwise_validators=[oneof],
124 | )
125 |
126 | when(12)
127 |
128 | with pytest.raises(ValidationError):
129 | when(3)
130 |
131 | when = When(
132 | {"equal_to": "hello"},
133 | when_vos=[{"value": "hello"}],
134 | then_validators=[range_],
135 | otherwise_validators=[oneof],
136 | )
137 |
138 | when(3)
139 |
140 | with pytest.raises(ValidationError):
141 | when(12)
142 |
143 | assert when.grid() == list(range(10 + 1))
144 |
145 |
146 | def test_level():
147 | oneof = OneOf(choices=["allowed1", "allowed2"], level="warn")
148 | assert oneof.level == "warn"
149 | with pytest.raises(ValidationError) as excinfo:
150 | oneof("notachoice")
151 | assert excinfo.value.level == "warn"
152 |
153 | range_ = Range(0, 10, level="warn")
154 | assert range_.level == "warn"
155 | with pytest.raises(ValidationError) as excinfo:
156 | range_(11)
157 | assert excinfo.value.level == "warn"
158 |
--------------------------------------------------------------------------------
/paramtools/tests/test_values.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import copy
3 |
4 | from paramtools.values import Values, Slice, QueryResult, ValueItem
5 |
6 |
7 | @pytest.fixture
8 | def keyfuncs():
9 | return {"d0": lambda x: x, "d1": lambda x: ["hello", "world"].index(x)}
10 |
11 |
12 | @pytest.fixture
13 | def _values():
14 | return [
15 | {"d0": 1, "d1": "hello", "value": 1},
16 | {"d0": 1, "d1": "world", "value": 1},
17 | {"d0": 2, "d1": "hello", "value": 1},
18 | {"d0": 3, "d1": "world", "value": 1},
19 | ]
20 |
21 |
22 | @pytest.fixture
23 | def values(_values, keyfuncs):
24 | return Values(_values, keyfuncs)
25 |
26 |
27 | class TestValues:
28 | def test_values(self, values):
29 | assert len(values) == 4
30 |
31 | def test_key_error(self, values):
32 | with pytest.raises(KeyError):
33 | values["heyo"]
34 |
35 | def test_types(self, values):
36 | assert isinstance(values["d0"], Slice)
37 |
38 | assert isinstance(values["d0"] > 1, QueryResult)
39 |
40 | qr = values["d0"] > 1
41 | assert isinstance(qr.isel[0], dict)
42 | assert isinstance(qr.isel[:], list)
43 |
44 | assert isinstance(values.isel, ValueItem)
45 |
46 | assert isinstance(qr.as_values(), Values)
47 |
48 |
49 | class TestQuery:
50 | def test_select_eq(self, values):
51 | assert list((values["d0"] == 1) & (values["d1"] == "hello")) == [
52 | {"d0": 1, "d1": "hello", "value": 1}
53 | ]
54 | assert list(
55 | ((values["d0"] == 1) | (values["d0"] == 2))
56 | & (values["d1"] == "hello")
57 | ) == [
58 | {"d0": 1, "d1": "hello", "value": 1},
59 | {"d0": 2, "d1": "hello", "value": 1},
60 | ]
61 |
62 | def test_select_eq_strict(self, _values, keyfuncs):
63 | _values[2]["_auto"] = True
64 | _values[3]["_auto"] = True
65 | values = Values(_values, keyfuncs)
66 |
67 | assert list(
68 | (values["_auto"] == False) | (values.missing("_auto"))
69 | ) == [
70 | {"d0": 1, "d1": "hello", "value": 1},
71 | {"d0": 1, "d1": "world", "value": 1},
72 | ]
73 |
74 | def test_select_ne(self, values):
75 | assert list((values["d0"] != 1) & (values["d1"] != "hello")) == [
76 | {"d0": 3, "d1": "world", "value": 1}
77 | ]
78 |
79 | assert list((values["d0"] != 2) & (values["d0"] != 3)) == [
80 | {"d0": 1, "d1": "hello", "value": 1},
81 | {"d0": 1, "d1": "world", "value": 1},
82 | ]
83 |
84 | def test_select_gt(self, values):
85 | assert list(values["d0"] > 1) == [
86 | {"d0": 2, "d1": "hello", "value": 1},
87 | {"d0": 3, "d1": "world", "value": 1},
88 | ]
89 |
90 | def test_select_gte(self, values):
91 | assert list(values["d0"] >= 2) == [
92 | {"d0": 2, "d1": "hello", "value": 1},
93 | {"d0": 3, "d1": "world", "value": 1},
94 | ]
95 |
96 | def test_select_lt(self, values):
97 | assert list(values["d0"] < 3) == [
98 | {"d0": 1, "d1": "hello", "value": 1},
99 | {"d0": 1, "d1": "world", "value": 1},
100 | {"d0": 2, "d1": "hello", "value": 1},
101 | ]
102 |
103 | def test_select_lte(self, values):
104 | assert list(values["d0"] <= 2) == [
105 | {"d0": 1, "d1": "hello", "value": 1},
106 | {"d0": 1, "d1": "world", "value": 1},
107 | {"d0": 2, "d1": "hello", "value": 1},
108 | ]
109 |
110 | def test_isin(self, values):
111 | assert list(
112 | (values["d0"].isin([1, 2])) & (values["d1"] == "hello")
113 | ) == [
114 | {"d0": 1, "d1": "hello", "value": 1},
115 | {"d0": 2, "d1": "hello", "value": 1},
116 | ]
117 |
118 |
119 | class TestOperations:
120 | def test_add(self, values):
121 | copied = copy.deepcopy(values.values)
122 |
123 | new_vals = values.add([{"d0": 3, "d1": "hello", "value": 1}])
124 |
125 | assert len(values.values) == len(copied)
126 | assert len(values.index) == len(copied)
127 |
128 | assert len(new_vals.values) == len(copied) + 1
129 | assert len(new_vals.index) == len(copied) + 1
130 | assert new_vals.index == [0, 1, 2, 3, 4]
131 |
132 | assert list((new_vals["d0"] == 3) & (new_vals["d1"] == "hello")) == [
133 | {"d0": 3, "d1": "hello", "value": 1}
134 | ]
135 |
136 | def test_delete(self, values):
137 | copied = copy.deepcopy(values.values)
138 |
139 | new_vals = values.delete(0, inplace=False)
140 |
141 | assert len(values.values) == len(copied)
142 | assert len(values.index) == len(copied)
143 |
144 | assert len(new_vals.values) == len(copied) - 1
145 | assert len(new_vals.index) == len(copied) - 1
146 | assert new_vals.index == [1, 2, 3]
147 |
148 | new_vals.delete(1, inplace=True)
149 | assert len(new_vals.index) == len(copied) - 2
150 | assert len(new_vals.values) == len(copied) - 2
151 | assert new_vals.index == [2, 3]
152 |
153 | def test_as_values(self, values):
154 | queryset = (values["d0"] == 1) | (values["d0"] == 3)
155 | assert list(queryset) == [
156 | {"d0": 1, "d1": "hello", "value": 1},
157 | {"d0": 1, "d1": "world", "value": 1},
158 | {"d0": 3, "d1": "world", "value": 1},
159 | ]
160 |
161 | new_values = queryset.as_values()
162 | assert list(new_values["d1"] == "hello") == [
163 | {"d0": 1, "d1": "hello", "value": 1}
164 | ]
165 | assert list(
166 | (new_values["d1"] == "hello") | (new_values["d0"] == 3)
167 | ) == [
168 | {"d0": 1, "d1": "hello", "value": 1},
169 | {"d0": 3, "d1": "world", "value": 1},
170 | ]
171 |
172 |
173 | class TestIndexing:
174 | def test_Values(self, values, _values):
175 | for ix, value in enumerate(_values):
176 | assert values.isel[ix] == _values[ix]
177 |
178 | def test_Slice(self, values, _values):
179 | for ix, value in enumerate(_values):
180 | assert values["d0"][ix] == _values[ix]["d0"]
181 |
182 | def test_QueryResult(self, values):
183 | res1 = (values["d0"] == 1) & (values["d1"] == "hello")
184 |
185 | assert res1.isel[0] == {"d0": 1, "d1": "hello", "value": 1}
186 |
187 | res2 = ((values["d0"] == 1) | (values["d0"] == 2)) & (
188 | values["d1"] == "hello"
189 | )
190 | assert list(res2) == res2.isel[:2]
191 | assert res2.isel[:2] == [
192 | {"d0": 1, "d1": "hello", "value": 1},
193 | {"d0": 2, "d1": "hello", "value": 1},
194 | ]
195 | assert res2.isel[0] == {"d0": 1, "d1": "hello", "value": 1}
196 | assert res2.isel[1] == {"d0": 2, "d1": "hello", "value": 1}
197 |
198 | def test_not_implemented(self, values):
199 | with pytest.raises(NotImplementedError):
200 | values["d0"].isel[0]
201 |
202 | with pytest.raises(NotImplementedError):
203 | (values["d0"] == 1)[0]
204 |
--------------------------------------------------------------------------------
/paramtools/typing.py:
--------------------------------------------------------------------------------
1 | from typing import NewType, Dict, Any, Callable, Iterable, Union, IO, AnyStr
2 | from pathlib import Path
3 |
4 | ValueObject = NewType("ValueObject", Dict[str, Any])
5 | CmpFunc = NewType("CmpFunc", Callable[[Any, Iterable], bool])
6 |
7 | FileDictStringLike = Union[str, Path, IO[AnyStr], Dict[str, any]]
8 |
--------------------------------------------------------------------------------
/paramtools/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | from collections import OrderedDict
5 | from typing import Optional, List, Dict, Any
6 |
7 | import fsspec
8 | from fsspec.registry import known_implementations
9 | import marshmallow as ma
10 |
11 | from paramtools.typing import ValueObject, FileDictStringLike
12 |
13 |
14 | def _is_url(maybe_url):
15 | """
16 | Determine whether string is a URL or not using marshmallow and the URL
17 | schemes available through fsspec.
18 | """
19 | schemes = (
20 | set(["http"])
21 | | known_implementations.keys()
22 | | set(list(fsspec.registry))
23 | )
24 | try:
25 | ma.validate.URL(schemes=schemes, require_tld=False)(maybe_url)
26 | return True
27 | except ma.exceptions.ValidationError:
28 | return False
29 |
30 |
31 | def _read(
32 | params_or_path: FileDictStringLike,
33 | storage_options: Optional[Dict[str, Any]] = None,
34 | ):
35 | """
36 | Read files of the form:
37 | - Local file path.
38 | - Any URL readable by fsspec. For example:
39 | - s3: s3://paramtools-test/defaults.json
40 | - gcs: gs://paramtools-dev/defaults.json
41 | - http: https://somedomain.com/defaults.json
42 | - github: github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json
43 |
44 | """
45 | if isinstance(params_or_path, str) and os.path.exists(params_or_path):
46 | with open(params_or_path, "r") as f:
47 | return f.read()
48 |
49 | if isinstance(params_or_path, str) and _is_url(params_or_path):
50 | with fsspec.open(params_or_path, "r", **(storage_options or {})) as f:
51 | return f.read()
52 |
53 | if isinstance(params_or_path, str):
54 | return params_or_path
55 |
56 | if isinstance(params_or_path, dict):
57 | return params_or_path
58 |
59 | else:
60 | raise TypeError(
61 | f"Unable to read data of type: {type(params_or_path)}\n"
62 | " Data must be a File Path, URL, String, or Dict."
63 | )
64 |
65 |
66 | def remove_comments(string):
67 | """
68 | Remove single and multiline comments from JSON.
69 |
70 | StackOverflow magic:
71 | https://stackoverflow.com/a/18381470/9100772
72 | """
73 | pattern = r"(\".*?\"|\'.*?\')|(/\*.*?\*/|//[^\r\n]*$)"
74 | # first group captures quoted strings (double or single)
75 | # second group captures comments (//single-line or /* multi-line */)
76 | regex = re.compile(pattern, re.MULTILINE | re.DOTALL)
77 |
78 | def _replacer(match):
79 | # if the 2nd group (capturing comments) is not None,
80 | # it means we have captured a non-quoted (real) comment string.
81 | if match.group(2) is not None:
82 | return "\n" # preserve line numbers
83 | else: # otherwise, we will return the 1st group
84 | return match.group(1) # captured quoted-string
85 |
86 | return regex.sub(_replacer, string)
87 |
88 |
89 | def read_json(
90 | params_or_path: FileDictStringLike,
91 | storage_options: Optional[Dict[str, Any]] = None,
92 | ):
93 | """
94 | Read JSON data of the form:
95 | - Dict.
96 | - JSON string.
97 | - Local file path.
98 | - Any URL readable by fsspec. For example:
99 | - s3: s3://paramtools-test/defaults.json
100 | - gcs: gs://paramtools-dev/defaults.json
101 | - http: https://somedomain.com/defaults.json
102 | - github: github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json
103 |
104 | """
105 | res = _read(params_or_path, storage_options)
106 | if isinstance(res, str):
107 | try:
108 | res = remove_comments(res)
109 | return json.loads(res, object_pairs_hook=OrderedDict)
110 | except json.JSONDecodeError as je:
111 | if len(res) > 100:
112 | res = res[:100] + "..." + res[-10:]
113 | raise ValueError(f"Unable to decode JSON string: {res}") from je
114 |
115 | if isinstance(res, dict):
116 | return res
117 |
118 | # Error should be thrown in `_read`
119 | raise TypeError(f"Unknown type: {type(res)}")
120 |
121 |
122 | def get_example_paths(name):
123 | assert name in ("taxparams-demo",)
124 | current_path = os.path.abspath(os.path.dirname(__file__))
125 | default_spec_path = os.path.join(
126 | current_path, f"examples/{name}/defaults.json"
127 | )
128 | return default_spec_path
129 |
130 |
131 | class LeafGetter:
132 | """
133 | Return all non-dict or non-list items of a given object. This object
134 | should be an item or a list or dictionary composed of non-iterable items,
135 | nested dictionaries or nested lists.
136 |
137 | A functional approach was considered instead of this class. However, I was
138 | unable to come up with a way to store all of the leaves without "cheating"
139 | and keeping "leaf" state.
140 | """
141 |
142 | def __init__(self):
143 | self.leaves = []
144 |
145 | def get(self, item):
146 | if isinstance(item, dict):
147 | for _, v in item.items():
148 | self.get(v)
149 | elif isinstance(item, list):
150 | for li in item:
151 | self.get(li)
152 | else:
153 | self.leaves.append(item)
154 |
155 |
156 | def get_leaves(item):
157 | gl = LeafGetter()
158 | gl.get(item)
159 | return gl.leaves
160 |
161 |
162 | def ravel(nlabel_list):
163 | """ only up to 2D for now. """
164 | if not isinstance(nlabel_list, list):
165 | return nlabel_list
166 | raveled = []
167 | for maybe_list in nlabel_list:
168 | if isinstance(maybe_list, list):
169 | for item in maybe_list:
170 | raveled.append(item)
171 | else:
172 | raveled.append(maybe_list)
173 | return raveled
174 |
175 |
176 | def consistent_labels(value_items: List[ValueObject]):
177 | """
178 | Get labels used consistently across all value objects.
179 | Returns None if labels are omitted or added for
180 | some value object(s).
181 | """
182 | if not value_items:
183 | return set([])
184 | used = set(k for k in value_items[0] if k not in ("value", "_auto"))
185 | for vo in value_items:
186 | if used != set(k for k in vo if k not in ("value", "_auto")):
187 | return None
188 | return used
189 |
190 |
191 | def ensure_value_object(vo) -> ValueObject:
192 | if not isinstance(vo, list) or (
193 | isinstance(vo, list) and not isinstance(vo[0], dict)
194 | ):
195 | vo = [{"value": vo}]
196 | return vo
197 |
198 |
199 | def hashable_value_object(vo: ValueObject) -> tuple:
200 | """
201 | Helper function convertinga value object into a format
202 | that can be stored in a set.
203 | """
204 | return tuple(
205 | (label, value)
206 | for label, value in sorted(vo.items())
207 | if label not in ("_auto",)
208 | )
209 |
210 |
211 | def filter_labels(vo: ValueObject, drop=None, keep=None) -> ValueObject:
212 | """
213 | Filter a value objects labels by keeping labels
214 | in keep if specified and dropping labels that are in drop.
215 | """
216 | drop = drop or ()
217 | keep = keep or ()
218 | return {
219 | lab: lv
220 | for lab, lv in vo.items()
221 | if (lab not in drop) and (not keep or lab in keep)
222 | }
223 |
224 |
225 | def make_label_str(vo: ValueObject) -> str:
226 | """
227 | Create string from labels. This is used to create error messages.
228 | """
229 | lab_str = ", ".join(
230 | [
231 | f"{lab}={vo[lab]}"
232 | for lab in sorted(vo)
233 | if lab not in ("value", "_auto")
234 | ]
235 | )
236 | if lab_str:
237 | return f"[{lab_str}]"
238 | else:
239 | return ""
240 |
241 |
242 | def grid_sort(vos, label_to_extend, grid):
243 | def key(v):
244 | if label_to_extend in v:
245 | return grid.index(v[label_to_extend])
246 | else:
247 | return grid[0]
248 |
249 | return sorted(vos, key=key)
250 |
--------------------------------------------------------------------------------
/paramtools/values.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import defaultdict
3 | from typing import List, Dict, Any, Union, Generator
4 |
5 | from paramtools.sorted_key_list import SortedKeyList
6 | from paramtools.typing import ValueObject
7 |
8 |
9 | def default_cmp_func(x):
10 | return x
11 |
12 |
13 | class ValueItem:
14 | """
15 | Handles index-based look-ups on the Values class.
16 | """
17 |
18 | def __init__(self, values: "Values", index: List[int] = None):
19 | self.values = values
20 | self.index = list(index) if index is not None else index
21 |
22 | def __getitem__(self, item):
23 | if isinstance(item, slice):
24 | if self.index is not None:
25 | indices = item.indices(len(self.index))
26 | return [
27 | dict(self.values.values[self.index[ix]])
28 | for ix in range(*indices)
29 | ]
30 | else:
31 | indices = item.indices(len(self.values))
32 | return [dict(self.values.values[ix]) for ix in range(*indices)]
33 | elif self.index is not None:
34 | return dict(self.values.values[self.index[item]])
35 | else:
36 | return dict(self.values.values[item])
37 |
38 |
39 | class ValueBase:
40 | @property
41 | def cmp_attr(self):
42 | raise NotImplementedError()
43 |
44 | def __eq__(self, value=None, **labels):
45 | return self.cmp_attr.eq(**{self.label: value})
46 |
47 | def __ne__(self, value):
48 | return self.cmp_attr.ne(**{self.label: value})
49 |
50 | def __gt__(self, value):
51 | return self.cmp_attr.gt(**{self.label: value})
52 |
53 | def __ge__(self, value):
54 | return self.cmp_attr.gte(**{self.label: value})
55 |
56 | def __lt__(self, value):
57 | return self.cmp_attr.lt(**{self.label: value})
58 |
59 | def __le__(self, value):
60 | return self.cmp_attr.lte(**{self.label: value})
61 |
62 | def __len__(self):
63 | return len([item for item in iter(self.cmp_attr)])
64 |
65 | def __iter__(self):
66 | return iter(self.cmp_attr)
67 |
68 | def __getitem__(self, item):
69 | return self.cmp_attr[item]
70 |
71 | def eq(self, value, strict=True):
72 | return self.cmp_attr.eq(strict, **{self.label: value})
73 |
74 | def ne(self, value, strict=True):
75 | return self.cmp_attr.ne(strict, **{self.label: value})
76 |
77 | def gt(self, value, strict=True):
78 | return self.cmp_attr.gt(strict, **{self.label: value})
79 |
80 | def gte(self, value, strict=True):
81 | return self.cmp_attr.gte(strict, **{self.label: value})
82 |
83 | def lt(self, value, strict=True):
84 | return self.cmp_attr.lt(strict, **{self.label: value})
85 |
86 | def lte(self, value, strict=True):
87 | return self.cmp_attr.lte(strict, **{self.label: value})
88 |
89 | def isin(self, value, strict=True):
90 | return self.cmp_attr.isin(strict, **{self.label: value})
91 |
92 |
93 | class QueryResult(ValueBase):
94 | def __init__(self, values: "Values", index: List[Any]):
95 | self.values = values
96 | self.index = index
97 |
98 | def __and__(self, queryresult: "QueryResult"):
99 | res = set(self.index) & set(queryresult.index)
100 |
101 | return QueryResult(self.values, res)
102 |
103 | def __or__(self, queryresult: "QueryResult"):
104 | res = set(self.index) | set(queryresult.index)
105 |
106 | return QueryResult(self.values, res)
107 |
108 | def __repr__(self):
109 | vo_repr = "\n ".join(
110 | str(dict(self.values.values[i])) for i in (self.index or [])
111 | )
112 | return f"QueryResult([\n {vo_repr}\n])"
113 |
114 | def __iter__(self):
115 | for i in self.index:
116 | yield self.values.values[i]
117 |
118 | def __getitem__(self, item):
119 | raise NotImplementedError(
120 | "Use .isel to do index-based look ups or as_values to chain queries."
121 | )
122 |
123 | @property
124 | def isel(self):
125 | return ValueItem(self.values, self.index)
126 |
127 | def tolist(self):
128 | return [self.values.values[i] for i in self.index]
129 |
130 | def eq(self, strict=True, **labels):
131 | return self.cmp_attr.eq(strict, **labels)
132 |
133 | def ne(self, strict=True, **labels):
134 | return self.cmp_attr.ne(strict, **labels)
135 |
136 | def gt(self, strict=True, **labels):
137 | return self.cmp_attr.gt(strict, **labels)
138 |
139 | def gte(self, strict=True, **labels):
140 | return self.cmp_attr.gte(strict, **labels)
141 |
142 | def lt(self, strict=True, **labels):
143 | return self.cmp_attr.lt(strict, **labels)
144 |
145 | def lte(self, strict=True, **labels):
146 | return self.cmp_attr.lte(strict, **labels)
147 |
148 | def isin(self, strict=True, **labels):
149 | return self.cmp_attr.isin(strict, **labels)
150 |
151 | def __eq__(self, *args, **kwargs):
152 | raise NotImplementedError()
153 |
154 | def __ne__(self, *args, **kwargs):
155 | raise NotImplementedError()
156 |
157 | def __gt__(self, *args, **kwargs):
158 | raise NotImplementedError()
159 |
160 | def __ge__(self, *args, **kwargs):
161 | raise NotImplementedError()
162 |
163 | def __lt__(self, *args, **kwargs):
164 | raise NotImplementedError()
165 |
166 | def __le__(self, *args, **kwargs):
167 | raise NotImplementedError()
168 |
169 | def as_values(self):
170 | return Values(
171 | values=list(self), index=self.index, keyfuncs=self.values.keyfuncs
172 | )
173 |
174 | def delete(self):
175 | self.values.delete(*self.index, inplace=True)
176 |
177 | @property
178 | def cmp_attr(self):
179 | return self.values
180 |
181 |
182 | class Slice(ValueBase):
183 | def __init__(self, values: "Values", label: str):
184 | self.values = values
185 | self.label = label
186 |
187 | @property
188 | def cmp_attr(self):
189 | return self.values
190 |
191 | def __getitem__(self, item):
192 | if isinstance(item, slice):
193 | indices = item.indices(len(self))
194 | return [
195 | self.values.values[ix].get(self.label, None)
196 | for ix in range(*indices)
197 | ]
198 | else:
199 | return self.values.values[item][self.label]
200 |
201 | @property
202 | def isel(self):
203 | raise NotImplementedError(
204 | "Access values of a Slice object directly: parameters['label'][1]"
205 | )
206 |
207 | def __repr__(self):
208 | vo_repr = "\n ".join(
209 | str(dict(self.values.values[i])) for i in self.values.values
210 | )
211 | return f"Slice([\n {vo_repr}\n], \nlabel={self.label})"
212 |
213 |
214 | class Values(ValueBase):
215 | """
216 | The Values class is used to query and update parameter values.
217 |
218 | For more information, checkout the
219 | `Viewing Data `_ docs.
220 | """
221 |
222 | def __init__(
223 | self,
224 | values: List[ValueObject],
225 | keyfuncs: Dict[str, Any] = None,
226 | skls: Dict[str, SortedKeyList] = None,
227 | index: List[Any] = None,
228 | ):
229 | self.index = index or list(range(len(values)))
230 | self.values = {ix: value for ix, value in zip(self.index, values)}
231 | self.keyfuncs = keyfuncs
232 | self.label = "value"
233 |
234 | if skls is not None:
235 | self.skls = skls
236 | else:
237 | self.skls = self.build_skls(self.values, keyfuncs or {})
238 |
239 | def build_skls(self, values, keyfuncs):
240 | label_values = defaultdict(list)
241 | label_index = defaultdict(list)
242 | for ix, vo in values.items():
243 | for label, value in vo.items():
244 | label_values[label].append(value)
245 | label_index[label].append(ix)
246 |
247 | skls = {}
248 | for label in label_values:
249 | keyfunc = self.get_keyfunc(label, keyfuncs)
250 | skls[label] = SortedKeyList(
251 | label_values[label], keyfunc, label_index[label]
252 | )
253 |
254 | return skls
255 |
256 | def update_skls(self, values):
257 | # TODO: remove existing values with clashing index
258 | for ix, vo in values.items():
259 | for label, value in vo.items():
260 | if self.skls.get(label, None) is not None:
261 | self.skls[label].add(value, index=ix)
262 | else:
263 | self.skls[label] = SortedKeyList(
264 | [value],
265 | keyfunc=self.get_keyfunc(label, self.keyfuncs),
266 | index=[ix],
267 | )
268 |
269 | def get_keyfunc(self, label, keyfuncs):
270 | keyfunc = keyfuncs.get(label)
271 | return keyfunc or default_cmp_func
272 |
273 | def _cmp(self, op, strict, **labels):
274 | label, value = list(labels.items())[0]
275 | skl = self.skls.get(label, None)
276 |
277 | if skl is None and strict:
278 | raise KeyError(f"Unknown label: {label}.")
279 | elif skl is None and not strict:
280 | return QueryResult(self, list(self.index))
281 |
282 | skl_result = getattr(self.skls[label], op)(value)
283 | if not strict:
284 | match_index = skl_result.index if skl_result else []
285 | missing = self.missing(label)
286 | match_index = set(match_index + missing.index)
287 | elif skl_result is None:
288 | match_index = []
289 | else:
290 | match_index = skl_result.index
291 | return QueryResult(self, match_index)
292 |
293 | def __getitem__(self, label):
294 | if label not in self.skls:
295 | raise KeyError(f"Unknown label: {label}")
296 | return Slice(self, label)
297 |
298 | def missing(self, label: str):
299 | index = list(set(self.index) - self.skls[label].index)
300 | return QueryResult(self, index)
301 |
302 | def eq(self, strict=True, **labels):
303 | """
304 | Returns values that match the given label:
305 |
306 | .. code-block:: Python
307 |
308 | params.sel["my_param"].eq(my_label=5)
309 | params.sel["my_param"]["my_label"] == 5
310 | """
311 | return self._cmp("eq", strict, **labels)
312 |
313 | def ne(self, strict=True, **labels):
314 | """
315 | Returns values that do match the given label:
316 |
317 | .. code-block:: Python
318 |
319 | params.sel["my_param"].ne(my_label=5)
320 | params.sel["my_param"]["my_label"] != 5
321 | """
322 |
323 | return self._cmp("ne", strict, **labels)
324 |
325 | def gt(self, strict=True, **labels):
326 | """
327 | Returns values that have label values greater than the label value:
328 |
329 | .. code-block:: Python
330 |
331 | params.sel["my_param"].gt(my_label=5)
332 | params.sel["my_param"]["my_label"] > 5
333 |
334 | """
335 |
336 | return self._cmp("gt", strict, **labels)
337 |
338 | def gte(self, strict=True, **labels):
339 | """
340 | Returns values that have label values greater than or equal to the label value:
341 |
342 | .. code-block:: Python
343 |
344 | params.sel["my_param"].gte(my_label=5)
345 | params.sel["my_param"]["my_label"] >= 5
346 |
347 | """
348 | return self._cmp("gte", strict, **labels)
349 |
350 | def lt(self, strict=True, **labels):
351 | """
352 | Returns values that have label values less than the label value:
353 |
354 | .. code-block:: Python
355 |
356 | params.sel["my_param"].lt(my_label=5)
357 | params.sel["my_param"]["my_label"] < 5
358 |
359 | """
360 |
361 | return self._cmp("lt", strict, **labels)
362 |
363 | def lte(self, strict=True, **labels):
364 | """
365 | Returns values that have label values less than or equal to the label value:
366 |
367 | .. code-block:: Python
368 |
369 | params.sel["my_param"].lte(my_label=5)
370 | params.sel["my_param"]["my_label"] <= 5
371 |
372 | """
373 |
374 | return self._cmp("lte", strict, **labels)
375 |
376 | def isin(self, strict=True, **labels):
377 | """
378 | Returns values that have label values less than or equal to the label value:
379 |
380 | .. code-block:: Python
381 |
382 | params.sel["my_param"].isin(my_label=[5, 6])
383 |
384 | """
385 |
386 | label, values = list(labels.items())[0]
387 | return union(
388 | self.eq(strict=strict, **{label: value}) for value in values
389 | )
390 |
391 | def add(
392 | self, values: List[ValueObject], index: List[Any] = None, inplace=False
393 | ):
394 | if index is not None:
395 | assert len(index) == len(values)
396 | new_index = index
397 | else:
398 | max_index = max(self.index) if self.index else 0
399 | new_index = [max_index + i + 1 for i in range(len(values))]
400 |
401 | new_values = {ix: value for ix, value in zip(new_index, values)}
402 |
403 | if inplace:
404 | self.update_skls(new_values)
405 | self.values.update(new_values)
406 | self.index += new_index
407 | else:
408 | current_index = list(self.index)
409 | updated_values = dict(self.values)
410 | updated_values.update(new_values)
411 | return Values(
412 | [value for value in updated_values.values()],
413 | skls=self.build_skls(updated_values, self.keyfuncs),
414 | index=current_index + new_index,
415 | )
416 |
417 | def delete(self, *index, inplace=False):
418 | if not index:
419 | index = list(self.index)
420 | if inplace:
421 | for ix in index:
422 | self.values.pop(ix)
423 | self.index.remove(ix)
424 | self.skls = self.build_skls(self.values, self.keyfuncs)
425 | else:
426 | new_index = list(self.index)
427 | new_values = copy.deepcopy(self.values)
428 | for ix in index:
429 | new_values.pop(ix)
430 | new_index.remove(ix)
431 |
432 | return Values(
433 | [value for value in new_values.values()],
434 | keyfuncs=self.keyfuncs,
435 | index=new_index,
436 | )
437 |
438 | @property
439 | def cmp_attr(self):
440 | return self
441 |
442 | @property
443 | def isel(self):
444 | """
445 | Select values by their index:
446 |
447 | .. code-block:: Python
448 |
449 | params.sel["my_param"].isel[0]
450 | params.sel["my_param"].isel[:5]
451 |
452 | """
453 |
454 | return ValueItem(self, self.index)
455 |
456 | @property
457 | def labels(self):
458 | return list(self.skls.keys())
459 |
460 | def __eq__(self, other):
461 | if isinstance(other, ValueBase):
462 | return list(self) == list(other)
463 | elif isinstance(other, list):
464 | return list(self) == other
465 | else:
466 | raise TypeError(f"Unable to compare Values against {type(other)}")
467 |
468 | def __and__(self, queryresult: "QueryResult"):
469 | """
470 | Combine queries with logical 'and':
471 |
472 | .. code-block:: Python
473 |
474 | my_param = params.sel["my_param]
475 | (my_param["my_label"] == 5) & (my_param["oth_label"] == "hello")
476 | """
477 |
478 | res = set(self.index) & set(queryresult.index)
479 |
480 | return QueryResult(self, res)
481 |
482 | def __or__(self, queryresult: "QueryResult"):
483 | """
484 | Combine queries with logical 'or':
485 |
486 | .. code-block:: Python
487 |
488 | my_param = params.sel["my_param]
489 | (my_param["my_label"] == 5) | (my_param["oth_label"] == "hello")
490 | """
491 |
492 | res = set(self.index) | set(queryresult.index)
493 |
494 | return QueryResult(self, res)
495 |
496 | def __iter__(self):
497 | for value in self.values.values():
498 | yield value
499 |
500 | def __repr__(self):
501 | vo_repr = (
502 | ",\n ".join(str(dict(self.values[i])) for i in self.index) + ","
503 | )
504 | return f"Values([\n {vo_repr}\n])"
505 |
506 |
507 | def union(
508 | queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]]
509 | ):
510 | result = None
511 | for queryresult in queryresults:
512 | if result is None:
513 | result = queryresult
514 | else:
515 | result |= queryresult
516 |
517 | return result or QueryResult(None, [])
518 |
519 |
520 | def intersection(
521 | queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]]
522 | ):
523 | result = None
524 | for queryresult in queryresults:
525 | if result is None:
526 | result = queryresult
527 | else:
528 | result &= queryresult
529 |
530 | return result or QueryResult(None, [])
531 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 79
3 |
4 | [tool.pytest.ini_options]
5 | markers = ["network_bound"]
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | import os
3 |
4 | with open("README.md", "r") as f:
5 | long_description = f.read()
6 |
7 | setuptools.setup(
8 | name="paramtools",
9 | version=os.environ.get("VERSION", "0.20.0"),
10 | author="Hank Doupe",
11 | author_email="henrymdoupe@gmail.com",
12 | description=(
13 | "Library for parameter processing and validation with a focus "
14 | "on computational modeling projects"
15 | ),
16 | long_description=long_description,
17 | long_description_content_type="text/markdown",
18 | url="https://github.com/hdoupe/ParamTools",
19 | packages=setuptools.find_packages(),
20 | install_requires=[
21 | "marshmallow>=4.0.0",
22 | "numpy",
23 | "python-dateutil>=2.8.0",
24 | "fsspec",
25 | "sortedcontainers",
26 | ],
27 | include_package_data=True,
28 | classifiers=[
29 | "Programming Language :: Python :: 3",
30 | "License :: OSI Approved :: MIT License",
31 | "Operating System :: OS Independent",
32 | ],
33 | )
34 |
--------------------------------------------------------------------------------