├── .bandit ├── .git-blame-ignore-revs ├── .github └── workflows │ ├── build-and-test.yaml │ └── code-quality.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .yapfignore ├── CODEOWNERS ├── Dockerfile.test ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION ├── docs ├── contributing.md ├── examples │ ├── grid_cond_gfn.ipynb │ └── grid_cond_gfn.py ├── getting_started.md └── implementation_notes.md ├── generate_requirements.sh ├── pyproject.toml ├── requirements ├── dev-3.10.in ├── dev-3.10.txt ├── main-3.10.in └── main-3.10.txt ├── setup.py ├── src └── gflownet │ ├── __init__.py │ ├── algo │ ├── __init__.py │ ├── advantage_actor_critic.py │ ├── config.py │ ├── envelope_q_learning.py │ ├── flow_matching.py │ ├── graph_sampling.py │ ├── multiobjective_reinforce.py │ ├── soft_q_learning.py │ └── trajectory_balance.py │ ├── config.py │ ├── data │ ├── __init__.py │ ├── config.py │ ├── data_source.py │ ├── qm9.py │ └── replay_buffer.py │ ├── envs │ ├── __init__.py │ ├── frag_mol_env.py │ ├── graph_building_env.py │ ├── mol_building_env.py │ ├── seq_building_env.py │ └── test.py │ ├── hyperopt │ └── wandb_demo │ │ ├── README.md │ │ ├── init_wandb_sweep.py │ │ └── launch_wandb_agents.sh │ ├── models │ ├── __init__.py │ ├── bengio2021flow.py │ ├── config.py │ ├── graph_transformer.py │ ├── mxmnet.py │ └── seq_transformer.py │ ├── online_trainer.py │ ├── tasks │ ├── __init__.py │ ├── config.py │ ├── make_rings.py │ ├── qm9.py │ ├── qm9_moo.py │ ├── seh_frag.py │ ├── seh_frag_moo.py │ └── toy_seq.py │ ├── trainer.py │ └── utils │ ├── __init__.py │ ├── conditioning.py │ ├── config.py │ ├── focus_model.py │ ├── fpscores.pkl.gz │ ├── graphs.py │ ├── metrics.py │ ├── misc.py │ ├── multiobjective_hooks.py │ ├── multiprocessing_proxy.py │ ├── sascore.py │ ├── sqlite_log.py │ └── transforms.py ├── tests ├── __init__.py ├── test_envs.py ├── test_graph_building_env.py └── test_subtb.py └── tox.ini /.bandit: -------------------------------------------------------------------------------- 1 | [bandit] 2 | exclude = ./.tox,tests,docs 3 | skips = B101,B614 4 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | 915625c873755f10ae348b5262372c4a78dfa6d9 2 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yaml: -------------------------------------------------------------------------------- 1 | name: Build-and-Test 2 | on: 3 | push: 4 | branches: 5 | - trunk 6 | pull_request: 7 | jobs: 8 | tests: 9 | name: ${{ matrix.name }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | include: 15 | - {name: Linux, python: '3.10', os: ubuntu-latest, tox: py310} 16 | - {name: Windows, python: '3.10', os: windows-latest, tox: py310} 17 | # Some packages fail on M1 (macos-latest) due to not having wheels and not being able to build from source 18 | - {name: Mac, python: '3.10', os: macos-13, tox: py310} 19 | steps: 20 | - uses: actions/checkout@v3 21 | - uses: actions/setup-python@v3 22 | with: 23 | python-version: ${{ matrix.python }} 24 | cache: 'pip' 25 | cache-dependency-path: 'requirements/*.txt' 26 | - name: update pip 27 | run: | 28 | pip install -U wheel 29 | pip install -U setuptools 30 | python -m pip install -U pip 31 | - run: pip install tox 32 | - run: tox -e ${{ matrix.tox }} 33 | -------------------------------------------------------------------------------- /.github/workflows/code-quality.yaml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | on: 3 | push: 4 | branches: 5 | - trunk 6 | pull_request: 7 | jobs: 8 | tests: 9 | name: Code Quality 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: 3.10.13 16 | cache: 'pip' 17 | cache-dependency-path: 'requirements/*.txt' 18 | architecture: 'x64' 19 | - name: update pip 20 | run: | 21 | pip install -U wheel 22 | pip install -U setuptools 23 | python -m pip install -U pip 24 | - run: pip install tox 25 | - run: tox -e style 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Model cache 2 | src/gflownet/models/cache/ 3 | 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com/hooks.html 2 | 3 | exclude: ^(docs/|.tox) 4 | default_language_version: 5 | python: python3.10 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.1.0 9 | hooks: 10 | - id: check-added-large-files # Prevent giant files from being committed 11 | - id: check-ast # Simply check whether the files parse as valid python. 12 | - id: check-byte-order-marker # forbid files which have a UTF-8 byte-order marker 13 | - id: check-builtin-literals # Require literal syntax when initializing empty or zero Python builtin types. 14 | - id: check-case-conflict # Check for/home/pawelrc/code/gflownet/requirements/docs files that would conflict in case-insensitive filesystems 15 | - id: check-docstring-first # Checks a common error of defining a docstring after code. 16 | - id: check-executables-have-shebangs # Ensures that (non-binary) executables have a shebang. 17 | - id: check-json # This hook checks json files for parseable syntax. 18 | - id: check-shebang-scripts-are-executable # Ensures that (non-binary) files with a shebang are executable. 19 | - id: pretty-format-json # This hook sets a standard for formatting JSON files. 20 | - id: check-merge-conflict # Check for files that contain merge conflict strings. 21 | - id: check-symlinks # Checks for symlinks which do not point to anything. 22 | - id: check-toml # This hook checks toml files for parseable syntax. 23 | - id: check-vcs-permalinks # Ensures that links to vcs websites are permalinks. 24 | - id: check-xml # This hook checks xml files for parseable syntax. 25 | - id: check-yaml # This hook checks yaml files for parseable syntax. 26 | - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source. 27 | - id: destroyed-symlinks # Detects symlinks which are changed to regular files with a content of a path which that symlink was pointing to. 28 | - id: detect-private-key # Detects the presence of private keys 29 | - id: end-of-file-fixer # Ensures that a file is either empty, or ends with one newline. 30 | - id: fix-byte-order-marker # removes UTF-8 byte order marker 31 | - id: mixed-line-ending # Replaces or checks mixed line ending 32 | - id: sort-simple-yaml # Sorts simple YAML files which consist only of top-level keys, preserving comments and blocks. 33 | - id: trailing-whitespace # This hook trims trailing whitespace. 34 | - repo: local 35 | hooks: 36 | - id: isort 37 | name: isort 38 | entry: "isort --settings-path=pyproject.toml" 39 | language: system 40 | types: [python] 41 | require_serial: true 42 | - id: black 43 | name: black 44 | entry: "black --config=pyproject.toml" 45 | language: system 46 | types: [python] 47 | require_serial: true 48 | - id: mypy 49 | name: mypy 50 | entry: "mypy --config-file=pyproject.toml" 51 | language: system 52 | types: [python] 53 | require_serial: true 54 | - id: bandit 55 | name: bandit 56 | entry: "bandit -r -c pyproject.toml ." 57 | language: system 58 | types: [python] 59 | require_serial: true 60 | - repo: https://github.com/pre-commit/pygrep-hooks 61 | rev: v1.9.0 62 | hooks: 63 | - id: python-check-blanket-noqa # Enforce that noqa annotations always occur with specific codes. Sample annotations: # noqa: F401, # noqa: F401,W203 64 | # - id: python-check-blanket-type-ignore # Enforce that # type: ignore annotations always occur with specific codes. Sample annotations: # type: ignore[attr-defined], # type: ignore[attr-defined, name-defined] 65 | - id: python-check-mock-methods # Prevent common mistakes of assert mck.not_called(), assert mck.called_once_with(...) and mck.assert_called. 66 | - id: python-use-type-annotations # Enforce that python3.6+ type annotations are used instead of type comments 67 | - id: text-unicode-replacement-char # Forbid files which have a UTF-8 Unicode replacement character 68 | - repo: https://github.com/charliermarsh/ruff-pre-commit 69 | rev: 'v0.0.261' 70 | hooks: 71 | - id: ruff 72 | -------------------------------------------------------------------------------- /.yapfignore: -------------------------------------------------------------------------------- 1 | venv*/ 2 | build/ 3 | dist/ 4 | .tox 5 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @bengioe 2 | -------------------------------------------------------------------------------- /Dockerfile.test: -------------------------------------------------------------------------------- 1 | FROM gcr.io/eng-infrastructure/rxrx-pyenv as test 2 | ENV LC_ALL=C.UTF-8 3 | ENV LANG=C.UTF-8 4 | 5 | ENV CONFIGOME_ENV=test 6 | ENTRYPOINT [ "tox", "--parallel" ] 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Recursion Pharmaceuticals 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![Build-and-Test](https://github.com/recursionpharma/gflownet/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/recursionpharma/gflownet/actions/workflows/build-and-test.yaml) 4 | [![Code Quality](https://github.com/recursionpharma/gflownet/actions/workflows/code-quality.yaml/badge.svg)](https://github.com/recursionpharma/gflownet/actions/workflows/code-quality.yaml) 5 | [![Python versions](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/downloads/) 6 | [![license: MIT](https://img.shields.io/badge/License-MIT-purple.svg)](LICENSE) 7 | 8 | # gflownet 9 | 10 | GFlowNet-related training and environment code on graphs. 11 | 12 | **Primer** 13 | 14 | GFlowNet [[1]](https://yoshuabengio.org/2022/03/05/generative-flow-networks/), [[2]](https://www.gflownet.org/), [[3]](https://github.com/zdhNarsil/Awesome-GFlowNets), short for Generative Flow Network, is a novel generative modeling framework, particularly suited for discrete, combinatorial objects. Here in particular it is implemented for graph generation. 15 | 16 | The idea behind GFN is to estimate flows in a (graph-theoretic) directed acyclic network*. The network represents all possible ways of constructing objects, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a _trajectory_. *Perhaps confusingly, the _network_ in GFN refers to the state space, not a neural network architecture. 17 | 18 | The main focus of this library (although it can do other things) is to construct graphs (e.g. graphs of atoms), which are constructed node by node. To make policy predictions, we use a graph neural network. This GNN outputs per-node logits (e.g. add an atom to this atom, or add a bond between these two atoms), as well as per-graph logits (e.g. stop/"done constructing this object"). 19 | 20 | This library supports a variety of GFN algorithms (as well as some baselines), and supports training on a mix of existing data (offline) and self-generated data (online), the latter being obtained by querying the model sequentially to obtain trajectories. 21 | 22 | 23 | ## Installation 24 | 25 | ### PIP 26 | 27 | This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the `--find-links` arguments must be specified as well: 28 | 29 | ```bash 30 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html 31 | ``` 32 | Or for CPU use: 33 | 34 | ```bash 35 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.1.2+cpu.html 36 | ``` 37 | 38 | To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository-pip-dependency/) a specific tag, for example here `v0.0.10`, use the following scheme: 39 | ```bash 40 | pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ... 41 | ``` 42 | 43 | If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main-3.10.txt`. 44 | 45 | ## Getting started 46 | 47 | A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`). 48 | 49 | For a gentler introduction to the library, see [Getting Started](docs/getting_started.md). For a more in-depth look at the library, see [Implementation Notes](docs/implementation_notes.md). 50 | 51 | ## Repo overview 52 | 53 | - [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. 54 | - [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. 55 | - [envs](src/gflownet/envs), contains environment classes; the base environment is agnostic to what kind of graph is being made, and context classes specify mappings from graphs to objects (e.g. molecules) and torch geometric Data. 56 | - [examples](docs/examples), contains simple example implementations of GFlowNet. 57 | - [models](src/gflownet/models), contains model definitions. 58 | - [tasks](src/gflownet/tasks), contains training code. 59 | - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. 60 | - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein 61 | - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). 62 | - [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). 63 | - [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. 64 | - [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. 65 | 66 | See [implementation notes](docs/implementation_notes.md) for more. 67 | 68 | 69 | ## Developing & Contributing 70 | 71 | External contributions are welcome. 72 | 73 | To install the developers dependencies 74 | ``` 75 | pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html 76 | ``` 77 | 78 | We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. 79 | To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. 80 | 81 | For more information, see [Contributing](docs/contributing.md). 82 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | MAJOR="0" 2 | MINOR="1" 3 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions to the repository are welcome, and we encourage you to open issues and pull requests. In general, it is recommended to fork this repository and open a pull request from your fork to the `trunk` branch. PRs are encouraged to be short and focused, and to include tests and documentation where appropriate. 4 | 5 | ## Installation 6 | 7 | To install the developers dependencies run: 8 | ``` 9 | pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html 10 | ``` 11 | 12 | ## Dependencies 13 | 14 | Dependencies are defined in `pyproject.toml`, and frozen versions that are known to work are provided in `requirements/`. 15 | 16 | To regenerate the frozen versions, run `./generate_requirements.sh `. See comments within. 17 | 18 | ## Linting and testing 19 | 20 | We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. 21 | To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. 22 | 23 | `tox` itself runs many linters, but the most important ones are `black`, `ruff`, `isort`, and `mypy`. The full list 24 | of linting tools is found in `.pre-commit-config.yaml`, while `tox.ini` defines the environments under which these 25 | linters (as well as tests) are run. 26 | 27 | ## Github Actions 28 | 29 | We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`. 30 | 31 | The cascade of events is as follows: 32 | - For `build-and-test`, `tox -> testenv:py310 -> pytest` is run. 33 | - For `code-quality`, `tox -e style -> testenv:style -> pre-commit -> {isort, black, mypy, bandit, ruff, & others}`. This and the "others" are defined in `.pre-commit-config.yaml` and include things like checking for secrets and trailing whitespace. 34 | 35 | ## Style Guide 36 | 37 | On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). 38 | Our docstrings follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) format, and we use type hints throughout the codebase. 39 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | For an introduction to the library, see [this colab notebook](https://colab.research.google.com/drive/1wANyo6Y-ceYEto9-p50riCsGRb_6U6eH). 4 | 5 | For an introduction to using `wandb` to log experiments, see [this demo](../src/gflownet/hyperopt/wandb_demo). 6 | 7 | For more general introductions to GFlowNets, check out the following: 8 | - The 2023 [GFlowNet workshop](https://gflownet.org/) has several introductory talks and colab tutorials. 9 | - This high-level [GFlowNet colab tutorial](https://colab.research.google.com/drive/1fUMwgu2OhYpQagpzU5mhe9_Esib3Q2VR) (updated versions of which were written for the 2023 workshop, in particular for continuous GFNs). 10 | 11 | A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`). -------------------------------------------------------------------------------- /docs/implementation_notes.md: -------------------------------------------------------------------------------- 1 | # Implementation notes 2 | 3 | This repo is centered around training GFlowNets that produce graphs, although sequences are also supported. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments. 4 | 5 | ## Environment, Context, Task, Trainers 6 | 7 | We separate experiment concerns in four categories: 8 | - The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP. 9 | - The Context provides an interface between the agent and the environment, it 10 | - maps graphs to torch_geometric `Data` 11 | instances 12 | - maps GraphActions to action indices 13 | - communicates to the model what inputs it should expect 14 | - The Task class is responsible for computing the reward of a state, and for sampling conditioning information 15 | - The Trainer class is responsible for instanciating everything, and running the training & testing loop 16 | 17 | Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. 18 | 19 | 20 | ## Graphs 21 | 22 | This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations. 23 | 24 | Some notes: 25 | - graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs. 26 | - When converting from `GraphAction`s (nx) to `ActionIndex`s (tuple of ints), the action indexes are encoding-bound, i.e. they point to specific rows and columns in the torch encoding. 27 | 28 | 29 | ### Graph policies & graph action categoricals 30 | 31 | The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch. 32 | 33 | Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. 34 | 35 | The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, action masks and so on; it can also be used to sample from the distribution. 36 | 37 | To expand, the logits are always 2d tensors, and there’s going to be one such tensor per “action type” that the agent is allowed to take. 38 | Since graphs have variable number of nodes, and since each node has `n_node_actions` associated possible action/logits, then the `(n_nodes, n_node_actions)` tensor will vary from minibatch to minibatch. 39 | In addition, the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`). 40 | 41 | Here’s an example: say we have 2 graphs in a minibatch, the first has 3 nodes, the second 2 nodes. The logits associated with AddNode will be of shape `(5, n)` (assuming there are `n` types of nodes in the problem). Say `n=2`, and `logits[AddNode] = [[1,2],[3,4],[5,6],[7,8],[9,0]]`, and `batch=[0,0,0,1,1]`. 42 | Then to compute the policy, we have to compute a softmax appropriately, i.e. the softmax for the first graph would be `softmax([1,2,3,4,5,6])` and for the second `softmax([7,8,9,0])` . This is possible thanks to `batch` and is what `GraphActionCategorical` does behind the scenes. 43 | Now that would be for when we only have the `AddNode` action. With more than one action we also have to compute the log-softmax log-normalization factor over the logits of these other tensors, log-add them together and then substract it from all corresponding logits. 44 | 45 | ## Data sources 46 | 47 | The data used for training GFlowNets can come from a variety of sources. `DataSource` implements these different use-cases as individual iterators that collectively assemble the training batches before passing it to the trainer. Some of these use-cases include: 48 | - Generating new trajectories on-policy 49 | - Sampling trajectories from passed policies from a replay buffer 50 | - Sampling trajectories from a fixed, offline dataset 51 | 52 | `DataSource` also covers validation sets, including cases such as: 53 | - Generating new trajectories (w.r.t a fixed dataset of conditioning goals) 54 | - Evaluating the model's likelihood on trajectories from a fixed, offline dataset -------------------------------------------------------------------------------- /generate_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: ./generate_requirements.sh (e.g. ./generate_requirements.sh dev-3.10) 4 | 5 | # set env variable 6 | # to allow pip-compile-cross-platform to use pip with --find-links. 7 | # not entirely sure this is needed. 8 | export PIP_FIND_LINKS=https://data.pyg.org/whl/torch-2.1.2+cpu.html 9 | 10 | # compile the dependencies from .in files 11 | pip-compile-cross-platform requirements/$1.in --min-python-version 3.10 -o requirements/$1.txt 12 | 13 | # remove the hashes from the .txt files 14 | # this is slightly less safe in terms of reproducibility 15 | # (e.g. if a package was re-uploaded to PyPI with the same version) 16 | # but it is necessary to allow `pip install -r requirements` to use --find-links 17 | # in our case, without --find-links, torch-cluster often cannot find the 18 | # proper wheels and throws out an error `no torch module` when trying to build 19 | sed -i '/--hash=/d' requirements/$1.txt 20 | sed -i 's/\\//g' requirements/$1.txt 21 | 22 | # removes the nvidia requirements 23 | sed -i '/nvidia/d' requirements/$1.txt 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.distutils.bdist_wheel] 6 | universal = "true" 7 | 8 | [tool.bandit] 9 | # B101 tests the use of assert 10 | # B301 and B403 test the use of pickle 11 | skips = ["B101", "B301", "B403"] 12 | exclude_dirs = ["tests", ".tox", ".venv"] 13 | 14 | [tool.pytest.ini_options] 15 | addopts = [ 16 | "-v", 17 | "-x", 18 | "--color=yes", 19 | "--cov-report=term-missing", 20 | "--cov=gflownet", 21 | "--typeguard-packages=ml_kit,tests" 22 | ] 23 | testpaths = ["tests"] 24 | pythonpath = "src/" 25 | 26 | [tool.mypy] 27 | ignore_missing_imports = true 28 | show_error_codes = true 29 | show_error_context = true 30 | show_traceback = true 31 | strict = false 32 | strict_optional = false 33 | implicit_reexport = true 34 | allow_redefinition = true 35 | files = "src" 36 | 37 | [[tool.mypy.overrides]] 38 | module = "tests.*" 39 | allow_untyped_defs = true 40 | allow_incomplete_defs = true 41 | 42 | [tool.isort] 43 | profile = "black" 44 | py_version = "auto" 45 | line_length = 120 46 | 47 | [tool.black] 48 | line-length = 120 49 | target-version = ["py310"] 50 | 51 | [project] 52 | name = "gflownet" 53 | readme = "README.md" 54 | classifiers = ["Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3"] 55 | keywords = ["gflownet"] 56 | requires-python = ">=3.10,<3.11" 57 | dynamic = ["version"] 58 | dependencies = [ 59 | "torch==2.1.2", 60 | "torch-geometric==2.4.0", 61 | "torch-scatter==2.1.2", 62 | "torch-sparse==0.6.18", 63 | "torch-cluster==1.6.3", 64 | "rdkit", 65 | "tables", 66 | "scipy", 67 | "networkx", 68 | "tensorboard", 69 | "cvxopt", 70 | "pyarrow", 71 | "gitpython", 72 | "botorch", 73 | "pyro-ppl", 74 | "gpytorch", 75 | "omegaconf>=2.3", 76 | "wandb", 77 | "pandas", # needed for QM9 and HDF5 support. 78 | ] 79 | 80 | [project.optional-dependencies] 81 | dev = [ 82 | "bandit[toml]", 83 | "black", 84 | "isort", 85 | "mypy", 86 | "pip-compile-cross-platform", 87 | "pre-commit", 88 | "pytest", 89 | "pytest-cov", 90 | "ruff", 91 | "tox", 92 | "typeguard", 93 | # Security pin 94 | "gitpython>=3.1.30", 95 | ] 96 | 97 | [[project.authors]] 98 | name = "Recursion Pharmaceuticals" 99 | email = "devs@recursionpharma.com" 100 | 101 | [tool.ruff] 102 | line-length = 120 103 | -------------------------------------------------------------------------------- /requirements/dev-3.10.in: -------------------------------------------------------------------------------- 1 | -r main-3.10.in 2 | bandit[toml] 3 | black 4 | isort 5 | mypy 6 | pip-compile-multi 7 | pre-commit 8 | pytest 9 | pytest-cov 10 | ruff 11 | tox 12 | gitpython>=3.1.30 13 | -------------------------------------------------------------------------------- /requirements/dev-3.10.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile-cross-platform 3 | # To update, run: 4 | # 5 | # pip-compile-cross-platform requirements/dev-3.10.in --output-file requirements/dev-3.10.txt --min-python-version 3.10 6 | # 7 | absl-py==2.1.0 ; python_version >= "3.10" and python_version < "4.0" 8 | antlr4-python3-runtime==4.9.3 ; python_version >= "3.10" and python_version < "4.0" 9 | appdirs==1.4.4 ; python_version >= "3.10" and python_version < "4.0" 10 | bandit[toml]==1.7.7 ; python_version >= "3.10" and python_version < "4.0" 11 | black==24.2.0 ; python_version >= "3.10" and python_version < "4.0" 12 | blosc2==2.5.1 ; python_version >= "3.10" and python_version < "4" 13 | botorch==0.9.5 ; python_version >= "3.10" and python_version < "4.0" 14 | build==1.0.3 ; python_version >= "3.10" and python_version < "4.0" 15 | certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0" 16 | cfgv==3.4.0 ; python_version >= "3.10" and python_version < "4.0" 17 | charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0" 18 | click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" 19 | colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32" or os_name == "nt") 20 | coverage[toml]==7.4.1 ; python_version >= "3.10" and python_version < "4.0" 21 | cvxopt==1.3.2 ; python_version >= "3.10" and python_version < "4.0" 22 | distlib==0.3.8 ; python_version >= "3.10" and python_version < "4.0" 23 | docker-pycreds==0.4.0 ; python_version >= "3.10" and python_version < "4.0" 24 | exceptiongroup==1.2.0 ; python_version >= "3.10" and python_version < "3.11" 25 | filelock==3.13.1 ; python_version >= "3.10" and python_version < "4.0" 26 | fsspec==2024.2.0 ; python_version >= "3.10" and python_version < "4.0" 27 | gitdb==4.0.11 ; python_version >= "3.10" and python_version < "4.0" 28 | gitpython==3.1.42 ; python_version >= "3.10" and python_version < "4.0" 29 | gpytorch==1.11 ; python_version >= "3.10" and python_version < "4.0" 30 | grpcio==1.60.1 ; python_version >= "3.10" and python_version < "4.0" 31 | identify==2.5.34 ; python_version >= "3.10" and python_version < "4.0" 32 | idna==3.6 ; python_version >= "3.10" and python_version < "4.0" 33 | iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "4.0" 34 | isort==5.13.2 ; python_version >= "3.10" and python_version < "4.0" 35 | jaxtyping==0.2.25 ; python_version >= "3.10" and python_version < "4.0" 36 | jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0" 37 | joblib==1.3.2 ; python_version >= "3.10" and python_version < "4.0" 38 | linear-operator==0.5.1 ; python_version >= "3.10" and python_version < "4.0" 39 | markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "4.0" 40 | markdown==3.5.2 ; python_version >= "3.10" and python_version < "4.0" 41 | markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0" 42 | mdurl==0.1.2 ; python_version >= "3.10" and python_version < "4.0" 43 | mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0" 44 | msgpack==1.0.7 ; python_version >= "3.10" and python_version < "4" 45 | multipledispatch==1.0.0 ; python_version >= "3.10" and python_version < "4.0" 46 | mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "4.0" 47 | mypy==1.8.0 ; python_version >= "3.10" and python_version < "4.0" 48 | ndindex==1.8 ; python_version >= "3.10" and python_version < "4" 49 | networkx==3.2.1 ; python_version >= "3.10" and python_version < "4.0" 50 | nodeenv==1.8.0 ; python_version >= "3.10" and python_version < "4.0" 51 | numexpr==2.9.0 ; python_version >= "3.10" and python_version < "4.0" 52 | numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" 53 | omegaconf==2.3.0 ; python_version >= "3.10" and python_version < "4.0" 54 | opt-einsum==3.3.0 ; python_version >= "3.10" and python_version < "4.0" 55 | packaging==23.2 ; python_version >= "3.10" and python_version < "4.0" 56 | pathspec==0.12.1 ; python_version >= "3.10" and python_version < "4.0" 57 | pbr==6.0.0 ; python_version >= "3.10" and python_version < "4.0" 58 | pillow==10.2.0 ; python_version >= "3.10" and python_version < "4.0" 59 | pip-compile-multi==2.6.3 ; python_version >= "3.10" and python_version < "4.0" 60 | pip-tools==7.3.0 ; python_version >= "3.10" and python_version < "4.0" 61 | pip==24.0 ; python_version >= "3.10" and python_version < "4.0" 62 | platformdirs==4.2.0 ; python_version >= "3.10" and python_version < "4.0" 63 | pluggy==1.4.0 ; python_version >= "3.10" and python_version < "4.0" 64 | pre-commit==3.6.1 ; python_version >= "3.10" and python_version < "4.0" 65 | protobuf==4.25.3 ; python_version >= "3.10" and python_version < "4.0" 66 | psutil==5.9.8 ; python_version >= "3.10" and python_version < "4.0" 67 | py-cpuinfo==9.0.0 ; python_version >= "3.10" and python_version < "4.0" 68 | pyarrow==15.0.0 ; python_version >= "3.10" and python_version < "4.0" 69 | pygments==2.17.2 ; python_version >= "3.10" and python_version < "4.0" 70 | pyparsing==3.1.1 ; python_version >= "3.10" and python_version < "4.0" 71 | pyproject-hooks==1.0.0 ; python_version >= "3.10" and python_version < "4.0" 72 | pyro-api==0.1.2 ; python_version >= "3.10" and python_version < "4.0" 73 | pyro-ppl==1.8.6 ; python_version >= "3.10" and python_version < "4.0" 74 | pytest-cov==4.1.0 ; python_version >= "3.10" and python_version < "4.0" 75 | pytest==8.0.0 ; python_version >= "3.10" and python_version < "4.0" 76 | pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0" 77 | rdkit==2023.9.5 ; python_version >= "3.10" and python_version < "4.0" 78 | requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0" 79 | rich==13.7.0 ; python_version >= "3.10" and python_version < "4.0" 80 | ruff==0.2.1 ; python_version >= "3.10" and python_version < "4.0" 81 | scikit-learn==1.4.0 ; python_version >= "3.10" and python_version < "4.0" 82 | scipy==1.12.0 ; python_version >= "3.10" and python_version < "4.0" 83 | sentry-sdk==1.40.4 ; python_version >= "3.10" and python_version < "4.0" 84 | setproctitle==1.3.3 ; python_version >= "3.10" and python_version < "4.0" 85 | setuptools==69.1.0 ; python_version >= "3.10" and python_version < "4.0" 86 | six==1.16.0 ; python_version >= "3.10" and python_version < "4.0" 87 | smmap==5.0.1 ; python_version >= "3.10" and python_version < "4.0" 88 | stevedore==5.1.0 ; python_version >= "3.10" and python_version < "4.0" 89 | sympy==1.12 ; python_version >= "3.10" and python_version < "4.0" 90 | tables==3.9.2 ; python_version >= "3.10" and python_version < "4.0" 91 | tensorboard-data-server==0.7.2 ; python_version >= "3.10" and python_version < "4.0" 92 | tensorboard==2.16.1 ; python_version >= "3.10" and python_version < "4.0" 93 | tf-keras==2.15.0 ; python_version >= "3.10" and python_version < "4.0" 94 | threadpoolctl==3.3.0 ; python_version >= "3.10" and python_version < "4.0" 95 | tomli==2.0.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.10" 96 | toposort==1.10 ; python_version >= "3.10" and python_version < "4.0" 97 | torch-cluster==1.6.3 ; python_version >= "3.10" and python_version < "4.0" 98 | torch-geometric==2.4.0 ; python_version >= "3.10" and python_version < "4.0" 99 | torch-scatter==2.1.2 ; python_version >= "3.10" and python_version < "4.0" 100 | torch-sparse==0.6.18 ; python_version >= "3.10" and python_version < "4.0" 101 | torch==2.1.2 ; python_version >= "3.10" and python_version < "4.0" 102 | tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0" 103 | triton==2.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 104 | typeguard==2.13.3 ; python_version >= "3.10" and python_version < "4.0" 105 | typing-extensions==4.9.0 ; python_version >= "3.10" and python_version < "4.0" 106 | urllib3==2.2.0 ; python_version >= "3.10" and python_version < "4.0" 107 | virtualenv==20.25.0 ; python_version >= "3.10" and python_version < "4.0" 108 | wandb==0.16.3 ; python_version >= "3.10" and python_version < "4.0" 109 | werkzeug==3.0.1 ; python_version >= "3.10" and python_version < "4.0" 110 | wheel==0.42.0 ; python_version >= "3.10" and python_version < "4.0" 111 | -------------------------------------------------------------------------------- /requirements/main-3.10.in: -------------------------------------------------------------------------------- 1 | torch==2.1.2 2 | torch-geometric==2.4.0 3 | torch-scatter==2.1.2 4 | torch-sparse==0.6.18 5 | torch-cluster==1.6.3 6 | rdkit 7 | tables 8 | scipy 9 | networkx 10 | tensorboard 11 | cvxopt 12 | pyarrow 13 | gitpython 14 | botorch 15 | pyro-ppl 16 | gpytorch 17 | omegaconf>=2.3 18 | wandb 19 | -------------------------------------------------------------------------------- /requirements/main-3.10.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile-cross-platform 3 | # To update, run: 4 | # 5 | # pip-compile-cross-platform requirements/main-3.10.in --output-file requirements/main-3.10.txt --min-python-version 3.10 6 | # 7 | absl-py==2.1.0 ; python_version >= "3.10" and python_version < "4.0" 8 | antlr4-python3-runtime==4.9.3 ; python_version >= "3.10" and python_version < "4.0" 9 | appdirs==1.4.4 ; python_version >= "3.10" and python_version < "4.0" 10 | blosc2==2.5.1 ; python_version >= "3.10" and python_version < "4" 11 | botorch==0.9.5 ; python_version >= "3.10" and python_version < "4.0" 12 | certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0" 13 | charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0" 14 | click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" 15 | colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Windows" 16 | cvxopt==1.3.2 ; python_version >= "3.10" and python_version < "4.0" 17 | docker-pycreds==0.4.0 ; python_version >= "3.10" and python_version < "4.0" 18 | filelock==3.13.1 ; python_version >= "3.10" and python_version < "4.0" 19 | fsspec==2024.2.0 ; python_version >= "3.10" and python_version < "4.0" 20 | gitdb==4.0.11 ; python_version >= "3.10" and python_version < "4.0" 21 | gitpython==3.1.42 ; python_version >= "3.10" and python_version < "4.0" 22 | gpytorch==1.11 ; python_version >= "3.10" and python_version < "4.0" 23 | grpcio==1.60.1 ; python_version >= "3.10" and python_version < "4.0" 24 | idna==3.6 ; python_version >= "3.10" and python_version < "4.0" 25 | jaxtyping==0.2.25 ; python_version >= "3.10" and python_version < "4.0" 26 | jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0" 27 | joblib==1.3.2 ; python_version >= "3.10" and python_version < "4.0" 28 | linear-operator==0.5.1 ; python_version >= "3.10" and python_version < "4.0" 29 | markdown==3.5.2 ; python_version >= "3.10" and python_version < "4.0" 30 | markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0" 31 | mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0" 32 | msgpack==1.0.7 ; python_version >= "3.10" and python_version < "4" 33 | multipledispatch==1.0.0 ; python_version >= "3.10" and python_version < "4.0" 34 | ndindex==1.8 ; python_version >= "3.10" and python_version < "4" 35 | networkx==3.2.1 ; python_version >= "3.10" and python_version < "4.0" 36 | numexpr==2.9.0 ; python_version >= "3.10" and python_version < "4.0" 37 | numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" 38 | omegaconf==2.3.0 ; python_version >= "3.10" and python_version < "4.0" 39 | opt-einsum==3.3.0 ; python_version >= "3.10" and python_version < "4.0" 40 | packaging==23.2 ; python_version >= "3.10" and python_version < "4.0" 41 | pillow==10.2.0 ; python_version >= "3.10" and python_version < "4.0" 42 | protobuf==4.25.3 ; python_version >= "3.10" and python_version < "4.0" 43 | psutil==5.9.8 ; python_version >= "3.10" and python_version < "4.0" 44 | py-cpuinfo==9.0.0 ; python_version >= "3.10" and python_version < "4.0" 45 | pyarrow==15.0.0 ; python_version >= "3.10" and python_version < "4.0" 46 | pyparsing==3.1.1 ; python_version >= "3.10" and python_version < "4.0" 47 | pyro-api==0.1.2 ; python_version >= "3.10" and python_version < "4.0" 48 | pyro-ppl==1.8.6 ; python_version >= "3.10" and python_version < "4.0" 49 | pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0" 50 | rdkit==2023.9.5 ; python_version >= "3.10" and python_version < "4.0" 51 | requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0" 52 | scikit-learn==1.4.0 ; python_version >= "3.10" and python_version < "4.0" 53 | scipy==1.12.0 ; python_version >= "3.10" and python_version < "4.0" 54 | sentry-sdk==1.40.4 ; python_version >= "3.10" and python_version < "4.0" 55 | setproctitle==1.3.3 ; python_version >= "3.10" and python_version < "4.0" 56 | setuptools==69.1.0 ; python_version >= "3.10" and python_version < "4.0" 57 | six==1.16.0 ; python_version >= "3.10" and python_version < "4.0" 58 | smmap==5.0.1 ; python_version >= "3.10" and python_version < "4.0" 59 | sympy==1.12 ; python_version >= "3.10" and python_version < "4.0" 60 | tables==3.9.2 ; python_version >= "3.10" and python_version < "4.0" 61 | tensorboard-data-server==0.7.2 ; python_version >= "3.10" and python_version < "4.0" 62 | tensorboard==2.16.1 ; python_version >= "3.10" and python_version < "4.0" 63 | tf-keras==2.15.0 ; python_version >= "3.10" and python_version < "4.0" 64 | threadpoolctl==3.3.0 ; python_version >= "3.10" and python_version < "4.0" 65 | torch-cluster==1.6.3 ; python_version >= "3.10" and python_version < "4.0" 66 | torch-geometric==2.4.0 ; python_version >= "3.10" and python_version < "4.0" 67 | torch-scatter==2.1.2 ; python_version >= "3.10" and python_version < "4.0" 68 | torch-sparse==0.6.18 ; python_version >= "3.10" and python_version < "4.0" 69 | torch==2.1.2 ; python_version >= "3.10" and python_version < "4.0" 70 | tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0" 71 | triton==2.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 72 | typeguard==2.13.3 ; python_version >= "3.10" and python_version < "4.0" 73 | typing-extensions==4.9.0 ; python_version >= "3.10" and python_version < "4.0" 74 | urllib3==2.2.0 ; python_version >= "3.10" and python_version < "4.0" 75 | wandb==0.16.3 ; python_version >= "3.10" and python_version < "4.0" 76 | werkzeug==3.0.1 ; python_version >= "3.10" and python_version < "4.0" 77 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ast import literal_eval 3 | from subprocess import check_output # nosec - command is hard-coded, no possibility of injection 4 | 5 | from setuptools import setup 6 | 7 | 8 | def _get_next_version(): 9 | if "SEMVER" in os.environ: 10 | return os.environ.get("SEMVER") 11 | 12 | # Note, this should only be used for development builds. Only robots can 13 | # create releases on PyPI from trunk, and the robots should know have the 14 | # `SEMVER` variable loaded at runtime. 15 | with open("VERSION", "r") as f: 16 | lines = f.read().splitlines() 17 | version_parts = {k: literal_eval(v) for k, v in map(lambda x: x.split("="), lines)} 18 | major = int(version_parts["MAJOR"]) 19 | minor = int(version_parts["MINOR"]) 20 | versions = check_output(["git", "tag", "--list"], encoding="utf-8").splitlines() # nosec - command is hard-coded 21 | try: 22 | latest_patch = max(int(v.rsplit(".", 1)[1]) for v in versions if v.startswith(f"v{major}.{minor}.")) 23 | except ValueError: # no tags for this major.minor exist yet 24 | latest_patch = -1 25 | return f"{major}.{minor}.{latest_patch+1}" 26 | 27 | 28 | setup(name="gflownet", version=_get_next_version()) 29 | -------------------------------------------------------------------------------- /src/gflownet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, NewType, Optional, Tuple 2 | 3 | import torch_geometric.data as gd 4 | from torch import Tensor, nn 5 | 6 | from .config import Config 7 | 8 | # This type represents a set of scalar properties attached to each object in a batch. 9 | ObjectProperties = NewType("ObjectProperties", Tensor) # type: ignore 10 | 11 | # This type represents log-scalars, in particular log-rewards at the scale we operate with with GFlowNets 12 | # for example, converting a reward ObjectProperties to a log-scalar with log [(sum R_i omega_i) ** beta] 13 | LogScalar = NewType("LogScalar", Tensor) # type: ignore 14 | # This type represents linear-scalars 15 | LinScalar = NewType("LinScalar", Tensor) # type: ignore 16 | 17 | 18 | class GFNAlgorithm: 19 | updates: int = 0 20 | global_cfg: Config 21 | is_eval: bool = False 22 | 23 | def step(self): 24 | self.updates += 1 # This isn't used anywhere? 25 | 26 | def compute_batch_losses( 27 | self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 28 | ) -> Tuple[Tensor, Dict[str, Tensor]]: 29 | """Computes the loss for a batch of data, and proves logging informations 30 | 31 | Parameters 32 | ---------- 33 | model: nn.Module 34 | The model being trained or evaluated 35 | batch: gd.Batch 36 | A batch of graphs 37 | num_bootstrap: Optional[int] 38 | The number of trajectories with reward targets in the batch (if applicable). 39 | 40 | Returns 41 | ------- 42 | loss: Tensor 43 | The loss for that batch 44 | info: Dict[str, Tensor] 45 | Logged information about model predictions. 46 | """ 47 | raise NotImplementedError() 48 | 49 | def construct_batch(self, trajs, cond_info, log_rewards): 50 | """Construct a batch from a list of trajectories and their information 51 | 52 | Typically calls ctx.graph_to_Data and ctx.collate to convert the trajectories into 53 | a batch of graphs and adds the necessary attributes for training. 54 | 55 | Parameters 56 | ---------- 57 | trajs: List[List[tuple[Graph, GraphAction]]] 58 | A list of N trajectories. 59 | cond_info: Tensor 60 | The conditional info that is considered for each trajectory. Shape (N, n_info) 61 | log_rewards: Tensor 62 | The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,) 63 | Returns 64 | ------- 65 | batch: gd.Batch 66 | A (CPU) Batch object with relevant attributes added 67 | """ 68 | raise NotImplementedError() 69 | 70 | def get_random_action_prob(self, it: int): 71 | if self.is_eval: 72 | return self.global_cfg.algo.valid_random_action_prob 73 | if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: 74 | return self.global_cfg.algo.train_random_action_prob 75 | return 0 76 | 77 | 78 | class GFNTask: 79 | def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: 80 | """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. 81 | 82 | Parameters 83 | ---------- 84 | cond_info: Dict[str, Tensor] 85 | A dictionary with various conditional informations (e.g. temperature) 86 | obj_props: ObjectProperties 87 | A 2d tensor where each row represents a series of object properties. 88 | 89 | Returns 90 | ------- 91 | reward: RewardScalar 92 | A 1d tensor, a scalar log-reward for each minibatch entry. 93 | """ 94 | raise NotImplementedError() 95 | 96 | def compute_obj_properties(self, objs: List[Any]) -> Tuple[ObjectProperties, Tensor]: 97 | """Compute the flat rewards of objs according the the tasks' proxies 98 | 99 | Parameters 100 | ---------- 101 | objs: List[Any] 102 | A list of n objects. 103 | Returns 104 | ------- 105 | obj_probs: ObjectProperties 106 | A 2d tensor (m, p), a vector of scalar properties for the m <= n valid objects. 107 | is_valid: Tensor 108 | A 1d tensor (n,), a boolean indicating whether each object is valid. 109 | """ 110 | raise NotImplementedError() 111 | 112 | def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: 113 | """Sample conditional information for n objects 114 | 115 | Parameters 116 | ---------- 117 | n: int 118 | The number of objects to sample conditional information for. 119 | train_it: int 120 | The training iteration number. 121 | 122 | Returns 123 | ------- 124 | cond_info: Dict[str, Tensor] 125 | A dictionary with various conditional informations (e.g. temperature) 126 | """ 127 | raise NotImplementedError() 128 | -------------------------------------------------------------------------------- /src/gflownet/algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/algo/__init__.py -------------------------------------------------------------------------------- /src/gflownet/algo/advantage_actor_critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.data as gd 4 | from torch import Tensor 5 | 6 | from gflownet.config import Config 7 | from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory 8 | from gflownet.utils.misc import get_worker_device 9 | 10 | from .graph_sampling import GraphSampler 11 | 12 | 13 | class A2C: 14 | def __init__( 15 | self, 16 | env: GraphBuildingEnv, 17 | ctx: GraphBuildingEnvContext, 18 | cfg: Config, 19 | ): 20 | """Advantage Actor-Critic implementation, see 21 | Asynchronous Methods for Deep Reinforcement Learning, 22 | Volodymyr Mnih, Adria Puigdomenech Badia, Mehdi Mirza, Alex Graves, Timothy Lillicrap, Tim 23 | Harley, David Silver, Koray Kavukcuoglu 24 | Proceedings of The 33rd International Conference on Machine Learning, 2016 25 | 26 | Hyperparameters used: 27 | illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions 28 | 29 | Parameters 30 | ---------- 31 | env: GraphBuildingEnv 32 | A graph environment. 33 | ctx: GraphBuildingEnvContext 34 | A context. 35 | cfg: Config 36 | The experiment configuration 37 | 38 | """ 39 | self.ctx = ctx 40 | self.env = env 41 | self.max_len = cfg.algo.max_len 42 | self.max_nodes = cfg.algo.max_nodes 43 | self.illegal_action_logreward = cfg.algo.illegal_action_logreward 44 | self.entropy_coef = cfg.algo.a2c.entropy 45 | self.gamma = cfg.algo.a2c.gamma 46 | self.invalid_penalty = cfg.algo.a2c.penalty 47 | assert self.gamma == 1 48 | self.bootstrap_own_reward = False 49 | # Experimental flags 50 | self.sample_temp = 1 51 | self.do_q_prime_correction = False 52 | self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp) 53 | 54 | def create_training_data_from_own_samples( 55 | self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float 56 | ): 57 | """Generate trajectories by sampling a model 58 | 59 | Parameters 60 | ---------- 61 | model: nn.Module 62 | The model being sampled 63 | graphs: List[Graph] 64 | List of N Graph endpoints 65 | cond_info: torch.tensor 66 | Conditional information, shape (N, n_info) 67 | random_action_prob: float 68 | Probability of taking a random action 69 | Returns 70 | ------- 71 | data: List[Dict] 72 | A list of trajectories. Each trajectory is a dict with keys 73 | - trajs: List[Tuple[Graph, GraphAction]] 74 | - fwd_logprob: log Z + sum logprobs P_F 75 | - bck_logprob: sum logprobs P_B 76 | - is_valid: is the generated graph valid according to the env & ctx 77 | """ 78 | dev = get_worker_device() 79 | cond_info = cond_info.to(dev) 80 | data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) 81 | return data 82 | 83 | def create_training_data_from_graphs(self, graphs): 84 | """Generate trajectories from known endpoints 85 | 86 | Parameters 87 | ---------- 88 | graphs: List[Graph] 89 | List of Graph endpoints 90 | 91 | Returns 92 | ------- 93 | trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] 94 | A list of trajectories. 95 | """ 96 | return [{"traj": generate_forward_trajectory(i)} for i in graphs] 97 | 98 | def construct_batch(self, trajs, cond_info, log_rewards): 99 | """Construct a batch from a list of trajectories and their information 100 | 101 | Parameters 102 | ---------- 103 | trajs: List[List[tuple[Graph, GraphAction]]] 104 | A list of N trajectories. 105 | cond_info: Tensor 106 | The conditional info that is considered for each trajectory. Shape (N, n_info) 107 | log_rewards: Tensor 108 | The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,) 109 | Returns 110 | ------- 111 | batch: gd.Batch 112 | A (CPU) Batch object with relevant attributes added 113 | """ 114 | torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] 115 | actions = [ 116 | self.ctx.GraphAction_to_ActionIndex(g, a) 117 | for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) 118 | ] 119 | batch = self.ctx.collate(torch_graphs) 120 | batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) 121 | batch.actions = torch.tensor(actions) 122 | batch.log_rewards = log_rewards 123 | batch.cond_info = cond_info 124 | batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() 125 | return batch 126 | 127 | def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): 128 | """Compute the losses over trajectories contained in the batch 129 | 130 | Parameters 131 | ---------- 132 | model: TrajectoryBalanceModel 133 | A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`. 134 | Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info) 135 | batch: gd.Batch 136 | batch of graphs inputs as per constructed by `self.construct_batch` 137 | num_bootstrap: int 138 | the number of trajectories for which the reward loss is computed. Ignored if 0.""" 139 | dev = batch.x.device 140 | # A single trajectory is comprised of many graphs 141 | num_trajs = int(batch.traj_lens.shape[0]) 142 | rewards = torch.exp(batch.log_rewards) 143 | cond_info = batch.cond_info 144 | 145 | # This index says which trajectory each graph belongs to, so 146 | # it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is 147 | # of length 4, trajectory 1 of length 3, and so on. 148 | batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) 149 | 150 | # Forward pass of the model, returns a GraphActionCategorical and per graph predictions 151 | # Here we will interpret the logits of the fwd_cat as Q values 152 | policy, per_state_preds = model(batch, cond_info[batch_idx]) 153 | V = per_state_preds[:, 0] 154 | G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 155 | G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object 156 | A = G - V 157 | log_probs = policy.log_prob(batch.actions) 158 | 159 | V_loss = A.pow(2).mean() 160 | pol_objective = (log_probs * A.detach()).mean() + self.entropy_coef * policy.entropy().mean() 161 | pol_loss = -pol_objective 162 | 163 | loss = V_loss + pol_loss 164 | invalid_mask = 1 - batch.is_valid 165 | info = { 166 | "V_loss": V_loss, 167 | "A": A.mean(), 168 | "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, 169 | "loss": loss.item(), 170 | } 171 | 172 | if not torch.isfinite(loss).all(): 173 | raise ValueError("loss is not finite") 174 | return loss, info 175 | -------------------------------------------------------------------------------- /src/gflownet/algo/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import IntEnum 3 | from typing import Optional 4 | 5 | from gflownet.utils.misc import StrictDataClass 6 | 7 | 8 | class Backward(IntEnum): 9 | """ 10 | See algo.trajectory_balance.TrajectoryBalance for details. 11 | The A variant of `Maxent` and `GSQL` equire the environment to provide $n$. 12 | This is true for sEH but not QM9. 13 | """ 14 | 15 | Uniform = 1 16 | Free = 2 17 | Maxent = 3 18 | MaxentA = 4 19 | GSQL = 5 20 | GSQLA = 6 21 | 22 | 23 | class NLoss(IntEnum): 24 | """See algo.trajectory_balance.TrajectoryBalance for details.""" 25 | 26 | none = 0 27 | Transition = 1 28 | SubTB1 = 2 29 | TermTB1 = 3 30 | StartTB1 = 4 31 | TB = 5 32 | 33 | 34 | class TBVariant(IntEnum): 35 | """See algo.trajectory_balance.TrajectoryBalance for details.""" 36 | 37 | TB = 0 38 | SubTB1 = 1 39 | DB = 2 40 | 41 | 42 | class LossFN(IntEnum): 43 | """ 44 | The loss function to use. 45 | 46 | - GHL: Kaan Gokcesu, Hakan Gokcesu 47 | https://arxiv.org/pdf/2108.12627.pdf, 48 | Note: This can be used as a differentiable version of HUB. 49 | """ 50 | 51 | MSE = 0 52 | MAE = 1 53 | HUB = 2 54 | GHL = 3 55 | 56 | 57 | @dataclass 58 | class TBConfig(StrictDataClass): 59 | """Trajectory Balance config. 60 | 61 | Attributes 62 | ---------- 63 | bootstrap_own_reward : bool 64 | Whether to bootstrap the reward with the own reward. (deprecated) 65 | epsilon : Optional[float] 66 | The epsilon parameter in log-flow smoothing (see paper) 67 | reward_loss_multiplier : float 68 | The multiplier for the reward loss when bootstrapping the reward. (deprecated) 69 | variant : TBVariant 70 | The loss variant. See algo.trajectory_balance.TrajectoryBalance for details. 71 | do_correct_idempotent : bool 72 | Whether to correct for idempotent actions 73 | do_parameterize_p_b : bool 74 | Whether to parameterize the P_B distribution (otherwise it is uniform) 75 | do_predict_n : bool 76 | Whether to predict the number of paths in the graph 77 | do_length_normalize : bool 78 | Whether to normalize the loss by the length of the trajectory 79 | subtb_max_len : int 80 | The maximum length trajectories, used to cache subTB computation indices 81 | Z_learning_rate : float 82 | The learning rate for the logZ parameter (only relevant when do_subtb is False) 83 | Z_lr_decay : float 84 | The learning rate decay for the logZ parameter (only relevant when do_subtb is False) 85 | loss_fn: LossFN 86 | The loss function to use 87 | loss_fn_par: float 88 | The loss function parameter in case of Huber loss, it is the delta 89 | n_loss: NLoss 90 | The $n$ loss to use (defaults to NLoss.none i.e., do not learn $n$) 91 | n_loss_multiplier: float 92 | The multiplier for the $n$ loss 93 | backward_policy: Backward 94 | The backward policy to use 95 | """ 96 | 97 | bootstrap_own_reward: bool = False 98 | epsilon: Optional[float] = None 99 | reward_loss_multiplier: float = 1.0 100 | variant: TBVariant = TBVariant.TB 101 | do_correct_idempotent: bool = False 102 | do_parameterize_p_b: bool = False 103 | do_predict_n: bool = False 104 | do_sample_p_b: bool = False 105 | do_length_normalize: bool = False 106 | subtb_max_len: int = 128 107 | Z_learning_rate: float = 1e-4 108 | Z_lr_decay: float = 50_000 109 | cum_subtb: bool = True 110 | loss_fn: LossFN = LossFN.MSE 111 | loss_fn_par: float = 1.0 112 | n_loss: NLoss = NLoss.none 113 | n_loss_multiplier: float = 1.0 114 | backward_policy: Backward = Backward.Uniform 115 | 116 | 117 | @dataclass 118 | class MOQLConfig(StrictDataClass): 119 | gamma: float = 1 120 | num_omega_samples: int = 32 121 | num_objectives: int = 2 122 | lambda_decay: int = 10_000 123 | penalty: float = -10 124 | 125 | 126 | @dataclass 127 | class A2CConfig(StrictDataClass): 128 | entropy: float = 0.01 129 | gamma: float = 1 130 | penalty: float = -10 131 | 132 | 133 | @dataclass 134 | class FMConfig(StrictDataClass): 135 | epsilon: float = 1e-38 136 | balanced_loss: bool = False 137 | leaf_coef: float = 10 138 | correct_idempotent: bool = False 139 | 140 | 141 | @dataclass 142 | class SQLConfig(StrictDataClass): 143 | alpha: float = 0.01 144 | gamma: float = 1 145 | penalty: float = -10 146 | 147 | 148 | @dataclass 149 | class AlgoConfig(StrictDataClass): 150 | """Generic configuration for algorithms 151 | 152 | Attributes 153 | ---------- 154 | method : str 155 | The name of the algorithm to use (e.g. "TB") 156 | num_from_policy : int 157 | The number of on-policy samples for a training batch. 158 | If using a replay buffer, see `replay.num_from_replay` for the number of samples from the replay buffer, and 159 | `replay.num_new_samples` for the number of new samples to add to the replay buffer (e.g. `num_from_policy=0`, 160 | and `num_new_samples=N` inserts `N` new samples in the replay buffer at each step, but does not make that data 161 | part of the training batch). 162 | num_from_dataset : int 163 | The number of samples from the dataset for a training batch 164 | valid_num_from_policy : int 165 | The number of on-policy samples for a validation batch 166 | valid_num_from_dataset : int 167 | The number of samples from the dataset for a validation batch 168 | max_len : int 169 | The maximum length of a trajectory 170 | max_nodes : int 171 | The maximum number of nodes in a generated graph 172 | max_edges : int 173 | The maximum number of edges in a generated graph 174 | illegal_action_logreward : float 175 | The log reward an agent gets for illegal actions 176 | train_random_action_prob : float 177 | The probability of taking a random action during training 178 | train_det_after: Optional[int] 179 | Do not take random actions after this number of steps 180 | valid_random_action_prob : float 181 | The probability of taking a random action during validation 182 | sampling_tau : float 183 | The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) 184 | """ 185 | 186 | method: str = "TB" 187 | num_from_policy: int = 64 188 | num_from_dataset: int = 0 189 | valid_num_from_policy: int = 64 190 | valid_num_from_dataset: int = 0 191 | max_len: int = 128 192 | max_nodes: int = 128 193 | max_edges: int = 128 194 | illegal_action_logreward: float = -100 195 | train_random_action_prob: float = 0.0 196 | train_det_after: Optional[int] = None 197 | valid_random_action_prob: float = 0.0 198 | sampling_tau: float = 0.0 199 | tb: TBConfig = field(default_factory=TBConfig) 200 | moql: MOQLConfig = field(default_factory=MOQLConfig) 201 | a2c: A2CConfig = field(default_factory=A2CConfig) 202 | fm: FMConfig = field(default_factory=FMConfig) 203 | sql: SQLConfig = field(default_factory=SQLConfig) 204 | -------------------------------------------------------------------------------- /src/gflownet/algo/flow_matching.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | import torch.nn as nn 4 | import torch_geometric.data as gd 5 | from torch_scatter import scatter 6 | 7 | from gflownet.algo.trajectory_balance import TrajectoryBalance 8 | from gflownet.config import Config 9 | from gflownet.envs.graph_building_env import ( 10 | Graph, 11 | GraphAction, 12 | GraphActionType, 13 | GraphBuildingEnv, 14 | GraphBuildingEnvContext, 15 | ) 16 | 17 | 18 | def relabel(ga: GraphAction, g: Graph): 19 | """Relabel the nodes for g to 0-N, and the graph action ga applied to g. 20 | 21 | This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be 22 | labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that 23 | creates a gap in 0-N, leading to a faulty encoding of the graph. 24 | """ 25 | rmap = dict(zip(g.nodes, range(len(g.nodes)))) 26 | if not len(g) and ga.action == GraphActionType.AddNode: 27 | rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0 28 | g = nx.relabel_nodes(g, rmap) 29 | if ga.source is not None: 30 | ga.source = rmap[ga.source] 31 | if ga.target is not None: 32 | ga.target = rmap[ga.target] 33 | return ga, g 34 | 35 | 36 | class FlowMatching(TrajectoryBalance): # TODO: FM inherits from TB but we could have a generic GFNAlgorithm class 37 | def __init__( 38 | self, 39 | env: GraphBuildingEnv, 40 | ctx: GraphBuildingEnvContext, 41 | cfg: Config, 42 | ): 43 | super().__init__(env, ctx, cfg) 44 | self.fm_epsilon = torch.as_tensor(cfg.algo.fm.epsilon).log() 45 | # We include the "balanced loss" as a possibility to reproduce results from the FM paper, but 46 | # in a number of settings the regular loss is more stable. 47 | self.fm_balanced_loss = cfg.algo.fm.balanced_loss 48 | self.fm_leaf_coef = cfg.algo.fm.leaf_coef 49 | self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent 50 | 51 | def construct_batch(self, trajs, cond_info, log_rewards): 52 | """Construct a batch from a list of trajectories and their information 53 | 54 | Parameters 55 | ---------- 56 | trajs: List[List[tuple[Graph, GraphAction]]] 57 | A list of N trajectories. 58 | cond_info: Tensor 59 | The conditional info that is considered for each trajectory. Shape (N, n_info) 60 | log_rewards: Tensor 61 | The transformed reward (e.g. log(R(x) ** beta)) for each trajectory. Shape (N,) 62 | Returns 63 | ------- 64 | batch: gd.Batch 65 | A (CPU) Batch object with relevant attributes added 66 | """ 67 | if not self.correct_idempotent: 68 | # For every s' (i.e. every state except the first of each trajectory), enumerate parents 69 | parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]] 70 | # convert parents to Data 71 | parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent] 72 | else: 73 | # Here we again enumerate parents 74 | states = [i[0] for tj in trajs for i in tj["traj"][1:]] 75 | base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states] 76 | base_parent_graphs = [ 77 | [self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents 78 | ] 79 | parents = [] 80 | parent_graphs = [] 81 | for state, parent_set, parent_set_graphs in zip(states, base_parents, base_parent_graphs): 82 | new_parent_set = [] 83 | new_parent_graphs = [] 84 | # But for each parent we add all the possible (action, parent) pairs to the sets of parents 85 | for (ga, p), pd in zip(parent_set, parent_set_graphs): 86 | ipa = self.get_idempotent_actions(p, pd, state, ga, return_aidx=False) 87 | new_parent_set += [(a, p) for a in ipa] 88 | new_parent_graphs += [pd] * len(ipa) 89 | parents.append(new_parent_set) 90 | parent_graphs += new_parent_graphs 91 | # Implementation Note: no further correction is required for environments where episodes 92 | # always end in a Stop action. If this is not the case, then this implementation is 93 | # incorrect in that it doesn't account for the multiple ways that one could reach the 94 | # terminal state (because it assumes that a terminal state has only one parent and gives 95 | # 100% of the reward-flow to the edge between that parent and the terminal state, which 96 | # for stop actions is correct). Notably, this error will happen in environments where 97 | # there are invalid states that make episodes end prematurely (when those invalid states 98 | # have multiple possible parents). 99 | 100 | # convert actions to ActionIndex 101 | parent_actions = [pact for parent in parents for pact, pstate in parent] 102 | parent_actionidxs = [ 103 | self.ctx.GraphAction_to_ActionIndex(gdata, a) for gdata, a in zip(parent_graphs, parent_actions) 104 | ] 105 | # convert state to Data 106 | state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]] 107 | terminal_actions = [ 108 | self.ctx.GraphAction_to_ActionIndex(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) 109 | for tj in trajs 110 | ] 111 | 112 | # Create a batch from [*parents, *states]. This order will make it easier when computing the loss 113 | batch = self.ctx.collate(parent_graphs + state_graphs) 114 | batch.num_parents = torch.tensor([len(i) for i in parents]) 115 | batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) 116 | batch.parent_acts = torch.tensor(parent_actionidxs) 117 | batch.terminal_acts = torch.tensor(terminal_actions) 118 | batch.log_rewards = log_rewards 119 | batch.cond_info = cond_info 120 | batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() 121 | if self.correct_idempotent: 122 | raise ValueError("Not implemented") 123 | return batch 124 | 125 | def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): 126 | dev = batch.x.device 127 | eps = self.fm_epsilon.to(dev) 128 | # Compute relevant quantities 129 | num_trajs = len(batch.log_rewards) 130 | num_states = int(batch.num_parents.shape[0]) 131 | total_num_parents = batch.num_parents.sum() 132 | # Compute, for every parent, the index of the state it corresponds to (all states are 133 | # considered numbered 0..N regardless of which trajectory they correspond to) 134 | parents_state_idx = torch.arange(num_states, device=dev).repeat_interleave(batch.num_parents) 135 | # Compute, for every state, the index of the trajectory it corresponds to 136 | states_traj_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens - 1) 137 | # Idem for parents 138 | parents_traj_idx = states_traj_idx.repeat_interleave(batch.num_parents) 139 | # Compute the index of the first graph of every trajectory via a cumsum of the trajectory 140 | # lengths. This works because by design the first parent of every trajectory is s0 (i.e. s1 141 | # only has one parent that is s0) 142 | num_parents_per_traj = scatter(batch.num_parents, states_traj_idx, 0, reduce="sum") 143 | first_graph_idx = torch.cumsum( 144 | torch.cat([torch.zeros_like(num_parents_per_traj[0])[None], num_parents_per_traj]), 0 145 | ) 146 | # Similarly we want the index of the last graph of each trajectory 147 | final_graph_idx = torch.cumsum(batch.traj_lens - 1, 0) + total_num_parents - 1 148 | 149 | # Query the model for Fsa. The model will output a GraphActionCategorical, but we will 150 | # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of 151 | # log F(s,a) so we don't have to change anything in the sampling routines. 152 | cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) 153 | # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the 154 | # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs 155 | # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We 156 | # only need the parent's outputs so we specify those batch indices. 157 | parent_log_F_sa = cat.log_prob( 158 | batch.parent_acts, logprobs=cat.logits, batch=torch.arange(total_num_parents, device=dev) 159 | ) 160 | # The inflows is then simply the sum reduction of exponentiating the log edge flows. The 161 | # indices are the state index that each parent belongs to. 162 | log_inflows = scatter(parent_log_F_sa.exp(), parents_state_idx, 0, reduce="sum").log() 163 | # To compute the outflows we can just logsumexp the log F(s,a) predictions. We do so for the 164 | # entire batch, which is slightly wasteful (TODO). We only take the last outflows here, and 165 | # later take the log outflows of s0 to estimate logZ. 166 | all_log_outflows = cat.logsumexp() 167 | log_outflows = all_log_outflows[total_num_parents:] 168 | 169 | # The loss of intermediary states is inflow - outflow. We use the log-epsilon variant (see FM paper) 170 | intermediate_loss = (torch.logaddexp(log_inflows, eps) - torch.logaddexp(log_outflows, eps)).pow(2) 171 | # To compute the loss of the terminal states we match F(s, a'), where a' is the action that 172 | # terminated the trajectory, to R(s). We again use the mechanism of log_prob 173 | log_F_s_stop = cat.log_prob(batch.terminal_acts, cat.logits, final_graph_idx) 174 | terminal_loss = (torch.logaddexp(log_F_s_stop, eps) - torch.logaddexp(batch.log_rewards, eps)).pow(2) 175 | 176 | if self.fm_balanced_loss: 177 | loss = intermediate_loss.mean() + terminal_loss.mean() * self.fm_leaf_coef 178 | else: 179 | loss = (intermediate_loss.sum() + terminal_loss.sum()) / ( 180 | intermediate_loss.shape[0] + terminal_loss.shape[0] 181 | ) 182 | 183 | # logZ is simply the outflow of s0, the first graph of each parent set. 184 | logZ = all_log_outflows[first_graph_idx] 185 | info = { 186 | "intermediate_loss": intermediate_loss.mean().item(), 187 | "terminal_loss": terminal_loss.mean().item(), 188 | "loss": loss.item(), 189 | "logZ": logZ.mean().item(), 190 | } 191 | return loss, info 192 | -------------------------------------------------------------------------------- /src/gflownet/algo/multiobjective_reinforce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.data as gd 3 | from torch_scatter import scatter 4 | 5 | from gflownet.algo.trajectory_balance import TrajectoryBalance, TrajectoryBalanceModel 6 | from gflownet.config import Config 7 | from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext 8 | 9 | 10 | class MultiObjectiveReinforce(TrajectoryBalance): 11 | """ 12 | Class that inherits from TrajectoryBalance and implements the multi-objective reinforce algorithm 13 | """ 14 | 15 | def __init__( 16 | self, 17 | env: GraphBuildingEnv, 18 | ctx: GraphBuildingEnvContext, 19 | cfg: Config, 20 | ): 21 | super().__init__(env, ctx, cfg) 22 | 23 | def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): 24 | """Compute multi objective REINFORCE loss over trajectories contained in the batch""" 25 | dev = batch.x.device 26 | # A single trajectory is comprised of many graphs 27 | num_trajs = int(batch.traj_lens.shape[0]) 28 | rewards = torch.exp(batch.log_rewards) 29 | cond_info = batch.cond_info 30 | 31 | # This index says which trajectory each graph belongs to, so 32 | # it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is 33 | # of length 4, trajectory 1 of length 3, and so on. 34 | batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) 35 | 36 | # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions 37 | fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx]) 38 | 39 | # This is the log prob of each action in the trajectory 40 | log_prob = fwd_cat.log_prob(batch.actions) 41 | 42 | # Take log rewards, and clip 43 | assert rewards.ndim == 1 44 | traj_log_prob = scatter(log_prob, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") 45 | 46 | traj_losses = traj_log_prob * (-rewards - (-1) * rewards.mean()) 47 | 48 | loss = traj_losses.mean() 49 | info = { 50 | "loss": loss.item(), 51 | } 52 | if not torch.isfinite(traj_losses).all(): 53 | raise ValueError("loss is not finite") 54 | return loss, info 55 | -------------------------------------------------------------------------------- /src/gflownet/algo/soft_q_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.data as gd 4 | from torch import Tensor 5 | from torch_scatter import scatter 6 | 7 | from gflownet.algo.graph_sampling import GraphSampler 8 | from gflownet.config import Config 9 | from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory 10 | from gflownet.utils.misc import get_worker_device 11 | 12 | 13 | class SoftQLearning: 14 | def __init__( 15 | self, 16 | env: GraphBuildingEnv, 17 | ctx: GraphBuildingEnvContext, 18 | cfg: Config, 19 | ): 20 | """Soft Q-Learning implementation, see 21 | Haarnoja, Tuomas, Haoran Tang, Pieter Abbeel, and Sergey Levine. "Reinforcement learning with deep 22 | energy-based policies." In International conference on machine learning, pp. 1352-1361. PMLR, 2017. 23 | 24 | Hyperparameters used: 25 | illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions 26 | 27 | Parameters 28 | ---------- 29 | env: GraphBuildingEnv 30 | A graph environment. 31 | ctx: GraphBuildingEnvContext 32 | A context. 33 | cfg: Config 34 | The experiment configuration 35 | """ 36 | self.ctx = ctx 37 | self.env = env 38 | self.max_len = cfg.algo.max_len 39 | self.max_nodes = cfg.algo.max_nodes 40 | self.illegal_action_logreward = cfg.algo.illegal_action_logreward 41 | self.alpha = cfg.algo.sql.alpha 42 | self.gamma = cfg.algo.sql.gamma 43 | self.invalid_penalty = cfg.algo.sql.penalty 44 | self.bootstrap_own_reward = False 45 | # Experimental flags 46 | self.sample_temp = 1 47 | self.do_q_prime_correction = False 48 | self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp) 49 | 50 | def create_training_data_from_own_samples( 51 | self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float 52 | ): 53 | """Generate trajectories by sampling a model 54 | 55 | Parameters 56 | ---------- 57 | model: nn.Module 58 | The model being sampled 59 | graphs: List[Graph] 60 | List of N Graph endpoints 61 | cond_info: torch.tensor 62 | Conditional information, shape (N, n_info) 63 | random_action_prob: float 64 | Probability of taking a random action 65 | Returns 66 | ------- 67 | data: List[Dict] 68 | A list of trajectories. Each trajectory is a dict with keys 69 | - trajs: List[Tuple[Graph, GraphAction]] 70 | - fwd_logprob: log Z + sum logprobs P_F 71 | - bck_logprob: sum logprobs P_B 72 | - is_valid: is the generated graph valid according to the env & ctx 73 | """ 74 | dev = get_worker_device() 75 | cond_info = cond_info.to(dev) 76 | data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) 77 | return data 78 | 79 | def create_training_data_from_graphs(self, graphs): 80 | """Generate trajectories from known endpoints 81 | 82 | Parameters 83 | ---------- 84 | graphs: List[Graph] 85 | List of Graph endpoints 86 | 87 | Returns 88 | ------- 89 | trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] 90 | A list of trajectories. 91 | """ 92 | return [{"traj": generate_forward_trajectory(i)} for i in graphs] 93 | 94 | def construct_batch(self, trajs, cond_info, log_rewards): 95 | """Construct a batch from a list of trajectories and their information 96 | 97 | Parameters 98 | ---------- 99 | trajs: List[List[tuple[Graph, GraphAction]]] 100 | A list of N trajectories. 101 | cond_info: Tensor 102 | The conditional info that is considered for each trajectory. Shape (N, n_info) 103 | log_rewards: Tensor 104 | The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,) 105 | Returns 106 | ------- 107 | batch: gd.Batch 108 | A (CPU) Batch object with relevant attributes added 109 | """ 110 | torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] 111 | actions = [ 112 | self.ctx.GraphAction_to_ActionIndex(g, a) 113 | for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) 114 | ] 115 | batch = self.ctx.collate(torch_graphs) 116 | batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) 117 | batch.actions = torch.tensor(actions) 118 | batch.log_rewards = log_rewards 119 | batch.cond_info = cond_info 120 | batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() 121 | return batch 122 | 123 | def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): 124 | """Compute the losses over trajectories contained in the batch 125 | 126 | Parameters 127 | ---------- 128 | model: TrajectoryBalanceModel 129 | A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`. 130 | Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info) 131 | batch: gd.Batch 132 | batch of graphs inputs as per constructed by `self.construct_batch` 133 | num_bootstrap: int 134 | the number of trajectories for which the reward loss is computed. Ignored if 0.""" 135 | dev = batch.x.device 136 | # A single trajectory is comprised of many graphs 137 | num_trajs = int(batch.traj_lens.shape[0]) 138 | rewards = torch.exp(batch.log_rewards) 139 | cond_info = batch.cond_info 140 | 141 | # This index says which trajectory each graph belongs to, so 142 | # it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is 143 | # of length 4, trajectory 1 of length 3, and so on. 144 | batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) 145 | # The position of the last graph of each trajectory 146 | final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 147 | 148 | # Forward pass of the model, returns a GraphActionCategorical and per object predictions 149 | # Here we will interpret the logits of the fwd_cat as Q values 150 | Q, per_state_preds = model(batch, cond_info[batch_idx]) 151 | 152 | if self.do_q_prime_correction: 153 | # First we need to estimate V_soft. We will use q_a' = \pi 154 | log_policy = Q.logsoftmax() 155 | # in Eq (10) we have an expectation E_{a~q_a'}[exp(1/alpha Q(s,a))/q_a'(a)] 156 | # we rewrite the inner part `exp(a)/b` as `exp(a-log(b))` since we have the log_policy probabilities 157 | soft_expectation = [Q_sa / self.alpha - logprob for Q_sa, logprob in zip(Q.logits, log_policy)] 158 | # This allows us to more neatly just call logsumexp on the logits, and then multiply by alpha 159 | V_soft = self.alpha * Q.logsumexp(soft_expectation).detach() # shape: (num_graphs,) 160 | else: 161 | V_soft = Q.logsumexp(Q.logits).detach() 162 | rewards = rewards / self.alpha 163 | 164 | # Here were are again hijacking the GraphActionCategorical machinery to get Q[s,a], but 165 | # instead of logprobs we're just going to use the logits, i.e. the Q values. 166 | Q_sa = Q.log_prob(batch.actions, logprobs=Q.logits) 167 | 168 | # We now need to compute the target, \hat Q = R_t + V_soft(s_t+1) 169 | # Shift t+1-> t, pad last state with a 0, multiply by gamma 170 | shifted_V_soft = self.gamma * torch.cat([V_soft[1:], torch.zeros_like(V_soft[:1])]) 171 | # Replace V(s_T) with R(tau). Since we've shifted the values in the array, V(s_T) is V(s_0) 172 | # of the next trajectory in the array, and rewards are terminal (0 except at s_T). 173 | shifted_V_soft[final_graph_idx] = rewards + (1 - batch.is_valid) * self.invalid_penalty 174 | # The result is \hat Q = R_t + gamma V(s_t+1) 175 | hat_Q = shifted_V_soft 176 | 177 | losses = (Q_sa - hat_Q).pow(2) 178 | traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") 179 | loss = losses.mean() 180 | invalid_mask = 1 - batch.is_valid 181 | info = { 182 | "mean_loss": loss, 183 | "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, 184 | "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, 185 | "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, 186 | "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), 187 | } 188 | 189 | if not torch.isfinite(traj_losses).all(): 190 | raise ValueError("loss is not finite") 191 | return loss, info 192 | -------------------------------------------------------------------------------- /src/gflownet/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, fields, is_dataclass 2 | from typing import Optional 3 | 4 | from omegaconf import MISSING 5 | 6 | from gflownet.algo.config import AlgoConfig 7 | from gflownet.data.config import ReplayConfig 8 | from gflownet.models.config import ModelConfig 9 | from gflownet.tasks.config import TasksConfig 10 | from gflownet.utils.config import ConditionalsConfig 11 | from gflownet.utils.misc import StrictDataClass 12 | 13 | 14 | @dataclass 15 | class OptimizerConfig(StrictDataClass): 16 | """Generic configuration for optimizers 17 | 18 | Attributes 19 | ---------- 20 | opt : str 21 | The optimizer to use (either "adam" or "sgd") 22 | learning_rate : float 23 | The learning rate 24 | lr_decay : float 25 | The learning rate decay (in steps, f = 2 ** (-steps / self.cfg.opt.lr_decay)) 26 | weight_decay : float 27 | The L2 weight decay 28 | momentum : float 29 | The momentum parameter value 30 | clip_grad_type : str 31 | The type of gradient clipping to use (either "norm" or "value") 32 | clip_grad_param : float 33 | The parameter for gradient clipping 34 | adam_eps : float 35 | The epsilon parameter for Adam 36 | """ 37 | 38 | opt: str = "adam" 39 | learning_rate: float = 1e-4 40 | lr_decay: float = 20_000 41 | weight_decay: float = 1e-8 42 | momentum: float = 0.9 43 | clip_grad_type: str = "norm" 44 | clip_grad_param: float = 10.0 45 | adam_eps: float = 1e-8 46 | 47 | 48 | @dataclass 49 | class Config(StrictDataClass): 50 | """Base configuration for training 51 | 52 | Attributes 53 | ---------- 54 | desc : str 55 | A description of the experiment 56 | log_dir : str 57 | The directory where to store logs, checkpoints, and samples. 58 | device : str 59 | The device to use for training (either "cpu" or "cuda[:]") 60 | seed : int 61 | The random seed 62 | validate_every : int 63 | The number of training steps after which to validate the model 64 | checkpoint_every : Optional[int] 65 | The number of training steps after which to checkpoint the model 66 | store_all_checkpoints : bool 67 | Whether to store all checkpoints or only the last one 68 | print_every : int 69 | The number of training steps after which to print the training loss 70 | start_at_step : int 71 | The training step to start at (default: 0) 72 | num_final_gen_steps : Optional[int] 73 | After training, the number of steps to generate graphs for 74 | num_training_steps : int 75 | The number of training steps 76 | num_workers : int 77 | The number of workers to use for creating minibatches (0 = no multiprocessing) 78 | hostname : Optional[str] 79 | The hostname of the machine on which the experiment is run 80 | pickle_mp_messages : bool 81 | Whether to pickle messages sent between processes (only relevant if num_workers > 0) 82 | git_hash : Optional[str] 83 | The git hash of the current commit 84 | overwrite_existing_exp : bool 85 | Whether to overwrite the contents of the log_dir if it already exists 86 | """ 87 | 88 | desc: str = "noDesc" 89 | log_dir: str = MISSING 90 | device: str = "cuda" 91 | seed: int = 0 92 | validate_every: int = 1000 93 | checkpoint_every: Optional[int] = None 94 | store_all_checkpoints: bool = False 95 | print_every: int = 100 96 | start_at_step: int = 0 97 | num_final_gen_steps: Optional[int] = None 98 | num_validation_gen_steps: Optional[int] = None 99 | num_training_steps: int = 10_000 100 | num_workers: int = 0 101 | hostname: Optional[str] = None 102 | pickle_mp_messages: bool = False 103 | git_hash: Optional[str] = None 104 | overwrite_existing_exp: bool = False 105 | algo: AlgoConfig = field(default_factory=AlgoConfig) 106 | model: ModelConfig = field(default_factory=ModelConfig) 107 | opt: OptimizerConfig = field(default_factory=OptimizerConfig) 108 | replay: ReplayConfig = field(default_factory=ReplayConfig) 109 | task: TasksConfig = field(default_factory=TasksConfig) 110 | cond: ConditionalsConfig = field(default_factory=ConditionalsConfig) 111 | 112 | 113 | def init_empty(cfg): 114 | """ 115 | Initialize a dataclass instance with all fields set to MISSING, 116 | including nested dataclasses. 117 | 118 | This is meant to be used on the user side (tasks) to provide 119 | some configuration using the Config class while overwritting 120 | only the fields that have been set by the user. 121 | """ 122 | for f in fields(cfg): 123 | if is_dataclass(f.type): 124 | setattr(cfg, f.name, init_empty(f.type())) 125 | else: 126 | setattr(cfg, f.name, MISSING) 127 | 128 | return cfg 129 | -------------------------------------------------------------------------------- /src/gflownet/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/data/__init__.py -------------------------------------------------------------------------------- /src/gflownet/data/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from gflownet.utils.misc import StrictDataClass 5 | 6 | 7 | @dataclass 8 | class ReplayConfig(StrictDataClass): 9 | """Replay buffer configuration 10 | 11 | Attributes 12 | ---------- 13 | use : bool 14 | Whether to use a replay buffer 15 | capacity : int 16 | The capacity of the replay buffer 17 | warmup : int 18 | The number of samples to collect before starting to sample from the replay buffer 19 | hindsight_ratio : float 20 | The ratio of hindsight samples within a batch 21 | num_from_replay : Optional[int] 22 | The number of replayed samples for a training batch (defaults to cfg.algo.num_from_policy, i.e. a 50/50 split) 23 | num_new_samples : Optional[int] 24 | The number of new samples added to the replay at every training step. Defaults to cfg.algo.num_from_policy. If 25 | smaller than num_from_policy then not all on-policy samples will be added to the replay. If larger 26 | than num_from_policy then the training batch will not contain all the new samples, but the buffer will. 27 | For example, if one wishes to sample N samples every step but only add them to the buffer and not make them 28 | part of the training batch, then one should set replay.num_new_samples=N and algo.num_from_policy=0. 29 | """ 30 | 31 | use: bool = False 32 | capacity: Optional[int] = None 33 | warmup: Optional[int] = None 34 | hindsight_ratio: float = 0 35 | num_from_replay: Optional[int] = None 36 | num_new_samples: Optional[int] = None 37 | -------------------------------------------------------------------------------- /src/gflownet/data/qm9.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tarfile 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import rdkit.Chem as Chem 7 | import torch 8 | from rdkit.Chem import QED, Descriptors 9 | from torch.utils.data import Dataset 10 | 11 | from gflownet.utils import sascore 12 | 13 | 14 | class QM9Dataset(Dataset): 15 | def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): 16 | if h5_file is not None: 17 | self.hdf = pd.HDFStore(h5_file, "r") 18 | self.df = self.hdf["df"] 19 | self.is_hdf = True 20 | elif xyz_file is not None: 21 | self.df = load_tar(xyz_file) 22 | self.is_hdf = False 23 | else: 24 | raise ValueError("Either h5_file or xyz_file must be provided") 25 | rng = np.random.default_rng(split_seed) 26 | idcs = np.arange(len(self.df)) 27 | rng.shuffle(idcs) 28 | self.targets = targets 29 | if train: 30 | self.idcs = idcs[: int(np.floor(ratio * len(self.df)))] 31 | else: 32 | self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] 33 | self.obj_to_graph = lambda x: x 34 | 35 | def setup(self, task, ctx): 36 | self.obj_to_graph = ctx.obj_to_graph 37 | 38 | def get_stats(self, target=None, percentile=0.95): 39 | if target is None: 40 | target = self.targets[0] 41 | y = self.df[target] 42 | return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] 43 | 44 | def __len__(self): 45 | return len(self.idcs) 46 | 47 | def __getitem__(self, idx): 48 | return ( 49 | self.obj_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), 50 | torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), 51 | ) 52 | 53 | def terminate(self): 54 | if self.is_hdf: 55 | self.hdf.close() 56 | 57 | 58 | def load_tar(xyz_file): 59 | labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] 60 | f = tarfile.TarFile(xyz_file, "r") 61 | all_mols = [] 62 | for pt in f: 63 | pt = f.extractfile(pt) # type: ignore 64 | data = pt.read().decode().splitlines() # type: ignore 65 | all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) 66 | df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) 67 | mols = df["SMILES"].map(Chem.MolFromSmiles) 68 | df["qed"] = mols.map(QED.qed) 69 | df["sa"] = mols.map(sascore.calculateScore) 70 | df["mw"] = mols.map(Descriptors.MolWt) 71 | return df 72 | 73 | 74 | def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"): 75 | """ 76 | Convert `xyz_file` and dump it into `h5_file` 77 | """ 78 | # File obtained from 79 | # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 80 | # (from http://quantum-machine.org/datasets/) 81 | df = load_tar(xyz_file) 82 | with pd.HDFStore(h5_file, "w") as store: 83 | store["df"] = df 84 | 85 | 86 | if __name__ == "__main__": 87 | convert_h5(*sys.argv[1:]) 88 | -------------------------------------------------------------------------------- /src/gflownet/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from gflownet.config import Config 7 | from gflownet.utils.misc import get_worker_rng 8 | 9 | 10 | class ReplayBuffer(object): 11 | def __init__(self, cfg: Config): 12 | """ 13 | Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) 14 | In self.push(), the buffer detaches any torch tensor and sends it to the CPU. 15 | """ 16 | self.capacity = cfg.replay.capacity 17 | self.warmup = cfg.replay.warmup 18 | assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" 19 | 20 | self.buffer: List[tuple] = [] 21 | self.position = 0 22 | 23 | def push(self, *args): 24 | if len(self.buffer) == 0: 25 | self._input_size = len(args) 26 | else: 27 | assert self._input_size == len(args), "ReplayBuffer input size must be constant" 28 | if len(self.buffer) < self.capacity: 29 | self.buffer.append(None) 30 | args = detach_and_cpu(args) 31 | self.buffer[self.position] = args 32 | self.position = (self.position + 1) % self.capacity 33 | 34 | def sample(self, batch_size): 35 | idxs = get_worker_rng().choice(len(self.buffer), batch_size) 36 | out = list(zip(*[self.buffer[idx] for idx in idxs])) 37 | for i in range(len(out)): 38 | # stack if all elements are numpy arrays or torch tensors 39 | # (this is much more efficient to send arrays through multiprocessing queues) 40 | if all([isinstance(x, np.ndarray) for x in out[i]]): 41 | out[i] = np.stack(out[i], axis=0) 42 | elif all([isinstance(x, torch.Tensor) for x in out[i]]): 43 | out[i] = torch.stack(out[i], dim=0) 44 | else: 45 | out[i] = list(out[i]) 46 | return out 47 | 48 | def __len__(self): 49 | return len(self.buffer) 50 | 51 | 52 | def detach_and_cpu(x): 53 | if isinstance(x, torch.Tensor): 54 | x = x.detach().cpu() 55 | elif isinstance(x, dict): 56 | x = {k: detach_and_cpu(v) for k, v in x.items()} 57 | elif isinstance(x, list): 58 | x = [detach_and_cpu(v) for v in x] 59 | elif isinstance(x, tuple): 60 | x = tuple(detach_and_cpu(v) for v in x) 61 | return x 62 | -------------------------------------------------------------------------------- /src/gflownet/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/envs/__init__.py -------------------------------------------------------------------------------- /src/gflownet/envs/seq_building_env.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch_geometric.data import Data 7 | 8 | from gflownet.envs.graph_building_env import ( 9 | ActionIndex, 10 | Graph, 11 | GraphAction, 12 | GraphActionType, 13 | GraphBuildingEnv, 14 | GraphBuildingEnvContext, 15 | ) 16 | 17 | 18 | # For typing's sake, we'll pretend that a sequence is a graph. 19 | class Seq(Graph): 20 | def __init__(self): 21 | self.seq: list[Any] = [] 22 | 23 | def __repr__(self): 24 | return "".join(map(str, self.seq)) 25 | 26 | @property 27 | def nodes(self): 28 | return self.seq 29 | 30 | 31 | class SeqBuildingEnv(GraphBuildingEnv): 32 | """This class masquerades as a GraphBuildingEnv, but actually generates sequences of tokens.""" 33 | 34 | def __init__(self, variant): 35 | super().__init__() 36 | 37 | def new(self): 38 | return Seq() 39 | 40 | def step(self, g: Graph, a: GraphAction): 41 | s: Seq = deepcopy(g) # type: ignore 42 | if a.action == GraphActionType.AddNode: 43 | s.seq.append(a.value) 44 | return s 45 | 46 | def count_backward_transitions(self, g: Graph, check_idempotent: bool = False): 47 | return 1 48 | 49 | def parents(self, g: Graph): 50 | s: Seq = deepcopy(g) # type: ignore 51 | if not len(s.seq): 52 | return [] 53 | v = s.seq.pop() 54 | return [(GraphAction(GraphActionType.AddNode, value=v), s)] 55 | 56 | def reverse(self, g: Graph, ga: GraphAction): 57 | # TODO: if we implement non-LR variants we'll need to do something here 58 | return GraphAction(GraphActionType.Stop) 59 | 60 | 61 | class SeqBatch: 62 | def __init__(self, seqs: List[torch.Tensor], pad: int): 63 | self.seqs = seqs 64 | self.x = pad_sequence(seqs, batch_first=False, padding_value=pad) 65 | self.mask = self.x.eq(pad).T 66 | self.lens = torch.tensor([len(i) for i in seqs], dtype=torch.long) 67 | # This tells where (in the flattened array of outputs) the non-masked outputs are. 68 | # E.g. if the batch is [["ABC", "VWXYZ"]], logit_idx would be [0, 1, 2, 5, 6, 7, 8, 9] 69 | self.logit_idx = self.x.ne(pad).T.flatten().nonzero().flatten() 70 | # Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this 71 | # is the total number of timesteps. 72 | self.num_graphs = self.lens.sum().item() 73 | 74 | def to(self, device): 75 | for name in dir(self): 76 | x = getattr(self, name) 77 | if isinstance(x, torch.Tensor): 78 | setattr(self, name, x.to(device)) 79 | return self 80 | 81 | 82 | class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext): 83 | """This class masquerades as a GraphBuildingEnvContext, but actually generates sequences of tokens. 84 | 85 | This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion. 86 | """ 87 | 88 | def __init__(self, alphabet: Sequence[str], num_cond_dim=0): 89 | self.alphabet = alphabet 90 | self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode] 91 | 92 | self.num_tokens = len(alphabet) + 2 # Alphabet + BOS + PAD 93 | self.bos_token = len(alphabet) 94 | self.pad_token = len(alphabet) + 1 95 | self.num_actions = len(alphabet) + 1 # Alphabet + Stop 96 | self.num_cond_dim = num_cond_dim 97 | 98 | def ActionIndex_to_GraphAction(self, g: Data, aidx: ActionIndex, fwd: bool = True) -> GraphAction: 99 | # Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0 100 | t = self.action_type_order[aidx.action_type] 101 | if t is GraphActionType.Stop: 102 | return GraphAction(t) 103 | elif t is GraphActionType.AddNode: 104 | return GraphAction(t, value=aidx.col_idx) 105 | raise ValueError(aidx) 106 | 107 | def GraphAction_to_ActionIndex(self, g: Data, action: GraphAction) -> ActionIndex: 108 | if action.action is GraphActionType.Stop: 109 | col = 0 110 | type_idx = self.action_type_order.index(action.action) 111 | elif action.action is GraphActionType.AddNode: 112 | col = action.value 113 | type_idx = self.action_type_order.index(action.action) 114 | else: 115 | raise ValueError(action) 116 | return ActionIndex(action_type=type_idx, row_idx=0, col_idx=int(col)) 117 | 118 | def graph_to_Data(self, g: Graph): 119 | s: Seq = g # type: ignore 120 | return torch.tensor([self.bos_token] + s.seq, dtype=torch.long) 121 | 122 | def collate(self, graphs: List[Data]): 123 | return SeqBatch(graphs, pad=self.pad_token) 124 | 125 | def is_sane(self, g: Graph) -> bool: 126 | return True 127 | 128 | def graph_to_obj(self, g: Graph): 129 | s: Seq = g # type: ignore 130 | return "".join(self.alphabet[int(i)] for i in s.seq) 131 | 132 | def object_to_log_repr(self, g: Graph): 133 | return self.graph_to_obj(g) 134 | -------------------------------------------------------------------------------- /src/gflownet/envs/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This test shows an example of how to setup a model and environment. 3 | It trains a model to overfit generating one single molecule. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch_geometric.data as gd 9 | import torch_geometric.nn as gnn 10 | from tqdm import tqdm 11 | 12 | from gflownet.envs.graph_building_env import ( 13 | GraphActionCategorical, 14 | GraphActionType, 15 | GraphBuildingEnv, 16 | generate_forward_trajectory, 17 | ) 18 | from gflownet.envs.mol_building_env import MolBuildingEnvContext 19 | 20 | 21 | class Model(nn.Module): 22 | def __init__(self, env_ctx, num_emb=64): 23 | super().__init__() 24 | self.x2h = nn.Linear(env_ctx.num_node_dim, num_emb) 25 | self.e2h = nn.Linear(env_ctx.num_edge_dim, num_emb) 26 | self.graph2emb = nn.ModuleList( 27 | sum( 28 | [ 29 | [ 30 | gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add"), 31 | gnn.TransformerConv(num_emb, num_emb, edge_dim=num_emb), 32 | ] 33 | for i in range(6) 34 | ], 35 | [], 36 | ) 37 | ) 38 | 39 | def h2l(nl): 40 | return nn.Sequential(nn.Linear(num_emb, num_emb), nn.LeakyReLU(), nn.Linear(num_emb, nl)) 41 | 42 | self.emb2add_edge = h2l(1) 43 | self.emb2add_node = h2l(env_ctx.num_new_node_values) 44 | self.emb2add_node_attr = h2l(env_ctx.num_node_attr_logits) 45 | self.emb2add_edge_attr = h2l(env_ctx.num_edge_attr_logits) 46 | self.emb2stop = h2l(1) 47 | self.action_type_order = [ 48 | GraphActionType.Stop, 49 | GraphActionType.AddNode, 50 | GraphActionType.SetNodeAttr, 51 | GraphActionType.AddEdge, 52 | GraphActionType.SetEdgeAttr, 53 | ] 54 | 55 | def forward(self, g: gd.Batch): 56 | o = self.x2h(g.x) 57 | e = self.e2h(g.edge_attr) 58 | for layer in self.graph2emb: 59 | o = o + layer(o, g.edge_index, e) 60 | glob = gnn.global_mean_pool(o, g.batch) 61 | ne_row, ne_col = g.non_edge_index 62 | # On `::2`, edges are duplicated to make graphs undirected, only take the even ones 63 | e_row, e_col = g.edge_index[:, ::2] 64 | cat = GraphActionCategorical( 65 | g, 66 | raw_logits=[ 67 | self.emb2stop(glob), 68 | self.emb2add_node(o), 69 | self.emb2add_node_attr(o), 70 | self.emb2add_edge(o[ne_row] + o[ne_col]), 71 | self.emb2add_edge_attr(o[e_row] + o[e_col]), 72 | ], 73 | keys=[None, "x", "x", "non_edge_index", "edge_index"], 74 | types=self.action_type_order, 75 | ) 76 | return cat 77 | 78 | 79 | def main(smi, n_steps): 80 | """This trains a model to overfit producing a molecule, runs a 81 | generative episode and tests whether the model has successfully 82 | generated that molecule 83 | 84 | """ 85 | import networkx as nx 86 | import numpy as np 87 | from rdkit import Chem 88 | 89 | np.random.seed(123) 90 | env = GraphBuildingEnv() 91 | ctx = MolBuildingEnvContext() 92 | model = Model(ctx, num_emb=64) 93 | opt = torch.optim.Adam(model.parameters(), 5e-4) 94 | mol = Chem.MolFromSmiles(smi) 95 | molg = ctx.obj_to_graph(mol) 96 | traj = generate_forward_trajectory(molg) 97 | for g, a in traj: 98 | print(a.action, a.source, a.target, a.value) 99 | graphs = [ctx.graph_to_Data(i) for i, _ in traj] 100 | traj_batch = ctx.collate(graphs) 101 | actions = [ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(graphs, [i[1] for i in traj])] 102 | 103 | # Train to overfit 104 | for i in tqdm(range(n_steps)): 105 | fwd_cat = model(traj_batch) 106 | logprob = fwd_cat.log_prob(actions) 107 | loss = -logprob.mean() 108 | if not i % 100: 109 | print(fwd_cat.logits) 110 | print(logprob.exp()) 111 | print(loss) 112 | loss.backward() 113 | opt.step() 114 | opt.zero_grad() 115 | print() 116 | # Generation episode 117 | model.eval() 118 | g = env.new() 119 | for t in range(100): 120 | tg = ctx.graph_to_Data(g) 121 | with torch.no_grad(): 122 | fwd_cat = model(ctx.collate([tg])) 123 | fwd_cat.logsoftmax() 124 | print("stop:", fwd_cat.logprobs[0].exp()) 125 | action = fwd_cat.sample()[0] 126 | print("action prob:", fwd_cat.log_prob([action]).exp()) 127 | if fwd_cat.log_prob([action]).exp().item() < 0.2: 128 | # This test should work but obviously it's not perfect, 129 | # some probability is left on unlikely (wrong) steps 130 | print("oops, starting step over") 131 | continue 132 | graph_action = ctx.ActionIndex_to_GraphAction(tg, action) 133 | print(graph_action.action, graph_action.source, graph_action.target, graph_action.value) 134 | if graph_action.action is GraphActionType.Stop: 135 | break 136 | g = env.step(g, graph_action) 137 | # Make sure the subgraph is isomorphic to the target molecule 138 | issub = nx.algorithms.isomorphism.GraphMatcher(molg, g).subgraph_is_monomorphic() 139 | print(issub) 140 | if not issub: 141 | raise ValueError() 142 | print(g) 143 | new_mol = ctx.graph_to_obj(g) 144 | print(Chem.MolToSmiles(new_mol)) 145 | # This should be True as well 146 | print(new_mol.HasSubstructMatch(mol) and mol.HasSubstructMatch(new_mol)) 147 | 148 | 149 | if __name__ == "__main__": 150 | # Simple mol 151 | main("C1N2C3C2C2C4OC12C34", 500) 152 | # More complicated mol 153 | # main("O=C(NC1=CC=2NC(=NC2C=C1)C=3C=CC=CC3)C4=NN(C=C4N(=O)=O)C", 2000) 154 | -------------------------------------------------------------------------------- /src/gflownet/hyperopt/wandb_demo/README.md: -------------------------------------------------------------------------------- 1 | Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents. 2 | 3 | To launch the search: 4 | 1. `python init_wandb_sweep.py` to intialize the sweep 5 | 2. `sbatch launch_wandb_agents.sh ` to schedule a jobarray in slurm which will launch wandb agents. 6 | The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. 7 | -------------------------------------------------------------------------------- /src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import wandb 6 | 7 | from gflownet.config import Config, init_empty 8 | from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer 9 | 10 | TIME = time.strftime("%m-%d-%H-%M") 11 | ENTITY = "valencelabs" 12 | PROJECT = "gflownet" 13 | SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay" 14 | STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}" 15 | 16 | 17 | # Define the search space of the sweep 18 | sweep_config = { 19 | "name": SWEEP_NAME, 20 | "program": "init_wandb_sweep.py", 21 | "controller": { 22 | "type": "cloud", 23 | }, 24 | "method": "grid", 25 | "parameters": { 26 | "config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]}, 27 | "config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]}, 28 | }, 29 | } 30 | 31 | 32 | def wandb_config_merger(): 33 | config = init_empty(Config()) 34 | wandb_config = wandb.config 35 | 36 | # Set desired config values 37 | config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}" 38 | config.print_every = 100 39 | config.validate_every = 1000 40 | config.num_final_gen_steps = 1000 41 | config.num_training_steps = 40_000 42 | config.pickle_mp_messages = True 43 | config.overwrite_existing_exp = False 44 | config.algo.sampling_tau = 0.95 45 | config.algo.train_random_action_prob = 0.01 46 | config.algo.tb.Z_learning_rate = 1e-3 47 | config.task.seh_moo.objectives = ["seh", "qed"] 48 | config.cond.temperature.sample_dist = "constant" 49 | config.cond.temperature.dist_params = [60.0] 50 | config.cond.weighted_prefs.preference_type = "dirichlet" 51 | config.cond.focus_region.focus_type = None 52 | config.replay.use = False 53 | 54 | # Merge the wandb sweep config with the nested config from gflownet 55 | config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"] 56 | config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"] 57 | 58 | return config 59 | 60 | 61 | if __name__ == "__main__": 62 | # if there no arguments, initialize the sweep, otherwise this is a wandb agent 63 | if len(sys.argv) == 1: 64 | if os.path.exists(STORAGE_DIR): 65 | raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.") 66 | 67 | wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT) 68 | 69 | else: 70 | wandb.init(entity=ENTITY, project=PROJECT) 71 | config = wandb_config_merger() 72 | trial = SEHMOOFragTrainer(config) 73 | trial.run() 74 | -------------------------------------------------------------------------------- /src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Purpose: Script to allocate a node and run a wandb sweep agent on it 4 | # Usage: sbatch launch_wandb_agent.sh 5 | 6 | #SBATCH --job-name=wandb_sweep_agent 7 | #SBATCH --array=1-6 8 | #SBATCH --time=23:59:00 9 | #SBATCH --output=slurm_output_files/%x_%N_%A_%a.out 10 | #SBATCH --gpus=1 11 | #SBATCH --cpus-per-task=16 12 | #SBATCH --mem=16GB 13 | #SBATCH --partition compute 14 | 15 | source activate gfn-py39-torch113 16 | echo "Using environment={$CONDA_DEFAULT_ENV}" 17 | 18 | # launch wandb agent 19 | wandb agent --count 1 --entity valencelabs --project gflownet $1 20 | -------------------------------------------------------------------------------- /src/gflownet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/models/__init__.py -------------------------------------------------------------------------------- /src/gflownet/models/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | 4 | from gflownet.utils.misc import StrictDataClass 5 | 6 | 7 | @dataclass 8 | class GraphTransformerConfig(StrictDataClass): 9 | num_heads: int = 2 10 | ln_type: str = "pre" 11 | num_mlp_layers: int = 0 12 | concat_heads: bool = True 13 | 14 | 15 | class SeqPosEnc(int, Enum): 16 | Pos = 0 17 | Rotary = 1 18 | 19 | 20 | @dataclass 21 | class SeqTransformerConfig(StrictDataClass): 22 | num_heads: int = 2 23 | posenc: SeqPosEnc = SeqPosEnc.Rotary 24 | 25 | 26 | @dataclass 27 | class ModelConfig(StrictDataClass): 28 | """Generic configuration for models 29 | 30 | Attributes 31 | ---------- 32 | num_layers : int 33 | The number of layers in the model 34 | num_emb : int 35 | The number of dimensions of the embedding 36 | """ 37 | 38 | num_layers: int = 3 39 | num_emb: int = 128 40 | dropout: float = 0 41 | graph_transformer: GraphTransformerConfig = field(default_factory=GraphTransformerConfig) 42 | seq_transformer: SeqTransformerConfig = field(default_factory=SeqTransformerConfig) 43 | -------------------------------------------------------------------------------- /src/gflownet/models/seq_transformer.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from https://github.com/MJ10/mo_gfn 2 | import math 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from gflownet.config import Config 9 | from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext 10 | from gflownet.envs.seq_building_env import SeqBatch 11 | from gflownet.models.config import SeqPosEnc 12 | 13 | 14 | class MLPWithDropout(nn.Module): 15 | def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False): 16 | super(MLPWithDropout, self).__init__() 17 | self.in_dim = in_dim 18 | self.out_dim = out_dim 19 | layers = [nn.Linear(in_dim, hidden_layers[0]), nn.ReLU()] 20 | layers += [nn.Dropout(dropout_prob)] if init_drop else [] 21 | for i in range(1, len(hidden_layers)): 22 | layers.extend([nn.Linear(hidden_layers[i - 1], hidden_layers[i]), nn.ReLU(), nn.Dropout(dropout_prob)]) 23 | layers.append(nn.Linear(hidden_layers[-1], out_dim)) 24 | self.model = nn.Sequential(*layers) 25 | 26 | def forward(self, x): 27 | return self.model(x) 28 | 29 | 30 | class SeqTransformerGFN(nn.Module): 31 | """A standard transformer-encoder based GFN model for sequences.""" 32 | 33 | ctx: GraphBuildingEnvContext 34 | 35 | def __init__( 36 | self, 37 | env_ctx, 38 | cfg: Config, 39 | num_state_out=1, 40 | ): 41 | super().__init__() 42 | self.ctx = env_ctx 43 | self.num_state_out = num_state_out 44 | num_hid = cfg.model.num_emb 45 | num_outs = env_ctx.num_actions + num_state_out 46 | mc = cfg.model 47 | if mc.seq_transformer.posenc == SeqPosEnc.Pos: 48 | self.pos = PositionalEncoding(num_hid, dropout=cfg.model.dropout, max_len=cfg.algo.max_len + 2) 49 | elif mc.seq_transformer.posenc == SeqPosEnc.Rotary: 50 | self.pos = RotaryEmbedding(num_hid) 51 | self.use_cond = env_ctx.num_cond_dim > 0 52 | self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid) 53 | encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout) 54 | self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers) 55 | self._logZ = nn.Linear(env_ctx.num_cond_dim, 1) 56 | if self.use_cond: 57 | self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout) 58 | self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid) 59 | else: 60 | self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout) 61 | self.num_hid = num_hid 62 | 63 | def logZ(self, cond_info: Optional[torch.Tensor]): 64 | if cond_info is None: 65 | return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) 66 | return self._logZ(cond_info) 67 | 68 | def forward(self, xs: SeqBatch, cond, batched=False): 69 | """Returns a GraphActionCategorical and a tensor of state predictions. 70 | 71 | Parameters 72 | ---------- 73 | xs: SeqBatch 74 | A batch of sequences. 75 | cond: torch.Tensor 76 | A tensor of conditional information. 77 | batched: bool 78 | If True, the it's assumed that the cond tensor is constant along a sequence, and the output is given 79 | at each timestep (of the autoregressive process), which works because we are using causal self-attenion. 80 | If False, only the last timesteps' output is returned, which one would use to sample the next token.""" 81 | x = self.embedding(xs.x) 82 | x = self.pos(x) # (time, batch, nemb) 83 | x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) 84 | pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) 85 | 86 | if self.use_cond: 87 | cond_var = self.cond_embed(cond) # (batch, nemb) 88 | cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var 89 | final_rep = torch.cat((x, cond_var), axis=-1) if batched else torch.cat((pooled_x, cond_var), axis=-1) 90 | else: 91 | final_rep = x if batched else pooled_x 92 | 93 | out: torch.Tensor = self.output(final_rep) 94 | ns = self.num_state_out 95 | if batched: 96 | # out is (time, batch, nout) 97 | out = out.transpose(1, 0).contiguous().reshape((-1, out.shape[2])) # (batch * time, nout) 98 | # logit_idx tells us where (in the flattened array of outputs) the non-masked outputs are. 99 | # E.g. if the batch is [["ABC", "VWXYZ"]], logit_idx would be [0, 1, 2, 5, 6, 7, 8, 9] 100 | state_preds = out[xs.logit_idx, 0:ns] # (proper_time, num_state_out) 101 | stop_logits = out[xs.logit_idx, ns : ns + 1] # (proper_time, 1) 102 | add_node_logits = out[xs.logit_idx, ns + 1 :] # (proper_time, nout - 1) 103 | # `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs)) 104 | # which is what we need to give to GraphActionCategorical 105 | else: 106 | # The default num_graphs is computed for the batched case, so we need to fix it here so that 107 | # GraphActionCategorical knows how many "graphs" (sequence inputs) there are 108 | xs.num_graphs = out.shape[0] 109 | # out is (batch, nout) 110 | state_preds = out[:, 0:ns] 111 | stop_logits = out[:, ns : ns + 1] 112 | add_node_logits = out[:, ns + 1 :] 113 | 114 | return ( 115 | GraphActionCategorical( 116 | xs, 117 | raw_logits=[stop_logits, add_node_logits], 118 | keys=[None, None], 119 | types=self.ctx.action_type_order, 120 | slice_dict={}, 121 | ), 122 | state_preds, 123 | ) 124 | 125 | 126 | def generate_square_subsequent_mask(sz: int): 127 | """Generates an upper-triangular matrix of -inf, with zeros on diag.""" 128 | return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) 129 | 130 | 131 | class PositionalEncoding(nn.Module): 132 | def __init__(self, d_model, dropout=0.1, max_len=5000): 133 | super(PositionalEncoding, self).__init__() 134 | self.dropout = nn.Dropout(p=dropout) 135 | pe = torch.zeros(max_len, d_model) 136 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 137 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 138 | pe[:, 0::2] = torch.sin(position * div_term) 139 | pe[:, 1::2] = torch.cos(position * div_term) 140 | pe = pe.unsqueeze(0).transpose(0, 1) 141 | self.register_buffer("pe", pe) 142 | 143 | def forward(self, x): 144 | x = x + self.pe[: x.size(0), :] 145 | return self.dropout(x) 146 | 147 | 148 | # This is adapted from https://github.com/lucidrains/x-transformers 149 | class RotaryEmbedding(nn.Module): 150 | def __init__(self, dim, interpolation_factor=1.0, base=10000, base_rescale_factor=1.0): 151 | super().__init__() 152 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 153 | # has some connection to NTK literature 154 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 155 | base *= base_rescale_factor ** (dim / (dim - 2)) 156 | 157 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 158 | self.register_buffer("inv_freq", inv_freq) 159 | 160 | assert interpolation_factor >= 1.0 161 | self.interpolation_factor = interpolation_factor 162 | 163 | def get_emb(self, seq_len, device): 164 | t = torch.arange(seq_len, device=device).type_as(self.inv_freq) 165 | t = t / self.interpolation_factor 166 | 167 | freqs = torch.einsum("i , j -> i j", t, self.inv_freq) 168 | freqs = torch.cat((freqs, freqs), dim=-1) 169 | 170 | return freqs 171 | 172 | def forward(self, x, scale=1): 173 | x1, x2 = x.reshape(x.shape[:-1] + (2, -1)).unbind(dim=-2) 174 | xrot = torch.cat((-x2, x1), dim=-1) 175 | freqs = self.get_emb(x.shape[0], x.device)[:, None, :] 176 | return (x * freqs.cos() * scale) + (xrot * freqs.sin() * scale) 177 | -------------------------------------------------------------------------------- /src/gflownet/online_trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pathlib 4 | 5 | import git 6 | import torch 7 | from omegaconf import OmegaConf 8 | from torch import Tensor 9 | 10 | from gflownet.algo.advantage_actor_critic import A2C 11 | from gflownet.algo.flow_matching import FlowMatching 12 | from gflownet.algo.soft_q_learning import SoftQLearning 13 | from gflownet.algo.trajectory_balance import TrajectoryBalance 14 | from gflownet.data.replay_buffer import ReplayBuffer 15 | from gflownet.models.graph_transformer import GraphTransformerGFN 16 | 17 | from .trainer import GFNTrainer 18 | 19 | 20 | def model_grad_norm(model): 21 | x = 0 22 | for i in model.parameters(): 23 | if i.grad is not None: 24 | x += (i.grad * i.grad).sum() 25 | return torch.sqrt(x) 26 | 27 | 28 | class StandardOnlineTrainer(GFNTrainer): 29 | def setup_model(self): 30 | self.model = GraphTransformerGFN( 31 | self.ctx, 32 | self.cfg, 33 | do_bck=self.cfg.algo.tb.do_parameterize_p_b, 34 | num_graph_out=self.cfg.algo.tb.do_predict_n + 1, 35 | ) 36 | 37 | def setup_algo(self): 38 | algo = self.cfg.algo.method 39 | if algo == "TB": 40 | algo = TrajectoryBalance 41 | elif algo == "FM": 42 | algo = FlowMatching 43 | elif algo == "A2C": 44 | algo = A2C 45 | elif algo == "SQL": 46 | algo = SoftQLearning 47 | else: 48 | raise ValueError(algo) 49 | self.algo = algo(self.env, self.ctx, self.cfg) 50 | 51 | def setup_data(self): 52 | self.training_data = [] 53 | self.test_data = [] 54 | 55 | def _opt(self, params, lr=None, momentum=None): 56 | if lr is None: 57 | lr = self.cfg.opt.learning_rate 58 | if momentum is None: 59 | momentum = self.cfg.opt.momentum 60 | if self.cfg.opt.opt == "adam": 61 | return torch.optim.Adam( 62 | params, 63 | lr, 64 | (momentum, 0.999), 65 | weight_decay=self.cfg.opt.weight_decay, 66 | eps=self.cfg.opt.adam_eps, 67 | ) 68 | 69 | raise NotImplementedError(f"{self.cfg.opt.opt} is not implemented") 70 | 71 | def setup(self): 72 | super().setup() 73 | self.offline_ratio = 0 74 | self.replay_buffer = ReplayBuffer(self.cfg) if self.cfg.replay.use else None 75 | self.sampling_hooks.append(AvgRewardHook()) 76 | self.valid_sampling_hooks.append(AvgRewardHook()) 77 | 78 | # Separate Z parameters from non-Z to allow for LR decay on the former 79 | if hasattr(self.model, "_logZ"): 80 | Z_params = list(self.model._logZ.parameters()) 81 | non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] 82 | else: 83 | Z_params = [] 84 | non_Z_params = list(self.model.parameters()) 85 | self.opt = self._opt(non_Z_params) 86 | self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) 87 | self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) 88 | self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( 89 | self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) 90 | ) 91 | 92 | self.sampling_tau = self.cfg.algo.sampling_tau 93 | if self.sampling_tau > 0: 94 | self.sampling_model = copy.deepcopy(self.model) 95 | else: 96 | self.sampling_model = self.model 97 | 98 | self.clip_grad_callback = { 99 | "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), 100 | "norm": lambda params: [torch.nn.utils.clip_grad_norm_(p, self.cfg.opt.clip_grad_param) for p in params], 101 | "total_norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), 102 | "none": lambda x: None, 103 | }[self.cfg.opt.clip_grad_type] 104 | 105 | # saving hyperparameters 106 | try: 107 | self.cfg.git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] 108 | except git.InvalidGitRepositoryError: 109 | self.cfg.git_hash = "unknown" # May not have been installed through git 110 | 111 | yaml_cfg = OmegaConf.to_yaml(self.cfg) 112 | if self.print_config: 113 | print("\n\nHyperparameters:\n") 114 | print(yaml_cfg) 115 | os.makedirs(self.cfg.log_dir, exist_ok=True) 116 | with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: 117 | f.write(yaml_cfg) 118 | 119 | def step(self, loss: Tensor): 120 | loss.backward() 121 | with torch.no_grad(): 122 | g0 = model_grad_norm(self.model) 123 | self.clip_grad_callback(self.model.parameters()) 124 | g1 = model_grad_norm(self.model) 125 | self.opt.step() 126 | self.opt.zero_grad() 127 | self.opt_Z.step() 128 | self.opt_Z.zero_grad() 129 | self.lr_sched.step() 130 | self.lr_sched_Z.step() 131 | if self.sampling_tau > 0: 132 | for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): 133 | b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) 134 | return {"grad_norm": g0, "grad_norm_clip": g1} 135 | 136 | 137 | class AvgRewardHook: 138 | def __call__(self, trajs, rewards, obj_props, extra_info): 139 | return {"sampled_reward_avg": rewards.mean().item()} 140 | -------------------------------------------------------------------------------- /src/gflownet/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/tasks/__init__.py -------------------------------------------------------------------------------- /src/gflownet/tasks/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | from gflownet.utils.misc import StrictDataClass 5 | 6 | 7 | @dataclass 8 | class SEHTaskConfig(StrictDataClass): 9 | reduced_frag: bool = False 10 | 11 | 12 | @dataclass 13 | class SEHMOOTaskConfig(StrictDataClass): 14 | """Config for the SEHMOOTask 15 | 16 | Attributes 17 | ---------- 18 | n_valid : int 19 | The number of valid cond_info tensors to sample. 20 | n_valid_repeats : int 21 | The number of times to repeat the valid cond_info tensors. 22 | objectives : List[str] 23 | The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "mw"]. 24 | online_pareto_front : bool 25 | Whether to calculate the pareto front online. 26 | """ 27 | 28 | n_valid: int = 15 29 | n_valid_repeats: int = 128 30 | objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) 31 | log_topk: bool = False 32 | online_pareto_front: bool = True 33 | 34 | 35 | @dataclass 36 | class QM9TaskConfig(StrictDataClass): 37 | h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py 38 | model_path: str = "./data/qm9/qm9_model.pt" 39 | 40 | 41 | @dataclass 42 | class QM9MOOTaskConfig(StrictDataClass): 43 | """ 44 | Config for the QM9MooTask 45 | 46 | Attributes 47 | ---------- 48 | n_valid : int 49 | The number of valid cond_info tensors to sample. 50 | n_valid_repeats : int 51 | The number of times to repeat the valid cond_info tensors. 52 | objectives : List[str] 53 | The objectives to use for the multi-objective optimization. Should be a subset of ["gap", "qed", "sa", "mw"]. 54 | While "mw" can be used, it is not recommended as the molecules are already small. 55 | online_pareto_front : bool 56 | Whether to calculate the pareto front online. 57 | """ 58 | 59 | n_valid: int = 15 60 | n_valid_repeats: int = 128 61 | objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) 62 | online_pareto_front: bool = True 63 | 64 | 65 | @dataclass 66 | class TasksConfig(StrictDataClass): 67 | qm9: QM9TaskConfig = field(default_factory=QM9TaskConfig) 68 | qm9_moo: QM9MOOTaskConfig = field(default_factory=QM9MOOTaskConfig) 69 | seh: SEHTaskConfig = field(default_factory=SEHTaskConfig) 70 | seh_moo: SEHMOOTaskConfig = field(default_factory=SEHMOOTaskConfig) 71 | -------------------------------------------------------------------------------- /src/gflownet/tasks/make_rings.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | from rdkit import Chem 6 | from rdkit.Chem.rdchem import Mol as RDMol 7 | from torch import Tensor 8 | 9 | from gflownet import GFNTask, LogScalar, ObjectProperties 10 | from gflownet.config import Config, init_empty 11 | from gflownet.envs.mol_building_env import MolBuildingEnvContext 12 | from gflownet.online_trainer import StandardOnlineTrainer 13 | 14 | 15 | class MakeRingsTask(GFNTask): 16 | """A toy task where the reward is the number of rings in the molecule.""" 17 | 18 | def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: 19 | return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} 20 | 21 | def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: 22 | scalar_logreward = torch.as_tensor(obj_props).squeeze().clamp(min=1e-30).log() 23 | return LogScalar(scalar_logreward.flatten()) 24 | 25 | def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: 26 | rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() 27 | return ObjectProperties(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() 28 | 29 | 30 | class MakeRingsTrainer(StandardOnlineTrainer): 31 | def set_default_hps(self, cfg: Config): 32 | cfg.hostname = socket.gethostname() 33 | cfg.num_workers = 8 34 | cfg.algo.num_from_policy = 64 35 | cfg.model.num_emb = 128 36 | cfg.model.num_layers = 4 37 | 38 | cfg.algo.method = "TB" 39 | cfg.algo.max_nodes = 6 40 | cfg.algo.sampling_tau = 0.9 41 | cfg.algo.illegal_action_logreward = -75 42 | cfg.algo.train_random_action_prob = 0.0 43 | cfg.algo.valid_random_action_prob = 0.0 44 | cfg.algo.tb.do_parameterize_p_b = True 45 | 46 | cfg.replay.use = False 47 | 48 | def setup_task(self): 49 | self.task = MakeRingsTask() 50 | 51 | def setup_env_context(self): 52 | self.ctx = MolBuildingEnvContext( 53 | ["C"], 54 | charges=[0], # disable charge 55 | chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality 56 | num_rw_feat=0, 57 | max_nodes=self.cfg.algo.max_nodes, 58 | num_cond_dim=1, 59 | ) 60 | 61 | 62 | def main(): 63 | """Example of how this model can be run.""" 64 | config = init_empty(Config()) 65 | config.print_every = 1 66 | config.log_dir = "./logs/debug_run_mr4" 67 | config.device = "cuda" 68 | config.num_training_steps = 10_000 69 | config.num_workers = 8 70 | config.algo.tb.do_parameterize_p_b = True 71 | 72 | trial = MakeRingsTrainer(config) 73 | trial.run() 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /src/gflownet/tasks/qm9.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch_geometric.data as gd 7 | from rdkit.Chem.rdchem import Mol as RDMol 8 | from torch import Tensor 9 | from torch.utils.data import Dataset 10 | 11 | import gflownet.models.mxmnet as mxmnet 12 | from gflownet import GFNTask, LogScalar, ObjectProperties 13 | from gflownet.config import Config, init_empty 14 | from gflownet.data.qm9 import QM9Dataset 15 | from gflownet.envs.mol_building_env import MolBuildingEnvContext 16 | from gflownet.online_trainer import StandardOnlineTrainer 17 | from gflownet.utils.conditioning import TemperatureConditional 18 | from gflownet.utils.misc import get_worker_device 19 | from gflownet.utils.transforms import to_logreward 20 | 21 | 22 | class QM9GapTask(GFNTask): 23 | """This class captures conditional information generation and reward transforms""" 24 | 25 | def __init__( 26 | self, 27 | dataset: Dataset, 28 | cfg: Config, 29 | wrap_model: Callable[[nn.Module], nn.Module] = None, 30 | ): 31 | self._wrap_model = wrap_model 32 | self.device = get_worker_device() 33 | self.models = self.load_task_models(cfg.task.qm9.model_path) 34 | self.dataset = dataset 35 | self.temperature_conditional = TemperatureConditional(cfg) 36 | self.num_cond_dim = self.temperature_conditional.encoding_size() 37 | # TODO: fix interface 38 | self._min, self._max, self._percentile_95 = self.dataset.get_stats("gap", percentile=0.05) # type: ignore 39 | self._width = self._max - self._min 40 | self._rtrans = "unit+95p" # TODO: hyperparameter 41 | 42 | def reward_transform(self, y: Union[float, Tensor]) -> ObjectProperties: 43 | """Transforms a target quantity y (e.g. the LUMO energy in QM9) to a positive reward scalar""" 44 | if self._rtrans == "exp": 45 | flat_r = np.exp(-(y - self._min) / self._width) 46 | elif self._rtrans == "unit": 47 | flat_r = 1 - (y - self._min) / self._width 48 | elif self._rtrans == "unit+95p": 49 | # Add constant such that 5% of rewards are > 1 50 | flat_r = 1 - (y - self._percentile_95) / self._width 51 | else: 52 | raise ValueError(self._rtrans) 53 | return ObjectProperties(flat_r) 54 | 55 | def inverse_reward_transform(self, rp): 56 | if self._rtrans == "exp": 57 | return -np.log(rp) * self._width + self._min 58 | elif self._rtrans == "unit": 59 | return (1 - rp) * self._width + self._min 60 | elif self._rtrans == "unit+95p": 61 | return (1 - rp + (1 - self._percentile_95)) * self._width + self._min 62 | 63 | def load_task_models(self, path): 64 | gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) 65 | # TODO: this path should be part of the config? 66 | try: 67 | state_dict = torch.load(path, map_location=self.device) 68 | except Exception as e: 69 | print( 70 | "Could not load model.", 71 | e, 72 | "\nModel weights can be found at", 73 | "https://storage.valencelabs.com/gflownet/models/mxmnet_gap_model.pt", 74 | ) 75 | gap_model.load_state_dict(state_dict) 76 | gap_model.to(self.device) 77 | gap_model = self._wrap_model(gap_model) 78 | return {"mxmnet_gap": gap_model} 79 | 80 | def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: 81 | return self.temperature_conditional.sample(n) 82 | 83 | def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LogScalar: 84 | return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) 85 | 86 | def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: 87 | batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) 88 | batch.to( 89 | self.models["mxmnet_gap"].device if hasattr(self.models["mxmnet_gap"], "device") else get_worker_device() 90 | ) 91 | preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] 92 | preds[preds.isnan()] = 1 93 | preds = ( 94 | self.reward_transform(preds) 95 | .clip(1e-4, 2) 96 | .reshape( 97 | -1, 98 | ) 99 | ) 100 | return preds 101 | 102 | def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: 103 | graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] 104 | is_valid = torch.tensor([i is not None for i in graphs]).bool() 105 | if not is_valid.any(): 106 | return ObjectProperties(torch.zeros((0, 1))), is_valid 107 | 108 | preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) 109 | assert len(preds) == is_valid.sum() 110 | return ObjectProperties(preds), is_valid 111 | 112 | 113 | class QM9GapTrainer(StandardOnlineTrainer): 114 | def set_default_hps(self, cfg: Config): 115 | cfg.num_workers = 8 116 | cfg.num_training_steps = 100000 117 | cfg.opt.learning_rate = 1e-4 118 | cfg.opt.weight_decay = 1e-8 119 | cfg.opt.momentum = 0.9 120 | cfg.opt.adam_eps = 1e-8 121 | cfg.opt.lr_decay = 20000 122 | cfg.opt.clip_grad_type = "norm" 123 | cfg.opt.clip_grad_param = 10 124 | cfg.algo.max_nodes = 9 125 | cfg.algo.num_from_policy = 32 126 | cfg.algo.num_from_dataset = 32 127 | cfg.algo.train_random_action_prob = 0.001 128 | cfg.algo.illegal_action_logreward = -75 129 | cfg.algo.sampling_tau = 0.0 130 | cfg.model.num_emb = 128 131 | cfg.model.num_layers = 4 132 | cfg.cond.temperature.sample_dist = "uniform" 133 | cfg.cond.temperature.dist_params = [0.5, 32.0] 134 | cfg.cond.temperature.num_thermometer_dim = 32 135 | 136 | def setup_env_context(self): 137 | self.ctx = MolBuildingEnvContext( 138 | ["C", "N", "F", "O"], 139 | expl_H_range=[0, 1, 2, 3], 140 | num_cond_dim=self.task.num_cond_dim, 141 | allow_5_valence_nitrogen=True, 142 | ) 143 | # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories 144 | # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action 145 | # for setting the explicit hydrogen is used before the positive charge is set, it will be considered 146 | # an invalid action. However, generate_forward_trajectory does not consider this implementation detail, 147 | # it assumes that attribute-setting will always be valid. For the molecular environment, as of writing 148 | # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. 149 | 150 | def setup_data(self): 151 | self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=["gap"]) 152 | self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=["gap"]) 153 | self.to_terminate.append(self.training_data.terminate) 154 | self.to_terminate.append(self.test_data.terminate) 155 | 156 | def setup_task(self): 157 | self.task = QM9GapTask( 158 | dataset=self.training_data, 159 | cfg=self.cfg, 160 | wrap_model=self._wrap_for_mp, 161 | ) 162 | 163 | def setup(self): 164 | super().setup() 165 | self.training_data.setup(self.task, self.ctx) 166 | self.test_data.setup(self.task, self.ctx) 167 | 168 | 169 | def main(): 170 | """Example of how this model can be run.""" 171 | config = init_empty(Config()) 172 | config.num_workers = 0 173 | config.num_training_steps = 100000 174 | config.validate_every = 100 175 | config.log_dir = "./logs/debug_qm9" 176 | config.opt.lr_decay = 10000 177 | config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5" 178 | config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt" 179 | 180 | trial = QM9GapTrainer(config) 181 | trial.run() 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /src/gflownet/tasks/seh_frag.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import Callable, Dict, List, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch_geometric.data as gd 7 | from rdkit import Chem 8 | from rdkit.Chem.rdchem import Mol as RDMol 9 | from torch import Tensor 10 | from torch.utils.data import Dataset 11 | from torch_geometric.data import Data 12 | 13 | from gflownet import GFNTask, LogScalar, ObjectProperties 14 | from gflownet.config import Config, init_empty 15 | from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph 16 | from gflownet.models import bengio2021flow 17 | from gflownet.online_trainer import StandardOnlineTrainer 18 | from gflownet.utils.conditioning import TemperatureConditional 19 | from gflownet.utils.misc import get_worker_device 20 | from gflownet.utils.transforms import to_logreward 21 | 22 | 23 | class SEHTask(GFNTask): 24 | """Sets up a task where the reward is computed using a proxy for the binding energy of a molecule to 25 | Soluble Epoxide Hydrolases. 26 | 27 | The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. 28 | 29 | This setup essentially reproduces the results of the Trajectory Balance paper when using the TB 30 | objective, or of the original paper when using Flow Matching. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | cfg: Config, 36 | wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None, 37 | ) -> None: 38 | self._wrap_model = wrap_model if wrap_model is not None else (lambda x: x) 39 | self.models = self._load_task_models() 40 | self.temperature_conditional = TemperatureConditional(cfg) 41 | self.num_cond_dim = self.temperature_conditional.encoding_size() 42 | 43 | def _load_task_models(self): 44 | model = bengio2021flow.load_original_model() 45 | model.to(get_worker_device()) 46 | model = self._wrap_model(model) 47 | return {"seh": model} 48 | 49 | def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: 50 | return self.temperature_conditional.sample(n) 51 | 52 | def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LogScalar: 53 | return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) 54 | 55 | def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: 56 | batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) 57 | batch.to(self.models["seh"].device if hasattr(self.models["seh"], "device") else get_worker_device()) 58 | preds = self.models["seh"](batch).reshape((-1,)).data.cpu() / 8 59 | preds[preds.isnan()] = 0 60 | return preds.clip(1e-4, 100).reshape((-1,)) 61 | 62 | def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: 63 | graphs = [bengio2021flow.mol2graph(i) for i in mols] 64 | is_valid = torch.tensor([i is not None for i in graphs]).bool() 65 | if not is_valid.any(): 66 | return ObjectProperties(torch.zeros((0, 1))), is_valid 67 | 68 | preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) 69 | assert len(preds) == is_valid.sum() 70 | return ObjectProperties(preds), is_valid 71 | 72 | 73 | SOME_MOLS = [ 74 | "O=C(NCc1cc(CCc2cccc(N3CCC(c4cc(-c5cc(-c6cncnc6)[nH]n5)ccn4)CC3)c2)ccn1)c1cccc2ccccc12", 75 | "O=c1nc2[nH]c3cc(-c4cc(C5CC(c6ccc(CNC7CCOC7c7csc(C8=CC(c9ccc%10ccccc%10c9)CCC8)n7)cc6)CO5)c[nH]4)ccc3nc-2c(=O)[nH]1", 76 | "c1ccc(-c2cnn(-c3cc(-c4cc(CCc5cc(C6CCC(c7cc(-c8ccccc8)[nH]n7)CO6)ncn5)n[nH]4)ccn3)c2)cc1", 77 | "O=C(NCc1cc(C2CCNC2C2CCNC2)ncn1)c1cccc(-c2cccc(-c3cccc(C4CCC(c5ccccc5)CO4)c3)c2)c1", 78 | "O=C(NCc1cccc(C2COC(c3ccc4nc5c(=O)[nH]c(=O)nc-5[nH]c4c3)C2)c1)c1cccc(CCc2cccc(-c3ncnc4c3ncn4C3CCCN3)c2)c1", 79 | "O=C(NCc1ccc(OCc2ccc(-c3ccncc3C3CCNCC3)cn2)cc1)c1cccc(N2CCC(C3CCCN3)CC2)n1", 80 | "O=C(NCc1ccc(C2CCC(c3cccc(-c4cccc(C5CCOC5)c4)c3)CO2)cn1)c1ccc(-n2ccc(-c3ccc4nc5c(=O)[nH]c(=O)nc-5[nH]c4c3)n2)cn1", 81 | "O=C(NCc1nc2c(c(=O)[nH]1)NC(c1cn(N3CCN(c4ccc5nc6c(=O)[nH]c(=O)nc-6[nH]c5c4)CC3)c(=O)[nH]c1=O)CN2)c1ccc[n+](-c2cccc(-c3nccc(-c4ccccc4)n3)c2)c1", 82 | "C1=C(C2CCC(c3ccnc(-c4ccc(CNC5CCC(c6ccncc6)OC5)cc4)n3)CO2)CCC(c2cc(-c3cncnc3)c3ccccc3c2)C1", 83 | "O=C(NCc1cccc(-c2nccc(-c3cc(-c4ccc5ccccc5c4)n[nH]3)n2)c1)C1CCC(C2CCC(c3cn(-c4ccc5nc6c(=O)[nH]c(=O)nc-6[nH]c5c4)c(=O)[nH]c3=O)OC2)O1", 84 | "O=C(Nc1ccc2ccccc2c1)c1cccc(-c2cccc(CNN3CCN(C4CCCC(c5cccc(C6CCCN6)c5)C4)CC3)c2)c1", 85 | "O=C(NCC1CC=C(c2cc(CCc3c[nH]c(-c4cccc(-c5ccccc5)c4)c3)n[nH]2)CC1)c1cccc(C2CCNC2)n1", 86 | "O=C(Nc1nccc(CNc2cc(C3CCNC3)n[nH]2)n1)c1nccc(C2CCC(C3CCNCC3c3ccc4ccccc4c3)CO2)n1", 87 | "C1=C(C2CCC(c3ccc(-c4cc(C5CCCNC5)n[nH]4)cc3)OC2)CCCC1CCc1cccc(-c2cccc(-c3ncnc4[nH]cnc34)c2)n1", 88 | "O=C(NCc1cc(C2CCC(C3CCN(c4cc(-c5nccc(-c6cccc(-c7ccccc7)c6)n5)c[nH]4)CC3)CO2)ccn1)c1ccccc1", 89 | "O=C(NCc1cccc(-c2ccn(NCc3ccc(-c4cc(C5CNC(c6ccncc6)C5)c[nH]4)cc3)n2)c1)c1ccc2ccccc2c1", 90 | "O=c1nc2n(-c3cccc(OCc4cccc(CNC5CCC(c6cccc(-c7ccc(C8CCNC8)cc7)c6)OC5)c4)n3)c3ccccc3nc-2c(=O)[nH]1", 91 | "O=C(NCc1ccc(C2OCCC2C2CC(c3ccnc(-c4ccc5ccccc5c4)c3)CO2)cc1)c1nccc(N2C=CCC(c3ccccc3)=C2)n1", 92 | "O=C(NCNC(=O)c1cccc(C(=O)NCc2cccc(-c3ccc4[nH]c5nc(=O)[nH]c(=O)c-5nc4c3)c2)n1)c1ccnc(-c2nccc(C3CCCN3)n2)c1", 93 | "O=c1nc2[nH]c3cc(C4CCC(c5ccc(-c6cc(C7CCC(C8CCCC(C9CCC(c%10ccc(-c%11cncnc%11)cc%10)O9)O8)OC7)ccn6)cn5)CO4)ccc3nc-2c(=O)[nH]1", 94 | "O=c1[nH]c(CNc2cc(-c3cccc(-n4ccc(-c5ccc6ccccc6c5)n4)c3)c[nH]2)nc2c1NC(n1ccc(C3CCC(c4cccnc4)CO3)n1)CN2", 95 | "O=c1nc2[nH]c3cc(C=CC4COC(C5CCCC(C6CCOC(C7CCC(c8cccc(-c9ccnc(-c%10ccc%11ccccc%11c%10)n9)c8)CO7)C6)O5)C4)ccc3nc-2c(=O)[nH]1", 96 | "c1ccc2c(C3CC(CNc4ccnc(C5CCNC5)c4)CO3)cc(NCc3ccc(-c4cc(C5CCNC5)c[nH]4)cc3)cc2c1", 97 | "O=C(NCc1nccc(C2CC(C(=O)NC3CCC(c4ccc5nc6c(=O)[nH]c(=O)nc-6[nH]c5c4)CO3)CCO2)n1)c1ccnc(-n2cc(-n3cnc4cncnc43)cn2)n1", 98 | "O=C(NCc1ccc(-c2ccccc2)cc1)c1cccc(C(=O)NCc2nccc(N3C=CCC(c4ncnc5c4ncn5-c4cccc5ccccc45)=C3)n2)c1", 99 | ] 100 | 101 | 102 | class LittleSEHDataset(Dataset): 103 | """Note: this dataset isn't used by default, but turning it on showcases some features of this codebase. 104 | 105 | To turn on, self `cfg.algo.num_from_dataset > 0`""" 106 | 107 | def __init__(self, smis) -> None: 108 | super().__init__() 109 | self.props: ObjectProperties 110 | self.mols: List[Graph] = [] 111 | self.smis = smis 112 | 113 | def setup(self, task: SEHTask, ctx: FragMolBuildingEnvContext) -> None: 114 | rdmols = [Chem.MolFromSmiles(i) for i in SOME_MOLS] 115 | self.mols = [ctx.obj_to_graph(i) for i in rdmols] 116 | self.props = task.compute_obj_properties(rdmols)[0] 117 | 118 | def __len__(self): 119 | return len(self.mols) 120 | 121 | def __getitem__(self, index): 122 | return self.mols[index], self.props[index] 123 | 124 | 125 | class SEHFragTrainer(StandardOnlineTrainer): 126 | task: SEHTask 127 | training_data: LittleSEHDataset 128 | 129 | def set_default_hps(self, cfg: Config): 130 | cfg.hostname = socket.gethostname() 131 | cfg.pickle_mp_messages = False 132 | cfg.num_workers = 8 133 | cfg.opt.learning_rate = 1e-4 134 | cfg.opt.weight_decay = 1e-8 135 | cfg.opt.momentum = 0.9 136 | cfg.opt.adam_eps = 1e-8 137 | cfg.opt.lr_decay = 20_000 138 | cfg.opt.clip_grad_type = "norm" 139 | cfg.opt.clip_grad_param = 10 140 | cfg.algo.num_from_policy = 64 141 | cfg.model.num_emb = 128 142 | cfg.model.num_layers = 4 143 | 144 | cfg.algo.method = "TB" 145 | cfg.algo.max_nodes = 9 146 | cfg.algo.sampling_tau = 0.9 147 | cfg.algo.illegal_action_logreward = -75 148 | cfg.algo.train_random_action_prob = 0.0 149 | cfg.algo.valid_random_action_prob = 0.0 150 | cfg.algo.valid_num_from_policy = 64 151 | cfg.num_validation_gen_steps = 10 152 | cfg.algo.tb.epsilon = None 153 | cfg.algo.tb.bootstrap_own_reward = False 154 | cfg.algo.tb.Z_learning_rate = 1e-3 155 | cfg.algo.tb.Z_lr_decay = 50_000 156 | cfg.algo.tb.do_parameterize_p_b = False 157 | cfg.algo.tb.do_sample_p_b = True 158 | 159 | cfg.replay.use = False 160 | cfg.replay.capacity = 10_000 161 | cfg.replay.warmup = 1_000 162 | 163 | def setup_task(self): 164 | self.task = SEHTask( 165 | cfg=self.cfg, 166 | wrap_model=self._wrap_for_mp, 167 | ) 168 | 169 | def setup_data(self): 170 | super().setup_data() 171 | if self.cfg.task.seh.reduced_frag: 172 | # The examples don't work with the 18 frags 173 | self.training_data = LittleSEHDataset([]) 174 | else: 175 | self.training_data = LittleSEHDataset(SOME_MOLS) 176 | 177 | def setup_env_context(self): 178 | self.ctx = FragMolBuildingEnvContext( 179 | max_frags=self.cfg.algo.max_nodes, 180 | num_cond_dim=self.task.num_cond_dim, 181 | fragments=bengio2021flow.FRAGMENTS_18 if self.cfg.task.seh.reduced_frag else bengio2021flow.FRAGMENTS, 182 | ) 183 | 184 | def setup(self): 185 | super().setup() 186 | self.training_data.setup(self.task, self.ctx) 187 | 188 | 189 | def main(): 190 | """Example of how this model can be run.""" 191 | import datetime 192 | 193 | config = init_empty(Config()) 194 | config.print_every = 1 195 | config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" 196 | config.device = "cuda" if torch.cuda.is_available() else "cpu" 197 | config.overwrite_existing_exp = True 198 | config.num_training_steps = 1_00 199 | config.validate_every = 20 200 | config.num_final_gen_steps = 10 201 | config.num_workers = 1 202 | config.opt.lr_decay = 20_000 203 | config.algo.sampling_tau = 0.99 204 | config.cond.temperature.sample_dist = "uniform" 205 | config.cond.temperature.dist_params = [0, 64.0] 206 | 207 | trial = SEHFragTrainer(config) 208 | trial.run() 209 | 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /src/gflownet/tasks/toy_seq.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from gflownet import GFNTask, LogScalar, ObjectProperties 8 | from gflownet.config import Config, init_empty 9 | from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv 10 | from gflownet.models.seq_transformer import SeqTransformerGFN 11 | from gflownet.online_trainer import StandardOnlineTrainer 12 | from gflownet.utils.conditioning import TemperatureConditional 13 | from gflownet.utils.transforms import to_logreward 14 | 15 | 16 | class ToySeqTask(GFNTask): 17 | """Sets up a task where the reward is the number of times some sequences appear in the input. Normalized to be 18 | in [0,1]""" 19 | 20 | def __init__( 21 | self, 22 | seqs: List[str], 23 | cfg: Config, 24 | ) -> None: 25 | self.seqs = seqs 26 | self.temperature_conditional = TemperatureConditional(cfg) 27 | self.num_cond_dim = self.temperature_conditional.encoding_size() 28 | self.norm = cfg.algo.max_len / min(map(len, seqs)) 29 | 30 | def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: 31 | return self.temperature_conditional.sample(n) 32 | 33 | def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: 34 | return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(obj_props))) 35 | 36 | def compute_obj_properties(self, objs: List[str]) -> Tuple[ObjectProperties, Tensor]: 37 | rs = torch.tensor([sum([s.count(p) for p in self.seqs]) for s in objs]).float() / self.norm 38 | return ObjectProperties(rs[:, None]), torch.ones(len(objs), dtype=torch.bool) 39 | 40 | 41 | class ToySeqTrainer(StandardOnlineTrainer): 42 | task: ToySeqTask 43 | 44 | def set_default_hps(self, cfg: Config): 45 | cfg.hostname = socket.gethostname() 46 | cfg.pickle_mp_messages = False 47 | cfg.num_workers = 8 48 | cfg.num_validation_gen_steps = 1 49 | cfg.opt.learning_rate = 1e-4 50 | cfg.opt.weight_decay = 1e-8 51 | cfg.opt.momentum = 0.9 52 | cfg.opt.adam_eps = 1e-8 53 | cfg.opt.lr_decay = 20_000 54 | cfg.opt.clip_grad_type = "norm" 55 | cfg.opt.clip_grad_param = 10 56 | cfg.algo.num_from_policy = 64 57 | cfg.model.num_emb = 64 58 | cfg.model.num_layers = 4 59 | 60 | cfg.algo.method = "TB" 61 | cfg.algo.max_nodes = 10 62 | cfg.algo.max_len = 10 63 | cfg.algo.sampling_tau = 0.9 64 | cfg.algo.illegal_action_logreward = -75 65 | cfg.algo.train_random_action_prob = 0.0 66 | cfg.algo.valid_random_action_prob = 0.0 67 | cfg.algo.tb.epsilon = None 68 | cfg.algo.tb.bootstrap_own_reward = False 69 | cfg.algo.tb.Z_learning_rate = 1e-2 70 | cfg.algo.tb.Z_lr_decay = 50_000 71 | cfg.algo.tb.do_parameterize_p_b = False 72 | 73 | def setup_model(self): 74 | self.model = SeqTransformerGFN( 75 | self.ctx, 76 | self.cfg, 77 | ) 78 | 79 | def setup_task(self): 80 | self.task = ToySeqTask( 81 | ["aa", "bb", "cc"], 82 | cfg=self.cfg, 83 | ) 84 | 85 | def setup_env_context(self): 86 | self.env = SeqBuildingEnv(None) 87 | self.ctx = AutoregressiveSeqBuildingContext( 88 | "abc", 89 | self.task.num_cond_dim, 90 | ) 91 | 92 | def setup_algo(self): 93 | super().setup_algo() 94 | # If the algo implements it, avoid giving, ["A", "AB", "ABC", ...] as a sequence of inputs, and instead give 95 | # "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using a transformer with 96 | # causal self-attention. 97 | self.algo.model_is_autoregressive = True 98 | 99 | 100 | def main(): 101 | """Example of how this model can be run.""" 102 | config = init_empty(Config()) 103 | config.log_dir = "./logs/debug_run_toy_seq" 104 | config.device = "cuda" 105 | config.overwrite_existing_exp = True 106 | config.num_training_steps = 2_000 107 | config.checkpoint_every = 200 108 | config.num_workers = 4 109 | config.print_every = 1 110 | config.cond.temperature.sample_dist = "constant" 111 | config.cond.temperature.dist_params = [2.0] 112 | config.cond.temperature.num_thermometer_dim = 1 113 | config.algo.train_random_action_prob = 0.05 114 | 115 | trial = ToySeqTrainer(config) 116 | trial.run() 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /src/gflownet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/utils/__init__.py -------------------------------------------------------------------------------- /src/gflownet/utils/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List, Optional 3 | 4 | from gflownet.utils.misc import StrictDataClass 5 | 6 | 7 | @dataclass 8 | class TempCondConfig(StrictDataClass): 9 | """Config for the temperature conditional. 10 | 11 | Attributes 12 | ---------- 13 | 14 | sample_dist : str 15 | The distribution to sample the inverse temperature from. Can be one of: 16 | - "uniform": uniform distribution 17 | - "loguniform": log-uniform distribution 18 | - "gamma": gamma distribution 19 | - "constant": constant temperature 20 | - "beta": beta distribution 21 | dist_params : List[Any] 22 | The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. 23 | num_thermometer_dim : int 24 | The number of thermometer encoding dimensions to use. 25 | """ 26 | 27 | sample_dist: str = "uniform" 28 | dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) 29 | num_thermometer_dim: int = 32 30 | 31 | 32 | @dataclass 33 | class MultiObjectiveConfig(StrictDataClass): 34 | num_objectives: int = 2 # TODO: Change that as it can conflict with cfg.task.seh_moo.num_objectives 35 | num_thermometer_dim: int = 16 36 | 37 | 38 | @dataclass 39 | class WeightedPreferencesConfig(StrictDataClass): 40 | """Config for the weighted preferences conditional. 41 | 42 | Attributes 43 | ---------- 44 | preference_type : str 45 | The preference sampling distribution, defaults to "dirichlet". Can be one of: 46 | - "dirichlet": Dirichlet distribution 47 | - "dirichlet_exponential": Dirichlet distribution with exponential temperature 48 | - "seeded": Enumerated preferences 49 | - None: All rewards equally weighted""" 50 | 51 | preference_type: Optional[str] = "dirichlet" 52 | preference_param: Optional[float] = 1.5 53 | 54 | 55 | @dataclass 56 | class FocusRegionConfig(StrictDataClass): 57 | """Config for the focus region conditional. 58 | 59 | Attributes 60 | ---------- 61 | focus_type : str 62 | The type of focus distribtuion used, see FocusRegionConditon.setup_focus_regions. Can be one of: 63 | [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] 64 | """ 65 | 66 | focus_type: Optional[str] = "centered" 67 | use_steer_thermomether: bool = False 68 | focus_cosim: float = 0.98 69 | focus_limit_coef: float = 0.1 70 | focus_model_training_limits: tuple[float, float] = (0.25, 0.75) 71 | focus_model_state_space_res: int = 30 72 | max_train_it: int = 20_000 73 | 74 | 75 | @dataclass 76 | class ConditionalsConfig(StrictDataClass): 77 | valid_sample_cond_info: bool = True 78 | temperature: TempCondConfig = field(default_factory=TempCondConfig) 79 | moo: MultiObjectiveConfig = field(default_factory=MultiObjectiveConfig) 80 | weighted_prefs: WeightedPreferencesConfig = field(default_factory=WeightedPreferencesConfig) 81 | focus_region: FocusRegionConfig = field(default_factory=FocusRegionConfig) 82 | -------------------------------------------------------------------------------- /src/gflownet/utils/focus_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gflownet.utils.metrics import get_limits_of_hypercube 7 | 8 | 9 | class FocusModel: 10 | """ 11 | Abstract class for a belief model over focus directions for goal-conditioned GFNs. 12 | Goal-conditioned GFNs allow for more control over the objective-space region from which 13 | we wish to sample. However due to the growing number of emtpy regions in the objective space, 14 | if we naively sample focus-directions from the entire objective space, we will condition 15 | our GFN with a lot of infeasible directions which significantly harms its sample efficiency 16 | compared to a more simple preference-conditioned model. 17 | To alleviate this problem, we introduce a focus belief model which is used to sample 18 | focus directions from a subset of the objective space. The belief model is 19 | trained to predict the probability of a focus direction being feasible. The likelihood 20 | to sample a focus direction is then proportional to its population. Directions that have never 21 | been sampled should be given the maximum likelihood. 22 | """ 23 | 24 | def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: 25 | """ 26 | args: 27 | device: torch device 28 | n_objectives: number of objectives 29 | state_space_res: resolution of the state space discretisation. The number of focus directions to consider 30 | grows within O(state_space_res ** n_objectives) and depends on the amount of filtering we apply 31 | (e.g. valid focus-directions should sum to 1 [dirichlet], should contain a 1 [limits], etc.) 32 | """ 33 | self.device = device 34 | self.n_objectives = n_objectives 35 | self.state_space_res = state_space_res 36 | 37 | self.feasible_flow = 1.0 38 | self.infeasible_flow = 0.1 39 | 40 | def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): 41 | raise NotImplementedError 42 | 43 | def sample_focus_directions(self, n: int): 44 | raise NotImplementedError 45 | 46 | 47 | class TabularFocusModel(FocusModel): 48 | """ 49 | Tabular model of the feasibility of focus directions for goal-condtioning. 50 | We keep a count of the number of times each focus direction has been sampled and whether 51 | this direction succesfully lead to a sample in this region of the objective space. The (unormalized) likelihood 52 | of a focus direction being feasible is then given by the ratio of these numbers. 53 | If a focus direction has not been sampled yet it obtains the maximum likelihood of one. 54 | """ 55 | 56 | def __init__(self, device: torch.device, n_objectives: int, state_space_res: int) -> None: 57 | super().__init__(device, n_objectives, state_space_res) 58 | self.n_objectives = n_objectives 59 | self.state_space_res = state_space_res 60 | self.focus_dir_dataset = ( 61 | nn.functional.normalize(torch.tensor(get_limits_of_hypercube(n_objectives, state_space_res)), dim=1) 62 | .float() 63 | .to(self.device) 64 | ) 65 | self.focus_dir_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) 66 | self.focus_dir_population_count = torch.zeros(self.focus_dir_dataset.shape[0]).to(self.device) 67 | 68 | def update_belief(self, focus_dirs: torch.Tensor, flat_rewards: torch.Tensor): 69 | """ 70 | Updates the focus model with the focus directions and rewards 71 | of the last batch. 72 | """ 73 | focus_dirs = nn.functional.normalize(focus_dirs, dim=1) 74 | flat_rewards = nn.functional.normalize(flat_rewards, dim=1) 75 | 76 | focus_dirs_indices = torch.argmin(torch.cdist(focus_dirs, self.focus_dir_dataset), dim=1) 77 | flat_rewards_indices = torch.argmin(torch.cdist(flat_rewards, self.focus_dir_dataset), dim=1) 78 | 79 | for idxs, count in zip( 80 | [focus_dirs_indices, flat_rewards_indices], 81 | [self.focus_dir_count, self.focus_dir_population_count], 82 | ): 83 | idx_increments = torch.bincount(idxs, minlength=len(count)) 84 | count += idx_increments 85 | 86 | def sample_focus_directions(self, n: int): 87 | """ 88 | Samples n focus directions from the focus model. 89 | """ 90 | sampling_likelihoods = torch.zeros_like(self.focus_dir_count).float().to(self.device) 91 | sampling_likelihoods[self.focus_dir_count == 0] = self.feasible_flow 92 | sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0)] = ( 93 | self.feasible_flow 94 | ) 95 | sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0)] = ( 96 | self.infeasible_flow 97 | ) 98 | focus_dir_indices = torch.multinomial(sampling_likelihoods, n, replacement=True) 99 | return self.focus_dir_dataset[focus_dir_indices].to("cpu") 100 | 101 | def save(self, path: Path): 102 | params = { 103 | "n_objectives": self.n_objectives, 104 | "state_space_res": self.state_space_res, 105 | "focus_dir_dataset": self.focus_dir_dataset.to("cpu"), 106 | "focus_dir_count": self.focus_dir_count.to("cpu"), 107 | "focus_dir_population_count": self.focus_dir_population_count.to("cpu"), 108 | } 109 | torch.save(params, open(path / "tabular_focus_model.pt", "wb")) 110 | 111 | def load(self, device: torch.device, path: Path): 112 | params = torch.load(open(path / "tabular_focus_model.pt", "rb")) 113 | self.n_objectives = params["n_objectives"] 114 | self.state_space_res = params["state_space_res"] 115 | self.focus_dir_dataset = params["focus_dir_dataset"].to(device) 116 | self.focus_dir_count = params["focus_dir_count"].to(device) 117 | self.focus_dir_population_count = params["focus_dir_population_count"].to(device) 118 | -------------------------------------------------------------------------------- /src/gflownet/utils/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/src/gflownet/utils/fpscores.pkl.gz -------------------------------------------------------------------------------- /src/gflownet/utils/graphs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from torch_geometric.utils import to_dense_adj 4 | from torch_scatter import scatter_add 5 | 6 | 7 | def random_walk_probs(g: Data, k: int, skip_odd=False): 8 | source, _ = g.edge_index[0], g.edge_index[1] 9 | deg = scatter_add(torch.ones_like(source), source, dim=0, dim_size=g.num_nodes) 10 | deg_inv = deg.pow(-1.0) 11 | deg_inv.masked_fill_(deg_inv == float("inf"), 0) 12 | 13 | if g.edge_index.shape[1] == 0: 14 | P = g.edge_index.new_zeros((1, g.num_nodes, g.num_nodes)) 15 | else: 16 | # P = D^-1 * A 17 | P = torch.diag(deg_inv) @ to_dense_adj(g.edge_index, max_num_nodes=g.num_nodes) # (1, num nodes, num nodes) 18 | diags = [] 19 | if skip_odd: 20 | Pmult = P @ P 21 | else: 22 | Pmult = P 23 | Pk = Pmult 24 | for _ in range(k): 25 | diags.append(torch.diagonal(Pk, dim1=-2, dim2=-1)) 26 | Pk = Pk @ Pmult 27 | p = torch.cat(diags, dim=0).transpose(0, 1) # (num nodes, k) 28 | return p 29 | -------------------------------------------------------------------------------- /src/gflownet/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(loglevel) 11 | while len([logger.removeHandler(i) for i in logger.handlers]): 12 | pass # Remove all handlers (only useful when debugging) 13 | formatter = logging.Formatter( 14 | fmt="%(asctime)s - %(levelname)s - {} - %(message)s".format(name), 15 | datefmt="%d/%m/%Y %H:%M:%S", 16 | ) 17 | 18 | handlers = [] 19 | if logfile is not None: 20 | handlers.append(logging.FileHandler(logfile, mode="a")) 21 | if streamHandle: 22 | handlers.append(logging.StreamHandler(stream=sys.stdout)) 23 | 24 | for handler in handlers: 25 | handler.setFormatter(formatter) 26 | logger.addHandler(handler) 27 | 28 | return logger 29 | 30 | 31 | _worker_rngs = {} 32 | _worker_rng_seed = [142857] 33 | _main_process_device = [torch.device("cpu")] 34 | 35 | 36 | def get_worker_rng(): 37 | worker_info = torch.utils.data.get_worker_info() 38 | wid = worker_info.id if worker_info is not None else 0 39 | if wid not in _worker_rngs: 40 | _worker_rngs[wid] = np.random.RandomState(_worker_rng_seed[0] + wid) 41 | return _worker_rngs[wid] 42 | 43 | 44 | def set_worker_rng_seed(seed): 45 | _worker_rng_seed[0] = seed 46 | for wid in _worker_rngs: 47 | _worker_rngs[wid].seed(seed + wid) 48 | 49 | 50 | def set_main_process_device(device): 51 | _main_process_device[0] = device 52 | 53 | 54 | def get_worker_device(): 55 | worker_info = torch.utils.data.get_worker_info() 56 | return _main_process_device[0] if worker_info is None else torch.device("cpu") 57 | 58 | 59 | class StrictDataClass: 60 | """ 61 | A dataclass that raises an error if any field is created outside of the __init__ method. 62 | """ 63 | 64 | def __setattr__(self, name, value): 65 | if hasattr(self, name) or name in self.__annotations__: 66 | super().__setattr__(name, value) 67 | else: 68 | raise AttributeError( 69 | f"'{type(self).__name__}' object has no attribute '{name}'." 70 | f" '{type(self).__name__}' is a StrictDataClass object." 71 | f" Attributes can only be defined in the class definition." 72 | ) 73 | -------------------------------------------------------------------------------- /src/gflownet/utils/multiobjective_hooks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import queue 4 | import threading 5 | from collections import defaultdict 6 | from typing import List 7 | 8 | import numpy as np 9 | import torch 10 | import torch.multiprocessing as mp 11 | from torch import Tensor 12 | 13 | from gflownet.utils import metrics 14 | 15 | 16 | class MultiObjectiveStatsHook: 17 | """ 18 | This hook is multithreaded and the keep_alive object needs to be closed for graceful termination. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | num_to_keep: int, 24 | log_dir: str, 25 | save_every: int = 50, 26 | compute_hvi=False, 27 | compute_hsri=False, 28 | compute_normed=False, 29 | compute_igd=False, 30 | compute_pc_entropy=False, 31 | compute_focus_accuracy=False, 32 | focus_cosim=None, 33 | ): 34 | # This __init__ is only called in the main process. This object is then (potentially) cloned 35 | # in pytorch data worker processed and __call__'ed from within those processes. This means 36 | # each process will compute its own Pareto front, which we will accumulate in the main 37 | # process by pushing local fronts to self.pareto_queue. 38 | self.num_to_keep = num_to_keep 39 | self.hsri_epsilon = 0.3 40 | 41 | self.compute_hvi = compute_hvi 42 | self.compute_hsri = compute_hsri 43 | self.compute_normed = compute_normed 44 | self.compute_igd = compute_igd 45 | self.compute_pc_entropy = compute_pc_entropy 46 | self.compute_focus_accuracy = compute_focus_accuracy 47 | self.focus_cosim = focus_cosim 48 | 49 | self.all_flat_rewards: List[Tensor] = [] 50 | self.all_focus_dirs: List[Tensor] = [] 51 | self.all_smi: List[str] = [] 52 | self.pareto_queue: mp.Queue = mp.Queue() 53 | self.pareto_front = None 54 | self.pareto_front_smi = None 55 | self.pareto_metrics = mp.Array("f", 4) 56 | 57 | self.stop = threading.Event() 58 | self.save_every = save_every 59 | self.log_path = pathlib.Path(log_dir) / "pareto.pt" 60 | self.pareto_thread = threading.Thread(target=self._run_pareto_accumulation, daemon=True) 61 | self.pareto_thread.start() 62 | 63 | def _hsri(self, x): 64 | assert x.ndim == 2, "x should have shape (num points, num objectives)" 65 | upper = np.zeros(x.shape[-1]) + self.hsri_epsilon 66 | lower = np.ones(x.shape[-1]) * -1 - self.hsri_epsilon 67 | hsr_indicator = metrics.HSR_Calculator(lower, upper) 68 | try: 69 | hsri, _ = hsr_indicator.calculate_hsr(-x) 70 | except Exception: 71 | hsri = 1e-42 72 | return hsri 73 | 74 | def _run_pareto_accumulation(self): 75 | num_updates = 0 76 | timeouts = 0 77 | while not self.stop.is_set() and timeouts < 200: 78 | try: 79 | r, smi, owid = self.pareto_queue.get(block=True, timeout=1) 80 | except queue.Empty: 81 | timeouts += 1 82 | continue 83 | except ConnectionError as e: 84 | print("Pareto Accumulation thread Queue ConnectionError", e) 85 | break 86 | 87 | timeouts = 0 88 | # accumulates pareto fronts across batches 89 | if self.pareto_front is None: 90 | p = self.pareto_front = r 91 | psmi = smi 92 | else: 93 | p = np.concatenate([self.pareto_front, r], 0) 94 | psmi = self.pareto_front_smi + smi 95 | 96 | # distills down by removing dominated points 97 | idcs = metrics.is_pareto_efficient(-p, False) 98 | self.pareto_front = p[idcs] 99 | self.pareto_front_smi = [psmi[i] for i in idcs] 100 | 101 | # computes pareto metrics and store in multiprocessing array 102 | if self.compute_hvi: 103 | self.pareto_metrics[0] = metrics.get_hypervolume(torch.tensor(self.pareto_front), zero_ref=True) 104 | if self.compute_hsri: 105 | self.pareto_metrics[1] = self._hsri(self.pareto_front) 106 | if self.compute_igd: 107 | self.pareto_metrics[2] = metrics.get_IGD(torch.tensor(self.pareto_front)) 108 | if self.compute_pc_entropy: 109 | self.pareto_metrics[3] = metrics.get_PC_entropy(torch.tensor(self.pareto_front)) 110 | 111 | # saves data to disk 112 | num_updates += 1 113 | if num_updates % self.save_every == 0: 114 | if self.pareto_queue.qsize() > 10: 115 | print("Warning: pareto metrics computation lagging") 116 | self._save() 117 | self._save() 118 | 119 | def _save(self): 120 | with open(self.log_path, "wb") as fd: 121 | torch.save( 122 | { 123 | "pareto_front": self.pareto_front, 124 | "pareto_metrics": list(self.pareto_metrics), 125 | "pareto_front_smi": self.pareto_front_smi, 126 | }, 127 | fd, 128 | ) 129 | 130 | def __call__(self, trajs, rewards, flat_rewards, cond_info): 131 | # locally (in-process) accumulate flat rewards to build a better pareto estimate 132 | self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards) 133 | if self.compute_focus_accuracy: 134 | self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) 135 | self.all_smi = self.all_smi + list([i.get("smi", None) for i in trajs]) 136 | if len(self.all_flat_rewards) > self.num_to_keep: 137 | self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep :] 138 | self.all_focus_dirs = self.all_focus_dirs[-self.num_to_keep :] 139 | self.all_smi = self.all_smi[-self.num_to_keep :] 140 | 141 | flat_rewards = torch.stack(self.all_flat_rewards).numpy() 142 | if self.compute_focus_accuracy: 143 | focus_dirs = torch.stack(self.all_focus_dirs).numpy() 144 | 145 | # collects empirical pareto front from in-process samples 146 | pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False) 147 | gfn_pareto = flat_rewards[pareto_idces] 148 | pareto_smi = [self.all_smi[i] for i in pareto_idces] 149 | 150 | # send pareto front to main process for lifetime accumulation 151 | worker_info = torch.utils.data.get_worker_info() 152 | wid = worker_info.id if worker_info is not None else 0 153 | self.pareto_queue.put((gfn_pareto, pareto_smi, wid)) 154 | 155 | # compute in-process pareto metrics and collects lifetime pareto metrics from main process 156 | info = {} 157 | if self.compute_hvi: 158 | unnorm_hypervolume_with_zero_ref = metrics.get_hypervolume(torch.tensor(gfn_pareto), zero_ref=True) 159 | unnorm_hypervolume_wo_zero_ref = metrics.get_hypervolume(torch.tensor(gfn_pareto), zero_ref=False) 160 | info = { 161 | **info, 162 | "UHV, zero_ref=True": unnorm_hypervolume_with_zero_ref, 163 | "UHV, zero_ref=False": unnorm_hypervolume_wo_zero_ref, 164 | "lifetime_hv0": self.pareto_metrics[0], 165 | } 166 | if self.compute_normed: 167 | target_min = flat_rewards.min(0).copy() 168 | target_range = flat_rewards.max(0).copy() - target_min 169 | hypercube_transform = metrics.Normalizer(loc=target_min, scale=target_range) 170 | normed_gfn_pareto = hypercube_transform(gfn_pareto) 171 | hypervolume_with_zero_ref = metrics.get_hypervolume(torch.tensor(normed_gfn_pareto), zero_ref=True) 172 | hypervolume_wo_zero_ref = metrics.get_hypervolume(torch.tensor(normed_gfn_pareto), zero_ref=False) 173 | info = { 174 | **info, 175 | "HV, zero_ref=True": hypervolume_with_zero_ref, 176 | "HV, zero_ref=False": hypervolume_wo_zero_ref, 177 | } 178 | if self.compute_hsri: 179 | hsri_w_pareto = self._hsri(gfn_pareto) 180 | info = { 181 | **info, 182 | "hsri": hsri_w_pareto, 183 | "lifetime_hsri": self.pareto_metrics[1], 184 | } 185 | if self.compute_igd: 186 | igd = metrics.get_IGD(flat_rewards, ref_front=None) 187 | info = { 188 | **info, 189 | "igd": igd, 190 | "lifetime_igd_frontOnly": self.pareto_metrics[2], 191 | } 192 | if self.compute_pc_entropy: 193 | pc_ent = metrics.get_PC_entropy(flat_rewards, ref_front=None) 194 | info = { 195 | **info, 196 | "PCent": pc_ent, 197 | "lifetime_PCent_frontOnly": self.pareto_metrics[3], 198 | } 199 | if self.compute_focus_accuracy: 200 | focus_acc = metrics.get_focus_accuracy( 201 | torch.tensor(flat_rewards), torch.tensor(focus_dirs), self.focus_cosim 202 | ) 203 | info = { 204 | **info, 205 | "focus_acc": focus_acc, 206 | } 207 | 208 | return info 209 | 210 | def terminate(self): 211 | self.stop.set() 212 | self.pareto_thread.join() 213 | 214 | 215 | class TopKHook: 216 | def __init__(self, k, repeats, num_preferences): 217 | self.queue: mp.Queue = mp.Queue() 218 | self.k = k 219 | self.repeats = repeats 220 | self.num_preferences = num_preferences 221 | 222 | def __call__(self, trajs, rewards, flat_rewards, cond_info): 223 | self.queue.put([(i["data_idx"], r) for i, r in zip(trajs, rewards)]) 224 | return {} 225 | 226 | def finalize(self): 227 | data = [] 228 | while not self.queue.empty(): 229 | try: 230 | data += self.queue.get(True, 1) 231 | except queue.Empty: 232 | # print("Warning, TopKHook queue timed out!") 233 | break 234 | repeats = defaultdict(list) 235 | for idx, r in data: 236 | repeats[idx // self.repeats].append(r) 237 | top_ks = [np.mean(sorted(i)[-self.k :]) for i in repeats.values()] 238 | assert len(top_ks) == self.num_preferences # Make sure we got all of them? 239 | return top_ks 240 | 241 | 242 | class RewardPercentilesHook: 243 | """ 244 | Calculate percentiles of the reward. 245 | 246 | Parameters 247 | ---------- 248 | idx: List[float] 249 | The percentiles to calculate. Should be in the range [0, 1]. 250 | Default: [1.0, 0.75, 0.5, 0.25, 0] 251 | """ 252 | 253 | def __init__(self, percentiles=None): 254 | if percentiles is None: 255 | percentiles = [1.0, 0.75, 0.5, 0.25, 0] 256 | self.percentiles = percentiles 257 | 258 | def __call__(self, trajs, rewards, flat_rewards, cond_info): 259 | x = np.sort(flat_rewards.numpy(), axis=0) 260 | ret = {} 261 | y = np.sort(rewards.numpy()) 262 | for p in self.percentiles: 263 | f = max(min(math.floor(x.shape[0] * p), x.shape[0] - 1), 0) 264 | for j in range(x.shape[1]): 265 | ret[f"percentile_flat_reward_{j}_{p:.2f}"] = x[f, j] 266 | ret[f"percentile_reward_{p:.2f}%"] = y[f] 267 | return ret 268 | 269 | 270 | class TrajectoryLengthHook: 271 | """ 272 | Report the average trajectory length. 273 | """ 274 | 275 | def __init__(self) -> None: 276 | pass 277 | 278 | def __call__(self, trajs, rewards, flat_rewards, cond_info): 279 | ret = {} 280 | ret["sample_len"] = sum([len(i["traj"]) for i in trajs]) / len(trajs) 281 | return ret 282 | -------------------------------------------------------------------------------- /src/gflownet/utils/multiprocessing_proxy.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import queue 3 | import threading 4 | import traceback 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | 9 | 10 | class MPObjectPlaceholder: 11 | """This class can be used for example as a model or dataset placeholder 12 | in a worker process, and translates calls to the object-placeholder into 13 | queries for the main process to execute on the real object.""" 14 | 15 | def __init__(self, in_queues, out_queues, pickle_messages=False): 16 | self.qs = in_queues, out_queues 17 | self.device = torch.device("cpu") 18 | self.pickle_messages = pickle_messages 19 | self._is_init = False 20 | 21 | def _check_init(self): 22 | if self._is_init: 23 | return 24 | info = torch.utils.data.get_worker_info() 25 | if info is None: 26 | self.in_queue = self.qs[0][-1] 27 | self.out_queue = self.qs[1][-1] 28 | else: 29 | self.in_queue = self.qs[0][info.id] 30 | self.out_queue = self.qs[1][info.id] 31 | self._is_init = True 32 | 33 | def encode(self, m): 34 | if self.pickle_messages: 35 | return pickle.dumps(m) 36 | return m 37 | 38 | def decode(self, m): 39 | if self.pickle_messages: 40 | m = pickle.loads(m) 41 | if isinstance(m, Exception): 42 | print("Received exception from main process, reraising.") 43 | raise m 44 | return m 45 | 46 | def __getattr__(self, name): 47 | def method_wrapper(*a, **kw): 48 | self._check_init() 49 | self.in_queue.put(self.encode((name, a, kw))) 50 | return self.decode(self.out_queue.get()) 51 | 52 | return method_wrapper 53 | 54 | def __call__(self, *a, **kw): 55 | self._check_init() 56 | self.in_queue.put(self.encode(("__call__", a, kw))) 57 | return self.decode(self.out_queue.get()) 58 | 59 | def __len__(self): 60 | self._check_init() 61 | self.in_queue.put(("__len__", (), {})) 62 | return self.out_queue.get() 63 | 64 | 65 | class MPObjectProxy: 66 | """This class maintains a reference to some object and 67 | creates a `placeholder` attribute which can be safely passed to 68 | multiprocessing DataLoader workers. 69 | 70 | The placeholders in each process send messages accross multiprocessing 71 | queues which are received by this proxy instance. The proxy instance then 72 | runs the calls on our object and sends the return value back to the worker. 73 | 74 | Starts its own (daemon) thread. 75 | Always passes CPU tensors between processes. 76 | """ 77 | 78 | def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): 79 | """Construct a multiprocessing object proxy. 80 | 81 | Parameters 82 | ---------- 83 | obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) 84 | Lives in the main process to which method calls are passed 85 | num_workers: int 86 | Number of DataLoader workers 87 | cast_types: tuple 88 | Types that will be cast to cuda when received as arguments of method calls. 89 | torch.Tensor is cast by default. 90 | pickle_messages: bool 91 | If True, pickle messages sent between processes. This reduces load on shared 92 | memory, but increases load on CPU. It is recommended to activate this flag if 93 | encountering "Too many open files"-type errors. 94 | """ 95 | self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore 96 | self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore 97 | self.pickle_messages = pickle_messages 98 | self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) 99 | self.obj = obj 100 | if hasattr(obj, "parameters"): 101 | self.device = next(obj.parameters()).device 102 | else: 103 | self.device = torch.device("cpu") 104 | self.cuda_types = (torch.Tensor,) + cast_types 105 | self.stop = threading.Event() 106 | self.thread = threading.Thread(target=self.run, daemon=True) 107 | self.thread.start() 108 | 109 | def encode(self, m): 110 | if self.pickle_messages: 111 | return pickle.dumps(m) 112 | return m 113 | 114 | def decode(self, m): 115 | if self.pickle_messages: 116 | return pickle.loads(m) 117 | return m 118 | 119 | def to_cpu(self, i): 120 | return i.detach().to(torch.device("cpu")) if isinstance(i, self.cuda_types) else i 121 | 122 | def run(self): 123 | timeouts = 0 124 | 125 | while not self.stop.is_set() or timeouts < 500: 126 | for qi, q in enumerate(self.in_queues): 127 | try: 128 | r = self.decode(q.get(True, 1e-5)) 129 | except queue.Empty: 130 | timeouts += 1 131 | continue 132 | except ConnectionError: 133 | break 134 | timeouts = 0 135 | attr, args, kwargs = r 136 | f = getattr(self.obj, attr) 137 | args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] 138 | kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} 139 | try: 140 | # There's no need to compute gradients, since we can't transfer them back to the worker 141 | with torch.no_grad(): 142 | result = f(*args, **kwargs) 143 | except Exception as e: 144 | result = e 145 | exc_str = traceback.format_exc() 146 | try: 147 | pickle.dumps(e) 148 | except Exception: 149 | result = RuntimeError("Exception raised in MPModelProxy, but it cannot be pickled.\n" + exc_str) 150 | if isinstance(result, (list, tuple)): 151 | msg = [self.to_cpu(i) for i in result] 152 | elif isinstance(result, dict): 153 | msg = {k: self.to_cpu(i) for k, i in result.items()} 154 | else: 155 | msg = self.to_cpu(result) 156 | self.out_queues[qi].put(self.encode(msg)) 157 | 158 | def terminate(self): 159 | self.stop.set() 160 | 161 | 162 | def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): 163 | """Construct a multiprocessing object proxy for torch DataLoaders so 164 | that it does not need to be copied in every worker's memory. For example, 165 | this can be used to wrap a model such that only the main process makes 166 | cuda calls by forwarding data through the model, or a replay buffer 167 | such that the new data is pushed in from the worker processes but only the 168 | main process has to hold the full buffer in memory. 169 | self.out_queues[qi].put(self.encode(msg)) 170 | elif isinstance(result, dict): 171 | msg = {k: self.to_cpu(i) for k, i in result.items()} 172 | self.out_queues[qi].put(self.encode(msg)) 173 | else: 174 | msg = self.to_cpu(result) 175 | self.out_queues[qi].put(self.encode(msg)) 176 | 177 | Parameters 178 | ---------- 179 | obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) 180 | Lives in the main process to which method calls are passed 181 | num_workers: int 182 | Number of DataLoader workers 183 | cast_types: tuple 184 | Types that will be cast to cuda when received as arguments of method calls. 185 | torch.Tensor is cast by default. 186 | pickle_messages: bool 187 | If True, pickle messages sent between processes. This reduces load on shared 188 | memory, but increases load on CPU. It is recommended to activate this flag if 189 | encountering "Too many open files"-type errors. 190 | 191 | Returns 192 | ------- 193 | placeholder: MPObjectPlaceholder 194 | A placeholder object whose method calls route arguments to the main process 195 | 196 | """ 197 | return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) 198 | -------------------------------------------------------------------------------- /src/gflownet/utils/sascore.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on 5 | # Molecular Complexity and Fragment Contributions 6 | # Peter Ertl and Ansgar Schuffenhauer 7 | # Journal of Cheminformatics 1:8 (2009) 8 | # http://www.jcheminf.com/content/1/1/8 9 | # 10 | # several small modifications to the original paper are included 11 | # particularly slightly different formula for marocyclic penalty 12 | # and taking into account also molecule symmetry (fingerprint density) 13 | # 14 | # for a set of 10k diverse molecules the agreement between the original method 15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 16 | # 17 | # peter ertl & greg landrum, september 2013 18 | # 19 | 20 | import math 21 | import os.path as op 22 | import pickle # nosec 23 | 24 | from rdkit import Chem 25 | from rdkit.Chem import rdMolDescriptors 26 | 27 | _fscores = None 28 | 29 | 30 | def readFragmentScores(name="fpscores"): 31 | import gzip 32 | 33 | global _fscores 34 | # generate the full path filename: 35 | if name == "fpscores": 36 | name = op.join(op.dirname(__file__), name) 37 | data = pickle.load(gzip.open("%s.pkl.gz" % name)) # nosec 38 | outDict = {} 39 | for i in data: 40 | for j in range(1, len(i)): 41 | outDict[i[j]] = float(i[0]) 42 | _fscores = outDict 43 | 44 | 45 | def numBridgeheadsAndSpiro(mol, ri=None): 46 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 47 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 48 | return nBridgehead, nSpiro 49 | 50 | 51 | def calculateScore(m): 52 | if _fscores is None: 53 | readFragmentScores() 54 | 55 | # fragment score 56 | try: 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 2) # <- 2 is the *radius* of the circular fingerprint 58 | except RuntimeError: 59 | return 9.99 60 | fps = fp.GetNonzeroElements() 61 | score1 = 0.0 62 | nf = 0 63 | for bitId, v in fps.items(): 64 | nf += v 65 | sfp = bitId 66 | score1 += _fscores.get(sfp, -4) * v 67 | score1 /= nf 68 | 69 | # features score 70 | nAtoms = m.GetNumAtoms() 71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 72 | ri = m.GetRingInfo() 73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 74 | nMacrocycles = 0 75 | for x in ri.AtomRings(): 76 | if len(x) > 8: 77 | nMacrocycles += 1 78 | 79 | sizePenalty = nAtoms**1.005 - nAtoms 80 | stereoPenalty = math.log10(nChiralCenters + 1) 81 | spiroPenalty = math.log10(nSpiro + 1) 82 | bridgePenalty = math.log10(nBridgeheads + 1) 83 | macrocyclePenalty = 0.0 84 | # --------------------------------------- 85 | # This differs from the paper, which defines: 86 | # macrocyclePenalty = math.log10(nMacrocycles+1) 87 | # This form generates better results when 2 or more macrocycles are present 88 | if nMacrocycles > 0: 89 | macrocyclePenalty = math.log10(2) 90 | 91 | score2 = 0.0 - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 92 | 93 | # correction for the fingerprint density 94 | # not in the original publication, added in version 1.1 95 | # to make highly symmetrical molecules easier to synthetise 96 | score3 = 0.0 97 | if nAtoms > len(fps): 98 | score3 = math.log(float(nAtoms) / len(fps)) * 0.5 99 | 100 | sascore = score1 + score2 + score3 101 | 102 | # need to transform "raw" value into scale between 1 and 10 103 | min = -4.0 104 | max = 2.5 105 | sascore = 11.0 - (sascore - min + 1) / (max - min) * 9.0 106 | # smooth the 10-end 107 | if sascore > 8.0: 108 | sascore = 8.0 + math.log(sascore + 1.0 - 9.0) 109 | if sascore > 10.0: 110 | sascore = 10.0 111 | elif sascore < 1.0: 112 | sascore = 1.0 113 | 114 | return sascore 115 | 116 | 117 | def processMols(mols): 118 | print("smiles\tName\tsa_score") 119 | for i, m in enumerate(mols): 120 | if m is None: 121 | continue 122 | 123 | s = calculateScore(m) 124 | 125 | smiles = Chem.MolToSmiles(m) 126 | print(smiles + "\t" + m.GetProp("_Name") + "\t%3f" % s) 127 | 128 | 129 | if __name__ == "__main__": 130 | import sys 131 | import time 132 | 133 | t1 = time.time() 134 | readFragmentScores("fpscores") 135 | t2 = time.time() 136 | 137 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 138 | t3 = time.time() 139 | processMols(suppl) 140 | t4 = time.time() 141 | 142 | print("Reading took %.2f seconds. Calculating took %.2f seconds" % ((t2 - t1), (t4 - t3)), file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # 174 | -------------------------------------------------------------------------------- /src/gflownet/utils/sqlite_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | from typing import Iterable 4 | 5 | import torch 6 | 7 | 8 | class SQLiteLogHook: 9 | def __init__(self, log_dir, ctx) -> None: 10 | self.log = None # Only initialized in __call__, which will occur inside the worker 11 | self.log_dir = log_dir 12 | self.ctx = ctx 13 | self.data_labels = None 14 | 15 | def __call__(self, trajs, rewards, obj_props, cond_info): 16 | if self.log is None: 17 | worker_info = torch.utils.data.get_worker_info() 18 | self._wid = worker_info.id if worker_info is not None else 0 19 | os.makedirs(self.log_dir, exist_ok=True) 20 | self.log_path = f"{self.log_dir}/generated_objs_{self._wid}.db" 21 | self.log = SQLiteLog() 22 | self.log.connect(self.log_path) 23 | 24 | if hasattr(self.ctx, "object_to_log_repr"): 25 | objs = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] 26 | else: 27 | objs = [""] * len(trajs) 28 | 29 | obj_props = obj_props.reshape((len(obj_props), -1)).data.numpy().tolist() 30 | rewards = rewards.data.numpy().tolist() 31 | preferences = cond_info.get("preferences", torch.zeros((len(objs), 0))).data.numpy().tolist() 32 | focus_dir = cond_info.get("focus_dir", torch.zeros((len(objs), 0))).data.numpy().tolist() 33 | logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] 34 | 35 | data = [ 36 | [objs[i], rewards[i]] 37 | + obj_props[i] 38 | + preferences[i] 39 | + focus_dir[i] 40 | + [cond_info[k][i].item() for k in logged_keys] 41 | for i in range(len(trajs)) 42 | ] 43 | if self.data_labels is None: 44 | self.data_labels = ( 45 | ["smi", "r"] 46 | + [f"fr_{i}" for i in range(len(obj_props[0]))] 47 | + [f"pref_{i}" for i in range(len(preferences[0]))] 48 | + [f"focus_{i}" for i in range(len(focus_dir[0]))] 49 | + [f"ci_{k}" for k in logged_keys] 50 | ) 51 | 52 | self.log.insert_many(data, self.data_labels) 53 | return {} 54 | 55 | 56 | class SQLiteLog: 57 | def __init__(self, timeout=300): 58 | """Creates a log instance, but does not connect it to any db.""" 59 | self.is_connected = False 60 | self.db = None 61 | self.timeout = timeout 62 | 63 | def connect(self, db_path: str): 64 | """Connects to db_path 65 | 66 | Parameters 67 | ---------- 68 | db_path: str 69 | The sqlite3 database path. If it does not exist, it will be created. 70 | """ 71 | self.db = sqlite3.connect(db_path, timeout=self.timeout) 72 | cur = self.db.cursor() 73 | self._has_results_table = len( 74 | cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() 75 | ) 76 | cur.close() 77 | 78 | def _make_results_table(self, types, names): 79 | type_map = {str: "text", float: "real", int: "real"} 80 | col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) 81 | cur = self.db.cursor() 82 | cur.execute(f"create table results ({col_str})") 83 | self._has_results_table = True 84 | cur.close() 85 | 86 | def insert_many(self, rows, column_names): 87 | assert all( 88 | [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] 89 | ), "rows must only contain scalars" 90 | if not self._has_results_table: 91 | self._make_results_table([type(i) for i in rows[0]], column_names) 92 | cur = self.db.cursor() 93 | cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec 94 | cur.close() 95 | self.db.commit() 96 | 97 | def __del__(self): 98 | if self.db is not None: 99 | self.db.close() 100 | 101 | 102 | def read_all_results(path): 103 | # E402: module level import not at top of file, but pandas is an optional dependency 104 | import pandas as pd # noqa: E402 105 | 106 | num_workers = len([f for f in os.listdir(path) if f.startswith("generated_objs")]) 107 | dfs = [ 108 | pd.read_sql_query("SELECT * FROM results", sqlite3.connect(f"file:{path}/generated_objs_{i}.db?mode=ro")) 109 | for i in range(num_workers) 110 | ] 111 | return pd.concat(dfs).sort_index().reset_index(drop=True) 112 | -------------------------------------------------------------------------------- /src/gflownet/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from gflownet import LogScalar 5 | 6 | 7 | def to_logreward(reward: Tensor) -> LogScalar: 8 | return LogScalar(reward.squeeze().clamp(min=1e-30).log()) 9 | 10 | 11 | def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) -> Tensor: 12 | """Thermometer encoding of a scalar quantity. 13 | 14 | Parameters 15 | ---------- 16 | v: Tensor 17 | Value(s) to encode. Can be any shape 18 | n_bins: int 19 | The number of dimensions to encode the values into 20 | vmin: float 21 | The smallest value, below which the encoding is equal to torch.zeros(n_bins) 22 | vmax: float 23 | The largest value, beyond which the encoding is equal to torch.ones(n_bins) 24 | Returns 25 | ------- 26 | encoding: Tensor 27 | The encoded values, shape: `v.shape + (n_bins,)` 28 | """ 29 | bins = torch.linspace(vmin, vmax, n_bins) 30 | gap = bins[1] - bins[0] 31 | assert gap > 0, "vmin and vmax must be different" 32 | return (v[..., None] - bins.reshape((1,) * v.ndim + (-1,))).clamp(0, gap.item()) / gap 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/gflownet/d04f9e1b23310f5442bede61e18b22ddfe5857d7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_envs.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import pickle 3 | 4 | import networkx as nx 5 | import pytest 6 | 7 | from gflownet.algo.trajectory_balance import TrajectoryBalance 8 | from gflownet.config import Config 9 | from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext 10 | from gflownet.envs.graph_building_env import ActionIndex, GraphBuildingEnv, GraphBuildingEnvContext 11 | from gflownet.envs.mol_building_env import MolBuildingEnvContext 12 | from gflownet.models import bengio2021flow 13 | 14 | 15 | def build_two_node_states(ctx: GraphBuildingEnvContext): 16 | # TODO: This is actually fairly generic code that will probably be reused by other tests in the future. 17 | # Having a proper class to handle graph-indexed hash maps would probably be good. 18 | graph_cache: dict[str, nx.Graph] = {} 19 | graph_by_idx = {} 20 | _graph_cache_buckets = {} 21 | 22 | # We're enumerating all states of length two, but we could've just as well randomly sampled 23 | # some states. 24 | env = GraphBuildingEnv() 25 | 26 | def g2h(g): 27 | gc = g.to_directed() 28 | for e in gc.edges: 29 | gc.edges[e]["v"] = ( 30 | str(gc.edges[e].get(f"{e[0]}_attach", -1)) + "_" + str(gc.edges[e].get(f"{e[1]}_attach", -1)) 31 | ) 32 | h = nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(gc, "v", "v") 33 | if h not in _graph_cache_buckets: 34 | _graph_cache_buckets[h] = [g] 35 | return h + "_0" 36 | else: 37 | bucket = _graph_cache_buckets[h] 38 | for i, gp in enumerate(bucket): 39 | if nx.algorithms.isomorphism.is_isomorphic(g, gp, lambda a, b: a == b, lambda a, b: a == b): 40 | return h + "_" + str(i) 41 | # Nothing was isomorphic 42 | bucket.append(g) 43 | return h + "_" + str(len(bucket) - 1) 44 | 45 | mdp_graph = nx.DiGraph() 46 | mdp_graph.add_node(0) 47 | graph_by_idx[0] = env.new() 48 | 49 | def expand(s, idx): 50 | # Recursively expand all children of s 51 | gd = ctx.graph_to_Data(s) 52 | action_masks = [getattr(gd, gat.mask_name) for gat in ctx.action_type_order] 53 | for at, mask in enumerate(action_masks): 54 | if at == 0: # Ignore Stop action 55 | continue 56 | nz = mask.nonzero() 57 | for i in nz: # Only expand non-masked legal actions 58 | aidx = ActionIndex(at, i[0].item(), i[1].item()) 59 | ga = ctx.ActionIndex_to_GraphAction(gd, aidx) 60 | sp = env.step(s, ga) 61 | h = g2h(sp) 62 | if h in graph_cache: 63 | idxp = graph_cache[h][1] 64 | else: 65 | idxp = len(mdp_graph) 66 | mdp_graph.add_node(idxp) 67 | graph_cache[h] = (sp, idxp) 68 | graph_by_idx[idxp] = sp 69 | expand(sp, idxp) 70 | mdp_graph.add_edge(idx, idxp) 71 | 72 | expand(graph_by_idx[0], 0) 73 | return [graph_by_idx[i] for i in list(nx.topological_sort(mdp_graph))] 74 | 75 | 76 | def get_frag_env_ctx() -> FragMolBuildingEnvContext: 77 | return FragMolBuildingEnvContext(max_frags=2, fragments=bengio2021flow.FRAGMENTS[:20]) 78 | 79 | 80 | def get_atom_env_ctx() -> MolBuildingEnvContext: 81 | return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) 82 | 83 | 84 | @pytest.fixture 85 | def two_node_states_frags(request): 86 | data = request.config.cache.get("frag_env/two_node_states", None) 87 | if data is None: 88 | data = build_two_node_states(get_frag_env_ctx()) 89 | # pytest caches through JSON so we have to make a clean enough string 90 | request.config.cache.set("frag_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) 91 | else: 92 | data = pickle.loads(base64.b64decode(data)) 93 | return data 94 | 95 | 96 | @pytest.fixture 97 | def two_node_states_atoms(request): 98 | data = request.config.cache.get("atom_env/two_node_states", None) 99 | if data is None: 100 | data = build_two_node_states(get_atom_env_ctx()) 101 | # pytest caches through JSON so we have to make a clean enough string 102 | request.config.cache.set("atom_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) 103 | else: 104 | data = pickle.loads(base64.b64decode(data)) 105 | return data 106 | 107 | 108 | def _test_backwards_action_mask_equivalence(two_node_states: list[nx.Graph], ctx: GraphBuildingEnvContext) -> None: 109 | """This tests that FragMolBuildingEnvContext implements backwards action masks correctly. It treats 110 | GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is 111 | a different number of actions leading to the parents of any state. 112 | """ 113 | env = GraphBuildingEnv() 114 | for i in range(1, len(two_node_states)): 115 | g = two_node_states[i] 116 | n = env.count_backward_transitions(g, check_idempotent=False) 117 | nm = 0 118 | gd = ctx.graph_to_Data(g) 119 | for u, k in enumerate(ctx.bck_action_type_order): 120 | m = getattr(gd, k.mask_name) 121 | nm += m.sum() 122 | if n != nm: 123 | raise ValueError() 124 | 125 | 126 | def _test_backwards_action_mask_equivalence_ipa(two_node_states: list[nx.Graph], ctx: GraphBuildingEnvContext) -> None: 127 | """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats 128 | GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is 129 | a different number of actions leading to the parents of any state. 130 | 131 | This test also accounts for idempotent actions. 132 | """ 133 | env = GraphBuildingEnv() 134 | cfg = Config() 135 | cfg.algo.max_nodes = 2 136 | algo = TrajectoryBalance(env, ctx, cfg) 137 | for i in range(1, len(two_node_states)): 138 | g = two_node_states[i] 139 | n = env.count_backward_transitions(g, check_idempotent=True) 140 | gd = ctx.graph_to_Data(g) 141 | # To check that we're computing masks correctly, we need to check that there is the same 142 | # number of idempotent action classes, i.e. groups of actions that lead to the same parent. 143 | equivalence_classes: list[list[tuple[int, int, int]]] = [] 144 | for u, k in enumerate(ctx.bck_action_type_order): 145 | m = getattr(gd, k.mask_name) 146 | for aidx in m.nonzero(): 147 | aidx = ActionIndex(u, aidx[0].item(), aidx[1].item()) 148 | for c in equivalence_classes: 149 | # Here `a` could have been added in another equivalence class by 150 | # get_idempotent_actions. If so, no need to check it. 151 | if aidx in c: 152 | break 153 | else: 154 | ga = ctx.ActionIndex_to_GraphAction(gd, aidx, fwd=False) 155 | gp = env.step(g, ga) 156 | # TODO: It is a bit weird that get_idempotent_actions is in an algo class, 157 | # probably also belongs in a graph utils file. 158 | ipa = algo.get_idempotent_actions(g, gd, gp, ga) 159 | equivalence_classes.append(ipa) 160 | if n != len(equivalence_classes): 161 | raise ValueError() 162 | 163 | 164 | def test_backwards_action_mask_equivalence_frag(two_node_states_frags): 165 | _test_backwards_action_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) 166 | 167 | 168 | def test_backwards_action_mask_equivalence_ipa_frag(two_node_states_frags): 169 | _test_backwards_action_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) 170 | 171 | 172 | def test_backwards_action_mask_equivalence_atom(two_node_states_atoms): 173 | _test_backwards_action_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) 174 | 175 | 176 | def test_backwards_action_mask_equivalence_ipa_atom(two_node_states_atoms): 177 | _test_backwards_action_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx()) 178 | -------------------------------------------------------------------------------- /tests/test_graph_building_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional 3 | from torch_geometric.data import Batch, Data 4 | 5 | from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType 6 | 7 | 8 | def make_test_cat(): 9 | batch = Batch.from_data_list( 10 | [ 11 | Data(x=torch.ones((2, 8)), y=torch.ones((1, 8))), 12 | Data(x=torch.ones((2, 8)), y=torch.ones((1, 8))), 13 | Data(x=torch.ones((2, 8)), y=torch.ones((0, 8))), 14 | ], 15 | follow_batch=["y"], 16 | ) 17 | cat = GraphActionCategorical( 18 | # Let's use arange to have different logit values 19 | batch, 20 | raw_logits=[ 21 | torch.arange(3).reshape((3, 1)).float(), 22 | torch.arange(6 * 4).reshape((6, 4)).float(), 23 | torch.arange(2 * 3).reshape((2, 3)).float(), 24 | ], 25 | types=[GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.AddEdge], 26 | keys=[None, "x", "y"], 27 | ) 28 | return cat 29 | 30 | 31 | def test_batch(): 32 | cat = make_test_cat() 33 | assert (cat.batch[0] == torch.tensor([0, 1, 2])).all() 34 | assert (cat.batch[1] == torch.tensor([0, 0, 1, 1, 2, 2])).all() 35 | assert (cat.batch[2] == torch.tensor([0, 1])).all() 36 | 37 | 38 | def test_slice(): 39 | cat = make_test_cat() 40 | assert (cat.slice[0] == torch.tensor([0, 1, 2, 3])).all() 41 | assert (cat.slice[1] == torch.tensor([0, 2, 4, 6])).all() 42 | assert (cat.slice[2] == torch.tensor([0, 1, 2, 2])).all() 43 | 44 | 45 | def test_logsoftmax(): 46 | cat = make_test_cat() 47 | ls = cat.logsoftmax() 48 | # There are 3 graphs in the batch, so the total probability should be 3 49 | assert torch.isclose(sum([i.exp().sum() for i in ls]), torch.tensor(3.0)) 50 | 51 | 52 | def test_logsoftmax_grad(): 53 | # Purposefully large values to test extremal behaviors 54 | logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() 55 | logits.requires_grad_(True) 56 | batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) 57 | cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) 58 | cat._epsilon = 0 59 | gac_softmax = cat.logsoftmax() 60 | torch_softmax = torch.nn.functional.log_softmax(logits, dim=1) 61 | (grad_gac,) = torch.autograd.grad(gac_softmax[0].sum() + gac_softmax[1].sum(), logits, retain_graph=True) 62 | (grad_torch,) = torch.autograd.grad(torch_softmax.sum(), logits) 63 | assert torch.isclose(grad_gac, grad_torch).all() 64 | 65 | 66 | def test_logsumexp(): 67 | cat = make_test_cat() 68 | totals = torch.tensor( 69 | [ 70 | # Plug in the arange values for each graph 71 | torch.logsumexp(torch.tensor([0.0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2]), 0), 72 | torch.logsumexp(torch.tensor([1.0, 8, 9, 10, 11, 12, 13, 14, 15, 3, 4, 5]), 0), 73 | torch.logsumexp(torch.tensor([2.0, 16, 17, 18, 19, 20, 21, 22, 23]), 0), 74 | ] 75 | ) 76 | assert torch.isclose(cat.logsumexp(), totals).all() 77 | 78 | 79 | def test_logsumexp_grad(): 80 | # Purposefully large values to test extremal behaviors 81 | logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() 82 | logits.requires_grad_(True) 83 | batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) 84 | cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) 85 | cat._epsilon = 0 86 | (grad_gac,) = torch.autograd.grad(cat.logsumexp(), logits, retain_graph=True) 87 | (grad_torch,) = torch.autograd.grad(torch.logsumexp(logits, dim=1), logits) 88 | assert torch.isclose(grad_gac, grad_torch).all() 89 | 90 | 91 | def test_sample(): 92 | # Let's just make sure we can sample and compute logprobs without error 93 | cat = make_test_cat() 94 | actions = cat.sample() 95 | logprobs = cat.log_prob(actions) 96 | assert logprobs is not None 97 | 98 | 99 | def test_argmax(): 100 | cat = make_test_cat() 101 | # The AddNode logits has the most actions, and each graph has two rows each, so the argmax 102 | # should be 1,1,3 (1th action, AddNode, 1th row is larger due to arange, 3rd col is largest due 103 | # to arange) 104 | assert cat.argmax(cat.logits) == [(1, 1, 3), (1, 1, 3), (1, 1, 3)] 105 | 106 | 107 | def test_log_prob(): 108 | cat = make_test_cat() 109 | logprobs = cat.logsoftmax() 110 | actions = [[0, 0, 0], [2, 0, 2], [1, 1, 3]] 111 | correct_lp = torch.stack([logprobs[t][row + cat.slice[t][i], col] for i, (t, row, col) in enumerate(actions)]) 112 | assert (cat.log_prob(actions) == correct_lp).all() 113 | 114 | actions = [[1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3]] 115 | batch = torch.tensor([1, 1, 1, 1]) 116 | correct_lp = torch.stack([logprobs[t][row + cat.slice[t][i], col] for i, (t, row, col) in zip(batch, actions)]) 117 | assert (cat.log_prob(actions, batch=batch) == correct_lp).all() 118 | 119 | actions = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] 120 | correct_lp = torch.arange(3) 121 | assert (cat.log_prob(actions, logprobs=cat.logits) == correct_lp).all() 122 | 123 | 124 | def test_entropy(): 125 | cat = make_test_cat() 126 | cat.entropy() 127 | -------------------------------------------------------------------------------- /tests/test_subtb.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | 7 | from gflownet.algo.trajectory_balance import subTB 8 | from gflownet.envs.frag_mol_env import NCounter 9 | 10 | 11 | def subTB_ref(P_F, P_B, F): 12 | h = F.shape[0] - 1 13 | assert P_F.shape == P_B.shape == (h,) 14 | assert F.ndim == 1 15 | 16 | dtype = reduce(torch.promote_types, [P_F.dtype, P_B.dtype, F.dtype]) 17 | D = torch.zeros(h, h, dtype=dtype) 18 | for i in range(h): 19 | for j in range(i, h): 20 | D[i, j] = F[i] - F[j + 1] 21 | D[i, j] = D[i, j] + P_F[i : j + 1].sum() 22 | D[i, j] = D[i, j] - P_B[i : j + 1].sum() 23 | return D 24 | 25 | 26 | def test_subTB(): 27 | for T in range(5, 20): 28 | T = 10 29 | P_F = torch.randint(1, 10, (T,)) 30 | P_B = torch.randint(1, 10, (T,)) 31 | F = torch.randint(1, 10, (T + 1,)) 32 | assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all() 33 | 34 | 35 | def test_n(): 36 | n = NCounter() 37 | x = 0 38 | for i in range(1, 10): 39 | x += np.log(i) 40 | assert np.isclose(n.lfac(i), x) 41 | 42 | assert np.isclose(n.lcomb(5, 2), np.log(10)) 43 | 44 | 45 | def test_g1(): 46 | n = NCounter() 47 | g = nx.Graph() 48 | for i in range(3): 49 | g.add_node(i) 50 | g.add_edge(0, 1) 51 | g.add_edge(1, 2) 52 | rg = n.root_tree(g, 0) 53 | assert n.f(rg, 0) == 0 54 | rg = n.root_tree(g, 2) 55 | assert n.f(rg, 2) == 0 56 | rg = n.root_tree(g, 1) 57 | assert np.isclose(n.f(rg, 1), np.log(2)) 58 | 59 | assert np.isclose(n(g), np.log(4)) 60 | 61 | 62 | def test_g(): 63 | n = NCounter() 64 | g = nx.Graph() 65 | for i in range(3): 66 | g.add_node(i) 67 | g.add_edge(0, 1) 68 | g.add_edge(1, 2, weight=2) 69 | rg = n.root_tree(g, 0) 70 | assert n.f(rg, 0) == 0 71 | rg = n.root_tree(g, 2) 72 | assert np.isclose(n.f(rg, 2), np.log(2)) 73 | rg = n.root_tree(g, 1) 74 | assert np.isclose(n.f(rg, 1), np.log(3)) 75 | 76 | assert np.isclose(n(g), np.log(6)) 77 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py3{10}, report 3 | 4 | [testenv] 5 | commands = pytest 6 | skip_install = true 7 | depends = 8 | report: py3{10} 9 | setenv = 10 | py3{10}: COVERAGE_FILE = .coverage.{envname} 11 | install_command = 12 | python -m pip install -U {opts} {packages} --find-links https://data.pyg.org/whl/torch-2.1.2+cpu.html 13 | deps = 14 | py310: -r requirements/dev-3.10.txt 15 | 16 | 17 | [testenv:report] 18 | deps = coverage 19 | skip_install = true 20 | commands = 21 | coverage combine 22 | coverage report --fail-under=0 23 | 24 | [testenv:style] 25 | deps = 26 | types-setuptools 27 | pre-commit 28 | ruff 29 | isort 30 | mypy 31 | bandit[toml] 32 | black 33 | skip_install = true 34 | commands = pre-commit run --all-files --show-diff-on-failure 35 | --------------------------------------------------------------------------------