├── .coveragerc
├── .github
└── workflows
│ ├── ci.yaml
│ ├── docgen_test.yaml
│ ├── pypi-nightly.yaml
│ └── pypi.yaml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── _static
│ ├── favicon.svg
│ ├── logo.svg
│ ├── logo_dark.svg
│ ├── logo_light.svg
│ └── style.css
├── _templates
│ └── layout.html
├── api
│ ├── _templates
│ │ ├── _default
│ │ │ ├── class.rst
│ │ │ ├── function.rst
│ │ │ ├── module.rst
│ │ │ └── object.rst
│ │ ├── core
│ │ │ └── symbolic
│ │ │ │ ├── dict.rst
│ │ │ │ ├── list.rst
│ │ │ │ ├── object.rst
│ │ │ │ └── symbolic.rst
│ │ └── index.rst
│ └── docgen.py
├── conf.py
├── generate_docs.sh
├── guide
│ ├── automl
│ │ ├── index.rst
│ │ ├── industrial.rst
│ │ ├── research.rst
│ │ └── toy.rst
│ ├── evolution
│ │ └── index.rst
│ ├── general
│ │ └── index.rst
│ ├── index.rst
│ └── ml
│ │ └── index.rst
├── index.rst
├── learn
│ ├── evolution
│ │ └── index.rst
│ ├── how_pyglove_works.svg
│ ├── index.rst
│ └── soop
│ │ ├── detour.rst
│ │ ├── index.rst
│ │ └── som
│ │ ├── definition.rst
│ │ ├── duality.svg
│ │ ├── events.rst
│ │ ├── object_layout.svg
│ │ ├── operations.rst
│ │ ├── placeholding.rst
│ │ ├── types.rst
│ │ └── validation.rst
├── notebooks
│ ├── evolution
│ │ ├── function_regression.ipynb
│ │ ├── onemax.ipynb
│ │ └── tsp.ipynb
│ ├── gui
│ │ └── html_view.ipynb
│ ├── intro
│ │ ├── basics
│ │ │ ├── runtime_typing.ipynb
│ │ │ └── symbolic_function.ipynb
│ │ ├── birdview.ipynb
│ │ └── search
│ │ │ ├── evolution_algorithm.ipynb
│ │ │ ├── evolution_ops.ipynb
│ │ │ └── evolution_scheduling.ipynb
│ ├── ml
│ │ ├── efficiently_exchange_ml_ideas_as_code.ipynb
│ │ ├── neural_modeling.ipynb
│ │ └── symbolic_ml.ipynb
│ └── python
│ │ ├── interactive_svg.ipynb
│ │ ├── sticky_notes.ipynb
│ │ └── where_is_the_duck.ipynb
└── requirements.txt
├── examples
├── automl
│ ├── mnist
│ │ ├── README.md
│ │ ├── mnist_train.py
│ │ ├── mnist_tune.py
│ │ ├── mnist_tune_eagerly.py
│ │ ├── mnist_tune_hparams.py
│ │ └── pytorch
│ │ │ ├── mnist_train.py
│ │ │ └── mnist_tune_eagerly.py
│ ├── nasbench
│ │ └── nasbench.py
│ └── natsbench
│ │ └── natsbench.py
├── evolution
│ ├── onemax.py
│ └── tsp.py
├── ml
│ └── symbolic_modeling.py
└── python
│ └── runtime_typing.py
├── pyglove
├── __init__.py
├── core
│ ├── __init__.py
│ ├── coding
│ │ ├── __init__.py
│ │ ├── errors.py
│ │ ├── errors_test.py
│ │ ├── execution.py
│ │ ├── execution_test.py
│ │ ├── function_generation.py
│ │ ├── function_generation_test.py
│ │ ├── parsing.py
│ │ ├── parsing_test.py
│ │ ├── permissions.py
│ │ └── permissions_test.py
│ ├── detouring
│ │ ├── __init__.py
│ │ ├── class_detour.py
│ │ └── class_detour_test.py
│ ├── geno
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── base_test.py
│ │ ├── categorical.py
│ │ ├── categorical_test.py
│ │ ├── custom.py
│ │ ├── custom_test.py
│ │ ├── deduping.py
│ │ ├── deduping_test.py
│ │ ├── dna_generator.py
│ │ ├── dna_generator_test.py
│ │ ├── numerical.py
│ │ ├── numerical_test.py
│ │ ├── random.py
│ │ ├── random_test.py
│ │ ├── space.py
│ │ ├── space_test.py
│ │ ├── sweeping.py
│ │ └── sweeping_test.py
│ ├── hyper
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── categorical.py
│ │ ├── categorical_test.py
│ │ ├── custom.py
│ │ ├── custom_test.py
│ │ ├── derived.py
│ │ ├── derived_test.py
│ │ ├── dynamic_evaluation.py
│ │ ├── dynamic_evaluation_test.py
│ │ ├── evolvable.py
│ │ ├── evolvable_test.py
│ │ ├── iter.py
│ │ ├── iter_test.py
│ │ ├── numerical.py
│ │ ├── numerical_test.py
│ │ ├── object_template.py
│ │ └── object_template_test.py
│ ├── io
│ │ ├── __init__.py
│ │ ├── file_system.py
│ │ ├── file_system_test.py
│ │ ├── sequence.py
│ │ └── sequence_test.py
│ ├── logging.py
│ ├── logging_test.py
│ ├── patching
│ │ ├── __init__.py
│ │ ├── object_factory.py
│ │ ├── object_factory_test.py
│ │ ├── pattern_based.py
│ │ ├── pattern_based_test.py
│ │ ├── rule_based.py
│ │ └── rule_based_test.py
│ ├── symbolic
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── base_test.py
│ │ ├── boilerplate.py
│ │ ├── boilerplate_test.py
│ │ ├── class_wrapper.py
│ │ ├── class_wrapper_test.py
│ │ ├── compounding.py
│ │ ├── compounding_test.py
│ │ ├── contextual_object.py
│ │ ├── contextual_object_test.py
│ │ ├── dict.py
│ │ ├── dict_test.py
│ │ ├── diff.py
│ │ ├── diff_test.py
│ │ ├── error_info.py
│ │ ├── error_info_test.py
│ │ ├── flags.py
│ │ ├── flags_test.py
│ │ ├── functor.py
│ │ ├── functor_test.py
│ │ ├── inferred.py
│ │ ├── inferred_test.py
│ │ ├── list.py
│ │ ├── list_test.py
│ │ ├── object.py
│ │ ├── object_test.py
│ │ ├── origin.py
│ │ ├── origin_test.py
│ │ ├── pure_symbolic.py
│ │ ├── ref.py
│ │ ├── ref_test.py
│ │ ├── symbolize.py
│ │ └── symbolize_test.py
│ ├── tuning
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── backend_test.py
│ │ ├── early_stopping.py
│ │ ├── local_backend.py
│ │ ├── protocols.py
│ │ ├── protocols_test.py
│ │ ├── sample.py
│ │ └── sample_test.py
│ ├── typing
│ │ ├── __init__.py
│ │ ├── annotated.py
│ │ ├── annotated_test.py
│ │ ├── annotation_conversion.py
│ │ ├── annotation_conversion_test.py
│ │ ├── annotation_future_test.py
│ │ ├── callable_ext.py
│ │ ├── callable_ext_test.py
│ │ ├── callable_signature.py
│ │ ├── callable_signature_test.py
│ │ ├── class_schema.py
│ │ ├── class_schema_test.py
│ │ ├── custom_typing.py
│ │ ├── inspect.py
│ │ ├── inspect_test.py
│ │ ├── json_schema.py
│ │ ├── json_schema_test.py
│ │ ├── key_specs.py
│ │ ├── key_specs_test.py
│ │ ├── pytype_support.py
│ │ ├── type_conversion.py
│ │ ├── type_conversion_test.py
│ │ ├── typed_missing.py
│ │ ├── typed_missing_test.py
│ │ ├── value_specs.py
│ │ └── value_specs_test.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── common_traits.py
│ │ ├── common_traits_test.py
│ │ ├── contextual.py
│ │ ├── contextual_test.py
│ │ ├── docstr_utils.py
│ │ ├── docstr_utils_test.py
│ │ ├── error_utils.py
│ │ ├── error_utils_test.py
│ │ ├── formatting.py
│ │ ├── formatting_test.py
│ │ ├── hierarchical.py
│ │ ├── hierarchical_test.py
│ │ ├── json_conversion.py
│ │ ├── json_conversion_test.py
│ │ ├── missing.py
│ │ ├── missing_test.py
│ │ ├── text_color.py
│ │ ├── text_color_test.py
│ │ ├── thread_local.py
│ │ ├── thread_local_test.py
│ │ ├── timing.py
│ │ ├── timing_test.py
│ │ ├── value_location.py
│ │ └── value_location_test.py
│ └── views
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── base_test.py
│ │ └── html
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── base_test.py
│ │ ├── controls
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── label.py
│ │ ├── label_test.py
│ │ ├── progress_bar.py
│ │ ├── progress_bar_test.py
│ │ ├── tab.py
│ │ ├── tab_test.py
│ │ ├── tooltip.py
│ │ └── tooltip_test.py
│ │ ├── tree_view.py
│ │ └── tree_view_test.py
└── ext
│ ├── __init__.py
│ ├── early_stopping
│ ├── __init__.py
│ ├── base.py
│ ├── base_test.py
│ ├── step_wise.py
│ └── step_wise_test.py
│ ├── evolution
│ ├── __init__.py
│ ├── base.py
│ ├── base_test.py
│ ├── hill_climb.py
│ ├── hill_climb_test.py
│ ├── mutators.py
│ ├── mutators_test.py
│ ├── neat.py
│ ├── neat_test.py
│ ├── nsga2.py
│ ├── nsga2_test.py
│ ├── recombinators.py
│ ├── recombinators_test.py
│ ├── regularized_evolution.py
│ ├── regularized_evolution_test.py
│ ├── selectors.py
│ ├── selectors_test.py
│ ├── where.py
│ └── where_test.py
│ ├── mutfun
│ ├── __init__.py
│ ├── base.py
│ ├── base_test.py
│ ├── basic_ops.py
│ └── basic_ops_test.py
│ └── scalars
│ ├── __init__.py
│ ├── base.py
│ ├── base_test.py
│ ├── maths.py
│ ├── maths_test.py
│ ├── randoms.py
│ ├── randoms_test.py
│ ├── step_wise.py
│ └── step_wise_test.py
├── requirements.txt
└── setup.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = pyglove
3 | branch = True
4 |
5 | [report]
6 | precision = 2
7 | show_missing = True
8 | omit = **/*_test.py
9 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test-ubuntu:
13 | name: "pytest on ${{ matrix.python-version }} on ${{ matrix.os }}"
14 | runs-on: "${{ matrix.os }}"
15 | strategy:
16 | matrix:
17 | python-version: ["3.9", "3.10", "3.11"]
18 | os: [ubuntu-latest]
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v1
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | pip install pytest
28 | pip install pytest-xdist
29 | pip install pytest-cov
30 | pip install -r requirements.txt
31 | - name: Test with pytest and generate coverage report
32 | run: |
33 | pytest -n auto --cov=pyglove --cov-report=xml
34 | - name: Upload coverage to Codecov
35 | uses: codecov/codecov-action@v1
36 | with:
37 | file: ./coverage.xml
38 | # The below step just reports the success or failure of tests as a "commit status".
39 | # This is needed for copybara integration.
40 | - name: Report success or failure as github status
41 | if: always()
42 | shell: bash
43 | run: |
44 | status="${{ job.status }}"
45 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
46 | curl -sS --request POST \
47 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
48 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
49 | --header 'content-type: application/json' \
50 | --data '{
51 | "state": "'$lowercase_status'",
52 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
53 | "description": "'$status'",
54 | "context": "github-actions/build"
55 | }'
56 |
--------------------------------------------------------------------------------
/.github/workflows/docgen_test.yaml:
--------------------------------------------------------------------------------
1 | name: docgen_test
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test-ubuntu:
13 | name: "documentation generation on ${{ matrix.python-version }} on ${{ matrix.os }}"
14 | runs-on: "${{ matrix.os }}"
15 | strategy:
16 | matrix:
17 | python-version: ["3.9"]
18 | os: [ubuntu-latest]
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install -r docs/requirements.txt
29 | - name: Print installed dependencies
30 | run: |
31 | pip freeze
32 | - name: Make Sphinx Docs to HTML (Test)
33 | run: |
34 | cd docs
35 | bash ./generate_docs.sh html
36 | # The below step just reports the success or failure of tests as a "commit status".
37 | # This is needed for copybara integration.
38 | - name: Report success or failure as github status
39 | if: always()
40 | shell: bash
41 | run: |
42 | status="${{ job.status }}"
43 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
44 | curl -sS --request POST \
45 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
46 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
47 | --header 'content-type: application/json' \
48 | --data '{
49 | "state": "'$lowercase_status'",
50 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
51 | "description": "'$status'",
52 | "context": "github-actions/docgen"
53 | }'
54 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-nightly.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created.
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: pypi-nightly
5 |
6 | on:
7 | workflow_dispatch:
8 | schedule:
9 | # Everyday@1:00AM PST
10 | - cron: "0 8 * * *"
11 |
12 | jobs:
13 | deploy:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v1
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | pip install -r requirements.txt
26 | - name: Build and publish
27 | env:
28 | TWINE_USERNAME: __token__
29 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
30 | run: |
31 | # Remove the parallel image for Github dark mode in README.md
32 | sed -i -e 's$
$$' README.md
33 | python setup.py sdist bdist_wheel -- --nightly
34 | twine upload dist/* --skip-existing
35 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created.
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: pypi
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python
16 | uses: actions/setup-python@v1
17 | with:
18 | python-version: '3.x'
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install setuptools wheel twine
23 | pip install -r requirements.txt
24 | - name: Build and publish
25 | env:
26 | TWINE_USERNAME: __token__
27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
28 | run: |
29 | # Remove the parallel image for Github dark mode in README.md
30 | sed -i -e 's$
$$' README.md
31 | python setup.py sdist bdist_wheel
32 | twine upload dist/*
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .coverage
3 | .vscode
4 | *.egg-info
5 | __pycache__/
6 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.9"
13 | # You can also specify other tool versions:
14 | # nodejs: "19"
15 | # rust: "1.64"
16 | # golang: "1.19"
17 | apt_packages:
18 | - graphviz
19 |
20 | # Build documentation in the docs/ directory with Sphinx
21 | sphinx:
22 | configuration: docs/conf.py
23 |
24 | # If using Sphinx, optionally build your docs in additional formats such as PDF
25 | # formats:
26 | # - pdf
27 |
28 | # Optionally declare the Python requirements required to build your docs
29 | python:
30 | install:
31 | - requirements: docs/requirements.txt
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/docs/_static/favicon.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | @import url("theme.css");
2 |
3 | .wy-side-nav-search {
4 | background-color: #000;
5 | }
6 |
7 | .wy-table-responsive table td,
8 | .wy-table-responsive table th {
9 | white-space: normal;
10 | }
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% set css_files = css_files + ["_static/style.css"] %}
--------------------------------------------------------------------------------
/docs/api/_templates/_default/class.rst:
--------------------------------------------------------------------------------
1 | .. _{{cls.rst_label}}:
2 |
3 | {{cls.preferred_path}}
4 | ======================
5 |
6 | Accessible via {{cls.rst_access_paths}}.
7 |
8 | .. autoclass:: {{cls.rst_import_name}}
9 | :members:
10 | :show-inheritance:
11 | :autosummary:
--------------------------------------------------------------------------------
/docs/api/_templates/_default/function.rst:
--------------------------------------------------------------------------------
1 | .. _{{func.rst_label}}:
2 |
3 | {{func.preferred_path}}
4 | =====================
5 |
6 | Accessible via {{func.rst_access_paths}}.
7 |
8 | .. autofunction:: {{func.rst_import_name}}
9 |
--------------------------------------------------------------------------------
/docs/api/_templates/_default/module.rst:
--------------------------------------------------------------------------------
1 | .. _{{module.rst_label}}:
2 |
3 | {{module.preferred_path}}
4 | =====================
5 |
6 | .. automodule:: {{module.rst_import_name}}
7 |
8 | {% if module.modules %}
9 | Modules
10 | *******
11 | .. toctree::
12 | :maxdepth: 1
13 | {% for entry in module.modules %}
14 | {{entry.name}}<{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %}
15 | {% endif %}
16 |
17 | {% if module.objects %}
18 | Objects
19 | *******
20 |
21 | .. toctree::
22 | :maxdepth: 1
23 | {% for entry in module.objects %}
24 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %}
25 | {% endif %}
26 |
27 | {% if module.classes %}
28 | Classes
29 | *******
30 | .. toctree::
31 | :maxdepth: 1
32 | {% for entry in module.classes %}
33 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %}
34 | {% endif %}
35 |
36 | {% if module.functions %}
37 | Functions
38 | *********
39 | .. toctree::
40 | :maxdepth: 1
41 | {% for entry in module.functions %}
42 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %}
43 | {% endif %}
--------------------------------------------------------------------------------
/docs/api/_templates/_default/object.rst:
--------------------------------------------------------------------------------
1 | .. _{{obj.rst_label}}:
2 |
3 | {{obj.preferred_path}}
4 | ================
5 |
6 | Accessible via {{obj.rst_access_paths}}.
7 |
8 | .. autodata:: {{obj.rst_import_name}}
9 |
--------------------------------------------------------------------------------
/docs/api/_templates/core/symbolic/dict.rst:
--------------------------------------------------------------------------------
1 | .. _{{cls.rst_label}}:
2 |
3 | {{cls.preferred_path}}
4 | ================
5 |
6 | Accessible via {{cls.rst_access_paths}}.
7 |
8 | .. autoclass:: {{cls.rst_import_name}}
9 | :members:
10 | :show-inheritance:
11 | :autosummary:
12 | :inherited-members:
13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__
14 |
--------------------------------------------------------------------------------
/docs/api/_templates/core/symbolic/list.rst:
--------------------------------------------------------------------------------
1 | .. _{{cls.rst_label}}:
2 |
3 | {{cls.preferred_path}}
4 | ================
5 |
6 | Accessible via {{cls.rst_access_paths}}.
7 |
8 | .. autoclass:: {{cls.rst_import_name}}
9 | :members:
10 | :show-inheritance:
11 | :autosummary:
12 | :inherited-members:
13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__
14 |
--------------------------------------------------------------------------------
/docs/api/_templates/core/symbolic/object.rst:
--------------------------------------------------------------------------------
1 | .. _{{cls.rst_label}}:
2 |
3 | {{cls.preferred_path}}
4 | ================
5 |
6 | Accessible via {{cls.rst_access_paths}}.
7 |
8 | .. autoclass:: {{cls.rst_import_name}}
9 | :members:
10 | :show-inheritance:
11 | :autosummary:
12 | :inherited-members:
13 | :private-members: _on_init, _on_change, _on_bound, _on_parent_change, _on_path_change
14 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__
15 |
--------------------------------------------------------------------------------
/docs/api/_templates/core/symbolic/symbolic.rst:
--------------------------------------------------------------------------------
1 | .. _{{cls.rst_label}}:
2 |
3 | {{cls.preferred_path}}
4 | ================
5 |
6 | Accessible via {{cls.rst_access_paths}}.
7 |
8 | .. autoclass:: {{cls.rst_import_name}}
9 | :members:
10 | :show-inheritance:
11 | :autosummary:
12 | :inherited-members:
13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__
14 |
--------------------------------------------------------------------------------
/docs/api/_templates/index.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: pyglove
2 |
3 | Public API: pyglove
4 | ===================
5 |
6 | Modules
7 | -------
8 |
9 | core
10 | ^^^^
11 |
12 | .. toctree::
13 | :maxdepth: 1
14 | {% for e in module.modules -%}
15 | {%- if e.api.source_category == 'core' %}
16 | {{e.api.relative_handle(module.doc_dir)}}
17 | {%- endif -%}
18 | {%- endfor %}
19 |
20 | ext
21 | ^^^
22 | .. toctree::
23 | :maxdepth: 1
24 | {% for e in module.modules -%}
25 | {%- if e.api.source_category == 'ext' %}
26 | {{e.api.relative_handle(module.doc_dir)}}
27 | {%- endif -%}
28 | {%- endfor %}
29 |
30 | .. toctree::
31 | :maxdepth: 1
32 |
33 | {% for e in module.modules -%}
34 | {%- if e.api.source_category == 'generators' %}
35 | {{e.api.relative_handle(module.doc_dir)}}
36 | {%- endif -%}
37 | {%- endfor %}
38 |
39 |
40 | Top-level shortcurts
41 | --------------------
42 |
43 | Objects
44 | ^^^^^^^
45 |
46 | {% for e in module.objects %}
47 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>`
48 | {%- endfor %}
49 |
50 | Classes
51 | ^^^^^^^
52 |
53 | {% for e in module.classes %}
54 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>`
55 | {%- endfor %}
56 |
57 | Functions
58 | ^^^^^^^^^
59 | {% for e in module.functions %}
60 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>`
61 | {%- endfor %}
62 |
--------------------------------------------------------------------------------
/docs/generate_docs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Minimal script to manually generate Sphinx documentation, used by Github
3 | # integration testing as well. Installs necessary Sphinx packages via
4 | # `requirements.txt`, and then uses `conf.py` to output HTML files. This script
5 | # is not actually needed for generating the official website at
6 | # (https://pyglove.readthedocs.io/en/latest/).
7 |
8 | # Define output folder for build files.
9 | OUTPUT_FOLDER=_build
10 |
11 | # Install Sphinx.
12 | sudo apt-get install python3-sphinx
13 |
14 | # Installs relevant Sphinx packages.
15 | pip install -r requirements.txt --use-deprecated=legacy-resolver
16 |
17 | # Build files (HTML, doctests, etc.) into `OUTPUT_FOLDER` directory.
18 | rm -rf ${OUTPUT_FOLDER} # Clear out original folder
19 | sphinx-build -b $1 -a . ${OUTPUT_FOLDER}
20 |
21 | # Optionally host the HTML folder. Access on browser `https://localhost:5000/`.
22 | # python -m http.server --directory ${OUTPUT_FOLDER} 5000
23 |
--------------------------------------------------------------------------------
/docs/guide/automl/index.rst:
--------------------------------------------------------------------------------
1 | For AutoML
2 | ==========
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | toy
8 | industrial
9 | Research
10 |
--------------------------------------------------------------------------------
/docs/guide/automl/industrial.rst:
--------------------------------------------------------------------------------
1 | Industrial Uses
2 | ===============
3 |
4 | Google Cloud Vertex AI
5 | **********************
6 |
7 | PyGlove is the `driving force `_
8 | behind `Vertex AI NAS `_, which
9 | has brought `advanced AutoML technologies `_
10 | to industries with significant impact. Companies such as
11 | `Qualcomm `_,
12 | `Varian `_, and
13 | `Oppo `_
14 | have all benefited from this groundbreaking technology.
15 |
16 | Vertex AI has published a few search spaces for computer vision as examples:
17 |
18 | * EfficientNet [`paper `_][`code `_]
19 | * SpineNet [`paper `_][`code `_]
20 | * MNASNet [`paper `_][`code `_]
21 | * NASFpn [`paper `_][`code `_]
22 | * AutoAugment [`paper `_][`code `_]
23 |
24 | Pax
25 | ***
26 |
27 | Pax is a powerful machine learning framework developed by Google, based on Jax, that is designed for training large-scale models.
28 | It utilizes PyGlove to enable its hyperparameter tuning and AutoML capabilities. Pax serves as a good example of how PyGlove can
29 | be seamlessly integrated into a large scale ML codebase based on dynamic evaluation:
30 |
31 | * `Inspecting the search space `_
32 | * `Implementing the tuning loop `_
33 |
34 | Vizier
35 | ******
36 |
37 | `Vizier `_
38 | is the distributed tuning solution at Alphabet. PyGlove uses it as a backend for serving distributed AutoML scenarios at Google.
39 | The `open-source Vizier `_ has shown how PyGlove can be
40 | `used together with Vizier `_,
41 | it also serve as an example on how PyGlove backend could be developed.
42 |
43 | * `Implementing the Backend interface `_
44 | * `Implementing the Feedback interface `_
45 |
46 |
--------------------------------------------------------------------------------
/docs/guide/automl/research.rst:
--------------------------------------------------------------------------------
1 | AutoML Research
2 | ---------------
3 |
4 | PyGlove is a versatile library that not only simplifies the implementation of AutoML applications but is also powerful enough to facilitate complex AutoML/ML research.
5 | Its capability and flexibility have been demonstrated in several academic papers, including the following:
6 |
7 | Papers with code:
8 |
9 |
10 | * `Evolving Reinforcement Learning algorithms, ICLR 2021 `_ [`code `_]
11 | * `PyGlove: Efficiently Exchanging ML Ideas as Code, 2022 `_ [`code `_]
12 |
13 | Papers only:
14 |
15 | * `PyGlove: Symbolic Programming for Automated Machine Learning, NeurIPS 2020 `_
16 | * `AutoHAS: Efficient Hyperparameter and Architecture Search, ICLR 2021 NAS workshop `_
17 | * `Towards the Co-design of Neural Networks and Accelerators, MLSys 2022 `_
18 | * `Deepfusion: Lidar-camera Deep Fusion for Multi-modal 3D Object Detection, CVPR 2022 `_
19 | * `ES-ENAS: Combining Evolution Strategies with Neural Architecture Search at No Extra Cost for Reinforcement Learning, CoRR 2021 `_
20 |
--------------------------------------------------------------------------------
/docs/guide/automl/toy.rst:
--------------------------------------------------------------------------------
1 | Toy Problems
2 | ------------
3 |
4 | * `Neural Architecture Search on MNIST `_
5 | * `NAS-Bench-101 `_
6 | * `NATS-Bench `_
7 |
--------------------------------------------------------------------------------
/docs/guide/evolution/index.rst:
--------------------------------------------------------------------------------
1 | For Evolution
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | ../../notebooks/evolution/onemax
8 | ../../notebooks/evolution/tsp
9 | Function Regression <../../notebooks/evolution/function_regression>
--------------------------------------------------------------------------------
/docs/guide/general/index.rst:
--------------------------------------------------------------------------------
1 | For Python
2 | ==========
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | Flexible Function Bindings <../../notebooks/intro/basics/symbolic_function>
8 | Runtime Typing <../../notebooks/intro/basics/runtime_typing>
9 | Direct Manipulation <../../notebooks/python/interactive_svg>
10 | Domain-Specific Languages <../../notebooks/python/sticky_notes>
11 | Context-aware Components <../../notebooks/python/where_is_the_duck>
12 |
--------------------------------------------------------------------------------
/docs/guide/index.rst:
--------------------------------------------------------------------------------
1 | Getting Started
2 | ===============
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | ../notebooks/intro/birdview
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | general/index.rst
13 | ml/index.rst
14 | automl/index.rst
15 | evolution/index.rst
16 |
--------------------------------------------------------------------------------
/docs/guide/ml/index.rst:
--------------------------------------------------------------------------------
1 | For Machine Learning
2 | ====================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | Symbolic Modeling <../../notebooks/ml/neural_modeling>
8 | Symbolic Machine Learning <../../notebooks/ml/symbolic_ml>
9 | Patching: Efficiently Exchangeing ML Ideas <../../notebooks/ml/efficiently_exchange_ml_ideas_as_code>
10 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. BUILD: sphinx-build -b html -a . ../build
2 |
3 | Welcome to PyGlove
4 | ##################
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | guide/index
10 | learn/index
11 |
12 | .. toctree::
13 | :maxdepth: 2
14 | :caption: API References
15 |
16 | api/index
17 | API Index (A-Z)
18 |
--------------------------------------------------------------------------------
/docs/learn/evolution/index.rst:
--------------------------------------------------------------------------------
1 | Evolution
2 | #########
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | Algorithm <../../notebooks/intro/search/evolution_algorithm>
8 | Operations <../../notebooks/intro/search/evolution_ops>
9 | Fine Control <../../notebooks/intro/search/evolution_scheduling>
--------------------------------------------------------------------------------
/docs/learn/index.rst:
--------------------------------------------------------------------------------
1 | Learning PyGlove
2 | ################
3 |
4 | .. PyGlove is organized into three layers:
5 |
6 | .. * At the bottom, the layer of symbolic object-oriented programming, enables a mutable
7 | .. programming model with objects, which allows unknown program parts to be expressed
8 | .. side-by-side with the expression of known program parts, and enables dynamic interpretations
9 | .. on them.
10 | .. * At the middle, the layer of intelligent programs, provides the representation
11 | .. and operations to convert between symbols and numbers. It also introduces the expression for
12 | .. feedback loops, as well as a framework for building algorithms that evolve the program.
13 | .. * At the top, the layer of distributed symbolic computing, introduces API to allow
14 | .. feedback loops to be distributed so that intelligent programs can run at scale. This layer
15 | .. also provide the interfaces for the user to plug in their own infrastructures into PyGlove.
16 |
17 | PyGlove is a Python library for manipulating Python programs. It is built on the
18 | concept of symbolic object-oriented programming (SOOP), which forms its core foundation.
19 | On top of that, PyGlove includes multiple layers of components that enhance its capabilities and
20 | enable the development of intelligent systems.
21 |
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 |
26 | SOOP
27 | evolution/index
28 |
29 | .. .. image:: how_pyglove_works.svg
30 |
31 |
--------------------------------------------------------------------------------
/docs/learn/soop/index.rst:
--------------------------------------------------------------------------------
1 | Symbolic Object-Oriented Programming
2 | ####################################
3 |
4 | PyGlove offers two key capabilities for symbolic object-oriented programming: the
5 | *Symbolic Object Model (SOM)* and *Symbolic Detour (SD)*. SOM implements *dynamic representation*,
6 | allowing for flexible manipulation of symbolic objects. SD enables *dynamic interpretation*,
7 | allowing symbols to be interpreted in different ways at runtime.
8 |
9 |
10 | Symbolic Object Model
11 | *********************
12 |
13 | The Symbolic Object Model (SOM) is the core of symbolic object-oriented programming,
14 | providing dynamic representation through symbolic objects. SOM stores initialization
15 | arguments as symbolic attributes and allows inspection and manipulation of them.
16 | It also includes a symbolic schema system for validation and a messaging system for
17 | handling mutations. Symbolic placeholders are also supported for representing unknown
18 | program parts.
19 |
20 | .. .. _`dynamic representation`: ../../overview/what_and_why.html#programming-the-unknown
21 |
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 |
26 | som/definition
27 | som/types
28 | som/operations
29 | som/events
30 | som/validation
31 | som/placeholding
32 |
33 | Symbolic Detour
34 | ***************
35 |
36 | Symbolic Detour (SD) is independent of the SOM, allowing users to alter
37 | class mapping without modifying the source code that instantiate the classes. This is
38 | particularly useful when the source code cannot be symbolized in various reasons.
39 | SD complements SOM.
40 |
41 | .. toctree::
42 | :maxdepth: 1
43 |
44 | detour
45 |
46 |
47 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # Required Sphinx-related packages.
2 |
3 | absl-py
4 | autodocsumm
5 | furo
6 | jinja2
7 | sphinx_autodoc_typehints
8 | sphinx-copybutton
9 | myst_nb
10 |
11 | # Required packages for PyGlove.
12 |
13 | docstring-parser>=0.12
14 |
--------------------------------------------------------------------------------
/examples/automl/mnist/mnist_train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Train MNIST.
15 |
16 | This is a basic working ML program which trains MNIST.
17 | The code is modified from the tf.keras tutorial here:
18 | https://www.tensorflow.org/tutorials/keras/classification
19 |
20 | (The tutorial uses Fashion-MNIST,
21 | but we just use "regular" MNIST for these tutorials.)
22 |
23 | """
24 |
25 | from typing import Tuple
26 |
27 | from absl import app
28 | import numpy as np
29 | import tensorflow.google.compat.v2 as tf
30 |
31 |
32 | def download_and_prep_data() -> Tuple[np.ndarray,
33 | np.ndarray,
34 | np.ndarray,
35 | np.ndarray]:
36 | """Download dataset and scale to [0, 1].
37 |
38 | Returns:
39 | tr_x: Training data.
40 | tr_y: Training labels.
41 | te_x: Testing data.
42 | te_y: Testing labels.
43 | """
44 | mnist_dataset = tf.keras.datasets.mnist
45 | (tr_x, tr_y), (te_x, te_y) = mnist_dataset.load_data()
46 | tr_x = tr_x / 255.0
47 | te_x = te_x / 255.0
48 | return tr_x, tr_y, te_x, te_y
49 |
50 |
51 | def create_model() -> tf.keras.Model:
52 | """Create model for training.
53 |
54 | Create a simple tf.keras model for training.
55 |
56 | Returns:
57 | The model to use for training.
58 | """
59 | model = tf.keras.Sequential([
60 | tf.keras.layers.Flatten(input_shape=(28, 28)),
61 | tf.keras.layers.Dense(128, activation='relu'),
62 | tf.keras.layers.Dense(10, activation='softmax')
63 | ])
64 | return model
65 |
66 |
67 | def train_and_eval() -> None:
68 | """Run training and evaluation.
69 |
70 | Code to run all of the prep, training, and evaluation.
71 | """
72 | tr_x, tr_y, te_x, te_y = download_and_prep_data()
73 | model = create_model()
74 | model.compile(optimizer='adam',
75 | loss='sparse_categorical_crossentropy',
76 | metrics=['accuracy'])
77 | model.fit(tr_x, tr_y, epochs=10)
78 | test_loss, test_acc = model.evaluate(te_x, te_y, verbose=2)
79 | print('Test loss: {}, accuracy: {}'.format(test_loss, test_acc))
80 |
81 |
82 | def main(argv):
83 | if len(argv) > 1:
84 | raise app.UsageError('Too many command-line arguments.')
85 | train_and_eval()
86 |
87 |
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/examples/evolution/onemax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Solving the One-Max problem with custom encoding.
15 |
16 | Reference: https://tracer.lcc.uma.es/problems/onemax/onemax.html#SE91
17 | For more details, see:
18 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/evolution/onemax.ipynb
19 | """
20 |
21 | import random
22 | import pyglove as pg
23 |
24 |
25 | def one_max(search_space, search_algorithm, num_trials=200):
26 | """Solves the One-Max program through running a search."""
27 | best_sequence, best_reward = None, None
28 | for sequence, feedback in pg.sample(
29 | search_space, search_algorithm,
30 | num_examples=num_trials):
31 | reward = sum(sequence)
32 | if best_reward is None or best_reward < reward:
33 | best_sequence, best_reward = sequence, reward
34 | feedback(reward)
35 | print(f'Best sequence: {list(best_sequence)} (sum={best_reward})')
36 |
37 |
38 | def one_max_with_builtin_primitive(n: int):
39 | """Solves One-Max problem using built-in hyper primitive."""
40 | search_space = pg.List([pg.oneof([0, 1])] * n)
41 | search_algorithm = pg.evolution.regularized_evolution(
42 | population_size=20, tournament_size=10)
43 | one_max(search_space, search_algorithm)
44 |
45 |
46 | def one_max_with_custom_primitive(n: int):
47 | """Sovles One-Max problem using user-defined hyper primitive."""
48 |
49 | class BitString(pg.hyper.CustomHyper):
50 | """Custom hyper primitive that represents a bit string of size n."""
51 |
52 | def custom_decode(self, dna: pg.DNA):
53 | assert isinstance(dna.value, str)
54 | bitstr = dna.value
55 | return [int(x) for x in bitstr]
56 |
57 | class MutateOneBit(pg.evolution.Mutator):
58 |
59 | def mutate(self, dna: pg.DNA): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
60 | bitstr = dna.value
61 | index = random.randint(0, len(dna.value) - 1)
62 | new_bitstr = (
63 | bitstr[:index]
64 | + ('0' if bitstr[index] == '1' else '1')
65 | + bitstr[index + 1:])
66 | return pg.DNA(new_bitstr)
67 |
68 | def init_population(population_size):
69 | @pg.geno.dna_generator
70 | def initializer(dna_spec):
71 | del dna_spec
72 | for _ in range(population_size):
73 | bits = [str(random.randint(0, 1)) for _ in range(n)]
74 | yield pg.DNA(''.join(bits))
75 | return initializer() # pylint: disable=no-value-for-parameter
76 |
77 | search_space = BitString()
78 | search_algorithm = pg.evolution.Evolution(
79 | (pg.evolution.selectors.Random(10)
80 | >> pg.evolution.selectors.Top(1)
81 | >> MutateOneBit()),
82 | population_init=init_population(10),
83 | population_update=pg.evolution.selectors.Last(20))
84 | one_max(search_space, search_algorithm)
85 |
86 |
87 | def main():
88 | one_max_with_builtin_primitive(10)
89 | one_max_with_custom_primitive(10)
90 |
91 |
92 | if __name__ == '__main__':
93 | main()
94 |
--------------------------------------------------------------------------------
/examples/evolution/tsp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Solving the travelling salesman problem (TSP) with evolution.
15 |
16 | Reference: https://en.wikipedia.org/wiki/Travelling_salesman_problem
17 |
18 | A more detailed tutorial explaning how to use PyGlove to solve TSP can be
19 | found here:
20 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/evolution/tsp.ipynb
21 | """
22 |
23 | import math
24 | from typing import List
25 | import pyglove as pg
26 |
27 |
28 | @pg.symbolize
29 | class City:
30 | """Represents a city with location (x, y) on the map."""
31 |
32 | def __init__(self, x: int, y: int):
33 | self.x = x
34 | self.y = y
35 |
36 | def distance(self, other: 'City') -> float:
37 | return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2)
38 |
39 |
40 | @pg.symbolize
41 | class Route:
42 | """Represents a route that traverse the cities in their appearing order."""
43 |
44 | def __init__(self, cities: List[City]):
45 | self.cities = cities
46 |
47 | def length(self) -> float:
48 | l = 0
49 | for i in range(0, len(self.cities)):
50 | l += self.cities[i].distance(self.cities[(i + 1) % len(self.cities)])
51 | return l
52 |
53 |
54 | def tsp(cities: List[City], num_trials: int = 500) -> Route:
55 | """Returns the best route found."""
56 | best_route, min_len = None, None
57 |
58 | # The route space is a Route object
59 | # with all possible permutations generated
60 | # from given cities.
61 | route_space = Route(pg.permutate(cities))
62 |
63 | def evolution(op, population_size=50, tournament_size=20, seed=None):
64 | return pg.evolution.Evolution(
65 | (pg.evolution.selectors.Random(tournament_size, seed=seed)
66 | >> pg.evolution.selectors.Top(2)
67 | >> op),
68 | population_init=(pg.geno.Random(seed=seed), population_size),
69 | population_update=pg.evolution.selectors.Last(population_size))
70 |
71 | search_algorithm = evolution(
72 | pg.evolution.recombinators.PartiallyMapped()
73 | >> pg.evolution.mutators.Swap())
74 |
75 | # `pg.sample` is the way to sample an example
76 | # route from the route space. Each iteration
77 | # will emit a `feedback` object, which can be
78 | # used to pass the reward to the controller.
79 | for route, feedback in pg.sample(
80 | route_space,
81 | search_algorithm,
82 | num_examples=num_trials):
83 | l = route.length()
84 | if min_len is None or min_len > l:
85 | best_route, min_len = route, l
86 | # We negate the route length as the reward since
87 | # the algorithm is to maximize the reward value.
88 | feedback(-l)
89 |
90 | print(f'Best route length: {min_len}.')
91 | print(best_route)
92 | return best_route
93 |
94 |
95 | def main():
96 | # Generating 25 cities.
97 | cities = list(pg.random_sample(
98 | City(x=pg.oneof(range(100)), y=pg.oneof(range(100))), 25, seed=1))
99 | tsp(cities)
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/examples/ml/symbolic_modeling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Symbolic neural modeling with PyGlove.
15 |
16 | For more details, see:
17 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/ml/neural_modeling.ipynb
18 | """
19 | import pyglove as pg
20 | import tensorflow as tf
21 |
22 |
23 | # Symbolizing Keras layers so their instances can be symbolically manipulated.
24 |
25 | Sequential = pg.symbolize(tf.keras.Sequential)
26 | Conv2D = pg.symbolize(tf.keras.layers.Conv2D)
27 | Dense = pg.symbolize(tf.keras.layers.Dense)
28 | Flatten = pg.symbolize(tf.keras.layers.Flatten)
29 | ReLU = pg.symbolize(tf.keras.layers.ReLU)
30 |
31 |
32 | def create_model() -> tf.keras.layers.Layer:
33 | return Sequential([
34 | Conv2D(16, (5, 5)),
35 | ReLU(),
36 | Conv2D(32, (3, 3)),
37 | ReLU(),
38 | Flatten(),
39 | Dense(10)
40 | ])
41 |
42 |
43 | def scale_model(model) -> None:
44 | """Scale the model up by doubling the filters of Conv2D layers."""
45 | def double_width(k, v, p):
46 | """A rebind rule for doubling the filters for Conv2D layers.
47 |
48 | Args:
49 | k: A `pg.KeyPath` object representing the location of current node.
50 | v: The value of current node.
51 | p: The parent of current node.
52 |
53 | Returns:
54 | The output value for current node.
55 | """
56 | if isinstance(p, Conv2D) and k.key == 'filters':
57 | return 2 * v
58 | return v
59 |
60 | # Rebind allows the users to manipulate a symbolic object by
61 | # rules.
62 | model.rebind(double_width)
63 |
64 |
65 | def remove_relus(model) -> None:
66 | """Remove ReLU layers from the model."""
67 | def remove_activations(k, v, p):
68 | del k, p
69 | if isinstance(v, ReLU):
70 | # `pg.MISSING_VALUE` is a placeholder for deleting a value from
71 | # a container.
72 | return pg.MISSING_VALUE
73 | return v
74 | model.rebind(remove_activations)
75 |
76 |
77 | def change_classification_head_width(model, width: int) -> None:
78 | """Update classification head width."""
79 | result = pg.query(model, where=lambda v: isinstance(v, Dense))
80 | classification_head_location = list(result.keys())[-1]
81 | model.rebind({
82 | f'{classification_head_location}.units': width
83 | })
84 |
85 |
86 | def main() -> None:
87 | model = create_model()
88 | # The symbolized Keras layers can be printed in human readable form.
89 | # For clarity, we hide the default values of the layers.
90 | print('Original model.')
91 | print(model.format(hide_default_values=True))
92 |
93 | scale_model(model)
94 | print('After doubling the width.')
95 | print(model.format(hide_default_values=True))
96 |
97 | remove_relus(model)
98 | print('After removing the ReLUs.')
99 | print(model.format(hide_default_values=True))
100 |
101 | print('After changing the classification head width.')
102 | change_classification_head_width(model, 100)
103 | print(model.format(hide_default_values=True))
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
--------------------------------------------------------------------------------
/examples/python/runtime_typing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """PyGlove runtime typing example."""
15 |
16 | import pyglove as pg
17 |
18 |
19 | class Foo(pg.Object):
20 | """A symbolic class."""
21 |
22 | x: pg.typing.List[int] | None
23 | y: pg.typing.Dict[str, pg.typing.Any]
24 | z: pg.typing.Annotated[
25 | pg.typing.Int(min_value=0), # Field value spec.
26 | 'Field z', # Field docstring.
27 | dict(p=1) # Meta-data
28 | ]
29 | p: pg.typing.Enum[['foo', 'bar', 'baz']] = 'baz'
30 | q: pg.typing.Dict[{
31 | 'a': int | None,
32 | 'b': pg.typing.Int[0, 100],
33 | 'c': pg.typing.Tuple[int, ...]
34 | }] = dict(a=1, b=10, c=(1, 2))
35 |
36 |
37 | # `pg.typing` could also be used for static type analysis.
38 | def add(
39 | x: pg.typing.Int[0, None],
40 | y: pg.typing.Float[None, 1.0]) -> pg.typing.Any:
41 | return x + y
42 |
43 |
44 | def main() -> None:
45 | foo = Foo([0, 1], dict(x=1), 1)
46 | print(foo)
47 |
48 | try:
49 | Foo(None, dict(y=1), -1)
50 | except ValueError as e:
51 | print('Expected error', e)
52 |
53 | # There is no runtime check for regular Python function even type annotation
54 | # is given.
55 | print(add(-1, 2.0))
56 |
57 | # But we can create a symbolic function which does runtime value check.
58 | prime_add = pg.symbolize(add)
59 | try:
60 | prime_add(-1, 2.0)()
61 | except ValueError as e:
62 | print('Expected error', e)
63 |
64 |
65 | if __name__ == '__main__':
66 | main()
67 |
68 |
--------------------------------------------------------------------------------
/pyglove/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Package pyglove.
16 |
17 | This package is the facade of the public PyGlove library, which includes modules
18 | for symbolic type definition and manipulation, symbolic typing and constraints,
19 | symbolic value generation and etc. It only have a handful of dependencies such
20 | as enum, six, yaml.
21 | """
22 |
23 | # NOTE(daiyip): We disable bad-import-order to preserve the relation of
24 | # imported symbols
25 | # pylint: disable=g-bad-import-order
26 | # pylint: disable=unused-import
27 | # pylint: disable=reimported
28 | # pylint: disable=g-import-not-at-top
29 |
30 | from pyglove.core import *
31 | from pyglove.ext import *
32 |
33 | # Placeholder for Google-internal imports.
34 |
35 | # pylint: enable=g-import-not-at-top
36 | # pylint: enable=reimported
37 | # pylint: enable=unused-import
38 | # pylint: enable=g-bad-import-order
39 |
40 | __version__ = "0.4.5"
41 |
--------------------------------------------------------------------------------
/pyglove/core/coding/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # pylint: disable=line-too-long
15 | """Code generation utilities."""
16 |
17 | # pylint: enable=line-too-long
18 | # pylint: disable=g-bad-import-order
19 | # pylint: disable=g-importing-member
20 |
21 | from pyglove.core.coding.errors import CodeError
22 | from pyglove.core.coding.errors import SerializationError
23 |
24 | from pyglove.core.coding.permissions import CodePermission
25 | from pyglove.core.coding.permissions import permission
26 | from pyglove.core.coding.permissions import get_permission
27 |
28 | from pyglove.core.coding.parsing import parse
29 |
30 | from pyglove.core.coding.execution import context
31 | from pyglove.core.coding.execution import get_context
32 | from pyglove.core.coding.execution import evaluate
33 | from pyglove.core.coding.execution import sandbox_call
34 | from pyglove.core.coding.execution import maybe_sandbox_call
35 | from pyglove.core.coding.execution import run
36 |
37 | from pyglove.core.coding.function_generation import NO_TYPE_ANNOTATION
38 | from pyglove.core.coding.function_generation import make_function
39 |
40 | # pylint: disable=line-too-long
41 | # pylint: enable=g-bad-import-order
42 | # pylint: enable=g-importing-member
43 |
--------------------------------------------------------------------------------
/pyglove/core/coding/errors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python code errors."""
15 |
16 | import io
17 | import sys
18 | import textwrap
19 | import traceback
20 | from typing import Optional
21 |
22 | from pyglove.core import utils
23 |
24 |
25 | class CodeError(RuntimeError):
26 | """Python code error."""
27 |
28 | def __init__(
29 | self,
30 | code: str,
31 | cause: BaseException,
32 | ):
33 | self.code = code
34 | self.cause = cause
35 |
36 | # Figure out the starting and ending line numbers of the erratic code.
37 | lineno = None
38 | end_lineno = None
39 | if isinstance(cause, SyntaxError):
40 | lineno = cause.lineno
41 | # For Python 3.9 and below, `end_lineno` is not available.
42 | end_lineno = getattr(cause, 'end_lineno', lineno)
43 | elif not isinstance(cause, TimeoutError):
44 | tb = sys.exc_info()[2]
45 | frames = traceback.extract_tb(tb, limit=5)
46 | for f in frames:
47 | if not f.filename or f.filename == '':
48 | lineno = f.lineno
49 | end_lineno = lineno
50 | break
51 | self.lineno = lineno
52 | self.end_lineno = end_lineno
53 |
54 | def __str__(self):
55 | return self.format(include_complete_code=True)
56 |
57 | def code_lines(self, start_line: int, end_line: int):
58 | """Returns code lines ."""
59 | return '\n'.join(self.code.split('\n')[start_line:end_line])
60 |
61 | def format(self, include_complete_code: bool = True):
62 | """Formats the code error."""
63 | r = io.StringIO()
64 | error_message = str(self.cause).rstrip()
65 | if 'line' not in error_message and self.lineno is not None:
66 | error_message += f' (, line {self.lineno})'
67 | r.write(
68 | utils.colored(
69 | f'{self.cause.__class__.__name__}: {error_message}', 'magenta'))
70 |
71 | if self.lineno is not None:
72 | r.write('\n\n')
73 | r.write(textwrap.indent(
74 | utils.colored(
75 | self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'),
76 | ' ' * 2
77 | ))
78 | r.write('\n')
79 |
80 | if include_complete_code:
81 | r.write('\n')
82 | r.write(utils.colored('[Code]', 'green', styles=['bold']))
83 | r.write('\n\n')
84 | r.write(utils.colored(' ```python\n', 'green'))
85 | r.write(textwrap.indent(
86 | utils.colored(self.code, 'green'),
87 | ' ' * 2
88 | ))
89 | r.write(utils.colored('\n ```\n', 'green'))
90 | return r.getvalue()
91 |
92 |
93 | class SerializationError(RuntimeError):
94 | """Object serialization error."""
95 |
96 | def __init__(self, message: Optional[str], cause: BaseException):
97 | self.message = message
98 | self.cause = cause
99 |
100 | def __str__(self):
101 | r = io.StringIO()
102 | cause_message = str(self.cause).rstrip()
103 | if self.message:
104 | r.write(utils.colored(self.message, 'magenta'))
105 | r.write('\n\n')
106 | r.write(
107 | utils.colored(
108 | f'{self.cause.__class__.__name__}: {cause_message}', 'magenta'
109 | )
110 | )
111 | return r.getvalue()
112 |
--------------------------------------------------------------------------------
/pyglove/core/coding/errors_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | import unittest
16 |
17 | from pyglove.core.coding import errors
18 | from pyglove.core.coding import execution
19 |
20 |
21 | def code_error(code: str) -> errors.CodeError:
22 | try:
23 | execution.run(inspect.cleandoc(code), timeout=2)
24 | assert False, 'should not reach here'
25 | except errors.CodeError as e:
26 | return e
27 |
28 |
29 | class CodeErrorsTest(unittest.TestCase):
30 |
31 | def test_format(self):
32 | e = code_error(
33 | """
34 | x = y + 1
35 | """
36 | )
37 | self.assertIn('[Code]', str(e))
38 | self.assertNotIn(
39 | '[Code]', e.format(include_complete_code=False))
40 |
41 | def test_lineno(self):
42 | self.assertEqual(
43 | code_error(
44 | """
45 | x = y + 1
46 | """
47 | ).lineno, 1)
48 | self.assertEqual(
49 | code_error(
50 | """
51 | x = 1
52 | for i of x:
53 | y = i
54 | """
55 | ).lineno, 2)
56 | self.assertEqual(
57 | code_error(
58 | """
59 | x = 1
60 | y = 2
61 | raise ValueError
62 | """
63 | ).lineno, 3)
64 |
65 | def test_lineno_in_error_message(self):
66 | def assert_lineno(code):
67 | e = code_error(code)
68 | self.assertIn('line', e.format(include_complete_code=False))
69 |
70 | assert_lineno(
71 | """
72 | x = y + 1
73 | """
74 | )
75 | assert_lineno(
76 | """
77 | x = 1
78 | y = 2
79 | """
80 | )
81 | assert_lineno(
82 | """
83 | raise ValueError()
84 | """
85 | )
86 |
87 |
88 | class SerializationErrorTest(unittest.TestCase):
89 |
90 | def test_str(self):
91 | e = errors.SerializationError(
92 | 'Output cannot be serialized.', ValueError('abc'))
93 | self.assertIn('Output cannot be serialized', str(e))
94 | self.assertIn('ValueError: abc', str(e))
95 |
96 |
97 | if __name__ == '__main__':
98 | unittest.main()
99 |
--------------------------------------------------------------------------------
/pyglove/core/coding/function_generation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for function generation."""
15 |
16 | from typing import Any, Dict, List, Optional
17 |
18 |
19 | class _NoTypeAnnotation:
20 | """Placeholder for no type annotation."""
21 |
22 |
23 | NO_TYPE_ANNOTATION = _NoTypeAnnotation()
24 |
25 |
26 | def make_function(
27 | name: str,
28 | args: List[str],
29 | body: List[str],
30 | *,
31 | exec_globals: Optional[Dict[str, Any]] = None,
32 | exec_locals: Optional[Dict[str, Any]] = None,
33 | return_type: Any = NO_TYPE_ANNOTATION):
34 | """Creates a function dynamically from source."""
35 | if exec_locals is None:
36 | exec_locals = {}
37 | if return_type != NO_TYPE_ANNOTATION:
38 | exec_locals['_return_type'] = return_type
39 | return_annotation = '->_return_type'
40 | else:
41 | return_annotation = ''
42 | args = ', '.join(args)
43 | body = '\n'.join(f' {line}' for line in body)
44 | fn_def = f' def {name}({args}){return_annotation}:\n{body}'
45 | local_vars = ', '.join(exec_locals.keys())
46 | fn_def = f'def _make_fn({local_vars}):\n{fn_def}\n return {name}'
47 | ns = {}
48 | exec(fn_def, exec_globals, ns) # pylint: disable=exec-used
49 | return ns['_make_fn'](**exec_locals)
50 |
--------------------------------------------------------------------------------
/pyglove/core/coding/function_generation_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | import typing
16 | import unittest
17 |
18 | from pyglove.core.coding import function_generation
19 |
20 |
21 | class MakeFunctionTest(unittest.TestCase):
22 | """Tests for function_generation.make_function."""
23 |
24 | def test_make_function_with_type_annotations(self):
25 | func = function_generation.make_function(
26 | 'foo',
27 | ['x: typing.Optional[int]', 'y: int = 0'],
28 | ['return x + y'],
29 | exec_globals=None,
30 | exec_locals={'typing': typing},
31 | return_type=int)
32 |
33 | signature = inspect.signature(func)
34 | self.assertEqual(list(signature.parameters.keys()), ['x', 'y'])
35 | self.assertEqual(signature.parameters['x'].annotation, typing.Optional[int])
36 | self.assertEqual(signature.parameters['y'].annotation, int)
37 | self.assertEqual(signature.parameters['y'].default, 0)
38 | self.assertIs(signature.return_annotation, int)
39 | self.assertEqual(func(1, 2), 3)
40 |
41 | def test_make_function_without_type_annotations(self):
42 | func = function_generation.make_function(
43 | 'foo',
44 | ['x', 'y'],
45 | ['return x + y'])
46 | signature = inspect.signature(func)
47 | self.assertEqual(list(signature.parameters.keys()), ['x', 'y'])
48 | self.assertEqual(
49 | signature.parameters['x'].annotation, inspect.Signature.empty)
50 | self.assertEqual(
51 | signature.parameters['y'].annotation, inspect.Signature.empty)
52 | self.assertEqual(signature.parameters['y'].default, inspect.Signature.empty)
53 | self.assertIs(signature.return_annotation, inspect.Signature.empty)
54 | self.assertEqual(func(1, 2), 3)
55 |
56 |
57 | if __name__ == '__main__':
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/pyglove/core/coding/permissions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python code permissions."""
15 |
16 | import contextlib
17 | import enum
18 | from typing import Optional
19 |
20 | from pyglove.core import utils
21 |
22 |
23 | class CodePermission(enum.Flag):
24 | """Permissions for code execution."""
25 |
26 | # Allows assignment.
27 | ASSIGN = enum.auto()
28 |
29 | # Allows conditions.
30 | CONDITION = enum.auto()
31 |
32 | # Allows loops.
33 | LOOP = enum.auto()
34 |
35 | # Call functions or methods.
36 | CALL = enum.auto()
37 |
38 | # Allows exception.
39 | EXCEPTION = enum.auto()
40 |
41 | # Allows class definitions.
42 | CLASS_DEFINITION = enum.auto()
43 |
44 | # Allows function definitions.
45 | FUNCTION_DEFINITION = enum.auto()
46 |
47 | # Allows import.
48 | IMPORT = enum.auto()
49 |
50 | @classmethod
51 | @property
52 | def BASIC(cls) -> 'CodePermission': # pylint: disable=invalid-name
53 | """Returns basic permissions."""
54 | return CodePermission.ASSIGN | CodePermission.CALL
55 |
56 | @classmethod
57 | @property
58 | def ALL(cls) -> 'CodePermission': # pylint: disable=invalid-name
59 | """Returns all permissions."""
60 | return (
61 | CodePermission.BASIC | CodePermission.CONDITION | CodePermission.LOOP |
62 | CodePermission.EXCEPTION | CodePermission.CLASS_DEFINITION |
63 | CodePermission.FUNCTION_DEFINITION | CodePermission.IMPORT)
64 |
65 |
66 | _TLS_CODE_RUN_PERMISSION = '__code_run_permission__'
67 |
68 |
69 | @contextlib.contextmanager
70 | def permission(perm: CodePermission):
71 | """Context manager for controling the permission for code execution.
72 |
73 | When the `permission` context manager is nested, the outtermost permission
74 | will be used. This design allows users to control permission at the top level.
75 |
76 | Args:
77 | perm: Code execution permission.
78 |
79 | Yields:
80 | Actual permission applied.
81 | """
82 |
83 | outter_perm = utils.thread_local_get(_TLS_CODE_RUN_PERMISSION, None)
84 |
85 | # Use the top-level permission as the actual permission
86 | if outter_perm is not None:
87 | perm = outter_perm
88 |
89 | utils.thread_local_set(_TLS_CODE_RUN_PERMISSION, perm)
90 |
91 | try:
92 | yield perm
93 | finally:
94 | if outter_perm is None:
95 | utils.thread_local_del(_TLS_CODE_RUN_PERMISSION)
96 |
97 |
98 | def get_permission() -> Optional[CodePermission]:
99 | """Gets the current permission for code execution."""
100 | return utils.thread_local_get(_TLS_CODE_RUN_PERMISSION, None)
101 |
--------------------------------------------------------------------------------
/pyglove/core/coding/permissions_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import unittest
15 | from pyglove.core.coding import permissions
16 |
17 |
18 | class CodePermissionTest(unittest.TestCase):
19 |
20 | def assert_set(
21 | self,
22 | permission: permissions.CodePermission,
23 | flag: permissions.CodePermission,
24 | ):
25 | self.assertEqual(permission & flag, flag)
26 |
27 | def assert_not_set(
28 | self,
29 | permission: permissions.CodePermission,
30 | flag: permissions.CodePermission,
31 | ):
32 | self.assertFalse(permission & flag)
33 |
34 | def test_basic(self):
35 | self.assert_set(
36 | permissions.CodePermission.BASIC, permissions.CodePermission.ASSIGN
37 | )
38 | self.assert_set(
39 | permissions.CodePermission.BASIC, permissions.CodePermission.CALL
40 | )
41 | self.assert_set(
42 | permissions.CodePermission.BASIC, permissions.CodePermission.CALL
43 | )
44 |
45 | def test_all(self):
46 | self.assert_set(
47 | permissions.CodePermission.ALL, permissions.CodePermission.BASIC
48 | )
49 | self.assert_set(
50 | permissions.CodePermission.ALL, permissions.CodePermission.CONDITION
51 | )
52 | self.assert_set(
53 | permissions.CodePermission.ALL, permissions.CodePermission.LOOP
54 | )
55 | self.assert_set(
56 | permissions.CodePermission.ALL, permissions.CodePermission.EXCEPTION
57 | )
58 | self.assert_set(
59 | permissions.CodePermission.ALL,
60 | permissions.CodePermission.CLASS_DEFINITION,
61 | )
62 | self.assert_set(
63 | permissions.CodePermission.ALL,
64 | permissions.CodePermission.FUNCTION_DEFINITION,
65 | )
66 | self.assert_set(
67 | permissions.CodePermission.ALL, permissions.CodePermission.IMPORT
68 | )
69 |
70 | def test_xor(self):
71 | self.assert_not_set(
72 | permissions.CodePermission.ALL ^ permissions.CodePermission.BASIC,
73 | permissions.CodePermission.BASIC,
74 | )
75 | self.assert_set(
76 | permissions.CodePermission.ALL ^ permissions.CodePermission.BASIC,
77 | permissions.CodePermission.CONDITION,
78 | )
79 |
80 | def test_permission_control(self):
81 | self.assertIsNone(permissions.get_permission())
82 | with permissions.permission(permissions.CodePermission.BASIC):
83 | self.assertEqual(
84 | permissions.get_permission(), permissions.CodePermission.BASIC
85 | )
86 | with permissions.permission(permissions.CodePermission.ALL):
87 | self.assertEqual(
88 | permissions.get_permission(), permissions.CodePermission.BASIC
89 | )
90 |
91 |
92 | if __name__ == '__main__':
93 | unittest.main()
94 |
--------------------------------------------------------------------------------
/pyglove/core/detouring/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Symbolic detour.
15 |
16 | It is straighforward to symbolize existing classes and functions, but in order
17 | to use them, we need to replace the classes that were used in existing code
18 | with the symbolic ones. Sometimes, we just cannot modify existing source code.
19 | Or in some other cases, objects created within a function or a class method are
20 | not exposed to the external, therefore we cannot manipulate them as a part of
21 | the symbolic tree. For example::
22 |
23 | @pg.symbolize
24 | def foo():
25 | # Object `a` is not a part of `foo`'s interface,
26 | # therefore it cannot be seen from the symbolic tree
27 | # that contains a `foo` object.
28 | a = A(1)
29 | return a.do_something()
30 |
31 | Symbolic detour is introduced to address these use cases, which redirects
32 | the ``__new__`` method of a class to another class or function when it’s
33 | evaluated under a context manager. Symbolic detour is not dependent on
34 | symbolization, so in theory it can be used for detouring any classes.
35 | Therefore, it does not require the presence of symbolic objects for mutating
36 | the program.
37 | """
38 |
39 | # pylint: disable=g-bad-import-order
40 |
41 | from pyglove.core.detouring.class_detour import detour
42 | from pyglove.core.detouring.class_detour import current_mappings
43 | from pyglove.core.detouring.class_detour import undetoured_new
44 |
45 | # pylint: enable=g-bad-import-order
46 |
--------------------------------------------------------------------------------
/pyglove/core/geno/random.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Random DNA generator."""
15 |
16 | import random
17 | import types
18 | from typing import Any, Optional, Union
19 |
20 | from pyglove.core import symbolic
21 | from pyglove.core import typing as pg_typing
22 | from pyglove.core.geno.base import DNA
23 | from pyglove.core.geno.base import DNASpec
24 | from pyglove.core.geno.dna_generator import DNAGenerator
25 |
26 |
27 | @symbolic.members([
28 | ('seed', pg_typing.Int().noneable(), 'Random seed.')
29 | ])
30 | class Random(DNAGenerator):
31 | """Random DNA generator."""
32 |
33 | def _setup(self):
34 | """Setup DNA spec."""
35 | if self.seed is None:
36 | self._random = random
37 | else:
38 | self._random = random.Random(self.seed)
39 |
40 | def _propose(self) -> DNA:
41 | """Propose a random DNA."""
42 | return random_dna(self._dna_spec, self._random)
43 |
44 | def _replay(self, trial_id: int, dna: DNA, reward: Any) -> None:
45 | """Replay the history to recover the last proposed DNA."""
46 | # NOTE(daiyip): If the seed is fixed, we want to reproduce the same
47 | # sequence of random examples, we can do this simply by repeating previous
48 | # generation process.
49 | if self.seed is not None:
50 | random_dna(self._dna_spec, self._random)
51 |
52 |
53 | def random_dna(
54 | dna_spec: DNASpec,
55 | random_generator: Union[None, types.ModuleType, random.Random] = None,
56 | attach_spec: bool = True,
57 | previous_dna: Optional[DNA] = None
58 | ) -> DNA:
59 | """Generates a random DNA from a DNASpec.
60 |
61 | Example::
62 |
63 | spec = pg.geno.space([
64 | pg.geno.oneof([
65 | pg.geno.constant(),
66 | pg.geno.constant(),
67 | pg.geno.constant()
68 | ]),
69 | pg.geno.floatv(0.1, 0.2)
70 | ])
71 |
72 | print(pg.random_dna(spec))
73 | # DNA([2, 0.1123])
74 |
75 | Args:
76 | dna_spec: a DNASpec object.
77 | random_generator: a Python random generator.
78 | attach_spec: If True, attach the DNASpec to generated DNA.
79 | previous_dna: An optional DNA representing previous DNA. This field might
80 | be useful for generating stateful random DNAs.
81 |
82 | Returns:
83 | A DNA object.
84 | """
85 | return dna_spec.random_dna(
86 | random_generator or random, attach_spec, previous_dna)
87 |
88 |
--------------------------------------------------------------------------------
/pyglove/core/geno/sweeping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Sweeping DNA generator."""
15 |
16 | from typing import Any
17 |
18 | from pyglove.core.geno.base import DNA
19 | from pyglove.core.geno.dna_generator import DNAGenerator
20 |
21 |
22 | class Sweeping(DNAGenerator):
23 | """Sweeping (Grid Search) DNA generator."""
24 |
25 | def _setup(self):
26 | """Setup DNA spec."""
27 | self._last_proposed_dna = None
28 |
29 | def _propose(self) -> DNA:
30 | """Propose a random DNA."""
31 | next_dna = self.dna_spec.next_dna(self._last_proposed_dna)
32 | if next_dna is None:
33 | raise StopIteration()
34 | self._last_proposed_dna = next_dna
35 | return next_dna
36 |
37 | def _replay(self, trial_id: int, dna: DNA, reward: Any) -> None:
38 | """Replay the history to recover the last proposed DNA."""
39 | del trial_id, reward
40 | self._last_proposed_dna = dna
41 |
--------------------------------------------------------------------------------
/pyglove/core/geno/sweeping_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain algo copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.geno.Sweeping."""
15 |
16 | import unittest
17 |
18 | from pyglove.core.geno.base import DNA
19 | from pyglove.core.geno.categorical import manyof
20 | from pyglove.core.geno.categorical import oneof
21 | from pyglove.core.geno.space import constant
22 | from pyglove.core.geno.space import Space
23 | from pyglove.core.geno.sweeping import Sweeping
24 |
25 |
26 | class SweepingTest(unittest.TestCase):
27 | """Test the `pg.geno.Sweeping`."""
28 |
29 | def _dna_spec(self):
30 | return Space([
31 | # Single choice.
32 | oneof([
33 | manyof(2, [constant(), constant(), constant()]),
34 | oneof([constant(), constant()])
35 | ]),
36 | manyof(2, [constant(), constant(), constant()], sorted=True),
37 | ])
38 |
39 | def test_propose(self):
40 | algo = Sweeping()
41 | algo.setup(self._dna_spec())
42 | results = []
43 | while True:
44 | try:
45 | results.append(algo.propose())
46 | except StopIteration:
47 | break
48 |
49 | self.assertEqual(results, [
50 | DNA([(0, [0, 1]), [0, 1]]),
51 | DNA([(0, [0, 1]), [0, 2]]),
52 | DNA([(0, [0, 1]), [1, 2]]),
53 | DNA([(0, [0, 2]), [0, 1]]),
54 | DNA([(0, [0, 2]), [0, 2]]),
55 | DNA([(0, [0, 2]), [1, 2]]),
56 | DNA([(0, [1, 0]), [0, 1]]),
57 | DNA([(0, [1, 0]), [0, 2]]),
58 | DNA([(0, [1, 0]), [1, 2]]),
59 | DNA([(0, [1, 2]), [0, 1]]),
60 | DNA([(0, [1, 2]), [0, 2]]),
61 | DNA([(0, [1, 2]), [1, 2]]),
62 | DNA([(0, [2, 0]), [0, 1]]),
63 | DNA([(0, [2, 0]), [0, 2]]),
64 | DNA([(0, [2, 0]), [1, 2]]),
65 | DNA([(0, [2, 1]), [0, 1]]),
66 | DNA([(0, [2, 1]), [0, 2]]),
67 | DNA([(0, [2, 1]), [1, 2]]),
68 | DNA([(1, 0), [0, 1]]),
69 | DNA([(1, 0), [0, 2]]),
70 | DNA([(1, 0), [1, 2]]),
71 | DNA([(1, 1), [0, 1]]),
72 | DNA([(1, 1), [0, 2]]),
73 | DNA([(1, 1), [1, 2]])
74 | ])
75 |
76 | def test_recover(self):
77 | algo1 = Sweeping()
78 | algo1.setup(self._dna_spec())
79 | dna_list = [algo1.propose() for _ in range(10)]
80 | algo2 = algo1.clone(deep=True)
81 | algo2.setup(self._dna_spec())
82 | algo2.recover([(dna, 0.) for dna in dna_list])
83 | self.assertEqual(algo1.propose(), algo2.propose())
84 |
85 |
86 | if __name__ == '__main__':
87 | unittest.main()
88 |
--------------------------------------------------------------------------------
/pyglove/core/hyper/custom_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import random
15 | import unittest
16 |
17 | from pyglove.core import geno
18 | from pyglove.core import symbolic
19 | from pyglove.core import utils
20 | from pyglove.core.hyper.categorical import oneof
21 | from pyglove.core.hyper.custom import CustomHyper
22 | from pyglove.core.hyper.iter import iterate
23 | from pyglove.core.hyper.object_template import materialize
24 |
25 |
26 | class IntSequence(CustomHyper):
27 |
28 | def custom_decode(self, dna):
29 | return [int(v) for v in dna.value.split(',')]
30 |
31 |
32 | class IntSequenceWithEncode(IntSequence):
33 |
34 | def custom_encode(self, value):
35 | return geno.DNA(','.join([str(v) for v in value]))
36 |
37 | def next_dna(self, dna):
38 | if dna is None:
39 | return geno.DNA(','.join([str(i) for i in range(5)]))
40 | v = self.custom_decode(dna)
41 | v.append(len(v))
42 | return self._create_dna(v)
43 |
44 | def random_dna(self, random_generator, previous_dna):
45 | del previous_dna
46 | k = random_generator.randint(0, 10)
47 | v = random_generator.choices(list(range(10)), k=k)
48 | return self._create_dna(v)
49 |
50 | def _create_dna(self, numbers):
51 | return geno.DNA(','.join([str(n) for n in numbers]))
52 |
53 |
54 | class CustomHyperTest(unittest.TestCase):
55 | """Test for CustomHyper."""
56 |
57 | def test_dna_spec(self):
58 | self.assertTrue(
59 | symbolic.eq(
60 | IntSequence(hints='x').dna_spec('a'),
61 | geno.CustomDecisionPoint(
62 | hyper_type='IntSequence', location=utils.KeyPath('a'), hints='x'
63 | ),
64 | )
65 | )
66 |
67 | def test_decode(self):
68 | self.assertEqual(IntSequence().decode(geno.DNA('0,1,2')), [0, 1, 2])
69 | self.assertEqual(IntSequence().decode(geno.DNA('0')), [0])
70 | with self.assertRaisesRegex(ValueError, '.* expects string type DNA'):
71 | IntSequence().decode(geno.DNA(1))
72 |
73 | def test_encode(self):
74 | self.assertEqual(
75 | IntSequenceWithEncode().encode([0, 1, 2]), geno.DNA('0,1,2'))
76 |
77 | with self.assertRaisesRegex(
78 | NotImplementedError, '\'custom_encode\' is not supported by'):
79 | _ = IntSequence().encode([0, 1, 2])
80 |
81 | def test_random_dna(self):
82 | self.assertEqual(
83 | geno.random_dna(
84 | IntSequenceWithEncode().dna_spec('a'), random.Random(1)),
85 | geno.DNA('5,8'))
86 |
87 | with self.assertRaisesRegex(
88 | NotImplementedError, '`random_dna` is not implemented in .*'):
89 | geno.random_dna(IntSequence().dna_spec('a'))
90 |
91 | def test_iter(self):
92 | self.assertEqual(IntSequenceWithEncode().first_dna(), geno.DNA('0,1,2,3,4'))
93 | self.assertEqual(
94 | list(iterate(IntSequenceWithEncode(), 3)),
95 | [[0, 1, 2, 3, 4],
96 | [0, 1, 2, 3, 4, 5],
97 | [0, 1, 2, 3, 4, 5, 6]])
98 |
99 | with self.assertRaisesRegex(
100 | NotImplementedError, '`next_dna` is not implemented in .*'):
101 | next(iterate(IntSequence()))
102 |
103 | def test_interop_with_other_primitives(self):
104 | v = oneof([IntSequence(), 1, 2])
105 | self.assertEqual(materialize(v, geno.DNA(1)), 1)
106 | self.assertEqual(materialize(v, geno.DNA((0, '3,4'))), [3, 4])
107 |
108 |
109 | if __name__ == '__main__':
110 | unittest.main()
111 |
--------------------------------------------------------------------------------
/pyglove/core/io/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Pluggable IO with unified interfaces."""
15 |
16 | # pylint: disable=g-bad-import-order
17 |
18 | from pyglove.core.io.file_system import *
19 | from pyglove.core.io.sequence import *
20 |
21 | # pylint: enable=g-bad-import-order
22 |
--------------------------------------------------------------------------------
/pyglove/core/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Logging for PyGlove.
15 |
16 | This module allows PyGlove to use external created logger for logging PyGlove
17 | events without introducing library dependencies in PyGlove.
18 | """
19 |
20 | import inspect
21 | import logging
22 | from typing import Any, Callable, List, Union
23 |
24 |
25 | _DEFAULT_LOGGER = logging.getLogger()
26 |
27 |
28 | def register_frame_to_skip(
29 | method: Union[Callable[..., Any], List[Callable[..., Any]]]
30 | ) -> bool:
31 | """Skips the source of the given method when logging.
32 |
33 | Args:
34 | method: The method to skip. Can be a single method or a list of methods.
35 |
36 | Returns:
37 | True if the method is registered to skip.
38 |
39 | Raises:
40 | TypeError: The source file of the method cannot be inspected.
41 | """
42 | register_fn = getattr(
43 | _DEFAULT_LOGGER.__class__, 'register_frame_to_skip', None
44 | )
45 | if register_fn is None:
46 | return False
47 | methods = [method] if not isinstance(method, list) else method
48 | for m in methods:
49 | register_fn(inspect.getsourcefile(m), m.__name__)
50 | return True
51 |
52 |
53 | def set_logger(logger: logging.Logger) -> None:
54 | """Sets current logger."""
55 | global _DEFAULT_LOGGER
56 | _DEFAULT_LOGGER = logger
57 |
58 | # Skip logging frames in pyglove.logging.
59 | register_frame_to_skip([debug, info, warning, error, critical])
60 |
61 |
62 | def get_logger() -> logging.Logger:
63 | """Gets the current logger."""
64 | return _DEFAULT_LOGGER
65 |
66 |
67 | def debug(msg: str, *args, **kwargs) -> None:
68 | """Logs debug message.
69 |
70 | Args:
71 | msg: Message with possible format string.
72 | *args: Values for variables in the format string.
73 | **kwargs: Keyword arguments for the logger.
74 | """
75 | _DEFAULT_LOGGER.debug(msg, *args, **kwargs)
76 |
77 |
78 | def info(msg: str, *args, **kwargs) -> None:
79 | """Logs info message.
80 |
81 | Args:
82 | msg: Message with possible format string.
83 | *args: Values for variables in the format string.
84 | **kwargs: Keyword arguments for the logger.
85 | """
86 | _DEFAULT_LOGGER.info(msg, *args, **kwargs)
87 |
88 |
89 | def warning(msg: str, *args, **kwargs) -> None:
90 | """Logs warning message.
91 |
92 | Args:
93 | msg: Message with possible format string.
94 | *args: Values for variables in the format string.
95 | **kwargs: Keyword arguments for the logger.
96 | """
97 | _DEFAULT_LOGGER.warning(msg, *args, **kwargs)
98 |
99 |
100 | def error(msg: str, *args, **kwargs) -> None:
101 | """Logs error message.
102 |
103 | Args:
104 | msg: Message with possible format string.
105 | *args: Values for variables in the format string.
106 | **kwargs: Keyword arguments for the logger.
107 | """
108 | _DEFAULT_LOGGER.error(msg, *args, **kwargs)
109 |
110 |
111 | def critical(msg: str, *args, **kwargs) -> None:
112 | """Logs critical message.
113 |
114 | Args:
115 | msg: Message with possible format string.
116 | *args: Values for variables in the format string.
117 | **kwargs: Keyword arguments for the logger.
118 | """
119 | _DEFAULT_LOGGER.critical(msg, *args, **kwargs)
120 |
--------------------------------------------------------------------------------
/pyglove/core/logging_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import io
15 | import logging
16 | import unittest
17 | from pyglove.core import logging as pg_logging
18 |
19 |
20 | class LoggingTest(unittest.TestCase):
21 | """Tests for pg.logging."""
22 |
23 | def testLogging(self):
24 | string_io = io.StringIO()
25 | logger = logging.getLogger('logger1')
26 | logger.setLevel(logging.INFO)
27 | console_handler = logging.StreamHandler(stream=string_io)
28 | console_handler.setLevel(logging.INFO)
29 | logger.addHandler(console_handler)
30 |
31 | self.assertIs(pg_logging.get_logger(), logging.getLogger())
32 | pg_logging.set_logger(logger)
33 | self.assertIs(pg_logging.get_logger(), logger)
34 |
35 | pg_logging.debug('x=%s', 1)
36 | pg_logging.info('y=%s', 2)
37 | pg_logging.warning('z=%s', 3)
38 | pg_logging.error('p=%s', 4)
39 | pg_logging.critical('q=%s', 5)
40 |
41 | self.assertEqual(string_io.getvalue(), '\n'.join([
42 | 'y=2',
43 | 'z=3',
44 | 'p=4',
45 | 'q=5',
46 | ]) + '\n')
47 |
48 | string_io = io.StringIO()
49 | logger = logging.getLogger('logger2')
50 | logger.setLevel(logging.DEBUG)
51 | console_handler = logging.StreamHandler(stream=string_io)
52 | console_handler.setLevel(logging.DEBUG)
53 | logger.addHandler(console_handler)
54 |
55 | pg_logging.set_logger(logger)
56 |
57 | pg_logging.debug('x=%s', 6)
58 | self.assertEqual(string_io.getvalue(), '\n'.join([
59 | 'x=6',
60 | ]) + '\n')
61 |
62 |
63 | if __name__ == '__main__':
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/pyglove/core/patching/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Systematic patching on symbolic values.
15 |
16 | As :meth:`pyglove.Symbolic.rebind` provides a flexible programming
17 | interface for modifying symbolic values, why bother to have this module?
18 | Here are the motivations:
19 |
20 | * Provide user friendly methods for addressing the most common patching
21 | patterns.
22 |
23 | * Provide a systematic solution for
24 |
25 | * Patch semantic groups.
26 | * Enable combination of these groups.
27 | * Provide an interface that patching can be invoked from the command line.
28 | """
29 |
30 | # pylint: disable=g-bad-import-order
31 |
32 | # Pattern-based patching.
33 |
34 | from pyglove.core.patching.pattern_based import patch_on_key
35 | from pyglove.core.patching.pattern_based import patch_on_path
36 | from pyglove.core.patching.pattern_based import patch_on_type
37 | from pyglove.core.patching.pattern_based import patch_on_value
38 | from pyglove.core.patching.pattern_based import patch_on_member
39 |
40 | # Patcher: modular rule-based patching.
41 | from pyglove.core.patching.rule_based import patcher
42 | from pyglove.core.patching.rule_based import patch
43 |
44 | from pyglove.core.patching.rule_based import Patcher
45 | from pyglove.core.patching.rule_based import from_uri
46 |
47 | from pyglove.core.patching.rule_based import patcher_names
48 | from pyglove.core.patching.rule_based import allow_repeated_patcher_registration
49 |
50 | # Object factory based on patchers.
51 | from pyglove.core.patching.object_factory import ObjectFactory
52 | from pyglove.core.patching.object_factory import from_maybe_serialized
53 |
54 |
55 | # pylint: enable=g-bad-import-order
56 |
--------------------------------------------------------------------------------
/pyglove/core/patching/object_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Object factory based on patchers."""
15 |
16 | from typing import Any, Callable, Dict, Optional, Type, Union
17 | from pyglove.core import symbolic
18 | from pyglove.core import utils
19 | from pyglove.core.patching import rule_based
20 |
21 |
22 | def from_maybe_serialized(
23 | source: Union[Any, str],
24 | value_type: Optional[Type[Any]] = None) -> Any:
25 | """Load value from maybe serialized form (e.g. JSON file or JSON string).
26 |
27 | Args:
28 | source: Source of value. It can be value (non-string type) itself, or a
29 | filepath, or a JSON string from where the value will be loaded.
30 | value_type: An optional type to constrain the value.
31 |
32 | Returns:
33 | Value from source.
34 | """
35 | if isinstance(source, str):
36 | if source.endswith('.json'):
37 | value = symbolic.load(source)
38 | else:
39 | value = symbolic.from_json_str(source)
40 | else:
41 | value = source
42 | if value_type is not None and not isinstance(value, value_type):
43 | raise TypeError(
44 | f'Loaded value {value!r} is not an instance of {value_type!r}.')
45 | return value
46 |
47 |
48 | @symbolic.functor()
49 | def ObjectFactory( # pylint: disable=invalid-name
50 | value_type: Type[symbolic.Symbolic],
51 | base_value: Union[symbolic.Symbolic,
52 | Callable[[], symbolic.Symbolic],
53 | str],
54 | patches: Optional[rule_based.PatchType] = None,
55 | params_override: Optional[Union[Dict[str, Any], str]] = None) -> Any:
56 | """A factory to create symbolic object from a base value and patches.
57 |
58 | Args:
59 | value_type: Type of return value.
60 | base_value: An instance of `value_type`,
61 | or a callable object that produces an instance of `value_type`,
62 | or a string as the path to the serialized value.
63 | patches: Optional patching rules. See :func:`patch` for details.
64 | params_override: A rebind dict (or a JSON string as serialized rebind dict)
65 | as an additional patch to the value,
66 |
67 | Returns:
68 | Value after applying `patchers` and `params_override` based on `base_value`.
69 | """
70 | # Step 1: Load base value.
71 | if not isinstance(base_value, value_type) and callable(base_value):
72 | value = base_value()
73 | elif isinstance(base_value, str):
74 | value = symbolic.load(base_value)
75 | else:
76 | value = base_value
77 |
78 | if not isinstance(value, value_type):
79 | raise TypeError(
80 | f'{base_value!r} is neither an instance of {value_type!r}, '
81 | f'nor a factory or a path of JSON file that produces an '
82 | f'instance of {value_type!r}.')
83 |
84 | # Step 2: Patch with patchers if available.
85 | if patches is not None:
86 | value = rule_based.patch(value, patches)
87 |
88 | # Step 3: Patch with additional parameter override dict if available.
89 | if params_override:
90 | value = value.rebind(
91 | utils.flatten(from_maybe_serialized(params_override, dict)),
92 | raise_on_no_change=False,
93 | )
94 | return value
95 |
--------------------------------------------------------------------------------
/pyglove/core/patching/pattern_based_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pattern-based patching."""
15 |
16 | import unittest
17 | from pyglove.core import symbolic
18 | from pyglove.core import typing as pg_typing
19 | from pyglove.core.patching import pattern_based
20 |
21 |
22 | class PatternBasedPatchingTest(unittest.TestCase):
23 | """Pattern-based patching test."""
24 |
25 | def test_patch_on_key(self):
26 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1})
27 | pattern_based.patch_on_key(d, 'a', 3)
28 | self.assertEqual(d, {'a': 3, 'b': {'a': 3, 'b': 1}})
29 |
30 | pattern_based.patch_on_key(d, 'a', value_fn=lambda v: v + 1)
31 | self.assertEqual(d, {'a': 4, 'b': {'a': 4, 'b': 1}})
32 |
33 | with self.assertRaisesRegex(
34 | ValueError, 'Either `value` or `value_fn` should be specified'):
35 | pattern_based.patch_on_key(d, 'a', value=1, value_fn=lambda v: v + 1)
36 |
37 | def test_patch_on_path(self):
38 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1})
39 | pattern_based.patch_on_path(d, '.+b', 3)
40 | self.assertEqual(d, {'a': 1, 'b': {'a': 2, 'b': 3}})
41 |
42 | pattern_based.patch_on_path(d, '.*a', value_fn=lambda v: v + 1)
43 | self.assertEqual(d, {'a': 2, 'b': {'a': 3, 'b': 3}})
44 |
45 | def test_patch_on_value(self):
46 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1})
47 | pattern_based.patch_on_value(d, 1, 3)
48 | self.assertEqual(d, {'a': 3, 'b': {'a': 2, 'b': 3}})
49 |
50 | pattern_based.patch_on_value(d, 2, value_fn=lambda v: v * 2)
51 | self.assertEqual(d, {'a': 3, 'b': {'a': 4, 'b': 3}})
52 |
53 | def test_patch_on_type(self):
54 | d = symbolic.Dict(a='abc', b={'a': 2, 'b': 'def'})
55 | pattern_based.patch_on_type(d, str, 'foo')
56 | self.assertEqual(d, {'a': 'foo', 'b': {'a': 2, 'b': 'foo'}})
57 |
58 | pattern_based.patch_on_type(d, int, value_fn=lambda v: v * 2)
59 | self.assertEqual(d, {'a': 'foo', 'b': {'a': 4, 'b': 'foo'}})
60 |
61 | def test_patch_on_member(self):
62 |
63 | @symbolic.members([
64 | ('x', pg_typing.Int()),
65 | ('y', pg_typing.Int()),
66 | ])
67 | class A(symbolic.Object):
68 | pass
69 |
70 | d = symbolic.Dict(a=A(x=1, y=2), x=1)
71 | pattern_based.patch_on_member(d, A, 'x', 2)
72 | self.assertEqual(d, {'a': A(x=2, y=2), 'x': 1})
73 |
74 | pattern_based.patch_on_member(d, A, 'y', value_fn=lambda v: v * 2)
75 | self.assertEqual(d, {'a': A(x=2, y=4), 'x': 1})
76 |
77 |
78 | if __name__ == '__main__':
79 | unittest.main()
80 |
--------------------------------------------------------------------------------
/pyglove/core/symbolic/error_info_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | import unittest
17 | from pyglove.core.symbolic import base
18 | from pyglove.core.symbolic import error_info as error_info_lib # pylint: disable=g-bad-import-order
19 |
20 |
21 | class ErrorInfoTest(unittest.TestCase):
22 | """Tests for ErrorInfo."""
23 |
24 | def test_from_exception(self):
25 |
26 | def foo():
27 | return 1 / 0
28 |
29 | def bar():
30 | try:
31 | foo()
32 | except ZeroDivisionError as e:
33 | raise ValueError('Bad call to `foo`') from e
34 |
35 | error_info = None
36 | try:
37 | bar()
38 | except ValueError as e:
39 | error_info = error_info_lib.ErrorInfo.from_exception(e)
40 | self.assertIsNotNone(error_info)
41 | self.assertEqual(error_info.tag, 'ValueError.ZeroDivisionError')
42 | self.assertEqual(error_info.description, 'Bad call to `foo`')
43 | self.assertIn('Traceback (most recent call last)', error_info.stacktrace)
44 |
45 | def test_to_json(self):
46 | error_info = error_info_lib.ErrorInfo(
47 | tag='ValueError.ZeroDivisionError',
48 | description='Bad call to `foo`',
49 | stacktrace='Traceback (most recent call last)',
50 | )
51 | json_dict = error_info.to_json()
52 | error_info2 = base.from_json(json_dict)
53 | self.assertIsNot(error_info2, error_info)
54 | self.assertEqual(error_info2, error_info)
55 | json_dict['_type'] = 'pyglove.core.utils.error_utils.ErrorInfo'
56 | error_info2 = base.from_json(json_dict)
57 | self.assertEqual(error_info2, error_info)
58 |
59 | def test_format(self):
60 | error_info = error_info_lib.ErrorInfo(
61 | tag='ValueError.ZeroDivisionError',
62 | description='Bad call to `foo`',
63 | stacktrace='Traceback (most recent call last)',
64 | )
65 | self.assertEqual(
66 | error_info.format(compact=False),
67 | inspect.cleandoc(
68 | """
69 | ErrorInfo(
70 | tag = 'ValueError.ZeroDivisionError',
71 | description = 'Bad call to `foo`',
72 | stacktrace = 'Traceback (most recent call last)'
73 | )
74 | """
75 | )
76 | )
77 |
78 | def test_to_html(self):
79 | error_info = error_info_lib.ErrorInfo(
80 | tag='ValueError.ZeroDivisionError',
81 | description='Bad call to `foo`',
82 | stacktrace='Traceback (most recent call last)',
83 | )
84 | html = error_info.to_html()
85 | self.assertIn('ErrorInfo', html.content)
86 | self.assertIn('ValueError.ZeroDivisionError', html.content)
87 | self.assertIn('Bad call to `foo`', html.content)
88 | self.assertIn('Traceback (most recent call last)', html.content)
89 |
90 |
91 | if __name__ == '__main__':
92 | unittest.main()
93 |
--------------------------------------------------------------------------------
/pyglove/core/symbolic/inferred.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common inferential values."""
15 |
16 | from typing import Any, Tuple
17 | from pyglove.core import typing as pg_typing
18 | from pyglove.core import utils
19 | from pyglove.core.symbolic import base
20 | from pyglove.core.symbolic.object import Object
21 |
22 |
23 | class InferredValue(Object, base.Inferential):
24 | """Base class for inferred values."""
25 |
26 | def custom_apply(self, *args, **kwargs: Any) -> Tuple[bool, Any]:
27 | # This is to make a ``InferredValue`` object assignable
28 | # to any symbolic attribute.
29 | return (False, self)
30 |
31 |
32 | class ValueFromParentChain(InferredValue):
33 | """A value that could inferred from the parent chain.
34 |
35 | For example::
36 |
37 | class A(pg.Object):
38 | x: int
39 | y: int = pg.symbolic.ValueFromParentChain()
40 |
41 | # Not okay: `x` is not inferential and is not specified.
42 | A()
43 |
44 | # Okay: both `x` and `y` are specified.
45 | A(x=1, y=2)
46 |
47 | # Okay: `y` is inferential, hence optional.
48 | a = A(x=1)
49 |
50 | # Raises: `y` is neither specified during __init__
51 | # nor provided from the context.
52 | a.y
53 |
54 | d = pg.Dict(y=2, z=pg.Dict(a=a))
55 |
56 | # `a.y` now refers to `d.a` since `d` is in its symbolic parent chain,
57 | # aka. context.
58 | assert a.y == 2
59 | """
60 |
61 | def infer(self, **kwargs) -> Any:
62 | parent = self.sym_parent
63 | while True:
64 | v = self.value_from(parent, **kwargs)
65 | if v == pg_typing.MISSING_VALUE:
66 | if parent is None:
67 | raise AttributeError(
68 | utils.message_on_path(
69 | (
70 | f'`{self.inference_key}` is not found under its context '
71 | '(along its symbolic parent chain).'
72 | ),
73 | self.sym_path,
74 | )
75 | )
76 | parent = parent.sym_parent
77 | else:
78 | return v
79 |
80 | @property
81 | def inference_key(self) -> str:
82 | """Returns the key for attribute inference from parents."""
83 | return self.sym_path.key
84 |
85 | def value_from(self, parent: base.Symbolic, **kwargs) -> Any:
86 | del kwargs
87 | if parent is self.sym_parent or parent is None:
88 | # NOTE(daiyip): The inferred value could not be read from the immediate
89 | # parent since its key points to current inferential value.
90 | # We should also return MISSING_VALUE when the traversal has gone beyond
91 | # of the symbolic tree root.
92 | return pg_typing.MISSING_VALUE
93 |
94 | # Use current key to lookup from the parent.
95 | key = self.inference_key
96 | if isinstance(key, int):
97 | return (
98 | parent[key]
99 | if isinstance(parent, (list, tuple))
100 | else pg_typing.MISSING_VALUE
101 | )
102 | return getattr(parent, key, pg_typing.MISSING_VALUE)
103 |
--------------------------------------------------------------------------------
/pyglove/core/symbolic/inferred_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.symbolic.inferred."""
15 |
16 | import unittest
17 |
18 | from pyglove.core import typing as pg_typing
19 | from pyglove.core.symbolic import inferred
20 | from pyglove.core.symbolic.dict import Dict
21 |
22 |
23 | class ValueFromParentChain(unittest.TestCase):
24 | """Tests for `pg.symbolic.ValueFromParentChain`."""
25 |
26 | def test_inference(self):
27 | v = Dict(y=1, x=Dict(x=1, y=inferred.ValueFromParentChain()))
28 | self.assertEqual(v.x.y, 1)
29 |
30 | v.rebind(y=2)
31 | self.assertEqual(v.x.y, 2)
32 |
33 | def test_custom_typing(self):
34 | v = inferred.ValueFromParentChain()
35 | self.assertIs(pg_typing.Int().apply(v), v)
36 | self.assertIs(pg_typing.Str().apply(v), v)
37 |
38 |
39 | if __name__ == '__main__':
40 | unittest.main()
41 |
--------------------------------------------------------------------------------
/pyglove/core/symbolic/origin_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.symbolic.Origin."""
15 |
16 | import unittest
17 |
18 | from pyglove.core.symbolic import flags
19 | from pyglove.core.symbolic.dict import Dict
20 | from pyglove.core.symbolic.origin import Origin
21 |
22 |
23 | class OriginTest(unittest.TestCase):
24 | """Tests for `pg.symbolic.Origin`."""
25 |
26 | def test_basics(self):
27 | a = Dict(a=1)
28 | o = Origin(a, '__init__')
29 | self.assertIs(o.source, a)
30 | self.assertEqual(o.tag, '__init__')
31 | self.assertIsNone(o.stack)
32 | self.assertIsNone(o.stacktrace)
33 |
34 | with self.assertRaisesRegex(ValueError, '`tag` must be a string'):
35 | _ = Origin(a, 1)
36 |
37 | def test_include_stacktrace(self):
38 | a = Dict(a=1)
39 | o = Origin(a, '__init__', stacktrace=True, stacklimit=3)
40 | self.assertIs(o.source, a)
41 | self.assertEqual(o.tag, '__init__')
42 | self.assertEqual(len(o.stack), 3)
43 | self.assertIsNotNone(o.stacktrace)
44 |
45 | flags.set_origin_stacktrace_limit(2)
46 | o = Origin(a, '__init__', stacktrace=True)
47 | self.assertEqual(len(o.stack), 2)
48 |
49 | def test_root(self):
50 | a = Dict(a=1)
51 | b = Dict(b=2)
52 | c = Dict(c=3)
53 |
54 | c.sym_setorigin(b, 'foo')
55 | b.sym_setorigin(a, 'bar')
56 | self.assertIs(c.sym_origin.root.source, a)
57 |
58 | def test_history(self):
59 | a = Dict(a=1)
60 | b = Dict(b=2)
61 | c = Dict(c=3)
62 |
63 | c.sym_setorigin(b, 'foo')
64 | b.sym_setorigin(a, 'bar')
65 | self.assertEqual(
66 | c.sym_origin.history(),
67 | [
68 | Origin(a, 'bar'),
69 | Origin(b, 'foo'),
70 | ])
71 |
72 | self.assertEqual(
73 | c.sym_origin.history(lambda o: o.tag == 'foo'),
74 | [
75 | Origin(b, 'foo'),
76 | ])
77 |
78 | self.assertEqual(
79 | c.sym_origin.history(lambda o: o.tag == 'bar'),
80 | [
81 | Origin(a, 'bar'),
82 | ])
83 |
84 | def test_eq_ne(self):
85 | a = Dict(a=1)
86 | self.assertEqual(Origin(None, '__init__'), Origin(None, '__init__'))
87 | self.assertEqual(Origin(a, 'builder'), Origin(a, 'builder'))
88 | self.assertNotEqual(Origin(a, 'builder'), a)
89 | self.assertNotEqual(Origin(a, 'builder'), Origin(a, 'return'))
90 | self.assertNotEqual(Origin(a, 'builder'), Origin(Dict(a=1), 'builder'))
91 |
92 | def test_format(self):
93 | a = Dict(a=1)
94 | o = Origin(None, '__init__')
95 | self.assertEqual(o.format(compact=True), 'Origin(tag=\'__init__\')')
96 |
97 | o = Origin('/path/to/file', 'load')
98 | self.assertEqual(
99 | o.format(compact=False),
100 | "Origin(\n tag='load',\n source='/path/to/file'\n)"
101 | )
102 |
103 | o = Origin(a, 'builder')
104 | self.assertEqual(
105 | o.format(compact=True),
106 | 'Origin(tag=\'builder\', source={a=1} at 0x%x)' % id(a))
107 |
108 |
109 | if __name__ == '__main__':
110 | unittest.main()
111 |
--------------------------------------------------------------------------------
/pyglove/core/symbolic/pure_symbolic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interfaces for pure symbolic objects."""
15 |
16 | from typing import Any, Callable, Optional, Tuple
17 | from pyglove.core import typing as pg_typing
18 | from pyglove.core import utils
19 |
20 |
21 | class PureSymbolic(pg_typing.CustomTyping):
22 | """Base class to classes whose objects are considered pure symbolic.
23 |
24 | Pure symbolic objects can be used for representing abstract concepts - for
25 | example, a search space of objects - which cannot be executed but soely
26 | representational.
27 |
28 | Having pure symbolic object is a key differentiator of symbolic OOP from
29 | regular OOP, which can be used to placehold values in an object as a
30 | high-level expression of ideas. Later, with symbolic manipulation, the
31 | pure symbolic objects are replaced with material values so the object
32 | can be evaluated. This effectively decouples the expression of ideas from
33 | the implementation of ideas. For example: ``pg.oneof(['a', 'b', 'c']`` will
34 | be manipulated into 'a', 'b' or 'c' based on the decision of a search
35 | algorithm, letting the program evolve itself.
36 | """
37 |
38 | def custom_apply(
39 | self,
40 | path: utils.KeyPath,
41 | value_spec: pg_typing.ValueSpec,
42 | allow_partial: bool,
43 | child_transform: Optional[
44 | Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
45 | ] = None,
46 | ) -> Tuple[bool, Any]:
47 | """Custom apply on a value based on its original value spec.
48 |
49 | This implements ``pg.pg_typing.CustomTyping``, allowing a pure symbolic
50 | value to be assigned to any field. To customize this behavior, override
51 | this method in subclasses.
52 |
53 | Args:
54 | path: KeyPath of current object under its object tree.
55 | value_spec: Original value spec for this field.
56 | allow_partial: Whether allow partial object to be created.
57 | child_transform: Function to transform child node values into their final
58 | values. Transform function is called on leaf nodes first, then on their
59 | parents, recursively.
60 |
61 | Returns:
62 | A tuple (proceed_with_standard_apply, value_to_proceed).
63 | If proceed_with_standard_apply is set to False, value_to_proceed
64 | will be used as final value.
65 |
66 | Raises:
67 | Error when the value is not compatible with the value spec.
68 | """
69 | del path, value_spec, allow_partial, child_transform
70 | return (False, self)
71 |
72 |
73 | class NonDeterministic(PureSymbolic):
74 | """Base class to mark a class whose objects are considered non-deterministic.
75 |
76 | A non-deterministic value represents a value that will be decided later.
77 | In PyGlove system, `pg.one_of`, `pg.sublist_of`, `pg.float_value` are
78 | non-deterministic values. Please search `NonDeterministic` subclasses for more
79 | details.
80 | """
81 |
--------------------------------------------------------------------------------
/pyglove/core/tuning/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Distributed tuning with pluggable backends.
15 |
16 | :func:`pyglove.iter` provides an interface for sampling examples from a search
17 | space within a process. To support distributed tuning, PyGlove introduces
18 | :func:`pyglove.sample`, which is almost identical but with more features:
19 |
20 | * Allow multiple worker processes (aka. workers) to collaborate on a search
21 | with failover handling.
22 | * Each worker can process different trials, or can cowork on the same trials
23 | via work groups.
24 | * Provide APIs for communicating between the co-workers.
25 | * Provide API for retrieving the search results.
26 | * Provide a pluggable backend system for supporting user infrastructures.
27 |
28 | """
29 |
30 | # pylint: disable=g-bad-import-order
31 |
32 | # User facing APIs for tuning.
33 | from pyglove.core.tuning.sample import sample
34 | from pyglove.core.tuning.backend import poll_result
35 |
36 | from pyglove.core.tuning.backend import default_backend
37 | from pyglove.core.tuning.backend import set_default_backend
38 |
39 | # Tuning protocols.
40 | from pyglove.core.tuning.protocols import Measurement
41 | from pyglove.core.tuning.protocols import Trial
42 | from pyglove.core.tuning.protocols import Result
43 | from pyglove.core.tuning.protocols import Feedback
44 | from pyglove.core.tuning.protocols import RaceConditionError
45 |
46 | # Interface for early stopping.
47 | from pyglove.core.tuning.early_stopping import EarlyStoppingPolicy
48 |
49 | # Interfaces for tuning backend developers.
50 | from pyglove.core.tuning.backend import Backend
51 | from pyglove.core.tuning.backend import add_backend
52 | from pyglove.core.tuning.backend import available_backends
53 |
54 | # Importing local backend.
55 | import pyglove.core.tuning.local_backend
56 |
57 | # pylint: enable=g-bad-import-order
58 |
--------------------------------------------------------------------------------
/pyglove/core/tuning/backend_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.core.tuning.backend."""
15 |
16 | import unittest
17 | from pyglove.core.tuning import backend
18 | from pyglove.core.tuning import local_backend # pylint: disable=unused-import
19 |
20 |
21 | class BackendTest(unittest.TestCase):
22 | """Tests for pluggable backend."""
23 |
24 | def test_pluggable_backend(self):
25 | self.assertEqual(backend.available_backends(), ['in-memory'])
26 |
27 | @backend.add_backend('test')
28 | class TestBackend(backend.Backend): # pylint: disable=unused-variable
29 | """A fake backend factory for testing."""
30 |
31 | def __init__(self, **kwargs):
32 | pass
33 |
34 | @classmethod
35 | def poll_result(cls, name):
36 | return None
37 |
38 | def next(self):
39 | return None
40 |
41 | self.assertEqual(backend.available_backends(), ['in-memory', 'test'])
42 | self.assertEqual(backend.default_backend(), 'in-memory')
43 | backend.set_default_backend('test')
44 | self.assertEqual(backend.default_backend(), 'test')
45 |
46 | with self.assertRaisesRegex(
47 | ValueError, 'Backend .* does not exist'):
48 | backend.set_default_backend('non-exist-backend')
49 |
50 | with self.assertRaisesRegex(
51 | TypeError, '.* is not a `pg.tuning.Backend` subclass'):
52 |
53 | @backend.add_backend('bad')
54 | class BadBackend: # pylint: disable=unused-variable
55 | pass
56 | backend.set_default_backend('in-memory')
57 | self.assertEqual(backend.default_backend(), 'in-memory')
58 |
59 |
60 | if __name__ == '__main__':
61 | unittest.main()
62 |
--------------------------------------------------------------------------------
/pyglove/core/tuning/early_stopping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interface for early stopping policies."""
15 |
16 | import abc
17 | from typing import Iterable, Optional
18 |
19 | from pyglove.core import geno
20 | from pyglove.core import symbolic
21 | from pyglove.core.tuning.protocols import Trial
22 |
23 |
24 | class EarlyStoppingPolicy(symbolic.Object):
25 | """Interface for early stopping policy."""
26 |
27 | def setup(self, dna_spec: geno.DNASpec) -> None:
28 | """Setup states of an early stopping policy based on dna_spec.
29 |
30 | Args:
31 | dna_spec: DNASpec for DNA to propose.
32 |
33 | Raises:
34 | RuntimeError: if dna_spec is not supported.
35 | """
36 | self._dna_spec = dna_spec
37 |
38 | @property
39 | def dna_spec(self) -> Optional[geno.DNASpec]:
40 | return getattr(self, '_dna_spec', None)
41 |
42 | @abc.abstractmethod
43 | def should_stop_early(self, trial: Trial) -> bool:
44 | """Should stop the input trial early based on its measurements."""
45 |
46 | def recover(self, history: Iterable[Trial]) -> None:
47 | """Recover states by replaying the trial history.
48 |
49 | Subclass can override.
50 |
51 | NOTE: `recover` will always be called before the first `should_stop_early`
52 | is called. It could be called multiple times if there are multiple source
53 | of history, e.g: trials from a previous study and existing trials from
54 | current study.
55 |
56 | The default behavior is to replay `should_stop_early` on all trials that
57 | contain all intermediate measurements.
58 |
59 | Args:
60 | history: An iterable object of trials.
61 | """
62 | for trial in history:
63 | if trial.status in ['COMPLETED', 'PENDING', 'STOPPING']:
64 | # TODO(daiyip): Stopped trials do not need to be fed back.
65 | self.should_stop_early(trial)
66 |
--------------------------------------------------------------------------------
/pyglove/core/tuning/protocols_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.core.tuning.protocols."""
15 |
16 | import time
17 | import unittest
18 |
19 | from pyglove.core import geno
20 | from pyglove.core.tuning.protocols import Measurement
21 | from pyglove.core.tuning.protocols import Trial
22 |
23 |
24 | class TrialTest(unittest.TestCase):
25 | """Test for Trial class."""
26 |
27 | def test_get_reward_for_feedback(self):
28 | t = Trial(
29 | id=0, dna=geno.DNA(0),
30 | status='PENDING',
31 | created_time=int(time.time()))
32 | self.assertIsNone(t.get_reward_for_feedback())
33 |
34 | t = Trial(
35 | id=0, dna=geno.DNA(0),
36 | status='COMPLETED',
37 | infeasible=True,
38 | created_time=int(time.time()))
39 | self.assertIsNone(t.get_reward_for_feedback())
40 |
41 | t = Trial(
42 | id=0, dna=geno.DNA(0),
43 | status='COMPLETED',
44 | infeasible=False,
45 | final_measurement=Measurement(
46 | step=1, elapse_secs=0.1, reward=1.0, metrics=dict(
47 | accuracy=0.9, latency=750.0)),
48 | created_time=int(time.time()))
49 | self.assertEqual(t.get_reward_for_feedback(), 1.0)
50 | self.assertEqual(t.get_reward_for_feedback(['accuracy', 'latency']),
51 | (0.9, 750.0))
52 | with self.assertRaisesRegex(
53 | ValueError, 'Metric \'foo\' does not exist'):
54 | t.get_reward_for_feedback(['foo'])
55 |
56 |
57 | if __name__ == '__main__':
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/pyglove/core/typing/annotated.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """pg.typing.Annotated: A drop-in replacement of typing.Annotated."""
15 |
16 | import typing
17 | from typing import Any, Dict, Optional, Tuple
18 |
19 | from pyglove.core.typing import class_schema
20 | from pyglove.core.typing import value_specs as vs
21 |
22 |
23 | class Annotated(vs.Generic):
24 | """The PyGlove enhanced `typing.Annotated` for defining a class field.
25 |
26 | Example::
27 |
28 | class A(pg.Object):
29 | x = pg.typing.Annotated[
30 | int, # Field type or value spec.
31 | 'Docstring for field `x`.', # Field docstring.
32 | dict(foo=1, bar=2) # Field metadata
33 | ]
34 | ]
35 | """
36 |
37 | def __init__(
38 | self,
39 | t: vs.ValueSpecOrAnnotation,
40 | docstring: Optional[str] = None,
41 | metadata: Optional[Dict[str, Any]] = None
42 | ):
43 | super().__init__()
44 | self._value_spec = class_schema.ValueSpec.from_annotation(
45 | t, auto_typing=True)
46 | self._docstring = docstring
47 | self._metadata = metadata or {}
48 |
49 | @property
50 | def value_spec(self) -> class_schema.ValueSpec:
51 | """Returns the value spec for the field."""
52 | return self._value_spec
53 |
54 | @property
55 | def docstring(self) -> Optional[str]:
56 | """Returns the docstring for the field."""
57 | return self._docstring
58 |
59 | @property
60 | def metadata(self) -> Dict[str, Any]:
61 | """Returns the metadata for the field."""
62 | return self._metadata
63 |
64 | @classmethod
65 | def with_type_args(cls, type_args: Tuple[Any, ...]) -> Any:
66 | if len(type_args) == 1:
67 | t = type_args[0]
68 | docstring, metadata = None, None
69 | elif len(type_args) == 2:
70 | t, docstring = type_args
71 | metadata = None
72 | elif len(type_args) == 3:
73 | t, docstring, metadata = type_args
74 | else:
75 | raise TypeError(
76 | '`pg.typing.Annotated` accepts 1 to 3 type arguments ',
77 | '(, [field docstring], [field metadata]). '
78 | f'Encountered: {type_args!r}')
79 | if docstring is not None and not isinstance(docstring, str):
80 | raise TypeError(
81 | 'The second type argument (`docstring`) must be a str. '
82 | f'Encountered: {docstring!r}')
83 | if metadata is not None and not isinstance(metadata, dict):
84 | raise TypeError(
85 | 'The third type argument (`metadata`) must be a dict with str keys. '
86 | f'Encountered: {metadata!r}')
87 |
88 | t = class_schema.ValueSpec.from_annotation(t, auto_typing=True)
89 | annotated = Annotated(t, docstring=docstring, metadata=metadata)
90 |
91 | # This makes `pg.typing.Annotated` compatible with third-party type checking
92 | # solutions.
93 | if typing.TYPE_CHECKING:
94 | return annotated.value_spec.annotation
95 | return annotated
96 |
--------------------------------------------------------------------------------
/pyglove/core/typing/annotated_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for pyglove.core.typing.annotated."""
15 |
16 | import typing
17 | import unittest
18 |
19 | from pyglove.core.typing import annotated
20 | from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import
21 | from pyglove.core.typing import value_specs as vs
22 |
23 |
24 | class AnnotatedTest(unittest.TestCase):
25 |
26 | def test_subscription(self):
27 | # Field type only.
28 | x = annotated.Annotated[int]
29 | self.assertEqual(x.value_spec, vs.Int())
30 | self.assertIsNone(x.docstring)
31 | self.assertEqual(x.metadata, {})
32 |
33 | # Field type and docstring
34 | x = annotated.Annotated[int, 'hello']
35 | self.assertEqual(x.value_spec, vs.Int())
36 | self.assertEqual(x.docstring, 'hello')
37 | self.assertEqual(x.metadata, {})
38 |
39 | # Field type, docstring and metadata
40 | x = annotated.Annotated[int, 'hello', dict(foo=1)]
41 | self.assertEqual(x.value_spec, vs.Int())
42 | self.assertEqual(x.docstring, 'hello')
43 | self.assertEqual(x.metadata, dict(foo=1))
44 |
45 | def test_bad_subscription(self):
46 | with self.assertRaisesRegex(
47 | TypeError, '`pg.typing.Annotated` accepts 1 to 3 type arguments'):
48 | _ = annotated.Annotated[int, 'hello', dict(foo=1), 1]
49 |
50 | with self.assertRaisesRegex(
51 | TypeError, 'Cannot convert 1'):
52 | _ = annotated.Annotated[1]
53 |
54 | with self.assertRaisesRegex(
55 | TypeError, 'The second type argument .* must be a str'):
56 | _ = annotated.Annotated[int, 1]
57 |
58 | with self.assertRaisesRegex(
59 | TypeError, 'The third type argument .* must be a dict with str keys'):
60 | _ = annotated.Annotated[int, 'foo', [1, 2]]
61 |
62 | def test_type_checking(self):
63 | typing.TYPE_CHECKING = True
64 |
65 | self.assertIs(annotated.Annotated[str], str)
66 | self.assertIs(annotated.Annotated[str, 'hello'], str)
67 | self.assertIs(
68 | annotated.Annotated[vs.Str().noneable(), 'hello'], typing.Optional[str])
69 |
70 | typing.TYPE_CHECKING = False
71 |
72 |
73 | if __name__ == '__main__':
74 | unittest.main()
75 |
--------------------------------------------------------------------------------
/pyglove/core/typing/custom_typing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interface for intercepting typing checking."""
15 |
16 | import abc
17 | from typing import Any, Callable, Optional, Tuple
18 |
19 | from pyglove.core import utils
20 | from pyglove.core.typing import class_schema
21 |
22 |
23 | class CustomTyping(metaclass=abc.ABCMeta):
24 | """Interface of custom value type.
25 |
26 | Instances of subclasses of CustomTyping can be assigned to fields of
27 | any ValueSpec, and take over `apply` via `custom_apply` method.
28 |
29 | As a result, CustomTyping makes the schema system extensible without modifying
30 | existing value specs. For example, value generators can extend CustomTyping
31 | and be assignable to any fields.
32 | """
33 |
34 | @abc.abstractmethod
35 | def custom_apply(
36 | self,
37 | path: utils.KeyPath,
38 | value_spec: class_schema.ValueSpec,
39 | allow_partial: bool,
40 | child_transform: Optional[
41 | Callable[[utils.KeyPath, class_schema.Field, Any], Any]
42 | ] = None,
43 | ) -> Tuple[bool, Any]:
44 | """Custom apply on a value based on its original value spec.
45 |
46 | Args:
47 | path: KeyPath of current object under its object tree.
48 | value_spec: Original value spec for this field.
49 | allow_partial: Whether allow partial object to be created.
50 | child_transform: Function to transform child node values into their final
51 | values. Transform function is called on leaf nodes first, then on their
52 | parents, recursively.
53 |
54 | Returns:
55 | A tuple (proceed_with_standard_apply, value_to_proceed).
56 | If proceed_with_standard_apply is set to False, value_to_proceed
57 | will be used as final value.
58 |
59 | Raises:
60 | Error when the value is not compatible with the value spec.
61 | """
62 |
--------------------------------------------------------------------------------
/pyglove/core/typing/pytype_support.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Pytype support."""
15 |
16 | import typing
17 |
18 | if typing.TYPE_CHECKING:
19 |
20 | _GenericCallable = typing.TypeVar('_GenericCallable')
21 |
22 | class Decorator(object):
23 | """A type annotation for decorators that do not change signatures.
24 |
25 | This is a stand-in for using `Callable[[T], T]` to represent a decorator.
26 |
27 | Given a decorator function, which takes in a callable and returns a callable
28 | with the same signature, apply this class as a decorator to that function.
29 | This can also be used for decorator factories.
30 |
31 | Examples:
32 |
33 | Plain decorator (decorator matches Callable[[T], T]):
34 |
35 | >>> @pg.typing.Decorator
36 | ... def my_decorator(func):
37 | ... def wrapper(...):
38 | ... ...
39 | ... return wrapper
40 |
41 | Decorator factory (factory matches Callable[..., Callable[[T], T]]):
42 |
43 | >>> def my_decorator_factory(foo: int):
44 | ...
45 | ... @py.typing.Decorator
46 | ... def my_decorator(func):
47 | ... ...
48 | ... return my_decorator
49 |
50 | This class only exists at build time, for typechecking. At runtime, the
51 | 'Decorator' member of this module is a simple identity function.
52 | """
53 |
54 | def __init__(
55 | self,
56 | decorator: typing.Callable[[_GenericCallable], _GenericCallable]): # pylint: disable=unused-argument
57 | ... # pylint: disable=pointless-statement
58 |
59 | def __call__(self, func: _GenericCallable) -> _GenericCallable:
60 | ... # pytype: disable=bad-return-type # pylint: disable=pointless-statement
61 |
62 | else:
63 | Decorator = lambda d: d
64 |
--------------------------------------------------------------------------------
/pyglove/core/typing/typed_missing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Typed value placeholders."""
15 |
16 | from typing import Any
17 | from pyglove.core import utils
18 | from pyglove.core.typing import class_schema
19 |
20 |
21 | # Non-typed missing value.
22 | MISSING_VALUE = utils.MISSING_VALUE
23 |
24 |
25 | class MissingValue(utils.MissingValue, utils.Formattable):
26 | """Class represents missing value **for a specific value spec**."""
27 |
28 | def __init__(self, value_spec: class_schema.ValueSpec):
29 | """Constructor."""
30 | self._value_spec = value_spec
31 |
32 | @property
33 | def value_spec(self) -> class_schema.ValueSpec:
34 | """Returns value spec of current missing value."""
35 | return self._value_spec
36 |
37 | def __eq__(self, other: Any) -> bool:
38 | """Operator ==.
39 |
40 | NOTE: `MissingValue(value_spec) and `utils.MissingValue` are
41 | considered equal, but `MissingValue(value_spec1)` and
42 | `MissingValue(value_spec2)` are considered different. That being said,
43 | the 'eq' operation is not transitive.
44 |
45 | However in practice this is not a problem, since user always compare
46 | against `schema.MISSING_VALUE` which is `utils.MissingValue`.
47 | Therefore the `__hash__` function returns the same value with
48 | `utils.MissingValue`.
49 |
50 | Args:
51 | other: the value to compare against.
52 |
53 | Returns:
54 | True if the other value is a general MissingValue or MissingValue of the
55 | same value spec.
56 | """
57 | if self is other:
58 | return True
59 | if isinstance(other, MissingValue):
60 | return self._value_spec == other.value_spec
61 | return MISSING_VALUE == other
62 |
63 | def __hash__(self) -> int:
64 | """Overridden hashing to make all MissingValue return the same value."""
65 | return hash(MISSING_VALUE)
66 |
67 | def format(self,
68 | compact: bool = False,
69 | verbose: bool = True,
70 | root_indent: int = 0,
71 | **kwargs) -> str:
72 | """Format current object."""
73 | if compact:
74 | return 'MISSING_VALUE'
75 | else:
76 | spec_str = self._value_spec.format(
77 | compact=compact, verbose=verbose, root_indent=root_indent, **kwargs)
78 | return f'MISSING_VALUE({spec_str})'
79 |
80 | def __deepcopy__(self, memo):
81 | """Avoid deep copy by copying value_spec by reference."""
82 | return MissingValue(self.value_spec)
83 |
--------------------------------------------------------------------------------
/pyglove/core/typing/typed_missing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import unittest
15 |
16 | from pyglove.core import utils
17 | from pyglove.core.typing import typed_missing
18 | from pyglove.core.typing import value_specs
19 |
20 |
21 | class MissingValueTest(unittest.TestCase):
22 | """Tests for typed MissingValue class."""
23 |
24 | def test_eq(self):
25 | self.assertEqual(
26 | typed_missing.MissingValue(value_specs.Int()), utils.MISSING_VALUE
27 | )
28 |
29 | self.assertEqual(
30 | utils.MISSING_VALUE, typed_missing.MissingValue(value_specs.Int())
31 | )
32 |
33 | self.assertEqual(
34 | typed_missing.MissingValue(value_specs.Int()),
35 | typed_missing.MissingValue(value_specs.Int()))
36 |
37 | self.assertNotEqual(
38 | typed_missing.MissingValue(value_specs.Int()),
39 | typed_missing.MissingValue(value_specs.Int(max_value=1)))
40 |
41 | self.assertNotEqual(
42 | typed_missing.MissingValue(value_specs.Int()),
43 | typed_missing.MissingValue(value_specs.Str()))
44 |
45 | m = typed_missing.MissingValue(value_specs.Int())
46 | self.assertEqual(m, m)
47 |
48 | def test_hash(self):
49 | self.assertEqual(
50 | hash(typed_missing.MissingValue(value_specs.Int())),
51 | hash(typed_missing.MissingValue(value_specs.Float())))
52 |
53 | self.assertEqual(
54 | hash(typed_missing.MissingValue(value_specs.Int())),
55 | hash(utils.MISSING_VALUE),
56 | )
57 |
58 | self.assertNotEqual(
59 | hash(typed_missing.MissingValue(value_specs.Int())),
60 | hash(1))
61 |
62 | def test_format(self):
63 | """Test MissingValue.format."""
64 | self.assertEqual(
65 | typed_missing.MissingValue(value_specs.Int()).format(compact=True),
66 | 'MISSING_VALUE')
67 |
68 | self.assertEqual(
69 | typed_missing.MissingValue(value_specs.Int()).format(compact=False),
70 | 'MISSING_VALUE(Int())')
71 |
72 |
73 | if __name__ == '__main__':
74 | unittest.main()
75 |
--------------------------------------------------------------------------------
/pyglove/core/utils/common_traits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common traits for Python objects.
15 |
16 | This file defines interfaces for describing the common traits of a Python
17 | object, for example, partiality (MaybePartial), functor (Functor).
18 | """
19 |
20 | import abc
21 | from typing import Any, Dict, Optional, Union
22 |
23 |
24 | class MaybePartial(metaclass=abc.ABCMeta):
25 | """Interface for classes whose instances can be partially constructed.
26 |
27 | A ``MaybePartial`` object is an object whose ``__init__`` method can accept
28 | ``pg.MISSING_VALUE`` as its argument values. All symbolic types (see
29 | :class:`pyglove.Symbolic`) implements this interface, as their symbolic
30 | attributes can be partially filled.
31 |
32 | Example::
33 |
34 | d = pg.Dict(x=pg.MISSING_VALUE, y=1)
35 | assert d.is_partial
36 | assert 'x' in d.missing_values()
37 | """
38 |
39 | @property
40 | def is_partial(self) -> bool:
41 | """Returns True if this object is partial. Otherwise False.
42 |
43 | An object is considered partial when any of its required fields is missing,
44 | or at least one member is partial. The subclass can override this method
45 | to provide a more efficient solution.
46 | """
47 | return len(self.missing_values()) > 0 # pylint: disable=g-explicit-length-test
48 |
49 | @abc.abstractmethod
50 | def missing_values(self, flatten: bool = True) -> Dict[Union[str, int], Any]: # pylint: disable=redefined-outer-name
51 | """Returns missing values from this object.
52 |
53 | Args:
54 | flatten: If True, convert nested structures into a flattened dict using
55 | key path (delimited by '.' and '[]') as key.
56 |
57 | Returns:
58 | A dict of key to MISSING_VALUE.
59 | """
60 |
61 |
62 | class Functor(metaclass=abc.ABCMeta):
63 | """Interface for functor."""
64 |
65 | @abc.abstractmethod
66 | def __call__(self, *args, **kwargs) -> Any:
67 | """Calls the functor.
68 |
69 | Args:
70 | *args: Any positional arguments.
71 | **kwargs: Any keyword arguments.
72 |
73 | Returns:
74 | Any value.
75 | """
76 |
77 |
78 | def explicit_method_override(method):
79 | """Decorator that marks a member method as explicitly overridden.
80 |
81 | In PyGlove, many methods are managed by the framework - for example -
82 | ``pg.Object.__init__``. It's easy for users to override these methods
83 | unconsciously. Therefore, we introduce this decorator to catch error at
84 | the first place when such overrides incidentally take place, while allowing
85 | advanced users to override them.
86 |
87 | Usage::
88 |
89 | class Foo(pg.Object):
90 |
91 | @pg.explicit_method_override
92 | def __init__(self, *args, **kwargs):
93 | ...
94 |
95 | Args:
96 | method: method to explicitly overriden.
97 |
98 | Returns:
99 | The original method with an explicit overriden stamp.
100 | """
101 | setattr(method, '__explicit_override__', True)
102 | return method
103 |
104 |
105 | def ensure_explicit_method_override(
106 | method, error_message: Optional[str] = None) -> None:
107 | """Returns True if a method is explicitly overridden."""
108 | if not getattr(method, '__explicit_override__', False):
109 | if error_message is None:
110 | error_message = (
111 | f'{method} is a PyGlove managed method. If you do need to override '
112 | 'it, please decorate the method with `@pg.explicit_method_override`.')
113 | raise TypeError(error_message)
114 |
--------------------------------------------------------------------------------
/pyglove/core/utils/common_traits_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import unittest
15 | from pyglove.core.utils import common_traits
16 |
17 |
18 | class ExplicitlyOverrideTest(unittest.TestCase):
19 |
20 | def test_explicitly_override(self):
21 | class A:
22 |
23 | @common_traits.explicit_method_override
24 | def __init__(self, x, y):
25 | pass
26 |
27 | def bar(self):
28 | pass
29 |
30 | common_traits.ensure_explicit_method_override(A.__init__)
31 | with self.assertRaisesRegex(TypeError, '.* is a PyGlove managed method'):
32 | common_traits.ensure_explicit_method_override(A.bar)
33 |
34 |
35 | if __name__ == '__main__':
36 | unittest.main()
37 |
--------------------------------------------------------------------------------
/pyglove/core/utils/contextual_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import concurrent.futures
15 | import unittest
16 | from pyglove.core.utils import contextual
17 |
18 |
19 | class ContextualTest(unittest.TestCase):
20 |
21 | def test_contextual_override(self):
22 | with contextual.contextual_override(x=3, y=3, z=3) as parent_override:
23 | self.assertEqual(
24 | parent_override,
25 | dict(
26 | x=contextual.ContextualOverride(
27 | 3, cascade=False, override_attrs=False
28 | ),
29 | y=contextual.ContextualOverride(
30 | 3, cascade=False, override_attrs=False
31 | ),
32 | z=contextual.ContextualOverride(
33 | 3, cascade=False, override_attrs=False
34 | ),
35 | ),
36 | )
37 | self.assertEqual(
38 | contextual.get_contextual_override('y'),
39 | contextual.ContextualOverride(3, cascade=False, override_attrs=False),
40 | )
41 | self.assertEqual(contextual.contextual_value('x'), 3)
42 | self.assertIsNone(contextual.contextual_value('f', None))
43 | with self.assertRaisesRegex(KeyError, '.* does not exist'):
44 | contextual.contextual_value('f')
45 |
46 | self.assertEqual(contextual.all_contextual_values(), dict(x=3, y=3, z=3))
47 |
48 | # Test nested contextual override with override_attrs=True (default).
49 | with contextual.contextual_override(
50 | y=4, z=4, override_attrs=True) as nested_override:
51 | self.assertEqual(
52 | nested_override,
53 | dict(
54 | x=contextual.ContextualOverride(
55 | 3, cascade=False, override_attrs=False
56 | ),
57 | y=contextual.ContextualOverride(
58 | 4, cascade=False, override_attrs=True
59 | ),
60 | z=contextual.ContextualOverride(
61 | 4, cascade=False, override_attrs=True
62 | ),
63 | ),
64 | )
65 |
66 | # Test nested contextual override with cascade=True.
67 | with contextual.contextual_override(x=3, y=3, z=3, cascade=True):
68 | with contextual.contextual_override(y=4, z=4, cascade=True):
69 | self.assertEqual(contextual.contextual_value('x'), 3)
70 | self.assertEqual(contextual.contextual_value('y'), 3)
71 | self.assertEqual(contextual.contextual_value('z'), 3)
72 |
73 | def test_with_contextual_override(self):
74 | def func(i):
75 | del i
76 | return contextual.contextual_value('x')
77 |
78 | pool = concurrent.futures.ThreadPoolExecutor()
79 | with contextual.contextual_override(x=3):
80 | self.assertEqual(contextual.with_contextual_override(func)(0), 3)
81 | self.assertEqual(
82 | list(pool.map(contextual.with_contextual_override(func), range(1))),
83 | [3]
84 | )
85 |
86 |
87 | if __name__ == '__main__':
88 | unittest.main()
89 |
--------------------------------------------------------------------------------
/pyglove/core/utils/error_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import unittest
15 | from pyglove.core.utils import error_utils
16 |
17 |
18 | class CatchErrorsTest(unittest.TestCase):
19 |
20 | def assert_caught_error(self, func, errors_to_catch):
21 | with error_utils.catch_errors(errors_to_catch) as context:
22 | func()
23 | self.assertIsNotNone(context.error)
24 |
25 | def assert_propagate_error(self, func, errors_to_catch):
26 |
27 | with self.assertRaises(Exception):
28 | with error_utils.catch_errors(errors_to_catch):
29 | func()
30 |
31 | def test_catch_errors(self):
32 | def foo():
33 | raise ValueError('this is an error')
34 |
35 | self.assert_caught_error(foo, ValueError)
36 | self.assert_caught_error(foo, (ValueError,))
37 | self.assert_caught_error(foo, (Exception, 'ValueError'))
38 | self.assert_caught_error(foo, (KeyError, ValueError))
39 | self.assert_caught_error(foo, (ValueError, 'an error'))
40 | self.assert_caught_error(foo, (KeyError, (ValueError, 'an error'),))
41 |
42 | self.assert_propagate_error(foo, KeyError)
43 | self.assert_propagate_error(foo, (ValueError, '^an error'))
44 | self.assert_propagate_error(foo, (ValueError, 'something else'))
45 |
46 | def test_catch_errors_with_error_handler(self):
47 | errors = []
48 | def handle_error(error):
49 | errors.append(error)
50 |
51 | def foo():
52 | raise ValueError()
53 |
54 | with error_utils.catch_errors(ValueError, handle_error) as context:
55 | foo()
56 | self.assertEqual(errors, [context.error])
57 |
58 | def test_catch_errors_bad_inputs(self):
59 | with self.assertRaisesRegex(
60 | TypeError, 'Each error specification should be either .*'):
61 | with error_utils.catch_errors([(ValueError, 'abc', 'abc')]):
62 | pass
63 |
64 | with self.assertRaisesRegex(
65 | TypeError, 'Each error specification should be either .*'):
66 | with error_utils.catch_errors([(ValueError, 1)]):
67 | pass
68 |
69 | with self.assertRaisesRegex(
70 | TypeError, 'Exception contains non-except types'):
71 | with error_utils.catch_errors([str, ValueError]):
72 | pass
73 |
74 |
75 | if __name__ == '__main__':
76 | unittest.main()
77 |
--------------------------------------------------------------------------------
/pyglove/core/utils/missing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Representing missing value for a field."""
15 |
16 | from typing import Any, Dict
17 | from pyglove.core.utils import formatting
18 | from pyglove.core.utils import json_conversion
19 |
20 |
21 | class MissingValue(formatting.Formattable, json_conversion.JSONConvertible):
22 | """Value placeholder for an unassigned attribute."""
23 |
24 | def format(self, *args, **kwargs): # pytype: disable=signature-mismatch
25 | return 'MISSING_VALUE'
26 |
27 | def __ne__(self, other: Any) -> bool:
28 | return not self.__eq__(other)
29 |
30 | def __eq__(self, other: Any) -> bool:
31 | return isinstance(other, MissingValue)
32 |
33 | def __hash__(self) -> int:
34 | return hash(MissingValue.__module__ + MissingValue.__name__)
35 |
36 | def to_json(self, **kwargs) -> Dict[str, Any]:
37 | return self.to_json_dict(fields=dict(), **kwargs)
38 |
39 |
40 | # A shortcut global object (constant) for referencing MissingValue.
41 | MISSING_VALUE = MissingValue()
42 |
--------------------------------------------------------------------------------
/pyglove/core/utils/missing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import unittest
15 | from pyglove.core.utils import json_conversion
16 | from pyglove.core.utils import missing
17 |
18 |
19 | class MissingValueTest(unittest.TestCase):
20 | """Tests for class MissingValue."""
21 |
22 | def test_basics(self):
23 | self.assertEqual(missing.MissingValue(),
24 | missing.MissingValue())
25 | self.assertNotEqual(missing.MissingValue(), 1)
26 | self.assertNotEqual(missing.MissingValue(), {})
27 |
28 | self.assertEqual(str(missing.MissingValue()), 'MISSING_VALUE')
29 | self.assertEqual(repr(missing.MissingValue()), 'MISSING_VALUE')
30 |
31 | def test_to_json(self):
32 | json = json_conversion.to_json(missing.MissingValue())
33 | self.assertEqual(json_conversion.from_json(json), missing.MissingValue())
34 |
35 |
36 | if __name__ == '__main__':
37 | unittest.main()
38 |
--------------------------------------------------------------------------------
/pyglove/core/utils/text_color_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | import os
16 | import unittest
17 | from pyglove.core.utils import text_color
18 |
19 |
20 | class TextColorTest(unittest.TestCase):
21 |
22 | def setUp(self):
23 | super().setUp()
24 | os.environ.pop('ANSI_COLORS_DISABLED', None)
25 | os.environ.pop('NO_COLOR', None)
26 | os.environ['FORCE_COLOR'] = '1'
27 |
28 | def test_colored_block_without_colors_and_styles(self):
29 | self.assertEqual(text_color.colored_block('foo', '{{', '}}'), 'foo')
30 |
31 | def test_colored_block(self):
32 | original_text = inspect.cleandoc("""
33 | Hi << foo >>
34 | <# print x if x is present #>
35 | <% if x %>
36 | << x >>
37 | <% endif %>
38 | """)
39 |
40 | colored_text = text_color.colored_block(
41 | text_color.colored(original_text, color='blue'),
42 | '<<', '>>',
43 | color='white',
44 | background='blue',
45 | )
46 | origin_color = '\x1b[34m'
47 | reset = '\x1b[0m'
48 | block_color = text_color.colored(
49 | 'TEXT', color='white', background='blue'
50 | ).split('TEXT')[0]
51 | expected = (
52 | f'{origin_color}Hi {block_color}<< foo >>{reset}{origin_color}\n'
53 | '<# print x if x is present #>\n<% if x %>\n'
54 | f'{block_color}<< x >>{reset}{origin_color}\n'
55 | f'<% endif %>{reset}'
56 | )
57 | self.assertTrue(
58 | # On some termcolor versions, the color codes are not applied.
59 | colored_text == expected or colored_text == original_text
60 | )
61 | self.assertEqual(text_color.decolor(colored_text), original_text)
62 |
63 | def test_colored_block_without_full_match(self):
64 | self.assertEqual(
65 | text_color.colored_block(
66 | 'Hi {{ foo',
67 | '{{', '}}',
68 | color='white',
69 | background='blue',
70 | ),
71 | 'Hi {{ foo'
72 | )
73 |
74 | def test_colored_block_without_termcolor(self):
75 | termcolor = text_color.termcolor
76 | text_color.termcolor = None
77 | original_text = inspect.cleandoc("""
78 | Hi {{ foo }}
79 | {# print x if x is present #}
80 | {% if x %}
81 | {{ x }}
82 | {% endif %}
83 | """)
84 |
85 | colored_text = text_color.colored_block(
86 | text_color.colored(original_text, color='blue'),
87 | '{{', '}}',
88 | color='white',
89 | background='blue',
90 | )
91 | self.assertEqual(colored_text, original_text)
92 | self.assertEqual(text_color.decolor(colored_text), original_text)
93 | text_color.termcolor = termcolor
94 |
95 |
96 | if __name__ == '__main__':
97 | unittest.main()
98 |
--------------------------------------------------------------------------------
/pyglove/core/views/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """PyGlove views."""
15 |
16 | from pyglove.core.views import base
17 | from pyglove.core.views import html
18 |
19 | View = base.View
20 | view = base.view
21 | view_options = base.view_options
22 |
23 | # Pytype annotation.
24 | NodeFilter = base.NodeFilter
25 |
26 | Html = html.Html
27 | HtmlConvertible = html.HtmlConvertible
28 | HtmlView = html.HtmlView
29 | HtmlTreeView = html.HtmlTreeView
30 |
31 | to_html = html.to_html
32 | to_html_str = html.to_html_str
33 |
--------------------------------------------------------------------------------
/pyglove/core/views/html/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """HTML views for PyGlove objects."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=g-bad-import-order
18 |
19 | from pyglove.core.views.html.base import Html
20 | from pyglove.core.views.html.base import HtmlConvertible
21 | from pyglove.core.views.html.base import HtmlView
22 | from pyglove.core.views.html.base import to_html
23 | from pyglove.core.views.html.base import to_html_str
24 | from pyglove.core.views.html.tree_view import HtmlTreeView
25 |
26 | # pylint: enable=g-bad-import-order
27 | # pylint: enable=g-importing-member
28 |
--------------------------------------------------------------------------------
/pyglove/core/views/html/controls/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common HTML controls."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=g-bad-import-order
18 |
19 | from pyglove.core.views.html.controls.base import HtmlControl
20 |
21 |
22 | from pyglove.core.views.html.controls.label import Label
23 | from pyglove.core.views.html.controls.label import LabelGroup
24 | from pyglove.core.views.html.controls.label import Badge
25 |
26 | from pyglove.core.views.html.controls.tooltip import Tooltip
27 |
28 | from pyglove.core.views.html.controls.tab import Tab
29 | from pyglove.core.views.html.controls.tab import TabControl
30 |
31 | from pyglove.core.views.html.controls.progress_bar import ProgressBar
32 | from pyglove.core.views.html.controls.progress_bar import SubProgress
33 |
34 | # pylint: enable=g-bad-import-order
35 | # pylint: enable=g-importing-member
36 |
--------------------------------------------------------------------------------
/pyglove/core/views/html/controls/progress_bar_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | import unittest
16 |
17 | from pyglove.core import symbolic # pylint: disable=unused-import
18 | from pyglove.core.views.html.controls import progress_bar
19 |
20 |
21 | class ProgressBarTest(unittest.TestCase):
22 |
23 | def assert_html_content(self, control, expected):
24 | expected = inspect.cleandoc(expected).strip()
25 | actual = control.to_html().content.strip()
26 | if actual != expected:
27 | print(actual)
28 | self.assertEqual(actual, expected)
29 |
30 | def test_basic(self):
31 | bar = progress_bar.ProgressBar(
32 | subprogresses=[
33 | progress_bar.SubProgress('foo'),
34 | progress_bar.SubProgress('bar', 20),
35 | ],
36 | total=None,
37 | )
38 | self.assert_html_content(
39 | bar,
40 | (
41 | ''
45 | '
n/a'
48 | 'Not started
'
49 | )
50 | )
51 | self.assertEqual(bar['foo'], progress_bar.SubProgress('foo'))
52 | self.assertEqual(bar['bar'], progress_bar.SubProgress('bar', 20))
53 | with self.assertRaisesRegex(KeyError, 'Sub progress bar .* not found'):
54 | _ = bar['baz']
55 | self.assertIsNone(bar['foo'].total)
56 | self.assertIsNone(bar['foo'].width)
57 |
58 | bar.update(total=100)
59 | self.assertEqual(bar.total, 100)
60 | self.assertEqual(bar['foo'].total, 100)
61 | self.assertEqual(bar['foo'].width, '0%')
62 | self.assertEqual(bar['bar'].width, '20%')
63 | with bar.track_scripts() as scripts:
64 | bar['foo'].increment()
65 | self.assertEqual(
66 | scripts,
67 | [
68 | inspect.cleandoc(
69 | f"""
70 | elem = document.getElementById("{bar['foo'].element_id()}");
71 | elem.style = "width:1%;";
72 | """
73 | ),
74 | inspect.cleandoc(
75 | f"""
76 | elem = document.getElementById("{bar._progress_label.element_id()}");
77 | elem.textContent = " 21.0% (21/100)";
78 | """
79 | ),
80 | inspect.cleandoc(
81 | f"""
82 | elem = document.getElementById("{bar._progress_label.tooltip.element_id()}");
83 | elem.textContent = "foo: 1.0% (1/100)\\nbar: 20.0% (20/100)";
84 | """
85 | ),
86 | inspect.cleandoc(
87 | f"""
88 | elem = document.getElementById("{bar._progress_label.tooltip.element_id()}");
89 | elem.classList.remove("html-content");
90 | """
91 | ),
92 | ]
93 | )
94 |
95 |
96 | if __name__ == '__main__':
97 | unittest.main()
98 |
--------------------------------------------------------------------------------
/pyglove/core/views/html/controls/tooltip.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Html tooltip."""
15 |
16 | from typing import Optional, Union
17 | from pyglove.core import typing as pg_typing
18 | from pyglove.core.symbolic import object as pg_object
19 | # pylint: disable=g-importing-member
20 | from pyglove.core.views.html.base import Html
21 | from pyglove.core.views.html.controls.base import HtmlControl
22 | # pylint: disable=g-importing-member
23 |
24 |
25 | @pg_object.use_init_args(
26 | ['content', 'for_element', 'id', 'css_classes', 'styles']
27 | )
28 | class Tooltip(HtmlControl):
29 | """A tooltip control.
30 |
31 | Attributes:
32 | content: The content of the tooltip. It could be a string or a HTML object.
33 | id: The id of the tooltip.
34 | css_classes: The CSS classes for the tooltip.
35 | for_element: The CSS selector for the element to attach the tooltip to.
36 | e.g. '.my-element' or '#my-element'.
37 | """
38 |
39 | content: Union[str, Html]
40 | for_element: Optional[str] = None
41 |
42 | def _to_html(self, **kwargs):
43 | if self.for_element is None:
44 | raise ValueError(
45 | 'CSS selector `for_element` is required for tooltip to display.'
46 | )
47 | content = self.content
48 | if isinstance(self.content, str):
49 | content = Html.escape(self.content)
50 | return Html.element(
51 | 'span',
52 | [content],
53 | id=self.element_id(),
54 | css_classes=[
55 | 'tooltip',
56 | 'html-content' if isinstance(content, Html) else None,
57 | ] + self.css_classes,
58 | styles=self.styles,
59 | ).add_style(
60 | """
61 | span.tooltip {
62 | visibility: hidden;
63 | white-space: pre-wrap;
64 | font-weight: normal;
65 | background-color: #484848;
66 | color: #fff;
67 | padding: 10px;
68 | border-radius: 6px;
69 | position: absolute;
70 | z-index: 1;
71 | }
72 | span.tooltip:hover {
73 | visibility: visible;
74 | }
75 | .tooltip.html-content {
76 | white-space: inherit;
77 | background-color: white;
78 | color: inherit;
79 | box-shadow: rgba(0, 0, 0, 0.16) 0px 1px 4px;
80 | }
81 | """,
82 | f"""
83 | {self.for_element}:hover + .tooltip {{
84 | visibility: visible;
85 | }}
86 | """
87 | )
88 |
89 | def update(self, content: Union[str, Html]) -> None:
90 | self._sync_members(content=self._update_content(content))
91 | if isinstance(content, Html):
92 | self._add_css_class('html-content')
93 | else:
94 | self._remove_css_class('html-content')
95 |
96 |
97 | # Register converter for automatic conversion.
98 | pg_typing.register_converter(str, Tooltip, Tooltip)
99 | pg_typing.register_converter(Html, Tooltip, Tooltip)
100 |
--------------------------------------------------------------------------------
/pyglove/core/views/html/controls/tooltip_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | import unittest
16 |
17 | from pyglove.core import symbolic # pylint: disable=unused-import
18 | from pyglove.core.views.html import base
19 | from pyglove.core.views.html.controls import tooltip as tooltip_lib
20 |
21 |
22 | class TooltipTest(unittest.TestCase):
23 |
24 | def assert_html_content(self, control, expected):
25 | expected = inspect.cleandoc(expected).strip()
26 | actual = control.to_html().content.strip()
27 | if actual != expected:
28 | print(actual)
29 | self.assertEqual(actual, expected)
30 |
31 | def test_basic(self):
32 | tooltip = tooltip_lib.Tooltip('foo')
33 | with self.assertRaisesRegex(
34 | ValueError, 'CSS selector `for_element` is required'
35 | ):
36 | tooltip.to_html()
37 |
38 | tooltip = tooltip_lib.Tooltip('foo', for_element='.bar')
39 | self.assertEqual(tooltip.for_element, '.bar')
40 | self.assert_html_content(
41 | tooltip,
42 | 'foo'
43 | )
44 | self.assertIn(
45 | inspect.cleandoc(
46 | """
47 | .bar:hover + .tooltip {
48 | visibility: visible;
49 | }
50 | """
51 | ),
52 | tooltip.to_html().style_section,
53 | )
54 |
55 | def test_update(self):
56 | tooltip = tooltip_lib.Tooltip('foo', for_element='.bar', interactive=True)
57 | self.assertIn('id="control-', tooltip.to_html_str(content_only=True))
58 | with tooltip.track_scripts() as scripts:
59 | tooltip.update('normal text')
60 | self.assertEqual(tooltip.content, 'normal text')
61 | self.assertEqual(
62 | scripts,
63 | [
64 | inspect.cleandoc(
65 | f"""
66 | elem = document.getElementById("{tooltip.element_id()}");
67 | elem.textContent = "normal text";
68 | """
69 | ),
70 | inspect.cleandoc(
71 | f"""
72 | elem = document.getElementById("{tooltip.element_id()}");
73 | elem.classList.remove("html-content");
74 | """
75 | ),
76 | ]
77 | )
78 | with tooltip.track_scripts() as scripts:
79 | tooltip.update(base.Html('bold text'))
80 | self.assertEqual(
81 | scripts,
82 | [
83 | inspect.cleandoc(
84 | f"""
85 | elem = document.getElementById("{tooltip.element_id()}");
86 | elem.innerHTML = "bold text";
87 | """
88 | ),
89 | inspect.cleandoc(
90 | f"""
91 | elem = document.getElementById("{tooltip.element_id()}");
92 | elem.classList.add("html-content");
93 | """
94 | ),
95 | ]
96 | )
97 |
98 | if __name__ == '__main__':
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/pyglove/ext/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Package pyglove.generators."""
15 |
16 | from pyglove.ext import early_stopping
17 | from pyglove.ext import evolution
18 | from pyglove.ext import mutfun
19 | from pyglove.ext import scalars
20 |
--------------------------------------------------------------------------------
/pyglove/ext/early_stopping/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """PyGlove's built-in early stopping policies."""
15 |
16 | # pylint: disable=g-bad-import-order
17 |
18 | from pyglove.ext.early_stopping.base import EarlyStopingPolicyBase
19 | from pyglove.ext.early_stopping.base import And
20 | from pyglove.ext.early_stopping.base import Or
21 | from pyglove.ext.early_stopping.base import Not
22 |
23 | from pyglove.ext.early_stopping.step_wise import early_stop_by_rank
24 | from pyglove.ext.early_stopping.step_wise import early_stop_by_value
25 | from pyglove.ext.early_stopping.step_wise import StepWise
26 |
27 | # pylint: enable=g-bad-import-order
28 |
--------------------------------------------------------------------------------
/pyglove/ext/early_stopping/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base early stopping policy that is friendly to composition."""
15 |
16 | from typing import Iterable
17 | import pyglove.core as pg
18 |
19 |
20 | class EarlyStopingPolicyBase(pg.tuning.EarlyStoppingPolicy):
21 | """An early stopping policy base class that supports composition."""
22 |
23 | def __and__(self, other) -> pg.tuning.EarlyStoppingPolicy:
24 | """Operator &."""
25 | return And(self, other)
26 |
27 | def __or__(self, other) -> pg.tuning.EarlyStoppingPolicy:
28 | """Operator |."""
29 | return Or(self, other)
30 |
31 | def __neg__(self) -> pg.tuning.EarlyStoppingPolicy:
32 | """Operator -."""
33 | return Not(self)
34 |
35 | def __invert__(self) -> pg.tuning.EarlyStoppingPolicy:
36 | return Not(self)
37 |
38 |
39 | @pg.members([
40 | ('children',
41 | pg.typing.List(pg.typing.Object(pg.tuning.EarlyStoppingPolicy))),
42 | ], init_arg_list=['*children'])
43 | class Composite(EarlyStopingPolicyBase):
44 | """Base class for composite early stopping policies."""
45 |
46 | def recover(self, history: Iterable[pg.tuning.Trial]):
47 | for child in self.children:
48 | child.recover(history)
49 |
50 |
51 | class And(Composite):
52 | """Logical AND as a composite early stopping policy."""
53 |
54 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool:
55 | for child in self.children:
56 | if not child.should_stop_early(trial):
57 | return False
58 | return True
59 |
60 |
61 | class Or(Composite):
62 | """Logical OR as a composite early stopping policy."""
63 |
64 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool:
65 | for child in self.children:
66 | if child.should_stop_early(trial):
67 | return True
68 | return False
69 |
70 |
71 | class Not(Composite):
72 | """Logical OR as a composite early stopping policy."""
73 |
74 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool:
75 | assert len(self.children) == 1, self.children
76 | return not self.children[0].should_stop_early(trial)
77 |
--------------------------------------------------------------------------------
/pyglove/ext/early_stopping/base_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Iterable
16 | import unittest
17 | import pyglove.core as pg
18 | from pyglove.ext.early_stopping import base
19 |
20 |
21 | @pg.members([
22 | ('decision', pg.typing.Bool())
23 | ])
24 | class ConstantPolicy(base.EarlyStopingPolicyBase):
25 |
26 | def _on_bound(self):
27 | super()._on_bound()
28 | self.requested_trials = []
29 | self.recovered = False
30 |
31 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool:
32 | self.requested_trials.append(trial)
33 | return self.decision
34 |
35 | def recover(self, history: Iterable[pg.tuning.Trial]) -> None:
36 | self.recovered = True
37 |
38 |
39 | class EarlyStoppingPolicyComposabilityTest(unittest.TestCase):
40 | """Test the composability of early stopping policies."""
41 |
42 | def test_logical_and(self):
43 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0)
44 | x = ConstantPolicy(True)
45 | self.assertTrue(x.should_stop_early(t))
46 |
47 | y = ConstantPolicy(False)
48 | self.assertFalse(y.should_stop_early(t))
49 |
50 | self.assertTrue((x & x).should_stop_early(t))
51 | self.assertFalse((x & y).should_stop_early(t))
52 | self.assertFalse((y & x).should_stop_early(t))
53 | self.assertFalse((y & y).should_stop_early(t))
54 |
55 | def test_logical_or(self):
56 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0)
57 | self.assertTrue(
58 | (ConstantPolicy(True) | ConstantPolicy(True)).should_stop_early(t))
59 | self.assertTrue(
60 | (ConstantPolicy(True) | ConstantPolicy(False)).should_stop_early(t))
61 | self.assertTrue(
62 | (ConstantPolicy(False) | ConstantPolicy(True)).should_stop_early(t))
63 | self.assertFalse(
64 | (ConstantPolicy(False) | ConstantPolicy(False)).should_stop_early(t))
65 |
66 | def test_logical_not(self):
67 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0)
68 | self.assertFalse((~ConstantPolicy(True)).should_stop_early(t)) # pylint: disable=invalid-unary-operand-type
69 | self.assertTrue((-ConstantPolicy(False)).should_stop_early(t)) # pylint: disable=invalid-unary-operand-type
70 |
71 | def test_recorver(self):
72 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0)
73 | x = ConstantPolicy(True)
74 | y = ConstantPolicy(False)
75 | z = x & y
76 | z.recover([t])
77 | self.assertTrue(x.recovered)
78 | self.assertTrue(y.recovered)
79 |
80 |
81 | if __name__ == '__main__':
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """PyGlove's genetic computing framework."""
15 |
16 | from pyglove.ext import scalars
17 | from pyglove.ext.evolution import base
18 | from pyglove.ext.evolution import hill_climb as hill_climb_lib
19 | from pyglove.ext.evolution import mutators
20 | from pyglove.ext.evolution import neat as neat_lib
21 | from pyglove.ext.evolution import nsga2 as nsga2_lib
22 | from pyglove.ext.evolution import recombinators
23 | from pyglove.ext.evolution import regularized_evolution as regularized_evolution_lib
24 | from pyglove.ext.evolution import selectors
25 | from pyglove.ext.evolution.hill_climb import hill_climb
26 | from pyglove.ext.evolution.neat import neat
27 | from pyglove.ext.evolution.nsga2 import nsga2
28 | from pyglove.ext.evolution.regularized_evolution import regularized_evolution
29 |
30 | # Alias for backward compatibility.
31 | # Remove once dependencies are fixed.
32 | RegularizedEvolution = regularized_evolution
33 |
34 | # Interfaces.
35 | Operation = base.Operation
36 | DNAOperation = base.DNAOperation
37 | Selector = base.Selector
38 | Recombinator = base.Recombinator
39 | Mutator = base.Mutator
40 | Scalar = scalars.Scalar
41 |
42 | # The compositional evolution class.
43 | Evolution = base.Evolution
44 |
45 | # Compositional operators.
46 | # Common operations that are not associated with an operator.
47 | Lambda = base.Lambda
48 |
49 | Pipeline = base.Pipeline # operator >>
50 | Power = base.Power # operator **
51 |
52 | Concatenation = base.Concatenation # operator +
53 | Slice = base.Slice # operator []
54 | Repeat = base.Repeat # operator *
55 |
56 | Identity = base.Identity
57 | Union = base.Union # operator |
58 | Intersection = base.Intersection # operator &
59 | Difference = base.Difference # operator -
60 | SymmetricDifference = base.SymmetricDifference # operator ^
61 | Inversion = base.Inversion # operator ~
62 |
63 | Choice = base.Choice # .with_prob
64 | Conditional = base.Conditional # .if_true/.if_false
65 | ElementWise = base.ElementWise # .for_each
66 | Flatten = base.Flatten # .flatten
67 | UntilChange = base.UntilChange # .until_change
68 |
69 |
70 | GlobalStateGetter = base.GlobalStateGetter # .global_state
71 | GlobalStateSetter = base.GlobalStateSetter # .as_global_state
72 |
73 | # Helper method.
74 | scalar_spec = scalars.scalar_spec
75 | scalar_value = scalars.scalar_value
76 |
77 | set_fitness = base.set_fitness
78 | get_fitness = base.get_fitness
79 |
80 | get_generation_id = base.get_generation_id
81 | get_feedback_sequence_number = base.get_feedback_sequence_number
82 | is_initial_population = base.is_initial_population
83 |
84 |
85 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/hill_climb.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Generic Hill-Climbing Algorithm."""
15 |
16 | from typing import Optional
17 |
18 | import pyglove.core as pg
19 | from pyglove.ext.evolution import base
20 | from pyglove.ext.evolution import mutators
21 | from pyglove.ext.evolution import selectors
22 |
23 |
24 | def hill_climb(mutator=mutators.Uniform(),
25 | batch_size: int = 1,
26 | init_population_size: int = 1,
27 | seed: Optional[int] = None) -> base.Evolution:
28 | """Hill-Climbing algorithm, with an extra batched setting.
29 |
30 | Batched setting was shown to be effective in
31 | https://arxiv.org/pdf/1911.06317.pdf and https://arxiv.org/pdf/2003.01239.pdf,
32 | especially in noisy objective settings.
33 |
34 | Args:
35 | mutator: Mutator to use.
36 | batch_size: Number of mutations of the current best.
37 | init_population_size: Initial population size (randomly generated).
38 | seed: Random seed.
39 |
40 | Returns:
41 | An `Evolution` object.
42 | """
43 | return base.Evolution(
44 | selectors.Top(1) >> (mutator * batch_size), # pytype: disable=unsupported-operands
45 | population_init=(pg.geno.Random(seed), init_population_size),
46 | population_update=selectors.Top(1))
47 |
48 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/hill_climb_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for hill climbing algorithm."""
15 |
16 | import random
17 | import time
18 | import unittest
19 |
20 | import pyglove.core as pg
21 | from pyglove.ext.evolution import base
22 | from pyglove.ext.evolution import hill_climb_lib as hill_climb
23 |
24 |
25 | def get_trivial_search_space():
26 | """Trivial search space.
27 |
28 | Each point in the space is a value in [0, 1].
29 |
30 | Returns:
31 | A tunable value.
32 | """
33 | return pg.float_value(0.0, 1.0)
34 |
35 |
36 | class TrivialMutator(base.Mutator):
37 | """Mutator for trivial search space.
38 |
39 | Mutations can only change the value by a small amount.
40 | """
41 |
42 | def mutate(self, dna, step):
43 | del step
44 | dna.value = dna.value + random.uniform(-0.01, 0.01)
45 | if dna.value < 0.0:
46 | dna.value = 0.0
47 | if dna.value > 1.0:
48 | dna.value = 1.0
49 | return dna
50 |
51 |
52 | def trivial_reward(example):
53 | """Reward for the trivial search space.
54 |
55 | The reward (i.e. fitness) is the value itself. The goal of the search,
56 | therefore, is to find the value 1.
57 |
58 | Args:
59 | example: a materialized value.
60 |
61 | Returns:
62 | The corresponding reward.
63 | """
64 | return example
65 |
66 |
67 | def get_trivial_hash(search_space, algo):
68 | hashed_value = 0
69 | for example, _ in pg.iter(search_space, 30, algo):
70 | hashed_value ^= int(example * 1000000)
71 | return hashed_value
72 |
73 |
74 | class HillClimbingTest(unittest.TestCase):
75 |
76 | def test_integration(self):
77 | # Set up search space.
78 | search_space = get_trivial_search_space()
79 |
80 | # Search algorithm.
81 | algo = hill_climb.hill_climb(mutator=TrivialMutator(), batch_size=1)
82 |
83 | # Search.
84 | best_reward = None
85 | iters = 0
86 | start_time = time.time()
87 | while True:
88 | for example, feedback in pg.iter(search_space, 500, algo):
89 | reward = trivial_reward(example)
90 | feedback(reward)
91 | if best_reward is None or reward > best_reward:
92 | best_reward = reward
93 | iters += 1
94 | if reward >= 1.0:
95 | break
96 | if reward >= 1.0:
97 | break
98 | if time.time() - start_time > 600.0:
99 | self.fail('Took too long to find a solution.')
100 |
101 | def test_permanence(self):
102 | search_space = get_trivial_search_space()
103 | algo = hill_climb.hill_climb(mutator=TrivialMutator(), batch_size=1, seed=0)
104 |
105 | # If a CL causes the following assert to fail, it means that the CL is
106 | # causing a difference in the behavior of the evolutionary algorithms. If
107 | # this is expected (e.g. a change in the random number generator), then
108 | # simply update the hash to the new value.
109 | self.assertEqual(get_trivial_hash(search_space, algo), 789108)
110 |
111 |
112 | if __name__ == '__main__':
113 | unittest.main()
114 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/regularized_evolution.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Regularized Evolution: https://arxiv.org/abs/1802.01548."""
15 |
16 | from typing import Optional
17 |
18 | import pyglove.core as pg
19 | from pyglove.ext.evolution import base
20 | from pyglove.ext.evolution import mutators
21 | from pyglove.ext.evolution import selectors
22 |
23 |
24 | def regularized_evolution(
25 | mutator=mutators.Uniform(),
26 | population_size: int = 100,
27 | tournament_size: int = 10,
28 | seed: Optional[int] = None):
29 | """Regularized Evolution algorithm.
30 |
31 | https://www.aaai.org/ojs/index.php/AAAI/article/view/4405.
32 |
33 | Args:
34 | mutator: Mutator to use.
35 | population_size: Population size. Must be larger than tournament size.
36 | tournament_size: Tournament size.
37 | seed: Random seed. If None, the current system time is used as seed.
38 |
39 | Returns:
40 | An `Evolution` object.
41 | """
42 | if tournament_size < 2:
43 | raise ValueError(
44 | f'`tournament_size` must be no less than 2. '
45 | f'Encountered: {tournament_size}')
46 | if population_size < tournament_size:
47 | raise ValueError(
48 | f'The value of `population_size` ({population_size}) must be no '
49 | f'less than the value of `tournament_size` ({tournament_size}).')
50 | return base.Evolution(
51 | selectors.Random(
52 | tournament_size, seed=seed) >> selectors.Top(1) >> mutator,
53 | population_init=(pg.geno.Random(seed=seed), population_size),
54 | population_update=selectors.Last(population_size))
55 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/regularized_evolution_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for evolutionary algorithms."""
15 |
16 | import random
17 | import time
18 | import unittest
19 |
20 | import pyglove.core as pg
21 | from pyglove.ext.evolution import base
22 | from pyglove.ext.evolution import regularized_evolution_lib as regularized_evolution
23 |
24 |
25 | def get_trivial_search_space():
26 | """Trivial search space.
27 |
28 | Each point in the space is a value in [0, 1].
29 |
30 | Returns:
31 | A tunable value.
32 | """
33 | return pg.floatv(0.0, 1.0)
34 |
35 |
36 | @pg.members([
37 | ('seed', pg.typing.Int().noneable()),
38 | ])
39 | class TrivialMutator(base.Mutator):
40 | """Mutator for trivial search space.
41 |
42 | Mutations can only change the value by a small amount.
43 | """
44 |
45 | def _on_bound(self):
46 | super()._on_bound()
47 | self._random = random if self.seed is None else random.Random(self.seed)
48 |
49 | def mutate(self, dna, step):
50 | del step
51 | value = dna.value + self._random.uniform(-0.01, 0.01)
52 | if value < -1.0:
53 | value = -1.0
54 | if value > 1.0:
55 | value = 1.0
56 | return pg.DNA(value)
57 |
58 |
59 | def trivial_reward(example):
60 | """Reward for the trivial search space.
61 |
62 | The reward (i.e. fitness) is the value itself. The goal of the search,
63 | therefore, is to find the value 1.
64 |
65 | Args:
66 | example: a materialized value.
67 |
68 | Returns:
69 | The corresponding reward.
70 | """
71 | return example
72 |
73 |
74 | def get_trivial_hash(search_space, algo):
75 | hashed_value = 0
76 | for example, feedback in pg.iter(search_space, 30, algo):
77 | hashed_value ^= int(example * 1000000)
78 | feedback(example)
79 | return hashed_value
80 |
81 |
82 | class RegularizedEvolutionTest(unittest.TestCase):
83 |
84 | def test_integration(self):
85 | # Set up search space.
86 | search_space = get_trivial_search_space()
87 |
88 | # Search algorithm.
89 | algo = regularized_evolution.regularized_evolution(
90 | population_size=10, tournament_size=2, mutator=TrivialMutator())
91 |
92 | # Search.
93 | best_reward = None
94 | iters = 0
95 | start_time = time.time()
96 | while True:
97 | for example, feedback in pg.iter(search_space, 100, algo):
98 | reward = trivial_reward(example)
99 | feedback(reward)
100 | if best_reward is None or reward > best_reward:
101 | best_reward = reward
102 | iters += 1
103 | if reward >= 1.0:
104 | break
105 | if reward >= 1.0:
106 | break
107 | if time.time() - start_time > 300.0:
108 | self.fail('Took too long to find a solution.')
109 |
110 | def test_permanence(self):
111 | search_space = get_trivial_search_space()
112 | algo = regularized_evolution.regularized_evolution(
113 | population_size=10, tournament_size=2, mutator=TrivialMutator(seed=1),
114 | seed=1)
115 |
116 | # If a CL causes the following assert to fail, it means that the CL is
117 | # causing a difference in the behavior of the evolutionary algorithms. If
118 | # this is expected (e.g. a change in the random number generator), then
119 | # simply update the hash to the new value.
120 | self.assertEqual(get_trivial_hash(search_space, algo), 385892)
121 |
122 |
123 | if __name__ == '__main__':
124 | unittest.main()
125 |
--------------------------------------------------------------------------------
/pyglove/ext/evolution/where_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Test for evolution decision point filters."""
15 |
16 | import unittest
17 | import pyglove.core as pg
18 | from pyglove.ext.evolution import where
19 |
20 |
21 | class DecisionPointFiltersTest(unittest.TestCase):
22 | """Tests for DecisionPointFilter subclasses."""
23 |
24 | def test_lambda(self):
25 | w = where.Lambda(lambda xs: xs[:1])
26 | output = w([pg.geno.Float(0., float(i + 1)) for i in range(5)])
27 | self.assertEqual(len(output), 1)
28 |
29 | def test_all(self):
30 | w = where.ALL
31 | inputs = [pg.geno.Float(0., float(i + 1)) for i in range(5)]
32 | output = w(inputs)
33 | self.assertIs(inputs, output)
34 |
35 | def test_any(self):
36 | w = where.Any(seed=1)
37 | inputs = [pg.geno.Float(0., float(i + 1)) for i in range(5)]
38 | output = w(inputs)
39 | self.assertEqual(len(output), 1)
40 | self.assertEqual(output[0].max_value, 2.0)
41 |
42 | w = where.Any(4, seed=2)
43 | output = w(inputs)
44 | self.assertEqual(len(output), 4)
45 | # Check output is sorted.
46 | self.assertEqual([v.max_value for v in output], [1., 2., 4., 5.])
47 |
48 | w = where.Any(10, seed=1)
49 | output = w(inputs)
50 | self.assertIs(inputs, output)
51 |
52 | def test_automatic_conversion(self):
53 |
54 | @pg.members([
55 | ('where', where.where_spec())
56 | ])
57 | class MyFilter(pg.Object):
58 | pass
59 |
60 | w = MyFilter(lambda xs: xs[:2])
61 | self.assertIsInstance(w.where, where.Lambda)
62 |
63 |
64 | if __name__ == '__main__':
65 | unittest.main()
66 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Step-based scalars used as hyper-parameters for PyGlove algorithms."""
15 |
16 | # pylint: disable=g-bad-import-order
17 |
18 | # Helper methods for creating and retrieving scalars.
19 | from pyglove.ext.scalars.base import scalar_spec
20 | from pyglove.ext.scalars.base import scalar_value
21 | from pyglove.ext.scalars.base import make_scalar
22 |
23 | # Interface and common scalars.
24 | from pyglove.ext.scalars.base import Scalar
25 |
26 | from pyglove.ext.scalars.base import Lambda
27 | from pyglove.ext.scalars.base import Constant
28 | from pyglove.ext.scalars.base import STEP
29 |
30 | from pyglove.ext.scalars.base import UnaryOp
31 | from pyglove.ext.scalars.base import Negation
32 | from pyglove.ext.scalars.base import Absolute
33 | from pyglove.ext.scalars.base import Floor
34 | from pyglove.ext.scalars.base import Ceiling
35 |
36 | from pyglove.ext.scalars.base import BinaryOp
37 | from pyglove.ext.scalars.base import Addition
38 | from pyglove.ext.scalars.base import Substraction
39 | from pyglove.ext.scalars.base import Multiplication
40 | from pyglove.ext.scalars.base import Division
41 | from pyglove.ext.scalars.base import Mod
42 | from pyglove.ext.scalars.base import Power
43 |
44 | # Common math functions.
45 | from pyglove.ext.scalars.maths import linear
46 | from pyglove.ext.scalars.maths import cosine_decay
47 | from pyglove.ext.scalars.maths import exponential_decay
48 | from pyglove.ext.scalars.maths import cyclic
49 | from pyglove.ext.scalars.maths import sqrt
50 | from pyglove.ext.scalars.maths import exp
51 | from pyglove.ext.scalars.maths import log
52 | from pyglove.ext.scalars.maths import cos
53 | from pyglove.ext.scalars.maths import sin
54 |
55 | from pyglove.ext.scalars.maths import SquareRoot
56 | from pyglove.ext.scalars.maths import Exp
57 | from pyglove.ext.scalars.maths import Log
58 | from pyglove.ext.scalars.maths import Cosine
59 | from pyglove.ext.scalars.maths import Sine
60 |
61 | # Common random scalars.
62 | from pyglove.ext.scalars.randoms import RandomScalar
63 | from pyglove.ext.scalars.randoms import Uniform
64 | from pyglove.ext.scalars.randoms import Triangular
65 | from pyglove.ext.scalars.randoms import Gaussian
66 | from pyglove.ext.scalars.randoms import Normal
67 | from pyglove.ext.scalars.randoms import LogNormal
68 |
69 | # Step-wise scalar.
70 | from pyglove.ext.scalars.step_wise import StepWise
71 |
72 | # pylint: enable=g-bad-import-order
73 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/base_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for base scalars."""
15 |
16 | import unittest
17 | from pyglove.ext.scalars import base as scalars
18 |
19 |
20 | class BasicScalarTest(unittest.TestCase):
21 | """Test basic scalars."""
22 |
23 | def test_make_scalar(self):
24 | sv = scalars.make_scalar(scalars.Constant(1))
25 | self.assertIsInstance(sv, scalars.Scalar)
26 | self.assertEqual(sv(0), 1)
27 | self.assertEqual(sv(10), 1)
28 |
29 | sv = scalars.make_scalar(1)
30 | self.assertIsInstance(sv, scalars.Scalar)
31 | self.assertIsInstance(sv(0), int)
32 | self.assertEqual(sv(0), 1)
33 | self.assertEqual(sv(10), 1)
34 |
35 | sv = scalars.make_scalar(lambda step: step)
36 | self.assertIsInstance(sv, scalars.Scalar)
37 | self.assertEqual(sv(1), 1)
38 | self.assertEqual(sv(10), 10)
39 |
40 | def test_step(self):
41 | sv = scalars.STEP * 2
42 | self.assertEqual(sv(0), 0)
43 | self.assertEqual(sv(10), 20)
44 |
45 |
46 | class UnaryOpTest(unittest.TestCase):
47 | """Tests for unary scalar operators."""
48 |
49 | def test_negation(self):
50 | sv = -scalars.STEP
51 | self.assertEqual(sv(1), -1)
52 | self.assertEqual(sv(2), -2)
53 |
54 | def test_floor(self):
55 | sv = scalars.Constant(1.6).floor()
56 | self.assertEqual(sv(0), 1)
57 |
58 | def test_ceil(self):
59 | sv = scalars.Constant(1.6).ceil()
60 | self.assertEqual(sv(0), 2)
61 |
62 | def test_abs(self):
63 | sv = abs(scalars.Constant(-1))
64 | self.assertEqual(sv(0), 1)
65 |
66 |
67 | class BinaryOpTest(unittest.TestCase):
68 | """Tests for binary scalar operators."""
69 |
70 | def test_add(self):
71 | sv = scalars.Constant(1) + 2
72 | self.assertEqual(sv(0), 3)
73 |
74 | sv = 2 + scalars.Constant(1)
75 | self.assertEqual(sv(0), 3)
76 |
77 | def test_substract(self):
78 | sv = scalars.Constant(1) - 2
79 | self.assertEqual(sv(0), -1)
80 |
81 | sv = 2 - scalars.Constant(1)
82 | self.assertEqual(sv(0), 1)
83 |
84 | def test_multiply(self):
85 | sv = scalars.Constant(1) * 2
86 | self.assertEqual(sv(0), 2)
87 |
88 | sv = 2 * scalars.Constant(1)
89 | self.assertEqual(sv(0), 2)
90 |
91 | def test_divide(self):
92 | sv = scalars.Constant(1) / 2
93 | self.assertEqual(sv(0), 0.5)
94 |
95 | sv = 2 / scalars.Constant(1)
96 | self.assertEqual(sv(0), 2)
97 |
98 | def test_floor_divide(self):
99 | sv = scalars.Constant(1) // 2
100 | self.assertEqual(sv(0), 0)
101 |
102 | sv = 2 // scalars.Constant(1)
103 | self.assertEqual(sv(0), 2)
104 |
105 | def test_mod(self):
106 | sv = scalars.Constant(2) % 3
107 | self.assertEqual(sv(0), 2)
108 |
109 | sv = 3 % scalars.Constant(2)
110 | self.assertEqual(sv(0), 1)
111 |
112 | def test_power(self):
113 | sv = scalars.Constant(2) ** 3
114 | self.assertEqual(sv(0), 8)
115 |
116 | sv = 3 ** scalars.Constant(2)
117 | self.assertEqual(sv(0), 9)
118 |
119 |
120 | if __name__ == '__main__':
121 | unittest.main()
122 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/maths.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Step-based scalars used as evolution hyper-parameter values."""
15 |
16 | import math
17 | import pyglove.core as pg
18 | from pyglove.ext.scalars import base
19 |
20 |
21 | class SquareRoot(base.UnaryOp):
22 | """The square root scalar."""
23 |
24 | def operate(self, x: float) -> float:
25 | return math.sqrt(x)
26 |
27 | sqrt = SquareRoot # pylint: disable=invalid-name
28 |
29 |
30 | @pg.members([])
31 | class Exp(base.UnaryOp):
32 | """More accurate version for math.e ** x."""
33 |
34 | def operate(self, x: float) -> float:
35 | return math.exp(x)
36 |
37 | exp = Exp # pylint: disable=invalid-name
38 |
39 |
40 | @pg.members([
41 | ('x', base.scalar_spec(pg.typing.Union([
42 | pg.typing.Int(min_value=2),
43 | pg.typing.Float(min_value=0.0)]))),
44 | ('base', base.scalar_spec(pg.typing.Union([
45 | pg.typing.Int(min_value=2),
46 | pg.typing.Float(min_value=0.0)])).set_default(math.e),
47 | 'Base of the log function.'),
48 | ])
49 | class Log(base.Scalar):
50 | """A log scheduled float."""
51 |
52 | def _on_bound(self):
53 | super()._on_bound()
54 | self._x = base.make_scalar(self.x)
55 | self._base = base.make_scalar(self.base)
56 |
57 | def call(self, step: int) -> float:
58 | return math.log(self._x(step), self._base(step))
59 |
60 | log = Log # pylint: disable=invalid-name
61 |
62 |
63 | @pg.members([])
64 | class Cosine(base.UnaryOp):
65 | """Cosine that works for scalars."""
66 |
67 | def operate(self, x: float) -> float:
68 | return math.cos(x)
69 |
70 | cos = Cosine # pylint: disable=invalid-name
71 |
72 |
73 | @pg.members([])
74 | class Sine(base.UnaryOp):
75 | """Sine that works for scalars."""
76 |
77 | def operate(self, x: float) -> float:
78 | return math.sin(x)
79 |
80 | sin = Sine # pylint: disable=invalid-name
81 |
82 |
83 | #
84 | # Helper function for create popular scalar scheddule.
85 | #
86 |
87 |
88 | def linear(total_steps: int, start: float = 1.0, end: float = 0.0):
89 | """Returns a linear scalar from start to end."""
90 | return start + base.STEP * ((end - start) / total_steps)
91 |
92 |
93 | def cosine_decay(total_steps: int, start: float = 1.0, end: float = 0.0):
94 | """Returns a cosine decayed scalar from start to end."""
95 | return 0.5 * (start - end) * (
96 | 1 + cos(math.pi * base.STEP / total_steps)) + end
97 |
98 |
99 | def exponential_decay(
100 | decay_rate: float, decay_interval: int,
101 | start: float = 1.0, staircase: bool = True):
102 | """Returns a scalar that exponentially decays from start to end."""
103 | exponent = base.STEP / float(decay_interval)
104 | if staircase:
105 | exponent = exponent.floor()
106 | return start * (decay_rate ** exponent)
107 |
108 |
109 | def cyclic(cycle: int, initial_radiant: float = 0.0,
110 | high: float = 1.0, low: float = 0.0):
111 | """Returns a cyclic scalar using sin/cos."""
112 | return 0.5 * (high - low) * (
113 | 1 + cos(initial_radiant + math.pi * 2 * base.STEP / cycle)) + low
114 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/randoms_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for random scalars."""
15 |
16 | import unittest
17 | from pyglove.ext.scalars import randoms as scalars
18 |
19 |
20 | class RandomScalarsTest(unittest.TestCase):
21 | """Random scalars tests."""
22 |
23 | def test_uniform(self):
24 | sv = scalars.Uniform(seed=1)
25 | self.assertEqual(sv(0), 0.13436424411240122)
26 | self.assertEqual(sv(0), 0.8474337369372327)
27 |
28 | sv = scalars.Uniform(1, 10, seed=1)
29 | self.assertEqual(sv(0), 3)
30 | self.assertEqual(sv(0), 10)
31 |
32 | with self.assertRaisesRegex(
33 | ValueError,
34 | '`low` must be less or equal than `high`.'):
35 | scalars.Uniform(10, 1, seed=1)
36 |
37 | def test_triangular(self):
38 | sv = scalars.Triangular(0.0, 1.0, 0.9, seed=1)
39 | self.assertEqual(sv(0), 0.34774677525630787)
40 | self.assertEqual(sv(0), 0.8733214547023962)
41 |
42 | sv = scalars.Triangular(10, 20, seed=1)
43 | self.assertEqual(sv(0), 12)
44 | self.assertEqual(sv(0), 17)
45 |
46 | def test_gaussian(self):
47 | sv = scalars.Gaussian(1.0, 0.2, seed=1)
48 | self.assertEqual(sv(0), 1.2576369506310927)
49 | self.assertEqual(sv(0), 1.2898891217399542)
50 |
51 | def test_normal(self):
52 | sv = scalars.Normal(1.0, 0.2, seed=1)
53 | self.assertEqual(sv(0), 1.1214911715287412)
54 | self.assertEqual(sv(0), 0.997154910897843)
55 |
56 | def test_log_normal(self):
57 | sv = scalars.LogNormal(1.0, 0.2, seed=1)
58 | self.assertEqual(sv(0), 3.0694278358084994)
59 | self.assertEqual(sv(0), 2.710559065635824)
60 |
61 |
62 | if __name__ == '__main__':
63 | unittest.main()
64 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/step_wise.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Step-based scalars used as evolution hyper-parameter values."""
15 |
16 |
17 | from typing import Any
18 | import pyglove.core as pg
19 | from pyglove.ext.scalars import base
20 |
21 |
22 | #
23 | # Scheduled values that can be designed in multiple phases.
24 | #
25 |
26 |
27 | @pg.members([
28 | ('phases', pg.typing.List(
29 | pg.typing.Tuple([
30 | pg.typing.Union([pg.typing.Int(min_value=0),
31 | pg.typing.Float(min_value=0.)]),
32 | base.scalar_spec(pg.typing.Any())
33 | ]), min_size=1),
34 | ('All the phases in the schedule. Each item in the list is a tuple of '
35 | '`(length of phase, scheduled value)`. The length of phase can be an '
36 | 'integer representing number of steps used for that phase, or a float as '
37 | 'the proportion of that phase if `total_steps` is specified. All items '
38 | 'in the list should use the same type (integer or float) for the length '
39 | 'of phase. When a proportion is used, their sum does not have to sum up '
40 | 'to 1.')),
41 | ('total_steps', pg.typing.Int(min_value=1).noneable(),
42 | ('Total number of steps for the schedule. If None, the length of each '
43 | 'phase must be an integer.'))
44 | ])
45 | class StepWise(base.Scalar):
46 | """A step-wise schedule that is specified via multiple phases."""
47 |
48 | def _on_bound(self):
49 | super()._on_bound()
50 |
51 | last_step = 0
52 | phase_ending_steps = []
53 | if self.total_steps is None:
54 | for phase_len, phase_value in self.phases:
55 | if isinstance(phase_len, float):
56 | raise ValueError(
57 | f'`total_steps` must be specified when float is used as the '
58 | f'value for phase length. '
59 | f'Encountered: ({phase_len}, {phase_value}).')
60 | last_step += phase_len
61 | phase_ending_steps.append(last_step - 1)
62 | else:
63 | proportion_sum = 0.
64 | for proportion, phase_value in self.phases:
65 | if isinstance(proportion, int):
66 | raise ValueError(
67 | f'The phase length should be a float as a proportion of the '
68 | f'entire schedule when `total_steps` is specified. '
69 | f'Encountered: ({proportion}, {phase_value}).')
70 | proportion_sum += proportion
71 |
72 | if proportion_sum == 0:
73 | raise ValueError(
74 | f'The sum of all proportions must be greater than 0. '
75 | f'Encountered: {self.phases!r}')
76 |
77 | for proportion, _ in self.phases:
78 | phase_len = int(proportion / proportion_sum * self.total_steps)
79 | last_step += phase_len
80 | phase_ending_steps.append(last_step - 1)
81 | # Phase ending step is the step AFTER which the next phase will start.
82 | self._phase_ending_steps = phase_ending_steps
83 | self._phases = [base.make_scalar(p) for l, p in self.phases]
84 | self._current_phase = 0
85 | self._last_value = None
86 |
87 | def call(self, step: int) -> Any:
88 | if self._current_phase < len(self.phases):
89 | if self._current_phase > 0:
90 | phase_step = step - (
91 | self._phase_ending_steps[self._current_phase - 1] + 1)
92 | else:
93 | phase_step = step
94 | self._last_value = self._phases[self._current_phase](phase_step)
95 | if step == self._phase_ending_steps[self._current_phase]:
96 | self._current_phase += 1
97 | return self._last_value
98 |
--------------------------------------------------------------------------------
/pyglove/ext/scalars/step_wise_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for step-wise scalars."""
15 |
16 | import unittest
17 | from pyglove.ext.scalars import base
18 | from pyglove.ext.scalars import step_wise as scalars
19 |
20 |
21 | class StepWiseScalarTest(unittest.TestCase):
22 | """Test step-wise scalar schedule."""
23 |
24 | def test_step_as_phase_length(self):
25 | sv = scalars.StepWise([
26 | (2, 1),
27 | (2, base.STEP),
28 | (3, base.STEP ** 2)
29 | ])
30 | self.assertEqual([sv(i) for i in range(10)], [
31 | # For each phase, base.STEP
32 | # is evaluated to 0 when phase starts.
33 | 1, 1, # Phase 1
34 | 0, 1, # Phase 2
35 | 0, 1, 4, # Phase 3
36 | 4, 4, 4 # Use the last value for the rest.
37 | ])
38 |
39 | def test_proportion_as_phase_length(self):
40 | sv = scalars.StepWise([
41 | (0.2, 1),
42 | (0.2, base.STEP),
43 | (0.3, base.STEP ** 2)
44 | ], total_steps=8)
45 | self.assertEqual([sv(i) for i in range(10)], [
46 | # For each phase, base.STEP
47 | # is evaluated to 0 when phase starts.
48 | 1, 1, # Phase 1
49 | 0, 1, # Phase 2
50 | 0, 1, 4, # Phase 3
51 | 4, 4, 4 # Use the last value for the rest.
52 | ])
53 |
54 | def test_bad_specification(self):
55 | with self.assertRaisesRegex(
56 | ValueError,
57 | '`total_steps` must be specified when float is used as the value'):
58 | _ = scalars.StepWise([
59 | (0.2, 1),
60 | (0.2, base.STEP),
61 | (0.3, base.STEP ** 2)])
62 |
63 | with self.assertRaisesRegex(
64 | ValueError,
65 | 'The sum of all proportions must be greater than 0'):
66 | _ = scalars.StepWise([
67 | (0.0, 1),
68 | (0.0, base.STEP),
69 | (0.0, base.STEP ** 2)
70 | ], total_steps=10)
71 |
72 | with self.assertRaisesRegex(
73 | ValueError,
74 | 'The phase length should be a float as a proportion of the'):
75 | _ = scalars.StepWise([
76 | (1, 1),
77 | (2, base.STEP),
78 | (3, base.STEP ** 2)
79 | ], total_steps=10)
80 |
81 |
82 | if __name__ == '__main__':
83 | unittest.main()
84 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | docstring-parser>=0.12
2 | termcolor>=1.1.0
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PyGlove Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Setup for pip package."""
15 |
16 | import datetime
17 | import sys
18 | from setuptools import find_namespace_packages
19 | from setuptools import setup
20 |
21 |
22 | def _get_version():
23 | """Gets current version of PyGlove package."""
24 | with open('pyglove/__init__.py') as fp:
25 | version = None
26 | for line in fp:
27 | if line.startswith('__version__'):
28 | g = {}
29 | exec(line, g) # pylint: disable=exec-used
30 | version = g['__version__']
31 | break
32 | if version is None:
33 | raise ValueError('`__version__` not defined in `pyglove/__init__.py`')
34 | if '--nightly' in sys.argv:
35 | nightly_label = datetime.datetime.now().strftime('%Y%m%d%H%M')
36 | version = f'{version}.dev{nightly_label}'
37 | sys.argv.remove('--nightly')
38 | return version
39 |
40 |
41 | def _parse_requirements(requirements_txt_path: str) -> list[str]:
42 | """Returns a list of dependencies for setup() from requirements.txt."""
43 |
44 | def _strip_comments_from_line(s: str) -> str:
45 | """Parses a line of a requirements.txt file."""
46 | requirement, *_ = s.split('#')
47 | return requirement.strip()
48 |
49 | # Currently a requirements.txt is being used to specify dependencies. In order
50 | # to avoid specifying it in two places, we're going to use that file as the
51 | # source of truth.
52 | with open(requirements_txt_path) as fp:
53 | # Parse comments.
54 | lines = [_strip_comments_from_line(line) for line in fp.read().splitlines()]
55 | # Remove empty lines and direct github repos (not allowed in PyPI setups)
56 | return [l for l in lines if (l and 'github.com' not in l)]
57 |
58 |
59 | _VERSION = _get_version()
60 |
61 | setup(
62 | name='pyglove',
63 | version=_VERSION,
64 | url='https://github.com/google/pyglove',
65 | license='Apache License 2.0',
66 | author='PyGlove Authors',
67 | description='PyGlove: A library for manipulating Python objects.',
68 | long_description=open('README.md').read(),
69 | long_description_content_type='text/markdown',
70 | author_email='pyglove-authors@google.com',
71 | # Contained modules and scripts.
72 | packages=find_namespace_packages(include=['pyglove*']),
73 | install_requires=_parse_requirements('requirements.txt'),
74 | extras_require={},
75 | requires_python='>=3.9',
76 | include_package_data=True,
77 | # PyPI package information.
78 | classifiers=[
79 | 'Development Status :: 5 - Production/Stable',
80 | 'Intended Audience :: Developers',
81 | 'Intended Audience :: Education',
82 | 'Intended Audience :: Science/Research',
83 | 'License :: OSI Approved :: Apache Software License',
84 | 'Programming Language :: Python :: 3',
85 | 'Programming Language :: Python :: 3.9',
86 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
87 | 'Topic :: Scientific/Engineering :: Human Machine Interfaces',
88 | 'Topic :: Software Development :: Code Generators',
89 | 'Topic :: Software Development :: Libraries :: Python Modules',
90 | 'Topic :: Software Development :: Libraries',
91 | ],
92 | keywords=(
93 | 'ai machine learning automl mutable symbolic '
94 | 'framework meta-programming'),
95 | )
96 |
--------------------------------------------------------------------------------