├── .github └── workflows │ ├── docs_publish.yaml │ └── tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── docs └── source │ ├── about.md │ ├── conf.py │ ├── contributing.md │ ├── examples.md │ ├── examples_pbe.md │ ├── grammars.md │ ├── images │ └── pipeline.png │ ├── index.rst │ ├── introduction.md │ ├── license.md │ ├── prediction.md │ ├── sharpening.md │ ├── tutorial.md │ ├── type_system.md │ └── usage.rst ├── examples ├── README.md ├── compare_enumeration.py ├── pbe │ ├── README.md │ ├── analysis │ │ ├── dsl_analyzer.py │ │ └── json_comparator.py │ ├── calculator │ │ ├── calculator.py │ │ ├── convert_calculator.py │ │ └── dataset │ │ │ └── calculator_dataset.json │ ├── convert_to_psl.py │ ├── dataset_explorer.py │ ├── dataset_generator.py │ ├── dataset_generator_unique.py │ ├── dataset_improve.py │ ├── dataset_learner.py │ ├── dataset_loader.py │ ├── deepcoder │ │ ├── convert_deepcoder.py │ │ └── deepcoder.py │ ├── dreamcoder │ │ ├── convert_dreamcoder.py │ │ └── dreamcoder.py │ ├── dsl_equation_generator.py │ ├── dsl_loader.py │ ├── equivalence_classes_to_filter.py │ ├── evaluate_deprecated.py │ ├── karel │ │ ├── karel.py │ │ ├── karel_equation_generator.py │ │ └── karel_task_generator.py │ ├── model_embeddings_visualizer.py │ ├── model_loader.py │ ├── model_prediction.py │ ├── model_trainer.py │ ├── quantum_circuits │ │ ├── quantum.py │ │ └── quantum_tasks_generator.py │ ├── regexp │ │ ├── convert_regexp.py │ │ ├── evaluator_regexp.py │ │ ├── regexp.py │ │ ├── task_generator_regexp.py │ │ └── type_regex.py │ ├── solve.py │ └── transduction │ │ ├── convert_transduction.py │ │ ├── dataset │ │ ├── convert_flashfill.py │ │ └── flashfill.json │ │ ├── knowledge_graph │ │ ├── README.md │ │ ├── constants.json │ │ ├── convert_kg_json_tasks.py │ │ ├── fill_knowledge_graph.sparql │ │ ├── kg_path_finder.py │ │ └── preprocess_tasks.py │ │ ├── task_generator_transduction.py │ │ └── transduction.py ├── plot_enumeration_results.py ├── plot_helper.py └── plot_solve_results.py ├── images └── logo.png ├── mypy.ini ├── poetry.lock ├── pyproject.toml ├── setup.py ├── synth ├── __init__.py ├── filter │ ├── __init__.py │ ├── constraints │ │ ├── __init__.py │ │ ├── dfta_constraints.py │ │ ├── parsing.py │ │ └── ttcfg_constraints.py │ ├── dfta_filter.py │ ├── filter.py │ ├── local_stateless_filter.py │ ├── obs_eq_filter.py │ └── syntactic_filter.py ├── generation │ ├── __init__.py │ └── sampler.py ├── library │ ├── __init__.py │ └── learning.py ├── nlp │ ├── __init__.py │ └── bert.py ├── nn │ ├── __init__.py │ ├── abstractions.py │ ├── det_grammar_predictor.py │ ├── spec_encoder.py │ ├── u_grammar_predictor.py │ └── utils.py ├── pbe │ ├── __init__.py │ ├── io_encoder.py │ ├── solvers │ │ ├── __init__.py │ │ ├── pbe_solver.py │ │ └── restart_pbe_solver.py │ └── task_generator.py ├── py.typed ├── semantic │ ├── __init__.py │ └── evaluator.py ├── specification.py ├── syntax │ ├── __init__.py │ ├── automata │ │ ├── __init__.py │ │ ├── dfa.py │ │ └── tree_automaton.py │ ├── dsl.py │ ├── grammars │ │ ├── __init__.py │ │ ├── cfg.py │ │ ├── det_grammar.py │ │ ├── enumeration │ │ │ ├── __init__.py │ │ │ ├── a_star.py │ │ │ ├── beap_search.py │ │ │ ├── bee_search.py │ │ │ ├── constant_delay.py │ │ │ ├── constant_delay_queue.py │ │ │ ├── grammar_splitter.py │ │ │ ├── heap_search.py │ │ │ ├── program_enumerator.py │ │ │ └── u_heap_search.py │ │ ├── grammar.py │ │ ├── tagged_det_grammar.py │ │ ├── tagged_u_grammar.py │ │ ├── ttcfg.py │ │ ├── u_cfg.py │ │ └── u_grammar.py │ ├── program.py │ ├── type_helper.py │ └── type_system.py ├── task.py └── utils │ ├── __init__.py │ ├── chrono.py │ ├── data_storage.py │ ├── generator_utils.py │ ├── import_utils.py │ ├── ordered.py │ └── vose_polyfill.py └── tests ├── filtering └── constraints │ ├── test_dfta_constraints.py │ ├── test_parsing.py │ └── test_ttcfg_constraints.py ├── generation └── test_sampler.py ├── nn ├── test_det_grammar_predictor.py └── test_u_grammar_predictor.py ├── pbe ├── solvers │ ├── test_pbe_solver.py │ └── test_restart_pbe_solver.py ├── test_io_encoder.py └── test_task_generator.py ├── semantic └── test_evaluator.py ├── syntax ├── automata │ └── test_tree_automaton.py ├── grammars │ ├── enumeration │ │ ├── test_a_star.py │ │ ├── test_beap_search.py │ │ ├── test_bee_search.py │ │ ├── test_constant_delay.py │ │ ├── test_grammar_splitter.py │ │ ├── test_heap_search.py │ │ └── test_u_heap_search.py │ ├── test_cfg.py │ ├── test_tagged_det_grammar.py │ ├── test_tagged_u_grammar.py │ ├── test_ttcfg.py │ └── test_ucfg.py ├── test_dsl.py ├── test_program.py ├── test_type_helper.py └── test_type_system.py ├── test_task.py ├── test_vose.py └── utils └── test_generator_utils.py /.github/workflows/docs_publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Documentation 2 | 3 | on: 4 | push: 5 | branches: [ 'main' ] 6 | pull_request: 7 | branches: [ main ] 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.9] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | cache: 'pip' # caching pip dependencies 22 | - name: Install Dependencies 23 | run: | 24 | pip install --upgrade pip 25 | pip install sphinx sphinx-rtd-theme myst-parser 26 | #---------------------------------------------- 27 | # Build documentation 28 | #---------------------------------------------- 29 | - name: Build documentation 30 | run: sphinx-build -b html docs/source docs/build 31 | 32 | #---------------------------------------------- 33 | # Clone documentation 34 | #---------------------------------------------- 35 | - uses: actions/checkout@v3 36 | with: 37 | ref: gh-pages 38 | path: pages 39 | #---------------------------------------------- 40 | # Move documentation 41 | #---------------------------------------------- 42 | - name: Move documentation 43 | run: | 44 | rm -r pages/docs/* 45 | mv -f docs/build/* pages/docs/ 46 | #---------------------------------------------- 47 | # Commit & Push changes 48 | #---------------------------------------------- 49 | - name: Commit and Push 50 | continue-on-error: true 51 | run: | 52 | cd pages 53 | git config user.name github-actions 54 | git config user.email github-actions@github.com 55 | git add -f docs/ 56 | git commit -m "Updated documentation" 57 | git push 58 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches-ignore: [ gh-pages ] 6 | pull_request: 7 | branches: [ main ] 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.9] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | #---------------------------------------------- 22 | # ----- install & configure poetry ----- 23 | #---------------------------------------------- 24 | - name: Install Poetry 25 | uses: snok/install-poetry@v1 26 | with: 27 | virtualenvs-create: true 28 | virtualenvs-in-project: true 29 | installer-parallel: true 30 | #---------------------------------------------- 31 | # load cached venv if cache exists 32 | #---------------------------------------------- 33 | - name: Load cached venv 34 | id: cached-poetry-dependencies 35 | uses: actions/cache@v3 36 | with: 37 | path: .venv 38 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 39 | #---------------------------------------------- 40 | # install dependencies if cache does not exist 41 | #---------------------------------------------- 42 | - name: Install dependencies 43 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 44 | continue-on-error: true 45 | run: poetry install --no-interaction --no-root 46 | #---------------------------------------------- 47 | # install your root project, if required 48 | #---------------------------------------------- 49 | - name: Install library 50 | run: poetry install --no-interaction 51 | #---------------------------------------------- 52 | # type check, and test your code 53 | #---------------------------------------------- 54 | - name: MyPy 55 | run: | 56 | poetry run mypy synth 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac files 2 | .DS_store 3 | _site 4 | 5 | # VS code 6 | .vscode 7 | 8 | # Python files 9 | __pycache__ 10 | 11 | # Pytest 12 | .pytest_cache 13 | 14 | # Tensorboard runs 15 | runs 16 | 17 | # venv 18 | py3 19 | 20 | # documentation build 21 | docs/build 22 | 23 | # dataset files 24 | *.pickle 25 | 26 | # csv files 27 | *.csv 28 | 29 | # model files 30 | *.pt 31 | 32 | # VS Code Counter 33 | .VSCodeCounter 34 | 35 | # Parsing Sygus files 36 | examples/sygus/parsing/**/*.py -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | - repo: https://github.com/astral-sh/ruff-pre-commit 2 | # Ruff version. 3 | rev: v0.4.3 4 | hooks: 5 | # Run the linter. 6 | - id: ruff 7 | args: [ --fix ] 8 | # Run the formatter. 9 | - id: ruff-format -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DeepSynth2 2 | 3 | Feel free to open an issue or pull request if you have any questions or suggestions. 4 | If you plan to work on an issue, let us know in the issue thread so we can avoid duplicate work. 5 | 6 | Before attempting to push, make sure that the following holds: 7 | 8 | - No code formatting error, see [Code Formatting](#code-formatting); 9 | - No type error, see [Typing](#typing); 10 | - All tests pass, see [Testing](#testing). 11 | 12 | ## Dev Setup 13 | 14 | These development dependencies can be be found in [pyproject.toml](./pyproject.toml). 15 | If you are using a development install, they should be installed by default. Otherwise you need to install them manually. 16 | 17 | ## Code Formatting 18 | 19 | We use [Black](https://black.readthedocs.io/en/stable/) for code formatting. Th exact version used can be found in [pyproject.toml](./pyproject.toml). 20 | You can run the following to format all files: 21 | 22 | ```bash 23 | black . 24 | ``` 25 | 26 | ## Typing 27 | 28 | We use [mypy](http://mypy-lang.org/) to check typing. We require you to use type hints at all times. That means for all function signatures and all places where `mypy` can't deduce the full type, type hints should be placed. 29 | You can check if there are no typing errors with: 30 | 31 | ```bash 32 | mypy synth 33 | ``` 34 | 35 | ## Testing 36 | 37 | We use [Pytest](https://docs.pytest.org/en/latest/). 38 | Please ensure a few things: 39 | 40 | - When adding a new feature, also add relevant tests. 41 | - Tests should be deterministic. If your test depends on randomness, do not forget to seed. 42 | - No test should fail when you commit. 43 | 44 | Finally, you can run the tests with: 45 | 46 | ```bash 47 | pytest . 48 | ``` -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Nathanaël FIJALKOW & Théo MATRICON 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/source/about.md: -------------------------------------------------------------------------------- 1 | About 2 | === 3 | 4 | ```{include} ../../README.md 5 | --- 6 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | 5 | project = "ProgSynth" 6 | copyright = "2023, Nathanaël Fijalkow & Théo Matricon" 7 | author = "Nathanaël Fijalkow & Théo Matricon" 8 | 9 | release = "0.1" 10 | version = "0.1.0" 11 | 12 | # -- General configuration 13 | 14 | extensions = [ 15 | "sphinx.ext.duration", 16 | "sphinx.ext.doctest", 17 | "sphinx.ext.autodoc", 18 | "sphinx.ext.autosummary", 19 | "sphinx.ext.intersphinx", 20 | "sphinx.ext.autosectionlabel", 21 | "myst_parser", 22 | ] 23 | 24 | intersphinx_mapping = { 25 | "python": ("https://docs.python.org/3/", None), 26 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 27 | } 28 | intersphinx_disabled_domains = ["std"] 29 | 30 | templates_path = ["_templates"] 31 | 32 | # -- Options for HTML output 33 | 34 | html_theme = "sphinx_rtd_theme" 35 | 36 | # -- Options for EPUB output 37 | epub_show_urls = "footnote" 38 | source_suffix = { 39 | ".rst": "restructuredtext", 40 | ".txt": "markdown", 41 | ".md": "markdown", 42 | } 43 | # Make sure the target is unique 44 | autosectionlabel_prefix_document = True 45 | -------------------------------------------------------------------------------- /docs/source/contributing.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../CONTRIBUTING.md 2 | --- 3 | -------------------------------------------------------------------------------- /docs/source/examples.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../examples/README.md 2 | --- 3 | -------------------------------------------------------------------------------- /docs/source/examples_pbe.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../examples/pbe/README.md 2 | --- 3 | -------------------------------------------------------------------------------- /docs/source/grammars.md: -------------------------------------------------------------------------------- 1 | # Grammars 2 | 3 | Grammars are programs generator, all our grammars are finite. 4 | Typically when you instantiate a grammar you always specify a maximum depth. 5 | 6 | The main object of interest are probabilistic grammars on which most methods to enumerate and sample programs are provided. 7 | 8 | 9 | Table of contents: 10 | 11 | - [Grammar Models](#grammar-models) 12 | - [det-CFG](#det-cfg) 13 | - [U-CFG](#u-cfg) 14 | - [Probabilistic Grammars](#probabilistic-grammars) 15 | 16 | 17 | 18 | ## Grammar Models 19 | 20 | Currently the only grammar model supported are [Context-free grammars](https://en.wikipedia.org/wiki/Context-free_grammar) (CFG). 21 | All our rules have the following form: 22 | 23 | ``` 24 | S -> f S1 S2 ... Sk 25 | S -> g 26 | ``` 27 | 28 | where ``S``, ``S1``, ..., ``Sk`` are non terminal and ``f`` is a primitive of arity ``k`` and ``g`` is a primitive of arity 0, in other words a constant. 29 | 30 | We have two different models: deterministic CFG and unambiguous CFG; while the latter is more expressive it is around 20% slower but used correctly the gains are huge. 31 | 32 | The ways to generate a grammar are mainly through static methods such as ``MyGrammarModel.depth_constraint(dsl, type_request)``. 33 | Grammars albeit already complex objects are not the final object of interests in ProgSynth. 34 | The most relevant methods are: 35 | 36 | - ``program in grammar`` which returns whether program belongs to the grammar or not; 37 | - ``grammar.programs()`` which yields the number of programs contained in the grammar, do not convert it to float as this easily yield values over MAX_DOUBLE, hence we return an int to take advantage of the lack of limit for int in python; 38 | - ``grammar.derive(...)`` which allows you to derive your program step by step; 39 | - ``grammar.derive_all(...)`` which derives the whole given subtree for you and hands you the result; 40 | - ``grammar.reduce_derivatons(...)`` which is like a fold over the derivation steps of the given program. 41 | 42 | ### det-CFG 43 | 44 | A CFG which has the following property: 45 | > For a given non-terminal ``S``, for any primitive ``f``, there is at most one derivation from ``S`` using primitive ``f`` 46 | 47 | In other words, it is deterministic to derive ``f`` from non-terminal ``S``. 48 | 49 | In ProgSynth this is the default model, that is ``CFG``. 50 | If you do not use [sharpening](sharpening.md) for example, then ProgSynth uses this model when producing a grammar. 51 | 52 | ### U-CFG 53 | 54 | A CFG which has the following property: 55 | > For a tree/program ``t``, there exists at most one derivation for tree/program ``t`` in the grammar 56 | 57 | In other words, there is no ambiguity to derive a program from the grammar, but locally it may be ambiguous, that is you have to try all derivation rules for the primitive to find out later which is the one that allows deriving the program. 58 | 59 | ``UCFG`` in ProgSynth can express all regular tree languages and is generated when you use [sharpening](sharpening.md). 60 | 61 | ## Probabilistic Grammars 62 | 63 | We offer tagged grammars, those are grammars where derivations are tagged with a generic type, replacing 'probabilistic' with 'tagged' in what is following will work as well. 64 | The most relevant one is when derivations are tagged with float giving you probabilistic grammars. 65 | > For a given non-terminal ``S``, the set of all derivations from ``S`` make up a probability distribution, *i.e.* sum up to 1. 66 | 67 | There are two models: ``ProbGrammar`` and ``ProbUGrammar`` respectively working for ``CFG`` and ``UCFG``. 68 | Basically adding a U for class and a u_ for methods to the classic method will yield the equivalent methods for the unambiguous model. 69 | 70 | Probabilistic grammars offer a wide range of interesting methods to generate programs: 71 | 72 | - ``pgrammar.sample()`` sample a random program from the grammar, you will need to first call ``pgrammar.init_sampling(seed)`` for sampling to work, sampling is optimised compared to naive sampling; 73 | - ``enumerate_prob_(u_)_grammar(pgrammar)`` which gives you an enumerator that will enumerate programs in the grammar by decreasing order of probability; 74 | - ``split(pgrammar, n)`` which gives you ``n`` disjoint probabilistic unambiguous grammars that make up a partition of the original given ``pgrammar``, the main intereset is to easily parallelise the enumeration. 75 | 76 | Of course, since probabilistic grammars are grammars they also offer the same methods as classic grammars. 77 | 78 | **But I want to enumerate programs by size?** 79 | 80 | Well, you can just use ``Prob(U)Grammar.uniform(grammar)`` and enumerate that probabilistic grammar will give you an enumeration by program size. 81 | -------------------------------------------------------------------------------- /docs/source/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthesisLab/DeepSynth2/3efba7bfc03d3e576d67fb433c5fd8d50e3ebdb5/docs/source/images/pipeline.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ProgSynth, high-level framework for program synthesis 2 | ===================================================== 3 | 4 | **ProgSynth** 5 | ProgSynth is a high-level framework that enables to leverage program synthesis for other domains such as reinforcement learning or system design. 6 | 7 | Contents 8 | --------- 9 | .. toctree:: 10 | about 11 | introduction 12 | usage 13 | tutorial 14 | type_system 15 | grammars 16 | Specifications 17 | PBE 18 | prediction 19 | sharpening 20 | contributing 21 | license 22 | -------------------------------------------------------------------------------- /docs/source/introduction.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This page seeks to answer the following question: 4 | > How does ProgSynth works at a high level? 5 | 6 | 7 | Table of contents: 8 | 9 | - [What do you give in?](#what-do-you-give-in) 10 | - [What do you get out?](#what-do-you-get-out) 11 | - [How does that work?](#how-does-that-work) 12 | - [Sharpening (Optional)](#sharpening-optional) 13 | - [Compilation](#compilation) 14 | - [Prediction (Optional in the future)](#prediction-optional-in-the-future) 15 | - [Splitting (Optional)](#splitting-optional) 16 | - [Enumeration](#enumeration) 17 | 18 | 19 | 20 | ## What do you give in? 21 | 22 | ProgSynth only needs two things: a language to work on and a check function. 23 | 24 | The provided language is called a domain specific language (DSL). 25 | The different kind of functions that can manipulate your data are provided and the framework will automatically generate the associated grammar. 26 | Specifyign a language with its types is quite fast, it can be done in matters of minutes with the help of ProgSynth which has helper functions to help you focus on what really matters here: the synthesis. 27 | 28 | The check function is a python function, it can contain whatever code you might need to check that a given program satisfy your constraints. An example of such a function is to check whether your program matches the given examples of input-output pairs. 29 | 30 | ## What do you get out? 31 | 32 | At the end you get a program tthat satisfies your condition, or if all programs have been enumerated, you get that there is no program in the given grammar that matches your specification. In practice the negative answer is never given since the number of programs grow exponentially, it would be infeasible to enumerate them all, so we recommend to use a timeout after which the search is stopped. 33 | 34 | ## How does that work? 35 | 36 | This section attemps to give you a high-level overview of how it works. 37 | Here is a figure that gives you a rough overview of how it works: 38 | ![pipeline](./images/pipeline.png) 39 | 40 | ### Sharpening (Optional) 41 | 42 | Sharepning enables you to add syntactic constraints on the search space of programs, this enables to have large speed-ups in the enumeration step. 43 | See [the page on sharpening](sharpening.md) for more details. 44 | If you want to add syntactic constraints to your grammar then you need to write them and give them to ProgSynth in the compilation step. 45 | 46 | ### Compilation 47 | 48 | The language that you give is typed and you also give at least a depth constraint, to guarantee that programs have a maximum depth. 49 | Of course, you also specify the type of the program that you want to generate. 50 | This language is built into a context-free grammar (CFG). 51 | 52 | If you have specified constraints through [sharpening](sharpening.md), then the constraints and the grammar are compiled into deterministic bottom-up tree automata. 53 | The intersection is then computed and transformed back into a CFG. 54 | 55 | ### Prediction (Optional in the future) 56 | 57 | If you have trained a model to produce program distributions over grammars, then you can use this step otherwise fear not because this will not be mandatory in the future. 58 | 59 | A model (often a neural network) takes as an input your specification and other information that you trained it on and then produce a vector which we translate into probabilities for our CFG. 60 | That means that we have a probabilistic CFG (PCFG). 61 | 62 | ### Splitting (Optional) 63 | 64 | You have multiple CPUs and you want to parallelize the search? 65 | Well, we have a ``split`` function that takes a PCFG and split it into as many fragments as you want. 66 | This means that each fragment of the original PCFG is independent that is no other fragment contains its programs. 67 | In other words, with splitting you can have linear speed-up in your enumeration with the number of CPUs. 68 | 69 | Do note that while splitting is provided in ProgSynth we do not provide the parallelization framework, you will have to do it yourself. 70 | 71 | ### Enumeration 72 | 73 | This PCFG actually gives us an order over programs. Programs can be ordered by their probability. 74 | Therefore ProgSynth will enumerate programs in order of decreasing probability. That means that the most likely program will be enumerated first. 75 | All of the time for synthesis is spent here and in most cases the cost to call your check function is what bounds the runtime. 76 | 77 | When a program is enumerated, we call your check function with the program as argument. 78 | If your function returns ``True`` then we just stop here and returns the program, otherwise we continue. 79 | -------------------------------------------------------------------------------- /docs/source/license.md: -------------------------------------------------------------------------------- 1 | License 2 | === 3 | 4 | ```{include} ../../LICENSE.md 5 | --- 6 | -------------------------------------------------------------------------------- /docs/source/type_system.md: -------------------------------------------------------------------------------- 1 | # The Type System 2 | 3 | The type system in ProgSynth has the vocation of adding constraints for compilation from a DSL into a grammar. 4 | 5 | ProgSynth does not check at any time that the data you manipulate has the correct type, types are only used at compilation time. 6 | 7 | 8 | Table of contents: 9 | 10 | - [Basic Type](#basic-type) 11 | - [Advanced Types](#advanced-types) 12 | - [Sum Type](#sum-type) 13 | - [Arrow](#arrow) 14 | - [Generic](#generic) 15 | - [Polymorphic Types](#polymorphic-types) 16 | - [Methods of interest](#methods-of-interest) 17 | 18 | 19 | 20 | ## Basic Type 21 | 22 | Ground types are ``PrimitiveType``. 23 | An ``int`` is represented as ``PrimitiveType("int")``. 24 | Notice that it is uniquely identified by its name therefore two instances with the same name represent the same type. 25 | 26 | ProgSynth already defines the following types: ``INT``, ``BOOL``, ``STRING``, ``UNIT``. 27 | 28 | ## Advanced Types 29 | 30 | Advanced types are built from other types. 31 | There are three: 32 | 33 | - Sum Type; 34 | - Arrow; 35 | - Generic. 36 | 37 | ### Sum Type 38 | 39 | Sum types are union in python. 40 | They can be built two ways, first with ``Sum(t1, t2)`` or ``t1 | t2``. 41 | 42 | ### Arrow 43 | 44 | Arrow represent functions. 45 | They can be built with ``Arrow(t1, t2)``. 46 | A function ``int->int->int`` would have type ``Arrow(int, Arrow(int, int))``. 47 | An easier way to construct these arrows is to use ``FunctionType(int, int, int)`` in this case. While ``Arrow`` is a binary constructor, ``FunctionType`` allows any number of arguments and ensure that the ``Arrow``are built correctly especially if you are using higher order functions. 48 | 49 | ### Generic 50 | 51 | Generic are parametric types. For example, the previous type ``Arrow``is *almost* a generic, the ``List``type is one. You can make a list type out of any type using ``List(t)``. 52 | To instanciate a Generic builder you can use the ``GenericFunctor``: 53 | 54 | ```python 55 | List = GenericFunctor("list", min_args=1, max_args=1) 56 | # Arrow behaves almost like a Generic defined the following way: 57 | Arrow = GenericFunctor( 58 | "->", 59 | min_args=2, 60 | max_args=2, 61 | infix=True, 62 | ) 63 | 64 | ``` 65 | 66 | ## Polymorphic Types 67 | 68 | One can instantiate polymorphic types with ``PolymorphicType("a")`` as with ``PrimitiveType`` they are uniquely identified by their name. 69 | At compilation, a polymorphic type will take as possible values any ground type that is present in the DSL and advanced types built on top of them recursively up to some type size. 70 | 71 | You can limit the set of types a polymorphic type can take with ``FixedPolymorphicType("f", t1, t2, t3)`` which will only take types that are ``t1``, ``t2`` or ``t3``. 72 | 73 | You can check if a polymorphic type ``poly_t`` can be assigned some type ``t`` with ``poly_t.can_be(t)``. 74 | 75 | ## Methods of interest 76 | 77 | Creating a type can be done by instanciating the objects individually or you can use the ``auto_type`` method which takes either your string type or your syntax dictionnary and transform it into a real type. Here are a few examples: 78 | 79 | ```python 80 | from synth.syntax import auto_type 81 | 82 | t = auto_type("int") 83 | # PrimitiveType("int") 84 | t = auto_type("int | float -> float") 85 | # Arrow(int | float, float) 86 | t = auto_type("'a list ('a -> 'b ) -> 'b list") 87 | # let 88 | # a = PolymorphicType("a") 89 | # b = PolymorphicType("b") 90 | # in 91 | # FunctionType(List(a), Arrow(a, b), List(b)) 92 | 93 | t = auto_type("'a[int | float] -> 'a[int | float]") 94 | # let 95 | # a = FixedPolymorphicType("a", PrimitiveType("int"), PrimitiveType("float)) 96 | # or equivalently 97 | # a = FixedPolymorphicType("a", PrimitiveType("int") | PrimitiveType("float)) 98 | # in 99 | # Arrow(a, a) 100 | 101 | t = auto_type("int optional") 102 | # Generic("optional", PrimitiveType("int")) 103 | ``` 104 | 105 | - ``t1.is_instance(t2)`` computes wether type ``t1`` is an instance of ``t2``, notice that ``t2`` can be a python type, some type in our system or a ``TypeFunctor`` such as a ``GenericFunctor`` like ``List`` or ``Arrow``. 106 | Note that there is notion of covariant or contravariant types in our type system; 107 | - ``t.arguments()`` returns the list of arguments consumed by this type if this is an arrow, if this not an arrow then an empty list is returned; 108 | - ``t.returns()`` returns the type returned by this arrow, if this is not an arrow return the type ``t`` itself. 109 | - ``t.all_versions()`` returns the list of all types that this type can take, this is only relevant for sum types, basically each sum type will take all its possible values; 110 | - ``t.unify({'a': INT})`` will return ``t`` where all polymorphic types named ``'a'`` take value ``INT``. 111 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage of ProgSynth 2 | ================== 3 | 4 | The ProgSynth framework currently focuses on the Programming By Exemples (PBE) specification of tasks. Other cases are planned to be supported, despite this most elements of ProgSynth can be used independently of the specification. 5 | 6 | The :code:`synth` folder is a standalone library and does not provide ready to use code in the same manner a :code:`model.fit()` does, however we provide in :code:`./examples` scripts and DSLs that are ready to use. 7 | 8 | These scripts enable you to reproduce the results from papers or can be modified to test your ideas. The scripts are pretty generic and in general can be used for your own custom DSL with little to no modification. 9 | 10 | For further information, in each specification folder inside :code:`./examples` there is a :code:`README` explaining the use of scripts, what DSLs are implemented from which paper and where to download the datasets. 11 | 12 | The tutorial on section :doc:`tutorial` uses the example of tasks based on additions and substractions between integers or between floating point numbers, explaining step-by-step how to create a new DSL that can be used by the framework. 13 | 14 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This folder contains ready to use scripts and files that you can leverage to reproduce results from papers for example or to test your new ideas. 4 | 5 | 6 | 7 | - [Generics](#generics) 8 | - [Programming By Example](#programming-by-example) 9 | - [Programming from Natural Language](#programming-from-natural-language) 10 | - [SyGuS](#sygus) 11 | 12 | ## Generics 13 | 14 | Some scripts wield the same name across different kind of specifications, they often have the same interface and provide the same features. 15 | Here is a list with explanations of such scripts: 16 | 17 | - ``dataset_explorer.py`` this scripts takes a dataset as an input and enables yo uto explore different statistics about the dataset or to view specific tasks. 18 | 19 | ## Programming By Example 20 | 21 | This the PBE folder. The specification of task is given as pairs of inputs-outputs of correct executions of the solution program. 22 | The available domains are: 23 | 24 | - integer list manipulation with deepcoder and dreamcoder; 25 | - regexp; 26 | - trandsuctions. 27 | 28 | ## Programming from Natural Language 29 | 30 | This is the NLP folder. The specification of the task is given as a natural language string that explains the task. 31 | 32 | ## SyGuS 33 | 34 | This is the SyGuS folder. It provides scripts to directly work with the SyGus format. The goal is to mainly edit the specification thanks to the tools offered by ProgSynth such as sharpening. 35 | -------------------------------------------------------------------------------- /examples/pbe/calculator/calculator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Tuple, List as TList, Union 2 | 3 | import numpy as np 4 | 5 | from synth.generation.sampler import LexiconSampler, UnionSampler 6 | from synth.pbe.task_generator import ( 7 | TaskGenerator, 8 | basic_output_validator, 9 | reproduce_dataset, 10 | ) 11 | from synth.semantic import DSLEvaluator, Evaluator 12 | from synth.specification import PBE 13 | from synth.syntax import ( 14 | DSL, 15 | INT, 16 | Arrow, 17 | FixedPolymorphicType, 18 | PrimitiveType, 19 | BOOL, 20 | FunctionType, 21 | auto_type, 22 | ) 23 | from synth.task import Dataset 24 | 25 | # a type representing either an int or a float 26 | FLOAT = PrimitiveType("float") 27 | type = FixedPolymorphicType("int/float", INT | FLOAT) 28 | 29 | __semantics = { 30 | "+": lambda a: lambda b: round(a + b, 1), 31 | "-": lambda a: lambda b: round(a - b, 1), 32 | "int2float": lambda a: float(a), 33 | "1": 1, 34 | "2": 2, 35 | "3": 3, 36 | "3.0": 3.0, 37 | } 38 | 39 | __primitive_types = { 40 | # int|float -> int|float -> int|float 41 | "+": Arrow(type, Arrow(type, type)), 42 | "-": FunctionType(type, type, type), 43 | "int2float": auto_type("int->float"), 44 | "1": INT, 45 | "2": INT, 46 | "3": INT, 47 | } 48 | 49 | # Short example of a forbidden patterns (if add1 and sub1 are defined in _semantics and _primitive_types) 50 | _forbidden_patterns = { 51 | ("add1", 0): {"sub1"}, 52 | ("sub1", 0): {"add1"}, 53 | } 54 | 55 | dsl = DSL(__primitive_types, forbidden_patterns=_forbidden_patterns) 56 | dsl.instantiate_polymorphic_types() 57 | evaluator = DSLEvaluator(dsl.instantiate_semantics(__semantics)) 58 | lexicon = [round(x, 1) for x in np.arange(-256, 256 + 1, 0.1)] 59 | 60 | 61 | def reproduce_calculator_dataset( 62 | dataset: Dataset[PBE], 63 | dsl: DSL, 64 | evaluator: Evaluator, 65 | seed: Optional[int] = None, 66 | int_bound: int = 1000, 67 | *args: Any, 68 | **kwargs: Any, 69 | ) -> Tuple[TaskGenerator, TList[int]]: 70 | int_range: TList[int] = [int_bound, 0] 71 | int_range[1] = -int_range[0] 72 | 73 | float_range: TList[float] = [float(int_bound), 0] 74 | float_range[1] = -float_range[0] 75 | float_bound = float(int_bound) 76 | 77 | def analyser(start: None, element: Union[int, float]) -> None: 78 | if isinstance(element, int): 79 | int_range[0] = min(int_range[0], max(-int_bound, element)) 80 | int_range[1] = max(int_range[1], min(int_bound, element)) 81 | elif isinstance(element, float): 82 | float_range[0] = min(float_range[0], max(-float_bound, element)) 83 | float_range[1] = max(float_range[1], min(float_bound, element)) 84 | 85 | def get_element_sampler(start: None) -> UnionSampler: 86 | int_lexicon = list(range(int_range[0], int_range[1] + 1)) 87 | float_lexicon = [ 88 | round(x, 1) for x in np.arange(float_range[0], float_range[1] + 1, 0.1) 89 | ] 90 | return UnionSampler( 91 | { 92 | INT: LexiconSampler(int_lexicon, seed=seed), 93 | BOOL: LexiconSampler([True, False], seed=seed), 94 | FLOAT: LexiconSampler(float_lexicon, seed=seed), 95 | } 96 | ) 97 | 98 | def get_validator(start: None, max_list_length: int) -> Callable[[Any], bool]: 99 | return basic_output_validator( 100 | { 101 | int: list(range(int_range[0], int_range[1] + 1)), 102 | float: [ 103 | round(x, 1) 104 | for x in np.arange(float_range[0], float_range[1] + 1, 0.1) 105 | ], 106 | }, 107 | max_list_length, 108 | ) 109 | 110 | def get_lexicon(start: None) -> TList[float]: 111 | return [round(x, 1) for x in np.arange(float_range[0], float_range[1] + 1, 0.1)] 112 | 113 | return reproduce_dataset( 114 | dataset, 115 | dsl, 116 | evaluator, 117 | None, 118 | analyser, 119 | get_element_sampler, 120 | get_validator, 121 | get_lexicon, 122 | seed, 123 | *args, 124 | **kwargs, 125 | ) 126 | -------------------------------------------------------------------------------- /examples/pbe/calculator/convert_calculator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Callable, Dict, List as TList 3 | 4 | import tqdm 5 | 6 | from synth import Task, Dataset, PBE, Example 7 | from synth.syntax import ( 8 | FunctionType, 9 | Program, 10 | UnknownType, 11 | guess_type, 12 | ) 13 | 14 | from calculator import dsl, evaluator, FLOAT 15 | 16 | dsl.instantiate_polymorphic_types(5) 17 | 18 | 19 | def __convert__(load: Callable[[], Dataset[PBE]], name: str) -> None: 20 | tasks = load() 21 | tasks.save(name) 22 | sols = sum(1 for t in tasks if t.solution) 23 | print(f"Converted {len(tasks)} tasks {sols / len(tasks):.0%} containing solutions") 24 | # Integrity check 25 | for task in tqdm.tqdm(tasks, desc="integrity check"): 26 | for ex in task.specification.examples: 27 | obt = evaluator.eval(task.solution, ex.inputs) 28 | assert ( 29 | obt == ex.output 30 | ), f"failed on {task.solution} inputs:{ex.inputs} got:{obt} target:{ex.output}" 31 | 32 | 33 | def convert_calculator( 34 | file: str = "dataset/calculator_dataset.json", 35 | output_file: str = "calculator.pickle", 36 | ) -> None: 37 | def load() -> Dataset[PBE]: 38 | tasks: TList[Task[PBE]] = [] 39 | with open(file, "r") as fd: 40 | raw_tasks: TList[Dict[str, Any]] = json.load(fd) 41 | for raw_task in tqdm.tqdm(raw_tasks, desc="converting"): 42 | name: str = raw_task["program"] 43 | raw_examples: TList[Dict[str, Any]] = raw_task["examples"] 44 | inputs = [raw_example["inputs"] for raw_example in raw_examples] 45 | outputs: TList = [raw_example["output"] for raw_example in raw_examples] 46 | args_types = [guess_type(arg) for arg in inputs[0]] + [ 47 | guess_type(outputs[0]) 48 | ] 49 | # guess_type doesn't recognise FLOAT but since it is the only type not recognised we know that Unknown Type is acutally FLOAT 50 | args_types = [ 51 | at if not isinstance(at, UnknownType) else FLOAT 52 | for at in args_types 53 | ] 54 | type_request = FunctionType(*args_types) 55 | prog: Program = dsl.parse_program(name, type_request) 56 | examples = [ 57 | Example(inp, out) 58 | for inp, out in zip(inputs, outputs) 59 | if out is not None 60 | ] 61 | if len(examples) < len(inputs): 62 | continue 63 | tasks.append( 64 | Task[PBE](type_request, PBE(examples), prog, {"name": name}) 65 | ) 66 | return Dataset(tasks, metadata={"dataset": "calculator", "source:": file}) 67 | 68 | __convert__(load, output_file) 69 | 70 | 71 | if __name__ == "__main__": 72 | import argparse 73 | 74 | argument_parser: argparse.ArgumentParser = argparse.ArgumentParser( 75 | description="Convert calculator original dataset to ProgSynth format." 76 | ) 77 | 78 | argument_default_values = { 79 | "output": "calculator.pickle", 80 | } 81 | 82 | argument_parser.add_argument( 83 | type=str, 84 | dest="file", 85 | action="store", 86 | help="Source JSON calculator file to be converted", 87 | ) 88 | argument_parser.add_argument( 89 | "-o", 90 | "--output", 91 | type=str, 92 | action="store", 93 | default=argument_default_values["output"], 94 | help=f"Output dataset file in ProgSynth format (default: '{argument_default_values['output']}')", 95 | ) 96 | parsed_parameters = argument_parser.parse_args() 97 | convert_calculator(parsed_parameters.file, parsed_parameters.output) 98 | -------------------------------------------------------------------------------- /examples/pbe/convert_to_psl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as path 3 | from typing import Union 4 | 5 | from dsl_loader import add_dsl_choice_arg, load_DSL 6 | from dataset_loader import add_dataset_choice_arg, load_dataset 7 | 8 | from synth import Dataset, PBE 9 | from synth.specification import PBEWithConstants 10 | 11 | 12 | parser = argparse.ArgumentParser( 13 | description="Convert a ProgSynth dataset to the PSL format" 14 | ) 15 | add_dsl_choice_arg(parser) 16 | add_dataset_choice_arg(parser) 17 | parser.add_argument( 18 | "dest_folder", 19 | type=str, 20 | help="destination folder", 21 | ) 22 | parser.add_argument( 23 | "logics", 24 | type=str, 25 | help="logics used", 26 | ) 27 | 28 | parameters = parser.parse_args() 29 | dsl_name: str = parameters.dsl 30 | dataset_file: str = parameters.dataset 31 | dest_folder: str = parameters.dest_folder 32 | logics: str = parameters.logics 33 | # ================================ 34 | # Load constants specific to DSL 35 | # ================================ 36 | dsl_module = load_DSL(dsl_name) 37 | dsl = dsl_module.dsl 38 | # ================================ 39 | # Load dataset & Task Generator 40 | # ================================ 41 | # Load dataset 42 | full_dataset: Dataset[Union[PBE, PBEWithConstants]] = load_dataset( 43 | dsl_name, dataset_file 44 | ) 45 | COMMENT_PREFIX = "#" 46 | 47 | 48 | for i, task in enumerate(full_dataset.tasks): 49 | spec = task.specification 50 | name = task.metadata.get("name", f"{dsl_name}_{i}") 51 | filepath = path.join(dest_folder, name + ".psl") 52 | try: 53 | fd = open(filepath, "w") 54 | fd.close() 55 | except: 56 | filepath = path.join(dest_folder, f"{dsl_name}_{i}" + ".psl") 57 | 58 | with open(filepath, "w") as fd: 59 | fd.write(f"(set-logic {logics})\n") 60 | 61 | fd.write(f"\n{COMMENT_PREFIX} Function Synthesis\n") 62 | fd.write("(synth-fun f ") 63 | for j, arg in enumerate(task.type_request.arguments()): 64 | fd.write(f"(x{j+1} {arg}) ") 65 | fd.write(f"{task.type_request.returns()})\n") 66 | 67 | fd.write(f"\n{COMMENT_PREFIX} PBE Examples\n") 68 | for example in spec.examples: 69 | inputs = " ".join(map(str, example.inputs)) 70 | output = str(example.output) 71 | fd.write(f"(constraint-pbe (f {inputs}) {output})\n") 72 | 73 | if isinstance(spec, PBEWithConstants): 74 | constants = spec.constants 75 | fd.write(f"\n{COMMENT_PREFIX} Constants\n") 76 | for type, values in spec.constants.items(): 77 | allowed = " ".join(map(str, values)) 78 | fd.write(f"(define-const {type} {allowed})\n") 79 | 80 | fd.write("\n(check-progsynth)\n") 81 | if task.solution is not None: 82 | fd.write(f"\n(solution-pbe {task.solution})\n") 83 | lines = [] 84 | for name, val in task.metadata.items(): 85 | if name == "name": 86 | continue 87 | lines.append(f"{COMMENT_PREFIX} {name}: {val}") 88 | if lines: 89 | fd.write(f"\n{COMMENT_PREFIX} Metadata:\n") 90 | fd.writelines(lines) 91 | -------------------------------------------------------------------------------- /examples/pbe/dataset_improve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | from dsl_loader import add_dsl_choice_arg, load_DSL 5 | from dataset_loader import add_dataset_choice_arg, load_dataset 6 | 7 | from synth import Dataset, PBE 8 | from synth.utils import chrono 9 | 10 | 11 | parser = argparse.ArgumentParser( 12 | description="Generate a new dataset by replacing solutions found if they are shorter than the original's" 13 | ) 14 | add_dsl_choice_arg(parser) 15 | add_dataset_choice_arg(parser) 16 | parser.add_argument( 17 | "-s", 18 | "--solution", 19 | type=str, 20 | help="solution file", 21 | ) 22 | 23 | parameters = parser.parse_args() 24 | dsl_name: str = parameters.dsl 25 | dataset_file: str = parameters.dataset 26 | solution_file: str = parameters.solution 27 | # ================================ 28 | # Load constants specific to DSL 29 | # ================================ 30 | dsl_module = load_DSL(dsl_name) 31 | dsl, evaluator, lexicon = dsl_module.dsl, dsl_module.evaluator, dsl_module.lexicon 32 | # ================================ 33 | # Load dataset & Task Generator 34 | # ================================ 35 | # Load dataset 36 | full_dataset: Dataset[PBE] = load_dataset(dsl_name, dataset_file) 37 | 38 | print("Loading solutions...", end="", flush=True) 39 | with chrono.clock("solutions.load") as c: 40 | with open(solution_file, "r") as fd: 41 | reader = csv.reader(fd) 42 | trace = [tuple(row) for row in reader] 43 | trace.pop(0) 44 | solutions = [row[-1] if row[0] == "True" else None for row in trace] 45 | print("done in", c.elapsed_time(), "s") 46 | 47 | replaced = 0 48 | saved = 0 49 | print("Merging solutions and dataset...", end="", flush=True) 50 | with chrono.clock("merge") as c: 51 | for task, new_sol in zip(full_dataset.tasks, solutions): 52 | if new_sol is None: 53 | continue 54 | if task.solution is None: 55 | task.solution = dsl.parse_program(new_sol, task.type_request) 56 | continue 57 | size = new_sol.count(" ") + 1 58 | if size < task.solution.size(): 59 | saved += task.solution.size() - size 60 | task.solution = dsl.parse_program(new_sol, task.type_request) 61 | replaced += 1 62 | 63 | print("done in", c.elapsed_time(), "s") 64 | print(f"Replaced {replaced} original solutions saving {saved} size!") 65 | print("Saving merged dataset...", end="", flush=True) 66 | with chrono.clock("dataset.save") as c: 67 | full_dataset.save(dataset_file.replace(".pickle", "_merged.pickle")) 68 | print("done in", c.elapsed_time(), "s") 69 | -------------------------------------------------------------------------------- /examples/pbe/dataset_learner.py: -------------------------------------------------------------------------------- 1 | from synth import Dataset, PBE 2 | from synth.utils import chrono 3 | from synth.library import learn, make_score_probabilistic, score_description 4 | 5 | from dsl_loader import add_dsl_choice_arg, load_DSL 6 | from dataset_loader import add_dataset_choice_arg, load_dataset 7 | 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser(description="Learn a new primitive based on a dataset") 11 | add_dsl_choice_arg(parser) 12 | add_dataset_choice_arg(parser) 13 | parser.add_argument( 14 | "--probabilistic", 15 | action="store_true", 16 | help="Maximise probability instead of reducing description size", 17 | ) 18 | parameters = parser.parse_args() 19 | dsl_name: str = parameters.dsl 20 | proba: bool = parameters.probabilistic 21 | dataset_file: str = parameters.dataset 22 | # ================================ 23 | # Load constants specific to DSL 24 | # ================================ 25 | dsl_module = load_DSL(dsl_name) 26 | dsl, evaluator, lexicon = dsl_module.dsl, dsl_module.evaluator, dsl_module.lexicon 27 | # ================================ 28 | # Load dataset 29 | # ================================ 30 | # Load dataset 31 | print(f"Loading {dataset_file}...", end="") 32 | with chrono.clock("dataset.load") as c: 33 | full_dataset: Dataset[PBE] = Dataset.load(dataset_file) 34 | print("done in", c.elapsed_time(), "s") 35 | 36 | programs = [t.solution for t in full_dataset if t.solution is not None] 37 | score_fn = make_score_probabilistic(programs, False) if proba else score_description 38 | score, prog = learn(programs, score_fn, progress=True) 39 | if proba: 40 | print(f"Best program is {prog} which bumps the programs set's log prob to:{score}.") 41 | else: 42 | print(f"Best program is {prog} which reduces description size by:{score}.") 43 | # print(f"This would reduce the size by {(size - 1) * occs}.") 44 | # score, prog = learn( 45 | # [t.solution for t in full_dataset if t.solution is not None], progress=True 46 | # ) 47 | # # print(f"Found {occs} occurences of {prog}.") 48 | # print(f"This would reduce the size by {score}.") 49 | -------------------------------------------------------------------------------- /examples/pbe/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import pickle 3 | 4 | from colorama import Fore as F 5 | 6 | from synth import Dataset 7 | from synth.utils import chrono 8 | 9 | 10 | class DatasetUnpickler(pickle.Unpickler): 11 | def find_class(self, module, name): 12 | try: 13 | return super().find_class(module, name) 14 | except: 15 | return super().find_class(module + "." + module, name) 16 | 17 | 18 | def add_dataset_choice_arg(parser: ArgumentParser) -> None: 19 | parser.add_argument( 20 | "-d", 21 | "--dataset", 22 | type=str, 23 | default="{dsl_name}.pickle", 24 | help="the dataset file to load (default: {dsl_name}.pickle)", 25 | ) 26 | 27 | 28 | def load_dataset(dsl_name: str, dataset_file: str, verbose: bool = True) -> Dataset: 29 | dataset_file = dataset_file.format(dsl_name=dsl_name) 30 | if verbose: 31 | print(f"Loading {F.LIGHTCYAN_EX}{dataset_file}{F.RESET}...", end="") 32 | with chrono.clock("dataset.load") as c: 33 | full_dataset: Dataset = Dataset.load(dataset_file, DatasetUnpickler) 34 | if verbose: 35 | print(f"done in {c.elapsed_time():.2}s") 36 | return full_dataset 37 | -------------------------------------------------------------------------------- /examples/pbe/deepcoder/convert_deepcoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Callable, Dict, Tuple, List as TList 3 | 4 | import tqdm 5 | 6 | from synth import Task, Dataset, PBE, Example 7 | from synth.syntax import ( 8 | INT, 9 | FunctionType, 10 | List, 11 | Type, 12 | Function, 13 | Primitive, 14 | Program, 15 | Variable, 16 | ) 17 | 18 | from deepcoder import dsl, evaluator 19 | 20 | name2type = {p.primitive: p.type for p in dsl.list_primitives} 21 | 22 | 23 | def __convert__(load: Callable[[], Dataset[PBE]], name: str) -> None: 24 | tasks = load() 25 | tasks.save(name) 26 | sols = sum(1 for t in tasks if t.solution) 27 | print(f"Converted {len(tasks)} tasks {sols / len(tasks):.0%} containing solutions") 28 | # Integrity check 29 | for task in tqdm.tqdm(tasks, desc="integrity check"): 30 | for ex in task.specification.examples: 31 | assert evaluator.eval(task.solution, ex.inputs) == ex.output 32 | 33 | 34 | def convert_deepcoder( 35 | file: str = "deepcoder_dataset/T=3_train.json", 36 | output_file: str = "deepcoder.pickle", 37 | ) -> None: 38 | def load() -> Dataset[PBE]: 39 | tasks: TList[Task[PBE]] = [] 40 | with open(file, "r") as fd: 41 | raw_tasks: TList[Dict[str, Any]] = json.load(fd) 42 | for raw_task in tqdm.tqdm(raw_tasks, desc="converting"): 43 | name: str = raw_task["program"] 44 | raw_examples: TList[Dict[str, Any]] = raw_task["examples"] 45 | inputs = [raw_example["inputs"] for raw_example in raw_examples] 46 | outputs: TList = [raw_example["output"] for raw_example in raw_examples] 47 | 48 | prog, type_request = __deepcoder_str2prog(name) 49 | examples = [ 50 | Example(inp, out) 51 | for inp, out in zip(inputs, outputs) 52 | if out is not None 53 | ] 54 | if len(examples) < len(inputs): 55 | continue 56 | tasks.append( 57 | Task[PBE](type_request, PBE(examples), prog, {"name": name}) 58 | ) 59 | return Dataset(tasks, metadata={"dataset": "deepcoder", "source:": file}) 60 | 61 | __convert__(load, output_file) 62 | 63 | 64 | def __deepcoder_str2prog(s: str) -> Tuple[Program, Type]: 65 | parts = s.split("|") 66 | stack: TList[Program] = [] 67 | var: int = 0 68 | type_stack: TList[Type] = [] 69 | for part in parts: 70 | subparts = part.split(",") 71 | name = subparts.pop(0) 72 | if name == "LIST": 73 | stack.append(Variable(var, List(INT))) 74 | var += 1 75 | type_stack.append(List(INT)) 76 | continue 77 | if name == "INT": 78 | stack.append(Variable(var, INT)) 79 | var += 1 80 | type_stack.append(INT) 81 | continue 82 | if name not in name2type.keys(): 83 | name = name + "[" + subparts.pop(0) + "]" 84 | primitive = Primitive(name, name2type[name]) 85 | targets = [int(x) for x in subparts] 86 | arguments = [stack[x] for x in targets] 87 | stack.append(Function(primitive, arguments)) 88 | type_stack.append(stack[-1].type) 89 | type_request = FunctionType(*type_stack) 90 | return stack[-1], type_request 91 | 92 | 93 | if __name__ == "__main__": 94 | import argparse 95 | 96 | argument_parser: argparse.ArgumentParser = argparse.ArgumentParser( 97 | description="Convert deepcoder original dataset to ProgSynth format." 98 | ) 99 | 100 | argument_default_values = { 101 | "output": "deepcoder.pickle", 102 | } 103 | 104 | argument_parser.add_argument( 105 | type=str, 106 | dest="file", 107 | action="store", 108 | help="Source JSON deepcoder file to be converted", 109 | ) 110 | argument_parser.add_argument( 111 | "-o", 112 | "--output", 113 | type=str, 114 | action="store", 115 | default=argument_default_values["output"], 116 | help=f"Output dataset file in ProgSynth format (default: '{argument_default_values['output']}')", 117 | ) 118 | parsed_parameters = argument_parser.parse_args() 119 | convert_deepcoder(parsed_parameters.file, parsed_parameters.output) 120 | -------------------------------------------------------------------------------- /examples/pbe/dreamcoder/convert_dreamcoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Callable 3 | 4 | import tqdm 5 | 6 | from synth import Task, Dataset, PBE, Example 7 | from synth.syntax.type_system import EmptyList 8 | 9 | 10 | def __convert__(load: Callable[[], Dataset[PBE]], name: str) -> None: 11 | tasks = load() 12 | tasks.save(name) 13 | sols = len([0 for t in tasks if t.solution]) 14 | print(f"Converted {len(tasks)} tasks {sols / len(tasks):.0%} containing solutions") 15 | 16 | 17 | def convert_dreamcoder( 18 | file: str, 19 | output_file: str = "dreamcoder.pickle", 20 | ) -> None: 21 | def load() -> Dataset[PBE]: 22 | tasks: List[Task[PBE]] = [] 23 | with open(file, "rb") as fd: 24 | li: List[Dict[str, Any]] = json.load(fd) 25 | for task_dict in tqdm.tqdm(li, desc="converting"): 26 | examples = [ 27 | Example([dico["i"]], dico["o"]) for dico in task_dict["examples"] 28 | ] 29 | spec = PBE(examples) 30 | type_req = spec.guess_type() 31 | # Skip []-only output tasks 32 | if EmptyList in type_req: 33 | continue 34 | tasks.append( 35 | Task[PBE](type_req, spec, metadata={"name": task_dict["name"]}) 36 | ) 37 | return Dataset(tasks, metadata={"dataset": "dreamcoder", "source:": file}) 38 | 39 | __convert__(load, output_file) 40 | 41 | 42 | if __name__ == "__main__": 43 | import argparse 44 | 45 | argument_parser: argparse.ArgumentParser = argparse.ArgumentParser( 46 | description="Convert deepcoder original dataset to ProgSynth format." 47 | ) 48 | 49 | argument_default_values = { 50 | "output": "dreamcoder.pickle", 51 | } 52 | 53 | argument_parser.add_argument( 54 | type=str, 55 | dest="file", 56 | action="store", 57 | help="Source JSON file containing dreamcoder tasks to be converted", 58 | ) 59 | argument_parser.add_argument( 60 | "-o", 61 | "--output", 62 | type=str, 63 | action="store", 64 | default=argument_default_values["output"], 65 | help=f"Output dataset file in ProgSynth format (default: '{argument_default_values['output']}')", 66 | ) 67 | parsed_parameters = argument_parser.parse_args() 68 | convert_dreamcoder(parsed_parameters.file, parsed_parameters.output) 69 | -------------------------------------------------------------------------------- /examples/pbe/model_embeddings_visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.tensorboard.writer import SummaryWriter 4 | 5 | 6 | from dataset_loader import add_dataset_choice_arg, load_dataset 7 | from dsl_loader import add_dsl_choice_arg, load_DSL 8 | from model_loader import ( 9 | add_model_choice_arg, 10 | instantiate_predictor, 11 | ) 12 | 13 | 14 | from synth.nn import print_model_summary 15 | from synth.syntax import CFG, UCFG 16 | from synth.filter import add_dfta_constraints 17 | from synth.pbe.io_encoder import IOEncoder 18 | 19 | 20 | import argparse 21 | 22 | 23 | parser = argparse.ArgumentParser(description="Visualize model") 24 | parser.add_argument("-m", "--model", default="", type=str, help="model file") 25 | add_dataset_choice_arg(parser) 26 | add_dsl_choice_arg(parser) 27 | add_model_choice_arg(parser) 28 | 29 | parameters = parser.parse_args() 30 | dsl_name: str = parameters.dsl 31 | dataset_file: str = parameters.dataset 32 | cpu_only: bool = parameters.cpu 33 | model_file: str = parameters.model 34 | constrained: bool = parameters.constrained 35 | # Get device 36 | device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu" 37 | print("Using device:", device) 38 | # Load DSL ================================================================ 39 | dsl_module = load_DSL(dsl_name) 40 | dsl, lexicon = dsl_module.dsl, dsl_module.lexicon 41 | constraints = getattr(dsl_module, "constraints", []) 42 | constant_types = getattr(dsl_module, "constant_types", set()) 43 | # Load Dataset ============================================================ 44 | full_dataset = load_dataset(dsl_name, dataset_file) 45 | # Load CFGs =============================================================== 46 | all_type_requests = full_dataset.type_requests() 47 | 48 | if all(task.solution is not None for task in full_dataset): 49 | max_depth = max(task.solution.depth() for task in full_dataset) 50 | else: 51 | max_depth = 5 # TODO: set as parameter 52 | cfgs = [ 53 | CFG.depth_constraint( 54 | dsl, 55 | t, 56 | max_depth, 57 | upper_bound_type_size=10, 58 | constant_types=constant_types, 59 | min_variable_depth=0, 60 | ) 61 | for t in all_type_requests 62 | ] 63 | cfgs = [ 64 | UCFG.from_DFTA_with_ngrams( 65 | add_dfta_constraints(cfg, constraints, progress=False), 2 66 | ) 67 | if constrained 68 | else cfg 69 | for cfg in cfgs 70 | ] 71 | 72 | writer = SummaryWriter(comment=f"model_vizualizer_{model_file}") 73 | # Load Model ============================================================== 74 | predictor = instantiate_predictor(parameters, cfgs, lexicon) 75 | predictor.load_state_dict(torch.load(model_file, map_location=device)) 76 | predictor = predictor.to(device) 77 | predictor.eval() 78 | print_model_summary(predictor) 79 | # Plot embeddings ========================================================= 80 | print("Generating embeddings data:") 81 | encoder = predictor.packer.encoder 82 | embedder = predictor.packer.embedder 83 | # For now this part assumes isinstance(encoder, IOEncoder) 84 | assert isinstance(encoder, IOEncoder) 85 | encoded = [] 86 | for l in lexicon: 87 | encoder.__encode_element__(l, encoded) 88 | # Built as a Tensor 89 | res = torch.LongTensor(encoded).to(device).reshape((-1, 1)) 90 | output: Tensor = embedder(res).squeeze() 91 | writer.add_embedding(output, metadata=lexicon) 92 | # END ==================================================================== 93 | print("Additional model data can now be viewed with TensorBoard!") 94 | writer.close() 95 | -------------------------------------------------------------------------------- /examples/pbe/model_loader.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import List, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | from torch.nn.utils.rnn import PackedSequence 8 | 9 | from synth import PBE, Task 10 | from synth.nn import ( 11 | DetGrammarPredictorLayer, 12 | UGrammarPredictorLayer, 13 | abstractions, 14 | Task2Tensor, 15 | ) 16 | from synth.pbe import IOEncoder 17 | from synth.syntax import UCFG, TTCFG 18 | from synth.syntax.grammars.cfg import CFG 19 | 20 | 21 | class MyPredictor(nn.Module): 22 | def __init__( 23 | self, 24 | size: int, 25 | constrained: bool, 26 | cfgs: Union[List[TTCFG], List[UCFG]], 27 | variable_probability: float, 28 | encoding_dimension: int, 29 | device: str, 30 | lexicon, 31 | ) -> None: 32 | super().__init__() 33 | layer = UGrammarPredictorLayer if constrained else DetGrammarPredictorLayer 34 | abstraction = ( 35 | abstractions.ucfg_bigram 36 | if constrained 37 | else abstractions.cfg_bigram_without_depth 38 | ) 39 | self.bigram_layer = layer( 40 | size, 41 | cfgs, 42 | abstraction, 43 | variable_probability, 44 | ) 45 | encoder = IOEncoder(encoding_dimension, lexicon) 46 | self.packer = Task2Tensor( 47 | encoder, nn.Embedding(len(encoder.lexicon), size), size, device=device 48 | ) 49 | self.rnn = nn.LSTM(size, size, 1) 50 | self.end = nn.Sequential( 51 | nn.Linear(size, size), 52 | nn.ReLU(), 53 | nn.Linear(size, size), 54 | nn.ReLU(), 55 | ) 56 | 57 | def forward(self, x: List[Task[PBE]]) -> Tensor: 58 | seq: PackedSequence = self.packer(x) 59 | _, (y, _) = self.rnn(seq) 60 | y: Tensor = y.squeeze(0) 61 | return self.bigram_layer(self.end(y)) 62 | 63 | 64 | def instantiate_predictor( 65 | parameters: Namespace, cfgs: Union[List[CFG], List[UCFG]], lexicon: List 66 | ) -> MyPredictor: 67 | variable_probability: float = parameters.var_prob 68 | encoding_dimension: int = parameters.encoding_dimension 69 | hidden_size: int = parameters.hidden_size 70 | cpu_only: bool = parameters.cpu 71 | constrained: bool = parameters.constrained 72 | device = "cuda" if not cpu_only and torch.cuda.is_available() else "cpu" 73 | 74 | return MyPredictor( 75 | hidden_size, 76 | constrained, 77 | cfgs, 78 | variable_probability, 79 | encoding_dimension, 80 | device, 81 | lexicon, 82 | ).to(device) 83 | 84 | 85 | def add_model_choice_arg(parser: ArgumentParser) -> None: 86 | gg = parser.add_argument_group("model parameters") 87 | gg.add_argument( 88 | "-v", 89 | "--var-prob", 90 | type=float, 91 | default=0.2, 92 | help="variable probability (default: .2)", 93 | ) 94 | gg.add_argument( 95 | "-ed", 96 | "--encoding-dimension", 97 | type=int, 98 | default=512, 99 | help="encoding dimension (default: 512)", 100 | ) 101 | gg.add_argument( 102 | "-hd", 103 | "--hidden-size", 104 | type=int, 105 | default=512, 106 | help="hidden layer size (default: 512)", 107 | ) 108 | gg.add_argument( 109 | "--cpu", 110 | action="store_true", 111 | default=False, 112 | help="do not try to run things on cuda", 113 | ) 114 | gg = parser.add_argument_group("grammar parameters") 115 | gg.add_argument( 116 | "--constrained", 117 | action="store_true", 118 | default=False, 119 | help="use unambigous grammar to include constraints in the grammar if available", 120 | ) 121 | 122 | gg.add_argument( 123 | "--max-depth", 124 | type=int, 125 | default=5, 126 | help="maximum depth of grammars used (-1 for infinite, default: 5)", 127 | ) 128 | gg.add_argument( 129 | "--ngram", 130 | type=int, 131 | default=2, 132 | choices=[1, 2], 133 | help="ngram used by grammars (default: 2)", 134 | ) 135 | -------------------------------------------------------------------------------- /examples/pbe/quantum_circuits/quantum.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import numpy as np 3 | 4 | from synth.syntax import ( 5 | PolymorphicType, 6 | Arrow, 7 | INT, 8 | BOOL, 9 | List, 10 | DSL, 11 | auto_type, 12 | Program, 13 | ) 14 | from synth.semantic import DSLEvaluator 15 | 16 | 17 | import qiskit as qk 18 | from synth.syntax.program import Primitive 19 | 20 | __syntax = auto_type( 21 | { 22 | "H": "circuit -> int -> circuit", 23 | "T": "circuit -> int -> circuit", 24 | "Tdg": "circuit -> int -> circuit", 25 | "CNOT": "circuit -> int -> int -> circuit", 26 | } 27 | ) 28 | 29 | 30 | __semantics = { 31 | "H": lambda QT: lambda q1: QT if QT.circuit.h(QT.q(q1)) is not None else QT, 32 | "T": lambda QT: lambda q1: QT if QT.circuit.t(QT.q(q1)) is not None else QT, 33 | "Tdg": lambda QT: lambda q1: QT if QT.circuit.tdg(QT.q(q1)) is not None else QT, 34 | "CNOT": lambda QT: lambda q1: lambda q2: QT 35 | if QT.circuit.cnot(QT.q(q1), QT.q(q2)) is not None 36 | else QT, 37 | } 38 | 39 | 40 | class QiskitTester: 41 | def __init__(self, n_qubits: int): 42 | self.n_qubits = n_qubits 43 | self.unitary_matrix = None 44 | self.qreg_q = qk.QuantumRegister(self.n_qubits, "q") 45 | self.circuit = qk.QuantumCircuit(self.qreg_q) 46 | 47 | def q(self, q_num: int) -> int: 48 | return self.n_qubits - 1 - q_num 49 | 50 | def __enter__(self): 51 | return self 52 | 53 | def __exit__(self, *args: Any, **kwargs: Any) -> None: 54 | pass 55 | 56 | def __str__(self) -> str: 57 | return self.circuit.__str__() 58 | 59 | def execute(self, backend: qk.AerWrapper) -> np.ndarray: 60 | return np.array(qk.execute(self.circuit, backend).result().get_unitary()).T 61 | 62 | 63 | class QuantumCircuitEvaluator(DSLEvaluator): 64 | def __init__(self, semantics: Dict[Primitive, Any], nqbits: int = 3) -> None: 65 | super().__init__(semantics, False) 66 | self.nqbits = nqbits 67 | self.backend = qk.Aer.get_backend("unitary_simulator") 68 | 69 | def eval(self, program: Program, input: List) -> Any: 70 | with QiskitTester(self.nqbits) as QT: 71 | super().eval(program, [QT] + input) 72 | 73 | return QT.execute(self.backend) 74 | 75 | 76 | dsl = DSL(__syntax) 77 | evaluator = QuantumCircuitEvaluator(dsl.instantiate_semantics(__semantics)) 78 | -------------------------------------------------------------------------------- /examples/pbe/regexp/convert_regexp.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Callable, Dict, Tuple, List as TList 3 | 4 | import tqdm 5 | 6 | from synth import Task, Dataset, PBE, Example 7 | from synth.syntax import ( 8 | STRING, 9 | FunctionType, 10 | List, 11 | Type, 12 | Function, 13 | Primitive, 14 | Program, 15 | Variable, 16 | ) 17 | 18 | from examples.pbe.regexp.type_regex import REGEXP 19 | from examples.pbe.regexp.regexp import dsl, evaluator 20 | 21 | 22 | name2type = {p.primitive: p.type for p in dsl.list_primitives} 23 | 24 | 25 | def __convert__(load: Callable[[], Dataset[PBE]], name: str) -> None: 26 | tasks = load() 27 | tasks.save(name) 28 | sols = sum(1 for t in tasks if t.solution) 29 | print( 30 | f"Converted {len(tasks)} tasks {int(100 * sols / len(tasks))}% containing solutions" 31 | ) 32 | # Integrity check 33 | for task in tqdm.tqdm(tasks, desc="integrity check"): 34 | for ex in task.specification.examples: 35 | assert evaluator.eval(task.solution, ex.inputs) == ex.output 36 | 37 | 38 | def convert_regexp( 39 | file: str = "dataset/new_data.json", 40 | output_file: str = "regexp.pickle", 41 | ) -> None: 42 | def load() -> Dataset[PBE]: 43 | tasks: TList[Task[PBE]] = [] 44 | with open(file, "r") as fd: 45 | raw_tasks: TList[Dict[str, Any]] = json.load(fd) 46 | for raw_task in tqdm.tqdm(raw_tasks, desc="converting"): 47 | name: str = raw_task["program"] 48 | raw_examples: TList[Dict[str, Any]] = raw_task["examples"] 49 | inputs = [raw_example["inputs"] for raw_example in raw_examples] 50 | outputs: TList = [raw_example["output"] for raw_example in raw_examples] 51 | prog, type_request = __regexp_str2prog(name) 52 | examples = [ 53 | Example([list(x) for x in inp], out) 54 | for inp, out in zip(inputs, outputs) 55 | if out is not None 56 | ] 57 | if len(examples) < len(inputs): 58 | continue 59 | tasks.append( 60 | Task[PBE](type_request, PBE(examples), prog, {"name": name}) 61 | ) 62 | return Dataset(tasks, metadata={"dataset": "regexp", "source:": file}) 63 | 64 | __convert__(load, output_file) 65 | 66 | 67 | def __regexp_str2prog(s: str) -> Tuple[Program, Type]: 68 | parts = s.split("|") 69 | stack: TList[Program] = [] 70 | var: int = 0 71 | type_stack: TList[Type] = [] 72 | for part in parts: 73 | subparts = part.split(",") 74 | name = subparts.pop(0) 75 | # composition of methods 76 | if name == "eval": 77 | primitive = Primitive(name, name2type[name]) 78 | targets = [int(x) for x in subparts] 79 | arguments = [stack[x] for x in targets] 80 | solution = stack[-1] 81 | arguments.append(solution) 82 | stack.append(Function(primitive, arguments)) 83 | elif name == "STRING": 84 | stack.append(Variable(var, List(STRING))) 85 | var += 1 86 | type_stack.append(List(STRING)) 87 | elif name == "begin": 88 | stack.append(Primitive(name, REGEXP)) 89 | elif name in name2type.keys(): 90 | primitive = Primitive(name, name2type[name]) 91 | stack.append(Function(primitive, [stack[-1]])) 92 | type_stack.append(stack[-1].type) 93 | type_request = FunctionType(*type_stack) 94 | return stack[-1], type_request 95 | 96 | 97 | if __name__ == "__main__": 98 | import argparse 99 | 100 | argument_parser: argparse.ArgumentParser = argparse.ArgumentParser( 101 | description="Convert regexp original dataset to ProgSynth format." 102 | ) 103 | 104 | argument_default_values = { 105 | "output": "regexp.pickle", 106 | } 107 | 108 | argument_parser.add_argument( 109 | type=str, 110 | dest="file", 111 | action="store", 112 | help="Source JSON regexp file to be converted", 113 | ) 114 | argument_parser.add_argument( 115 | "-o", 116 | "--output", 117 | type=str, 118 | action="store", 119 | default=argument_default_values["output"], 120 | help=f"Output dataset file in ProgSynth format (default: '{argument_default_values['output']}')", 121 | ) 122 | parsed_parameters = argument_parser.parse_args() 123 | convert_regexp(parsed_parameters.file, parsed_parameters.output) 124 | -------------------------------------------------------------------------------- /examples/pbe/regexp/evaluator_regexp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Set 2 | from examples.pbe.regexp.type_regex import regex_match, Raw 3 | from synth.syntax.program import Function, Primitive, Program, Variable 4 | from synth.semantic.evaluator import Evaluator 5 | 6 | 7 | generalized_to_re = { 8 | "U": "[A-Z]", 9 | "L": "[a-z]", 10 | "N": "[0-9]", 11 | "O": "[^A-Za-z0-9]", 12 | "W": "\s", 13 | "begin": "", 14 | ".": ".", 15 | } 16 | 17 | 18 | def get_regexp(reg, groups: bool = True): 19 | modified = "" 20 | for char in reg: 21 | if char in generalized_to_re: 22 | modified += ( 23 | "(" + generalized_to_re[char] + ")" 24 | if groups 25 | else generalized_to_re[char] 26 | ) 27 | else: 28 | modified = modified[:-1] + char + ")" if groups else modified[:-1] + char 29 | return modified 30 | 31 | 32 | def __geometrical__(num_instances: int, probability: float) -> float: 33 | return ((1 - probability) ** (num_instances - 1)) * probability 34 | 35 | 36 | def __uniform__(lexicon: List[str]) -> float: 37 | return 1 / len(lexicon) 38 | 39 | 40 | def __tuplify__(element: Any) -> Any: 41 | if isinstance(element, List): 42 | return tuple(__tuplify__(x) for x in element) 43 | else: 44 | return element 45 | 46 | 47 | class RegexpEvaluator(Evaluator): 48 | def __init__(self, semantics: Dict[str, Any], use_cache: bool = True) -> None: 49 | super().__init__() 50 | self.semantics = semantics 51 | self.use_cache = use_cache 52 | self._cache: Dict[Any, Dict[str, Any]] = {} 53 | self.skip_exceptions: Set[Exception] = set() 54 | # Statistics 55 | self._total_requests = 0 56 | self._cache_hits = 0 57 | 58 | ### duplicated code with task_generator_regexp, to be factorized 59 | str_lexicon = list([chr(i) for i in range(32, 126)]) 60 | n_lexicon = [chr(i) for i in range(48, 58)] 61 | u_lexicon = [chr(i) for i in range(65, 91)] 62 | l_lexicon = [chr(i) for i in range(97, 123)] 63 | o_lexicon = str_lexicon[1:] # without whitespace 64 | o_lexicon = list(set(o_lexicon) - set(n_lexicon + u_lexicon + l_lexicon)) 65 | 66 | self.lexicons: Dict[str, List[str]] = { 67 | "W": ["\s"], 68 | "N": n_lexicon, 69 | "U": u_lexicon, 70 | "L": l_lexicon, 71 | "O": o_lexicon, 72 | } 73 | 74 | def eval(self, regexp: str, input: List) -> float: 75 | """ 76 | key = __tuplify__(regexp) 77 | if key not in self._cache and self.use_cache: 78 | self._cache[key] = {} 79 | evaluations: Dict[str, float] = self._cache[key] if self.use_cache else {} 80 | if regexp in evaluations: 81 | return evaluations[regexp] 82 | """ 83 | try: 84 | result = 1 85 | repeated = None 86 | re = get_regexp(regexp) 87 | match = regex_match(Raw(re), "".join(input)) 88 | if match is None or match.group() != "".join(input): 89 | print( 90 | "Regexp did not perfectly match with input. This word cannot be generated by this regexp." 91 | ) 92 | return 0 93 | group_index = len(match.groups()) 94 | for r in regexp[::-1]: 95 | self._total_requests += 1 96 | if r in self.lexicons.keys(): 97 | if repeated: 98 | group = match.groups(group_index) 99 | if repeated == "?": 100 | result *= __uniform__(self.lexicons) * 0.5 101 | else: 102 | tmp = 0 103 | if repeated == "+": 104 | tmp += __uniform__(self.lexicons) 105 | tmp *= __geometrical__( 106 | len(group), __uniform__(self.lexicons[r]) 107 | ) 108 | result *= tmp 109 | repeated = None 110 | else: 111 | result *= __uniform__(self.lexicons[r]) 112 | group_index -= 1 113 | elif r in ["+", "?", "*"]: 114 | repeated = r 115 | except Exception as e: 116 | if type(e) in self.skip_exceptions: 117 | return 0 118 | else: 119 | raise e 120 | 121 | return result 122 | 123 | def clear_cache(self) -> None: 124 | self._cache.clear() 125 | 126 | @property 127 | def cache_hit_rate(self) -> float: 128 | return self._cache_hits / self._total_requests 129 | -------------------------------------------------------------------------------- /examples/pbe/regexp/regexp.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from examples.pbe.regexp.type_regex import regex_match, Raw, REGEXP 4 | from examples.pbe.regexp.evaluator_regexp import RegexpEvaluator, get_regexp 5 | from examples.pbe.regexp.task_generator_regexp import reproduce_regexp_dataset 6 | 7 | from synth.semantic import DSLEvaluator 8 | from synth.syntax import DSL, PrimitiveType, Arrow, List, STRING, BOOL 9 | 10 | 11 | def pretty_print_solution(regexp: str) -> str: 12 | result = ( 13 | "".join("".join(regexp.__str__().split("(")[2:]).split(" ")[::-1]) 14 | .replace(")", "") 15 | .replace("begin", "") 16 | ) 17 | return f"(eval var0 {result})" 18 | 19 | 20 | def pretty_print_inputs(str: List) -> str: 21 | return "'" + "".join(str) + "'" 22 | 23 | 24 | init = PrimitiveType("") 25 | 26 | 27 | def __qmark__(x): 28 | return x + "?" 29 | 30 | 31 | def __kleene__(x): 32 | return x + "*" 33 | 34 | 35 | def __plus__(x): 36 | return x + "+" 37 | 38 | 39 | def __lowercase__(x): 40 | return x + "L" 41 | 42 | 43 | def __uppercase__(x): 44 | return x + "U" 45 | 46 | 47 | def __number__(x): 48 | return x + "N" 49 | 50 | 51 | def __other__(x): 52 | return x + "O" 53 | 54 | 55 | def __whitespace__(x): 56 | return x + "W" 57 | 58 | 59 | def __eval__(x, reg): 60 | x = "".join(x) 61 | result = regex_match(Raw(get_regexp(reg)), x, flags=re.ASCII) 62 | # print(f"{result.match.group() if result else None} vs {x} => {result.string == x if result != None else False}") 63 | if result is None: 64 | return False 65 | return result.match.group() == x 66 | 67 | 68 | __semantics = { 69 | "begin": init.type_name, 70 | "?": __qmark__, 71 | "*": __kleene__, 72 | "+": __plus__, 73 | "U": __uppercase__, 74 | "L": __lowercase__, 75 | "N": __number__, 76 | "O": __other__, 77 | "W": __whitespace__, 78 | "eval": lambda x: lambda reg: __eval__(x, reg), 79 | } 80 | 81 | __primitive_types = { 82 | "begin": REGEXP, 83 | "?": Arrow(REGEXP, REGEXP), 84 | "*": Arrow(REGEXP, REGEXP), 85 | "+": Arrow(REGEXP, REGEXP), 86 | "U": Arrow(REGEXP, REGEXP), 87 | "L": Arrow(REGEXP, REGEXP), 88 | "N": Arrow(REGEXP, REGEXP), 89 | "O": Arrow(REGEXP, REGEXP), 90 | "W": Arrow(REGEXP, REGEXP), 91 | "eval": Arrow(List(STRING), Arrow(REGEXP, BOOL)), 92 | } 93 | 94 | __forbidden_patterns = { 95 | "*": {"?", "+", "*"}, 96 | "?": {"?", "+", "*"}, 97 | "+": {"?", "+", "*"}, 98 | "W": {"?", "+", "*"}, 99 | } 100 | 101 | dsl = DSL(__primitive_types, __forbidden_patterns) 102 | evaluator = DSLEvaluator(dsl.instantiate_semantics(__semantics)) 103 | evaluator.skip_exceptions.add(re.error) 104 | lexicon = list([chr(i) for i in range(32, 126)]) 105 | regexp_evaluator = RegexpEvaluator(__semantics) 106 | -------------------------------------------------------------------------------- /examples/pbe/transduction/knowledge_graph/README.md: -------------------------------------------------------------------------------- 1 | 2 | # How to reproduce the experiment? 3 | 4 | First you need a database that supports SPARQL queries. 5 | Once you have that, you can generate the database using the ``fill_knowledge_graph.sparql``. 6 | 7 | Then you need to execute ``convert_kg_json_tasks.py`` to convert ``constants.json`` to ``constants.pickle`` (supported by AutoSynth). 8 | Then you need to preprocess the tasks with ``preprocess_tasks.py`` which will guess the constants. 9 | Then you can use the ``constants.pickle`` file in ``evaluate.py`` with your model. 10 | 11 | In our paper the model was the one obtained through the scrip ``test_performance.sh`` in experiments. 12 | -------------------------------------------------------------------------------- /examples/pbe/transduction/knowledge_graph/convert_kg_json_tasks.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from synth.specification import PBE, Example, TaskSpecification 4 | from synth.task import Dataset, Task 5 | 6 | 7 | def convert(dataset_file: str, output_file: str): 8 | tasks = [] 9 | 10 | with open(dataset_file) as fd: 11 | lst = json.load(fd) 12 | for el in lst: 13 | spec = PBE([Example(ex["inputs"], ex["output"]) for ex in el["examples"]]) 14 | task = Task[PBE]( 15 | spec.guess_type(), specification=spec, metadata=el["metadata"] 16 | ) 17 | tasks.append(task) 18 | 19 | dataset: Dataset[TaskSpecification] = Dataset(tasks) 20 | dataset.save(output_file) 21 | 22 | 23 | if __name__ == "__main__": 24 | import argparse 25 | 26 | argument_parser: argparse.ArgumentParser = argparse.ArgumentParser( 27 | description="Convert knowledge graph JSON dataset to ProgSynth format." 28 | ) 29 | 30 | argument_default_values = { 31 | "output": "constants.pickle", 32 | } 33 | 34 | argument_parser.add_argument( 35 | type=str, 36 | dest="file", 37 | action="store", 38 | help="Source JSON transduction file to be converted", 39 | ) 40 | argument_parser.add_argument( 41 | "-o", 42 | "--output", 43 | type=str, 44 | action="store", 45 | default=argument_default_values["output"], 46 | help=f"Output dataset file in ProgSynth format (default: '{argument_default_values['output']}')", 47 | ) 48 | parsed_parameters = argument_parser.parse_args() 49 | convert(parsed_parameters.file, parsed_parameters.output) 50 | -------------------------------------------------------------------------------- /examples/pbe/transduction/knowledge_graph/kg_path_finder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import List, Tuple 3 | from SPARQLWrapper import SPARQLWrapper, JSON 4 | 5 | 6 | def __make_query_path__(distance: int, id: int, tabs: int = 1) -> str: 7 | if distance == 0: 8 | return ("\t" * tabs) + f"w:{{}} ?p{distance} ?o_{id}_{distance} ." 9 | else: 10 | path = __make_query_path__(distance - 1, id, tabs) + "\n" 11 | path += ( 12 | "\t" * tabs 13 | ) + f"?o_{id}_{distance-1} ?p{distance} ?o_{id}_{distance} ." 14 | return path 15 | 16 | 17 | def __format__(el: str) -> str: 18 | return el.replace(" ", "_").replace("'", "_").replace(",", "_") 19 | 20 | 21 | def build_search_path_query(entities: List[Tuple[str, str]], distance: int = 1) -> str: 22 | entities = [(__format__(a), __format__(b)) for a, b in entities] 23 | first = entities.pop() 24 | subquery = "" 25 | for i, item in enumerate(entities): 26 | subquery += "\tFILTER EXISTS {\n" 27 | subquery += ( 28 | __make_query_path__(distance, i + 1, 2) 29 | .format(item[0]) 30 | .replace(f"?o_{i + 1}_{distance}", "w:" + item[1]) 31 | ) 32 | subquery += "\n\t} .\n" 33 | sparql_request = "PREFIX w: \n" 34 | sparql_request += "SELECT " 35 | sparql_request += " ".join(f"?p{d}" for d in range(distance + 1)) 36 | sparql_request += " WHERE {\n" 37 | sparql_request += ( 38 | __make_query_path__(distance, 0) 39 | .format(first[0]) 40 | .replace(f"?o_{0}_{distance}", "w:" + first[1]) 41 | ) 42 | sparql_request += "\n" 43 | sparql_request += subquery 44 | sparql_request += "\n}" 45 | return sparql_request 46 | 47 | 48 | def build_count_paths_query(start: str, path: List[str]) -> str: 49 | sparql_request = "PREFIX w: \n" 50 | sparql_request += "SELECT " 51 | sparql_request += "?dst" 52 | sparql_request += " WHERE {\n" 53 | sparql_request += f"\tw:{__format__(start)} w:{path[0]} ?e0 ." 54 | for i in range(1, len(path) - 1): 55 | sparql_request += f"\t?e{i-1} w:{path[i]} ?e{i} ." 56 | sparql_request += f"\t?e{len(path) - 1} w:{path[-1]} ?dst" 57 | sparql_request += "\n}" 58 | return sparql_request 59 | 60 | 61 | def build_wrapper(endpoint: str) -> SPARQLWrapper: 62 | wrapper = SPARQLWrapper(endpoint) 63 | wrapper.setReturnFormat(JSON) 64 | return wrapper 65 | 66 | 67 | def __exec_search_path_query__(query: str, wrapper: SPARQLWrapper) -> List[List[str]]: 68 | try: 69 | wrapper.setQuery(query) 70 | answer = wrapper.query().convert() 71 | paths: List[List[str]] = [] 72 | for path in answer["results"]["bindings"]: 73 | cur_path = [] 74 | for rel in path: 75 | cur_path.append(path[rel]["value"].split("/")[-1]) 76 | paths.append(cur_path) 77 | return paths 78 | except Exception as e: 79 | print(e, file=sys.stderr) 80 | pass 81 | return [] 82 | 83 | 84 | forbidden_chars = "+|!/<>" 85 | 86 | 87 | def find_paths_from_level( 88 | pairs: List[Tuple[str, str]], 89 | wrapper: SPARQLWrapper, 90 | level: int, 91 | max_distance: int = 3, 92 | ) -> List[List[str]]: 93 | if level < 0: 94 | return [] 95 | for inp, out in pairs: 96 | if any(c in inp or c in out for c in forbidden_chars): 97 | return [] 98 | d = level 99 | while d < max_distance: 100 | query = build_search_path_query(pairs, d) 101 | out = __exec_search_path_query__(query, wrapper) 102 | if len(out) > 0: 103 | return out 104 | d += 1 105 | return [] 106 | 107 | 108 | def __exec_count_query__(query: str, wrapper: SPARQLWrapper) -> int: 109 | if "+" in query or "|" in query: 110 | return 0 111 | try: 112 | wrapper.setQuery(query) 113 | answer = wrapper.query().convert() 114 | return len(answer["results"]["bindings"]) 115 | except Exception as e: 116 | print(e, file=sys.stderr) 117 | pass 118 | return 0 119 | 120 | 121 | def choose_best_path( 122 | paths: List[List[str]], pairs: List[Tuple[str, str]], wrapper: SPARQLWrapper 123 | ) -> List[str]: 124 | best_path_index = 0 125 | best_score = 99999999999999999999 126 | for i, path in enumerate(paths): 127 | score = 0 128 | for start, _ in pairs: 129 | score += __exec_count_query__(build_count_paths_query(start, path), wrapper) 130 | if score < best_score: 131 | best_score = score 132 | best_path_index = i 133 | return paths[best_path_index] 134 | -------------------------------------------------------------------------------- /examples/pbe/transduction/task_generator_transduction.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | List as TList, 3 | Any, 4 | Optional, 5 | Tuple, 6 | ) 7 | 8 | import numpy as np 9 | 10 | from examples.pbe.regexp.type_regex import REGEXP 11 | 12 | from synth.generation.sampler import ListSampler 13 | from synth.pbe.task_generator import ( 14 | TaskGenerator, 15 | basic_output_validator, 16 | reproduce_dataset, 17 | ) 18 | from synth.syntax.program import Program 19 | from synth.syntax.type_helper import auto_type 20 | from synth.syntax.type_system import List, Type 21 | 22 | from synth.task import Dataset, Task 23 | from synth.specification import PBE, Example, PBEWithConstants 24 | from synth.semantic.evaluator import DSLEvaluator 25 | from synth.syntax import ( 26 | STRING, 27 | DSL, 28 | ) 29 | from synth.generation import ( 30 | LexiconSampler, 31 | UnionSampler, 32 | ) 33 | 34 | CST_IN = auto_type("CST_STR_INPUT") 35 | CST_OUT = auto_type("CST_STR_OUTPUT") 36 | 37 | 38 | class TransductionTaskGenerator(TaskGenerator): 39 | def generate_program(self, type_request: Type) -> Tuple[Program, bool]: 40 | program, is_unique = super().generate_program(type_request) 41 | self.__constants = super().sample_input([STRING, STRING]) 42 | return program, is_unique 43 | 44 | def make_task( 45 | self, 46 | type_request: Type, 47 | solution: Program, 48 | inputs: TList, 49 | outputs: TList, 50 | **kwargs: Any, 51 | ) -> Task[PBEWithConstants]: 52 | return Task( 53 | type_request, 54 | PBEWithConstants( 55 | [Example(inp, out) for inp, out in zip(inputs, outputs)], 56 | {CST_IN: [self.__constants[0]], CST_OUT: [self.__constants[1]]}, 57 | ), 58 | solution, 59 | {"generated": True, **kwargs}, 60 | ) 61 | 62 | 63 | def reproduce_transduction_dataset( 64 | dataset: Dataset[PBE], 65 | dsl: DSL, 66 | evaluator: DSLEvaluator, 67 | seed: Optional[int] = None, 68 | *args: Any, 69 | **kwargs: Any, 70 | ) -> Tuple[TaskGenerator, TList[int]]: 71 | def analyser(start: None, elment: Any) -> None: 72 | pass 73 | 74 | str_lexicon = list([chr(i) for i in range(32, 126)]) 75 | regexp_symbols = [ 76 | "_", 77 | ")", 78 | "{", 79 | "+", 80 | ";", 81 | "=", 82 | "$", 83 | "\\", 84 | "^", 85 | ",", 86 | "!", 87 | "*", 88 | "'", 89 | " ", 90 | ">", 91 | "}", 92 | "<", 93 | "[", 94 | '"', 95 | "#", 96 | "|", 97 | "`", 98 | "%", 99 | "?", 100 | ":", 101 | "]", 102 | "&", 103 | "(", 104 | "@", 105 | ".", 106 | "/", 107 | "-", 108 | ] 109 | probabilities = np.array([0.5**i for i in range(6)]) 110 | probabilities /= np.sum(probabilities) 111 | STR_list = List(STRING) 112 | string_sampler = ( 113 | ListSampler( 114 | LexiconSampler(str_lexicon, seed=seed), 115 | [(i + 4, probabilities[i]) for i in range(len(probabilities))], 116 | max_depth=2, 117 | seed=seed, 118 | ) 119 | .compose_with_type_mapper(lambda _: STR_list) 120 | .compose(lambda el: el if isinstance(el, str) else "".join(el)) 121 | ) 122 | 123 | def get_sampler(start: None) -> UnionSampler: 124 | return UnionSampler( 125 | { 126 | STRING: string_sampler, 127 | REGEXP: LexiconSampler(regexp_symbols, seed=seed), 128 | } 129 | ) 130 | 131 | str_bank = str_lexicon + regexp_symbols 132 | task_generator, str_lexicon = reproduce_dataset( 133 | dataset, 134 | dsl, 135 | evaluator, 136 | None, 137 | lambda _, __: None, 138 | get_sampler, 139 | lambda _, max_list_length: lambda x: x is not None 140 | and all(xi in str_bank for xi in x), 141 | lambda _: str_lexicon + regexp_symbols, 142 | seed, 143 | *args, 144 | **kwargs, 145 | ) 146 | 147 | generator = TransductionTaskGenerator( 148 | task_generator.input_generator, 149 | task_generator.evaluator, 150 | task_generator.gen_random_type_request, 151 | task_generator.gen_random_sample_number, 152 | task_generator.type2pgrammar.values(), 153 | task_generator.output_validator, 154 | task_generator.max_tries, 155 | task_generator.uniques, 156 | verbose=task_generator.verbose, 157 | ) 158 | 159 | generator.skip_exceptions.add(ValueError) 160 | 161 | return generator, str_lexicon 162 | -------------------------------------------------------------------------------- /examples/pbe/transduction/transduction.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from examples.pbe.transduction.task_generator_transduction import ( 4 | reproduce_transduction_dataset, 5 | ) 6 | 7 | from synth.semantic.evaluator import DSLEvaluator 8 | from synth.syntax import ( 9 | DSL, 10 | Arrow, 11 | STRING, 12 | PrimitiveType, 13 | ) 14 | from examples.pbe.regexp.evaluator_regexp import get_regexp 15 | from examples.pbe.regexp.type_regex import REGEXP 16 | from examples.pbe.regexp.type_regex import ( 17 | Raw, 18 | REGEXP, 19 | regex_search, 20 | ) 21 | 22 | CST_IN = PrimitiveType("CST_STR_INPUT") 23 | CST_OUT = PrimitiveType("CST_STR_OUTPUT") 24 | 25 | 26 | def __concat__(x, y): 27 | return "" + x + y 28 | 29 | 30 | def __concat_if__(x, y): 31 | if y in x: 32 | return x 33 | if x in y: 34 | return y 35 | return "" + x + y 36 | 37 | 38 | def __split_first__(x: str, regexp: str): 39 | sbstr = regex_search(Raw(get_regexp(regexp)), x, flags=re.ASCII) 40 | if sbstr == None: 41 | return "" 42 | return x.split(sbstr.match.group(), 1)[0] 43 | 44 | 45 | def __split_snd__(x: str, regexp: str): 46 | sbstr = regex_search(Raw(get_regexp(regexp)), x, flags=re.ASCII) 47 | if sbstr == None: 48 | return "" 49 | return x.split(sbstr.match.group(), 1)[1] 50 | 51 | 52 | def __match__(x: str, regexp: str): 53 | sbstr = regex_search(Raw(get_regexp(regexp)), x, flags=re.ASCII) 54 | if sbstr == None: 55 | return "" 56 | return sbstr.match.group() 57 | 58 | 59 | # untreated matching, done for constant text inputs (e.g. "." will be considered as a point instead of any char) 60 | def __split_first_cst__(x: str, text: str): 61 | regexp = "(\\" + text + ")" 62 | sbstr = regex_search(Raw(regexp), x, flags=re.ASCII) 63 | if sbstr == None: 64 | return "" 65 | return x.split(sbstr.match.group(), 1)[0] 66 | 67 | 68 | def __split_snd_cst__(x: str, text: str): 69 | regexp = "(\\" + text + ")" 70 | sbstr = regex_search(Raw(regexp), x, flags=re.ASCII) 71 | if sbstr == None: 72 | return "" 73 | return x.split(sbstr.match.group(), 1)[1] 74 | 75 | 76 | def __match_cst__(x: str, text: str): 77 | regexp = "(\\" + text + ")" 78 | sbstr = regex_search(Raw(regexp), x, flags=re.ASCII) 79 | if sbstr == None: 80 | return "" 81 | return sbstr.match.group() 82 | 83 | 84 | def __compose__(x, y): 85 | return x + y 86 | 87 | 88 | __semantics = { 89 | "concat": lambda x: lambda y: __concat__(x, y), 90 | "concat_cst": lambda x: lambda y: __concat__(x, y), 91 | "concat_if": lambda x: lambda y: __concat_if__(x, y), 92 | "split_first": lambda x: lambda regexp: __split_first__(x, regexp), 93 | "split_snd": lambda x: lambda regexp: __split_snd__(x, regexp), 94 | "match": lambda x: lambda regexp: __match__(x, regexp), 95 | "split_first_cst": lambda x: lambda text: __split_first_cst__(x, text), 96 | "split_snd_cst": lambda x: lambda text: __split_snd_cst__(x, text), 97 | "match_cst": lambda x: lambda text: __match_cst__(x, text), 98 | "compose": lambda x: lambda y: __compose__(x, y), 99 | "$": "$", 100 | ".": ".", 101 | "except": lambda x: "([^" + x + "]+", 102 | "except_end": lambda x: "([^" + x + "]+$", 103 | } 104 | 105 | __primitive_types = { 106 | "concat": Arrow(STRING, Arrow(STRING, STRING)), 107 | "concat_cst": Arrow(STRING, Arrow(CST_OUT, STRING)), 108 | "concat_if": Arrow(STRING, Arrow(CST_OUT, STRING)), 109 | "split_first": Arrow(STRING, Arrow(REGEXP, STRING)), 110 | "split_snd": Arrow(STRING, Arrow(REGEXP, STRING)), 111 | "match": Arrow(STRING, Arrow(REGEXP, STRING)), 112 | "split_first_cst": Arrow(STRING, Arrow(CST_IN, STRING)), 113 | "split_snd_cst": Arrow(STRING, Arrow(CST_IN, STRING)), 114 | "match_cst": Arrow(STRING, Arrow(CST_IN, STRING)), 115 | "compose": Arrow(REGEXP, Arrow(REGEXP, REGEXP)), 116 | "$": REGEXP, 117 | ".": REGEXP, 118 | "except": Arrow(CST_IN, REGEXP), 119 | "except_end": Arrow(CST_IN, REGEXP), 120 | } 121 | 122 | __forbidden_patterns = {} 123 | 124 | dsl = DSL(__primitive_types, __forbidden_patterns) 125 | constant_types = {CST_IN, CST_OUT} 126 | evaluator = DSLEvaluator(dsl.instantiate_semantics(__semantics)) 127 | evaluator.skip_exceptions.add(re.error) 128 | lexicon = list([chr(i) for i in range(32, 126)]) 129 | constraints = [ 130 | "concat ^concat _", 131 | "compose ^compose _", 132 | ] 133 | -------------------------------------------------------------------------------- /examples/plot_enumeration_results.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict 2 | from typing import Dict, List 3 | import matplotlib.pyplot as plt 4 | import pltpublish as pub 5 | import csv 6 | 7 | from plot_helper import ( 8 | plot_y_wrt_x, 9 | make_plot_wrapper, 10 | ) 11 | 12 | 13 | __DATA__ = { 14 | "time": (0, "Time (in s)"), 15 | "programs": (1, "Programs Enumerated"), 16 | "queued": (2, "Queue Size"), 17 | "banked": (3, "Programs in Banks"), 18 | "non_terminals": (4, "Non Terminals in Grammar"), 19 | "rules": (5, "Derivation Rules in Grammar"), 20 | } 21 | 22 | 23 | def load_data(output_file: str, verbose: bool = False) -> Dict[str, Dict[int, List]]: 24 | # Dict[name, data] 25 | methods = {} 26 | 27 | # filename should end with a specific pattern 28 | name = output_file[:-4] 29 | if not (name.endswith("_detailed") or name.endswith("_growth")): 30 | if verbose: 31 | print(f"filename:{output_file} does not seem valid!") 32 | return {} 33 | trace = [] 34 | with open(output_file, "r") as fd: 35 | reader = csv.reader(fd) 36 | trace = [tuple(row) for row in reader] 37 | # Pop columns names 38 | columns = {name: ind for ind, name in enumerate(trace.pop(0))} 39 | indices = [ 40 | columns["search"], 41 | columns["time"], 42 | columns["programs"], 43 | columns["queue"], 44 | columns["bank"], 45 | columns["non_terminals"], 46 | columns["derivation_rules"], 47 | columns.get("seed", -1), 48 | ] 49 | data = [tuple(row[k] if k >= 0 else 0 for k in indices) for row in trace] 50 | if len(data) == 0: 51 | if verbose: 52 | print(f"filename:{output_file} is empty!") 53 | return {} 54 | agg = defaultdict(dict) 55 | for row in data: 56 | seed = int(row[-1]) 57 | if seed not in agg[row[0]]: 58 | agg[row[0]][seed] = [] 59 | agg[row[0]][seed].append(row[1:-1]) 60 | for name, data in agg.items(): 61 | name = name.replace("_", " ") 62 | if name not in methods: 63 | methods[name] = {} 64 | # Save data for method 65 | for seed, vals in data.items(): 66 | methods[name][seed] = [tuple(float(x) for x in row) for row in vals] 67 | # Backend support onl yseeded data so we register every data as seed 1 68 | return methods 69 | 70 | 71 | # Generate all possible combinations 72 | __PLOTS__ = {} 73 | for ydata in list(__DATA__.keys()): 74 | for xdata in list(__DATA__.keys()): 75 | if xdata == ydata: 76 | continue 77 | __PLOTS__[f"{ydata}_wrt_{xdata}"] = make_plot_wrapper( 78 | plot_y_wrt_x, 79 | __DATA__[xdata], 80 | __DATA__[ydata], 81 | cumulative=False, 82 | logy=xdata == "non_terminals", 83 | ) 84 | 85 | if __name__ == "__main__": 86 | import argparse 87 | import sys 88 | 89 | parser = argparse.ArgumentParser(description="Plot results") 90 | parser.add_argument( 91 | "file", 92 | type=str, 93 | help="data file to load", 94 | ) 95 | parser.add_argument( 96 | "-v", 97 | "--verbose", 98 | action="store_true", 99 | default=False, 100 | help="verbose mode", 101 | ) 102 | parser.add_argument("plots", nargs="+", choices=list(__PLOTS__.keys())) 103 | parameters = parser.parse_args() 104 | output_file: str = parameters.file 105 | verbose: bool = parameters.verbose 106 | plots: List[str] = parameters.plots 107 | 108 | # Load data 109 | pub.setup() 110 | methods = load_data(output_file, verbose) 111 | # Check we have at least one file 112 | if len(methods) == 0: 113 | print("Error: no performance file was found!", file=sys.stderr) 114 | sys.exit(1) 115 | # Order by name so that it is always the same color for the same methods if diff. DSL 116 | ordered_methods = OrderedDict() 117 | for met in sorted(methods.keys()): 118 | ordered_methods[met] = methods[met] 119 | # Plotting 120 | for count, to_plot in enumerate(plots): 121 | ax = plt.subplot(1, len(plots), count + 1) 122 | __PLOTS__[to_plot](ax, ordered_methods) 123 | plt.show() 124 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthesisLab/DeepSynth2/3efba7bfc03d3e576d67fb433c5fd8d50e3ebdb5/images/logo.png -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_defs = True 3 | disallow_any_unimported = True 4 | no_implicit_optional = True 5 | check_untyped_defs = True 6 | warn_return_any = True 7 | show_error_codes = True 8 | warn_unused_ignores = True 9 | exclude = tools|examples 10 | 11 | [mypy-colorama.*] 12 | ignore_missing_imports = True 13 | [mypy-numpy.*] 14 | ignore_missing_imports = True 15 | [mypy-tqdm.*] 16 | ignore_missing_imports = True 17 | [mypy-vose.*] 18 | ignore_missing_imports = True 19 | [mypy-torch.*] 20 | ignore_missing_imports = True 21 | [mypy-matplotlib.*] 22 | ignore_missing_imports = True 23 | [mypy-transformers.*] 24 | ignore_missing_imports = True -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "synth" 3 | version = "0.1.0" 4 | description = "Automated Synthesis Framework" 5 | authors = ["Théo Matricon ", "Nathanaël Fijalkow "] 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/SynthesisLab/DeepSynth2" 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.8" 12 | torch = ">=1.13.1" 13 | tensorboard = ">=0" 14 | cython = ">=0.29" 15 | numpy = ">=1.22" 16 | vose = { git = "https://github.com/Theomat/vose.git"} 17 | colorama = ">=0.4.4" 18 | tqdm = ">=4.63.0" 19 | matplotlib = ">=3.5.1" 20 | pltpublish = ">=0.1.0" 21 | 22 | [tool.poetry.dev-dependencies] 23 | ruff = ">= 0.4.3" 24 | pytest = ">=7.2.0" 25 | mypy = ">=0.910" 26 | 27 | [build-system] 28 | requires = ["poetry-core>=1.0.0"] 29 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import setuptools # type: ignore 4 | 5 | if __name__ == "__main__": 6 | setuptools.setup() 7 | -------------------------------------------------------------------------------- /synth/__init__.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | from synth.task import Task, Dataset 3 | from synth.specification import TaskSpecification, PBE, NLP, NLPBE, Example 4 | -------------------------------------------------------------------------------- /synth/filter/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to pruning 3 | """ 4 | 5 | from synth.filter.filter import Filter, UnionFilter, IntersectionFilter 6 | from synth.filter.dfta_filter import DFTAFilter 7 | from synth.filter.obs_eq_filter import ObsEqFilter 8 | from synth.filter.local_stateless_filter import LocalStatelessFilter 9 | from synth.filter.syntactic_filter import ( 10 | UseAllVariablesFilter, 11 | FunctionFilter, 12 | SyntacticFilter, 13 | SetFilter, 14 | ) 15 | from synth.filter.constraints import add_constraints, add_dfta_constraints 16 | -------------------------------------------------------------------------------- /synth/filter/constraints/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.filter.constraints.ttcfg_constraints import add_constraints 2 | from synth.filter.constraints.dfta_constraints import add_dfta_constraints 3 | -------------------------------------------------------------------------------- /synth/filter/dfta_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generic, TypeVar, Optional 2 | 3 | from synth.filter.filter import Filter 4 | from synth.syntax.automata.tree_automaton import DFTA 5 | from synth.syntax.grammars.grammar import DerivableProgram 6 | from synth.syntax.program import Function, Program, Lambda 7 | 8 | V = TypeVar("V") 9 | 10 | 11 | class DFTAFilter(Filter, Generic[V]): 12 | """ 13 | Filters out programs depending on the given DFTA. 14 | 15 | If accepting_dfta then rejects programs that are not in the language of the DFTA. 16 | If not accepting_dfta, rejects programs that are in the language of the DFTA. 17 | 18 | """ 19 | 20 | def __init__( 21 | self, dfta: DFTA[V, DerivableProgram], accepting_dfta: bool = True 22 | ) -> None: 23 | self.dfta = dfta 24 | self._cache: Dict[Program, V] = {} 25 | self.accepting_dfta = accepting_dfta 26 | 27 | def _get_prog_state(self, prog: Program) -> Optional[V]: 28 | state = self._cache.get(prog, None) 29 | if state is not None: 30 | return state 31 | if isinstance(prog, Function): 32 | fun = prog.function 33 | args = tuple(self._get_prog_state(arg) for arg in prog.arguments) 34 | state = self.dfta.read(fun, args) # type: ignore 35 | if state is not None: 36 | self._cache[prog] = state 37 | return state 38 | elif isinstance(prog, Lambda): 39 | assert False, "Not implemented" 40 | else: 41 | state = self.dfta.read(prog, ()) # type: ignore 42 | if state is not None: 43 | self._cache[prog] = state 44 | return state 45 | 46 | def accept(self, obj: Program) -> bool: 47 | return (self._get_prog_state(obj) is not None) == self.accepting_dfta 48 | 49 | def reset_cache(self) -> None: 50 | self._cache.clear() 51 | -------------------------------------------------------------------------------- /synth/filter/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | class Filter(ABC, Generic[T]): 9 | @abstractmethod 10 | def accept(self, obj: T) -> bool: 11 | """ 12 | Accepts objects that should be kept. 13 | """ 14 | pass 15 | 16 | def reject(self, obj: T) -> bool: 17 | """ 18 | Rejects objects that should NOT be kept. 19 | """ 20 | return not self.accept(obj) 21 | 22 | def __and__(self, other: "Filter[T]") -> "IntersectionFilter[T]": 23 | return self.intersection(other) 24 | 25 | def intersection(self, other: "Filter[T]") -> "IntersectionFilter[T]": 26 | if isinstance(other, IntersectionFilter): 27 | return other.intersection(self) 28 | elif isinstance(self, IntersectionFilter): 29 | if isinstance(other, IntersectionFilter): 30 | return IntersectionFilter(*self.filters, *other.filters) 31 | return IntersectionFilter(*self.filters, other) 32 | else: 33 | return IntersectionFilter(self, other) 34 | 35 | def __or__(self, other: "Filter[T]") -> "UnionFilter[T]": 36 | return self.union(other) 37 | 38 | def union(self, other: "Filter[T]") -> "UnionFilter[T]": 39 | if isinstance(other, UnionFilter): 40 | return other.union(self) 41 | elif isinstance(self, UnionFilter): 42 | if isinstance(other, UnionFilter): 43 | return UnionFilter(*self.filters, *other.filters) 44 | return UnionFilter(*self.filters, other) 45 | else: 46 | return UnionFilter(self, other) 47 | 48 | def __neg__(self) -> "Filter[T]": 49 | return self.complementary() 50 | 51 | def complementary(self) -> "Filter[T]": 52 | return NegFilter(self) 53 | 54 | 55 | class NegFilter(Filter, Generic[T]): 56 | def __init__(self, filter: Filter[T]) -> None: 57 | self.filter = filter 58 | 59 | def accept(self, obj: T) -> bool: 60 | return not self.filter.accept(obj) 61 | 62 | def complementary(self) -> "Filter[T]": 63 | return self.filter 64 | 65 | 66 | class UnionFilter(Filter, Generic[T]): 67 | def __init__(self, *filters: Filter[T]) -> None: 68 | self.filters = list(filters) 69 | 70 | def accept(self, obj: T) -> bool: 71 | return any(p.accept(obj) for p in self.filters) 72 | 73 | 74 | class IntersectionFilter(Filter, Generic[T]): 75 | def __init__(self, *filters: Filter[T]) -> None: 76 | self.filters = list(filters) 77 | 78 | def accept(self, obj: T) -> bool: 79 | return all(p.accept(obj) for p in self.filters) 80 | -------------------------------------------------------------------------------- /synth/filter/local_stateless_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Generic, TypeVar 2 | 3 | from synth.filter.filter import Filter 4 | from synth.syntax.program import Function, Program, Primitive 5 | 6 | V = TypeVar("V") 7 | 8 | 9 | class LocalStatelessFilter(Filter, Generic[V]): 10 | def __init__(self, should_reject: Dict[str, Callable]) -> None: 11 | self.should_reject = should_reject 12 | 13 | def accept(self, program: Program) -> bool: 14 | accepted = True 15 | if isinstance(program, Function): 16 | fun: Primitive = program.function # type: ignore 17 | rejects = self.should_reject.get(fun.primitive, None) 18 | accepted = rejects is None or not rejects(*program.arguments) 19 | return accepted 20 | 21 | 22 | def commutative_rejection(p1: Program, p2: Program) -> bool: 23 | """ 24 | Rejection filter to have unique programs for a commutative binary operator 25 | """ 26 | return hash(p1) <= hash(p2) 27 | 28 | 29 | def reject_functions(p: Program, *function_names: str) -> bool: 30 | """ 31 | Rejects any function whose name is in the parameters 32 | """ 33 | if isinstance(p, Function): 34 | return p.function.primitive in function_names # type: ignore 35 | return False 36 | -------------------------------------------------------------------------------- /synth/filter/obs_eq_filter.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Dict, List, Tuple 3 | 4 | from synth.filter.filter import Filter 5 | from synth.semantic.evaluator import Evaluator 6 | from synth.syntax.program import Program 7 | from synth.syntax.type_system import Type 8 | 9 | 10 | class ObsEqFilter(Filter): 11 | def __init__(self, evaluator: Evaluator, inputs_list: List[List[Any]]) -> None: 12 | self.evaluator = evaluator 13 | self.inputs_list = inputs_list 14 | self._cache: Dict[Type, Dict[Tuple[Any, ...], Program]] = defaultdict(dict) 15 | 16 | def _eval(self, prog: Program) -> bool: 17 | """ 18 | Returns True iff the prog is unique wrt to outputs 19 | """ 20 | outputs = None 21 | for inputs in self.inputs_list: 22 | out = self.evaluator.eval(prog, inputs) 23 | if out is None: 24 | return False 25 | elif isinstance(out, List): 26 | out = tuple(out) 27 | outputs = (outputs, out) 28 | original = self._cache[prog.type].get(outputs) # type: ignore 29 | if original is not None and hash(original) != hash(prog): 30 | return False 31 | else: 32 | self._cache[prog.type][outputs] = prog # type: ignore 33 | return True 34 | 35 | def accept(self, obj: Program) -> bool: 36 | return self._eval(obj) 37 | 38 | def reset_cache(self) -> None: 39 | self._cache.clear() 40 | -------------------------------------------------------------------------------- /synth/filter/syntactic_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Set, Tuple 2 | 3 | from synth.filter.filter import Filter 4 | from synth.syntax.program import Function, Primitive, Program 5 | from synth.syntax.type_system import Arrow, Type 6 | 7 | 8 | SyntacticFilter = Filter[Tuple[Type, Program]] 9 | 10 | 11 | class UseAllVariablesFilter(SyntacticFilter): 12 | def __init__(self) -> None: 13 | super().__init__() 14 | self._cached_variables_set: Dict[Type, Set[int]] = {} 15 | 16 | def __get_var_set__(self, treq: Type) -> Set[int]: 17 | if treq not in self._cached_variables_set: 18 | if treq.is_instance(Arrow): 19 | self._cached_variables_set[treq] = set(range(len(treq.arguments()))) 20 | else: 21 | self._cached_variables_set[treq] = set() 22 | 23 | return self._cached_variables_set[treq] 24 | 25 | def accept(self, obj: Tuple[Type, Program]) -> bool: 26 | treq, prog = obj 27 | target = self.__get_var_set__(treq) 28 | return prog.used_variables() == target 29 | 30 | 31 | class FunctionFilter(SyntacticFilter): 32 | def __init__(self, is_useless: Dict[str, Callable]) -> None: 33 | super().__init__() 34 | self.is_useless = is_useless 35 | 36 | def accept(self, obj: Tuple[Type, Program]) -> bool: 37 | _, prog = obj 38 | for P in prog.depth_first_iter(): 39 | if not isinstance(P, Function): 40 | continue 41 | f = P.function 42 | if not isinstance(f, Primitive): 43 | continue 44 | if f.primitive in self.is_useless and self.is_useless[f.primitive]( 45 | *P.arguments 46 | ): 47 | return False 48 | return True 49 | 50 | 51 | class SetFilter(SyntacticFilter): 52 | def __init__(self, forbidden: Set[Program]) -> None: 53 | super().__init__() 54 | self.forbidden = forbidden 55 | 56 | def accept(self, obj: Tuple[Type, Program]) -> bool: 57 | return obj[1] not in self.forbidden 58 | -------------------------------------------------------------------------------- /synth/generation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to the generation 3 | """ 4 | 5 | from synth.generation.sampler import ( 6 | Sampler, 7 | RequestSampler, 8 | LexiconSampler, 9 | ListSampler, 10 | UnionSampler, 11 | ) 12 | -------------------------------------------------------------------------------- /synth/library/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.library.learning import learn, make_score_probabilistic, score_description 2 | -------------------------------------------------------------------------------- /synth/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.nlp.bert import NLPEncoder 2 | -------------------------------------------------------------------------------- /synth/nlp/bert.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Literal, Optional, Tuple 2 | import re 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from transformers import BertTokenizer, BertModel 8 | 9 | from synth.nn.spec_encoder import SpecificationEncoder 10 | from synth.specification import NLP 11 | from synth.task import Task 12 | 13 | __QUOTED_TOKEN_RE__ = re.compile(r"(?P''|[`'\"])(?P.*?)(?P=quote)") 14 | """ 15 | Patterns that find strings surrounded by backquotes. 16 | """ 17 | 18 | __BERT_MODEL__ = "bert-base-uncased" 19 | 20 | 21 | class NLPEncoder(SpecificationEncoder[NLP, Tensor]): 22 | def __init__(self, max_var_num: int = 4) -> None: 23 | self.tokenizer = BertTokenizer.from_pretrained(__BERT_MODEL__) 24 | self.tokenizer.add_tokens( 25 | [f"var_{i}" for i in range(max_var_num + 1)] 26 | + [f"str_{i}" for i in range(max_var_num + 1)] 27 | ) 28 | self.encoder = BertModel.from_pretrained(__BERT_MODEL__) 29 | self.encoder.resize_token_embeddings(len(self.tokenizer)) 30 | 31 | @property 32 | def embedding_size(self) -> int: 33 | size: int = self.encoder.config.hidden_size 34 | return size 35 | 36 | def encode(self, task: Task[NLP], device: Optional[str] = None) -> Tensor: 37 | intent_tokens, slot_map = self.canonicalize_intent(task.specification.intent) 38 | tensor: Tensor = self.encoder(intent_tokens).last_hidden_state 39 | return tensor.to(device) 40 | # find the slot map 41 | 42 | def canonicalize_intent( 43 | self, intent: str 44 | ) -> Tuple[Tensor, Dict[str, Dict[str, str]]]: 45 | # handle the following special case: quote is `''` 46 | marked_token_matches = __QUOTED_TOKEN_RE__.findall(intent) 47 | 48 | slot_map = dict() 49 | ids_counts = {"var": 0, "str": 0} 50 | for match in marked_token_matches: 51 | quote: str = match[0] 52 | value: str = match[1] 53 | quoted_value = quote + value + quote 54 | 55 | slot_type = __infer_slot_type__(quote, value) 56 | slot_name = slot_type + ("_%d" % ids_counts[slot_type]) 57 | ids_counts[slot_type] += 1 58 | 59 | intent = intent.replace(quoted_value, slot_name) 60 | 61 | slot_map[slot_name] = { 62 | "value": value.strip().encode().decode("unicode_escape", "ignore"), 63 | "quote": quote, 64 | "type": slot_type, 65 | } 66 | 67 | intent_list: List[str] = self.tokenizer.tokenize(intent.lower()) 68 | intent_list = ["[CLS]"] + intent_list + ["[SEP]"] 69 | voc: Dict[str, int] = self.tokenizer.get_vocab() 70 | intent_tensor = torch.tensor([voc[x] for x in intent_list]).unsqueeze(0) 71 | return intent_tensor, slot_map 72 | 73 | 74 | def __infer_slot_type__(quote: str, value: str) -> Literal["var", "str"]: 75 | if quote == "`" and value.isidentifier(): 76 | return "var" 77 | return "str" 78 | -------------------------------------------------------------------------------- /synth/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to neural networks 3 | """ 4 | 5 | from synth.nn.det_grammar_predictor import DetGrammarPredictorLayer 6 | from synth.nn.u_grammar_predictor import UGrammarPredictorLayer 7 | import synth.nn.abstractions as abstractions 8 | from synth.nn.utils import ( 9 | AutoPack, 10 | Task2Tensor, 11 | print_model_summary, 12 | free_pytorch_memory, 13 | ) 14 | -------------------------------------------------------------------------------- /synth/nn/abstractions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, TypeVar 2 | from synth.syntax.grammars.cfg import CFGState, NoneType 3 | from synth.syntax.grammars.det_grammar import DerivableProgram 4 | from synth.syntax.grammars.grammar import NGram 5 | from synth.syntax.program import Primitive 6 | from synth.syntax.type_system import Type 7 | 8 | T = TypeVar("T") 9 | S = TypeVar("S") 10 | 11 | 12 | def ucfg_bigram( 13 | ctx: Tuple[Type, Tuple[NGram, T]], 14 | ) -> Optional[Tuple[DerivableProgram, int]]: 15 | """ 16 | Abstract away a TTCFG into tuples of (parent, no_arg). 17 | We lose any other information. 18 | """ 19 | _, (ngram, __) = ctx 20 | while not isinstance(ngram, NGram): 21 | ngram = ngram[0] 22 | if len(ngram) > 0: 23 | return ngram.last() 24 | return None 25 | 26 | 27 | def ttcfg_bigram( 28 | ctx: Tuple[Type, Tuple[S, T]], 29 | ) -> Optional[Tuple[DerivableProgram, int]]: 30 | """ 31 | Abstract away a TTCFG into tuples of (parent, no_arg). 32 | We lose any other information. 33 | """ 34 | _, (ngram, __) = ctx 35 | while not isinstance(ngram, NGram): 36 | ngram = ngram[0] # type: ignore 37 | if len(ngram) > 0: 38 | return ngram.last() 39 | return None 40 | 41 | 42 | def cfg_bigram_without_depth( 43 | ctx: Tuple[Type, Tuple[CFGState, NoneType]], 44 | ) -> Optional[Tuple[DerivableProgram, int]]: 45 | """ 46 | Abstract away a CFG into tuples of (parent, no_arg). 47 | We lose depth information. 48 | """ 49 | _, (state, __) = ctx 50 | ngram, ___ = state 51 | if len(ngram) > 0: 52 | return ngram.last() 53 | return None 54 | 55 | 56 | def primitive_presence(*args: Any) -> None: 57 | """ 58 | Abstract away a grammar into just the presence or absence of a primitive. 59 | """ 60 | return None 61 | -------------------------------------------------------------------------------- /synth/nn/spec_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import Generic, TypeVar 3 | 4 | from synth.specification import TaskSpecification 5 | from synth.task import Task 6 | 7 | T = TypeVar("T", bound=TaskSpecification, covariant=True) 8 | U = TypeVar("U") 9 | 10 | 11 | class SpecificationEncoder(ABC, Generic[T, U]): 12 | @abstractmethod 13 | def encode(self, task: Task[T]) -> U: 14 | pass 15 | -------------------------------------------------------------------------------- /synth/pbe/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to the Programming By Example (PBE) framework 3 | """ 4 | 5 | from synth.pbe.io_encoder import IOEncoder 6 | from synth.pbe.task_generator import ( 7 | TaskGenerator, 8 | basic_output_validator, 9 | reproduce_dataset, 10 | ) 11 | -------------------------------------------------------------------------------- /synth/pbe/io_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from synth.nn.spec_encoder import SpecificationEncoder 7 | from synth.specification import PBE 8 | from synth.task import Task 9 | 10 | 11 | class IOEncoder(SpecificationEncoder[PBE, Tensor]): 12 | def __init__( 13 | self, output_dimension: int, lexicon: List[Any], undefined: bool = True 14 | ) -> None: 15 | self.output_dimension = output_dimension 16 | 17 | self.special_symbols = [ 18 | "PADDING", # padding symbol that can be used later 19 | "STARTING", # start of entire sequence 20 | "ENDOFINPUT", # delimits the ending of an input - we might have multiple inputs 21 | "STARTOFOUTPUT", # begins the start of the output 22 | "ENDING", # ending of entire sequence 23 | "STARTOFLIST", 24 | "ENDOFLIST", 25 | ] 26 | if undefined: 27 | self.special_symbols.append("UNDEFINED") 28 | self.lexicon = lexicon + self.special_symbols 29 | self.non_special_lexicon_size = len(lexicon) 30 | self.symbol2index = {symbol: index for index, symbol in enumerate(self.lexicon)} 31 | self._default = self.symbol2index["UNDEFINED"] if undefined else None 32 | self.starting_index = self.symbol2index["STARTING"] 33 | self.end_of_input_index = self.symbol2index["ENDOFINPUT"] 34 | self.start_of_output_index = self.symbol2index["STARTOFOUTPUT"] 35 | self.ending_index = self.symbol2index["ENDING"] 36 | self.start_list_index = self.symbol2index["STARTOFLIST"] 37 | self.end_list_index = self.symbol2index["ENDOFLIST"] 38 | self.pad_symbol = self.symbol2index["PADDING"] 39 | 40 | def __encode_element__(self, x: Any, encoding: List[int]) -> None: 41 | if isinstance(x, List): 42 | encoding.append(self.start_list_index) 43 | for el in x: 44 | self.__encode_element__(el, encoding) 45 | encoding.append(self.end_list_index) 46 | else: 47 | encoding.append(self.symbol2index.get(x, self._default)) # type: ignore 48 | 49 | def encode_IO(self, IO: Tuple[List, Any], device: Optional[str] = None) -> Tensor: 50 | """ 51 | embed a list of inputs and its associated output 52 | IO is of the form [[I1, I2, ..., Ik], O] 53 | where I1, I2, ..., Ik are inputs and O is an output 54 | 55 | outputs a tensor of dimension self.output_dimension 56 | """ 57 | e = [self.starting_index] 58 | inputs, output = IO 59 | for x in inputs: 60 | self.__encode_element__(x, e) 61 | e.append(self.end_of_input_index) 62 | e.append(self.start_of_output_index) 63 | self.__encode_element__(output, e) 64 | e.append(self.ending_index) 65 | size = len(e) 66 | if size > self.output_dimension: 67 | assert False, "IOEncoder: IO too large: {} > {} for {}".format( 68 | size, self.output_dimension, IO 69 | ) 70 | else: 71 | for _ in range(self.output_dimension - size): 72 | e.append(self.ending_index) 73 | res = torch.LongTensor(e).to(device) 74 | return res 75 | 76 | def encode(self, task: Task[PBE], device: Optional[str] = None) -> Tensor: 77 | return torch.stack( 78 | [ 79 | self.encode_IO((ex.inputs, ex.output), device) 80 | for ex in task.specification.examples 81 | ] 82 | ) 83 | -------------------------------------------------------------------------------- /synth/pbe/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.pbe.solvers.pbe_solver import ( 2 | PBESolver, 3 | NaivePBESolver, 4 | CutoffPBESolver, 5 | MetaPBESolver, 6 | ) 7 | from synth.pbe.solvers.restart_pbe_solver import RestartPBESolver 8 | -------------------------------------------------------------------------------- /synth/pbe/solvers/restart_pbe_solver.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Generator, List, Tuple 2 | from synth.semantic.evaluator import DSLEvaluator 3 | 4 | 5 | from synth.specification import PBE 6 | from synth.syntax.grammars.enumeration.program_enumerator import ProgramEnumerator 7 | from synth.syntax.grammars.grammar import DerivableProgram 8 | from synth.syntax.program import Program 9 | from synth.syntax.type_system import Type 10 | from synth.task import Task 11 | from synth.utils import chrono 12 | from synth.pbe.solvers.pbe_solver import MetaPBESolver, NaivePBESolver, PBESolver 13 | 14 | 15 | class RestartPBESolver(MetaPBESolver): 16 | def __init__( 17 | self, 18 | evaluator: DSLEvaluator, 19 | solver_builder: Callable[..., PBESolver] = NaivePBESolver, 20 | restart_criterion: Callable[["RestartPBESolver"], bool] = lambda self: len( 21 | self._data 22 | ) 23 | - self._last_size 24 | > 10000, 25 | uniform_prior: float = 0.05, 26 | **kwargs: Any, 27 | ) -> None: 28 | super().__init__(evaluator, solver_builder, **kwargs) 29 | self.restart_criterion = restart_criterion 30 | self._last_size: int = 0 31 | self.uniform_prior = uniform_prior 32 | 33 | def _init_stats_(self) -> None: 34 | super()._init_stats_() 35 | self._stats["restarts"] = 0 36 | 37 | @classmethod 38 | def name(cls) -> str: 39 | return "restart" 40 | 41 | def _init_task_solving_( 42 | self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60 43 | ) -> None: 44 | super()._init_task_solving_(task, enumerator, timeout) 45 | self._restarts = 0 46 | self._data: List[Tuple[Program, float]] = [] 47 | self._last_size = 0 48 | 49 | def _close_task_solving_( 50 | self, 51 | task: Task[PBE], 52 | enumerator: ProgramEnumerator[None], 53 | time_used: float, 54 | solution: bool, 55 | last_program: Program, 56 | ) -> None: 57 | super()._close_task_solving_( 58 | task, enumerator, time_used, solution, last_program 59 | ) 60 | self._stats["restarts"] += self._restarts 61 | 62 | def solve( 63 | self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60 64 | ) -> Generator[Program, None, bool]: 65 | with chrono.clock(f"solve.{self.name()}.{self.subsolver.name()}") as c: # type: ignore 66 | self._enumerator = enumerator 67 | self._init_task_solving_(task, self._enumerator, timeout) 68 | gen = self._enumerator.generator() 69 | program = next(gen) 70 | while program is not None: 71 | time = c.elapsed_time() 72 | if time >= timeout: 73 | self._close_task_solving_( 74 | task, self._enumerator, time, False, program 75 | ) 76 | return False 77 | self._programs += 1 78 | if self._test_(task, program): 79 | should_stop = yield program 80 | if should_stop: 81 | self._close_task_solving_( 82 | task, self._enumerator, time, True, program 83 | ) 84 | return True 85 | self._score = self.subsolver._score 86 | # Saves data 87 | if self._score > 0: 88 | self._data.append((program, self._score)) 89 | # If should restart 90 | if self._should_restart_(): 91 | self._restarts += 1 92 | self._enumerator = self._restart_(self._enumerator) 93 | gen = self._enumerator.generator() 94 | program = next(gen) 95 | return False 96 | 97 | def _should_restart_(self) -> bool: 98 | return self.restart_criterion(self) 99 | 100 | def _restart_(self, enumerator: ProgramEnumerator[None]) -> ProgramEnumerator[None]: 101 | pcfg = enumerator.G * 0 # type: ignore 102 | self._last_size = len(self._data) 103 | 104 | def reduce( 105 | score: float, S: Tuple[Type, Any], P: DerivableProgram, prob: float 106 | ) -> float: 107 | pcfg.probabilities[S][P] += score 108 | return score 109 | 110 | for program, score in self._data: 111 | pcfg.reduce_derivations(reduce, score, program) 112 | if self.uniform_prior > 0: 113 | pcfg = pcfg + (pcfg.uniform(pcfg.grammar) * self.uniform_prior) 114 | pcfg.normalise() 115 | new_enumerator = enumerator.clone(pcfg) 116 | return new_enumerator 117 | -------------------------------------------------------------------------------- /synth/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthesisLab/DeepSynth2/3efba7bfc03d3e576d67fb433c5fd8d50e3ebdb5/synth/py.typed -------------------------------------------------------------------------------- /synth/semantic/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to the semantic 3 | """ 4 | 5 | from synth.semantic.evaluator import Evaluator, DSLEvaluator 6 | -------------------------------------------------------------------------------- /synth/semantic/evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Set, Callable, Tuple 3 | 4 | from synth.syntax.program import Constant, Function, Primitive, Program, Variable, Type 5 | 6 | 7 | class Evaluator(ABC): 8 | @abstractmethod 9 | def eval(self, program: Program, input: Any) -> Any: 10 | pass 11 | 12 | @abstractmethod 13 | def clear_cache(self) -> None: 14 | """ 15 | Clear any cache this evaluator might use. 16 | """ 17 | pass 18 | 19 | 20 | def __tuplify__(element: Any) -> Any: 21 | if isinstance(element, List): 22 | return tuple(__tuplify__(x) for x in element) 23 | else: 24 | return element 25 | 26 | 27 | class DSLEvaluator(Evaluator): 28 | def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> None: 29 | super().__init__() 30 | self.semantics = semantics 31 | self.use_cache = use_cache 32 | self._cache: Dict[Any, Dict[Program, Any]] = {} 33 | self._cons_cache: Dict[Any, Dict[Program, Any]] = {} 34 | self.skip_exceptions: Set[Exception] = set() 35 | # Statistics 36 | self._total_requests = 0 37 | self._cache_hits = 0 38 | self._dsl_constants: Dict[Tuple[Type, Any], Primitive] = {} 39 | for p, val in semantics.items(): 40 | if len(p.type.arguments()) == 0: 41 | self._dsl_constants[(p.type, __tuplify__(val))] = p 42 | 43 | def compress(self, program: Program, allow_constants: bool = True) -> Program: 44 | """ 45 | Return a semantically equivalent version of the program by evaluating constant expressions. 46 | Note for data saving/loading purposes, partial applications are left untouched. 47 | """ 48 | if isinstance(program, Function): 49 | args = [ 50 | self.compress(p, allow_constants=allow_constants) 51 | for p in program.arguments 52 | ] 53 | if len(program.type.returns().arguments()) == 0 and all( 54 | not a.uses_variables() for a in args 55 | ): 56 | before = self.use_cache 57 | self.use_cache = False 58 | value = self.eval(program, []) 59 | self.use_cache = before 60 | # Cancel compression of callable 61 | if isinstance(value, Callable): # type: ignore 62 | return Function(program.function, args) 63 | tval = __tuplify__(value) 64 | rtype = program.type 65 | if (rtype, tval) in self._dsl_constants: 66 | return self._dsl_constants[(rtype, tval)] 67 | if allow_constants: 68 | return Constant(program.type.returns(), value, True) 69 | else: 70 | return Function(program.function, args) 71 | else: 72 | return Function(program.function, args) 73 | else: 74 | return program 75 | 76 | def eval(self, program: Program, input: List) -> Any: 77 | key = __tuplify__(input) 78 | if self.use_cache and key not in self._cache: 79 | self._cache[key] = {} 80 | evaluations: Dict[Program, Any] = self._cache[key] if self.use_cache else {} 81 | if program in evaluations: 82 | return evaluations[program] 83 | try: 84 | for sub_prog in program.depth_first_iter(): 85 | self._total_requests += 1 86 | if sub_prog in evaluations: 87 | self._cache_hits += 1 88 | continue 89 | if isinstance(sub_prog, Primitive): 90 | evaluations[sub_prog] = self.semantics[sub_prog] 91 | elif isinstance(sub_prog, Variable): 92 | evaluations[sub_prog] = input[sub_prog.variable] 93 | elif isinstance(sub_prog, Constant): 94 | evaluations[sub_prog] = sub_prog.value 95 | elif isinstance(sub_prog, Function): 96 | fun = evaluations[sub_prog.function] 97 | for arg in sub_prog.arguments: 98 | fun = fun(evaluations[arg]) 99 | evaluations[sub_prog] = fun 100 | except Exception as e: 101 | if type(e) in self.skip_exceptions: 102 | evaluations[program] = None 103 | return None 104 | else: 105 | raise e 106 | 107 | return evaluations[program] 108 | 109 | def clear_cache(self) -> None: 110 | self._cache = {} 111 | self._cons_cache = {} 112 | 113 | @property 114 | def cache_hit_rate(self) -> float: 115 | return self._cache_hits / self._total_requests 116 | -------------------------------------------------------------------------------- /synth/specification.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Generic, List, Optional, TypeVar 3 | 4 | from synth.syntax.type_system import EmptyList, Type 5 | from synth.syntax.type_helper import FunctionType, guess_type 6 | 7 | 8 | class TaskSpecification: 9 | def get_specification(self, specification: type) -> "Optional[TaskSpecification]": 10 | """ 11 | Gets the specification of the given TaskSpecification type that may be embed in this possibly compound specification. 12 | If none is found returns None. 13 | """ 14 | if isinstance(self, specification): 15 | return self 16 | return None 17 | 18 | 19 | @dataclass 20 | class Example: 21 | """ 22 | Represents an example pair of (inputs, output) 23 | """ 24 | 25 | inputs: List[Any] 26 | output: Any 27 | 28 | def guess_type(self) -> Type: 29 | types = list(map(guess_type, self.inputs)) + [guess_type(self.output)] 30 | return FunctionType(*types) 31 | 32 | 33 | @dataclass 34 | class PBE(TaskSpecification): 35 | """ 36 | Programming By Example (PBE) specification. 37 | """ 38 | 39 | examples: List[Example] 40 | 41 | def guess_type(self) -> Type: 42 | i = 0 43 | t = self.examples[i].guess_type() 44 | while EmptyList in t and i + 1 < len(self.examples): 45 | i += 1 46 | t = self.examples[i].guess_type() 47 | return t 48 | 49 | 50 | @dataclass 51 | class PBEWithConstants(PBE): 52 | """ 53 | Programming By Example (PBE) with constants specification 54 | """ 55 | 56 | constants: Dict[Type, List[Any]] 57 | 58 | 59 | @dataclass 60 | class NLP(TaskSpecification): 61 | """ 62 | Natural Language (NLP) specification. 63 | """ 64 | 65 | intent: str 66 | 67 | 68 | @dataclass 69 | class SketchedSpecification(TaskSpecification): 70 | sketch: str 71 | 72 | 73 | U = TypeVar("U", bound=TaskSpecification, covariant=True) 74 | V = TypeVar("V", bound=TaskSpecification, covariant=True) 75 | 76 | 77 | @dataclass 78 | class CompoundSpecification(TaskSpecification, Generic[U, V]): 79 | specification1: U 80 | specification2: V 81 | 82 | def get_specification(self, specification: type) -> "Optional[TaskSpecification]": 83 | a = self.specification1.get_specification(specification) 84 | if a is not None: 85 | return a 86 | return self.specification2.get_specification(specification) 87 | 88 | 89 | NLPBE = CompoundSpecification[NLP, PBE] 90 | SketchedPBE = CompoundSpecification[SketchedSpecification, PBE] 91 | -------------------------------------------------------------------------------- /synth/syntax/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module that contains anything relevant to the syntax 3 | """ 4 | 5 | from synth.syntax.dsl import DSL 6 | from synth.syntax.program import ( 7 | Primitive, 8 | Variable, 9 | Function, 10 | Lambda, 11 | Program, 12 | Constant, 13 | ) 14 | from synth.syntax.type_helper import guess_type, FunctionType, auto_type 15 | from synth.syntax.type_system import ( 16 | Type, 17 | match, 18 | PrimitiveType, 19 | PolymorphicType, 20 | FixedPolymorphicType, 21 | Generic, 22 | TypeFunctor, 23 | GenericFunctor, 24 | List, 25 | Arrow, 26 | Sum, 27 | UnknownType, 28 | INT, 29 | BOOL, 30 | STRING, 31 | UNIT, 32 | ) 33 | from synth.syntax.automata import DFA, DFTA 34 | from synth.syntax.grammars import ( 35 | CFG, 36 | UCFG, 37 | TTCFG, 38 | Grammar, 39 | DetGrammar, 40 | UGrammar, 41 | ProbDetGrammar, 42 | ProbUGrammar, 43 | TaggedDetGrammar, 44 | TaggedUGrammar, 45 | ProgramEnumerator, 46 | bs_enumerate_prob_grammar, 47 | bps_enumerate_prob_grammar, 48 | hs_enumerate_prob_grammar, 49 | hs_enumerate_prob_u_grammar, 50 | hs_enumerate_bucket_prob_grammar, 51 | hs_enumerate_bucket_prob_u_grammar, 52 | cd_enumerate_prob_grammar, 53 | as_enumerate_prob_grammar, 54 | split, 55 | ) 56 | -------------------------------------------------------------------------------- /synth/syntax/automata/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.automata.dfa import DFA 2 | from synth.syntax.automata.tree_automaton import DFTA 3 | -------------------------------------------------------------------------------- /synth/syntax/automata/dfa.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Generic, Set, Tuple, TypeVar 2 | 3 | 4 | U = TypeVar("U") 5 | V = TypeVar("V") 6 | W = TypeVar("W") 7 | X = TypeVar("X") 8 | 9 | 10 | class DFA(Generic[U, V]): 11 | """ 12 | Deterministic safe finite automaton. 13 | states: U 14 | alphabet: V 15 | Reads V elements from states U. 16 | If there is no transition from U reading V it means it is non accepting. (there are no final states) 17 | """ 18 | 19 | def __init__(self, initial: U, rules: Dict[U, Dict[V, U]]) -> None: 20 | self.start = initial 21 | self.rules = rules 22 | # Clean unreachable states 23 | reachables = self.states 24 | for u in list(self.rules.keys()): 25 | if u not in reachables: 26 | del self.rules[u] 27 | else: 28 | for P in list(self.rules[u].keys()): 29 | if self.rules[u][P] not in reachables: 30 | del self.rules[u][P] 31 | 32 | def __mul__(self, other: "DFA[W, X]") -> "DFA[Tuple[U, W], Tuple[V, X]]": 33 | start = (self.start, other.start) 34 | rules: Dict[Tuple[U, W], Dict[Tuple[V, X], Tuple[U, W]]] = {} 35 | for S1 in self.rules: 36 | for S2 in other.rules: 37 | rules[(S1, S2)] = {} 38 | for w1 in self.rules[S1]: 39 | for w2 in other.rules[S2]: 40 | rules[(S1, S2)][(w1, w2)] = ( 41 | self.rules[S1][w1], 42 | other.rules[S2][w2], 43 | ) 44 | return DFA(start, rules) 45 | 46 | def __str__(self) -> str: 47 | s = f"Print a DFA\n" 48 | s += "start: {}\n".format(self.start) 49 | for S in reversed(self.rules): 50 | s += "#\n {}\n".format(S) 51 | for P in self.rules[S]: 52 | out = self.rules[S][P] 53 | s += "\t{} -> {}\n".format(P, out) 54 | return s 55 | 56 | def __repr__(self) -> str: 57 | return self.__str__() 58 | 59 | @property 60 | def states(self) -> Set[U]: 61 | """ 62 | The set of reachables states. 63 | """ 64 | all = set() 65 | frontier = [self.start] 66 | while frontier: 67 | state = frontier.pop() 68 | for P in self.rules[state]: 69 | new_state = self.rules[state][P] 70 | if new_state not in all: 71 | all.add(new_state) 72 | frontier.append(new_state) 73 | return all 74 | 75 | def can_read(self, start: U, word: V) -> bool: 76 | return start in self.rules and word in self.rules[start] 77 | 78 | def read(self, start: U, word: V) -> U: 79 | return self.rules[start][word] 80 | 81 | def map_states(self, f: Callable[[U], W]) -> "DFA[W, V]": 82 | mapping = {s: f(s) for s in self.states} 83 | dst_rules = { 84 | mapping[S]: {P: mapping[self.rules[S][P]] for P in self.rules[S]} 85 | for S in self.rules 86 | } 87 | return DFA(mapping[self.start], dst_rules) 88 | 89 | def then(self, other: "DFA[U, V]") -> "DFA[U, V]": 90 | assert self.states.isdisjoint(other.states) 91 | new_rules = { 92 | S: {P: self.rules[S][P] for P in self.rules[S]} for S in self.rules 93 | } 94 | for S in other.rules: 95 | new_rules[S] = {P: other.rules[S][P] for P in other.rules[S]} 96 | return DFA(self.start, new_rules) 97 | 98 | def read_product(self, other: "DFA[W, V]") -> "DFA[Tuple[U, W], V]": 99 | start = (self.start, other.start) 100 | rules: Dict[Tuple[U, W], Dict[V, Tuple[U, W]]] = {} 101 | for S1 in self.rules: 102 | for S2 in other.rules: 103 | rules[(S1, S2)] = {} 104 | for v in self.rules[S1]: 105 | if v in other.rules[S2]: 106 | rules[(S1, S2)][v] = ( 107 | self.rules[S1][v], 108 | other.rules[S2][v], 109 | ) 110 | return DFA(start, rules) 111 | -------------------------------------------------------------------------------- /synth/syntax/grammars/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.cfg import CFG 2 | from synth.syntax.grammars.ttcfg import TTCFG 3 | from synth.syntax.grammars.grammar import Grammar 4 | from synth.syntax.grammars.det_grammar import DetGrammar 5 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar, TaggedDetGrammar 6 | from synth.syntax.grammars.enumeration import ( 7 | ProgramEnumerator, 8 | bs_enumerate_prob_grammar, 9 | bps_enumerate_prob_grammar, 10 | hs_enumerate_prob_grammar, 11 | hs_enumerate_prob_u_grammar, 12 | hs_enumerate_bucket_prob_grammar, 13 | hs_enumerate_bucket_prob_u_grammar, 14 | cd_enumerate_prob_grammar, 15 | as_enumerate_prob_grammar, 16 | split, 17 | ) 18 | from synth.syntax.grammars.u_grammar import UGrammar 19 | from synth.syntax.grammars.u_cfg import UCFG 20 | from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar, TaggedUGrammar 21 | -------------------------------------------------------------------------------- /synth/syntax/grammars/enumeration/__init__.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.heap_search import ( 2 | enumerate_prob_grammar as hs_enumerate_prob_grammar, 3 | enumerate_bucket_prob_grammar as hs_enumerate_bucket_prob_grammar, 4 | ) 5 | from synth.syntax.grammars.enumeration.u_heap_search import ( 6 | enumerate_prob_u_grammar as hs_enumerate_prob_u_grammar, 7 | enumerate_bucket_prob_u_grammar as hs_enumerate_bucket_prob_u_grammar, 8 | ) 9 | from synth.syntax.grammars.enumeration.grammar_splitter import split 10 | from synth.syntax.grammars.enumeration.program_enumerator import ProgramEnumerator 11 | 12 | from synth.syntax.grammars.enumeration.bee_search import ( 13 | enumerate_prob_grammar as bs_enumerate_prob_grammar, 14 | ) 15 | from synth.syntax.grammars.enumeration.beap_search import ( 16 | enumerate_prob_grammar as bps_enumerate_prob_grammar, 17 | ) 18 | from synth.syntax.grammars.enumeration.constant_delay import ( 19 | enumerate_prob_grammar as cd_enumerate_prob_grammar, 20 | ) 21 | from synth.syntax.grammars.enumeration.a_star import ( 22 | enumerate_prob_grammar as as_enumerate_prob_grammar, 23 | ) 24 | 25 | enumerate_prob_grammar = cd_enumerate_prob_grammar 26 | -------------------------------------------------------------------------------- /synth/syntax/grammars/enumeration/a_star.py: -------------------------------------------------------------------------------- 1 | from heapq import heappush, heappop 2 | from typing import ( 3 | Generator, 4 | Generic, 5 | List, 6 | Optional, 7 | Tuple, 8 | TypeVar, 9 | Union, 10 | ) 11 | from dataclasses import dataclass, field 12 | 13 | import numpy as np 14 | 15 | from synth.filter.filter import Filter 16 | from synth.syntax.grammars.enumeration.program_enumerator import ProgramEnumerator 17 | from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar 18 | from synth.syntax.program import Function, Program 19 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar, DerivableProgram 20 | from synth.syntax.type_system import Type 21 | 22 | U = TypeVar("U") 23 | V = TypeVar("V") 24 | W = TypeVar("W") 25 | 26 | 27 | def _build_( 28 | elems: List[Tuple[DerivableProgram, Tuple[Type, U]]], G: ProbDetGrammar[U, V, W] 29 | ) -> Program: 30 | P, S = elems.pop(0) 31 | nargs = G.arguments_length_for(S, P) 32 | if nargs == 0: 33 | return P 34 | else: 35 | args = [] 36 | while nargs > 0: 37 | args.append(_build_(elems, G)) 38 | nargs -= 1 39 | return Function(P, args) 40 | 41 | 42 | @dataclass(order=True, frozen=True) 43 | class HeapElement(Generic[U]): 44 | priority: float 45 | to_expand: List[Tuple[Type, U]] = field(compare=False) 46 | parts: List[Tuple[DerivableProgram, Tuple[Type, U]]] = field(compare=False) 47 | 48 | def __repr__(self) -> str: 49 | return f"({self.priority}, {self.parts})" 50 | 51 | def make_program(self, g: ProbDetGrammar[U, V, W]) -> Program: 52 | return _build_(self.parts, g) 53 | 54 | 55 | class AStar( 56 | ProgramEnumerator[None], 57 | Generic[U, V, W], 58 | ): 59 | def __init__( 60 | self, 61 | G: ProbDetGrammar[U, V, W], 62 | filter: Optional[Filter[Program]] = None, 63 | ) -> None: 64 | super().__init__(filter) 65 | self.current: Optional[Program] = None 66 | 67 | self.G = G 68 | self.start = G.start 69 | self.rules = G.rules 70 | 71 | self.frontier: List[HeapElement[U]] = [] 72 | 73 | def probability(self, program: Program) -> float: 74 | return self.G.probability(program) 75 | 76 | @classmethod 77 | def name(cls) -> str: 78 | return "a-star" 79 | 80 | def generator(self) -> Generator[Program, None, None]: 81 | """ 82 | A generator which outputs the next most probable program 83 | """ 84 | first = (self.G.start[0], self.G.start[1][0]) # type: ignore 85 | heappush(self.frontier, HeapElement(0, [first], [])) 86 | 87 | while self.frontier: 88 | elem = heappop(self.frontier) 89 | if len(elem.to_expand) == 0: 90 | p = elem.make_program(self.G) 91 | if self._should_keep_subprogram(p): 92 | yield p 93 | else: 94 | partS = elem.to_expand.pop() 95 | S = (partS[0], (partS[1], None)) 96 | for P in self.G.rules[S]: # type: ignore 97 | args = self.G.rules[S][P][0] # type: ignore 98 | p = self.G.probabilities[S][P] # type: ignore 99 | new_el = HeapElement( 100 | elem.priority + p, # type: ignore 101 | elem.to_expand + list(args), 102 | elem.parts + [(P, S)], 103 | ) 104 | heappush(self.frontier, new_el) 105 | 106 | def merge_program(self, representative: Program, other: Program) -> None: 107 | """ 108 | Merge other into representative. 109 | In other words, other will no longer be generated through heap search 110 | """ 111 | pass 112 | 113 | def programs_in_banks(self) -> int: 114 | return 0 115 | 116 | def programs_in_queues(self) -> int: 117 | return len(self.frontier) 118 | 119 | def clone(self, G: Union[ProbDetGrammar, ProbUGrammar]) -> "AStar[U, V, W]": 120 | assert isinstance(G, ProbDetGrammar) 121 | enum = self.__class__(G) 122 | return enum 123 | 124 | 125 | def enumerate_prob_grammar(G: ProbDetGrammar[U, V, W]) -> AStar[U, V, W]: 126 | Gp: ProbDetGrammar = ProbDetGrammar( 127 | G.grammar, 128 | { 129 | S: {P: -np.log(p) for P, p in val.items() if p > 0} 130 | for S, val in G.probabilities.items() 131 | }, 132 | ) 133 | return AStar(Gp) 134 | -------------------------------------------------------------------------------- /synth/syntax/grammars/enumeration/program_enumerator.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Generator, 3 | Generic, 4 | Optional, 5 | TypeVar, 6 | Union, 7 | ) 8 | from abc import ABC, abstractmethod 9 | 10 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 11 | from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar 12 | from synth.syntax.program import Program 13 | from synth.filter import Filter 14 | 15 | 16 | U = TypeVar("U") 17 | 18 | 19 | class ProgramEnumerator(ABC, Generic[U]): 20 | """ 21 | Object that enumerates over programs. 22 | When a program is generated a feedback of type U is expected. 23 | If U is None then no feedback is expected. 24 | """ 25 | 26 | def __init__(self, filter: Optional[Filter[Program]] = None) -> None: 27 | super().__init__() 28 | self.filter = filter 29 | 30 | @classmethod 31 | @abstractmethod 32 | def name(cls) -> str: 33 | pass 34 | 35 | @abstractmethod 36 | def generator(self) -> Generator[Program, U, None]: 37 | pass 38 | 39 | def __iter__(self) -> Generator[Program, U, None]: 40 | return self.generator() 41 | 42 | @abstractmethod 43 | def programs_in_banks(self) -> int: 44 | pass 45 | 46 | @abstractmethod 47 | def programs_in_queues(self) -> int: 48 | pass 49 | 50 | @abstractmethod 51 | def probability(self, program: Program) -> float: 52 | """ 53 | Return the probability of generating the given program according to the grammar associated with this enumerator. 54 | """ 55 | pass 56 | 57 | def merge_program(self, representative: Program, other: Program) -> None: 58 | """ 59 | Merge other into representative. 60 | Function used for observational equivalence, that means other and representative are semantically equivalent for the current task. 61 | This is for a posteriori merging, it is rather inefficient compared to evaluating subprograms for most enumerative algorithms. 62 | """ 63 | pass 64 | 65 | def _should_keep_subprogram(self, program: Program) -> bool: 66 | return self.filter is None or self.filter.accept(program) 67 | 68 | @abstractmethod 69 | def clone( 70 | self, grammar: Union[ProbDetGrammar, ProbUGrammar] 71 | ) -> "ProgramEnumerator[U]": 72 | """ 73 | Clone this enumerator with the specified new grammar but remember every single program enumerated so that it does not enumerate them again. 74 | """ 75 | pass 76 | -------------------------------------------------------------------------------- /synth/syntax/grammars/grammar.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List, Tuple, Union 4 | 5 | from synth.syntax.program import Constant, Primitive, Program, Variable 6 | from synth.syntax.type_system import Type 7 | 8 | DerivableProgram = Union[Primitive, Variable, Constant] 9 | 10 | 11 | @dataclass(frozen=True) 12 | class NGram: 13 | n: int 14 | predecessors: List[Tuple[DerivableProgram, int]] = field(default_factory=lambda: []) 15 | 16 | def __hash__(self) -> int: 17 | return hash((self.n, tuple(self.predecessors))) 18 | 19 | def __str__(self) -> str: 20 | return str(self.predecessors) 21 | 22 | def __repr__(self) -> str: 23 | return self.__str__() 24 | 25 | def __len__(self) -> int: 26 | return len(self.predecessors) 27 | 28 | def successor(self, new_succ: Tuple[DerivableProgram, int]) -> "NGram": 29 | new_pred = [new_succ] + self.predecessors 30 | if len(new_pred) + 1 > self.n and self.n >= 0: 31 | new_pred.pop() 32 | return NGram(self.n, new_pred) 33 | 34 | def last(self) -> Tuple[DerivableProgram, int]: 35 | return self.predecessors[0] 36 | 37 | 38 | class Grammar(ABC): 39 | @abstractmethod 40 | def __contains__(self, program: Program) -> bool: 41 | pass 42 | 43 | @abstractmethod 44 | def name(self) -> str: 45 | """ 46 | Returns the name of this class of grammar. 47 | """ 48 | pass 49 | 50 | @abstractmethod 51 | def clean(self) -> None: 52 | """ 53 | Clean the grammar. 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def programs(self) -> int: 59 | """ 60 | Return the number of programs contained within this grammar. 61 | """ 62 | pass 63 | 64 | @abstractmethod 65 | def instantiate_constants(self, constants: Dict[Type, List[Any]]) -> "Grammar": 66 | """ 67 | Replace all occurences of non instantiated constants with all possible values of instantiated ones. 68 | """ 69 | pass 70 | -------------------------------------------------------------------------------- /synth/task.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import ( 3 | Any, 4 | Callable, 5 | Dict, 6 | Generic, 7 | Iterator, 8 | List, 9 | Optional, 10 | SupportsIndex, 11 | TypeVar, 12 | overload, 13 | Set, 14 | ) 15 | import pickle 16 | import bz2 17 | 18 | from synth.specification import TaskSpecification 19 | from synth.syntax.program import Program 20 | from synth.syntax.type_system import Type 21 | from synth.utils.data_storage import load_object, save_object 22 | 23 | 24 | T = TypeVar("T", bound=TaskSpecification) 25 | 26 | 27 | @dataclass 28 | class Task(Generic[T]): 29 | type_request: Type 30 | specification: T 31 | solution: Optional[Program] = field(default=None) 32 | metadata: Dict[str, Any] = field(default_factory=lambda: {}) 33 | 34 | def __str__(self) -> str: 35 | return "{} ({}, spec={}, {})".format( 36 | self.metadata.get("name", "Task"), 37 | self.solution or "no solution", 38 | self.specification, 39 | self.metadata, 40 | ) 41 | 42 | 43 | @dataclass 44 | class Dataset(Generic[T]): 45 | """ 46 | Represents a list of tasks in a given specification. 47 | """ 48 | 49 | tasks: List[Task[T]] 50 | metadata: Dict[str, Any] = field(default_factory=lambda: {}) 51 | 52 | def __len__(self) -> int: 53 | return len(self.tasks) 54 | 55 | def __iter__(self) -> Iterator[Task[T]]: 56 | return self.tasks.__iter__() 57 | 58 | @overload 59 | def __getitem__(self, key: SupportsIndex) -> Task[T]: 60 | pass 61 | 62 | @overload 63 | def __getitem__(self, key: slice) -> List[Task[T]]: 64 | pass 65 | 66 | def __getitem__(self, key: Any) -> Any: 67 | return self.tasks.__getitem__(key) 68 | 69 | def type_requests(self) -> Set[Type]: 70 | return set([task.type_request for task in self.tasks]) 71 | 72 | def save(self, path: str) -> None: 73 | """ 74 | Save this dataset in the specified file. 75 | The dataset file is compressed. 76 | """ 77 | save_object(path, self) 78 | 79 | @classmethod 80 | def load( 81 | cls, 82 | path: str, 83 | unpickler: Optional[Callable[[bz2.BZ2File], pickle.Unpickler]] = None, 84 | ) -> "Dataset[T]": 85 | """ 86 | Load the dataset object stored in this file. 87 | """ 88 | d: Dataset = load_object(path, unpickler) 89 | return d 90 | -------------------------------------------------------------------------------- /synth/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility objects and functions that do not fit elsewhere. 3 | """ 4 | 5 | import synth.utils.chrono as chrono 6 | from synth.utils.generator_utils import gen_take 7 | from synth.utils.data_storage import load_object, save_object 8 | -------------------------------------------------------------------------------- /synth/utils/data_storage.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import pickle 3 | from typing import Any, Callable, Optional 4 | import pickletools 5 | 6 | 7 | def load_object( 8 | path: str, unpickler: Optional[Callable[[bz2.BZ2File], pickle.Unpickler]] = None 9 | ) -> Any: 10 | """ 11 | Load an arbitrary object from the specified file. 12 | """ 13 | with bz2.BZ2File(path, "rb") as fd: 14 | if unpickler is None: 15 | return pickle.load(fd) 16 | else: 17 | return unpickler(fd).load() 18 | 19 | 20 | def save_object( 21 | path: str, obj: Any, optimize: bool = True, compress_level: int = 9 22 | ) -> None: 23 | """ 24 | Save an arbitrary object to the specified path. 25 | Compression level must be in 1-9 where 9 is the highest level. 26 | """ 27 | with bz2.BZ2File(path, "w", compresslevel=compress_level) as fd: 28 | content = pickle.dumps(obj) 29 | if optimize: 30 | content = pickletools.optimize(content) 31 | fd.write(content) 32 | 33 | 34 | def legacy_load_object(path: str, **kwargs: Any) -> Any: 35 | """ 36 | DEPRECATED 37 | Load an arbitrary object from the specified file. 38 | """ 39 | with open(path, "rb") as fd: 40 | return pickle.load(fd) 41 | 42 | 43 | def legacy_save_object(path: str, obj: Any, **kwargs: Any) -> None: 44 | """ 45 | DEPRECATED 46 | Save an arbitrary object to the specified path. 47 | """ 48 | with open(path, "wb") as fd: 49 | pickle.dump(obj, fd) 50 | -------------------------------------------------------------------------------- /synth/utils/generator_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generator, List, TypeVar 2 | 3 | from tqdm import trange 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def gen_take(gen: Generator[T, Any, Any], n: int, progress: bool = False) -> List[T]: 9 | """ 10 | Take the first n elements of a generator and return them as a list. 11 | """ 12 | out: List[T] = [] 13 | try: 14 | ite = range(n) if not progress else trange(n) 15 | for _ in ite: 16 | out.append(next(gen)) 17 | except StopIteration: 18 | pass 19 | return out 20 | -------------------------------------------------------------------------------- /synth/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from types import SimpleNamespace 3 | from typing import List, Tuple, TypeVar, Union, Iterable, Callable, Optional 4 | 5 | U = TypeVar("U") 6 | 7 | 8 | def __try_names__(name: str, f: Callable[[str], U], prefixes: List[str]) -> U: 9 | try: 10 | return f(name) 11 | except ModuleNotFoundError: 12 | l = prefixes[0] 13 | return __try_names__(l + "." + name, f, prefixes[1:]) 14 | 15 | 16 | def import_file_function( 17 | import_name: str, 18 | keys: Iterable[Union[str, Tuple[str, str]]], 19 | prefixes: List[str] = [], 20 | ) -> Callable[[bool], Optional[SimpleNamespace]]: 21 | """ 22 | Utility function that creates a simple loader for you if you only need 23 | > from name.name import X, Y, ... 24 | where "X, Y, ..." are elements of keys. 25 | 26 | prefixes is a list of prefixes to the import name to try in case of failure. 27 | """ 28 | 29 | def loader(fully_load: bool = True) -> Optional[SimpleNamespace]: 30 | if not fully_load: 31 | return __try_names__(import_name, importlib.util.find_spec, prefixes) # type: ignore 32 | 33 | module = __try_names__(import_name, importlib.import_module, prefixes) 34 | out = {} 35 | for key in keys: 36 | get, set = key if isinstance(key, tuple) else (key, key) 37 | out[set] = module.__getattribute__(get) 38 | return SimpleNamespace(**out) 39 | 40 | return loader 41 | -------------------------------------------------------------------------------- /synth/utils/ordered.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol 2 | from abc import abstractmethod 3 | 4 | 5 | class Ordered(Protocol): 6 | @abstractmethod 7 | def __le__(self, other: Any) -> bool: 8 | pass 9 | 10 | @abstractmethod 11 | def __lt__(self, other: Any) -> bool: 12 | pass 13 | 14 | @abstractmethod 15 | def __gt__(self, other: Any) -> bool: 16 | pass 17 | -------------------------------------------------------------------------------- /synth/utils/vose_polyfill.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | import numpy as np 3 | 4 | 5 | class PythonSampler: 6 | def __init__(self, weights: np.ndarray, seed: Optional[int] = None) -> None: 7 | self.rng = np.random.default_rng(seed or 1) 8 | n = len(weights) 9 | alias = np.zeros(n, dtype=int) 10 | proba = np.zeros(n, dtype=float) 11 | # Compute the average probability and cache it for later use. 12 | avg = 1.0 / n 13 | # Create two stacks to act as worklists as we populate the tables. 14 | small = [] 15 | large = [] 16 | # Populate the stacks with the input probabilities. 17 | for i in range(n): 18 | # If the probability is below the average probability, then we add it to the small 19 | # list; otherwise we add it to the large list. 20 | if weights[i] >= avg: 21 | large.append(i) 22 | else: 23 | small.append(i) 24 | # As a note: in the mathematical specification of the algorithm, we will always exhaust the 25 | # small list before the big list. However, due to floating point inaccuracies, this is not 26 | # necessarily true. Consequently, this inner loop (which tries to pair small and large 27 | # elements) will have to check that both lists aren't empty. 28 | while len(small) > 0 and len(large) > 0: 29 | # Get the index of the small and the large probabilities. 30 | less = small.pop(0) 31 | more = large.pop(0) 32 | # These probabilities have not yet been scaled up to be such that 1 / n is given weight 33 | # 1.0. We do this here instead. 34 | proba[less] = weights[less] * n 35 | alias[less] = more 36 | # Decrease the probability of the larger one by the appropriate amount. 37 | weights[more] = weights[more] + weights[less] - avg 38 | # If the new probability is less than the average, add it into the small list; 39 | # otherwise add it to the large list. 40 | if weights[more] >= avg: 41 | large.append(more) 42 | else: 43 | small.append(more) 44 | # At this point, everything is in one list, which means that the remaining probabilities 45 | # should all be 1 / n. Based on this, set them appropriately. Due to numerical issues, we 46 | # can't be sure which stack will hold the entries, so we empty both. 47 | while len(small) > 0: 48 | less = small.pop(0) 49 | proba[less] = 1.0 50 | while len(large) > 0: 51 | more = large.pop(0) 52 | proba[more] = 1.0 53 | self.n = n 54 | self.alias = alias 55 | self.proba = proba 56 | 57 | def sample_1(self) -> int: 58 | # Generate a fair die roll to determine which column to inspect. 59 | col = int(self.rng.uniform(0, self.n)) 60 | # Generate a biased coin toss to determine which option to pick. 61 | heads = self.rng.uniform() < 0.5 62 | 63 | # Based on the outcome, return either the column or its alias. 64 | if heads: 65 | return col 66 | return self.alias[col] # type: ignore 67 | 68 | def sample( 69 | self, k: int = 1, values: Optional[np.ndarray] = None 70 | ) -> Union[int, np.ndarray]: 71 | """Sample a random integer or a value from a given array. 72 | 73 | Parameters: 74 | k: The number of integers to sample. If `k = 1`, then a single int (or float if values is not None) is returned. In any 75 | other case, a numpy array is returned. 76 | values: The numpy array of values from which to sample from. 77 | 78 | """ 79 | if values is None: 80 | if k == 1: 81 | return self.sample_1() 82 | return np.asarray([self.sample_1() for _ in range(k)]) 83 | else: 84 | if k == 1: 85 | return values[self.sample_1()] # type: ignore 86 | return np.asarray([values[self.sample_1()] for _ in range(k)]) 87 | 88 | 89 | try: 90 | import vose 91 | 92 | Sampler = vose.Sampler 93 | except ImportError: 94 | Sampler = PythonSampler 95 | -------------------------------------------------------------------------------- /tests/filtering/constraints/test_parsing.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.ttcfg import TTCFG 2 | from synth.syntax.dsl import DSL 3 | from synth.syntax.program import Variable 4 | from synth.syntax.type_system import ( 5 | INT, 6 | STRING, 7 | List, 8 | PolymorphicType, 9 | PrimitiveType, 10 | ) 11 | from synth.syntax.type_helper import FunctionType 12 | from synth.filter.constraints.parsing import ( 13 | parse_specification, 14 | TokenAnything, 15 | TokenAllow, 16 | TokenAtLeast, 17 | TokenAtMost, 18 | TokenFunction, 19 | TokenForceSubtree, 20 | TokenForbidSubtree, 21 | ) 22 | 23 | 24 | syntax = { 25 | "+": FunctionType(INT, INT, INT), 26 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 27 | "non_reachable": PrimitiveType("non_reachable"), 28 | "1": INT, 29 | "non_productive": FunctionType(INT, STRING), 30 | } 31 | dsl = DSL(syntax) 32 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), 20) 33 | ONE = dsl.get_primitive("1") 34 | PLUS = dsl.get_primitive("+") 35 | 36 | 37 | def test_bases() -> None: 38 | assert parse_specification("_", cfg) == TokenAnything() 39 | assert parse_specification("#(1)<=3", cfg) == TokenAtMost([ONE], count=3) 40 | assert parse_specification("(+ #(1)<=1 _)", cfg) == TokenFunction( 41 | TokenAllow([PLUS]), args=[TokenAtMost([ONE], count=1), TokenAnything()] 42 | ) 43 | assert parse_specification("#(1,+)>=3", cfg) == TokenAtLeast([ONE, PLUS], count=3) 44 | assert parse_specification("#(1,+,+)>=4", cfg) == TokenAtLeast([ONE, PLUS], count=4) 45 | assert parse_specification("(+ 1 _)", cfg) == TokenFunction( 46 | TokenAllow([PLUS]), 47 | [TokenAllow([ONE]), TokenAnything()], 48 | ) 49 | assert parse_specification(">(var0)", cfg) == TokenForceSubtree([Variable(0, INT)]) 50 | assert parse_specification(">^(1,var0)", cfg) == TokenForbidSubtree( 51 | [ONE, Variable(0, INT)] 52 | ) 53 | -------------------------------------------------------------------------------- /tests/generation/test_sampler.py: -------------------------------------------------------------------------------- 1 | from synth.generation.sampler import LexiconSampler, ListSampler, UnionSampler 2 | from synth.syntax.type_system import ( 3 | BOOL, 4 | INT, 5 | STRING, 6 | List, 7 | Type, 8 | UnknownType, 9 | ) 10 | 11 | 12 | def test_lexicon_sampling() -> None: 13 | lexicon = list(range(100)) 14 | sampler = LexiconSampler(lexicon) 15 | for _ in range(1000): 16 | x = sampler.sample(type=INT) 17 | assert x in lexicon 18 | 19 | 20 | def test_list_sampling() -> None: 21 | lexicon = list(range(100)) 22 | sampler = LexiconSampler(lexicon) 23 | for max_depth in [2, 3, 4]: 24 | a = ListSampler(sampler, [0.2] * 5, max_depth=max_depth, seed=10) 25 | my_type: Type = INT 26 | for _ in range(max_depth - 1): 27 | my_type = List(my_type) 28 | for _ in range(100): 29 | l = a.sample(type=my_type) 30 | el = l 31 | for _ in range(max_depth - 1): 32 | assert isinstance( 33 | el, list 34 | ), f"Max depth:{max_depth} Type:{my_type} list:{l}" 35 | assert len(el) <= 5 and len(el) > 0 36 | el = l[0] 37 | b = a.sample(type=INT) 38 | assert isinstance(b, int) and b in lexicon 39 | 40 | 41 | def test_union_sampler() -> None: 42 | lexicon = list(range(100)) 43 | bool_lexicon = [True, False] 44 | str_lexicon = ["a", "b", "c", "d"] 45 | sampler = UnionSampler( 46 | { 47 | INT: LexiconSampler(lexicon), 48 | BOOL: LexiconSampler(bool_lexicon), 49 | STRING: LexiconSampler(str_lexicon), 50 | }, 51 | LexiconSampler([[], []]), 52 | ) 53 | 54 | for _ in range(100): 55 | x = sampler.sample(type=INT) 56 | assert isinstance(x, int) and x in lexicon 57 | y = sampler.sample(type=BOOL) 58 | assert isinstance(y, bool) and y in bool_lexicon 59 | z = sampler.sample(type=STRING) 60 | assert isinstance(z, str) and z in str_lexicon 61 | d = sampler.sample(type=UnknownType()) 62 | assert isinstance(d, list) and len(d) == 0 63 | 64 | 65 | def test_seeding() -> None: 66 | lexicon = list(range(100)) 67 | aint = LexiconSampler(lexicon, seed=10) 68 | bint = LexiconSampler(lexicon, seed=10) 69 | for _ in range(1000): 70 | assert aint.sample(type=INT) == bint.sample(type=INT), _ 71 | 72 | a = ListSampler(aint, [0.2] * 5, max_depth=3, seed=10) 73 | b = ListSampler(bint, [0.2] * 5, max_depth=3, seed=10) 74 | for _ in range(1000): 75 | assert a.sample(type=List(INT)) == b.sample(type=List(INT)) 76 | -------------------------------------------------------------------------------- /tests/pbe/solvers/test_pbe_solver.py: -------------------------------------------------------------------------------- 1 | from synth.semantic.evaluator import DSLEvaluator 2 | from synth.specification import PBE, Example 3 | from synth.syntax.grammars.enumeration.heap_search import enumerate_prob_grammar 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 6 | from synth.syntax.dsl import DSL 7 | from synth.syntax.type_system import ( 8 | INT, 9 | STRING, 10 | List, 11 | PolymorphicType, 12 | PrimitiveType, 13 | ) 14 | from synth.syntax.type_helper import FunctionType 15 | from synth.pbe.solvers import NaivePBESolver, CutoffPBESolver, PBESolver 16 | 17 | import pytest 18 | 19 | from synth.task import Task 20 | 21 | 22 | syntax = { 23 | "+": FunctionType(INT, INT, INT), 24 | "-": FunctionType(INT, INT, INT), 25 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 26 | "non_reachable": PrimitiveType("non_reachable"), 27 | "1": INT, 28 | "non_productive": FunctionType(INT, STRING), 29 | } 30 | 31 | semantics = {"+": lambda x: lambda y: x + y, "-": lambda x: lambda y: x - y, "1": 1} 32 | 33 | 34 | type_req = FunctionType(INT, INT) 35 | int_lexicon = list(range(-100, 100)) 36 | max_depth = 4 37 | dsl = DSL(syntax) 38 | evaluator = DSLEvaluator(dsl.instantiate_semantics(semantics)) 39 | testdata = [ 40 | NaivePBESolver(evaluator), 41 | CutoffPBESolver(evaluator), 42 | ] 43 | 44 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), 4) 45 | pcfg = ProbDetGrammar.uniform(cfg) 46 | 47 | 48 | tasks = [ 49 | Task(cfg.type_request, PBE([Example([x], x + 2) for x in [3, 4, 9, 12]])), 50 | Task(cfg.type_request, PBE([Example([x], x - 2) for x in [3, 4, 9, 12]])), 51 | ] 52 | 53 | 54 | @pytest.mark.parametrize("solver", testdata) 55 | def test_solving(solver: PBESolver) -> None: 56 | for task in tasks: 57 | failed = True 58 | for program in solver.solve(task, enumerate_prob_grammar(pcfg), 10): 59 | for example in task.specification.examples: 60 | assert evaluator.eval(program, example.inputs) == example.output 61 | failed = False 62 | assert solver._score > 0 63 | break 64 | assert not failed 65 | -------------------------------------------------------------------------------- /tests/pbe/solvers/test_restart_pbe_solver.py: -------------------------------------------------------------------------------- 1 | from synth.semantic.evaluator import DSLEvaluator 2 | from synth.specification import PBE, Example 3 | from synth.syntax.grammars.enumeration.heap_search import enumerate_prob_grammar 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 6 | from synth.syntax.dsl import DSL 7 | from synth.syntax.type_system import ( 8 | INT, 9 | STRING, 10 | List, 11 | PolymorphicType, 12 | PrimitiveType, 13 | ) 14 | from synth.syntax.type_helper import FunctionType 15 | from synth.pbe.solvers import NaivePBESolver, ObsEqPBESolver, CutoffPBESolver, PBESolver 16 | from synth.pbe.solvers.restart_pbe_solver import RestartPBESolver 17 | 18 | import pytest 19 | 20 | from synth.task import Task 21 | 22 | 23 | syntax = { 24 | "+": FunctionType(INT, INT, INT), 25 | "-": FunctionType(INT, INT, INT), 26 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 27 | "non_reachable": PrimitiveType("non_reachable"), 28 | "1": INT, 29 | "non_productive": FunctionType(INT, STRING), 30 | } 31 | 32 | semantics = {"+": lambda x: lambda y: x + y, "-": lambda x: lambda y: x - y, "1": 1} 33 | 34 | 35 | type_req = FunctionType(INT, INT) 36 | max_depth = 4 37 | dsl = DSL(syntax) 38 | evaluator = DSLEvaluator(dsl.instantiate_semantics(semantics)) 39 | testdata = [ 40 | NaivePBESolver(evaluator), 41 | ObsEqPBESolver(evaluator), 42 | CutoffPBESolver(evaluator), 43 | ] 44 | 45 | cfg = CFG.depth_constraint(dsl, type_req, max_depth) 46 | pcfg = ProbDetGrammar.uniform(cfg) 47 | 48 | 49 | tasks = [ 50 | Task(cfg.type_request, PBE([Example([x], x + 2) for x in range(50)])), 51 | Task(cfg.type_request, PBE([Example([x], x - 2) for x in range(50)])), 52 | ] 53 | 54 | 55 | @pytest.mark.parametrize("solver", testdata) 56 | def test_solving(solver: PBESolver) -> None: 57 | real_solver = RestartPBESolver( 58 | solver.evaluator, 59 | lambda *args, **kwargs: solver, 60 | restart_criterion=lambda self: len(self._data) - self._last_size > 3, 61 | ) 62 | for task in tasks: 63 | failed = True 64 | real_solver.reset_stats() 65 | for program in real_solver.solve(task, enumerate_prob_grammar(pcfg), 5): 66 | for example in task.specification.examples: 67 | assert evaluator.eval(program, example.inputs) == example.output 68 | failed = False 69 | if real_solver._restarts <= 0: 70 | continue 71 | break 72 | assert not failed 73 | -------------------------------------------------------------------------------- /tests/pbe/test_io_encoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from synth.pbe.io_encoder import IOEncoder 6 | 7 | from synth.task import Task, Dataset 8 | from synth.specification import PBE, Example 9 | from synth.syntax.type_system import ( 10 | INT, 11 | List, 12 | ) 13 | from synth.syntax.type_helper import FunctionType 14 | 15 | 16 | def test_encoding() -> None: 17 | random.seed(0) 18 | dataset = Dataset( 19 | [ 20 | Task( 21 | FunctionType(INT, List(INT), INT), 22 | PBE( 23 | [ 24 | Example( 25 | [random.randint(0, 100), [random.randint(0, 100)]], 26 | random.randint(0, 100), 27 | ) 28 | for _ in range(random.randint(3, 8)) 29 | ] 30 | ), 31 | metadata={"index": i}, 32 | ) 33 | for i in range(100) 34 | ], 35 | metadata={"something": False, "else": "is", "coming": 42}, 36 | ) 37 | for output_dim in [32, 64, 512]: 38 | encoder = IOEncoder(output_dim, list(range(100 + 1))) 39 | for task in dataset: 40 | encoded = encoder.encode(task) 41 | assert encoded.shape == torch.Size( 42 | [len(task.specification.examples), output_dim] 43 | ) 44 | assert torch.min(encoded).item() >= 0 45 | assert torch.max(encoded).item() < len(encoder.lexicon) 46 | -------------------------------------------------------------------------------- /tests/pbe/test_task_generator.py: -------------------------------------------------------------------------------- 1 | from synth.generation.sampler import LexiconSampler 2 | from synth.pbe.task_generator import TaskGenerator, basic_output_validator 3 | from synth.semantic.evaluator import DSLEvaluator 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 6 | from synth.syntax.dsl import DSL 7 | from synth.syntax.type_system import ( 8 | INT, 9 | STRING, 10 | List, 11 | PolymorphicType, 12 | PrimitiveType, 13 | ) 14 | from synth.syntax.type_helper import FunctionType 15 | 16 | 17 | syntax = { 18 | "+": FunctionType(INT, INT, INT), 19 | "-": FunctionType(INT, INT, INT), 20 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 21 | "non_reachable": PrimitiveType("non_reachable"), 22 | "1": INT, 23 | "non_productive": FunctionType(INT, STRING), 24 | } 25 | 26 | semantics = {"+": lambda x: lambda y: x + y, "-": lambda x: lambda y: x - y, "1": 1} 27 | 28 | type_req = FunctionType(INT, INT) 29 | int_lexicon = list(range(-100, 100)) 30 | max_depth = 4 31 | dsl = DSL(syntax) 32 | validator = basic_output_validator({int: int_lexicon}, -1) 33 | 34 | 35 | def test_gen() -> None: 36 | samples_lexicon = [2, 3, 4] 37 | pcfg = ProbDetGrammar.uniform(CFG.depth_constraint(dsl, type_req, max_depth)) 38 | pcfg.init_sampling(20) 39 | g = TaskGenerator( 40 | LexiconSampler(int_lexicon, seed=10), 41 | DSLEvaluator(dsl.instantiate_semantics(semantics)), 42 | LexiconSampler([type_req], seed=10), 43 | LexiconSampler(samples_lexicon, [0.25, 0.5, 0.25], seed=10), 44 | {pcfg}, 45 | validator, 46 | ) 47 | for _ in range(100): 48 | task = g.generate_task() 49 | assert task.type_request == type_req 50 | assert len(task.specification.examples) in samples_lexicon 51 | assert task.solution 52 | assert task.solution.depth() <= max_depth 53 | for ex in task.specification.examples: 54 | assert validator(ex.output) 55 | assert len(ex.inputs) == 1 56 | assert all(x in int_lexicon for x in ex.inputs) 57 | 58 | 59 | def test_seed() -> None: 60 | pcfg = ProbDetGrammar.uniform(CFG.depth_constraint(dsl, type_req, max_depth)) 61 | pcfg.init_sampling(10) 62 | g1 = TaskGenerator( 63 | LexiconSampler(int_lexicon, seed=10), 64 | DSLEvaluator(dsl.instantiate_semantics(semantics)), 65 | LexiconSampler([type_req], seed=10), 66 | LexiconSampler([2, 3, 4], [0.25, 0.5, 0.25], seed=10), 67 | {pcfg}, 68 | validator, 69 | ) 70 | pcfg = ProbDetGrammar.uniform(CFG.depth_constraint(dsl, type_req, max_depth)) 71 | pcfg.init_sampling(10) 72 | g2 = TaskGenerator( 73 | LexiconSampler(int_lexicon, seed=10), 74 | DSLEvaluator(dsl.instantiate_semantics(semantics)), 75 | LexiconSampler([type_req], seed=10), 76 | LexiconSampler([2, 3, 4], [0.25, 0.5, 0.25], seed=10), 77 | {pcfg}, 78 | validator, 79 | ) 80 | for _ in range(100): 81 | assert g1.generate_task() == g2.generate_task() 82 | 83 | 84 | test_gen() 85 | -------------------------------------------------------------------------------- /tests/semantic/test_evaluator.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.cfg import CFG 2 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 3 | from synth.semantic.evaluator import DSLEvaluator, __tuplify__ 4 | from synth.syntax.dsl import DSL 5 | from synth.syntax.type_system import ( 6 | INT, 7 | STRING, 8 | List, 9 | PolymorphicType, 10 | PrimitiveType, 11 | ) 12 | from synth.syntax.type_helper import FunctionType 13 | 14 | 15 | syntax = { 16 | "+1": FunctionType(INT, INT), 17 | "0": INT, 18 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 19 | "non_reachable": PrimitiveType("non_reachable"), 20 | "non_productive": FunctionType(INT, STRING), 21 | } 22 | 23 | semantics = { 24 | "+1": lambda x: x + 1, 25 | "0": 0, 26 | } 27 | max_depth = 4 28 | dsl = DSL(syntax) 29 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 30 | 31 | other_syntax = {"+1": FunctionType(INT, INT), "0": INT, "2": INT, "True": STRING} 32 | 33 | other_semantics = { 34 | "+1": lambda x: x + 1, 35 | "0": 0, 36 | "True": True, 37 | "2": 2, 38 | } 39 | other_dsl = DSL(other_syntax) 40 | other_cfg = CFG.depth_constraint(other_dsl, FunctionType(INT, INT), max_depth) 41 | 42 | 43 | def test_eval() -> None: 44 | eval = DSLEvaluator(dsl.instantiate_semantics(semantics)) 45 | pcfg = ProbDetGrammar.uniform(cfg) 46 | pcfg.init_sampling(0) 47 | for _ in range(100): 48 | program = pcfg.sample_program() 49 | try: 50 | for i in range(-25, 25): 51 | if len(program.used_variables()) == 0: 52 | assert eval.eval(program, [i]) == program.size() - 1 53 | else: 54 | assert eval.eval(program, [i]) == program.size() + i - 1 55 | except Exception as e: 56 | assert False, e 57 | 58 | 59 | def test_supports_list() -> None: 60 | eval = DSLEvaluator(dsl.instantiate_semantics(semantics)) 61 | pcfg = ProbDetGrammar.uniform(cfg) 62 | pcfg.init_sampling(0) 63 | for _ in range(100): 64 | program = pcfg.sample_program() 65 | try: 66 | for i in range(-25, 25): 67 | if len(program.used_variables()) == 0: 68 | assert eval.eval(program, [i]) == program.size() - 1 69 | else: 70 | assert eval.eval(program, [i]) == program.size() + i - 1 71 | except Exception as e: 72 | assert False, e 73 | 74 | 75 | def test_use_cache() -> None: 76 | eval = DSLEvaluator(dsl.instantiate_semantics(semantics)) 77 | pcfg = ProbDetGrammar.uniform(cfg) 78 | pcfg.init_sampling(0) 79 | for _ in range(100): 80 | program = pcfg.sample_program() 81 | try: 82 | for i in range(-25, 25): 83 | if len(program.used_variables()) == 0: 84 | assert eval.eval(program, [i]) == program.size() - 1 85 | assert eval._cache[__tuplify__([i])][program] == program.size() - 1 86 | else: 87 | assert eval.eval(program, [i]) == program.size() + i - 1 88 | assert ( 89 | eval._cache[__tuplify__([i])][program] == program.size() + i - 1 90 | ) 91 | except Exception as e: 92 | assert False, e 93 | 94 | 95 | def test_compress() -> None: 96 | eval = DSLEvaluator(other_dsl.instantiate_semantics(other_semantics)) 97 | p = other_dsl.auto_parse_program("(+1 0)") 98 | pp = other_dsl.auto_parse_program("1", constants={"1": (INT, 1)}) 99 | c = eval.compress(p) 100 | assert c != p 101 | assert c == pp 102 | p = other_dsl.auto_parse_program("(+1 (+1 0))") 103 | pp = other_dsl.auto_parse_program("2") 104 | c = eval.compress(p) 105 | assert c == pp 106 | -------------------------------------------------------------------------------- /tests/syntax/automata/test_tree_automaton.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from typing import Dict, Set, Tuple 3 | 4 | from synth.syntax.dsl import DSL 5 | from synth.syntax.type_system import ( 6 | INT, 7 | STRING, 8 | List, 9 | PolymorphicType, 10 | PrimitiveType, 11 | Type, 12 | ) 13 | from synth.syntax.type_helper import FunctionType 14 | from synth.syntax.grammars.grammar import DerivableProgram, NGram 15 | from synth.syntax.automata.tree_automaton import DFTA 16 | from synth.syntax.grammars.cfg import CFG 17 | 18 | 19 | def cfg2dfta( 20 | grammar: CFG, 21 | ) -> DFTA[Tuple[Type, int], DerivableProgram]: 22 | StateT = Tuple[Type, int] 23 | dfta_rules: Dict[Tuple[DerivableProgram, Tuple[StateT, ...]], StateT] = {} 24 | max_depth = grammar.max_program_depth() 25 | all_cases: Dict[ 26 | Tuple[int, Tuple[Type, ...]], Set[Tuple[Tuple[Type, int], ...]] 27 | ] = {} 28 | for S in grammar.rules: 29 | for P in grammar.rules[S]: 30 | args = grammar.rules[S][P][0] 31 | if len(args) == 0: 32 | dfta_rules[(P, ())] = (P.type, 0) 33 | else: 34 | key = (len(args), tuple([arg[0] for arg in args])) 35 | if key not in all_cases: 36 | all_cases[key] = set( 37 | [ 38 | tuple(x) 39 | for x in product( 40 | *[ 41 | [(arg[0], j) for j in range(max_depth)] 42 | for arg in args 43 | ] 44 | ) 45 | ] 46 | ) 47 | for nargs in all_cases[key]: 48 | dfta_rules[(P, nargs)] = ( 49 | S[0], 50 | max(i for _, i in nargs) + 1, 51 | ) 52 | r = grammar.type_request.returns() 53 | return DFTA(dfta_rules, {(r, x) for x in range(max_depth)}) 54 | 55 | 56 | import pytest 57 | 58 | syntax = { 59 | "+": FunctionType(INT, INT, INT), 60 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 61 | "non_reachable": PrimitiveType("non_reachable"), 62 | "1": INT, 63 | "non_productive": FunctionType(INT, STRING), 64 | } 65 | dsl = DSL(syntax) 66 | max_depths = [3, 7, 11] 67 | 68 | 69 | @pytest.mark.parametrize("max_depth", max_depths) 70 | def test_reduce(max_depth: int) -> None: 71 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth, n_gram=1) 72 | dfta = cfg2dfta(cfg) 73 | dfta.reduce() 74 | for (P, args), dst in dfta.rules.items(): 75 | assert not ( 76 | all(x == 0 for x in args) and len(args) > 0 77 | ), f"Unreachable rule: {P} {args}" 78 | assert dst != max_depth, f"Unproductive rule: {P} {args} -> {dst}" 79 | 80 | 81 | @pytest.mark.parametrize("max_depth", max_depths) 82 | def test_states(max_depth: int) -> None: 83 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth, n_gram=1) 84 | dfta = cfg2dfta(cfg) 85 | dfta.reduce() 86 | for (P, args), dst in dfta.rules.items(): 87 | if dst[1] < 1: 88 | continue 89 | state = (dst[0], ((NGram(1), dst[1] - 1), None)) 90 | assert state in cfg.rules 91 | assert P in cfg.rules[state] 92 | assert all(a[0] == b[0] for a, b in zip(args, cfg.rules[state][P][0])) 93 | 94 | 95 | @pytest.mark.parametrize("max_depth", max_depths) 96 | def test_minimise(max_depth: int) -> None: 97 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 98 | dfta = cfg2dfta(cfg) 99 | dfta.reduce() 100 | ndfta = dfta.minimise() 101 | for P, args in ndfta.rules: 102 | assert not (all(x == (0,) for x in args) and len(args) > 0) 103 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_a_star.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.a_star import ( 2 | enumerate_prob_grammar, 3 | ) 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.ttcfg import TTCFG 6 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 7 | from synth.syntax.dsl import DSL 8 | from synth.syntax.type_system import ( 9 | INT, 10 | STRING, 11 | List, 12 | PolymorphicType, 13 | PrimitiveType, 14 | ) 15 | from synth.syntax.type_helper import FunctionType, auto_type 16 | 17 | import pytest 18 | 19 | 20 | syntax = { 21 | "+": FunctionType(INT, INT, INT), 22 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 23 | "non_reachable": PrimitiveType("non_reachable"), 24 | "1": INT, 25 | "2": INT, 26 | "non_productive": FunctionType(INT, STRING), 27 | } 28 | dsl = DSL(syntax) 29 | dsl.instantiate_polymorphic_types() 30 | testdata = [ 31 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 3), 32 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 4), 33 | ] 34 | 35 | 36 | @pytest.mark.parametrize("cfg", testdata) 37 | def test_unicity_a_star(cfg: TTCFG) -> None: 38 | pcfg = ProbDetGrammar.uniform(cfg) 39 | seen = set() 40 | print(cfg) 41 | for program in enumerate_prob_grammar(pcfg): 42 | assert program not in seen 43 | seen.add(program) 44 | # print(pcfg.grammar) 45 | assert len(seen) == cfg.programs() 46 | 47 | 48 | @pytest.mark.parametrize("cfg", testdata) 49 | def test_order_a_star(cfg: TTCFG) -> None: 50 | pcfg = ProbDetGrammar.uniform(cfg) 51 | last = 1.0 52 | for program in enumerate_prob_grammar(pcfg): 53 | p = pcfg.probability(program) 54 | assert p <= last 55 | last = p 56 | 57 | 58 | def test_infinite() -> None: 59 | pcfg = ProbDetGrammar.random( 60 | CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1 61 | ) 62 | count = 10000 63 | last = 1.0 64 | for program in enumerate_prob_grammar(pcfg): 65 | count -= 1 66 | p = pcfg.probability(program) 67 | assert -1e-12 <= last - p, f"failed at program n°{count}:{program}" 68 | last = p 69 | if count < 0: 70 | break 71 | assert count == -1 72 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_beap_search.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.beap_search import ( 2 | enumerate_prob_grammar, 3 | ) 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.ttcfg import TTCFG 6 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 7 | from synth.syntax.dsl import DSL 8 | from synth.syntax.type_system import ( 9 | INT, 10 | STRING, 11 | List, 12 | PolymorphicType, 13 | PrimitiveType, 14 | ) 15 | from synth.syntax.type_helper import FunctionType, auto_type 16 | 17 | import pytest 18 | 19 | 20 | syntax = { 21 | "+": FunctionType(INT, INT, INT), 22 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 23 | "non_reachable": PrimitiveType("non_reachable"), 24 | "1": INT, 25 | "2": INT, 26 | "non_productive": FunctionType(INT, STRING), 27 | } 28 | dsl = DSL(syntax) 29 | dsl.instantiate_polymorphic_types() 30 | testdata = [ 31 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 3), 32 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 4), 33 | ] 34 | 35 | 36 | @pytest.mark.parametrize("cfg", testdata) 37 | def test_unicity_beep_search(cfg: TTCFG) -> None: 38 | pcfg = ProbDetGrammar.uniform(cfg) 39 | seen = set() 40 | for program in enumerate_prob_grammar(pcfg): 41 | assert program not in seen 42 | seen.add(program) 43 | # print(pcfg.grammar) 44 | assert len(seen) == cfg.programs() 45 | 46 | 47 | @pytest.mark.parametrize("cfg", testdata) 48 | def test_order_beep_search(cfg: TTCFG) -> None: 49 | pcfg = ProbDetGrammar.uniform(cfg) 50 | last = 1.0 51 | for program in enumerate_prob_grammar(pcfg): 52 | p = pcfg.probability(program) 53 | assert p <= last 54 | last = p 55 | 56 | 57 | def test_infinite() -> None: 58 | pcfg = ProbDetGrammar.random( 59 | CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1 60 | ) 61 | count = 10000 62 | last = 1.0 63 | for program in enumerate_prob_grammar(pcfg): 64 | count -= 1 65 | p = pcfg.probability(program) 66 | assert -1e-12 <= last - p, f"failed at program n°{count}:{program}" 67 | last = p 68 | if count < 0: 69 | break 70 | assert count == -1 71 | 72 | 73 | @pytest.mark.parametrize("cfg", testdata) 74 | def test_merge(cfg: TTCFG) -> None: 75 | pcfg = ProbDetGrammar.uniform(cfg) 76 | seen = set() 77 | for program in enumerate_prob_grammar(pcfg): 78 | assert program not in seen 79 | seen.add(program) 80 | en = enumerate_prob_grammar(pcfg) 81 | removed = dsl.parse_program("(+ 1 1)", auto_type("int")) 82 | en.merge_program(dsl.parse_program("2", auto_type("int")), removed) 83 | new_seen = set() 84 | for program in en: 85 | assert removed not in program 86 | new_seen.add(program) 87 | diff = seen.difference(new_seen) 88 | for x in diff: 89 | assert removed in x 90 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_bee_search.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.bee_search import ( 2 | enumerate_prob_grammar, 3 | ) 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.grammars.ttcfg import TTCFG 6 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 7 | from synth.syntax.dsl import DSL 8 | from synth.syntax.type_system import ( 9 | INT, 10 | STRING, 11 | List, 12 | PolymorphicType, 13 | PrimitiveType, 14 | ) 15 | from synth.syntax.type_helper import FunctionType, auto_type 16 | 17 | import pytest 18 | 19 | 20 | syntax = { 21 | "+": FunctionType(INT, INT, INT), 22 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 23 | "non_reachable": PrimitiveType("non_reachable"), 24 | "1": INT, 25 | "2": INT, 26 | "non_productive": FunctionType(INT, STRING), 27 | } 28 | dsl = DSL(syntax) 29 | dsl.instantiate_polymorphic_types() 30 | testdata = [ 31 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 3), 32 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 4), 33 | ] 34 | 35 | 36 | @pytest.mark.parametrize("cfg", testdata) 37 | def test_unicity_beeSearch(cfg: TTCFG) -> None: 38 | pcfg = ProbDetGrammar.uniform(cfg) 39 | seen = set() 40 | for program in enumerate_prob_grammar(pcfg): 41 | assert program not in seen 42 | seen.add(program) 43 | # print(pcfg.grammar) 44 | assert len(seen) == cfg.programs() 45 | 46 | 47 | @pytest.mark.parametrize("cfg", testdata) 48 | def test_order_beeSearch(cfg: TTCFG) -> None: 49 | pcfg = ProbDetGrammar.uniform(cfg) 50 | last = 1.0 51 | for program in enumerate_prob_grammar(pcfg): 52 | p = pcfg.probability(program) 53 | assert p <= last 54 | last = p 55 | 56 | 57 | @pytest.mark.parametrize("cfg", testdata) 58 | def test_merge(cfg: TTCFG) -> None: 59 | pcfg = ProbDetGrammar.uniform(cfg) 60 | seen = set() 61 | for program in enumerate_prob_grammar(pcfg): 62 | assert program not in seen 63 | seen.add(program) 64 | en = enumerate_prob_grammar(pcfg) 65 | removed = dsl.parse_program("(+ 1 1)", auto_type("int")) 66 | en.merge_program(dsl.parse_program("2", auto_type("int")), removed) 67 | new_seen = set() 68 | for program in en: 69 | assert removed not in program 70 | new_seen.add(program) 71 | diff = seen.difference(new_seen) 72 | for x in diff: 73 | assert removed in x 74 | 75 | 76 | def test_infinite() -> None: 77 | pcfg = ProbDetGrammar.random( 78 | CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1 79 | ) 80 | count = 10000 81 | last = 1.0 82 | for program in enumerate_prob_grammar(pcfg): 83 | count -= 1 84 | p = pcfg.probability(program) 85 | assert -1e-12 <= last - p, f"failed at program n°{count}:{program}" 86 | last = p 87 | if count < 0: 88 | break 89 | assert count == -1 90 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_constant_delay.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.constant_delay import ( 2 | enumerate_prob_grammar, 3 | ) 4 | from synth.syntax.grammars.enumeration.beap_search import ( 5 | enumerate_prob_grammar as enumerate, 6 | ) 7 | 8 | from synth.syntax.grammars.cfg import CFG 9 | from synth.syntax.grammars.ttcfg import TTCFG 10 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 11 | from synth.syntax.dsl import DSL 12 | from synth.syntax.type_system import ( 13 | INT, 14 | STRING, 15 | List, 16 | PolymorphicType, 17 | PrimitiveType, 18 | ) 19 | from synth.syntax.type_helper import FunctionType, auto_type 20 | 21 | import numpy as np 22 | 23 | import pytest 24 | 25 | 26 | syntax = { 27 | "+": FunctionType(INT, INT, INT), 28 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 29 | "non_reachable": PrimitiveType("non_reachable"), 30 | "1": INT, 31 | "2": INT, 32 | "non_productive": FunctionType(INT, STRING), 33 | } 34 | dsl = DSL(syntax) 35 | dsl.instantiate_polymorphic_types() 36 | testdata = [ 37 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 3), 38 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 4), 39 | ] 40 | kvals = [4, 16, 64] 41 | precision = [1e-2, 1e-4, 1e-8] 42 | 43 | 44 | @pytest.mark.parametrize("cfg", testdata) 45 | @pytest.mark.parametrize("k", kvals) 46 | @pytest.mark.parametrize("precis", precision) 47 | def test_unicity_beep_search(cfg: TTCFG, k: int, precis: float) -> None: 48 | pcfg = ProbDetGrammar.uniform(cfg) 49 | seen = set() 50 | for program in enumerate_prob_grammar(pcfg, k, precis): 51 | assert program not in seen 52 | seen.add(program) 53 | # print(pcfg.grammar) 54 | assert len(seen) == cfg.programs() 55 | 56 | 57 | @pytest.mark.parametrize("cfg", testdata) 58 | @pytest.mark.parametrize("k", kvals) 59 | @pytest.mark.parametrize("precis", precision) 60 | def test_order_beep_search(cfg: TTCFG, k: int, precis: float) -> None: 61 | # pcfg = ProbDetGrammar.uniform(cfg) 62 | pcfg = ProbDetGrammar.random(cfg, seed=4) 63 | last = 1.0 64 | for program in enumerate_prob_grammar(pcfg, k, precis): 65 | p = pcfg.probability(program) 66 | assert p <= last or abs(p / last) <= 1 + precis * (2**4) 67 | last = p 68 | 69 | 70 | @pytest.mark.parametrize("k", kvals) 71 | @pytest.mark.parametrize("precis", precision) 72 | def test_infinite(k: int, precis: float) -> None: 73 | pcfg = ProbDetGrammar.random( 74 | CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1 75 | ) 76 | count = 10000 77 | last = 1.0 78 | for program in enumerate_prob_grammar(pcfg, k, precis): 79 | count -= 1 80 | p = pcfg.probability(program) 81 | assert ( 82 | -1e-12 <= last - p or abs(p / last) <= 1 + precis 83 | ), f"failed at program n°{count}:{program}, p={p} last={last}, p={np.log(p)} last={np.log(last)}" 84 | last = p 85 | if count < 0: 86 | break 87 | assert count == -1 88 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_grammar_splitter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from synth.syntax.grammars.enumeration.u_heap_search import enumerate_prob_u_grammar 3 | from synth.syntax.grammars.enumeration.grammar_splitter import split 4 | from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar 5 | from synth.syntax.grammars.u_cfg import UCFG 6 | from synth.syntax.dsl import DSL 7 | from synth.syntax.type_system import ( 8 | INT, 9 | STRING, 10 | List, 11 | PolymorphicType, 12 | PrimitiveType, 13 | ) 14 | from synth.syntax.type_helper import FunctionType 15 | 16 | 17 | syntax = { 18 | "+": FunctionType(INT, INT, INT), 19 | "-": FunctionType(INT, INT, INT), 20 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 21 | "non_reachable": PrimitiveType("non_reachable"), 22 | "1": INT, 23 | "non_productive": FunctionType(INT, STRING), 24 | } 25 | dsl = DSL(syntax) 26 | pucfg = ProbUGrammar.uniform(UCFG.depth_constraint(dsl, FunctionType(INT, INT), 3)) 27 | testdata = list(range(2, 5)) 28 | seen = set() 29 | for program in enumerate_prob_u_grammar(pucfg): 30 | seen.add(program) 31 | 32 | 33 | @pytest.mark.parametrize("splits", testdata) 34 | def test_unicity(splits: int) -> None: 35 | fragments, _ = split(pucfg, splits, desired_ratio=1.05) 36 | seen = set() 37 | for sub_pcfg in fragments: 38 | print(sub_pcfg) 39 | for program in enumerate_prob_u_grammar(sub_pcfg): 40 | assert program not in seen 41 | seen.add(program) 42 | 43 | 44 | @pytest.mark.parametrize("splits", testdata) 45 | def test_none_missing(splits: int) -> None: 46 | fragments, _ = split(pucfg, splits, desired_ratio=1.05) 47 | new_seen = set() 48 | for sub_pcfg in fragments: 49 | a = set() 50 | for program in enumerate_prob_u_grammar(sub_pcfg): 51 | a.add(program) 52 | new_seen |= a 53 | assert len(new_seen.difference(seen)) == 0, new_seen.difference(seen) 54 | assert len(seen.difference(new_seen)) == 0, seen.difference(new_seen) 55 | -------------------------------------------------------------------------------- /tests/syntax/grammars/enumeration/test_heap_search.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.enumeration.heap_search import ( 2 | Bucket, 3 | enumerate_prob_grammar, 4 | enumerate_bucket_prob_grammar, 5 | ) 6 | from synth.syntax.grammars.cfg import CFG 7 | from synth.syntax.grammars.ttcfg import TTCFG 8 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 9 | from synth.syntax.dsl import DSL 10 | from synth.syntax.type_system import ( 11 | INT, 12 | STRING, 13 | List, 14 | PolymorphicType, 15 | PrimitiveType, 16 | ) 17 | from synth.syntax.type_helper import FunctionType, auto_type 18 | 19 | import pytest 20 | 21 | 22 | syntax = { 23 | "+": FunctionType(INT, INT, INT), 24 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 25 | "non_reachable": PrimitiveType("non_reachable"), 26 | "1": INT, 27 | "2": INT, 28 | "non_productive": FunctionType(INT, STRING), 29 | } 30 | dsl = DSL(syntax) 31 | dsl.instantiate_polymorphic_types() 32 | testdata = [ 33 | CFG.depth_constraint(dsl, FunctionType(INT, INT), 3), 34 | TTCFG.size_constraint(dsl, FunctionType(INT, INT), 5), 35 | ] 36 | 37 | 38 | @pytest.mark.parametrize("cfg", testdata) 39 | def test_unicity_heapSearch(cfg: TTCFG) -> None: 40 | pcfg = ProbDetGrammar.uniform(cfg) 41 | seen = set() 42 | for program in enumerate_prob_grammar(pcfg): 43 | assert program not in seen 44 | seen.add(program) 45 | assert len(seen) == cfg.programs() 46 | 47 | 48 | @pytest.mark.parametrize("cfg", testdata) 49 | def test_order_heapSearch(cfg: TTCFG) -> None: 50 | pcfg = ProbDetGrammar.uniform(cfg) 51 | last = 1.0 52 | for program in enumerate_prob_grammar(pcfg): 53 | p = pcfg.probability(program) 54 | assert p <= last 55 | last = p 56 | 57 | 58 | @pytest.mark.parametrize("cfg", testdata) 59 | def test_threshold(cfg: TTCFG) -> None: 60 | pcfg = ProbDetGrammar.uniform(cfg) 61 | threshold = 0.15 62 | seen = set() 63 | for program in enumerate_prob_grammar(pcfg): 64 | p = pcfg.probability(program) 65 | if p <= threshold: 66 | break 67 | seen.add(p) 68 | seent = set() 69 | for program in enumerate_prob_grammar(pcfg, threshold): 70 | p = pcfg.probability(program) 71 | assert p > threshold 72 | seent.add(p) 73 | 74 | assert len(seent.symmetric_difference(seen)) == 0 75 | 76 | 77 | @pytest.mark.parametrize("cfg", testdata) 78 | def test_unicity_bucketSearch(cfg: TTCFG) -> None: 79 | pcfg = ProbDetGrammar.uniform(cfg) 80 | for bucketSize in range(3, 10): 81 | seen = set() 82 | for program in enumerate_bucket_prob_grammar(pcfg, bucket_size=bucketSize): 83 | assert program not in seen 84 | seen.add(program) 85 | assert len(seen) == cfg.programs() 86 | 87 | 88 | @pytest.mark.parametrize("cfg", testdata) 89 | def test_order_bucketSearch(cfg: TTCFG) -> None: 90 | pcfg = ProbDetGrammar.uniform(cfg) 91 | for bucketSize in range(3, 10): 92 | last = Bucket(bucketSize) 93 | for program in enumerate_bucket_prob_grammar(pcfg, bucket_size=bucketSize): 94 | p = pcfg.reduce_derivations( 95 | lambda b, S, P, _: b.add_prob_uniform(pcfg.probabilities[S][P]), 96 | Bucket(bucketSize), 97 | program, 98 | ) 99 | assert p.size == bucketSize 100 | assert p >= last or last == Bucket(bucketSize) 101 | last = p 102 | 103 | 104 | @pytest.mark.parametrize("cfg", testdata) 105 | def test_merge(cfg: TTCFG) -> None: 106 | pcfg = ProbDetGrammar.uniform(cfg) 107 | seen = set() 108 | for program in enumerate_prob_grammar(pcfg): 109 | assert program not in seen 110 | seen.add(program) 111 | en = enumerate_prob_grammar(pcfg) 112 | removed = dsl.parse_program("(+ 1 1)", auto_type("int")) 113 | en.merge_program(dsl.parse_program("2", auto_type("int")), removed) 114 | new_seen = set() 115 | for program in en: 116 | assert removed not in program 117 | new_seen.add(program) 118 | diff = seen.difference(new_seen) 119 | for x in diff: 120 | assert removed in x 121 | 122 | 123 | def test_infinite() -> None: 124 | pcfg = ProbDetGrammar.random( 125 | CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1 126 | ) 127 | count = 10000 128 | last = 1.0 129 | for program in enumerate_prob_grammar(pcfg): 130 | count -= 1 131 | p = pcfg.probability(program) 132 | assert -1e-12 <= last - p, f"failed at program n°{count}:{program}" 133 | last = p 134 | if count < 0: 135 | break 136 | assert count == -1 137 | -------------------------------------------------------------------------------- /tests/syntax/grammars/test_cfg.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.cfg import CFG 2 | from synth.syntax.dsl import DSL 3 | from synth.syntax.program import Primitive 4 | from synth.syntax.type_system import ( 5 | INT, 6 | STRING, 7 | Arrow, 8 | List, 9 | PolymorphicType, 10 | PrimitiveType, 11 | ) 12 | from synth.syntax.type_helper import FunctionType 13 | 14 | import pytest 15 | 16 | syntax = { 17 | "+": FunctionType(INT, INT, INT), 18 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 19 | "non_reachable": PrimitiveType("non_reachable"), 20 | "1": INT, 21 | "non_productive": FunctionType(INT, STRING), 22 | } 23 | dsl = DSL(syntax) 24 | max_depths = [3, 7, 11] 25 | 26 | 27 | def test_function_as_variable() -> None: 28 | dsl = DSL(syntax) 29 | max_depth = 5 30 | cfg = CFG.depth_constraint(dsl, FunctionType(Arrow(INT, INT), INT), max_depth) 31 | assert cfg.programs() > 0 32 | 33 | 34 | @pytest.mark.parametrize("max_depth", max_depths) 35 | def test_clean(max_depth: int) -> None: 36 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 37 | for rule in cfg.rules: 38 | assert rule[1][0][1] <= max_depth 39 | for P in cfg.rules[rule]: 40 | if isinstance(P, Primitive): 41 | assert P.primitive != "non_reachable" 42 | assert P.primitive != "non_productive" 43 | assert P.primitive != "head" 44 | 45 | 46 | @pytest.mark.parametrize("max_depth", max_depths) 47 | def test_depth_constraint(max_depth: int) -> None: 48 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 49 | res = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 50 | print(cfg) 51 | while res.depth() <= max_depth: 52 | assert ( 53 | res in cfg 54 | ), f"Program depth:{res.depth()} should be in the TTCFG max_depth:{max_depth}" 55 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 56 | assert ( 57 | res not in cfg 58 | ), f"Program depth:{res.depth()} should NOT be in the TTCFG max_depth:{max_depth}" 59 | 60 | 61 | def test_infinite() -> None: 62 | cfg = CFG.infinite(dsl, FunctionType(INT, INT), n_gram=2) 63 | res = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 64 | print(cfg) 65 | while res.depth() <= 30: 66 | assert ( 67 | res in cfg 68 | ), f"Program depth:{res.depth()} should be in the infinite TTCFG" 69 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 70 | -------------------------------------------------------------------------------- /tests/syntax/grammars/test_tagged_det_grammar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar 4 | from synth.syntax.grammars.cfg import CFG 5 | from synth.syntax.dsl import DSL 6 | from synth.syntax.grammars.ttcfg import TTCFG 7 | from synth.syntax.type_system import ( 8 | INT, 9 | STRING, 10 | List, 11 | PolymorphicType, 12 | PrimitiveType, 13 | ) 14 | from synth.syntax.type_helper import FunctionType 15 | 16 | import pytest 17 | 18 | syntax = { 19 | "+": FunctionType(INT, INT, INT), 20 | "-": FunctionType(INT, INT, INT), 21 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 22 | "non_reachable": PrimitiveType("non_reachable"), 23 | "1": INT, 24 | "2": INT, 25 | "non_productive": FunctionType(INT, STRING), 26 | } 27 | dsl = DSL(syntax) 28 | max_depths = [3, 7, 11] 29 | 30 | 31 | @pytest.mark.parametrize("max_depth", max_depths) 32 | def test_from_cfg(max_depth: int) -> None: 33 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 34 | pcfg = ProbDetGrammar.uniform(cfg) 35 | for rule in pcfg.rules: 36 | n = len(pcfg.rules[rule]) 37 | for P in pcfg.rules[rule]: 38 | prob = pcfg.probabilities[rule][P] 39 | assert np.isclose(prob, 1 / n) 40 | 41 | 42 | @pytest.mark.parametrize("max_depth", max_depths) 43 | def test_from_ttcfg(max_depth: int) -> None: 44 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_depth) 45 | pcfg = ProbDetGrammar.uniform(cfg) 46 | for rule in pcfg.rules: 47 | n = len(pcfg.rules[rule]) 48 | for P in pcfg.rules[rule]: 49 | prob = pcfg.probabilities[rule][P] 50 | assert np.isclose(prob, 1 / n) 51 | 52 | 53 | @pytest.mark.parametrize("max_depth", max_depths) 54 | def test_ready_for_sampling(max_depth: int) -> None: 55 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 56 | pcfg = ProbDetGrammar.uniform(cfg) 57 | assert not pcfg.ready_for_sampling 58 | pcfg.init_sampling() 59 | assert pcfg.ready_for_sampling 60 | 61 | 62 | @pytest.mark.parametrize("max_depth", max_depths) 63 | def test_seeding(max_depth: int) -> None: 64 | seed = 100 65 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 66 | pcfg = ProbDetGrammar.uniform(cfg) 67 | pcfg.init_sampling(seed) 68 | g1 = pcfg.sampling() 69 | cpy = ProbDetGrammar.uniform(cfg) 70 | cpy.init_sampling(seed) 71 | assert pcfg == cpy 72 | g2 = cpy.sampling() 73 | for _ in range(200): 74 | p1, p2 = next(g1), next(g2) 75 | assert p1 == p2, f"[n°{_}]: {p1} != {p2}" 76 | 77 | 78 | @pytest.mark.parametrize("max_depth", max_depths) 79 | def test_depth(max_depth: int) -> None: 80 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 81 | pcfg = ProbDetGrammar.uniform(cfg) 82 | pcfg.init_sampling(0) 83 | g = pcfg.sampling() 84 | for _ in range(200): 85 | assert next(g).depth() <= max_depth 86 | -------------------------------------------------------------------------------- /tests/syntax/grammars/test_tagged_u_grammar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar 4 | from synth.syntax.grammars.u_cfg import UCFG 5 | from synth.syntax.dsl import DSL 6 | from synth.syntax.type_system import ( 7 | INT, 8 | STRING, 9 | List, 10 | PolymorphicType, 11 | PrimitiveType, 12 | ) 13 | from synth.syntax.type_helper import FunctionType 14 | 15 | import pytest 16 | 17 | syntax = { 18 | "+": FunctionType(INT, INT, INT), 19 | "-": FunctionType(INT, INT, INT), 20 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 21 | "non_reachable": PrimitiveType("non_reachable"), 22 | "1": INT, 23 | "2": INT, 24 | "non_productive": FunctionType(INT, STRING), 25 | } 26 | 27 | dsl = DSL(syntax) 28 | max_depths = [3, 7, 11] 29 | 30 | 31 | @pytest.mark.parametrize("max_depth", max_depths) 32 | def test_from_cfg(max_depth: int) -> None: 33 | cfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 34 | pcfg = ProbUGrammar.uniform(cfg) 35 | for rule in pcfg.rules: 36 | n = sum(len(pcfg.rules[rule][P]) for P in pcfg.rules[rule]) 37 | for P in pcfg.rules[rule]: 38 | dico = pcfg.probabilities[rule][P] 39 | for _, prob in dico.items(): 40 | assert np.isclose(prob, 1 / n) 41 | 42 | 43 | # def test_from_ttcfg() -> None: 44 | # dsl = DSL(syntax) 45 | # for max_depth in [3, 7, 11]: 46 | # cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_depth) 47 | # pcfg = ProbUGrammar.uniform(cfg) 48 | # for rule in pcfg.rules: 49 | # n = len(pcfg.rules[rule]) 50 | # for P in pcfg.rules[rule]: 51 | # prob = pcfg.probabilities[rule][P] 52 | # assert np.isclose(prob, 1 / n) 53 | 54 | 55 | @pytest.mark.parametrize("max_depth", max_depths) 56 | def test_ready_for_sampling(max_depth: int) -> None: 57 | cfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 58 | pcfg = ProbUGrammar.uniform(cfg) 59 | assert not pcfg.ready_for_sampling 60 | pcfg.init_sampling() 61 | assert pcfg.ready_for_sampling 62 | 63 | 64 | @pytest.mark.parametrize("max_depth", max_depths) 65 | def test_seeding(max_depth: int) -> None: 66 | seed = 100 67 | cfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 68 | pcfg = ProbUGrammar.uniform(cfg) 69 | pcfg.init_sampling(seed) 70 | g1 = pcfg.sampling() 71 | cpy = ProbUGrammar.uniform(cfg) 72 | cpy.init_sampling(seed) 73 | assert pcfg == cpy 74 | g2 = cpy.sampling() 75 | for _ in range(200): 76 | p1, p2 = next(g1), next(g2) 77 | assert p1 == p2, f"[n°{_}]: {p1} != {p2}" 78 | 79 | 80 | @pytest.mark.parametrize("max_depth", max_depths) 81 | def test_depth(max_depth: int) -> None: 82 | cfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 83 | pcfg = ProbUGrammar.uniform(cfg) 84 | pcfg.init_sampling(0) 85 | g = pcfg.sampling() 86 | for _ in range(200): 87 | assert next(g).depth() <= max_depth 88 | -------------------------------------------------------------------------------- /tests/syntax/grammars/test_ttcfg.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.ttcfg import TTCFG 2 | from synth.syntax.dsl import DSL 3 | from synth.syntax.program import Primitive 4 | from synth.syntax.type_system import ( 5 | INT, 6 | STRING, 7 | List, 8 | PolymorphicType, 9 | PrimitiveType, 10 | ) 11 | from synth.syntax.type_helper import FunctionType 12 | 13 | import pytest 14 | 15 | syntax = { 16 | "+": FunctionType(INT, INT, INT), 17 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 18 | "non_reachable": PrimitiveType("non_reachable"), 19 | "1": INT, 20 | "non_productive": FunctionType(INT, STRING), 21 | } 22 | dsl = DSL(syntax) 23 | dsl.instantiate_polymorphic_types() 24 | max_sizes = [3, 7, 11] 25 | 26 | 27 | @pytest.mark.parametrize("max_size", max_sizes) 28 | def test_clean(max_size: int) -> None: 29 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 30 | for rule in cfg.rules: 31 | for P in cfg.rules[rule]: 32 | if isinstance(P, Primitive): 33 | assert P.primitive != "non_reachable" 34 | assert P.primitive != "non_productive" 35 | assert P.primitive != "head" 36 | else: 37 | assert P.type == INT 38 | 39 | 40 | @pytest.mark.parametrize("max_size,progs", [(1, 2), (3, 2 + 4), (5, 22)]) 41 | def test_size(max_size: int, progs: int) -> None: 42 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 43 | size = cfg.programs() 44 | print(cfg) 45 | assert size == progs 46 | 47 | 48 | @pytest.mark.parametrize("max_size", max_sizes) 49 | def test_size_constraint(max_size: int) -> None: 50 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 51 | size1 = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 52 | res = size1 53 | while res.size() <= max_size: 54 | assert ( 55 | res in cfg 56 | ), f"Program size:{res.size()} should be in the TTCFG max_size:{max_size}" 57 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 58 | assert ( 59 | res not in cfg 60 | ), f"Program size:{res.size()} should NOT be in the TTCFG max_size:{max_size}" 61 | 62 | 63 | @pytest.mark.parametrize("max_occ", [3, 4, 5]) 64 | def test_at_most(max_occ: int) -> None: 65 | cfg = TTCFG.at_most_k(dsl, FunctionType(INT, INT), "+", max_occ) 66 | res = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 67 | while res.depth() - 1 <= max_occ: 68 | assert ( 69 | res in cfg 70 | ), f"Occurences:{res.depth() - 1} should be in the TTCFG max occurences:{max_occ}" 71 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 72 | assert ( 73 | res not in cfg 74 | ), f"Occurences:{res.depth() - 1} should NOT be in the TTCFG max occurences:{max_occ}" 75 | 76 | 77 | @pytest.mark.parametrize("max_size", max_sizes) 78 | def test_clean(max_size: int) -> None: 79 | cfg = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 80 | for rule in cfg.rules: 81 | for P in cfg.rules[rule]: 82 | if isinstance(P, Primitive): 83 | assert P.primitive != "non_reachable" 84 | assert P.primitive != "non_productive" 85 | assert P.primitive != "head" 86 | 87 | cpy = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 88 | cpy.clean() 89 | assert cfg == cpy 90 | 91 | 92 | def test_product() -> None: 93 | max_size = 3 94 | cfg1 = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size * 2) 95 | cfg2 = TTCFG.size_constraint(dsl, FunctionType(INT, INT), max_size) 96 | cfg = cfg1 * cfg2 97 | assert cfg 98 | size1 = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 99 | res = size1 100 | while res.size() <= max_size: 101 | assert ( 102 | res in cfg 103 | ), f"Program size:{res.size()} should be in the TTCFG max_size:{max_size}" 104 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 105 | assert ( 106 | res not in cfg 107 | ), f"Program size:{res.size()} should NOT be in the TTCFG max_size:{max_size}" 108 | -------------------------------------------------------------------------------- /tests/syntax/grammars/test_ucfg.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.grammars.cfg import CFG 2 | from synth.syntax.grammars.u_cfg import UCFG 3 | from synth.syntax.dsl import DSL 4 | from synth.syntax.program import Primitive 5 | from synth.syntax.type_system import ( 6 | INT, 7 | STRING, 8 | List, 9 | PolymorphicType, 10 | PrimitiveType, 11 | ) 12 | from synth.syntax.type_helper import FunctionType 13 | 14 | import pytest 15 | 16 | 17 | syntax = { 18 | "+": FunctionType(INT, INT, INT), 19 | "head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")), 20 | "non_reachable": PrimitiveType("non_reachable"), 21 | "1": INT, 22 | "non_productive": FunctionType(INT, STRING), 23 | } 24 | dsl = DSL(syntax) 25 | dsl.instantiate_polymorphic_types() 26 | max_depths = [3, 7, 11] 27 | 28 | 29 | @pytest.mark.parametrize("max_depth", max_depths) 30 | def test_size(max_depth: int) -> None: 31 | ucfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 32 | cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 33 | assert cfg.programs() == ucfg.programs() 34 | 35 | 36 | @pytest.mark.parametrize("max_depth", max_depths) 37 | def test_clean(max_depth: int) -> None: 38 | dirty_ucfg = UCFG.from_CFG( 39 | CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth), clean=False 40 | ) 41 | clean_ucfg = UCFG.from_CFG( 42 | CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth), clean=True 43 | ) 44 | 45 | assert clean_ucfg.programs() == dirty_ucfg.programs() 46 | 47 | for rule in clean_ucfg.rules: 48 | assert rule[1][1] <= max_depth 49 | for P in clean_ucfg.rules[rule]: 50 | if isinstance(P, Primitive): 51 | assert P.primitive != "non_reachable" 52 | assert P.primitive != "non_productive" 53 | assert P.primitive != "head" 54 | 55 | 56 | @pytest.mark.parametrize("max_depth", max_depths) 57 | def test_depth_constraint(max_depth: int) -> None: 58 | cfg = UCFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth) 59 | res = dsl.parse_program("(+ 1 var0)", FunctionType(INT, INT)) 60 | print(cfg) 61 | while res.depth() <= max_depth: 62 | assert ( 63 | res in cfg 64 | ), f"Program depth:{res.depth()} should be in the TTCFG max_depth:{max_depth}" 65 | res = dsl.parse_program(f"(+ {res} var0)", FunctionType(INT, INT)) 66 | assert ( 67 | res not in cfg 68 | ), f"Program depth:{res.depth()} should NOT be in the TTCFG max_depth:{max_depth}" 69 | -------------------------------------------------------------------------------- /tests/syntax/test_dsl.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.dsl import DSL 2 | from synth.syntax.type_system import INT, PolymorphicType 3 | from synth.syntax.type_helper import FunctionType 4 | 5 | 6 | syntax = { 7 | "+": FunctionType(INT, INT, INT), 8 | "id": FunctionType(PolymorphicType("a"), PolymorphicType("a")), 9 | } 10 | 11 | 12 | def test_instantiate_polymorphic() -> None: 13 | polymorphic_keys = [key for key, t in syntax.items() if t.is_polymorphic()] 14 | for size in [2, 5, 7, 11]: 15 | dsl = DSL(syntax) 16 | dsl.instantiate_polymorphic_types(size) 17 | 18 | for primitive in dsl.list_primitives: 19 | type = primitive.type 20 | assert not type.is_polymorphic() 21 | print(primitive.primitive) 22 | if primitive.primitive in polymorphic_keys: 23 | assert type.size() <= size 24 | 25 | cpy = DSL(syntax) 26 | cpy.instantiate_polymorphic_types(size) 27 | cpy.instantiate_polymorphic_types(size) 28 | assert cpy == dsl 29 | -------------------------------------------------------------------------------- /tests/syntax/test_program.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, List 2 | import random 3 | 4 | from synth.syntax.program import Primitive, Function, Program, Variable 5 | from synth.syntax.type_system import BOOL, INT 6 | from synth.syntax.type_helper import FunctionType 7 | 8 | 9 | def __gen2list__(g: Generator) -> List: 10 | out = [] 11 | try: 12 | while True: 13 | out.append(next(g)) 14 | except StopIteration: 15 | return out 16 | 17 | 18 | def test_guess_function_type() -> None: 19 | vars: List[Program] = [Variable(i, INT) for i in range(10)] 20 | random.seed(0) 21 | for c in range(1, len(vars)): 22 | sub_vars = vars[:c] 23 | random.shuffle(sub_vars) 24 | f = Function( 25 | Primitive("f", FunctionType(*[INT for _ in range(c + 1)])), sub_vars 26 | ) 27 | assert f.type == INT 28 | if c > 1: 29 | for _ in range(random.randint(1, c - 1)): 30 | sub_vars.pop() 31 | f = Function( 32 | Primitive("f", FunctionType(*[INT for _ in range(c + 1)])), sub_vars 33 | ) 34 | assert f.type == FunctionType(*[INT for _ in range(c - len(sub_vars) + 1)]) 35 | 36 | 37 | def test_depth_first_iter() -> None: 38 | i, b, f = ( 39 | Primitive("a", INT), 40 | Variable(0, BOOL), 41 | Primitive("f", FunctionType(INT, BOOL, INT)), 42 | ) 43 | 44 | assert __gen2list__(i.depth_first_iter()) == [i] 45 | assert __gen2list__(b.depth_first_iter()) == [b] 46 | assert __gen2list__(f.depth_first_iter()) == [f] 47 | fun = Function(f, [i, b]) 48 | assert __gen2list__(fun.depth_first_iter()) == [f, i, b, fun] 49 | fun2 = Function(f, [fun, b]) 50 | assert __gen2list__(fun2.depth_first_iter()) == [f, f, i, b, fun, b, fun2] 51 | 52 | 53 | def test_is_using_all_variables() -> None: 54 | vars: List[Program] = [Variable(i, INT) for i in range(10)] 55 | random.seed(0) 56 | for c in range(1, len(vars)): 57 | sub_vars = vars[:c] 58 | random.shuffle(sub_vars) 59 | f = Function( 60 | Primitive("f", FunctionType(*[INT for _ in range(c + 1)])), sub_vars 61 | ) 62 | assert len(f.used_variables()) == c 63 | sub_vars.pop() 64 | f = Function( 65 | Primitive("f", FunctionType(*[INT for _ in range(c + 1)])), sub_vars 66 | ) 67 | assert len(f.used_variables()) == c - 1 68 | -------------------------------------------------------------------------------- /tests/syntax/test_type_helper.py: -------------------------------------------------------------------------------- 1 | from synth.syntax.type_system import ( 2 | STRING, 3 | INT, 4 | BOOL, 5 | FixedPolymorphicType, 6 | Generic, 7 | GenericFunctor, 8 | PolymorphicType, 9 | List, 10 | Arrow, 11 | PrimitiveType, 12 | UnknownType, 13 | match, 14 | ) 15 | from synth.syntax.type_helper import auto_type, FunctionType, guess_type, Optional 16 | import random 17 | 18 | 19 | def test_guess_type() -> None: 20 | # Bool 21 | assert guess_type(True) == BOOL 22 | assert guess_type(False) == BOOL 23 | # Int 24 | random.seed(0) 25 | for _ in range(100): 26 | assert guess_type(random.randint(-100, 100)) == INT 27 | # String 28 | assert guess_type("") == STRING 29 | # List 30 | assert match(guess_type([]), List(PolymorphicType(""))) 31 | assert guess_type([True]) == List(BOOL) 32 | assert guess_type([""]) == List(STRING) 33 | assert guess_type([1]) == List(INT) 34 | # Unknown 35 | assert isinstance(guess_type(int), UnknownType) 36 | 37 | 38 | def test_FunctionType() -> None: 39 | assert FunctionType(INT, BOOL, STRING, List(INT)) == Arrow( 40 | INT, Arrow(BOOL, Arrow(STRING, List(INT))) 41 | ) 42 | 43 | 44 | def test_auto_type_base() -> None: 45 | assert PrimitiveType("int") == auto_type("int") 46 | assert PrimitiveType("bb") == auto_type("bb") 47 | assert PolymorphicType("bb") == auto_type("'bb") 48 | assert PolymorphicType("aa") == auto_type("'aa") 49 | assert PrimitiveType("a_a") == auto_type("a_a") 50 | 51 | 52 | def test_auto_type_advanced() -> None: 53 | assert List(PrimitiveType("int")) == auto_type("int list") 54 | assert List(PolymorphicType("a")) == auto_type("'a list") 55 | 56 | some = GenericFunctor("some", min_args=1, max_args=1) 57 | 58 | assert some(PolymorphicType("a")) == auto_type("'a some") 59 | assert Optional(PolymorphicType("a")) == auto_type("'a optional") 60 | assert Optional(some(PolymorphicType("a"))) == auto_type("'a some optional") 61 | 62 | x = PrimitiveType("bb") | PolymorphicType("aa") 63 | assert x == auto_type("bb | 'aa") 64 | assert x == auto_type("bb|'aa") 65 | assert x == auto_type("'aa | bb") 66 | 67 | 68 | def test_auto_type_arrows() -> None: 69 | a = PrimitiveType("a") 70 | b = PrimitiveType("b") 71 | assert FunctionType(a, b) == auto_type("a->b") 72 | assert FunctionType(a, b, b) == auto_type("a->b->b") 73 | assert FunctionType(a, FunctionType(a, b), b) == auto_type("a->(a->b)->b") 74 | 75 | assert Generic("*", a, b, infix=True) == auto_type("a*b") 76 | 77 | 78 | def test_auto_type_fixed_poly() -> None: 79 | x = FixedPolymorphicType("z", PrimitiveType("b") | PrimitiveType("c")) 80 | assert x == auto_type("'z[b|c]") 81 | assert x == auto_type("'z [b|c]") 82 | assert x == auto_type("'z[ b|c ]") 83 | -------------------------------------------------------------------------------- /tests/test_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pathlib 3 | 4 | from synth.syntax.type_system import INT 5 | from synth.syntax.type_helper import FunctionType 6 | from synth.syntax.program import Variable 7 | from synth.task import Task, Dataset 8 | from synth.specification import PBE, Example 9 | 10 | 11 | def test_dataset_save_and_load(tmp_path: pathlib.Path) -> None: 12 | file_path = tmp_path / "dataset.pickle" 13 | random.seed(0) 14 | dataset = Dataset( 15 | [ 16 | Task( 17 | FunctionType(INT, INT, INT), 18 | PBE( 19 | [ 20 | Example( 21 | [random.randint(0, 100), random.randint(0, 100)], 22 | random.randint(0, 100), 23 | ) 24 | for _ in range(5) 25 | ] 26 | ), 27 | Variable(0, INT) if random.random() > 0.5 else None, 28 | metadata={"index": i}, 29 | ) 30 | for i in range(100) 31 | ], 32 | metadata={"something": False, "else": "is", "coming": 42}, 33 | ) 34 | dataset.save(file_path.as_posix()) 35 | loaded = Dataset[PBE].load(file_path.as_posix()) 36 | assert dataset == loaded 37 | 38 | 39 | def test_dataset_behaviour() -> None: 40 | random.seed(0) 41 | dataset = Dataset( 42 | [ 43 | Task( 44 | FunctionType(INT, INT, INT), 45 | PBE( 46 | [ 47 | Example( 48 | [random.randint(0, 100), random.randint(0, 100)], 49 | random.randint(0, 100), 50 | ) 51 | for _ in range(5) 52 | ] 53 | ), 54 | Variable(0, INT) if random.random() > 0.5 else None, 55 | metadata={"index": i}, 56 | ) 57 | for i in range(100) 58 | ], 59 | metadata={"something": False, "else": "is", "coming": 42}, 60 | ) 61 | assert dataset[0] == dataset.tasks[0] 62 | assert dataset[-5:-1] == dataset.tasks[-5:-1] 63 | assert dataset.tasks == [x for x in dataset] 64 | assert len(dataset) == len(dataset.tasks) 65 | -------------------------------------------------------------------------------- /tests/test_vose.py: -------------------------------------------------------------------------------- 1 | import vose 2 | 3 | import numpy as np 4 | 5 | 6 | def test_seeding() -> None: 7 | for _ in range(100): 8 | probs = np.random.randn((10)) 9 | probs /= np.sum(probs) 10 | seed = np.random.randint(9999999) 11 | a = vose.Sampler(probs, seed=seed) 12 | b = vose.Sampler(probs, seed=seed) 13 | for i in range(100): 14 | assert a.sample() == b.sample() 15 | -------------------------------------------------------------------------------- /tests/utils/test_generator_utils.py: -------------------------------------------------------------------------------- 1 | from synth.utils.generator_utils import gen_take 2 | 3 | 4 | def test_gen_take() -> None: 5 | g = (x for x in range(10000)) 6 | for i in range(10): 7 | l = gen_take(g, 100) 8 | assert l == list(range(i * 100, (i + 1) * 100)) 9 | --------------------------------------------------------------------------------