├── .coveragerc ├── .github └── workflows │ ├── ci.yaml │ ├── docgen_test.yaml │ ├── pypi-nightly.yaml │ └── pypi.yaml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── _static │ ├── favicon.svg │ ├── logo.svg │ ├── logo_dark.svg │ ├── logo_light.svg │ └── style.css ├── _templates │ └── layout.html ├── api │ ├── _templates │ │ ├── _default │ │ │ ├── class.rst │ │ │ ├── function.rst │ │ │ ├── module.rst │ │ │ └── object.rst │ │ ├── core │ │ │ └── symbolic │ │ │ │ ├── dict.rst │ │ │ │ ├── list.rst │ │ │ │ ├── object.rst │ │ │ │ └── symbolic.rst │ │ └── index.rst │ └── docgen.py ├── conf.py ├── generate_docs.sh ├── guide │ ├── automl │ │ ├── index.rst │ │ ├── industrial.rst │ │ ├── research.rst │ │ └── toy.rst │ ├── evolution │ │ └── index.rst │ ├── general │ │ └── index.rst │ ├── index.rst │ └── ml │ │ └── index.rst ├── index.rst ├── learn │ ├── evolution │ │ └── index.rst │ ├── how_pyglove_works.svg │ ├── index.rst │ └── soop │ │ ├── detour.rst │ │ ├── index.rst │ │ └── som │ │ ├── definition.rst │ │ ├── duality.svg │ │ ├── events.rst │ │ ├── object_layout.svg │ │ ├── operations.rst │ │ ├── placeholding.rst │ │ ├── types.rst │ │ └── validation.rst ├── notebooks │ ├── evolution │ │ ├── function_regression.ipynb │ │ ├── onemax.ipynb │ │ └── tsp.ipynb │ ├── gui │ │ └── html_view.ipynb │ ├── intro │ │ ├── basics │ │ │ ├── runtime_typing.ipynb │ │ │ └── symbolic_function.ipynb │ │ ├── birdview.ipynb │ │ └── search │ │ │ ├── evolution_algorithm.ipynb │ │ │ ├── evolution_ops.ipynb │ │ │ └── evolution_scheduling.ipynb │ ├── ml │ │ ├── efficiently_exchange_ml_ideas_as_code.ipynb │ │ ├── neural_modeling.ipynb │ │ └── symbolic_ml.ipynb │ └── python │ │ ├── interactive_svg.ipynb │ │ ├── sticky_notes.ipynb │ │ └── where_is_the_duck.ipynb └── requirements.txt ├── examples ├── automl │ ├── mnist │ │ ├── README.md │ │ ├── mnist_train.py │ │ ├── mnist_tune.py │ │ ├── mnist_tune_eagerly.py │ │ ├── mnist_tune_hparams.py │ │ └── pytorch │ │ │ ├── mnist_train.py │ │ │ └── mnist_tune_eagerly.py │ ├── nasbench │ │ └── nasbench.py │ └── natsbench │ │ └── natsbench.py ├── evolution │ ├── onemax.py │ └── tsp.py ├── ml │ └── symbolic_modeling.py └── python │ └── runtime_typing.py ├── pyglove ├── __init__.py ├── core │ ├── __init__.py │ ├── coding │ │ ├── __init__.py │ │ ├── errors.py │ │ ├── errors_test.py │ │ ├── execution.py │ │ ├── execution_test.py │ │ ├── function_generation.py │ │ ├── function_generation_test.py │ │ ├── parsing.py │ │ ├── parsing_test.py │ │ ├── permissions.py │ │ └── permissions_test.py │ ├── detouring │ │ ├── __init__.py │ │ ├── class_detour.py │ │ └── class_detour_test.py │ ├── geno │ │ ├── __init__.py │ │ ├── base.py │ │ ├── base_test.py │ │ ├── categorical.py │ │ ├── categorical_test.py │ │ ├── custom.py │ │ ├── custom_test.py │ │ ├── deduping.py │ │ ├── deduping_test.py │ │ ├── dna_generator.py │ │ ├── dna_generator_test.py │ │ ├── numerical.py │ │ ├── numerical_test.py │ │ ├── random.py │ │ ├── random_test.py │ │ ├── space.py │ │ ├── space_test.py │ │ ├── sweeping.py │ │ └── sweeping_test.py │ ├── hyper │ │ ├── __init__.py │ │ ├── base.py │ │ ├── categorical.py │ │ ├── categorical_test.py │ │ ├── custom.py │ │ ├── custom_test.py │ │ ├── derived.py │ │ ├── derived_test.py │ │ ├── dynamic_evaluation.py │ │ ├── dynamic_evaluation_test.py │ │ ├── evolvable.py │ │ ├── evolvable_test.py │ │ ├── iter.py │ │ ├── iter_test.py │ │ ├── numerical.py │ │ ├── numerical_test.py │ │ ├── object_template.py │ │ └── object_template_test.py │ ├── io │ │ ├── __init__.py │ │ ├── file_system.py │ │ ├── file_system_test.py │ │ ├── sequence.py │ │ └── sequence_test.py │ ├── logging.py │ ├── logging_test.py │ ├── patching │ │ ├── __init__.py │ │ ├── object_factory.py │ │ ├── object_factory_test.py │ │ ├── pattern_based.py │ │ ├── pattern_based_test.py │ │ ├── rule_based.py │ │ └── rule_based_test.py │ ├── symbolic │ │ ├── __init__.py │ │ ├── base.py │ │ ├── base_test.py │ │ ├── boilerplate.py │ │ ├── boilerplate_test.py │ │ ├── class_wrapper.py │ │ ├── class_wrapper_test.py │ │ ├── compounding.py │ │ ├── compounding_test.py │ │ ├── contextual_object.py │ │ ├── contextual_object_test.py │ │ ├── dict.py │ │ ├── dict_test.py │ │ ├── diff.py │ │ ├── diff_test.py │ │ ├── error_info.py │ │ ├── error_info_test.py │ │ ├── flags.py │ │ ├── flags_test.py │ │ ├── functor.py │ │ ├── functor_test.py │ │ ├── inferred.py │ │ ├── inferred_test.py │ │ ├── list.py │ │ ├── list_test.py │ │ ├── object.py │ │ ├── object_test.py │ │ ├── origin.py │ │ ├── origin_test.py │ │ ├── pure_symbolic.py │ │ ├── ref.py │ │ ├── ref_test.py │ │ ├── symbolize.py │ │ └── symbolize_test.py │ ├── tuning │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── backend_test.py │ │ ├── early_stopping.py │ │ ├── local_backend.py │ │ ├── protocols.py │ │ ├── protocols_test.py │ │ ├── sample.py │ │ └── sample_test.py │ ├── typing │ │ ├── __init__.py │ │ ├── annotated.py │ │ ├── annotated_test.py │ │ ├── annotation_conversion.py │ │ ├── annotation_conversion_test.py │ │ ├── annotation_future_test.py │ │ ├── callable_ext.py │ │ ├── callable_ext_test.py │ │ ├── callable_signature.py │ │ ├── callable_signature_test.py │ │ ├── class_schema.py │ │ ├── class_schema_test.py │ │ ├── custom_typing.py │ │ ├── inspect.py │ │ ├── inspect_test.py │ │ ├── json_schema.py │ │ ├── json_schema_test.py │ │ ├── key_specs.py │ │ ├── key_specs_test.py │ │ ├── pytype_support.py │ │ ├── type_conversion.py │ │ ├── type_conversion_test.py │ │ ├── typed_missing.py │ │ ├── typed_missing_test.py │ │ ├── value_specs.py │ │ └── value_specs_test.py │ ├── utils │ │ ├── __init__.py │ │ ├── common_traits.py │ │ ├── common_traits_test.py │ │ ├── contextual.py │ │ ├── contextual_test.py │ │ ├── docstr_utils.py │ │ ├── docstr_utils_test.py │ │ ├── error_utils.py │ │ ├── error_utils_test.py │ │ ├── formatting.py │ │ ├── formatting_test.py │ │ ├── hierarchical.py │ │ ├── hierarchical_test.py │ │ ├── json_conversion.py │ │ ├── json_conversion_test.py │ │ ├── missing.py │ │ ├── missing_test.py │ │ ├── text_color.py │ │ ├── text_color_test.py │ │ ├── thread_local.py │ │ ├── thread_local_test.py │ │ ├── timing.py │ │ ├── timing_test.py │ │ ├── value_location.py │ │ └── value_location_test.py │ └── views │ │ ├── __init__.py │ │ ├── base.py │ │ ├── base_test.py │ │ └── html │ │ ├── __init__.py │ │ ├── base.py │ │ ├── base_test.py │ │ ├── controls │ │ ├── __init__.py │ │ ├── base.py │ │ ├── label.py │ │ ├── label_test.py │ │ ├── progress_bar.py │ │ ├── progress_bar_test.py │ │ ├── tab.py │ │ ├── tab_test.py │ │ ├── tooltip.py │ │ └── tooltip_test.py │ │ ├── tree_view.py │ │ └── tree_view_test.py └── ext │ ├── __init__.py │ ├── early_stopping │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── step_wise.py │ └── step_wise_test.py │ ├── evolution │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── hill_climb.py │ ├── hill_climb_test.py │ ├── mutators.py │ ├── mutators_test.py │ ├── neat.py │ ├── neat_test.py │ ├── nsga2.py │ ├── nsga2_test.py │ ├── recombinators.py │ ├── recombinators_test.py │ ├── regularized_evolution.py │ ├── regularized_evolution_test.py │ ├── selectors.py │ ├── selectors_test.py │ ├── where.py │ └── where_test.py │ ├── mutfun │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── basic_ops.py │ └── basic_ops_test.py │ └── scalars │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── maths.py │ ├── maths_test.py │ ├── randoms.py │ ├── randoms_test.py │ ├── step_wise.py │ └── step_wise_test.py ├── requirements.txt └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = pyglove 3 | branch = True 4 | 5 | [report] 6 | precision = 2 7 | show_missing = True 8 | omit = **/*_test.py 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test-ubuntu: 13 | name: "pytest on ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] 18 | os: [ubuntu-latest] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | pip install pytest 28 | pip install pytest-xdist 29 | pip install pytest-cov 30 | pip install -r requirements.txt 31 | - name: Test with pytest and generate coverage report 32 | run: | 33 | pytest -n auto --cov=pyglove --cov-report=xml 34 | - name: Upload coverage to Codecov 35 | uses: codecov/codecov-action@v1 36 | with: 37 | file: ./coverage.xml 38 | # The below step just reports the success or failure of tests as a "commit status". 39 | # This is needed for copybara integration. 40 | - name: Report success or failure as github status 41 | if: always() 42 | shell: bash 43 | run: | 44 | status="${{ job.status }}" 45 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') 46 | curl -sS --request POST \ 47 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ 48 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ 49 | --header 'content-type: application/json' \ 50 | --data '{ 51 | "state": "'$lowercase_status'", 52 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", 53 | "description": "'$status'", 54 | "context": "github-actions/build" 55 | }' 56 | -------------------------------------------------------------------------------- /.github/workflows/docgen_test.yaml: -------------------------------------------------------------------------------- 1 | name: docgen_test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test-ubuntu: 13 | name: "documentation generation on ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | strategy: 16 | matrix: 17 | python-version: ["3.9"] 18 | os: [ubuntu-latest] 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r docs/requirements.txt 29 | - name: Print installed dependencies 30 | run: | 31 | pip freeze 32 | - name: Make Sphinx Docs to HTML (Test) 33 | run: | 34 | cd docs 35 | bash ./generate_docs.sh html 36 | # The below step just reports the success or failure of tests as a "commit status". 37 | # This is needed for copybara integration. 38 | - name: Report success or failure as github status 39 | if: always() 40 | shell: bash 41 | run: | 42 | status="${{ job.status }}" 43 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') 44 | curl -sS --request POST \ 45 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ 46 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ 47 | --header 'content-type: application/json' \ 48 | --data '{ 49 | "state": "'$lowercase_status'", 50 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", 51 | "description": "'$status'", 52 | "context": "github-actions/docgen" 53 | }' 54 | -------------------------------------------------------------------------------- /.github/workflows/pypi-nightly.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created. 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: pypi-nightly 5 | 6 | on: 7 | workflow_dispatch: 8 | schedule: 9 | # Everyday@1:00AM PST 10 | - cron: "0 8 * * *" 11 | 12 | jobs: 13 | deploy: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | pip install -r requirements.txt 26 | - name: Build and publish 27 | env: 28 | TWINE_USERNAME: __token__ 29 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 30 | run: | 31 | # Remove the parallel image for Github dark mode in README.md 32 | sed -i -e 's$logo$$' README.md 33 | python setup.py sdist bdist_wheel -- --nightly 34 | twine upload dist/* --skip-existing 35 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created. 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: pypi 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: '3.x' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine 23 | pip install -r requirements.txt 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 28 | run: | 29 | # Remove the parallel image for Github dark mode in README.md 30 | sed -i -e 's$logo$$' README.md 31 | python setup.py sdist bdist_wheel 32 | twine upload dist/* 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .coverage 3 | .vscode 4 | *.egg-info 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | apt_packages: 18 | - graphviz 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: docs/conf.py 23 | 24 | # If using Sphinx, optionally build your docs in additional formats such as PDF 25 | # formats: 26 | # - pdf 27 | 28 | # Optionally declare the Python requirements required to build your docs 29 | python: 30 | install: 31 | - requirements: docs/requirements.txt 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /docs/_static/favicon.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-side-nav-search { 4 | background-color: #000; 5 | } 6 | 7 | .wy-table-responsive table td, 8 | .wy-table-responsive table th { 9 | white-space: normal; 10 | } -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set css_files = css_files + ["_static/style.css"] %} -------------------------------------------------------------------------------- /docs/api/_templates/_default/class.rst: -------------------------------------------------------------------------------- 1 | .. _{{cls.rst_label}}: 2 | 3 | {{cls.preferred_path}} 4 | ====================== 5 | 6 | Accessible via {{cls.rst_access_paths}}. 7 | 8 | .. autoclass:: {{cls.rst_import_name}} 9 | :members: 10 | :show-inheritance: 11 | :autosummary: -------------------------------------------------------------------------------- /docs/api/_templates/_default/function.rst: -------------------------------------------------------------------------------- 1 | .. _{{func.rst_label}}: 2 | 3 | {{func.preferred_path}} 4 | ===================== 5 | 6 | Accessible via {{func.rst_access_paths}}. 7 | 8 | .. autofunction:: {{func.rst_import_name}} 9 | -------------------------------------------------------------------------------- /docs/api/_templates/_default/module.rst: -------------------------------------------------------------------------------- 1 | .. _{{module.rst_label}}: 2 | 3 | {{module.preferred_path}} 4 | ===================== 5 | 6 | .. automodule:: {{module.rst_import_name}} 7 | 8 | {% if module.modules %} 9 | Modules 10 | ******* 11 | .. toctree:: 12 | :maxdepth: 1 13 | {% for entry in module.modules %} 14 | {{entry.name}}<{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %} 15 | {% endif %} 16 | 17 | {% if module.objects %} 18 | Objects 19 | ******* 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | {% for entry in module.objects %} 24 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %} 25 | {% endif %} 26 | 27 | {% if module.classes %} 28 | Classes 29 | ******* 30 | .. toctree:: 31 | :maxdepth: 1 32 | {% for entry in module.classes %} 33 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %} 34 | {% endif %} 35 | 36 | {% if module.functions %} 37 | Functions 38 | ********* 39 | .. toctree:: 40 | :maxdepth: 1 41 | {% for entry in module.functions %} 42 | {{entry.name}} <{{entry.api.relative_handle(module.doc_dir)}}>{% endfor %} 43 | {% endif %} -------------------------------------------------------------------------------- /docs/api/_templates/_default/object.rst: -------------------------------------------------------------------------------- 1 | .. _{{obj.rst_label}}: 2 | 3 | {{obj.preferred_path}} 4 | ================ 5 | 6 | Accessible via {{obj.rst_access_paths}}. 7 | 8 | .. autodata:: {{obj.rst_import_name}} 9 | -------------------------------------------------------------------------------- /docs/api/_templates/core/symbolic/dict.rst: -------------------------------------------------------------------------------- 1 | .. _{{cls.rst_label}}: 2 | 3 | {{cls.preferred_path}} 4 | ================ 5 | 6 | Accessible via {{cls.rst_access_paths}}. 7 | 8 | .. autoclass:: {{cls.rst_import_name}} 9 | :members: 10 | :show-inheritance: 11 | :autosummary: 12 | :inherited-members: 13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__ 14 | -------------------------------------------------------------------------------- /docs/api/_templates/core/symbolic/list.rst: -------------------------------------------------------------------------------- 1 | .. _{{cls.rst_label}}: 2 | 3 | {{cls.preferred_path}} 4 | ================ 5 | 6 | Accessible via {{cls.rst_access_paths}}. 7 | 8 | .. autoclass:: {{cls.rst_import_name}} 9 | :members: 10 | :show-inheritance: 11 | :autosummary: 12 | :inherited-members: 13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__ 14 | -------------------------------------------------------------------------------- /docs/api/_templates/core/symbolic/object.rst: -------------------------------------------------------------------------------- 1 | .. _{{cls.rst_label}}: 2 | 3 | {{cls.preferred_path}} 4 | ================ 5 | 6 | Accessible via {{cls.rst_access_paths}}. 7 | 8 | .. autoclass:: {{cls.rst_import_name}} 9 | :members: 10 | :show-inheritance: 11 | :autosummary: 12 | :inherited-members: 13 | :private-members: _on_init, _on_change, _on_bound, _on_parent_change, _on_path_change 14 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__ 15 | -------------------------------------------------------------------------------- /docs/api/_templates/core/symbolic/symbolic.rst: -------------------------------------------------------------------------------- 1 | .. _{{cls.rst_label}}: 2 | 3 | {{cls.preferred_path}} 4 | ================ 5 | 6 | Accessible via {{cls.rst_access_paths}}. 7 | 8 | .. autoclass:: {{cls.rst_import_name}} 9 | :members: 10 | :show-inheritance: 11 | :autosummary: 12 | :inherited-members: 13 | :special-members: __copy__, __deepcopy__, __eq__, __ne__, __hash__ 14 | -------------------------------------------------------------------------------- /docs/api/_templates/index.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: pyglove 2 | 3 | Public API: pyglove 4 | =================== 5 | 6 | Modules 7 | ------- 8 | 9 | core 10 | ^^^^ 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | {% for e in module.modules -%} 15 | {%- if e.api.source_category == 'core' %} 16 | {{e.api.relative_handle(module.doc_dir)}} 17 | {%- endif -%} 18 | {%- endfor %} 19 | 20 | ext 21 | ^^^ 22 | .. toctree:: 23 | :maxdepth: 1 24 | {% for e in module.modules -%} 25 | {%- if e.api.source_category == 'ext' %} 26 | {{e.api.relative_handle(module.doc_dir)}} 27 | {%- endif -%} 28 | {%- endfor %} 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | 33 | {% for e in module.modules -%} 34 | {%- if e.api.source_category == 'generators' %} 35 | {{e.api.relative_handle(module.doc_dir)}} 36 | {%- endif -%} 37 | {%- endfor %} 38 | 39 | 40 | Top-level shortcurts 41 | -------------------- 42 | 43 | Objects 44 | ^^^^^^^ 45 | 46 | {% for e in module.objects %} 47 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>` 48 | {%- endfor %} 49 | 50 | Classes 51 | ^^^^^^^ 52 | 53 | {% for e in module.classes %} 54 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>` 55 | {%- endfor %} 56 | 57 | Functions 58 | ^^^^^^^^^ 59 | {% for e in module.functions %} 60 | * :ref:`pg.{{e.name}}<{{e.api.rst_label}}>` 61 | {%- endfor %} 62 | -------------------------------------------------------------------------------- /docs/generate_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Minimal script to manually generate Sphinx documentation, used by Github 3 | # integration testing as well. Installs necessary Sphinx packages via 4 | # `requirements.txt`, and then uses `conf.py` to output HTML files. This script 5 | # is not actually needed for generating the official website at 6 | # (https://pyglove.readthedocs.io/en/latest/). 7 | 8 | # Define output folder for build files. 9 | OUTPUT_FOLDER=_build 10 | 11 | # Install Sphinx. 12 | sudo apt-get install python3-sphinx 13 | 14 | # Installs relevant Sphinx packages. 15 | pip install -r requirements.txt --use-deprecated=legacy-resolver 16 | 17 | # Build files (HTML, doctests, etc.) into `OUTPUT_FOLDER` directory. 18 | rm -rf ${OUTPUT_FOLDER} # Clear out original folder 19 | sphinx-build -b $1 -a . ${OUTPUT_FOLDER} 20 | 21 | # Optionally host the HTML folder. Access on browser `https://localhost:5000/`. 22 | # python -m http.server --directory ${OUTPUT_FOLDER} 5000 23 | -------------------------------------------------------------------------------- /docs/guide/automl/index.rst: -------------------------------------------------------------------------------- 1 | For AutoML 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | toy 8 | industrial 9 | Research 10 | -------------------------------------------------------------------------------- /docs/guide/automl/industrial.rst: -------------------------------------------------------------------------------- 1 | Industrial Uses 2 | =============== 3 | 4 | Google Cloud Vertex AI 5 | ********************** 6 | 7 | PyGlove is the `driving force `_ 8 | behind `Vertex AI NAS `_, which 9 | has brought `advanced AutoML technologies `_ 10 | to industries with significant impact. Companies such as 11 | `Qualcomm `_, 12 | `Varian `_, and 13 | `Oppo `_ 14 | have all benefited from this groundbreaking technology. 15 | 16 | Vertex AI has published a few search spaces for computer vision as examples: 17 | 18 | * EfficientNet [`paper `_][`code `_] 19 | * SpineNet [`paper `_][`code `_] 20 | * MNASNet [`paper `_][`code `_] 21 | * NASFpn [`paper `_][`code `_] 22 | * AutoAugment [`paper `_][`code `_] 23 | 24 | Pax 25 | *** 26 | 27 | Pax is a powerful machine learning framework developed by Google, based on Jax, that is designed for training large-scale models. 28 | It utilizes PyGlove to enable its hyperparameter tuning and AutoML capabilities. Pax serves as a good example of how PyGlove can 29 | be seamlessly integrated into a large scale ML codebase based on dynamic evaluation: 30 | 31 | * `Inspecting the search space `_ 32 | * `Implementing the tuning loop `_ 33 | 34 | Vizier 35 | ****** 36 | 37 | `Vizier `_ 38 | is the distributed tuning solution at Alphabet. PyGlove uses it as a backend for serving distributed AutoML scenarios at Google. 39 | The `open-source Vizier `_ has shown how PyGlove can be 40 | `used together with Vizier `_, 41 | it also serve as an example on how PyGlove backend could be developed. 42 | 43 | * `Implementing the Backend interface `_ 44 | * `Implementing the Feedback interface `_ 45 | 46 | -------------------------------------------------------------------------------- /docs/guide/automl/research.rst: -------------------------------------------------------------------------------- 1 | AutoML Research 2 | --------------- 3 | 4 | PyGlove is a versatile library that not only simplifies the implementation of AutoML applications but is also powerful enough to facilitate complex AutoML/ML research. 5 | Its capability and flexibility have been demonstrated in several academic papers, including the following: 6 | 7 | Papers with code: 8 | 9 | 10 | * `Evolving Reinforcement Learning algorithms, ICLR 2021 `_ [`code `_] 11 | * `PyGlove: Efficiently Exchanging ML Ideas as Code, 2022 `_ [`code `_] 12 | 13 | Papers only: 14 | 15 | * `PyGlove: Symbolic Programming for Automated Machine Learning, NeurIPS 2020 `_ 16 | * `AutoHAS: Efficient Hyperparameter and Architecture Search, ICLR 2021 NAS workshop `_ 17 | * `Towards the Co-design of Neural Networks and Accelerators, MLSys 2022 `_ 18 | * `Deepfusion: Lidar-camera Deep Fusion for Multi-modal 3D Object Detection, CVPR 2022 `_ 19 | * `ES-ENAS: Combining Evolution Strategies with Neural Architecture Search at No Extra Cost for Reinforcement Learning, CoRR 2021 `_ 20 | -------------------------------------------------------------------------------- /docs/guide/automl/toy.rst: -------------------------------------------------------------------------------- 1 | Toy Problems 2 | ------------ 3 | 4 | * `Neural Architecture Search on MNIST `_ 5 | * `NAS-Bench-101 `_ 6 | * `NATS-Bench `_ 7 | -------------------------------------------------------------------------------- /docs/guide/evolution/index.rst: -------------------------------------------------------------------------------- 1 | For Evolution 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | ../../notebooks/evolution/onemax 8 | ../../notebooks/evolution/tsp 9 | Function Regression <../../notebooks/evolution/function_regression> -------------------------------------------------------------------------------- /docs/guide/general/index.rst: -------------------------------------------------------------------------------- 1 | For Python 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Flexible Function Bindings <../../notebooks/intro/basics/symbolic_function> 8 | Runtime Typing <../../notebooks/intro/basics/runtime_typing> 9 | Direct Manipulation <../../notebooks/python/interactive_svg> 10 | Domain-Specific Languages <../../notebooks/python/sticky_notes> 11 | Context-aware Components <../../notebooks/python/where_is_the_duck> 12 | -------------------------------------------------------------------------------- /docs/guide/index.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | ../notebooks/intro/birdview 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | general/index.rst 13 | ml/index.rst 14 | automl/index.rst 15 | evolution/index.rst 16 | -------------------------------------------------------------------------------- /docs/guide/ml/index.rst: -------------------------------------------------------------------------------- 1 | For Machine Learning 2 | ==================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Symbolic Modeling <../../notebooks/ml/neural_modeling> 8 | Symbolic Machine Learning <../../notebooks/ml/symbolic_ml> 9 | Patching: Efficiently Exchangeing ML Ideas <../../notebooks/ml/efficiently_exchange_ml_ideas_as_code> 10 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. BUILD: sphinx-build -b html -a . ../build 2 | 3 | Welcome to PyGlove 4 | ################## 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | guide/index 10 | learn/index 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: API References 15 | 16 | api/index 17 | API Index (A-Z) 18 | -------------------------------------------------------------------------------- /docs/learn/evolution/index.rst: -------------------------------------------------------------------------------- 1 | Evolution 2 | ######### 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Algorithm <../../notebooks/intro/search/evolution_algorithm> 8 | Operations <../../notebooks/intro/search/evolution_ops> 9 | Fine Control <../../notebooks/intro/search/evolution_scheduling> -------------------------------------------------------------------------------- /docs/learn/index.rst: -------------------------------------------------------------------------------- 1 | Learning PyGlove 2 | ################ 3 | 4 | .. PyGlove is organized into three layers: 5 | 6 | .. * At the bottom, the layer of symbolic object-oriented programming, enables a mutable 7 | .. programming model with objects, which allows unknown program parts to be expressed 8 | .. side-by-side with the expression of known program parts, and enables dynamic interpretations 9 | .. on them. 10 | .. * At the middle, the layer of intelligent programs, provides the representation 11 | .. and operations to convert between symbols and numbers. It also introduces the expression for 12 | .. feedback loops, as well as a framework for building algorithms that evolve the program. 13 | .. * At the top, the layer of distributed symbolic computing, introduces API to allow 14 | .. feedback loops to be distributed so that intelligent programs can run at scale. This layer 15 | .. also provide the interfaces for the user to plug in their own infrastructures into PyGlove. 16 | 17 | PyGlove is a Python library for manipulating Python programs. It is built on the 18 | concept of symbolic object-oriented programming (SOOP), which forms its core foundation. 19 | On top of that, PyGlove includes multiple layers of components that enhance its capabilities and 20 | enable the development of intelligent systems. 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | 26 | SOOP 27 | evolution/index 28 | 29 | .. .. image:: how_pyglove_works.svg 30 | 31 | -------------------------------------------------------------------------------- /docs/learn/soop/index.rst: -------------------------------------------------------------------------------- 1 | Symbolic Object-Oriented Programming 2 | #################################### 3 | 4 | PyGlove offers two key capabilities for symbolic object-oriented programming: the 5 | *Symbolic Object Model (SOM)* and *Symbolic Detour (SD)*. SOM implements *dynamic representation*, 6 | allowing for flexible manipulation of symbolic objects. SD enables *dynamic interpretation*, 7 | allowing symbols to be interpreted in different ways at runtime. 8 | 9 | 10 | Symbolic Object Model 11 | ********************* 12 | 13 | The Symbolic Object Model (SOM) is the core of symbolic object-oriented programming, 14 | providing dynamic representation through symbolic objects. SOM stores initialization 15 | arguments as symbolic attributes and allows inspection and manipulation of them. 16 | It also includes a symbolic schema system for validation and a messaging system for 17 | handling mutations. Symbolic placeholders are also supported for representing unknown 18 | program parts. 19 | 20 | .. .. _`dynamic representation`: ../../overview/what_and_why.html#programming-the-unknown 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | 26 | som/definition 27 | som/types 28 | som/operations 29 | som/events 30 | som/validation 31 | som/placeholding 32 | 33 | Symbolic Detour 34 | *************** 35 | 36 | Symbolic Detour (SD) is independent of the SOM, allowing users to alter 37 | class mapping without modifying the source code that instantiate the classes. This is 38 | particularly useful when the source code cannot be symbolized in various reasons. 39 | SD complements SOM. 40 | 41 | .. toctree:: 42 | :maxdepth: 1 43 | 44 | detour 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Required Sphinx-related packages. 2 | 3 | absl-py 4 | autodocsumm 5 | furo 6 | jinja2 7 | sphinx_autodoc_typehints 8 | sphinx-copybutton 9 | myst_nb 10 | 11 | # Required packages for PyGlove. 12 | 13 | docstring-parser>=0.12 14 | -------------------------------------------------------------------------------- /examples/automl/mnist/mnist_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Train MNIST. 15 | 16 | This is a basic working ML program which trains MNIST. 17 | The code is modified from the tf.keras tutorial here: 18 | https://www.tensorflow.org/tutorials/keras/classification 19 | 20 | (The tutorial uses Fashion-MNIST, 21 | but we just use "regular" MNIST for these tutorials.) 22 | 23 | """ 24 | 25 | from typing import Tuple 26 | 27 | from absl import app 28 | import numpy as np 29 | import tensorflow.google.compat.v2 as tf 30 | 31 | 32 | def download_and_prep_data() -> Tuple[np.ndarray, 33 | np.ndarray, 34 | np.ndarray, 35 | np.ndarray]: 36 | """Download dataset and scale to [0, 1]. 37 | 38 | Returns: 39 | tr_x: Training data. 40 | tr_y: Training labels. 41 | te_x: Testing data. 42 | te_y: Testing labels. 43 | """ 44 | mnist_dataset = tf.keras.datasets.mnist 45 | (tr_x, tr_y), (te_x, te_y) = mnist_dataset.load_data() 46 | tr_x = tr_x / 255.0 47 | te_x = te_x / 255.0 48 | return tr_x, tr_y, te_x, te_y 49 | 50 | 51 | def create_model() -> tf.keras.Model: 52 | """Create model for training. 53 | 54 | Create a simple tf.keras model for training. 55 | 56 | Returns: 57 | The model to use for training. 58 | """ 59 | model = tf.keras.Sequential([ 60 | tf.keras.layers.Flatten(input_shape=(28, 28)), 61 | tf.keras.layers.Dense(128, activation='relu'), 62 | tf.keras.layers.Dense(10, activation='softmax') 63 | ]) 64 | return model 65 | 66 | 67 | def train_and_eval() -> None: 68 | """Run training and evaluation. 69 | 70 | Code to run all of the prep, training, and evaluation. 71 | """ 72 | tr_x, tr_y, te_x, te_y = download_and_prep_data() 73 | model = create_model() 74 | model.compile(optimizer='adam', 75 | loss='sparse_categorical_crossentropy', 76 | metrics=['accuracy']) 77 | model.fit(tr_x, tr_y, epochs=10) 78 | test_loss, test_acc = model.evaluate(te_x, te_y, verbose=2) 79 | print('Test loss: {}, accuracy: {}'.format(test_loss, test_acc)) 80 | 81 | 82 | def main(argv): 83 | if len(argv) > 1: 84 | raise app.UsageError('Too many command-line arguments.') 85 | train_and_eval() 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /examples/evolution/onemax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Solving the One-Max problem with custom encoding. 15 | 16 | Reference: https://tracer.lcc.uma.es/problems/onemax/onemax.html#SE91 17 | For more details, see: 18 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/evolution/onemax.ipynb 19 | """ 20 | 21 | import random 22 | import pyglove as pg 23 | 24 | 25 | def one_max(search_space, search_algorithm, num_trials=200): 26 | """Solves the One-Max program through running a search.""" 27 | best_sequence, best_reward = None, None 28 | for sequence, feedback in pg.sample( 29 | search_space, search_algorithm, 30 | num_examples=num_trials): 31 | reward = sum(sequence) 32 | if best_reward is None or best_reward < reward: 33 | best_sequence, best_reward = sequence, reward 34 | feedback(reward) 35 | print(f'Best sequence: {list(best_sequence)} (sum={best_reward})') 36 | 37 | 38 | def one_max_with_builtin_primitive(n: int): 39 | """Solves One-Max problem using built-in hyper primitive.""" 40 | search_space = pg.List([pg.oneof([0, 1])] * n) 41 | search_algorithm = pg.evolution.regularized_evolution( 42 | population_size=20, tournament_size=10) 43 | one_max(search_space, search_algorithm) 44 | 45 | 46 | def one_max_with_custom_primitive(n: int): 47 | """Sovles One-Max problem using user-defined hyper primitive.""" 48 | 49 | class BitString(pg.hyper.CustomHyper): 50 | """Custom hyper primitive that represents a bit string of size n.""" 51 | 52 | def custom_decode(self, dna: pg.DNA): 53 | assert isinstance(dna.value, str) 54 | bitstr = dna.value 55 | return [int(x) for x in bitstr] 56 | 57 | class MutateOneBit(pg.evolution.Mutator): 58 | 59 | def mutate(self, dna: pg.DNA): # pytype: disable=signature-mismatch # overriding-parameter-count-checks 60 | bitstr = dna.value 61 | index = random.randint(0, len(dna.value) - 1) 62 | new_bitstr = ( 63 | bitstr[:index] 64 | + ('0' if bitstr[index] == '1' else '1') 65 | + bitstr[index + 1:]) 66 | return pg.DNA(new_bitstr) 67 | 68 | def init_population(population_size): 69 | @pg.geno.dna_generator 70 | def initializer(dna_spec): 71 | del dna_spec 72 | for _ in range(population_size): 73 | bits = [str(random.randint(0, 1)) for _ in range(n)] 74 | yield pg.DNA(''.join(bits)) 75 | return initializer() # pylint: disable=no-value-for-parameter 76 | 77 | search_space = BitString() 78 | search_algorithm = pg.evolution.Evolution( 79 | (pg.evolution.selectors.Random(10) 80 | >> pg.evolution.selectors.Top(1) 81 | >> MutateOneBit()), 82 | population_init=init_population(10), 83 | population_update=pg.evolution.selectors.Last(20)) 84 | one_max(search_space, search_algorithm) 85 | 86 | 87 | def main(): 88 | one_max_with_builtin_primitive(10) 89 | one_max_with_custom_primitive(10) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /examples/evolution/tsp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Solving the travelling salesman problem (TSP) with evolution. 15 | 16 | Reference: https://en.wikipedia.org/wiki/Travelling_salesman_problem 17 | 18 | A more detailed tutorial explaning how to use PyGlove to solve TSP can be 19 | found here: 20 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/evolution/tsp.ipynb 21 | """ 22 | 23 | import math 24 | from typing import List 25 | import pyglove as pg 26 | 27 | 28 | @pg.symbolize 29 | class City: 30 | """Represents a city with location (x, y) on the map.""" 31 | 32 | def __init__(self, x: int, y: int): 33 | self.x = x 34 | self.y = y 35 | 36 | def distance(self, other: 'City') -> float: 37 | return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2) 38 | 39 | 40 | @pg.symbolize 41 | class Route: 42 | """Represents a route that traverse the cities in their appearing order.""" 43 | 44 | def __init__(self, cities: List[City]): 45 | self.cities = cities 46 | 47 | def length(self) -> float: 48 | l = 0 49 | for i in range(0, len(self.cities)): 50 | l += self.cities[i].distance(self.cities[(i + 1) % len(self.cities)]) 51 | return l 52 | 53 | 54 | def tsp(cities: List[City], num_trials: int = 500) -> Route: 55 | """Returns the best route found.""" 56 | best_route, min_len = None, None 57 | 58 | # The route space is a Route object 59 | # with all possible permutations generated 60 | # from given cities. 61 | route_space = Route(pg.permutate(cities)) 62 | 63 | def evolution(op, population_size=50, tournament_size=20, seed=None): 64 | return pg.evolution.Evolution( 65 | (pg.evolution.selectors.Random(tournament_size, seed=seed) 66 | >> pg.evolution.selectors.Top(2) 67 | >> op), 68 | population_init=(pg.geno.Random(seed=seed), population_size), 69 | population_update=pg.evolution.selectors.Last(population_size)) 70 | 71 | search_algorithm = evolution( 72 | pg.evolution.recombinators.PartiallyMapped() 73 | >> pg.evolution.mutators.Swap()) 74 | 75 | # `pg.sample` is the way to sample an example 76 | # route from the route space. Each iteration 77 | # will emit a `feedback` object, which can be 78 | # used to pass the reward to the controller. 79 | for route, feedback in pg.sample( 80 | route_space, 81 | search_algorithm, 82 | num_examples=num_trials): 83 | l = route.length() 84 | if min_len is None or min_len > l: 85 | best_route, min_len = route, l 86 | # We negate the route length as the reward since 87 | # the algorithm is to maximize the reward value. 88 | feedback(-l) 89 | 90 | print(f'Best route length: {min_len}.') 91 | print(best_route) 92 | return best_route 93 | 94 | 95 | def main(): 96 | # Generating 25 cities. 97 | cities = list(pg.random_sample( 98 | City(x=pg.oneof(range(100)), y=pg.oneof(range(100))), 25, seed=1)) 99 | tsp(cities) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /examples/ml/symbolic_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Symbolic neural modeling with PyGlove. 15 | 16 | For more details, see: 17 | https://colab.research.google.com/github/google/pyglove/blob/main/docs/notebooks/ml/neural_modeling.ipynb 18 | """ 19 | import pyglove as pg 20 | import tensorflow as tf 21 | 22 | 23 | # Symbolizing Keras layers so their instances can be symbolically manipulated. 24 | 25 | Sequential = pg.symbolize(tf.keras.Sequential) 26 | Conv2D = pg.symbolize(tf.keras.layers.Conv2D) 27 | Dense = pg.symbolize(tf.keras.layers.Dense) 28 | Flatten = pg.symbolize(tf.keras.layers.Flatten) 29 | ReLU = pg.symbolize(tf.keras.layers.ReLU) 30 | 31 | 32 | def create_model() -> tf.keras.layers.Layer: 33 | return Sequential([ 34 | Conv2D(16, (5, 5)), 35 | ReLU(), 36 | Conv2D(32, (3, 3)), 37 | ReLU(), 38 | Flatten(), 39 | Dense(10) 40 | ]) 41 | 42 | 43 | def scale_model(model) -> None: 44 | """Scale the model up by doubling the filters of Conv2D layers.""" 45 | def double_width(k, v, p): 46 | """A rebind rule for doubling the filters for Conv2D layers. 47 | 48 | Args: 49 | k: A `pg.KeyPath` object representing the location of current node. 50 | v: The value of current node. 51 | p: The parent of current node. 52 | 53 | Returns: 54 | The output value for current node. 55 | """ 56 | if isinstance(p, Conv2D) and k.key == 'filters': 57 | return 2 * v 58 | return v 59 | 60 | # Rebind allows the users to manipulate a symbolic object by 61 | # rules. 62 | model.rebind(double_width) 63 | 64 | 65 | def remove_relus(model) -> None: 66 | """Remove ReLU layers from the model.""" 67 | def remove_activations(k, v, p): 68 | del k, p 69 | if isinstance(v, ReLU): 70 | # `pg.MISSING_VALUE` is a placeholder for deleting a value from 71 | # a container. 72 | return pg.MISSING_VALUE 73 | return v 74 | model.rebind(remove_activations) 75 | 76 | 77 | def change_classification_head_width(model, width: int) -> None: 78 | """Update classification head width.""" 79 | result = pg.query(model, where=lambda v: isinstance(v, Dense)) 80 | classification_head_location = list(result.keys())[-1] 81 | model.rebind({ 82 | f'{classification_head_location}.units': width 83 | }) 84 | 85 | 86 | def main() -> None: 87 | model = create_model() 88 | # The symbolized Keras layers can be printed in human readable form. 89 | # For clarity, we hide the default values of the layers. 90 | print('Original model.') 91 | print(model.format(hide_default_values=True)) 92 | 93 | scale_model(model) 94 | print('After doubling the width.') 95 | print(model.format(hide_default_values=True)) 96 | 97 | remove_relus(model) 98 | print('After removing the ReLUs.') 99 | print(model.format(hide_default_values=True)) 100 | 101 | print('After changing the classification head width.') 102 | change_classification_head_width(model, 100) 103 | print(model.format(hide_default_values=True)) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /examples/python/runtime_typing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PyGlove runtime typing example.""" 15 | 16 | import pyglove as pg 17 | 18 | 19 | class Foo(pg.Object): 20 | """A symbolic class.""" 21 | 22 | x: pg.typing.List[int] | None 23 | y: pg.typing.Dict[str, pg.typing.Any] 24 | z: pg.typing.Annotated[ 25 | pg.typing.Int(min_value=0), # Field value spec. 26 | 'Field z', # Field docstring. 27 | dict(p=1) # Meta-data 28 | ] 29 | p: pg.typing.Enum[['foo', 'bar', 'baz']] = 'baz' 30 | q: pg.typing.Dict[{ 31 | 'a': int | None, 32 | 'b': pg.typing.Int[0, 100], 33 | 'c': pg.typing.Tuple[int, ...] 34 | }] = dict(a=1, b=10, c=(1, 2)) 35 | 36 | 37 | # `pg.typing` could also be used for static type analysis. 38 | def add( 39 | x: pg.typing.Int[0, None], 40 | y: pg.typing.Float[None, 1.0]) -> pg.typing.Any: 41 | return x + y 42 | 43 | 44 | def main() -> None: 45 | foo = Foo([0, 1], dict(x=1), 1) 46 | print(foo) 47 | 48 | try: 49 | Foo(None, dict(y=1), -1) 50 | except ValueError as e: 51 | print('Expected error', e) 52 | 53 | # There is no runtime check for regular Python function even type annotation 54 | # is given. 55 | print(add(-1, 2.0)) 56 | 57 | # But we can create a symbolic function which does runtime value check. 58 | prime_add = pg.symbolize(add) 59 | try: 60 | prime_add(-1, 2.0)() 61 | except ValueError as e: 62 | print('Expected error', e) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | 68 | -------------------------------------------------------------------------------- /pyglove/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Package pyglove. 16 | 17 | This package is the facade of the public PyGlove library, which includes modules 18 | for symbolic type definition and manipulation, symbolic typing and constraints, 19 | symbolic value generation and etc. It only have a handful of dependencies such 20 | as enum, six, yaml. 21 | """ 22 | 23 | # NOTE(daiyip): We disable bad-import-order to preserve the relation of 24 | # imported symbols 25 | # pylint: disable=g-bad-import-order 26 | # pylint: disable=unused-import 27 | # pylint: disable=reimported 28 | # pylint: disable=g-import-not-at-top 29 | 30 | from pyglove.core import * 31 | from pyglove.ext import * 32 | 33 | # Placeholder for Google-internal imports. 34 | 35 | # pylint: enable=g-import-not-at-top 36 | # pylint: enable=reimported 37 | # pylint: enable=unused-import 38 | # pylint: enable=g-bad-import-order 39 | 40 | __version__ = "0.4.5" 41 | -------------------------------------------------------------------------------- /pyglove/core/coding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # pylint: disable=line-too-long 15 | """Code generation utilities.""" 16 | 17 | # pylint: enable=line-too-long 18 | # pylint: disable=g-bad-import-order 19 | # pylint: disable=g-importing-member 20 | 21 | from pyglove.core.coding.errors import CodeError 22 | from pyglove.core.coding.errors import SerializationError 23 | 24 | from pyglove.core.coding.permissions import CodePermission 25 | from pyglove.core.coding.permissions import permission 26 | from pyglove.core.coding.permissions import get_permission 27 | 28 | from pyglove.core.coding.parsing import parse 29 | 30 | from pyglove.core.coding.execution import context 31 | from pyglove.core.coding.execution import get_context 32 | from pyglove.core.coding.execution import evaluate 33 | from pyglove.core.coding.execution import sandbox_call 34 | from pyglove.core.coding.execution import maybe_sandbox_call 35 | from pyglove.core.coding.execution import run 36 | 37 | from pyglove.core.coding.function_generation import NO_TYPE_ANNOTATION 38 | from pyglove.core.coding.function_generation import make_function 39 | 40 | # pylint: disable=line-too-long 41 | # pylint: enable=g-bad-import-order 42 | # pylint: enable=g-importing-member 43 | -------------------------------------------------------------------------------- /pyglove/core/coding/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python code errors.""" 15 | 16 | import io 17 | import sys 18 | import textwrap 19 | import traceback 20 | from typing import Optional 21 | 22 | from pyglove.core import utils 23 | 24 | 25 | class CodeError(RuntimeError): 26 | """Python code error.""" 27 | 28 | def __init__( 29 | self, 30 | code: str, 31 | cause: BaseException, 32 | ): 33 | self.code = code 34 | self.cause = cause 35 | 36 | # Figure out the starting and ending line numbers of the erratic code. 37 | lineno = None 38 | end_lineno = None 39 | if isinstance(cause, SyntaxError): 40 | lineno = cause.lineno 41 | # For Python 3.9 and below, `end_lineno` is not available. 42 | end_lineno = getattr(cause, 'end_lineno', lineno) 43 | elif not isinstance(cause, TimeoutError): 44 | tb = sys.exc_info()[2] 45 | frames = traceback.extract_tb(tb, limit=5) 46 | for f in frames: 47 | if not f.filename or f.filename == '': 48 | lineno = f.lineno 49 | end_lineno = lineno 50 | break 51 | self.lineno = lineno 52 | self.end_lineno = end_lineno 53 | 54 | def __str__(self): 55 | return self.format(include_complete_code=True) 56 | 57 | def code_lines(self, start_line: int, end_line: int): 58 | """Returns code lines .""" 59 | return '\n'.join(self.code.split('\n')[start_line:end_line]) 60 | 61 | def format(self, include_complete_code: bool = True): 62 | """Formats the code error.""" 63 | r = io.StringIO() 64 | error_message = str(self.cause).rstrip() 65 | if 'line' not in error_message and self.lineno is not None: 66 | error_message += f' (, line {self.lineno})' 67 | r.write( 68 | utils.colored( 69 | f'{self.cause.__class__.__name__}: {error_message}', 'magenta')) 70 | 71 | if self.lineno is not None: 72 | r.write('\n\n') 73 | r.write(textwrap.indent( 74 | utils.colored( 75 | self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'), 76 | ' ' * 2 77 | )) 78 | r.write('\n') 79 | 80 | if include_complete_code: 81 | r.write('\n') 82 | r.write(utils.colored('[Code]', 'green', styles=['bold'])) 83 | r.write('\n\n') 84 | r.write(utils.colored(' ```python\n', 'green')) 85 | r.write(textwrap.indent( 86 | utils.colored(self.code, 'green'), 87 | ' ' * 2 88 | )) 89 | r.write(utils.colored('\n ```\n', 'green')) 90 | return r.getvalue() 91 | 92 | 93 | class SerializationError(RuntimeError): 94 | """Object serialization error.""" 95 | 96 | def __init__(self, message: Optional[str], cause: BaseException): 97 | self.message = message 98 | self.cause = cause 99 | 100 | def __str__(self): 101 | r = io.StringIO() 102 | cause_message = str(self.cause).rstrip() 103 | if self.message: 104 | r.write(utils.colored(self.message, 'magenta')) 105 | r.write('\n\n') 106 | r.write( 107 | utils.colored( 108 | f'{self.cause.__class__.__name__}: {cause_message}', 'magenta' 109 | ) 110 | ) 111 | return r.getvalue() 112 | -------------------------------------------------------------------------------- /pyglove/core/coding/errors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import unittest 16 | 17 | from pyglove.core.coding import errors 18 | from pyglove.core.coding import execution 19 | 20 | 21 | def code_error(code: str) -> errors.CodeError: 22 | try: 23 | execution.run(inspect.cleandoc(code), timeout=2) 24 | assert False, 'should not reach here' 25 | except errors.CodeError as e: 26 | return e 27 | 28 | 29 | class CodeErrorsTest(unittest.TestCase): 30 | 31 | def test_format(self): 32 | e = code_error( 33 | """ 34 | x = y + 1 35 | """ 36 | ) 37 | self.assertIn('[Code]', str(e)) 38 | self.assertNotIn( 39 | '[Code]', e.format(include_complete_code=False)) 40 | 41 | def test_lineno(self): 42 | self.assertEqual( 43 | code_error( 44 | """ 45 | x = y + 1 46 | """ 47 | ).lineno, 1) 48 | self.assertEqual( 49 | code_error( 50 | """ 51 | x = 1 52 | for i of x: 53 | y = i 54 | """ 55 | ).lineno, 2) 56 | self.assertEqual( 57 | code_error( 58 | """ 59 | x = 1 60 | y = 2 61 | raise ValueError 62 | """ 63 | ).lineno, 3) 64 | 65 | def test_lineno_in_error_message(self): 66 | def assert_lineno(code): 67 | e = code_error(code) 68 | self.assertIn('line', e.format(include_complete_code=False)) 69 | 70 | assert_lineno( 71 | """ 72 | x = y + 1 73 | """ 74 | ) 75 | assert_lineno( 76 | """ 77 | x = 1 78 | y = 2 79 | """ 80 | ) 81 | assert_lineno( 82 | """ 83 | raise ValueError() 84 | """ 85 | ) 86 | 87 | 88 | class SerializationErrorTest(unittest.TestCase): 89 | 90 | def test_str(self): 91 | e = errors.SerializationError( 92 | 'Output cannot be serialized.', ValueError('abc')) 93 | self.assertIn('Output cannot be serialized', str(e)) 94 | self.assertIn('ValueError: abc', str(e)) 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /pyglove/core/coding/function_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for function generation.""" 15 | 16 | from typing import Any, Dict, List, Optional 17 | 18 | 19 | class _NoTypeAnnotation: 20 | """Placeholder for no type annotation.""" 21 | 22 | 23 | NO_TYPE_ANNOTATION = _NoTypeAnnotation() 24 | 25 | 26 | def make_function( 27 | name: str, 28 | args: List[str], 29 | body: List[str], 30 | *, 31 | exec_globals: Optional[Dict[str, Any]] = None, 32 | exec_locals: Optional[Dict[str, Any]] = None, 33 | return_type: Any = NO_TYPE_ANNOTATION): 34 | """Creates a function dynamically from source.""" 35 | if exec_locals is None: 36 | exec_locals = {} 37 | if return_type != NO_TYPE_ANNOTATION: 38 | exec_locals['_return_type'] = return_type 39 | return_annotation = '->_return_type' 40 | else: 41 | return_annotation = '' 42 | args = ', '.join(args) 43 | body = '\n'.join(f' {line}' for line in body) 44 | fn_def = f' def {name}({args}){return_annotation}:\n{body}' 45 | local_vars = ', '.join(exec_locals.keys()) 46 | fn_def = f'def _make_fn({local_vars}):\n{fn_def}\n return {name}' 47 | ns = {} 48 | exec(fn_def, exec_globals, ns) # pylint: disable=exec-used 49 | return ns['_make_fn'](**exec_locals) 50 | -------------------------------------------------------------------------------- /pyglove/core/coding/function_generation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import typing 16 | import unittest 17 | 18 | from pyglove.core.coding import function_generation 19 | 20 | 21 | class MakeFunctionTest(unittest.TestCase): 22 | """Tests for function_generation.make_function.""" 23 | 24 | def test_make_function_with_type_annotations(self): 25 | func = function_generation.make_function( 26 | 'foo', 27 | ['x: typing.Optional[int]', 'y: int = 0'], 28 | ['return x + y'], 29 | exec_globals=None, 30 | exec_locals={'typing': typing}, 31 | return_type=int) 32 | 33 | signature = inspect.signature(func) 34 | self.assertEqual(list(signature.parameters.keys()), ['x', 'y']) 35 | self.assertEqual(signature.parameters['x'].annotation, typing.Optional[int]) 36 | self.assertEqual(signature.parameters['y'].annotation, int) 37 | self.assertEqual(signature.parameters['y'].default, 0) 38 | self.assertIs(signature.return_annotation, int) 39 | self.assertEqual(func(1, 2), 3) 40 | 41 | def test_make_function_without_type_annotations(self): 42 | func = function_generation.make_function( 43 | 'foo', 44 | ['x', 'y'], 45 | ['return x + y']) 46 | signature = inspect.signature(func) 47 | self.assertEqual(list(signature.parameters.keys()), ['x', 'y']) 48 | self.assertEqual( 49 | signature.parameters['x'].annotation, inspect.Signature.empty) 50 | self.assertEqual( 51 | signature.parameters['y'].annotation, inspect.Signature.empty) 52 | self.assertEqual(signature.parameters['y'].default, inspect.Signature.empty) 53 | self.assertIs(signature.return_annotation, inspect.Signature.empty) 54 | self.assertEqual(func(1, 2), 3) 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /pyglove/core/coding/permissions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python code permissions.""" 15 | 16 | import contextlib 17 | import enum 18 | from typing import Optional 19 | 20 | from pyglove.core import utils 21 | 22 | 23 | class CodePermission(enum.Flag): 24 | """Permissions for code execution.""" 25 | 26 | # Allows assignment. 27 | ASSIGN = enum.auto() 28 | 29 | # Allows conditions. 30 | CONDITION = enum.auto() 31 | 32 | # Allows loops. 33 | LOOP = enum.auto() 34 | 35 | # Call functions or methods. 36 | CALL = enum.auto() 37 | 38 | # Allows exception. 39 | EXCEPTION = enum.auto() 40 | 41 | # Allows class definitions. 42 | CLASS_DEFINITION = enum.auto() 43 | 44 | # Allows function definitions. 45 | FUNCTION_DEFINITION = enum.auto() 46 | 47 | # Allows import. 48 | IMPORT = enum.auto() 49 | 50 | @classmethod 51 | @property 52 | def BASIC(cls) -> 'CodePermission': # pylint: disable=invalid-name 53 | """Returns basic permissions.""" 54 | return CodePermission.ASSIGN | CodePermission.CALL 55 | 56 | @classmethod 57 | @property 58 | def ALL(cls) -> 'CodePermission': # pylint: disable=invalid-name 59 | """Returns all permissions.""" 60 | return ( 61 | CodePermission.BASIC | CodePermission.CONDITION | CodePermission.LOOP | 62 | CodePermission.EXCEPTION | CodePermission.CLASS_DEFINITION | 63 | CodePermission.FUNCTION_DEFINITION | CodePermission.IMPORT) 64 | 65 | 66 | _TLS_CODE_RUN_PERMISSION = '__code_run_permission__' 67 | 68 | 69 | @contextlib.contextmanager 70 | def permission(perm: CodePermission): 71 | """Context manager for controling the permission for code execution. 72 | 73 | When the `permission` context manager is nested, the outtermost permission 74 | will be used. This design allows users to control permission at the top level. 75 | 76 | Args: 77 | perm: Code execution permission. 78 | 79 | Yields: 80 | Actual permission applied. 81 | """ 82 | 83 | outter_perm = utils.thread_local_get(_TLS_CODE_RUN_PERMISSION, None) 84 | 85 | # Use the top-level permission as the actual permission 86 | if outter_perm is not None: 87 | perm = outter_perm 88 | 89 | utils.thread_local_set(_TLS_CODE_RUN_PERMISSION, perm) 90 | 91 | try: 92 | yield perm 93 | finally: 94 | if outter_perm is None: 95 | utils.thread_local_del(_TLS_CODE_RUN_PERMISSION) 96 | 97 | 98 | def get_permission() -> Optional[CodePermission]: 99 | """Gets the current permission for code execution.""" 100 | return utils.thread_local_get(_TLS_CODE_RUN_PERMISSION, None) 101 | -------------------------------------------------------------------------------- /pyglove/core/coding/permissions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import unittest 15 | from pyglove.core.coding import permissions 16 | 17 | 18 | class CodePermissionTest(unittest.TestCase): 19 | 20 | def assert_set( 21 | self, 22 | permission: permissions.CodePermission, 23 | flag: permissions.CodePermission, 24 | ): 25 | self.assertEqual(permission & flag, flag) 26 | 27 | def assert_not_set( 28 | self, 29 | permission: permissions.CodePermission, 30 | flag: permissions.CodePermission, 31 | ): 32 | self.assertFalse(permission & flag) 33 | 34 | def test_basic(self): 35 | self.assert_set( 36 | permissions.CodePermission.BASIC, permissions.CodePermission.ASSIGN 37 | ) 38 | self.assert_set( 39 | permissions.CodePermission.BASIC, permissions.CodePermission.CALL 40 | ) 41 | self.assert_set( 42 | permissions.CodePermission.BASIC, permissions.CodePermission.CALL 43 | ) 44 | 45 | def test_all(self): 46 | self.assert_set( 47 | permissions.CodePermission.ALL, permissions.CodePermission.BASIC 48 | ) 49 | self.assert_set( 50 | permissions.CodePermission.ALL, permissions.CodePermission.CONDITION 51 | ) 52 | self.assert_set( 53 | permissions.CodePermission.ALL, permissions.CodePermission.LOOP 54 | ) 55 | self.assert_set( 56 | permissions.CodePermission.ALL, permissions.CodePermission.EXCEPTION 57 | ) 58 | self.assert_set( 59 | permissions.CodePermission.ALL, 60 | permissions.CodePermission.CLASS_DEFINITION, 61 | ) 62 | self.assert_set( 63 | permissions.CodePermission.ALL, 64 | permissions.CodePermission.FUNCTION_DEFINITION, 65 | ) 66 | self.assert_set( 67 | permissions.CodePermission.ALL, permissions.CodePermission.IMPORT 68 | ) 69 | 70 | def test_xor(self): 71 | self.assert_not_set( 72 | permissions.CodePermission.ALL ^ permissions.CodePermission.BASIC, 73 | permissions.CodePermission.BASIC, 74 | ) 75 | self.assert_set( 76 | permissions.CodePermission.ALL ^ permissions.CodePermission.BASIC, 77 | permissions.CodePermission.CONDITION, 78 | ) 79 | 80 | def test_permission_control(self): 81 | self.assertIsNone(permissions.get_permission()) 82 | with permissions.permission(permissions.CodePermission.BASIC): 83 | self.assertEqual( 84 | permissions.get_permission(), permissions.CodePermission.BASIC 85 | ) 86 | with permissions.permission(permissions.CodePermission.ALL): 87 | self.assertEqual( 88 | permissions.get_permission(), permissions.CodePermission.BASIC 89 | ) 90 | 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /pyglove/core/detouring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Symbolic detour. 15 | 16 | It is straighforward to symbolize existing classes and functions, but in order 17 | to use them, we need to replace the classes that were used in existing code 18 | with the symbolic ones. Sometimes, we just cannot modify existing source code. 19 | Or in some other cases, objects created within a function or a class method are 20 | not exposed to the external, therefore we cannot manipulate them as a part of 21 | the symbolic tree. For example:: 22 | 23 | @pg.symbolize 24 | def foo(): 25 | # Object `a` is not a part of `foo`'s interface, 26 | # therefore it cannot be seen from the symbolic tree 27 | # that contains a `foo` object. 28 | a = A(1) 29 | return a.do_something() 30 | 31 | Symbolic detour is introduced to address these use cases, which redirects 32 | the ``__new__`` method of a class to another class or function when it’s 33 | evaluated under a context manager. Symbolic detour is not dependent on 34 | symbolization, so in theory it can be used for detouring any classes. 35 | Therefore, it does not require the presence of symbolic objects for mutating 36 | the program. 37 | """ 38 | 39 | # pylint: disable=g-bad-import-order 40 | 41 | from pyglove.core.detouring.class_detour import detour 42 | from pyglove.core.detouring.class_detour import current_mappings 43 | from pyglove.core.detouring.class_detour import undetoured_new 44 | 45 | # pylint: enable=g-bad-import-order 46 | -------------------------------------------------------------------------------- /pyglove/core/geno/random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Random DNA generator.""" 15 | 16 | import random 17 | import types 18 | from typing import Any, Optional, Union 19 | 20 | from pyglove.core import symbolic 21 | from pyglove.core import typing as pg_typing 22 | from pyglove.core.geno.base import DNA 23 | from pyglove.core.geno.base import DNASpec 24 | from pyglove.core.geno.dna_generator import DNAGenerator 25 | 26 | 27 | @symbolic.members([ 28 | ('seed', pg_typing.Int().noneable(), 'Random seed.') 29 | ]) 30 | class Random(DNAGenerator): 31 | """Random DNA generator.""" 32 | 33 | def _setup(self): 34 | """Setup DNA spec.""" 35 | if self.seed is None: 36 | self._random = random 37 | else: 38 | self._random = random.Random(self.seed) 39 | 40 | def _propose(self) -> DNA: 41 | """Propose a random DNA.""" 42 | return random_dna(self._dna_spec, self._random) 43 | 44 | def _replay(self, trial_id: int, dna: DNA, reward: Any) -> None: 45 | """Replay the history to recover the last proposed DNA.""" 46 | # NOTE(daiyip): If the seed is fixed, we want to reproduce the same 47 | # sequence of random examples, we can do this simply by repeating previous 48 | # generation process. 49 | if self.seed is not None: 50 | random_dna(self._dna_spec, self._random) 51 | 52 | 53 | def random_dna( 54 | dna_spec: DNASpec, 55 | random_generator: Union[None, types.ModuleType, random.Random] = None, 56 | attach_spec: bool = True, 57 | previous_dna: Optional[DNA] = None 58 | ) -> DNA: 59 | """Generates a random DNA from a DNASpec. 60 | 61 | Example:: 62 | 63 | spec = pg.geno.space([ 64 | pg.geno.oneof([ 65 | pg.geno.constant(), 66 | pg.geno.constant(), 67 | pg.geno.constant() 68 | ]), 69 | pg.geno.floatv(0.1, 0.2) 70 | ]) 71 | 72 | print(pg.random_dna(spec)) 73 | # DNA([2, 0.1123]) 74 | 75 | Args: 76 | dna_spec: a DNASpec object. 77 | random_generator: a Python random generator. 78 | attach_spec: If True, attach the DNASpec to generated DNA. 79 | previous_dna: An optional DNA representing previous DNA. This field might 80 | be useful for generating stateful random DNAs. 81 | 82 | Returns: 83 | A DNA object. 84 | """ 85 | return dna_spec.random_dna( 86 | random_generator or random, attach_spec, previous_dna) 87 | 88 | -------------------------------------------------------------------------------- /pyglove/core/geno/sweeping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Sweeping DNA generator.""" 15 | 16 | from typing import Any 17 | 18 | from pyglove.core.geno.base import DNA 19 | from pyglove.core.geno.dna_generator import DNAGenerator 20 | 21 | 22 | class Sweeping(DNAGenerator): 23 | """Sweeping (Grid Search) DNA generator.""" 24 | 25 | def _setup(self): 26 | """Setup DNA spec.""" 27 | self._last_proposed_dna = None 28 | 29 | def _propose(self) -> DNA: 30 | """Propose a random DNA.""" 31 | next_dna = self.dna_spec.next_dna(self._last_proposed_dna) 32 | if next_dna is None: 33 | raise StopIteration() 34 | self._last_proposed_dna = next_dna 35 | return next_dna 36 | 37 | def _replay(self, trial_id: int, dna: DNA, reward: Any) -> None: 38 | """Replay the history to recover the last proposed DNA.""" 39 | del trial_id, reward 40 | self._last_proposed_dna = dna 41 | -------------------------------------------------------------------------------- /pyglove/core/geno/sweeping_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain algo copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.geno.Sweeping.""" 15 | 16 | import unittest 17 | 18 | from pyglove.core.geno.base import DNA 19 | from pyglove.core.geno.categorical import manyof 20 | from pyglove.core.geno.categorical import oneof 21 | from pyglove.core.geno.space import constant 22 | from pyglove.core.geno.space import Space 23 | from pyglove.core.geno.sweeping import Sweeping 24 | 25 | 26 | class SweepingTest(unittest.TestCase): 27 | """Test the `pg.geno.Sweeping`.""" 28 | 29 | def _dna_spec(self): 30 | return Space([ 31 | # Single choice. 32 | oneof([ 33 | manyof(2, [constant(), constant(), constant()]), 34 | oneof([constant(), constant()]) 35 | ]), 36 | manyof(2, [constant(), constant(), constant()], sorted=True), 37 | ]) 38 | 39 | def test_propose(self): 40 | algo = Sweeping() 41 | algo.setup(self._dna_spec()) 42 | results = [] 43 | while True: 44 | try: 45 | results.append(algo.propose()) 46 | except StopIteration: 47 | break 48 | 49 | self.assertEqual(results, [ 50 | DNA([(0, [0, 1]), [0, 1]]), 51 | DNA([(0, [0, 1]), [0, 2]]), 52 | DNA([(0, [0, 1]), [1, 2]]), 53 | DNA([(0, [0, 2]), [0, 1]]), 54 | DNA([(0, [0, 2]), [0, 2]]), 55 | DNA([(0, [0, 2]), [1, 2]]), 56 | DNA([(0, [1, 0]), [0, 1]]), 57 | DNA([(0, [1, 0]), [0, 2]]), 58 | DNA([(0, [1, 0]), [1, 2]]), 59 | DNA([(0, [1, 2]), [0, 1]]), 60 | DNA([(0, [1, 2]), [0, 2]]), 61 | DNA([(0, [1, 2]), [1, 2]]), 62 | DNA([(0, [2, 0]), [0, 1]]), 63 | DNA([(0, [2, 0]), [0, 2]]), 64 | DNA([(0, [2, 0]), [1, 2]]), 65 | DNA([(0, [2, 1]), [0, 1]]), 66 | DNA([(0, [2, 1]), [0, 2]]), 67 | DNA([(0, [2, 1]), [1, 2]]), 68 | DNA([(1, 0), [0, 1]]), 69 | DNA([(1, 0), [0, 2]]), 70 | DNA([(1, 0), [1, 2]]), 71 | DNA([(1, 1), [0, 1]]), 72 | DNA([(1, 1), [0, 2]]), 73 | DNA([(1, 1), [1, 2]]) 74 | ]) 75 | 76 | def test_recover(self): 77 | algo1 = Sweeping() 78 | algo1.setup(self._dna_spec()) 79 | dna_list = [algo1.propose() for _ in range(10)] 80 | algo2 = algo1.clone(deep=True) 81 | algo2.setup(self._dna_spec()) 82 | algo2.recover([(dna, 0.) for dna in dna_list]) 83 | self.assertEqual(algo1.propose(), algo2.propose()) 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /pyglove/core/hyper/custom_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import random 15 | import unittest 16 | 17 | from pyglove.core import geno 18 | from pyglove.core import symbolic 19 | from pyglove.core import utils 20 | from pyglove.core.hyper.categorical import oneof 21 | from pyglove.core.hyper.custom import CustomHyper 22 | from pyglove.core.hyper.iter import iterate 23 | from pyglove.core.hyper.object_template import materialize 24 | 25 | 26 | class IntSequence(CustomHyper): 27 | 28 | def custom_decode(self, dna): 29 | return [int(v) for v in dna.value.split(',')] 30 | 31 | 32 | class IntSequenceWithEncode(IntSequence): 33 | 34 | def custom_encode(self, value): 35 | return geno.DNA(','.join([str(v) for v in value])) 36 | 37 | def next_dna(self, dna): 38 | if dna is None: 39 | return geno.DNA(','.join([str(i) for i in range(5)])) 40 | v = self.custom_decode(dna) 41 | v.append(len(v)) 42 | return self._create_dna(v) 43 | 44 | def random_dna(self, random_generator, previous_dna): 45 | del previous_dna 46 | k = random_generator.randint(0, 10) 47 | v = random_generator.choices(list(range(10)), k=k) 48 | return self._create_dna(v) 49 | 50 | def _create_dna(self, numbers): 51 | return geno.DNA(','.join([str(n) for n in numbers])) 52 | 53 | 54 | class CustomHyperTest(unittest.TestCase): 55 | """Test for CustomHyper.""" 56 | 57 | def test_dna_spec(self): 58 | self.assertTrue( 59 | symbolic.eq( 60 | IntSequence(hints='x').dna_spec('a'), 61 | geno.CustomDecisionPoint( 62 | hyper_type='IntSequence', location=utils.KeyPath('a'), hints='x' 63 | ), 64 | ) 65 | ) 66 | 67 | def test_decode(self): 68 | self.assertEqual(IntSequence().decode(geno.DNA('0,1,2')), [0, 1, 2]) 69 | self.assertEqual(IntSequence().decode(geno.DNA('0')), [0]) 70 | with self.assertRaisesRegex(ValueError, '.* expects string type DNA'): 71 | IntSequence().decode(geno.DNA(1)) 72 | 73 | def test_encode(self): 74 | self.assertEqual( 75 | IntSequenceWithEncode().encode([0, 1, 2]), geno.DNA('0,1,2')) 76 | 77 | with self.assertRaisesRegex( 78 | NotImplementedError, '\'custom_encode\' is not supported by'): 79 | _ = IntSequence().encode([0, 1, 2]) 80 | 81 | def test_random_dna(self): 82 | self.assertEqual( 83 | geno.random_dna( 84 | IntSequenceWithEncode().dna_spec('a'), random.Random(1)), 85 | geno.DNA('5,8')) 86 | 87 | with self.assertRaisesRegex( 88 | NotImplementedError, '`random_dna` is not implemented in .*'): 89 | geno.random_dna(IntSequence().dna_spec('a')) 90 | 91 | def test_iter(self): 92 | self.assertEqual(IntSequenceWithEncode().first_dna(), geno.DNA('0,1,2,3,4')) 93 | self.assertEqual( 94 | list(iterate(IntSequenceWithEncode(), 3)), 95 | [[0, 1, 2, 3, 4], 96 | [0, 1, 2, 3, 4, 5], 97 | [0, 1, 2, 3, 4, 5, 6]]) 98 | 99 | with self.assertRaisesRegex( 100 | NotImplementedError, '`next_dna` is not implemented in .*'): 101 | next(iterate(IntSequence())) 102 | 103 | def test_interop_with_other_primitives(self): 104 | v = oneof([IntSequence(), 1, 2]) 105 | self.assertEqual(materialize(v, geno.DNA(1)), 1) 106 | self.assertEqual(materialize(v, geno.DNA((0, '3,4'))), [3, 4]) 107 | 108 | 109 | if __name__ == '__main__': 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /pyglove/core/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pluggable IO with unified interfaces.""" 15 | 16 | # pylint: disable=g-bad-import-order 17 | 18 | from pyglove.core.io.file_system import * 19 | from pyglove.core.io.sequence import * 20 | 21 | # pylint: enable=g-bad-import-order 22 | -------------------------------------------------------------------------------- /pyglove/core/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Logging for PyGlove. 15 | 16 | This module allows PyGlove to use external created logger for logging PyGlove 17 | events without introducing library dependencies in PyGlove. 18 | """ 19 | 20 | import inspect 21 | import logging 22 | from typing import Any, Callable, List, Union 23 | 24 | 25 | _DEFAULT_LOGGER = logging.getLogger() 26 | 27 | 28 | def register_frame_to_skip( 29 | method: Union[Callable[..., Any], List[Callable[..., Any]]] 30 | ) -> bool: 31 | """Skips the source of the given method when logging. 32 | 33 | Args: 34 | method: The method to skip. Can be a single method or a list of methods. 35 | 36 | Returns: 37 | True if the method is registered to skip. 38 | 39 | Raises: 40 | TypeError: The source file of the method cannot be inspected. 41 | """ 42 | register_fn = getattr( 43 | _DEFAULT_LOGGER.__class__, 'register_frame_to_skip', None 44 | ) 45 | if register_fn is None: 46 | return False 47 | methods = [method] if not isinstance(method, list) else method 48 | for m in methods: 49 | register_fn(inspect.getsourcefile(m), m.__name__) 50 | return True 51 | 52 | 53 | def set_logger(logger: logging.Logger) -> None: 54 | """Sets current logger.""" 55 | global _DEFAULT_LOGGER 56 | _DEFAULT_LOGGER = logger 57 | 58 | # Skip logging frames in pyglove.logging. 59 | register_frame_to_skip([debug, info, warning, error, critical]) 60 | 61 | 62 | def get_logger() -> logging.Logger: 63 | """Gets the current logger.""" 64 | return _DEFAULT_LOGGER 65 | 66 | 67 | def debug(msg: str, *args, **kwargs) -> None: 68 | """Logs debug message. 69 | 70 | Args: 71 | msg: Message with possible format string. 72 | *args: Values for variables in the format string. 73 | **kwargs: Keyword arguments for the logger. 74 | """ 75 | _DEFAULT_LOGGER.debug(msg, *args, **kwargs) 76 | 77 | 78 | def info(msg: str, *args, **kwargs) -> None: 79 | """Logs info message. 80 | 81 | Args: 82 | msg: Message with possible format string. 83 | *args: Values for variables in the format string. 84 | **kwargs: Keyword arguments for the logger. 85 | """ 86 | _DEFAULT_LOGGER.info(msg, *args, **kwargs) 87 | 88 | 89 | def warning(msg: str, *args, **kwargs) -> None: 90 | """Logs warning message. 91 | 92 | Args: 93 | msg: Message with possible format string. 94 | *args: Values for variables in the format string. 95 | **kwargs: Keyword arguments for the logger. 96 | """ 97 | _DEFAULT_LOGGER.warning(msg, *args, **kwargs) 98 | 99 | 100 | def error(msg: str, *args, **kwargs) -> None: 101 | """Logs error message. 102 | 103 | Args: 104 | msg: Message with possible format string. 105 | *args: Values for variables in the format string. 106 | **kwargs: Keyword arguments for the logger. 107 | """ 108 | _DEFAULT_LOGGER.error(msg, *args, **kwargs) 109 | 110 | 111 | def critical(msg: str, *args, **kwargs) -> None: 112 | """Logs critical message. 113 | 114 | Args: 115 | msg: Message with possible format string. 116 | *args: Values for variables in the format string. 117 | **kwargs: Keyword arguments for the logger. 118 | """ 119 | _DEFAULT_LOGGER.critical(msg, *args, **kwargs) 120 | -------------------------------------------------------------------------------- /pyglove/core/logging_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import io 15 | import logging 16 | import unittest 17 | from pyglove.core import logging as pg_logging 18 | 19 | 20 | class LoggingTest(unittest.TestCase): 21 | """Tests for pg.logging.""" 22 | 23 | def testLogging(self): 24 | string_io = io.StringIO() 25 | logger = logging.getLogger('logger1') 26 | logger.setLevel(logging.INFO) 27 | console_handler = logging.StreamHandler(stream=string_io) 28 | console_handler.setLevel(logging.INFO) 29 | logger.addHandler(console_handler) 30 | 31 | self.assertIs(pg_logging.get_logger(), logging.getLogger()) 32 | pg_logging.set_logger(logger) 33 | self.assertIs(pg_logging.get_logger(), logger) 34 | 35 | pg_logging.debug('x=%s', 1) 36 | pg_logging.info('y=%s', 2) 37 | pg_logging.warning('z=%s', 3) 38 | pg_logging.error('p=%s', 4) 39 | pg_logging.critical('q=%s', 5) 40 | 41 | self.assertEqual(string_io.getvalue(), '\n'.join([ 42 | 'y=2', 43 | 'z=3', 44 | 'p=4', 45 | 'q=5', 46 | ]) + '\n') 47 | 48 | string_io = io.StringIO() 49 | logger = logging.getLogger('logger2') 50 | logger.setLevel(logging.DEBUG) 51 | console_handler = logging.StreamHandler(stream=string_io) 52 | console_handler.setLevel(logging.DEBUG) 53 | logger.addHandler(console_handler) 54 | 55 | pg_logging.set_logger(logger) 56 | 57 | pg_logging.debug('x=%s', 6) 58 | self.assertEqual(string_io.getvalue(), '\n'.join([ 59 | 'x=6', 60 | ]) + '\n') 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /pyglove/core/patching/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Systematic patching on symbolic values. 15 | 16 | As :meth:`pyglove.Symbolic.rebind` provides a flexible programming 17 | interface for modifying symbolic values, why bother to have this module? 18 | Here are the motivations: 19 | 20 | * Provide user friendly methods for addressing the most common patching 21 | patterns. 22 | 23 | * Provide a systematic solution for 24 | 25 | * Patch semantic groups. 26 | * Enable combination of these groups. 27 | * Provide an interface that patching can be invoked from the command line. 28 | """ 29 | 30 | # pylint: disable=g-bad-import-order 31 | 32 | # Pattern-based patching. 33 | 34 | from pyglove.core.patching.pattern_based import patch_on_key 35 | from pyglove.core.patching.pattern_based import patch_on_path 36 | from pyglove.core.patching.pattern_based import patch_on_type 37 | from pyglove.core.patching.pattern_based import patch_on_value 38 | from pyglove.core.patching.pattern_based import patch_on_member 39 | 40 | # Patcher: modular rule-based patching. 41 | from pyglove.core.patching.rule_based import patcher 42 | from pyglove.core.patching.rule_based import patch 43 | 44 | from pyglove.core.patching.rule_based import Patcher 45 | from pyglove.core.patching.rule_based import from_uri 46 | 47 | from pyglove.core.patching.rule_based import patcher_names 48 | from pyglove.core.patching.rule_based import allow_repeated_patcher_registration 49 | 50 | # Object factory based on patchers. 51 | from pyglove.core.patching.object_factory import ObjectFactory 52 | from pyglove.core.patching.object_factory import from_maybe_serialized 53 | 54 | 55 | # pylint: enable=g-bad-import-order 56 | -------------------------------------------------------------------------------- /pyglove/core/patching/object_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Object factory based on patchers.""" 15 | 16 | from typing import Any, Callable, Dict, Optional, Type, Union 17 | from pyglove.core import symbolic 18 | from pyglove.core import utils 19 | from pyglove.core.patching import rule_based 20 | 21 | 22 | def from_maybe_serialized( 23 | source: Union[Any, str], 24 | value_type: Optional[Type[Any]] = None) -> Any: 25 | """Load value from maybe serialized form (e.g. JSON file or JSON string). 26 | 27 | Args: 28 | source: Source of value. It can be value (non-string type) itself, or a 29 | filepath, or a JSON string from where the value will be loaded. 30 | value_type: An optional type to constrain the value. 31 | 32 | Returns: 33 | Value from source. 34 | """ 35 | if isinstance(source, str): 36 | if source.endswith('.json'): 37 | value = symbolic.load(source) 38 | else: 39 | value = symbolic.from_json_str(source) 40 | else: 41 | value = source 42 | if value_type is not None and not isinstance(value, value_type): 43 | raise TypeError( 44 | f'Loaded value {value!r} is not an instance of {value_type!r}.') 45 | return value 46 | 47 | 48 | @symbolic.functor() 49 | def ObjectFactory( # pylint: disable=invalid-name 50 | value_type: Type[symbolic.Symbolic], 51 | base_value: Union[symbolic.Symbolic, 52 | Callable[[], symbolic.Symbolic], 53 | str], 54 | patches: Optional[rule_based.PatchType] = None, 55 | params_override: Optional[Union[Dict[str, Any], str]] = None) -> Any: 56 | """A factory to create symbolic object from a base value and patches. 57 | 58 | Args: 59 | value_type: Type of return value. 60 | base_value: An instance of `value_type`, 61 | or a callable object that produces an instance of `value_type`, 62 | or a string as the path to the serialized value. 63 | patches: Optional patching rules. See :func:`patch` for details. 64 | params_override: A rebind dict (or a JSON string as serialized rebind dict) 65 | as an additional patch to the value, 66 | 67 | Returns: 68 | Value after applying `patchers` and `params_override` based on `base_value`. 69 | """ 70 | # Step 1: Load base value. 71 | if not isinstance(base_value, value_type) and callable(base_value): 72 | value = base_value() 73 | elif isinstance(base_value, str): 74 | value = symbolic.load(base_value) 75 | else: 76 | value = base_value 77 | 78 | if not isinstance(value, value_type): 79 | raise TypeError( 80 | f'{base_value!r} is neither an instance of {value_type!r}, ' 81 | f'nor a factory or a path of JSON file that produces an ' 82 | f'instance of {value_type!r}.') 83 | 84 | # Step 2: Patch with patchers if available. 85 | if patches is not None: 86 | value = rule_based.patch(value, patches) 87 | 88 | # Step 3: Patch with additional parameter override dict if available. 89 | if params_override: 90 | value = value.rebind( 91 | utils.flatten(from_maybe_serialized(params_override, dict)), 92 | raise_on_no_change=False, 93 | ) 94 | return value 95 | -------------------------------------------------------------------------------- /pyglove/core/patching/pattern_based_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pattern-based patching.""" 15 | 16 | import unittest 17 | from pyglove.core import symbolic 18 | from pyglove.core import typing as pg_typing 19 | from pyglove.core.patching import pattern_based 20 | 21 | 22 | class PatternBasedPatchingTest(unittest.TestCase): 23 | """Pattern-based patching test.""" 24 | 25 | def test_patch_on_key(self): 26 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1}) 27 | pattern_based.patch_on_key(d, 'a', 3) 28 | self.assertEqual(d, {'a': 3, 'b': {'a': 3, 'b': 1}}) 29 | 30 | pattern_based.patch_on_key(d, 'a', value_fn=lambda v: v + 1) 31 | self.assertEqual(d, {'a': 4, 'b': {'a': 4, 'b': 1}}) 32 | 33 | with self.assertRaisesRegex( 34 | ValueError, 'Either `value` or `value_fn` should be specified'): 35 | pattern_based.patch_on_key(d, 'a', value=1, value_fn=lambda v: v + 1) 36 | 37 | def test_patch_on_path(self): 38 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1}) 39 | pattern_based.patch_on_path(d, '.+b', 3) 40 | self.assertEqual(d, {'a': 1, 'b': {'a': 2, 'b': 3}}) 41 | 42 | pattern_based.patch_on_path(d, '.*a', value_fn=lambda v: v + 1) 43 | self.assertEqual(d, {'a': 2, 'b': {'a': 3, 'b': 3}}) 44 | 45 | def test_patch_on_value(self): 46 | d = symbolic.Dict(a=1, b={'a': 2, 'b': 1}) 47 | pattern_based.patch_on_value(d, 1, 3) 48 | self.assertEqual(d, {'a': 3, 'b': {'a': 2, 'b': 3}}) 49 | 50 | pattern_based.patch_on_value(d, 2, value_fn=lambda v: v * 2) 51 | self.assertEqual(d, {'a': 3, 'b': {'a': 4, 'b': 3}}) 52 | 53 | def test_patch_on_type(self): 54 | d = symbolic.Dict(a='abc', b={'a': 2, 'b': 'def'}) 55 | pattern_based.patch_on_type(d, str, 'foo') 56 | self.assertEqual(d, {'a': 'foo', 'b': {'a': 2, 'b': 'foo'}}) 57 | 58 | pattern_based.patch_on_type(d, int, value_fn=lambda v: v * 2) 59 | self.assertEqual(d, {'a': 'foo', 'b': {'a': 4, 'b': 'foo'}}) 60 | 61 | def test_patch_on_member(self): 62 | 63 | @symbolic.members([ 64 | ('x', pg_typing.Int()), 65 | ('y', pg_typing.Int()), 66 | ]) 67 | class A(symbolic.Object): 68 | pass 69 | 70 | d = symbolic.Dict(a=A(x=1, y=2), x=1) 71 | pattern_based.patch_on_member(d, A, 'x', 2) 72 | self.assertEqual(d, {'a': A(x=2, y=2), 'x': 1}) 73 | 74 | pattern_based.patch_on_member(d, A, 'y', value_fn=lambda v: v * 2) 75 | self.assertEqual(d, {'a': A(x=2, y=4), 'x': 1}) 76 | 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /pyglove/core/symbolic/error_info_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import unittest 17 | from pyglove.core.symbolic import base 18 | from pyglove.core.symbolic import error_info as error_info_lib # pylint: disable=g-bad-import-order 19 | 20 | 21 | class ErrorInfoTest(unittest.TestCase): 22 | """Tests for ErrorInfo.""" 23 | 24 | def test_from_exception(self): 25 | 26 | def foo(): 27 | return 1 / 0 28 | 29 | def bar(): 30 | try: 31 | foo() 32 | except ZeroDivisionError as e: 33 | raise ValueError('Bad call to `foo`') from e 34 | 35 | error_info = None 36 | try: 37 | bar() 38 | except ValueError as e: 39 | error_info = error_info_lib.ErrorInfo.from_exception(e) 40 | self.assertIsNotNone(error_info) 41 | self.assertEqual(error_info.tag, 'ValueError.ZeroDivisionError') 42 | self.assertEqual(error_info.description, 'Bad call to `foo`') 43 | self.assertIn('Traceback (most recent call last)', error_info.stacktrace) 44 | 45 | def test_to_json(self): 46 | error_info = error_info_lib.ErrorInfo( 47 | tag='ValueError.ZeroDivisionError', 48 | description='Bad call to `foo`', 49 | stacktrace='Traceback (most recent call last)', 50 | ) 51 | json_dict = error_info.to_json() 52 | error_info2 = base.from_json(json_dict) 53 | self.assertIsNot(error_info2, error_info) 54 | self.assertEqual(error_info2, error_info) 55 | json_dict['_type'] = 'pyglove.core.utils.error_utils.ErrorInfo' 56 | error_info2 = base.from_json(json_dict) 57 | self.assertEqual(error_info2, error_info) 58 | 59 | def test_format(self): 60 | error_info = error_info_lib.ErrorInfo( 61 | tag='ValueError.ZeroDivisionError', 62 | description='Bad call to `foo`', 63 | stacktrace='Traceback (most recent call last)', 64 | ) 65 | self.assertEqual( 66 | error_info.format(compact=False), 67 | inspect.cleandoc( 68 | """ 69 | ErrorInfo( 70 | tag = 'ValueError.ZeroDivisionError', 71 | description = 'Bad call to `foo`', 72 | stacktrace = 'Traceback (most recent call last)' 73 | ) 74 | """ 75 | ) 76 | ) 77 | 78 | def test_to_html(self): 79 | error_info = error_info_lib.ErrorInfo( 80 | tag='ValueError.ZeroDivisionError', 81 | description='Bad call to `foo`', 82 | stacktrace='Traceback (most recent call last)', 83 | ) 84 | html = error_info.to_html() 85 | self.assertIn('ErrorInfo', html.content) 86 | self.assertIn('ValueError.ZeroDivisionError', html.content) 87 | self.assertIn('Bad call to `foo`', html.content) 88 | self.assertIn('Traceback (most recent call last)', html.content) 89 | 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /pyglove/core/symbolic/inferred.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common inferential values.""" 15 | 16 | from typing import Any, Tuple 17 | from pyglove.core import typing as pg_typing 18 | from pyglove.core import utils 19 | from pyglove.core.symbolic import base 20 | from pyglove.core.symbolic.object import Object 21 | 22 | 23 | class InferredValue(Object, base.Inferential): 24 | """Base class for inferred values.""" 25 | 26 | def custom_apply(self, *args, **kwargs: Any) -> Tuple[bool, Any]: 27 | # This is to make a ``InferredValue`` object assignable 28 | # to any symbolic attribute. 29 | return (False, self) 30 | 31 | 32 | class ValueFromParentChain(InferredValue): 33 | """A value that could inferred from the parent chain. 34 | 35 | For example:: 36 | 37 | class A(pg.Object): 38 | x: int 39 | y: int = pg.symbolic.ValueFromParentChain() 40 | 41 | # Not okay: `x` is not inferential and is not specified. 42 | A() 43 | 44 | # Okay: both `x` and `y` are specified. 45 | A(x=1, y=2) 46 | 47 | # Okay: `y` is inferential, hence optional. 48 | a = A(x=1) 49 | 50 | # Raises: `y` is neither specified during __init__ 51 | # nor provided from the context. 52 | a.y 53 | 54 | d = pg.Dict(y=2, z=pg.Dict(a=a)) 55 | 56 | # `a.y` now refers to `d.a` since `d` is in its symbolic parent chain, 57 | # aka. context. 58 | assert a.y == 2 59 | """ 60 | 61 | def infer(self, **kwargs) -> Any: 62 | parent = self.sym_parent 63 | while True: 64 | v = self.value_from(parent, **kwargs) 65 | if v == pg_typing.MISSING_VALUE: 66 | if parent is None: 67 | raise AttributeError( 68 | utils.message_on_path( 69 | ( 70 | f'`{self.inference_key}` is not found under its context ' 71 | '(along its symbolic parent chain).' 72 | ), 73 | self.sym_path, 74 | ) 75 | ) 76 | parent = parent.sym_parent 77 | else: 78 | return v 79 | 80 | @property 81 | def inference_key(self) -> str: 82 | """Returns the key for attribute inference from parents.""" 83 | return self.sym_path.key 84 | 85 | def value_from(self, parent: base.Symbolic, **kwargs) -> Any: 86 | del kwargs 87 | if parent is self.sym_parent or parent is None: 88 | # NOTE(daiyip): The inferred value could not be read from the immediate 89 | # parent since its key points to current inferential value. 90 | # We should also return MISSING_VALUE when the traversal has gone beyond 91 | # of the symbolic tree root. 92 | return pg_typing.MISSING_VALUE 93 | 94 | # Use current key to lookup from the parent. 95 | key = self.inference_key 96 | if isinstance(key, int): 97 | return ( 98 | parent[key] 99 | if isinstance(parent, (list, tuple)) 100 | else pg_typing.MISSING_VALUE 101 | ) 102 | return getattr(parent, key, pg_typing.MISSING_VALUE) 103 | -------------------------------------------------------------------------------- /pyglove/core/symbolic/inferred_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.symbolic.inferred.""" 15 | 16 | import unittest 17 | 18 | from pyglove.core import typing as pg_typing 19 | from pyglove.core.symbolic import inferred 20 | from pyglove.core.symbolic.dict import Dict 21 | 22 | 23 | class ValueFromParentChain(unittest.TestCase): 24 | """Tests for `pg.symbolic.ValueFromParentChain`.""" 25 | 26 | def test_inference(self): 27 | v = Dict(y=1, x=Dict(x=1, y=inferred.ValueFromParentChain())) 28 | self.assertEqual(v.x.y, 1) 29 | 30 | v.rebind(y=2) 31 | self.assertEqual(v.x.y, 2) 32 | 33 | def test_custom_typing(self): 34 | v = inferred.ValueFromParentChain() 35 | self.assertIs(pg_typing.Int().apply(v), v) 36 | self.assertIs(pg_typing.Str().apply(v), v) 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /pyglove/core/symbolic/origin_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.symbolic.Origin.""" 15 | 16 | import unittest 17 | 18 | from pyglove.core.symbolic import flags 19 | from pyglove.core.symbolic.dict import Dict 20 | from pyglove.core.symbolic.origin import Origin 21 | 22 | 23 | class OriginTest(unittest.TestCase): 24 | """Tests for `pg.symbolic.Origin`.""" 25 | 26 | def test_basics(self): 27 | a = Dict(a=1) 28 | o = Origin(a, '__init__') 29 | self.assertIs(o.source, a) 30 | self.assertEqual(o.tag, '__init__') 31 | self.assertIsNone(o.stack) 32 | self.assertIsNone(o.stacktrace) 33 | 34 | with self.assertRaisesRegex(ValueError, '`tag` must be a string'): 35 | _ = Origin(a, 1) 36 | 37 | def test_include_stacktrace(self): 38 | a = Dict(a=1) 39 | o = Origin(a, '__init__', stacktrace=True, stacklimit=3) 40 | self.assertIs(o.source, a) 41 | self.assertEqual(o.tag, '__init__') 42 | self.assertEqual(len(o.stack), 3) 43 | self.assertIsNotNone(o.stacktrace) 44 | 45 | flags.set_origin_stacktrace_limit(2) 46 | o = Origin(a, '__init__', stacktrace=True) 47 | self.assertEqual(len(o.stack), 2) 48 | 49 | def test_root(self): 50 | a = Dict(a=1) 51 | b = Dict(b=2) 52 | c = Dict(c=3) 53 | 54 | c.sym_setorigin(b, 'foo') 55 | b.sym_setorigin(a, 'bar') 56 | self.assertIs(c.sym_origin.root.source, a) 57 | 58 | def test_history(self): 59 | a = Dict(a=1) 60 | b = Dict(b=2) 61 | c = Dict(c=3) 62 | 63 | c.sym_setorigin(b, 'foo') 64 | b.sym_setorigin(a, 'bar') 65 | self.assertEqual( 66 | c.sym_origin.history(), 67 | [ 68 | Origin(a, 'bar'), 69 | Origin(b, 'foo'), 70 | ]) 71 | 72 | self.assertEqual( 73 | c.sym_origin.history(lambda o: o.tag == 'foo'), 74 | [ 75 | Origin(b, 'foo'), 76 | ]) 77 | 78 | self.assertEqual( 79 | c.sym_origin.history(lambda o: o.tag == 'bar'), 80 | [ 81 | Origin(a, 'bar'), 82 | ]) 83 | 84 | def test_eq_ne(self): 85 | a = Dict(a=1) 86 | self.assertEqual(Origin(None, '__init__'), Origin(None, '__init__')) 87 | self.assertEqual(Origin(a, 'builder'), Origin(a, 'builder')) 88 | self.assertNotEqual(Origin(a, 'builder'), a) 89 | self.assertNotEqual(Origin(a, 'builder'), Origin(a, 'return')) 90 | self.assertNotEqual(Origin(a, 'builder'), Origin(Dict(a=1), 'builder')) 91 | 92 | def test_format(self): 93 | a = Dict(a=1) 94 | o = Origin(None, '__init__') 95 | self.assertEqual(o.format(compact=True), 'Origin(tag=\'__init__\')') 96 | 97 | o = Origin('/path/to/file', 'load') 98 | self.assertEqual( 99 | o.format(compact=False), 100 | "Origin(\n tag='load',\n source='/path/to/file'\n)" 101 | ) 102 | 103 | o = Origin(a, 'builder') 104 | self.assertEqual( 105 | o.format(compact=True), 106 | 'Origin(tag=\'builder\', source={a=1} at 0x%x)' % id(a)) 107 | 108 | 109 | if __name__ == '__main__': 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /pyglove/core/symbolic/pure_symbolic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interfaces for pure symbolic objects.""" 15 | 16 | from typing import Any, Callable, Optional, Tuple 17 | from pyglove.core import typing as pg_typing 18 | from pyglove.core import utils 19 | 20 | 21 | class PureSymbolic(pg_typing.CustomTyping): 22 | """Base class to classes whose objects are considered pure symbolic. 23 | 24 | Pure symbolic objects can be used for representing abstract concepts - for 25 | example, a search space of objects - which cannot be executed but soely 26 | representational. 27 | 28 | Having pure symbolic object is a key differentiator of symbolic OOP from 29 | regular OOP, which can be used to placehold values in an object as a 30 | high-level expression of ideas. Later, with symbolic manipulation, the 31 | pure symbolic objects are replaced with material values so the object 32 | can be evaluated. This effectively decouples the expression of ideas from 33 | the implementation of ideas. For example: ``pg.oneof(['a', 'b', 'c']`` will 34 | be manipulated into 'a', 'b' or 'c' based on the decision of a search 35 | algorithm, letting the program evolve itself. 36 | """ 37 | 38 | def custom_apply( 39 | self, 40 | path: utils.KeyPath, 41 | value_spec: pg_typing.ValueSpec, 42 | allow_partial: bool, 43 | child_transform: Optional[ 44 | Callable[[utils.KeyPath, pg_typing.Field, Any], Any] 45 | ] = None, 46 | ) -> Tuple[bool, Any]: 47 | """Custom apply on a value based on its original value spec. 48 | 49 | This implements ``pg.pg_typing.CustomTyping``, allowing a pure symbolic 50 | value to be assigned to any field. To customize this behavior, override 51 | this method in subclasses. 52 | 53 | Args: 54 | path: KeyPath of current object under its object tree. 55 | value_spec: Original value spec for this field. 56 | allow_partial: Whether allow partial object to be created. 57 | child_transform: Function to transform child node values into their final 58 | values. Transform function is called on leaf nodes first, then on their 59 | parents, recursively. 60 | 61 | Returns: 62 | A tuple (proceed_with_standard_apply, value_to_proceed). 63 | If proceed_with_standard_apply is set to False, value_to_proceed 64 | will be used as final value. 65 | 66 | Raises: 67 | Error when the value is not compatible with the value spec. 68 | """ 69 | del path, value_spec, allow_partial, child_transform 70 | return (False, self) 71 | 72 | 73 | class NonDeterministic(PureSymbolic): 74 | """Base class to mark a class whose objects are considered non-deterministic. 75 | 76 | A non-deterministic value represents a value that will be decided later. 77 | In PyGlove system, `pg.one_of`, `pg.sublist_of`, `pg.float_value` are 78 | non-deterministic values. Please search `NonDeterministic` subclasses for more 79 | details. 80 | """ 81 | -------------------------------------------------------------------------------- /pyglove/core/tuning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Distributed tuning with pluggable backends. 15 | 16 | :func:`pyglove.iter` provides an interface for sampling examples from a search 17 | space within a process. To support distributed tuning, PyGlove introduces 18 | :func:`pyglove.sample`, which is almost identical but with more features: 19 | 20 | * Allow multiple worker processes (aka. workers) to collaborate on a search 21 | with failover handling. 22 | * Each worker can process different trials, or can cowork on the same trials 23 | via work groups. 24 | * Provide APIs for communicating between the co-workers. 25 | * Provide API for retrieving the search results. 26 | * Provide a pluggable backend system for supporting user infrastructures. 27 | 28 | """ 29 | 30 | # pylint: disable=g-bad-import-order 31 | 32 | # User facing APIs for tuning. 33 | from pyglove.core.tuning.sample import sample 34 | from pyglove.core.tuning.backend import poll_result 35 | 36 | from pyglove.core.tuning.backend import default_backend 37 | from pyglove.core.tuning.backend import set_default_backend 38 | 39 | # Tuning protocols. 40 | from pyglove.core.tuning.protocols import Measurement 41 | from pyglove.core.tuning.protocols import Trial 42 | from pyglove.core.tuning.protocols import Result 43 | from pyglove.core.tuning.protocols import Feedback 44 | from pyglove.core.tuning.protocols import RaceConditionError 45 | 46 | # Interface for early stopping. 47 | from pyglove.core.tuning.early_stopping import EarlyStoppingPolicy 48 | 49 | # Interfaces for tuning backend developers. 50 | from pyglove.core.tuning.backend import Backend 51 | from pyglove.core.tuning.backend import add_backend 52 | from pyglove.core.tuning.backend import available_backends 53 | 54 | # Importing local backend. 55 | import pyglove.core.tuning.local_backend 56 | 57 | # pylint: enable=g-bad-import-order 58 | -------------------------------------------------------------------------------- /pyglove/core/tuning/backend_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.core.tuning.backend.""" 15 | 16 | import unittest 17 | from pyglove.core.tuning import backend 18 | from pyglove.core.tuning import local_backend # pylint: disable=unused-import 19 | 20 | 21 | class BackendTest(unittest.TestCase): 22 | """Tests for pluggable backend.""" 23 | 24 | def test_pluggable_backend(self): 25 | self.assertEqual(backend.available_backends(), ['in-memory']) 26 | 27 | @backend.add_backend('test') 28 | class TestBackend(backend.Backend): # pylint: disable=unused-variable 29 | """A fake backend factory for testing.""" 30 | 31 | def __init__(self, **kwargs): 32 | pass 33 | 34 | @classmethod 35 | def poll_result(cls, name): 36 | return None 37 | 38 | def next(self): 39 | return None 40 | 41 | self.assertEqual(backend.available_backends(), ['in-memory', 'test']) 42 | self.assertEqual(backend.default_backend(), 'in-memory') 43 | backend.set_default_backend('test') 44 | self.assertEqual(backend.default_backend(), 'test') 45 | 46 | with self.assertRaisesRegex( 47 | ValueError, 'Backend .* does not exist'): 48 | backend.set_default_backend('non-exist-backend') 49 | 50 | with self.assertRaisesRegex( 51 | TypeError, '.* is not a `pg.tuning.Backend` subclass'): 52 | 53 | @backend.add_backend('bad') 54 | class BadBackend: # pylint: disable=unused-variable 55 | pass 56 | backend.set_default_backend('in-memory') 57 | self.assertEqual(backend.default_backend(), 'in-memory') 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /pyglove/core/tuning/early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interface for early stopping policies.""" 15 | 16 | import abc 17 | from typing import Iterable, Optional 18 | 19 | from pyglove.core import geno 20 | from pyglove.core import symbolic 21 | from pyglove.core.tuning.protocols import Trial 22 | 23 | 24 | class EarlyStoppingPolicy(symbolic.Object): 25 | """Interface for early stopping policy.""" 26 | 27 | def setup(self, dna_spec: geno.DNASpec) -> None: 28 | """Setup states of an early stopping policy based on dna_spec. 29 | 30 | Args: 31 | dna_spec: DNASpec for DNA to propose. 32 | 33 | Raises: 34 | RuntimeError: if dna_spec is not supported. 35 | """ 36 | self._dna_spec = dna_spec 37 | 38 | @property 39 | def dna_spec(self) -> Optional[geno.DNASpec]: 40 | return getattr(self, '_dna_spec', None) 41 | 42 | @abc.abstractmethod 43 | def should_stop_early(self, trial: Trial) -> bool: 44 | """Should stop the input trial early based on its measurements.""" 45 | 46 | def recover(self, history: Iterable[Trial]) -> None: 47 | """Recover states by replaying the trial history. 48 | 49 | Subclass can override. 50 | 51 | NOTE: `recover` will always be called before the first `should_stop_early` 52 | is called. It could be called multiple times if there are multiple source 53 | of history, e.g: trials from a previous study and existing trials from 54 | current study. 55 | 56 | The default behavior is to replay `should_stop_early` on all trials that 57 | contain all intermediate measurements. 58 | 59 | Args: 60 | history: An iterable object of trials. 61 | """ 62 | for trial in history: 63 | if trial.status in ['COMPLETED', 'PENDING', 'STOPPING']: 64 | # TODO(daiyip): Stopped trials do not need to be fed back. 65 | self.should_stop_early(trial) 66 | -------------------------------------------------------------------------------- /pyglove/core/tuning/protocols_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.core.tuning.protocols.""" 15 | 16 | import time 17 | import unittest 18 | 19 | from pyglove.core import geno 20 | from pyglove.core.tuning.protocols import Measurement 21 | from pyglove.core.tuning.protocols import Trial 22 | 23 | 24 | class TrialTest(unittest.TestCase): 25 | """Test for Trial class.""" 26 | 27 | def test_get_reward_for_feedback(self): 28 | t = Trial( 29 | id=0, dna=geno.DNA(0), 30 | status='PENDING', 31 | created_time=int(time.time())) 32 | self.assertIsNone(t.get_reward_for_feedback()) 33 | 34 | t = Trial( 35 | id=0, dna=geno.DNA(0), 36 | status='COMPLETED', 37 | infeasible=True, 38 | created_time=int(time.time())) 39 | self.assertIsNone(t.get_reward_for_feedback()) 40 | 41 | t = Trial( 42 | id=0, dna=geno.DNA(0), 43 | status='COMPLETED', 44 | infeasible=False, 45 | final_measurement=Measurement( 46 | step=1, elapse_secs=0.1, reward=1.0, metrics=dict( 47 | accuracy=0.9, latency=750.0)), 48 | created_time=int(time.time())) 49 | self.assertEqual(t.get_reward_for_feedback(), 1.0) 50 | self.assertEqual(t.get_reward_for_feedback(['accuracy', 'latency']), 51 | (0.9, 750.0)) 52 | with self.assertRaisesRegex( 53 | ValueError, 'Metric \'foo\' does not exist'): 54 | t.get_reward_for_feedback(['foo']) 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /pyglove/core/typing/annotated.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """pg.typing.Annotated: A drop-in replacement of typing.Annotated.""" 15 | 16 | import typing 17 | from typing import Any, Dict, Optional, Tuple 18 | 19 | from pyglove.core.typing import class_schema 20 | from pyglove.core.typing import value_specs as vs 21 | 22 | 23 | class Annotated(vs.Generic): 24 | """The PyGlove enhanced `typing.Annotated` for defining a class field. 25 | 26 | Example:: 27 | 28 | class A(pg.Object): 29 | x = pg.typing.Annotated[ 30 | int, # Field type or value spec. 31 | 'Docstring for field `x`.', # Field docstring. 32 | dict(foo=1, bar=2) # Field metadata 33 | ] 34 | ] 35 | """ 36 | 37 | def __init__( 38 | self, 39 | t: vs.ValueSpecOrAnnotation, 40 | docstring: Optional[str] = None, 41 | metadata: Optional[Dict[str, Any]] = None 42 | ): 43 | super().__init__() 44 | self._value_spec = class_schema.ValueSpec.from_annotation( 45 | t, auto_typing=True) 46 | self._docstring = docstring 47 | self._metadata = metadata or {} 48 | 49 | @property 50 | def value_spec(self) -> class_schema.ValueSpec: 51 | """Returns the value spec for the field.""" 52 | return self._value_spec 53 | 54 | @property 55 | def docstring(self) -> Optional[str]: 56 | """Returns the docstring for the field.""" 57 | return self._docstring 58 | 59 | @property 60 | def metadata(self) -> Dict[str, Any]: 61 | """Returns the metadata for the field.""" 62 | return self._metadata 63 | 64 | @classmethod 65 | def with_type_args(cls, type_args: Tuple[Any, ...]) -> Any: 66 | if len(type_args) == 1: 67 | t = type_args[0] 68 | docstring, metadata = None, None 69 | elif len(type_args) == 2: 70 | t, docstring = type_args 71 | metadata = None 72 | elif len(type_args) == 3: 73 | t, docstring, metadata = type_args 74 | else: 75 | raise TypeError( 76 | '`pg.typing.Annotated` accepts 1 to 3 type arguments ', 77 | '(, [field docstring], [field metadata]). ' 78 | f'Encountered: {type_args!r}') 79 | if docstring is not None and not isinstance(docstring, str): 80 | raise TypeError( 81 | 'The second type argument (`docstring`) must be a str. ' 82 | f'Encountered: {docstring!r}') 83 | if metadata is not None and not isinstance(metadata, dict): 84 | raise TypeError( 85 | 'The third type argument (`metadata`) must be a dict with str keys. ' 86 | f'Encountered: {metadata!r}') 87 | 88 | t = class_schema.ValueSpec.from_annotation(t, auto_typing=True) 89 | annotated = Annotated(t, docstring=docstring, metadata=metadata) 90 | 91 | # This makes `pg.typing.Annotated` compatible with third-party type checking 92 | # solutions. 93 | if typing.TYPE_CHECKING: 94 | return annotated.value_spec.annotation 95 | return annotated 96 | -------------------------------------------------------------------------------- /pyglove/core/typing/annotated_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for pyglove.core.typing.annotated.""" 15 | 16 | import typing 17 | import unittest 18 | 19 | from pyglove.core.typing import annotated 20 | from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import 21 | from pyglove.core.typing import value_specs as vs 22 | 23 | 24 | class AnnotatedTest(unittest.TestCase): 25 | 26 | def test_subscription(self): 27 | # Field type only. 28 | x = annotated.Annotated[int] 29 | self.assertEqual(x.value_spec, vs.Int()) 30 | self.assertIsNone(x.docstring) 31 | self.assertEqual(x.metadata, {}) 32 | 33 | # Field type and docstring 34 | x = annotated.Annotated[int, 'hello'] 35 | self.assertEqual(x.value_spec, vs.Int()) 36 | self.assertEqual(x.docstring, 'hello') 37 | self.assertEqual(x.metadata, {}) 38 | 39 | # Field type, docstring and metadata 40 | x = annotated.Annotated[int, 'hello', dict(foo=1)] 41 | self.assertEqual(x.value_spec, vs.Int()) 42 | self.assertEqual(x.docstring, 'hello') 43 | self.assertEqual(x.metadata, dict(foo=1)) 44 | 45 | def test_bad_subscription(self): 46 | with self.assertRaisesRegex( 47 | TypeError, '`pg.typing.Annotated` accepts 1 to 3 type arguments'): 48 | _ = annotated.Annotated[int, 'hello', dict(foo=1), 1] 49 | 50 | with self.assertRaisesRegex( 51 | TypeError, 'Cannot convert 1'): 52 | _ = annotated.Annotated[1] 53 | 54 | with self.assertRaisesRegex( 55 | TypeError, 'The second type argument .* must be a str'): 56 | _ = annotated.Annotated[int, 1] 57 | 58 | with self.assertRaisesRegex( 59 | TypeError, 'The third type argument .* must be a dict with str keys'): 60 | _ = annotated.Annotated[int, 'foo', [1, 2]] 61 | 62 | def test_type_checking(self): 63 | typing.TYPE_CHECKING = True 64 | 65 | self.assertIs(annotated.Annotated[str], str) 66 | self.assertIs(annotated.Annotated[str, 'hello'], str) 67 | self.assertIs( 68 | annotated.Annotated[vs.Str().noneable(), 'hello'], typing.Optional[str]) 69 | 70 | typing.TYPE_CHECKING = False 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /pyglove/core/typing/custom_typing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interface for intercepting typing checking.""" 15 | 16 | import abc 17 | from typing import Any, Callable, Optional, Tuple 18 | 19 | from pyglove.core import utils 20 | from pyglove.core.typing import class_schema 21 | 22 | 23 | class CustomTyping(metaclass=abc.ABCMeta): 24 | """Interface of custom value type. 25 | 26 | Instances of subclasses of CustomTyping can be assigned to fields of 27 | any ValueSpec, and take over `apply` via `custom_apply` method. 28 | 29 | As a result, CustomTyping makes the schema system extensible without modifying 30 | existing value specs. For example, value generators can extend CustomTyping 31 | and be assignable to any fields. 32 | """ 33 | 34 | @abc.abstractmethod 35 | def custom_apply( 36 | self, 37 | path: utils.KeyPath, 38 | value_spec: class_schema.ValueSpec, 39 | allow_partial: bool, 40 | child_transform: Optional[ 41 | Callable[[utils.KeyPath, class_schema.Field, Any], Any] 42 | ] = None, 43 | ) -> Tuple[bool, Any]: 44 | """Custom apply on a value based on its original value spec. 45 | 46 | Args: 47 | path: KeyPath of current object under its object tree. 48 | value_spec: Original value spec for this field. 49 | allow_partial: Whether allow partial object to be created. 50 | child_transform: Function to transform child node values into their final 51 | values. Transform function is called on leaf nodes first, then on their 52 | parents, recursively. 53 | 54 | Returns: 55 | A tuple (proceed_with_standard_apply, value_to_proceed). 56 | If proceed_with_standard_apply is set to False, value_to_proceed 57 | will be used as final value. 58 | 59 | Raises: 60 | Error when the value is not compatible with the value spec. 61 | """ 62 | -------------------------------------------------------------------------------- /pyglove/core/typing/pytype_support.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pytype support.""" 15 | 16 | import typing 17 | 18 | if typing.TYPE_CHECKING: 19 | 20 | _GenericCallable = typing.TypeVar('_GenericCallable') 21 | 22 | class Decorator(object): 23 | """A type annotation for decorators that do not change signatures. 24 | 25 | This is a stand-in for using `Callable[[T], T]` to represent a decorator. 26 | 27 | Given a decorator function, which takes in a callable and returns a callable 28 | with the same signature, apply this class as a decorator to that function. 29 | This can also be used for decorator factories. 30 | 31 | Examples: 32 | 33 | Plain decorator (decorator matches Callable[[T], T]): 34 | 35 | >>> @pg.typing.Decorator 36 | ... def my_decorator(func): 37 | ... def wrapper(...): 38 | ... ... 39 | ... return wrapper 40 | 41 | Decorator factory (factory matches Callable[..., Callable[[T], T]]): 42 | 43 | >>> def my_decorator_factory(foo: int): 44 | ... 45 | ... @py.typing.Decorator 46 | ... def my_decorator(func): 47 | ... ... 48 | ... return my_decorator 49 | 50 | This class only exists at build time, for typechecking. At runtime, the 51 | 'Decorator' member of this module is a simple identity function. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | decorator: typing.Callable[[_GenericCallable], _GenericCallable]): # pylint: disable=unused-argument 57 | ... # pylint: disable=pointless-statement 58 | 59 | def __call__(self, func: _GenericCallable) -> _GenericCallable: 60 | ... # pytype: disable=bad-return-type # pylint: disable=pointless-statement 61 | 62 | else: 63 | Decorator = lambda d: d 64 | -------------------------------------------------------------------------------- /pyglove/core/typing/typed_missing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Typed value placeholders.""" 15 | 16 | from typing import Any 17 | from pyglove.core import utils 18 | from pyglove.core.typing import class_schema 19 | 20 | 21 | # Non-typed missing value. 22 | MISSING_VALUE = utils.MISSING_VALUE 23 | 24 | 25 | class MissingValue(utils.MissingValue, utils.Formattable): 26 | """Class represents missing value **for a specific value spec**.""" 27 | 28 | def __init__(self, value_spec: class_schema.ValueSpec): 29 | """Constructor.""" 30 | self._value_spec = value_spec 31 | 32 | @property 33 | def value_spec(self) -> class_schema.ValueSpec: 34 | """Returns value spec of current missing value.""" 35 | return self._value_spec 36 | 37 | def __eq__(self, other: Any) -> bool: 38 | """Operator ==. 39 | 40 | NOTE: `MissingValue(value_spec) and `utils.MissingValue` are 41 | considered equal, but `MissingValue(value_spec1)` and 42 | `MissingValue(value_spec2)` are considered different. That being said, 43 | the 'eq' operation is not transitive. 44 | 45 | However in practice this is not a problem, since user always compare 46 | against `schema.MISSING_VALUE` which is `utils.MissingValue`. 47 | Therefore the `__hash__` function returns the same value with 48 | `utils.MissingValue`. 49 | 50 | Args: 51 | other: the value to compare against. 52 | 53 | Returns: 54 | True if the other value is a general MissingValue or MissingValue of the 55 | same value spec. 56 | """ 57 | if self is other: 58 | return True 59 | if isinstance(other, MissingValue): 60 | return self._value_spec == other.value_spec 61 | return MISSING_VALUE == other 62 | 63 | def __hash__(self) -> int: 64 | """Overridden hashing to make all MissingValue return the same value.""" 65 | return hash(MISSING_VALUE) 66 | 67 | def format(self, 68 | compact: bool = False, 69 | verbose: bool = True, 70 | root_indent: int = 0, 71 | **kwargs) -> str: 72 | """Format current object.""" 73 | if compact: 74 | return 'MISSING_VALUE' 75 | else: 76 | spec_str = self._value_spec.format( 77 | compact=compact, verbose=verbose, root_indent=root_indent, **kwargs) 78 | return f'MISSING_VALUE({spec_str})' 79 | 80 | def __deepcopy__(self, memo): 81 | """Avoid deep copy by copying value_spec by reference.""" 82 | return MissingValue(self.value_spec) 83 | -------------------------------------------------------------------------------- /pyglove/core/typing/typed_missing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import unittest 15 | 16 | from pyglove.core import utils 17 | from pyglove.core.typing import typed_missing 18 | from pyglove.core.typing import value_specs 19 | 20 | 21 | class MissingValueTest(unittest.TestCase): 22 | """Tests for typed MissingValue class.""" 23 | 24 | def test_eq(self): 25 | self.assertEqual( 26 | typed_missing.MissingValue(value_specs.Int()), utils.MISSING_VALUE 27 | ) 28 | 29 | self.assertEqual( 30 | utils.MISSING_VALUE, typed_missing.MissingValue(value_specs.Int()) 31 | ) 32 | 33 | self.assertEqual( 34 | typed_missing.MissingValue(value_specs.Int()), 35 | typed_missing.MissingValue(value_specs.Int())) 36 | 37 | self.assertNotEqual( 38 | typed_missing.MissingValue(value_specs.Int()), 39 | typed_missing.MissingValue(value_specs.Int(max_value=1))) 40 | 41 | self.assertNotEqual( 42 | typed_missing.MissingValue(value_specs.Int()), 43 | typed_missing.MissingValue(value_specs.Str())) 44 | 45 | m = typed_missing.MissingValue(value_specs.Int()) 46 | self.assertEqual(m, m) 47 | 48 | def test_hash(self): 49 | self.assertEqual( 50 | hash(typed_missing.MissingValue(value_specs.Int())), 51 | hash(typed_missing.MissingValue(value_specs.Float()))) 52 | 53 | self.assertEqual( 54 | hash(typed_missing.MissingValue(value_specs.Int())), 55 | hash(utils.MISSING_VALUE), 56 | ) 57 | 58 | self.assertNotEqual( 59 | hash(typed_missing.MissingValue(value_specs.Int())), 60 | hash(1)) 61 | 62 | def test_format(self): 63 | """Test MissingValue.format.""" 64 | self.assertEqual( 65 | typed_missing.MissingValue(value_specs.Int()).format(compact=True), 66 | 'MISSING_VALUE') 67 | 68 | self.assertEqual( 69 | typed_missing.MissingValue(value_specs.Int()).format(compact=False), 70 | 'MISSING_VALUE(Int())') 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /pyglove/core/utils/common_traits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common traits for Python objects. 15 | 16 | This file defines interfaces for describing the common traits of a Python 17 | object, for example, partiality (MaybePartial), functor (Functor). 18 | """ 19 | 20 | import abc 21 | from typing import Any, Dict, Optional, Union 22 | 23 | 24 | class MaybePartial(metaclass=abc.ABCMeta): 25 | """Interface for classes whose instances can be partially constructed. 26 | 27 | A ``MaybePartial`` object is an object whose ``__init__`` method can accept 28 | ``pg.MISSING_VALUE`` as its argument values. All symbolic types (see 29 | :class:`pyglove.Symbolic`) implements this interface, as their symbolic 30 | attributes can be partially filled. 31 | 32 | Example:: 33 | 34 | d = pg.Dict(x=pg.MISSING_VALUE, y=1) 35 | assert d.is_partial 36 | assert 'x' in d.missing_values() 37 | """ 38 | 39 | @property 40 | def is_partial(self) -> bool: 41 | """Returns True if this object is partial. Otherwise False. 42 | 43 | An object is considered partial when any of its required fields is missing, 44 | or at least one member is partial. The subclass can override this method 45 | to provide a more efficient solution. 46 | """ 47 | return len(self.missing_values()) > 0 # pylint: disable=g-explicit-length-test 48 | 49 | @abc.abstractmethod 50 | def missing_values(self, flatten: bool = True) -> Dict[Union[str, int], Any]: # pylint: disable=redefined-outer-name 51 | """Returns missing values from this object. 52 | 53 | Args: 54 | flatten: If True, convert nested structures into a flattened dict using 55 | key path (delimited by '.' and '[]') as key. 56 | 57 | Returns: 58 | A dict of key to MISSING_VALUE. 59 | """ 60 | 61 | 62 | class Functor(metaclass=abc.ABCMeta): 63 | """Interface for functor.""" 64 | 65 | @abc.abstractmethod 66 | def __call__(self, *args, **kwargs) -> Any: 67 | """Calls the functor. 68 | 69 | Args: 70 | *args: Any positional arguments. 71 | **kwargs: Any keyword arguments. 72 | 73 | Returns: 74 | Any value. 75 | """ 76 | 77 | 78 | def explicit_method_override(method): 79 | """Decorator that marks a member method as explicitly overridden. 80 | 81 | In PyGlove, many methods are managed by the framework - for example - 82 | ``pg.Object.__init__``. It's easy for users to override these methods 83 | unconsciously. Therefore, we introduce this decorator to catch error at 84 | the first place when such overrides incidentally take place, while allowing 85 | advanced users to override them. 86 | 87 | Usage:: 88 | 89 | class Foo(pg.Object): 90 | 91 | @pg.explicit_method_override 92 | def __init__(self, *args, **kwargs): 93 | ... 94 | 95 | Args: 96 | method: method to explicitly overriden. 97 | 98 | Returns: 99 | The original method with an explicit overriden stamp. 100 | """ 101 | setattr(method, '__explicit_override__', True) 102 | return method 103 | 104 | 105 | def ensure_explicit_method_override( 106 | method, error_message: Optional[str] = None) -> None: 107 | """Returns True if a method is explicitly overridden.""" 108 | if not getattr(method, '__explicit_override__', False): 109 | if error_message is None: 110 | error_message = ( 111 | f'{method} is a PyGlove managed method. If you do need to override ' 112 | 'it, please decorate the method with `@pg.explicit_method_override`.') 113 | raise TypeError(error_message) 114 | -------------------------------------------------------------------------------- /pyglove/core/utils/common_traits_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import unittest 15 | from pyglove.core.utils import common_traits 16 | 17 | 18 | class ExplicitlyOverrideTest(unittest.TestCase): 19 | 20 | def test_explicitly_override(self): 21 | class A: 22 | 23 | @common_traits.explicit_method_override 24 | def __init__(self, x, y): 25 | pass 26 | 27 | def bar(self): 28 | pass 29 | 30 | common_traits.ensure_explicit_method_override(A.__init__) 31 | with self.assertRaisesRegex(TypeError, '.* is a PyGlove managed method'): 32 | common_traits.ensure_explicit_method_override(A.bar) 33 | 34 | 35 | if __name__ == '__main__': 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /pyglove/core/utils/contextual_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import concurrent.futures 15 | import unittest 16 | from pyglove.core.utils import contextual 17 | 18 | 19 | class ContextualTest(unittest.TestCase): 20 | 21 | def test_contextual_override(self): 22 | with contextual.contextual_override(x=3, y=3, z=3) as parent_override: 23 | self.assertEqual( 24 | parent_override, 25 | dict( 26 | x=contextual.ContextualOverride( 27 | 3, cascade=False, override_attrs=False 28 | ), 29 | y=contextual.ContextualOverride( 30 | 3, cascade=False, override_attrs=False 31 | ), 32 | z=contextual.ContextualOverride( 33 | 3, cascade=False, override_attrs=False 34 | ), 35 | ), 36 | ) 37 | self.assertEqual( 38 | contextual.get_contextual_override('y'), 39 | contextual.ContextualOverride(3, cascade=False, override_attrs=False), 40 | ) 41 | self.assertEqual(contextual.contextual_value('x'), 3) 42 | self.assertIsNone(contextual.contextual_value('f', None)) 43 | with self.assertRaisesRegex(KeyError, '.* does not exist'): 44 | contextual.contextual_value('f') 45 | 46 | self.assertEqual(contextual.all_contextual_values(), dict(x=3, y=3, z=3)) 47 | 48 | # Test nested contextual override with override_attrs=True (default). 49 | with contextual.contextual_override( 50 | y=4, z=4, override_attrs=True) as nested_override: 51 | self.assertEqual( 52 | nested_override, 53 | dict( 54 | x=contextual.ContextualOverride( 55 | 3, cascade=False, override_attrs=False 56 | ), 57 | y=contextual.ContextualOverride( 58 | 4, cascade=False, override_attrs=True 59 | ), 60 | z=contextual.ContextualOverride( 61 | 4, cascade=False, override_attrs=True 62 | ), 63 | ), 64 | ) 65 | 66 | # Test nested contextual override with cascade=True. 67 | with contextual.contextual_override(x=3, y=3, z=3, cascade=True): 68 | with contextual.contextual_override(y=4, z=4, cascade=True): 69 | self.assertEqual(contextual.contextual_value('x'), 3) 70 | self.assertEqual(contextual.contextual_value('y'), 3) 71 | self.assertEqual(contextual.contextual_value('z'), 3) 72 | 73 | def test_with_contextual_override(self): 74 | def func(i): 75 | del i 76 | return contextual.contextual_value('x') 77 | 78 | pool = concurrent.futures.ThreadPoolExecutor() 79 | with contextual.contextual_override(x=3): 80 | self.assertEqual(contextual.with_contextual_override(func)(0), 3) 81 | self.assertEqual( 82 | list(pool.map(contextual.with_contextual_override(func), range(1))), 83 | [3] 84 | ) 85 | 86 | 87 | if __name__ == '__main__': 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /pyglove/core/utils/error_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import unittest 15 | from pyglove.core.utils import error_utils 16 | 17 | 18 | class CatchErrorsTest(unittest.TestCase): 19 | 20 | def assert_caught_error(self, func, errors_to_catch): 21 | with error_utils.catch_errors(errors_to_catch) as context: 22 | func() 23 | self.assertIsNotNone(context.error) 24 | 25 | def assert_propagate_error(self, func, errors_to_catch): 26 | 27 | with self.assertRaises(Exception): 28 | with error_utils.catch_errors(errors_to_catch): 29 | func() 30 | 31 | def test_catch_errors(self): 32 | def foo(): 33 | raise ValueError('this is an error') 34 | 35 | self.assert_caught_error(foo, ValueError) 36 | self.assert_caught_error(foo, (ValueError,)) 37 | self.assert_caught_error(foo, (Exception, 'ValueError')) 38 | self.assert_caught_error(foo, (KeyError, ValueError)) 39 | self.assert_caught_error(foo, (ValueError, 'an error')) 40 | self.assert_caught_error(foo, (KeyError, (ValueError, 'an error'),)) 41 | 42 | self.assert_propagate_error(foo, KeyError) 43 | self.assert_propagate_error(foo, (ValueError, '^an error')) 44 | self.assert_propagate_error(foo, (ValueError, 'something else')) 45 | 46 | def test_catch_errors_with_error_handler(self): 47 | errors = [] 48 | def handle_error(error): 49 | errors.append(error) 50 | 51 | def foo(): 52 | raise ValueError() 53 | 54 | with error_utils.catch_errors(ValueError, handle_error) as context: 55 | foo() 56 | self.assertEqual(errors, [context.error]) 57 | 58 | def test_catch_errors_bad_inputs(self): 59 | with self.assertRaisesRegex( 60 | TypeError, 'Each error specification should be either .*'): 61 | with error_utils.catch_errors([(ValueError, 'abc', 'abc')]): 62 | pass 63 | 64 | with self.assertRaisesRegex( 65 | TypeError, 'Each error specification should be either .*'): 66 | with error_utils.catch_errors([(ValueError, 1)]): 67 | pass 68 | 69 | with self.assertRaisesRegex( 70 | TypeError, 'Exception contains non-except types'): 71 | with error_utils.catch_errors([str, ValueError]): 72 | pass 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /pyglove/core/utils/missing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Representing missing value for a field.""" 15 | 16 | from typing import Any, Dict 17 | from pyglove.core.utils import formatting 18 | from pyglove.core.utils import json_conversion 19 | 20 | 21 | class MissingValue(formatting.Formattable, json_conversion.JSONConvertible): 22 | """Value placeholder for an unassigned attribute.""" 23 | 24 | def format(self, *args, **kwargs): # pytype: disable=signature-mismatch 25 | return 'MISSING_VALUE' 26 | 27 | def __ne__(self, other: Any) -> bool: 28 | return not self.__eq__(other) 29 | 30 | def __eq__(self, other: Any) -> bool: 31 | return isinstance(other, MissingValue) 32 | 33 | def __hash__(self) -> int: 34 | return hash(MissingValue.__module__ + MissingValue.__name__) 35 | 36 | def to_json(self, **kwargs) -> Dict[str, Any]: 37 | return self.to_json_dict(fields=dict(), **kwargs) 38 | 39 | 40 | # A shortcut global object (constant) for referencing MissingValue. 41 | MISSING_VALUE = MissingValue() 42 | -------------------------------------------------------------------------------- /pyglove/core/utils/missing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import unittest 15 | from pyglove.core.utils import json_conversion 16 | from pyglove.core.utils import missing 17 | 18 | 19 | class MissingValueTest(unittest.TestCase): 20 | """Tests for class MissingValue.""" 21 | 22 | def test_basics(self): 23 | self.assertEqual(missing.MissingValue(), 24 | missing.MissingValue()) 25 | self.assertNotEqual(missing.MissingValue(), 1) 26 | self.assertNotEqual(missing.MissingValue(), {}) 27 | 28 | self.assertEqual(str(missing.MissingValue()), 'MISSING_VALUE') 29 | self.assertEqual(repr(missing.MissingValue()), 'MISSING_VALUE') 30 | 31 | def test_to_json(self): 32 | json = json_conversion.to_json(missing.MissingValue()) 33 | self.assertEqual(json_conversion.from_json(json), missing.MissingValue()) 34 | 35 | 36 | if __name__ == '__main__': 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /pyglove/core/utils/text_color_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import os 16 | import unittest 17 | from pyglove.core.utils import text_color 18 | 19 | 20 | class TextColorTest(unittest.TestCase): 21 | 22 | def setUp(self): 23 | super().setUp() 24 | os.environ.pop('ANSI_COLORS_DISABLED', None) 25 | os.environ.pop('NO_COLOR', None) 26 | os.environ['FORCE_COLOR'] = '1' 27 | 28 | def test_colored_block_without_colors_and_styles(self): 29 | self.assertEqual(text_color.colored_block('foo', '{{', '}}'), 'foo') 30 | 31 | def test_colored_block(self): 32 | original_text = inspect.cleandoc(""" 33 | Hi << foo >> 34 | <# print x if x is present #> 35 | <% if x %> 36 | << x >> 37 | <% endif %> 38 | """) 39 | 40 | colored_text = text_color.colored_block( 41 | text_color.colored(original_text, color='blue'), 42 | '<<', '>>', 43 | color='white', 44 | background='blue', 45 | ) 46 | origin_color = '\x1b[34m' 47 | reset = '\x1b[0m' 48 | block_color = text_color.colored( 49 | 'TEXT', color='white', background='blue' 50 | ).split('TEXT')[0] 51 | expected = ( 52 | f'{origin_color}Hi {block_color}<< foo >>{reset}{origin_color}\n' 53 | '<# print x if x is present #>\n<% if x %>\n' 54 | f'{block_color}<< x >>{reset}{origin_color}\n' 55 | f'<% endif %>{reset}' 56 | ) 57 | self.assertTrue( 58 | # On some termcolor versions, the color codes are not applied. 59 | colored_text == expected or colored_text == original_text 60 | ) 61 | self.assertEqual(text_color.decolor(colored_text), original_text) 62 | 63 | def test_colored_block_without_full_match(self): 64 | self.assertEqual( 65 | text_color.colored_block( 66 | 'Hi {{ foo', 67 | '{{', '}}', 68 | color='white', 69 | background='blue', 70 | ), 71 | 'Hi {{ foo' 72 | ) 73 | 74 | def test_colored_block_without_termcolor(self): 75 | termcolor = text_color.termcolor 76 | text_color.termcolor = None 77 | original_text = inspect.cleandoc(""" 78 | Hi {{ foo }} 79 | {# print x if x is present #} 80 | {% if x %} 81 | {{ x }} 82 | {% endif %} 83 | """) 84 | 85 | colored_text = text_color.colored_block( 86 | text_color.colored(original_text, color='blue'), 87 | '{{', '}}', 88 | color='white', 89 | background='blue', 90 | ) 91 | self.assertEqual(colored_text, original_text) 92 | self.assertEqual(text_color.decolor(colored_text), original_text) 93 | text_color.termcolor = termcolor 94 | 95 | 96 | if __name__ == '__main__': 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /pyglove/core/views/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PyGlove views.""" 15 | 16 | from pyglove.core.views import base 17 | from pyglove.core.views import html 18 | 19 | View = base.View 20 | view = base.view 21 | view_options = base.view_options 22 | 23 | # Pytype annotation. 24 | NodeFilter = base.NodeFilter 25 | 26 | Html = html.Html 27 | HtmlConvertible = html.HtmlConvertible 28 | HtmlView = html.HtmlView 29 | HtmlTreeView = html.HtmlTreeView 30 | 31 | to_html = html.to_html 32 | to_html_str = html.to_html_str 33 | -------------------------------------------------------------------------------- /pyglove/core/views/html/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """HTML views for PyGlove objects.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=g-bad-import-order 18 | 19 | from pyglove.core.views.html.base import Html 20 | from pyglove.core.views.html.base import HtmlConvertible 21 | from pyglove.core.views.html.base import HtmlView 22 | from pyglove.core.views.html.base import to_html 23 | from pyglove.core.views.html.base import to_html_str 24 | from pyglove.core.views.html.tree_view import HtmlTreeView 25 | 26 | # pylint: enable=g-bad-import-order 27 | # pylint: enable=g-importing-member 28 | -------------------------------------------------------------------------------- /pyglove/core/views/html/controls/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common HTML controls.""" 15 | 16 | # pylint: disable=g-importing-member 17 | # pylint: disable=g-bad-import-order 18 | 19 | from pyglove.core.views.html.controls.base import HtmlControl 20 | 21 | 22 | from pyglove.core.views.html.controls.label import Label 23 | from pyglove.core.views.html.controls.label import LabelGroup 24 | from pyglove.core.views.html.controls.label import Badge 25 | 26 | from pyglove.core.views.html.controls.tooltip import Tooltip 27 | 28 | from pyglove.core.views.html.controls.tab import Tab 29 | from pyglove.core.views.html.controls.tab import TabControl 30 | 31 | from pyglove.core.views.html.controls.progress_bar import ProgressBar 32 | from pyglove.core.views.html.controls.progress_bar import SubProgress 33 | 34 | # pylint: enable=g-bad-import-order 35 | # pylint: enable=g-importing-member 36 | -------------------------------------------------------------------------------- /pyglove/core/views/html/controls/progress_bar_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import unittest 16 | 17 | from pyglove.core import symbolic # pylint: disable=unused-import 18 | from pyglove.core.views.html.controls import progress_bar 19 | 20 | 21 | class ProgressBarTest(unittest.TestCase): 22 | 23 | def assert_html_content(self, control, expected): 24 | expected = inspect.cleandoc(expected).strip() 25 | actual = control.to_html().content.strip() 26 | if actual != expected: 27 | print(actual) 28 | self.assertEqual(actual, expected) 29 | 30 | def test_basic(self): 31 | bar = progress_bar.ProgressBar( 32 | subprogresses=[ 33 | progress_bar.SubProgress('foo'), 34 | progress_bar.SubProgress('bar', 20), 35 | ], 36 | total=None, 37 | ) 38 | self.assert_html_content( 39 | bar, 40 | ( 41 | '
' 42 | f'
' 43 | '
' 45 | '
n/a' 48 | 'Not started
' 49 | ) 50 | ) 51 | self.assertEqual(bar['foo'], progress_bar.SubProgress('foo')) 52 | self.assertEqual(bar['bar'], progress_bar.SubProgress('bar', 20)) 53 | with self.assertRaisesRegex(KeyError, 'Sub progress bar .* not found'): 54 | _ = bar['baz'] 55 | self.assertIsNone(bar['foo'].total) 56 | self.assertIsNone(bar['foo'].width) 57 | 58 | bar.update(total=100) 59 | self.assertEqual(bar.total, 100) 60 | self.assertEqual(bar['foo'].total, 100) 61 | self.assertEqual(bar['foo'].width, '0%') 62 | self.assertEqual(bar['bar'].width, '20%') 63 | with bar.track_scripts() as scripts: 64 | bar['foo'].increment() 65 | self.assertEqual( 66 | scripts, 67 | [ 68 | inspect.cleandoc( 69 | f""" 70 | elem = document.getElementById("{bar['foo'].element_id()}"); 71 | elem.style = "width:1%;"; 72 | """ 73 | ), 74 | inspect.cleandoc( 75 | f""" 76 | elem = document.getElementById("{bar._progress_label.element_id()}"); 77 | elem.textContent = " 21.0% (21/100)"; 78 | """ 79 | ), 80 | inspect.cleandoc( 81 | f""" 82 | elem = document.getElementById("{bar._progress_label.tooltip.element_id()}"); 83 | elem.textContent = "foo: 1.0% (1/100)\\nbar: 20.0% (20/100)"; 84 | """ 85 | ), 86 | inspect.cleandoc( 87 | f""" 88 | elem = document.getElementById("{bar._progress_label.tooltip.element_id()}"); 89 | elem.classList.remove("html-content"); 90 | """ 91 | ), 92 | ] 93 | ) 94 | 95 | 96 | if __name__ == '__main__': 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /pyglove/core/views/html/controls/tooltip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Html tooltip.""" 15 | 16 | from typing import Optional, Union 17 | from pyglove.core import typing as pg_typing 18 | from pyglove.core.symbolic import object as pg_object 19 | # pylint: disable=g-importing-member 20 | from pyglove.core.views.html.base import Html 21 | from pyglove.core.views.html.controls.base import HtmlControl 22 | # pylint: disable=g-importing-member 23 | 24 | 25 | @pg_object.use_init_args( 26 | ['content', 'for_element', 'id', 'css_classes', 'styles'] 27 | ) 28 | class Tooltip(HtmlControl): 29 | """A tooltip control. 30 | 31 | Attributes: 32 | content: The content of the tooltip. It could be a string or a HTML object. 33 | id: The id of the tooltip. 34 | css_classes: The CSS classes for the tooltip. 35 | for_element: The CSS selector for the element to attach the tooltip to. 36 | e.g. '.my-element' or '#my-element'. 37 | """ 38 | 39 | content: Union[str, Html] 40 | for_element: Optional[str] = None 41 | 42 | def _to_html(self, **kwargs): 43 | if self.for_element is None: 44 | raise ValueError( 45 | 'CSS selector `for_element` is required for tooltip to display.' 46 | ) 47 | content = self.content 48 | if isinstance(self.content, str): 49 | content = Html.escape(self.content) 50 | return Html.element( 51 | 'span', 52 | [content], 53 | id=self.element_id(), 54 | css_classes=[ 55 | 'tooltip', 56 | 'html-content' if isinstance(content, Html) else None, 57 | ] + self.css_classes, 58 | styles=self.styles, 59 | ).add_style( 60 | """ 61 | span.tooltip { 62 | visibility: hidden; 63 | white-space: pre-wrap; 64 | font-weight: normal; 65 | background-color: #484848; 66 | color: #fff; 67 | padding: 10px; 68 | border-radius: 6px; 69 | position: absolute; 70 | z-index: 1; 71 | } 72 | span.tooltip:hover { 73 | visibility: visible; 74 | } 75 | .tooltip.html-content { 76 | white-space: inherit; 77 | background-color: white; 78 | color: inherit; 79 | box-shadow: rgba(0, 0, 0, 0.16) 0px 1px 4px; 80 | } 81 | """, 82 | f""" 83 | {self.for_element}:hover + .tooltip {{ 84 | visibility: visible; 85 | }} 86 | """ 87 | ) 88 | 89 | def update(self, content: Union[str, Html]) -> None: 90 | self._sync_members(content=self._update_content(content)) 91 | if isinstance(content, Html): 92 | self._add_css_class('html-content') 93 | else: 94 | self._remove_css_class('html-content') 95 | 96 | 97 | # Register converter for automatic conversion. 98 | pg_typing.register_converter(str, Tooltip, Tooltip) 99 | pg_typing.register_converter(Html, Tooltip, Tooltip) 100 | -------------------------------------------------------------------------------- /pyglove/core/views/html/controls/tooltip_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import unittest 16 | 17 | from pyglove.core import symbolic # pylint: disable=unused-import 18 | from pyglove.core.views.html import base 19 | from pyglove.core.views.html.controls import tooltip as tooltip_lib 20 | 21 | 22 | class TooltipTest(unittest.TestCase): 23 | 24 | def assert_html_content(self, control, expected): 25 | expected = inspect.cleandoc(expected).strip() 26 | actual = control.to_html().content.strip() 27 | if actual != expected: 28 | print(actual) 29 | self.assertEqual(actual, expected) 30 | 31 | def test_basic(self): 32 | tooltip = tooltip_lib.Tooltip('foo') 33 | with self.assertRaisesRegex( 34 | ValueError, 'CSS selector `for_element` is required' 35 | ): 36 | tooltip.to_html() 37 | 38 | tooltip = tooltip_lib.Tooltip('foo', for_element='.bar') 39 | self.assertEqual(tooltip.for_element, '.bar') 40 | self.assert_html_content( 41 | tooltip, 42 | 'foo' 43 | ) 44 | self.assertIn( 45 | inspect.cleandoc( 46 | """ 47 | .bar:hover + .tooltip { 48 | visibility: visible; 49 | } 50 | """ 51 | ), 52 | tooltip.to_html().style_section, 53 | ) 54 | 55 | def test_update(self): 56 | tooltip = tooltip_lib.Tooltip('foo', for_element='.bar', interactive=True) 57 | self.assertIn('id="control-', tooltip.to_html_str(content_only=True)) 58 | with tooltip.track_scripts() as scripts: 59 | tooltip.update('normal text') 60 | self.assertEqual(tooltip.content, 'normal text') 61 | self.assertEqual( 62 | scripts, 63 | [ 64 | inspect.cleandoc( 65 | f""" 66 | elem = document.getElementById("{tooltip.element_id()}"); 67 | elem.textContent = "normal text"; 68 | """ 69 | ), 70 | inspect.cleandoc( 71 | f""" 72 | elem = document.getElementById("{tooltip.element_id()}"); 73 | elem.classList.remove("html-content"); 74 | """ 75 | ), 76 | ] 77 | ) 78 | with tooltip.track_scripts() as scripts: 79 | tooltip.update(base.Html('bold text')) 80 | self.assertEqual( 81 | scripts, 82 | [ 83 | inspect.cleandoc( 84 | f""" 85 | elem = document.getElementById("{tooltip.element_id()}"); 86 | elem.innerHTML = "bold text"; 87 | """ 88 | ), 89 | inspect.cleandoc( 90 | f""" 91 | elem = document.getElementById("{tooltip.element_id()}"); 92 | elem.classList.add("html-content"); 93 | """ 94 | ), 95 | ] 96 | ) 97 | 98 | if __name__ == '__main__': 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /pyglove/ext/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Package pyglove.generators.""" 15 | 16 | from pyglove.ext import early_stopping 17 | from pyglove.ext import evolution 18 | from pyglove.ext import mutfun 19 | from pyglove.ext import scalars 20 | -------------------------------------------------------------------------------- /pyglove/ext/early_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PyGlove's built-in early stopping policies.""" 15 | 16 | # pylint: disable=g-bad-import-order 17 | 18 | from pyglove.ext.early_stopping.base import EarlyStopingPolicyBase 19 | from pyglove.ext.early_stopping.base import And 20 | from pyglove.ext.early_stopping.base import Or 21 | from pyglove.ext.early_stopping.base import Not 22 | 23 | from pyglove.ext.early_stopping.step_wise import early_stop_by_rank 24 | from pyglove.ext.early_stopping.step_wise import early_stop_by_value 25 | from pyglove.ext.early_stopping.step_wise import StepWise 26 | 27 | # pylint: enable=g-bad-import-order 28 | -------------------------------------------------------------------------------- /pyglove/ext/early_stopping/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base early stopping policy that is friendly to composition.""" 15 | 16 | from typing import Iterable 17 | import pyglove.core as pg 18 | 19 | 20 | class EarlyStopingPolicyBase(pg.tuning.EarlyStoppingPolicy): 21 | """An early stopping policy base class that supports composition.""" 22 | 23 | def __and__(self, other) -> pg.tuning.EarlyStoppingPolicy: 24 | """Operator &.""" 25 | return And(self, other) 26 | 27 | def __or__(self, other) -> pg.tuning.EarlyStoppingPolicy: 28 | """Operator |.""" 29 | return Or(self, other) 30 | 31 | def __neg__(self) -> pg.tuning.EarlyStoppingPolicy: 32 | """Operator -.""" 33 | return Not(self) 34 | 35 | def __invert__(self) -> pg.tuning.EarlyStoppingPolicy: 36 | return Not(self) 37 | 38 | 39 | @pg.members([ 40 | ('children', 41 | pg.typing.List(pg.typing.Object(pg.tuning.EarlyStoppingPolicy))), 42 | ], init_arg_list=['*children']) 43 | class Composite(EarlyStopingPolicyBase): 44 | """Base class for composite early stopping policies.""" 45 | 46 | def recover(self, history: Iterable[pg.tuning.Trial]): 47 | for child in self.children: 48 | child.recover(history) 49 | 50 | 51 | class And(Composite): 52 | """Logical AND as a composite early stopping policy.""" 53 | 54 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool: 55 | for child in self.children: 56 | if not child.should_stop_early(trial): 57 | return False 58 | return True 59 | 60 | 61 | class Or(Composite): 62 | """Logical OR as a composite early stopping policy.""" 63 | 64 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool: 65 | for child in self.children: 66 | if child.should_stop_early(trial): 67 | return True 68 | return False 69 | 70 | 71 | class Not(Composite): 72 | """Logical OR as a composite early stopping policy.""" 73 | 74 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool: 75 | assert len(self.children) == 1, self.children 76 | return not self.children[0].should_stop_early(trial) 77 | -------------------------------------------------------------------------------- /pyglove/ext/early_stopping/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Iterable 16 | import unittest 17 | import pyglove.core as pg 18 | from pyglove.ext.early_stopping import base 19 | 20 | 21 | @pg.members([ 22 | ('decision', pg.typing.Bool()) 23 | ]) 24 | class ConstantPolicy(base.EarlyStopingPolicyBase): 25 | 26 | def _on_bound(self): 27 | super()._on_bound() 28 | self.requested_trials = [] 29 | self.recovered = False 30 | 31 | def should_stop_early(self, trial: pg.tuning.Trial) -> bool: 32 | self.requested_trials.append(trial) 33 | return self.decision 34 | 35 | def recover(self, history: Iterable[pg.tuning.Trial]) -> None: 36 | self.recovered = True 37 | 38 | 39 | class EarlyStoppingPolicyComposabilityTest(unittest.TestCase): 40 | """Test the composability of early stopping policies.""" 41 | 42 | def test_logical_and(self): 43 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0) 44 | x = ConstantPolicy(True) 45 | self.assertTrue(x.should_stop_early(t)) 46 | 47 | y = ConstantPolicy(False) 48 | self.assertFalse(y.should_stop_early(t)) 49 | 50 | self.assertTrue((x & x).should_stop_early(t)) 51 | self.assertFalse((x & y).should_stop_early(t)) 52 | self.assertFalse((y & x).should_stop_early(t)) 53 | self.assertFalse((y & y).should_stop_early(t)) 54 | 55 | def test_logical_or(self): 56 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0) 57 | self.assertTrue( 58 | (ConstantPolicy(True) | ConstantPolicy(True)).should_stop_early(t)) 59 | self.assertTrue( 60 | (ConstantPolicy(True) | ConstantPolicy(False)).should_stop_early(t)) 61 | self.assertTrue( 62 | (ConstantPolicy(False) | ConstantPolicy(True)).should_stop_early(t)) 63 | self.assertFalse( 64 | (ConstantPolicy(False) | ConstantPolicy(False)).should_stop_early(t)) 65 | 66 | def test_logical_not(self): 67 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0) 68 | self.assertFalse((~ConstantPolicy(True)).should_stop_early(t)) # pylint: disable=invalid-unary-operand-type 69 | self.assertTrue((-ConstantPolicy(False)).should_stop_early(t)) # pylint: disable=invalid-unary-operand-type 70 | 71 | def test_recorver(self): 72 | t = pg.tuning.Trial(id=1, dna=pg.DNA(1), created_time=0) 73 | x = ConstantPolicy(True) 74 | y = ConstantPolicy(False) 75 | z = x & y 76 | z.recover([t]) 77 | self.assertTrue(x.recovered) 78 | self.assertTrue(y.recovered) 79 | 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PyGlove's genetic computing framework.""" 15 | 16 | from pyglove.ext import scalars 17 | from pyglove.ext.evolution import base 18 | from pyglove.ext.evolution import hill_climb as hill_climb_lib 19 | from pyglove.ext.evolution import mutators 20 | from pyglove.ext.evolution import neat as neat_lib 21 | from pyglove.ext.evolution import nsga2 as nsga2_lib 22 | from pyglove.ext.evolution import recombinators 23 | from pyglove.ext.evolution import regularized_evolution as regularized_evolution_lib 24 | from pyglove.ext.evolution import selectors 25 | from pyglove.ext.evolution.hill_climb import hill_climb 26 | from pyglove.ext.evolution.neat import neat 27 | from pyglove.ext.evolution.nsga2 import nsga2 28 | from pyglove.ext.evolution.regularized_evolution import regularized_evolution 29 | 30 | # Alias for backward compatibility. 31 | # Remove once dependencies are fixed. 32 | RegularizedEvolution = regularized_evolution 33 | 34 | # Interfaces. 35 | Operation = base.Operation 36 | DNAOperation = base.DNAOperation 37 | Selector = base.Selector 38 | Recombinator = base.Recombinator 39 | Mutator = base.Mutator 40 | Scalar = scalars.Scalar 41 | 42 | # The compositional evolution class. 43 | Evolution = base.Evolution 44 | 45 | # Compositional operators. 46 | # Common operations that are not associated with an operator. 47 | Lambda = base.Lambda 48 | 49 | Pipeline = base.Pipeline # operator >> 50 | Power = base.Power # operator ** 51 | 52 | Concatenation = base.Concatenation # operator + 53 | Slice = base.Slice # operator [] 54 | Repeat = base.Repeat # operator * 55 | 56 | Identity = base.Identity 57 | Union = base.Union # operator | 58 | Intersection = base.Intersection # operator & 59 | Difference = base.Difference # operator - 60 | SymmetricDifference = base.SymmetricDifference # operator ^ 61 | Inversion = base.Inversion # operator ~ 62 | 63 | Choice = base.Choice # .with_prob 64 | Conditional = base.Conditional # .if_true/.if_false 65 | ElementWise = base.ElementWise # .for_each 66 | Flatten = base.Flatten # .flatten 67 | UntilChange = base.UntilChange # .until_change 68 | 69 | 70 | GlobalStateGetter = base.GlobalStateGetter # .global_state 71 | GlobalStateSetter = base.GlobalStateSetter # .as_global_state 72 | 73 | # Helper method. 74 | scalar_spec = scalars.scalar_spec 75 | scalar_value = scalars.scalar_value 76 | 77 | set_fitness = base.set_fitness 78 | get_fitness = base.get_fitness 79 | 80 | get_generation_id = base.get_generation_id 81 | get_feedback_sequence_number = base.get_feedback_sequence_number 82 | is_initial_population = base.is_initial_population 83 | 84 | 85 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/hill_climb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Generic Hill-Climbing Algorithm.""" 15 | 16 | from typing import Optional 17 | 18 | import pyglove.core as pg 19 | from pyglove.ext.evolution import base 20 | from pyglove.ext.evolution import mutators 21 | from pyglove.ext.evolution import selectors 22 | 23 | 24 | def hill_climb(mutator=mutators.Uniform(), 25 | batch_size: int = 1, 26 | init_population_size: int = 1, 27 | seed: Optional[int] = None) -> base.Evolution: 28 | """Hill-Climbing algorithm, with an extra batched setting. 29 | 30 | Batched setting was shown to be effective in 31 | https://arxiv.org/pdf/1911.06317.pdf and https://arxiv.org/pdf/2003.01239.pdf, 32 | especially in noisy objective settings. 33 | 34 | Args: 35 | mutator: Mutator to use. 36 | batch_size: Number of mutations of the current best. 37 | init_population_size: Initial population size (randomly generated). 38 | seed: Random seed. 39 | 40 | Returns: 41 | An `Evolution` object. 42 | """ 43 | return base.Evolution( 44 | selectors.Top(1) >> (mutator * batch_size), # pytype: disable=unsupported-operands 45 | population_init=(pg.geno.Random(seed), init_population_size), 46 | population_update=selectors.Top(1)) 47 | 48 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/hill_climb_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for hill climbing algorithm.""" 15 | 16 | import random 17 | import time 18 | import unittest 19 | 20 | import pyglove.core as pg 21 | from pyglove.ext.evolution import base 22 | from pyglove.ext.evolution import hill_climb_lib as hill_climb 23 | 24 | 25 | def get_trivial_search_space(): 26 | """Trivial search space. 27 | 28 | Each point in the space is a value in [0, 1]. 29 | 30 | Returns: 31 | A tunable value. 32 | """ 33 | return pg.float_value(0.0, 1.0) 34 | 35 | 36 | class TrivialMutator(base.Mutator): 37 | """Mutator for trivial search space. 38 | 39 | Mutations can only change the value by a small amount. 40 | """ 41 | 42 | def mutate(self, dna, step): 43 | del step 44 | dna.value = dna.value + random.uniform(-0.01, 0.01) 45 | if dna.value < 0.0: 46 | dna.value = 0.0 47 | if dna.value > 1.0: 48 | dna.value = 1.0 49 | return dna 50 | 51 | 52 | def trivial_reward(example): 53 | """Reward for the trivial search space. 54 | 55 | The reward (i.e. fitness) is the value itself. The goal of the search, 56 | therefore, is to find the value 1. 57 | 58 | Args: 59 | example: a materialized value. 60 | 61 | Returns: 62 | The corresponding reward. 63 | """ 64 | return example 65 | 66 | 67 | def get_trivial_hash(search_space, algo): 68 | hashed_value = 0 69 | for example, _ in pg.iter(search_space, 30, algo): 70 | hashed_value ^= int(example * 1000000) 71 | return hashed_value 72 | 73 | 74 | class HillClimbingTest(unittest.TestCase): 75 | 76 | def test_integration(self): 77 | # Set up search space. 78 | search_space = get_trivial_search_space() 79 | 80 | # Search algorithm. 81 | algo = hill_climb.hill_climb(mutator=TrivialMutator(), batch_size=1) 82 | 83 | # Search. 84 | best_reward = None 85 | iters = 0 86 | start_time = time.time() 87 | while True: 88 | for example, feedback in pg.iter(search_space, 500, algo): 89 | reward = trivial_reward(example) 90 | feedback(reward) 91 | if best_reward is None or reward > best_reward: 92 | best_reward = reward 93 | iters += 1 94 | if reward >= 1.0: 95 | break 96 | if reward >= 1.0: 97 | break 98 | if time.time() - start_time > 600.0: 99 | self.fail('Took too long to find a solution.') 100 | 101 | def test_permanence(self): 102 | search_space = get_trivial_search_space() 103 | algo = hill_climb.hill_climb(mutator=TrivialMutator(), batch_size=1, seed=0) 104 | 105 | # If a CL causes the following assert to fail, it means that the CL is 106 | # causing a difference in the behavior of the evolutionary algorithms. If 107 | # this is expected (e.g. a change in the random number generator), then 108 | # simply update the hash to the new value. 109 | self.assertEqual(get_trivial_hash(search_space, algo), 789108) 110 | 111 | 112 | if __name__ == '__main__': 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/regularized_evolution.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Regularized Evolution: https://arxiv.org/abs/1802.01548.""" 15 | 16 | from typing import Optional 17 | 18 | import pyglove.core as pg 19 | from pyglove.ext.evolution import base 20 | from pyglove.ext.evolution import mutators 21 | from pyglove.ext.evolution import selectors 22 | 23 | 24 | def regularized_evolution( 25 | mutator=mutators.Uniform(), 26 | population_size: int = 100, 27 | tournament_size: int = 10, 28 | seed: Optional[int] = None): 29 | """Regularized Evolution algorithm. 30 | 31 | https://www.aaai.org/ojs/index.php/AAAI/article/view/4405. 32 | 33 | Args: 34 | mutator: Mutator to use. 35 | population_size: Population size. Must be larger than tournament size. 36 | tournament_size: Tournament size. 37 | seed: Random seed. If None, the current system time is used as seed. 38 | 39 | Returns: 40 | An `Evolution` object. 41 | """ 42 | if tournament_size < 2: 43 | raise ValueError( 44 | f'`tournament_size` must be no less than 2. ' 45 | f'Encountered: {tournament_size}') 46 | if population_size < tournament_size: 47 | raise ValueError( 48 | f'The value of `population_size` ({population_size}) must be no ' 49 | f'less than the value of `tournament_size` ({tournament_size}).') 50 | return base.Evolution( 51 | selectors.Random( 52 | tournament_size, seed=seed) >> selectors.Top(1) >> mutator, 53 | population_init=(pg.geno.Random(seed=seed), population_size), 54 | population_update=selectors.Last(population_size)) 55 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/regularized_evolution_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for evolutionary algorithms.""" 15 | 16 | import random 17 | import time 18 | import unittest 19 | 20 | import pyglove.core as pg 21 | from pyglove.ext.evolution import base 22 | from pyglove.ext.evolution import regularized_evolution_lib as regularized_evolution 23 | 24 | 25 | def get_trivial_search_space(): 26 | """Trivial search space. 27 | 28 | Each point in the space is a value in [0, 1]. 29 | 30 | Returns: 31 | A tunable value. 32 | """ 33 | return pg.floatv(0.0, 1.0) 34 | 35 | 36 | @pg.members([ 37 | ('seed', pg.typing.Int().noneable()), 38 | ]) 39 | class TrivialMutator(base.Mutator): 40 | """Mutator for trivial search space. 41 | 42 | Mutations can only change the value by a small amount. 43 | """ 44 | 45 | def _on_bound(self): 46 | super()._on_bound() 47 | self._random = random if self.seed is None else random.Random(self.seed) 48 | 49 | def mutate(self, dna, step): 50 | del step 51 | value = dna.value + self._random.uniform(-0.01, 0.01) 52 | if value < -1.0: 53 | value = -1.0 54 | if value > 1.0: 55 | value = 1.0 56 | return pg.DNA(value) 57 | 58 | 59 | def trivial_reward(example): 60 | """Reward for the trivial search space. 61 | 62 | The reward (i.e. fitness) is the value itself. The goal of the search, 63 | therefore, is to find the value 1. 64 | 65 | Args: 66 | example: a materialized value. 67 | 68 | Returns: 69 | The corresponding reward. 70 | """ 71 | return example 72 | 73 | 74 | def get_trivial_hash(search_space, algo): 75 | hashed_value = 0 76 | for example, feedback in pg.iter(search_space, 30, algo): 77 | hashed_value ^= int(example * 1000000) 78 | feedback(example) 79 | return hashed_value 80 | 81 | 82 | class RegularizedEvolutionTest(unittest.TestCase): 83 | 84 | def test_integration(self): 85 | # Set up search space. 86 | search_space = get_trivial_search_space() 87 | 88 | # Search algorithm. 89 | algo = regularized_evolution.regularized_evolution( 90 | population_size=10, tournament_size=2, mutator=TrivialMutator()) 91 | 92 | # Search. 93 | best_reward = None 94 | iters = 0 95 | start_time = time.time() 96 | while True: 97 | for example, feedback in pg.iter(search_space, 100, algo): 98 | reward = trivial_reward(example) 99 | feedback(reward) 100 | if best_reward is None or reward > best_reward: 101 | best_reward = reward 102 | iters += 1 103 | if reward >= 1.0: 104 | break 105 | if reward >= 1.0: 106 | break 107 | if time.time() - start_time > 300.0: 108 | self.fail('Took too long to find a solution.') 109 | 110 | def test_permanence(self): 111 | search_space = get_trivial_search_space() 112 | algo = regularized_evolution.regularized_evolution( 113 | population_size=10, tournament_size=2, mutator=TrivialMutator(seed=1), 114 | seed=1) 115 | 116 | # If a CL causes the following assert to fail, it means that the CL is 117 | # causing a difference in the behavior of the evolutionary algorithms. If 118 | # this is expected (e.g. a change in the random number generator), then 119 | # simply update the hash to the new value. 120 | self.assertEqual(get_trivial_hash(search_space, algo), 385892) 121 | 122 | 123 | if __name__ == '__main__': 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /pyglove/ext/evolution/where_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test for evolution decision point filters.""" 15 | 16 | import unittest 17 | import pyglove.core as pg 18 | from pyglove.ext.evolution import where 19 | 20 | 21 | class DecisionPointFiltersTest(unittest.TestCase): 22 | """Tests for DecisionPointFilter subclasses.""" 23 | 24 | def test_lambda(self): 25 | w = where.Lambda(lambda xs: xs[:1]) 26 | output = w([pg.geno.Float(0., float(i + 1)) for i in range(5)]) 27 | self.assertEqual(len(output), 1) 28 | 29 | def test_all(self): 30 | w = where.ALL 31 | inputs = [pg.geno.Float(0., float(i + 1)) for i in range(5)] 32 | output = w(inputs) 33 | self.assertIs(inputs, output) 34 | 35 | def test_any(self): 36 | w = where.Any(seed=1) 37 | inputs = [pg.geno.Float(0., float(i + 1)) for i in range(5)] 38 | output = w(inputs) 39 | self.assertEqual(len(output), 1) 40 | self.assertEqual(output[0].max_value, 2.0) 41 | 42 | w = where.Any(4, seed=2) 43 | output = w(inputs) 44 | self.assertEqual(len(output), 4) 45 | # Check output is sorted. 46 | self.assertEqual([v.max_value for v in output], [1., 2., 4., 5.]) 47 | 48 | w = where.Any(10, seed=1) 49 | output = w(inputs) 50 | self.assertIs(inputs, output) 51 | 52 | def test_automatic_conversion(self): 53 | 54 | @pg.members([ 55 | ('where', where.where_spec()) 56 | ]) 57 | class MyFilter(pg.Object): 58 | pass 59 | 60 | w = MyFilter(lambda xs: xs[:2]) 61 | self.assertIsInstance(w.where, where.Lambda) 62 | 63 | 64 | if __name__ == '__main__': 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Step-based scalars used as hyper-parameters for PyGlove algorithms.""" 15 | 16 | # pylint: disable=g-bad-import-order 17 | 18 | # Helper methods for creating and retrieving scalars. 19 | from pyglove.ext.scalars.base import scalar_spec 20 | from pyglove.ext.scalars.base import scalar_value 21 | from pyglove.ext.scalars.base import make_scalar 22 | 23 | # Interface and common scalars. 24 | from pyglove.ext.scalars.base import Scalar 25 | 26 | from pyglove.ext.scalars.base import Lambda 27 | from pyglove.ext.scalars.base import Constant 28 | from pyglove.ext.scalars.base import STEP 29 | 30 | from pyglove.ext.scalars.base import UnaryOp 31 | from pyglove.ext.scalars.base import Negation 32 | from pyglove.ext.scalars.base import Absolute 33 | from pyglove.ext.scalars.base import Floor 34 | from pyglove.ext.scalars.base import Ceiling 35 | 36 | from pyglove.ext.scalars.base import BinaryOp 37 | from pyglove.ext.scalars.base import Addition 38 | from pyglove.ext.scalars.base import Substraction 39 | from pyglove.ext.scalars.base import Multiplication 40 | from pyglove.ext.scalars.base import Division 41 | from pyglove.ext.scalars.base import Mod 42 | from pyglove.ext.scalars.base import Power 43 | 44 | # Common math functions. 45 | from pyglove.ext.scalars.maths import linear 46 | from pyglove.ext.scalars.maths import cosine_decay 47 | from pyglove.ext.scalars.maths import exponential_decay 48 | from pyglove.ext.scalars.maths import cyclic 49 | from pyglove.ext.scalars.maths import sqrt 50 | from pyglove.ext.scalars.maths import exp 51 | from pyglove.ext.scalars.maths import log 52 | from pyglove.ext.scalars.maths import cos 53 | from pyglove.ext.scalars.maths import sin 54 | 55 | from pyglove.ext.scalars.maths import SquareRoot 56 | from pyglove.ext.scalars.maths import Exp 57 | from pyglove.ext.scalars.maths import Log 58 | from pyglove.ext.scalars.maths import Cosine 59 | from pyglove.ext.scalars.maths import Sine 60 | 61 | # Common random scalars. 62 | from pyglove.ext.scalars.randoms import RandomScalar 63 | from pyglove.ext.scalars.randoms import Uniform 64 | from pyglove.ext.scalars.randoms import Triangular 65 | from pyglove.ext.scalars.randoms import Gaussian 66 | from pyglove.ext.scalars.randoms import Normal 67 | from pyglove.ext.scalars.randoms import LogNormal 68 | 69 | # Step-wise scalar. 70 | from pyglove.ext.scalars.step_wise import StepWise 71 | 72 | # pylint: enable=g-bad-import-order 73 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for base scalars.""" 15 | 16 | import unittest 17 | from pyglove.ext.scalars import base as scalars 18 | 19 | 20 | class BasicScalarTest(unittest.TestCase): 21 | """Test basic scalars.""" 22 | 23 | def test_make_scalar(self): 24 | sv = scalars.make_scalar(scalars.Constant(1)) 25 | self.assertIsInstance(sv, scalars.Scalar) 26 | self.assertEqual(sv(0), 1) 27 | self.assertEqual(sv(10), 1) 28 | 29 | sv = scalars.make_scalar(1) 30 | self.assertIsInstance(sv, scalars.Scalar) 31 | self.assertIsInstance(sv(0), int) 32 | self.assertEqual(sv(0), 1) 33 | self.assertEqual(sv(10), 1) 34 | 35 | sv = scalars.make_scalar(lambda step: step) 36 | self.assertIsInstance(sv, scalars.Scalar) 37 | self.assertEqual(sv(1), 1) 38 | self.assertEqual(sv(10), 10) 39 | 40 | def test_step(self): 41 | sv = scalars.STEP * 2 42 | self.assertEqual(sv(0), 0) 43 | self.assertEqual(sv(10), 20) 44 | 45 | 46 | class UnaryOpTest(unittest.TestCase): 47 | """Tests for unary scalar operators.""" 48 | 49 | def test_negation(self): 50 | sv = -scalars.STEP 51 | self.assertEqual(sv(1), -1) 52 | self.assertEqual(sv(2), -2) 53 | 54 | def test_floor(self): 55 | sv = scalars.Constant(1.6).floor() 56 | self.assertEqual(sv(0), 1) 57 | 58 | def test_ceil(self): 59 | sv = scalars.Constant(1.6).ceil() 60 | self.assertEqual(sv(0), 2) 61 | 62 | def test_abs(self): 63 | sv = abs(scalars.Constant(-1)) 64 | self.assertEqual(sv(0), 1) 65 | 66 | 67 | class BinaryOpTest(unittest.TestCase): 68 | """Tests for binary scalar operators.""" 69 | 70 | def test_add(self): 71 | sv = scalars.Constant(1) + 2 72 | self.assertEqual(sv(0), 3) 73 | 74 | sv = 2 + scalars.Constant(1) 75 | self.assertEqual(sv(0), 3) 76 | 77 | def test_substract(self): 78 | sv = scalars.Constant(1) - 2 79 | self.assertEqual(sv(0), -1) 80 | 81 | sv = 2 - scalars.Constant(1) 82 | self.assertEqual(sv(0), 1) 83 | 84 | def test_multiply(self): 85 | sv = scalars.Constant(1) * 2 86 | self.assertEqual(sv(0), 2) 87 | 88 | sv = 2 * scalars.Constant(1) 89 | self.assertEqual(sv(0), 2) 90 | 91 | def test_divide(self): 92 | sv = scalars.Constant(1) / 2 93 | self.assertEqual(sv(0), 0.5) 94 | 95 | sv = 2 / scalars.Constant(1) 96 | self.assertEqual(sv(0), 2) 97 | 98 | def test_floor_divide(self): 99 | sv = scalars.Constant(1) // 2 100 | self.assertEqual(sv(0), 0) 101 | 102 | sv = 2 // scalars.Constant(1) 103 | self.assertEqual(sv(0), 2) 104 | 105 | def test_mod(self): 106 | sv = scalars.Constant(2) % 3 107 | self.assertEqual(sv(0), 2) 108 | 109 | sv = 3 % scalars.Constant(2) 110 | self.assertEqual(sv(0), 1) 111 | 112 | def test_power(self): 113 | sv = scalars.Constant(2) ** 3 114 | self.assertEqual(sv(0), 8) 115 | 116 | sv = 3 ** scalars.Constant(2) 117 | self.assertEqual(sv(0), 9) 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/maths.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Step-based scalars used as evolution hyper-parameter values.""" 15 | 16 | import math 17 | import pyglove.core as pg 18 | from pyglove.ext.scalars import base 19 | 20 | 21 | class SquareRoot(base.UnaryOp): 22 | """The square root scalar.""" 23 | 24 | def operate(self, x: float) -> float: 25 | return math.sqrt(x) 26 | 27 | sqrt = SquareRoot # pylint: disable=invalid-name 28 | 29 | 30 | @pg.members([]) 31 | class Exp(base.UnaryOp): 32 | """More accurate version for math.e ** x.""" 33 | 34 | def operate(self, x: float) -> float: 35 | return math.exp(x) 36 | 37 | exp = Exp # pylint: disable=invalid-name 38 | 39 | 40 | @pg.members([ 41 | ('x', base.scalar_spec(pg.typing.Union([ 42 | pg.typing.Int(min_value=2), 43 | pg.typing.Float(min_value=0.0)]))), 44 | ('base', base.scalar_spec(pg.typing.Union([ 45 | pg.typing.Int(min_value=2), 46 | pg.typing.Float(min_value=0.0)])).set_default(math.e), 47 | 'Base of the log function.'), 48 | ]) 49 | class Log(base.Scalar): 50 | """A log scheduled float.""" 51 | 52 | def _on_bound(self): 53 | super()._on_bound() 54 | self._x = base.make_scalar(self.x) 55 | self._base = base.make_scalar(self.base) 56 | 57 | def call(self, step: int) -> float: 58 | return math.log(self._x(step), self._base(step)) 59 | 60 | log = Log # pylint: disable=invalid-name 61 | 62 | 63 | @pg.members([]) 64 | class Cosine(base.UnaryOp): 65 | """Cosine that works for scalars.""" 66 | 67 | def operate(self, x: float) -> float: 68 | return math.cos(x) 69 | 70 | cos = Cosine # pylint: disable=invalid-name 71 | 72 | 73 | @pg.members([]) 74 | class Sine(base.UnaryOp): 75 | """Sine that works for scalars.""" 76 | 77 | def operate(self, x: float) -> float: 78 | return math.sin(x) 79 | 80 | sin = Sine # pylint: disable=invalid-name 81 | 82 | 83 | # 84 | # Helper function for create popular scalar scheddule. 85 | # 86 | 87 | 88 | def linear(total_steps: int, start: float = 1.0, end: float = 0.0): 89 | """Returns a linear scalar from start to end.""" 90 | return start + base.STEP * ((end - start) / total_steps) 91 | 92 | 93 | def cosine_decay(total_steps: int, start: float = 1.0, end: float = 0.0): 94 | """Returns a cosine decayed scalar from start to end.""" 95 | return 0.5 * (start - end) * ( 96 | 1 + cos(math.pi * base.STEP / total_steps)) + end 97 | 98 | 99 | def exponential_decay( 100 | decay_rate: float, decay_interval: int, 101 | start: float = 1.0, staircase: bool = True): 102 | """Returns a scalar that exponentially decays from start to end.""" 103 | exponent = base.STEP / float(decay_interval) 104 | if staircase: 105 | exponent = exponent.floor() 106 | return start * (decay_rate ** exponent) 107 | 108 | 109 | def cyclic(cycle: int, initial_radiant: float = 0.0, 110 | high: float = 1.0, low: float = 0.0): 111 | """Returns a cyclic scalar using sin/cos.""" 112 | return 0.5 * (high - low) * ( 113 | 1 + cos(initial_radiant + math.pi * 2 * base.STEP / cycle)) + low 114 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/randoms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for random scalars.""" 15 | 16 | import unittest 17 | from pyglove.ext.scalars import randoms as scalars 18 | 19 | 20 | class RandomScalarsTest(unittest.TestCase): 21 | """Random scalars tests.""" 22 | 23 | def test_uniform(self): 24 | sv = scalars.Uniform(seed=1) 25 | self.assertEqual(sv(0), 0.13436424411240122) 26 | self.assertEqual(sv(0), 0.8474337369372327) 27 | 28 | sv = scalars.Uniform(1, 10, seed=1) 29 | self.assertEqual(sv(0), 3) 30 | self.assertEqual(sv(0), 10) 31 | 32 | with self.assertRaisesRegex( 33 | ValueError, 34 | '`low` must be less or equal than `high`.'): 35 | scalars.Uniform(10, 1, seed=1) 36 | 37 | def test_triangular(self): 38 | sv = scalars.Triangular(0.0, 1.0, 0.9, seed=1) 39 | self.assertEqual(sv(0), 0.34774677525630787) 40 | self.assertEqual(sv(0), 0.8733214547023962) 41 | 42 | sv = scalars.Triangular(10, 20, seed=1) 43 | self.assertEqual(sv(0), 12) 44 | self.assertEqual(sv(0), 17) 45 | 46 | def test_gaussian(self): 47 | sv = scalars.Gaussian(1.0, 0.2, seed=1) 48 | self.assertEqual(sv(0), 1.2576369506310927) 49 | self.assertEqual(sv(0), 1.2898891217399542) 50 | 51 | def test_normal(self): 52 | sv = scalars.Normal(1.0, 0.2, seed=1) 53 | self.assertEqual(sv(0), 1.1214911715287412) 54 | self.assertEqual(sv(0), 0.997154910897843) 55 | 56 | def test_log_normal(self): 57 | sv = scalars.LogNormal(1.0, 0.2, seed=1) 58 | self.assertEqual(sv(0), 3.0694278358084994) 59 | self.assertEqual(sv(0), 2.710559065635824) 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/step_wise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Step-based scalars used as evolution hyper-parameter values.""" 15 | 16 | 17 | from typing import Any 18 | import pyglove.core as pg 19 | from pyglove.ext.scalars import base 20 | 21 | 22 | # 23 | # Scheduled values that can be designed in multiple phases. 24 | # 25 | 26 | 27 | @pg.members([ 28 | ('phases', pg.typing.List( 29 | pg.typing.Tuple([ 30 | pg.typing.Union([pg.typing.Int(min_value=0), 31 | pg.typing.Float(min_value=0.)]), 32 | base.scalar_spec(pg.typing.Any()) 33 | ]), min_size=1), 34 | ('All the phases in the schedule. Each item in the list is a tuple of ' 35 | '`(length of phase, scheduled value)`. The length of phase can be an ' 36 | 'integer representing number of steps used for that phase, or a float as ' 37 | 'the proportion of that phase if `total_steps` is specified. All items ' 38 | 'in the list should use the same type (integer or float) for the length ' 39 | 'of phase. When a proportion is used, their sum does not have to sum up ' 40 | 'to 1.')), 41 | ('total_steps', pg.typing.Int(min_value=1).noneable(), 42 | ('Total number of steps for the schedule. If None, the length of each ' 43 | 'phase must be an integer.')) 44 | ]) 45 | class StepWise(base.Scalar): 46 | """A step-wise schedule that is specified via multiple phases.""" 47 | 48 | def _on_bound(self): 49 | super()._on_bound() 50 | 51 | last_step = 0 52 | phase_ending_steps = [] 53 | if self.total_steps is None: 54 | for phase_len, phase_value in self.phases: 55 | if isinstance(phase_len, float): 56 | raise ValueError( 57 | f'`total_steps` must be specified when float is used as the ' 58 | f'value for phase length. ' 59 | f'Encountered: ({phase_len}, {phase_value}).') 60 | last_step += phase_len 61 | phase_ending_steps.append(last_step - 1) 62 | else: 63 | proportion_sum = 0. 64 | for proportion, phase_value in self.phases: 65 | if isinstance(proportion, int): 66 | raise ValueError( 67 | f'The phase length should be a float as a proportion of the ' 68 | f'entire schedule when `total_steps` is specified. ' 69 | f'Encountered: ({proportion}, {phase_value}).') 70 | proportion_sum += proportion 71 | 72 | if proportion_sum == 0: 73 | raise ValueError( 74 | f'The sum of all proportions must be greater than 0. ' 75 | f'Encountered: {self.phases!r}') 76 | 77 | for proportion, _ in self.phases: 78 | phase_len = int(proportion / proportion_sum * self.total_steps) 79 | last_step += phase_len 80 | phase_ending_steps.append(last_step - 1) 81 | # Phase ending step is the step AFTER which the next phase will start. 82 | self._phase_ending_steps = phase_ending_steps 83 | self._phases = [base.make_scalar(p) for l, p in self.phases] 84 | self._current_phase = 0 85 | self._last_value = None 86 | 87 | def call(self, step: int) -> Any: 88 | if self._current_phase < len(self.phases): 89 | if self._current_phase > 0: 90 | phase_step = step - ( 91 | self._phase_ending_steps[self._current_phase - 1] + 1) 92 | else: 93 | phase_step = step 94 | self._last_value = self._phases[self._current_phase](phase_step) 95 | if step == self._phase_ending_steps[self._current_phase]: 96 | self._current_phase += 1 97 | return self._last_value 98 | -------------------------------------------------------------------------------- /pyglove/ext/scalars/step_wise_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for step-wise scalars.""" 15 | 16 | import unittest 17 | from pyglove.ext.scalars import base 18 | from pyglove.ext.scalars import step_wise as scalars 19 | 20 | 21 | class StepWiseScalarTest(unittest.TestCase): 22 | """Test step-wise scalar schedule.""" 23 | 24 | def test_step_as_phase_length(self): 25 | sv = scalars.StepWise([ 26 | (2, 1), 27 | (2, base.STEP), 28 | (3, base.STEP ** 2) 29 | ]) 30 | self.assertEqual([sv(i) for i in range(10)], [ 31 | # For each phase, base.STEP 32 | # is evaluated to 0 when phase starts. 33 | 1, 1, # Phase 1 34 | 0, 1, # Phase 2 35 | 0, 1, 4, # Phase 3 36 | 4, 4, 4 # Use the last value for the rest. 37 | ]) 38 | 39 | def test_proportion_as_phase_length(self): 40 | sv = scalars.StepWise([ 41 | (0.2, 1), 42 | (0.2, base.STEP), 43 | (0.3, base.STEP ** 2) 44 | ], total_steps=8) 45 | self.assertEqual([sv(i) for i in range(10)], [ 46 | # For each phase, base.STEP 47 | # is evaluated to 0 when phase starts. 48 | 1, 1, # Phase 1 49 | 0, 1, # Phase 2 50 | 0, 1, 4, # Phase 3 51 | 4, 4, 4 # Use the last value for the rest. 52 | ]) 53 | 54 | def test_bad_specification(self): 55 | with self.assertRaisesRegex( 56 | ValueError, 57 | '`total_steps` must be specified when float is used as the value'): 58 | _ = scalars.StepWise([ 59 | (0.2, 1), 60 | (0.2, base.STEP), 61 | (0.3, base.STEP ** 2)]) 62 | 63 | with self.assertRaisesRegex( 64 | ValueError, 65 | 'The sum of all proportions must be greater than 0'): 66 | _ = scalars.StepWise([ 67 | (0.0, 1), 68 | (0.0, base.STEP), 69 | (0.0, base.STEP ** 2) 70 | ], total_steps=10) 71 | 72 | with self.assertRaisesRegex( 73 | ValueError, 74 | 'The phase length should be a float as a proportion of the'): 75 | _ = scalars.StepWise([ 76 | (1, 1), 77 | (2, base.STEP), 78 | (3, base.STEP ** 2) 79 | ], total_steps=10) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docstring-parser>=0.12 2 | termcolor>=1.1.0 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PyGlove Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Setup for pip package.""" 15 | 16 | import datetime 17 | import sys 18 | from setuptools import find_namespace_packages 19 | from setuptools import setup 20 | 21 | 22 | def _get_version(): 23 | """Gets current version of PyGlove package.""" 24 | with open('pyglove/__init__.py') as fp: 25 | version = None 26 | for line in fp: 27 | if line.startswith('__version__'): 28 | g = {} 29 | exec(line, g) # pylint: disable=exec-used 30 | version = g['__version__'] 31 | break 32 | if version is None: 33 | raise ValueError('`__version__` not defined in `pyglove/__init__.py`') 34 | if '--nightly' in sys.argv: 35 | nightly_label = datetime.datetime.now().strftime('%Y%m%d%H%M') 36 | version = f'{version}.dev{nightly_label}' 37 | sys.argv.remove('--nightly') 38 | return version 39 | 40 | 41 | def _parse_requirements(requirements_txt_path: str) -> list[str]: 42 | """Returns a list of dependencies for setup() from requirements.txt.""" 43 | 44 | def _strip_comments_from_line(s: str) -> str: 45 | """Parses a line of a requirements.txt file.""" 46 | requirement, *_ = s.split('#') 47 | return requirement.strip() 48 | 49 | # Currently a requirements.txt is being used to specify dependencies. In order 50 | # to avoid specifying it in two places, we're going to use that file as the 51 | # source of truth. 52 | with open(requirements_txt_path) as fp: 53 | # Parse comments. 54 | lines = [_strip_comments_from_line(line) for line in fp.read().splitlines()] 55 | # Remove empty lines and direct github repos (not allowed in PyPI setups) 56 | return [l for l in lines if (l and 'github.com' not in l)] 57 | 58 | 59 | _VERSION = _get_version() 60 | 61 | setup( 62 | name='pyglove', 63 | version=_VERSION, 64 | url='https://github.com/google/pyglove', 65 | license='Apache License 2.0', 66 | author='PyGlove Authors', 67 | description='PyGlove: A library for manipulating Python objects.', 68 | long_description=open('README.md').read(), 69 | long_description_content_type='text/markdown', 70 | author_email='pyglove-authors@google.com', 71 | # Contained modules and scripts. 72 | packages=find_namespace_packages(include=['pyglove*']), 73 | install_requires=_parse_requirements('requirements.txt'), 74 | extras_require={}, 75 | requires_python='>=3.9', 76 | include_package_data=True, 77 | # PyPI package information. 78 | classifiers=[ 79 | 'Development Status :: 5 - Production/Stable', 80 | 'Intended Audience :: Developers', 81 | 'Intended Audience :: Education', 82 | 'Intended Audience :: Science/Research', 83 | 'License :: OSI Approved :: Apache Software License', 84 | 'Programming Language :: Python :: 3', 85 | 'Programming Language :: Python :: 3.9', 86 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 87 | 'Topic :: Scientific/Engineering :: Human Machine Interfaces', 88 | 'Topic :: Software Development :: Code Generators', 89 | 'Topic :: Software Development :: Libraries :: Python Modules', 90 | 'Topic :: Software Development :: Libraries', 91 | ], 92 | keywords=( 93 | 'ai machine learning automl mutable symbolic ' 94 | 'framework meta-programming'), 95 | ) 96 | --------------------------------------------------------------------------------